From 4340f6c116df6f54bd733995881c77f8d02152f6 Mon Sep 17 00:00:00 2001 From: Andrey Mizurov Date: Sun, 11 Aug 2019 09:22:17 +0300 Subject: [PATCH] Set the ORIGIN header from a custom headers if present (#9435) Motivation: Allow to set the ORIGIN header value from custom headers in WebSocketClientHandshaker Modification: Only override header if not present already Result: More flexible handshaker usage --- .../WebSocketClientHandshaker00.java | 5 ++++- .../WebSocketClientHandshaker07.java | 7 +++++-- .../WebSocketClientHandshaker08.java | 7 +++++-- .../WebSocketClientHandshaker13.java | 7 +++++-- .../WebSocketClientHandshaker00Test.java | 3 +-- .../WebSocketClientHandshaker07Test.java | 3 +-- .../WebSocketClientHandshakerTest.java | 19 ++++++++++++++++--- 7 files changed, 37 insertions(+), 14 deletions(-) diff --git a/codec-http/src/main/java/io/netty/handler/codec/http/websocketx/WebSocketClientHandshaker00.java b/codec-http/src/main/java/io/netty/handler/codec/http/websocketx/WebSocketClientHandshaker00.java index b8aefde4da..8c609ed24e 100644 --- a/codec-http/src/main/java/io/netty/handler/codec/http/websocketx/WebSocketClientHandshaker00.java +++ b/codec-http/src/main/java/io/netty/handler/codec/http/websocketx/WebSocketClientHandshaker00.java @@ -190,10 +190,13 @@ public class WebSocketClientHandshaker00 extends WebSocketClientHandshaker { headers.set(HttpHeaderNames.UPGRADE, WEBSOCKET) .set(HttpHeaderNames.CONNECTION, HttpHeaderValues.UPGRADE) .set(HttpHeaderNames.HOST, websocketHostValue(wsURL)) - .set(HttpHeaderNames.ORIGIN, websocketOriginValue(wsURL)) .set(HttpHeaderNames.SEC_WEBSOCKET_KEY1, key1) .set(HttpHeaderNames.SEC_WEBSOCKET_KEY2, key2); + if (!headers.contains(HttpHeaderNames.ORIGIN)) { + headers.set(HttpHeaderNames.ORIGIN, websocketOriginValue(wsURL)); + } + String expectedSubprotocol = expectedSubprotocol(); if (expectedSubprotocol != null && !expectedSubprotocol.isEmpty()) { headers.set(HttpHeaderNames.SEC_WEBSOCKET_PROTOCOL, expectedSubprotocol); diff --git a/codec-http/src/main/java/io/netty/handler/codec/http/websocketx/WebSocketClientHandshaker07.java b/codec-http/src/main/java/io/netty/handler/codec/http/websocketx/WebSocketClientHandshaker07.java index c221e29562..a2fce7c911 100644 --- a/codec-http/src/main/java/io/netty/handler/codec/http/websocketx/WebSocketClientHandshaker07.java +++ b/codec-http/src/main/java/io/netty/handler/codec/http/websocketx/WebSocketClientHandshaker07.java @@ -223,8 +223,11 @@ public class WebSocketClientHandshaker07 extends WebSocketClientHandshaker { headers.set(HttpHeaderNames.UPGRADE, HttpHeaderValues.WEBSOCKET) .set(HttpHeaderNames.CONNECTION, HttpHeaderValues.UPGRADE) .set(HttpHeaderNames.SEC_WEBSOCKET_KEY, key) - .set(HttpHeaderNames.HOST, websocketHostValue(wsURL)) - .set(HttpHeaderNames.SEC_WEBSOCKET_ORIGIN, websocketOriginValue(wsURL)); + .set(HttpHeaderNames.HOST, websocketHostValue(wsURL)); + + if (!headers.contains(HttpHeaderNames.SEC_WEBSOCKET_ORIGIN)) { + headers.set(HttpHeaderNames.SEC_WEBSOCKET_ORIGIN, websocketOriginValue(wsURL)); + } String expectedSubprotocol = expectedSubprotocol(); if (expectedSubprotocol != null && !expectedSubprotocol.isEmpty()) { diff --git a/codec-http/src/main/java/io/netty/handler/codec/http/websocketx/WebSocketClientHandshaker08.java b/codec-http/src/main/java/io/netty/handler/codec/http/websocketx/WebSocketClientHandshaker08.java index 8e720f4e83..245eff8753 100644 --- a/codec-http/src/main/java/io/netty/handler/codec/http/websocketx/WebSocketClientHandshaker08.java +++ b/codec-http/src/main/java/io/netty/handler/codec/http/websocketx/WebSocketClientHandshaker08.java @@ -225,8 +225,11 @@ public class WebSocketClientHandshaker08 extends WebSocketClientHandshaker { headers.set(HttpHeaderNames.UPGRADE, HttpHeaderValues.WEBSOCKET) .set(HttpHeaderNames.CONNECTION, HttpHeaderValues.UPGRADE) .set(HttpHeaderNames.SEC_WEBSOCKET_KEY, key) - .set(HttpHeaderNames.HOST, websocketHostValue(wsURL)) - .set(HttpHeaderNames.SEC_WEBSOCKET_ORIGIN, websocketOriginValue(wsURL)); + .set(HttpHeaderNames.HOST, websocketHostValue(wsURL)); + + if (!headers.contains(HttpHeaderNames.SEC_WEBSOCKET_ORIGIN)) { + headers.set(HttpHeaderNames.SEC_WEBSOCKET_ORIGIN, websocketOriginValue(wsURL)); + } String expectedSubprotocol = expectedSubprotocol(); if (expectedSubprotocol != null && !expectedSubprotocol.isEmpty()) { diff --git a/codec-http/src/main/java/io/netty/handler/codec/http/websocketx/WebSocketClientHandshaker13.java b/codec-http/src/main/java/io/netty/handler/codec/http/websocketx/WebSocketClientHandshaker13.java index 99004d2e18..8efbb4a2e5 100644 --- a/codec-http/src/main/java/io/netty/handler/codec/http/websocketx/WebSocketClientHandshaker13.java +++ b/codec-http/src/main/java/io/netty/handler/codec/http/websocketx/WebSocketClientHandshaker13.java @@ -226,8 +226,11 @@ public class WebSocketClientHandshaker13 extends WebSocketClientHandshaker { headers.set(HttpHeaderNames.UPGRADE, HttpHeaderValues.WEBSOCKET) .set(HttpHeaderNames.CONNECTION, HttpHeaderValues.UPGRADE) .set(HttpHeaderNames.SEC_WEBSOCKET_KEY, key) - .set(HttpHeaderNames.HOST, websocketHostValue(wsURL)) - .set(HttpHeaderNames.ORIGIN, websocketOriginValue(wsURL)); + .set(HttpHeaderNames.HOST, websocketHostValue(wsURL)); + + if (!headers.contains(HttpHeaderNames.ORIGIN)) { + headers.set(HttpHeaderNames.ORIGIN, websocketOriginValue(wsURL)); + } String expectedSubprotocol = expectedSubprotocol(); if (expectedSubprotocol != null && !expectedSubprotocol.isEmpty()) { diff --git a/codec-http/src/test/java/io/netty/handler/codec/http/websocketx/WebSocketClientHandshaker00Test.java b/codec-http/src/test/java/io/netty/handler/codec/http/websocketx/WebSocketClientHandshaker00Test.java index 9b0432a199..efbe22e2bf 100644 --- a/codec-http/src/test/java/io/netty/handler/codec/http/websocketx/WebSocketClientHandshaker00Test.java +++ b/codec-http/src/test/java/io/netty/handler/codec/http/websocketx/WebSocketClientHandshaker00Test.java @@ -39,12 +39,11 @@ public class WebSocketClientHandshaker00Test extends WebSocketClientHandshakerTe } @Override - protected CharSequence[] getHandshakeHeaderNames() { + protected CharSequence[] getHandshakeRequiredHeaderNames() { return new CharSequence[] { HttpHeaderNames.CONNECTION, HttpHeaderNames.UPGRADE, HttpHeaderNames.HOST, - HttpHeaderNames.ORIGIN, HttpHeaderNames.SEC_WEBSOCKET_KEY1, HttpHeaderNames.SEC_WEBSOCKET_KEY2, }; diff --git a/codec-http/src/test/java/io/netty/handler/codec/http/websocketx/WebSocketClientHandshaker07Test.java b/codec-http/src/test/java/io/netty/handler/codec/http/websocketx/WebSocketClientHandshaker07Test.java index acc10d7c24..4d9be6cb67 100644 --- a/codec-http/src/test/java/io/netty/handler/codec/http/websocketx/WebSocketClientHandshaker07Test.java +++ b/codec-http/src/test/java/io/netty/handler/codec/http/websocketx/WebSocketClientHandshaker07Test.java @@ -40,13 +40,12 @@ public class WebSocketClientHandshaker07Test extends WebSocketClientHandshakerTe } @Override - protected CharSequence[] getHandshakeHeaderNames() { + protected CharSequence[] getHandshakeRequiredHeaderNames() { return new CharSequence[] { HttpHeaderNames.UPGRADE, HttpHeaderNames.CONNECTION, HttpHeaderNames.SEC_WEBSOCKET_KEY, HttpHeaderNames.HOST, - getOriginHeaderName(), HttpHeaderNames.SEC_WEBSOCKET_VERSION, }; } diff --git a/codec-http/src/test/java/io/netty/handler/codec/http/websocketx/WebSocketClientHandshakerTest.java b/codec-http/src/test/java/io/netty/handler/codec/http/websocketx/WebSocketClientHandshakerTest.java index ac6b6544d1..d92fce7378 100644 --- a/codec-http/src/test/java/io/netty/handler/codec/http/websocketx/WebSocketClientHandshakerTest.java +++ b/codec-http/src/test/java/io/netty/handler/codec/http/websocketx/WebSocketClientHandshakerTest.java @@ -51,7 +51,7 @@ public abstract class WebSocketClientHandshakerTest { protected abstract CharSequence getProtocolHeaderName(); - protected abstract CharSequence[] getHandshakeHeaderNames(); + protected abstract CharSequence[] getHandshakeRequiredHeaderNames(); @Test public void hostHeaderWs() { @@ -161,6 +161,19 @@ public abstract class WebSocketClientHandshakerTest { testOriginHeader("//LOCALHOST/", "http://localhost"); } + @Test + public void testSetOriginFromCustomHeaders() { + HttpHeaders customHeaders = new DefaultHttpHeaders().set(getOriginHeaderName(), "http://example.com"); + WebSocketClientHandshaker handshaker = newHandshaker(URI.create("ws://server.example.com/chat"), null, + customHeaders, false); + FullHttpRequest request = handshaker.newHandshakeRequest(); + try { + assertEquals("http://example.com", request.headers().get(getOriginHeaderName())); + } finally { + request.release(); + } + } + private void testHostHeader(String uri, String expected) { testHeaderDefaultHttp(uri, HttpHeaderNames.HOST, expected); } @@ -326,7 +339,7 @@ public abstract class WebSocketClientHandshakerTest { String bogusHeaderValue = "bogusHeaderValue"; // add values for the headers that are reserved for use in the websockets handshake - for (CharSequence header : getHandshakeHeaderNames()) { + for (CharSequence header : getHandshakeRequiredHeaderNames()) { inputHeaders.add(header, bogusHeaderValue); } inputHeaders.add(getProtocolHeaderName(), bogusSubProtocol); @@ -337,7 +350,7 @@ public abstract class WebSocketClientHandshakerTest { HttpHeaders outputHeaders = request.headers(); // the header values passed in originally have been replaced with values generated by the Handshaker - for (CharSequence header : getHandshakeHeaderNames()) { + for (CharSequence header : getHandshakeRequiredHeaderNames()) { assertEquals(1, outputHeaders.getAll(header).size()); assertNotEquals(bogusHeaderValue, outputHeaders.get(header)); }