Set the ORIGIN header from a custom headers if present (#9435)

Motivation:

Allow to set the ORIGIN header value from custom headers in WebSocketClientHandshaker

Modification:

Only override header if not present already

Result:

More flexible handshaker usage
This commit is contained in:
Andrey Mizurov 2019-08-11 09:22:17 +03:00 committed by Norman Maurer
parent fedcc40196
commit b8ac02d8ac
7 changed files with 37 additions and 14 deletions

View File

@ -189,10 +189,13 @@ public class WebSocketClientHandshaker00 extends WebSocketClientHandshaker {
headers.set(HttpHeaderNames.UPGRADE, WEBSOCKET) headers.set(HttpHeaderNames.UPGRADE, WEBSOCKET)
.set(HttpHeaderNames.CONNECTION, HttpHeaderValues.UPGRADE) .set(HttpHeaderNames.CONNECTION, HttpHeaderValues.UPGRADE)
.set(HttpHeaderNames.HOST, websocketHostValue(wsURL)) .set(HttpHeaderNames.HOST, websocketHostValue(wsURL))
.set(HttpHeaderNames.ORIGIN, websocketOriginValue(wsURL))
.set(HttpHeaderNames.SEC_WEBSOCKET_KEY1, key1) .set(HttpHeaderNames.SEC_WEBSOCKET_KEY1, key1)
.set(HttpHeaderNames.SEC_WEBSOCKET_KEY2, key2); .set(HttpHeaderNames.SEC_WEBSOCKET_KEY2, key2);
if (!headers.contains(HttpHeaderNames.ORIGIN)) {
headers.set(HttpHeaderNames.ORIGIN, websocketOriginValue(wsURL));
}
String expectedSubprotocol = expectedSubprotocol(); String expectedSubprotocol = expectedSubprotocol();
if (expectedSubprotocol != null && !expectedSubprotocol.isEmpty()) { if (expectedSubprotocol != null && !expectedSubprotocol.isEmpty()) {
headers.set(HttpHeaderNames.SEC_WEBSOCKET_PROTOCOL, expectedSubprotocol); headers.set(HttpHeaderNames.SEC_WEBSOCKET_PROTOCOL, expectedSubprotocol);

View File

@ -223,8 +223,11 @@ public class WebSocketClientHandshaker07 extends WebSocketClientHandshaker {
headers.set(HttpHeaderNames.UPGRADE, HttpHeaderValues.WEBSOCKET) headers.set(HttpHeaderNames.UPGRADE, HttpHeaderValues.WEBSOCKET)
.set(HttpHeaderNames.CONNECTION, HttpHeaderValues.UPGRADE) .set(HttpHeaderNames.CONNECTION, HttpHeaderValues.UPGRADE)
.set(HttpHeaderNames.SEC_WEBSOCKET_KEY, key) .set(HttpHeaderNames.SEC_WEBSOCKET_KEY, key)
.set(HttpHeaderNames.HOST, websocketHostValue(wsURL)) .set(HttpHeaderNames.HOST, websocketHostValue(wsURL));
.set(HttpHeaderNames.SEC_WEBSOCKET_ORIGIN, websocketOriginValue(wsURL));
if (!headers.contains(HttpHeaderNames.SEC_WEBSOCKET_ORIGIN)) {
headers.set(HttpHeaderNames.SEC_WEBSOCKET_ORIGIN, websocketOriginValue(wsURL));
}
String expectedSubprotocol = expectedSubprotocol(); String expectedSubprotocol = expectedSubprotocol();
if (expectedSubprotocol != null && !expectedSubprotocol.isEmpty()) { if (expectedSubprotocol != null && !expectedSubprotocol.isEmpty()) {

View File

@ -225,8 +225,11 @@ public class WebSocketClientHandshaker08 extends WebSocketClientHandshaker {
headers.set(HttpHeaderNames.UPGRADE, HttpHeaderValues.WEBSOCKET) headers.set(HttpHeaderNames.UPGRADE, HttpHeaderValues.WEBSOCKET)
.set(HttpHeaderNames.CONNECTION, HttpHeaderValues.UPGRADE) .set(HttpHeaderNames.CONNECTION, HttpHeaderValues.UPGRADE)
.set(HttpHeaderNames.SEC_WEBSOCKET_KEY, key) .set(HttpHeaderNames.SEC_WEBSOCKET_KEY, key)
.set(HttpHeaderNames.HOST, websocketHostValue(wsURL)) .set(HttpHeaderNames.HOST, websocketHostValue(wsURL));
.set(HttpHeaderNames.SEC_WEBSOCKET_ORIGIN, websocketOriginValue(wsURL));
if (!headers.contains(HttpHeaderNames.SEC_WEBSOCKET_ORIGIN)) {
headers.set(HttpHeaderNames.SEC_WEBSOCKET_ORIGIN, websocketOriginValue(wsURL));
}
String expectedSubprotocol = expectedSubprotocol(); String expectedSubprotocol = expectedSubprotocol();
if (expectedSubprotocol != null && !expectedSubprotocol.isEmpty()) { if (expectedSubprotocol != null && !expectedSubprotocol.isEmpty()) {

View File

@ -226,8 +226,11 @@ public class WebSocketClientHandshaker13 extends WebSocketClientHandshaker {
headers.set(HttpHeaderNames.UPGRADE, HttpHeaderValues.WEBSOCKET) headers.set(HttpHeaderNames.UPGRADE, HttpHeaderValues.WEBSOCKET)
.set(HttpHeaderNames.CONNECTION, HttpHeaderValues.UPGRADE) .set(HttpHeaderNames.CONNECTION, HttpHeaderValues.UPGRADE)
.set(HttpHeaderNames.SEC_WEBSOCKET_KEY, key) .set(HttpHeaderNames.SEC_WEBSOCKET_KEY, key)
.set(HttpHeaderNames.HOST, websocketHostValue(wsURL)) .set(HttpHeaderNames.HOST, websocketHostValue(wsURL));
.set(HttpHeaderNames.ORIGIN, websocketOriginValue(wsURL));
if (!headers.contains(HttpHeaderNames.ORIGIN)) {
headers.set(HttpHeaderNames.ORIGIN, websocketOriginValue(wsURL));
}
String expectedSubprotocol = expectedSubprotocol(); String expectedSubprotocol = expectedSubprotocol();
if (expectedSubprotocol != null && !expectedSubprotocol.isEmpty()) { if (expectedSubprotocol != null && !expectedSubprotocol.isEmpty()) {

View File

@ -39,12 +39,11 @@ public class WebSocketClientHandshaker00Test extends WebSocketClientHandshakerTe
} }
@Override @Override
protected CharSequence[] getHandshakeHeaderNames() { protected CharSequence[] getHandshakeRequiredHeaderNames() {
return new CharSequence[] { return new CharSequence[] {
HttpHeaderNames.CONNECTION, HttpHeaderNames.CONNECTION,
HttpHeaderNames.UPGRADE, HttpHeaderNames.UPGRADE,
HttpHeaderNames.HOST, HttpHeaderNames.HOST,
HttpHeaderNames.ORIGIN,
HttpHeaderNames.SEC_WEBSOCKET_KEY1, HttpHeaderNames.SEC_WEBSOCKET_KEY1,
HttpHeaderNames.SEC_WEBSOCKET_KEY2, HttpHeaderNames.SEC_WEBSOCKET_KEY2,
}; };

View File

@ -40,13 +40,12 @@ public class WebSocketClientHandshaker07Test extends WebSocketClientHandshakerTe
} }
@Override @Override
protected CharSequence[] getHandshakeHeaderNames() { protected CharSequence[] getHandshakeRequiredHeaderNames() {
return new CharSequence[] { return new CharSequence[] {
HttpHeaderNames.UPGRADE, HttpHeaderNames.UPGRADE,
HttpHeaderNames.CONNECTION, HttpHeaderNames.CONNECTION,
HttpHeaderNames.SEC_WEBSOCKET_KEY, HttpHeaderNames.SEC_WEBSOCKET_KEY,
HttpHeaderNames.HOST, HttpHeaderNames.HOST,
getOriginHeaderName(),
HttpHeaderNames.SEC_WEBSOCKET_VERSION, HttpHeaderNames.SEC_WEBSOCKET_VERSION,
}; };
} }

View File

@ -50,7 +50,7 @@ public abstract class WebSocketClientHandshakerTest {
protected abstract CharSequence getProtocolHeaderName(); protected abstract CharSequence getProtocolHeaderName();
protected abstract CharSequence[] getHandshakeHeaderNames(); protected abstract CharSequence[] getHandshakeRequiredHeaderNames();
@Test @Test
public void hostHeaderWs() { public void hostHeaderWs() {
@ -160,6 +160,19 @@ public abstract class WebSocketClientHandshakerTest {
testOriginHeader("//LOCALHOST/", "http://localhost"); testOriginHeader("//LOCALHOST/", "http://localhost");
} }
@Test
public void testSetOriginFromCustomHeaders() {
HttpHeaders customHeaders = new DefaultHttpHeaders().set(getOriginHeaderName(), "http://example.com");
WebSocketClientHandshaker handshaker = newHandshaker(URI.create("ws://server.example.com/chat"), null,
customHeaders, false);
FullHttpRequest request = handshaker.newHandshakeRequest();
try {
assertEquals("http://example.com", request.headers().get(getOriginHeaderName()));
} finally {
request.release();
}
}
private void testHostHeader(String uri, String expected) { private void testHostHeader(String uri, String expected) {
testHeaderDefaultHttp(uri, HttpHeaderNames.HOST, expected); testHeaderDefaultHttp(uri, HttpHeaderNames.HOST, expected);
} }
@ -325,7 +338,7 @@ public abstract class WebSocketClientHandshakerTest {
String bogusHeaderValue = "bogusHeaderValue"; String bogusHeaderValue = "bogusHeaderValue";
// add values for the headers that are reserved for use in the websockets handshake // add values for the headers that are reserved for use in the websockets handshake
for (CharSequence header : getHandshakeHeaderNames()) { for (CharSequence header : getHandshakeRequiredHeaderNames()) {
inputHeaders.add(header, bogusHeaderValue); inputHeaders.add(header, bogusHeaderValue);
} }
inputHeaders.add(getProtocolHeaderName(), bogusSubProtocol); inputHeaders.add(getProtocolHeaderName(), bogusSubProtocol);
@ -336,7 +349,7 @@ public abstract class WebSocketClientHandshakerTest {
HttpHeaders outputHeaders = request.headers(); HttpHeaders outputHeaders = request.headers();
// the header values passed in originally have been replaced with values generated by the Handshaker // the header values passed in originally have been replaced with values generated by the Handshaker
for (CharSequence header : getHandshakeHeaderNames()) { for (CharSequence header : getHandshakeRequiredHeaderNames()) {
assertEquals(1, outputHeaders.getAll(header).size()); assertEquals(1, outputHeaders.getAll(header).size());
assertNotEquals(bogusHeaderValue, outputHeaders.get(header)); assertNotEquals(bogusHeaderValue, outputHeaders.get(header));
} }