Set the ORIGIN header from a custom headers if present (#9435)
Motivation: Allow to set the ORIGIN header value from custom headers in WebSocketClientHandshaker Modification: Only override header if not present already Result: More flexible handshaker usage
This commit is contained in:
parent
fedcc40196
commit
b8ac02d8ac
@ -189,10 +189,13 @@ public class WebSocketClientHandshaker00 extends WebSocketClientHandshaker {
|
|||||||
headers.set(HttpHeaderNames.UPGRADE, WEBSOCKET)
|
headers.set(HttpHeaderNames.UPGRADE, WEBSOCKET)
|
||||||
.set(HttpHeaderNames.CONNECTION, HttpHeaderValues.UPGRADE)
|
.set(HttpHeaderNames.CONNECTION, HttpHeaderValues.UPGRADE)
|
||||||
.set(HttpHeaderNames.HOST, websocketHostValue(wsURL))
|
.set(HttpHeaderNames.HOST, websocketHostValue(wsURL))
|
||||||
.set(HttpHeaderNames.ORIGIN, websocketOriginValue(wsURL))
|
|
||||||
.set(HttpHeaderNames.SEC_WEBSOCKET_KEY1, key1)
|
.set(HttpHeaderNames.SEC_WEBSOCKET_KEY1, key1)
|
||||||
.set(HttpHeaderNames.SEC_WEBSOCKET_KEY2, key2);
|
.set(HttpHeaderNames.SEC_WEBSOCKET_KEY2, key2);
|
||||||
|
|
||||||
|
if (!headers.contains(HttpHeaderNames.ORIGIN)) {
|
||||||
|
headers.set(HttpHeaderNames.ORIGIN, websocketOriginValue(wsURL));
|
||||||
|
}
|
||||||
|
|
||||||
String expectedSubprotocol = expectedSubprotocol();
|
String expectedSubprotocol = expectedSubprotocol();
|
||||||
if (expectedSubprotocol != null && !expectedSubprotocol.isEmpty()) {
|
if (expectedSubprotocol != null && !expectedSubprotocol.isEmpty()) {
|
||||||
headers.set(HttpHeaderNames.SEC_WEBSOCKET_PROTOCOL, expectedSubprotocol);
|
headers.set(HttpHeaderNames.SEC_WEBSOCKET_PROTOCOL, expectedSubprotocol);
|
||||||
|
@ -223,8 +223,11 @@ public class WebSocketClientHandshaker07 extends WebSocketClientHandshaker {
|
|||||||
headers.set(HttpHeaderNames.UPGRADE, HttpHeaderValues.WEBSOCKET)
|
headers.set(HttpHeaderNames.UPGRADE, HttpHeaderValues.WEBSOCKET)
|
||||||
.set(HttpHeaderNames.CONNECTION, HttpHeaderValues.UPGRADE)
|
.set(HttpHeaderNames.CONNECTION, HttpHeaderValues.UPGRADE)
|
||||||
.set(HttpHeaderNames.SEC_WEBSOCKET_KEY, key)
|
.set(HttpHeaderNames.SEC_WEBSOCKET_KEY, key)
|
||||||
.set(HttpHeaderNames.HOST, websocketHostValue(wsURL))
|
.set(HttpHeaderNames.HOST, websocketHostValue(wsURL));
|
||||||
.set(HttpHeaderNames.SEC_WEBSOCKET_ORIGIN, websocketOriginValue(wsURL));
|
|
||||||
|
if (!headers.contains(HttpHeaderNames.SEC_WEBSOCKET_ORIGIN)) {
|
||||||
|
headers.set(HttpHeaderNames.SEC_WEBSOCKET_ORIGIN, websocketOriginValue(wsURL));
|
||||||
|
}
|
||||||
|
|
||||||
String expectedSubprotocol = expectedSubprotocol();
|
String expectedSubprotocol = expectedSubprotocol();
|
||||||
if (expectedSubprotocol != null && !expectedSubprotocol.isEmpty()) {
|
if (expectedSubprotocol != null && !expectedSubprotocol.isEmpty()) {
|
||||||
|
@ -225,8 +225,11 @@ public class WebSocketClientHandshaker08 extends WebSocketClientHandshaker {
|
|||||||
headers.set(HttpHeaderNames.UPGRADE, HttpHeaderValues.WEBSOCKET)
|
headers.set(HttpHeaderNames.UPGRADE, HttpHeaderValues.WEBSOCKET)
|
||||||
.set(HttpHeaderNames.CONNECTION, HttpHeaderValues.UPGRADE)
|
.set(HttpHeaderNames.CONNECTION, HttpHeaderValues.UPGRADE)
|
||||||
.set(HttpHeaderNames.SEC_WEBSOCKET_KEY, key)
|
.set(HttpHeaderNames.SEC_WEBSOCKET_KEY, key)
|
||||||
.set(HttpHeaderNames.HOST, websocketHostValue(wsURL))
|
.set(HttpHeaderNames.HOST, websocketHostValue(wsURL));
|
||||||
.set(HttpHeaderNames.SEC_WEBSOCKET_ORIGIN, websocketOriginValue(wsURL));
|
|
||||||
|
if (!headers.contains(HttpHeaderNames.SEC_WEBSOCKET_ORIGIN)) {
|
||||||
|
headers.set(HttpHeaderNames.SEC_WEBSOCKET_ORIGIN, websocketOriginValue(wsURL));
|
||||||
|
}
|
||||||
|
|
||||||
String expectedSubprotocol = expectedSubprotocol();
|
String expectedSubprotocol = expectedSubprotocol();
|
||||||
if (expectedSubprotocol != null && !expectedSubprotocol.isEmpty()) {
|
if (expectedSubprotocol != null && !expectedSubprotocol.isEmpty()) {
|
||||||
|
@ -226,8 +226,11 @@ public class WebSocketClientHandshaker13 extends WebSocketClientHandshaker {
|
|||||||
headers.set(HttpHeaderNames.UPGRADE, HttpHeaderValues.WEBSOCKET)
|
headers.set(HttpHeaderNames.UPGRADE, HttpHeaderValues.WEBSOCKET)
|
||||||
.set(HttpHeaderNames.CONNECTION, HttpHeaderValues.UPGRADE)
|
.set(HttpHeaderNames.CONNECTION, HttpHeaderValues.UPGRADE)
|
||||||
.set(HttpHeaderNames.SEC_WEBSOCKET_KEY, key)
|
.set(HttpHeaderNames.SEC_WEBSOCKET_KEY, key)
|
||||||
.set(HttpHeaderNames.HOST, websocketHostValue(wsURL))
|
.set(HttpHeaderNames.HOST, websocketHostValue(wsURL));
|
||||||
.set(HttpHeaderNames.ORIGIN, websocketOriginValue(wsURL));
|
|
||||||
|
if (!headers.contains(HttpHeaderNames.ORIGIN)) {
|
||||||
|
headers.set(HttpHeaderNames.ORIGIN, websocketOriginValue(wsURL));
|
||||||
|
}
|
||||||
|
|
||||||
String expectedSubprotocol = expectedSubprotocol();
|
String expectedSubprotocol = expectedSubprotocol();
|
||||||
if (expectedSubprotocol != null && !expectedSubprotocol.isEmpty()) {
|
if (expectedSubprotocol != null && !expectedSubprotocol.isEmpty()) {
|
||||||
|
@ -39,12 +39,11 @@ public class WebSocketClientHandshaker00Test extends WebSocketClientHandshakerTe
|
|||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
protected CharSequence[] getHandshakeHeaderNames() {
|
protected CharSequence[] getHandshakeRequiredHeaderNames() {
|
||||||
return new CharSequence[] {
|
return new CharSequence[] {
|
||||||
HttpHeaderNames.CONNECTION,
|
HttpHeaderNames.CONNECTION,
|
||||||
HttpHeaderNames.UPGRADE,
|
HttpHeaderNames.UPGRADE,
|
||||||
HttpHeaderNames.HOST,
|
HttpHeaderNames.HOST,
|
||||||
HttpHeaderNames.ORIGIN,
|
|
||||||
HttpHeaderNames.SEC_WEBSOCKET_KEY1,
|
HttpHeaderNames.SEC_WEBSOCKET_KEY1,
|
||||||
HttpHeaderNames.SEC_WEBSOCKET_KEY2,
|
HttpHeaderNames.SEC_WEBSOCKET_KEY2,
|
||||||
};
|
};
|
||||||
|
@ -40,13 +40,12 @@ public class WebSocketClientHandshaker07Test extends WebSocketClientHandshakerTe
|
|||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
protected CharSequence[] getHandshakeHeaderNames() {
|
protected CharSequence[] getHandshakeRequiredHeaderNames() {
|
||||||
return new CharSequence[] {
|
return new CharSequence[] {
|
||||||
HttpHeaderNames.UPGRADE,
|
HttpHeaderNames.UPGRADE,
|
||||||
HttpHeaderNames.CONNECTION,
|
HttpHeaderNames.CONNECTION,
|
||||||
HttpHeaderNames.SEC_WEBSOCKET_KEY,
|
HttpHeaderNames.SEC_WEBSOCKET_KEY,
|
||||||
HttpHeaderNames.HOST,
|
HttpHeaderNames.HOST,
|
||||||
getOriginHeaderName(),
|
|
||||||
HttpHeaderNames.SEC_WEBSOCKET_VERSION,
|
HttpHeaderNames.SEC_WEBSOCKET_VERSION,
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
@ -50,7 +50,7 @@ public abstract class WebSocketClientHandshakerTest {
|
|||||||
|
|
||||||
protected abstract CharSequence getProtocolHeaderName();
|
protected abstract CharSequence getProtocolHeaderName();
|
||||||
|
|
||||||
protected abstract CharSequence[] getHandshakeHeaderNames();
|
protected abstract CharSequence[] getHandshakeRequiredHeaderNames();
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void hostHeaderWs() {
|
public void hostHeaderWs() {
|
||||||
@ -160,6 +160,19 @@ public abstract class WebSocketClientHandshakerTest {
|
|||||||
testOriginHeader("//LOCALHOST/", "http://localhost");
|
testOriginHeader("//LOCALHOST/", "http://localhost");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void testSetOriginFromCustomHeaders() {
|
||||||
|
HttpHeaders customHeaders = new DefaultHttpHeaders().set(getOriginHeaderName(), "http://example.com");
|
||||||
|
WebSocketClientHandshaker handshaker = newHandshaker(URI.create("ws://server.example.com/chat"), null,
|
||||||
|
customHeaders, false);
|
||||||
|
FullHttpRequest request = handshaker.newHandshakeRequest();
|
||||||
|
try {
|
||||||
|
assertEquals("http://example.com", request.headers().get(getOriginHeaderName()));
|
||||||
|
} finally {
|
||||||
|
request.release();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
private void testHostHeader(String uri, String expected) {
|
private void testHostHeader(String uri, String expected) {
|
||||||
testHeaderDefaultHttp(uri, HttpHeaderNames.HOST, expected);
|
testHeaderDefaultHttp(uri, HttpHeaderNames.HOST, expected);
|
||||||
}
|
}
|
||||||
@ -325,7 +338,7 @@ public abstract class WebSocketClientHandshakerTest {
|
|||||||
String bogusHeaderValue = "bogusHeaderValue";
|
String bogusHeaderValue = "bogusHeaderValue";
|
||||||
|
|
||||||
// add values for the headers that are reserved for use in the websockets handshake
|
// add values for the headers that are reserved for use in the websockets handshake
|
||||||
for (CharSequence header : getHandshakeHeaderNames()) {
|
for (CharSequence header : getHandshakeRequiredHeaderNames()) {
|
||||||
inputHeaders.add(header, bogusHeaderValue);
|
inputHeaders.add(header, bogusHeaderValue);
|
||||||
}
|
}
|
||||||
inputHeaders.add(getProtocolHeaderName(), bogusSubProtocol);
|
inputHeaders.add(getProtocolHeaderName(), bogusSubProtocol);
|
||||||
@ -336,7 +349,7 @@ public abstract class WebSocketClientHandshakerTest {
|
|||||||
HttpHeaders outputHeaders = request.headers();
|
HttpHeaders outputHeaders = request.headers();
|
||||||
|
|
||||||
// the header values passed in originally have been replaced with values generated by the Handshaker
|
// the header values passed in originally have been replaced with values generated by the Handshaker
|
||||||
for (CharSequence header : getHandshakeHeaderNames()) {
|
for (CharSequence header : getHandshakeRequiredHeaderNames()) {
|
||||||
assertEquals(1, outputHeaders.getAll(header).size());
|
assertEquals(1, outputHeaders.getAll(header).size());
|
||||||
assertNotEquals(bogusHeaderValue, outputHeaders.get(header));
|
assertNotEquals(bogusHeaderValue, outputHeaders.get(header));
|
||||||
}
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user