Set the ORIGIN header from a custom headers if present (#9435)

Motivation:

Allow to set the ORIGIN header value from custom headers in WebSocketClientHandshaker

Modification:

Only override header if not present already

Result:

More flexible handshaker usage
This commit is contained in:
Andrey Mizurov 2019-08-11 09:22:17 +03:00 committed by Norman Maurer
parent dcd145864a
commit 4340f6c116
7 changed files with 37 additions and 14 deletions

View File

@ -190,10 +190,13 @@ public class WebSocketClientHandshaker00 extends WebSocketClientHandshaker {
headers.set(HttpHeaderNames.UPGRADE, WEBSOCKET)
.set(HttpHeaderNames.CONNECTION, HttpHeaderValues.UPGRADE)
.set(HttpHeaderNames.HOST, websocketHostValue(wsURL))
.set(HttpHeaderNames.ORIGIN, websocketOriginValue(wsURL))
.set(HttpHeaderNames.SEC_WEBSOCKET_KEY1, key1)
.set(HttpHeaderNames.SEC_WEBSOCKET_KEY2, key2);
if (!headers.contains(HttpHeaderNames.ORIGIN)) {
headers.set(HttpHeaderNames.ORIGIN, websocketOriginValue(wsURL));
}
String expectedSubprotocol = expectedSubprotocol();
if (expectedSubprotocol != null && !expectedSubprotocol.isEmpty()) {
headers.set(HttpHeaderNames.SEC_WEBSOCKET_PROTOCOL, expectedSubprotocol);

View File

@ -223,8 +223,11 @@ public class WebSocketClientHandshaker07 extends WebSocketClientHandshaker {
headers.set(HttpHeaderNames.UPGRADE, HttpHeaderValues.WEBSOCKET)
.set(HttpHeaderNames.CONNECTION, HttpHeaderValues.UPGRADE)
.set(HttpHeaderNames.SEC_WEBSOCKET_KEY, key)
.set(HttpHeaderNames.HOST, websocketHostValue(wsURL))
.set(HttpHeaderNames.SEC_WEBSOCKET_ORIGIN, websocketOriginValue(wsURL));
.set(HttpHeaderNames.HOST, websocketHostValue(wsURL));
if (!headers.contains(HttpHeaderNames.SEC_WEBSOCKET_ORIGIN)) {
headers.set(HttpHeaderNames.SEC_WEBSOCKET_ORIGIN, websocketOriginValue(wsURL));
}
String expectedSubprotocol = expectedSubprotocol();
if (expectedSubprotocol != null && !expectedSubprotocol.isEmpty()) {

View File

@ -225,8 +225,11 @@ public class WebSocketClientHandshaker08 extends WebSocketClientHandshaker {
headers.set(HttpHeaderNames.UPGRADE, HttpHeaderValues.WEBSOCKET)
.set(HttpHeaderNames.CONNECTION, HttpHeaderValues.UPGRADE)
.set(HttpHeaderNames.SEC_WEBSOCKET_KEY, key)
.set(HttpHeaderNames.HOST, websocketHostValue(wsURL))
.set(HttpHeaderNames.SEC_WEBSOCKET_ORIGIN, websocketOriginValue(wsURL));
.set(HttpHeaderNames.HOST, websocketHostValue(wsURL));
if (!headers.contains(HttpHeaderNames.SEC_WEBSOCKET_ORIGIN)) {
headers.set(HttpHeaderNames.SEC_WEBSOCKET_ORIGIN, websocketOriginValue(wsURL));
}
String expectedSubprotocol = expectedSubprotocol();
if (expectedSubprotocol != null && !expectedSubprotocol.isEmpty()) {

View File

@ -226,8 +226,11 @@ public class WebSocketClientHandshaker13 extends WebSocketClientHandshaker {
headers.set(HttpHeaderNames.UPGRADE, HttpHeaderValues.WEBSOCKET)
.set(HttpHeaderNames.CONNECTION, HttpHeaderValues.UPGRADE)
.set(HttpHeaderNames.SEC_WEBSOCKET_KEY, key)
.set(HttpHeaderNames.HOST, websocketHostValue(wsURL))
.set(HttpHeaderNames.ORIGIN, websocketOriginValue(wsURL));
.set(HttpHeaderNames.HOST, websocketHostValue(wsURL));
if (!headers.contains(HttpHeaderNames.ORIGIN)) {
headers.set(HttpHeaderNames.ORIGIN, websocketOriginValue(wsURL));
}
String expectedSubprotocol = expectedSubprotocol();
if (expectedSubprotocol != null && !expectedSubprotocol.isEmpty()) {

View File

@ -39,12 +39,11 @@ public class WebSocketClientHandshaker00Test extends WebSocketClientHandshakerTe
}
@Override
protected CharSequence[] getHandshakeHeaderNames() {
protected CharSequence[] getHandshakeRequiredHeaderNames() {
return new CharSequence[] {
HttpHeaderNames.CONNECTION,
HttpHeaderNames.UPGRADE,
HttpHeaderNames.HOST,
HttpHeaderNames.ORIGIN,
HttpHeaderNames.SEC_WEBSOCKET_KEY1,
HttpHeaderNames.SEC_WEBSOCKET_KEY2,
};

View File

@ -40,13 +40,12 @@ public class WebSocketClientHandshaker07Test extends WebSocketClientHandshakerTe
}
@Override
protected CharSequence[] getHandshakeHeaderNames() {
protected CharSequence[] getHandshakeRequiredHeaderNames() {
return new CharSequence[] {
HttpHeaderNames.UPGRADE,
HttpHeaderNames.CONNECTION,
HttpHeaderNames.SEC_WEBSOCKET_KEY,
HttpHeaderNames.HOST,
getOriginHeaderName(),
HttpHeaderNames.SEC_WEBSOCKET_VERSION,
};
}

View File

@ -51,7 +51,7 @@ public abstract class WebSocketClientHandshakerTest {
protected abstract CharSequence getProtocolHeaderName();
protected abstract CharSequence[] getHandshakeHeaderNames();
protected abstract CharSequence[] getHandshakeRequiredHeaderNames();
@Test
public void hostHeaderWs() {
@ -161,6 +161,19 @@ public abstract class WebSocketClientHandshakerTest {
testOriginHeader("//LOCALHOST/", "http://localhost");
}
@Test
public void testSetOriginFromCustomHeaders() {
HttpHeaders customHeaders = new DefaultHttpHeaders().set(getOriginHeaderName(), "http://example.com");
WebSocketClientHandshaker handshaker = newHandshaker(URI.create("ws://server.example.com/chat"), null,
customHeaders, false);
FullHttpRequest request = handshaker.newHandshakeRequest();
try {
assertEquals("http://example.com", request.headers().get(getOriginHeaderName()));
} finally {
request.release();
}
}
private void testHostHeader(String uri, String expected) {
testHeaderDefaultHttp(uri, HttpHeaderNames.HOST, expected);
}
@ -326,7 +339,7 @@ public abstract class WebSocketClientHandshakerTest {
String bogusHeaderValue = "bogusHeaderValue";
// add values for the headers that are reserved for use in the websockets handshake
for (CharSequence header : getHandshakeHeaderNames()) {
for (CharSequence header : getHandshakeRequiredHeaderNames()) {
inputHeaders.add(header, bogusHeaderValue);
}
inputHeaders.add(getProtocolHeaderName(), bogusSubProtocol);
@ -337,7 +350,7 @@ public abstract class WebSocketClientHandshakerTest {
HttpHeaders outputHeaders = request.headers();
// the header values passed in originally have been replaced with values generated by the Handshaker
for (CharSequence header : getHandshakeHeaderNames()) {
for (CharSequence header : getHandshakeRequiredHeaderNames()) {
assertEquals(1, outputHeaders.getAll(header).size());
assertNotEquals(bogusHeaderValue, outputHeaders.get(header));
}