diff --git a/codec-http/src/main/java/io/netty/handler/codec/http/websocketx/WebSocketClientHandshaker.java b/codec-http/src/main/java/io/netty/handler/codec/http/websocketx/WebSocketClientHandshaker.java index ecd9323f71..59e3ed597a 100644 --- a/codec-http/src/main/java/io/netty/handler/codec/http/websocketx/WebSocketClientHandshaker.java +++ b/codec-http/src/main/java/io/netty/handler/codec/http/websocketx/WebSocketClientHandshaker.java @@ -28,6 +28,7 @@ import io.netty.handler.codec.http.HttpContentDecompressor; import io.netty.handler.codec.http.HttpHeaders; import io.netty.handler.codec.http.HttpRequestEncoder; import io.netty.handler.codec.http.HttpResponseDecoder; +import io.netty.util.internal.StringUtil; import java.net.URI; @@ -199,7 +200,35 @@ public abstract class WebSocketClientHandshaker { */ public final void finishHandshake(Channel channel, FullHttpResponse response) { verify(response); - setActualSubprotocol(response.headers().get(HttpHeaders.Names.SEC_WEBSOCKET_PROTOCOL)); + + // Verify the subprotocol that we received from the server. + // This must be one of our expected subprotocols - or null/empty if we didn't want to speak a subprotocol + String receivedProtocol = response.headers().get(HttpHeaders.Names.SEC_WEBSOCKET_PROTOCOL); + receivedProtocol = receivedProtocol != null ? receivedProtocol.trim() : null; + String expectedProtocol = expectedSubprotocol != null ? expectedSubprotocol : ""; + boolean protocolValid = false; + + if (expectedProtocol.isEmpty() && receivedProtocol == null) { + // No subprotocol required and none received + protocolValid = true; + setActualSubprotocol(expectedSubprotocol); // null or "" - we echo what the user requested + } else if (!expectedProtocol.isEmpty() && receivedProtocol != null && !receivedProtocol.isEmpty()) { + // We require a subprotocol and received one -> verify it + for (String protocol : StringUtil.split(expectedSubprotocol, ',')) { + if (protocol.trim().equals(receivedProtocol)) { + protocolValid = true; + setActualSubprotocol(receivedProtocol); + break; + } + } + } // else mixed cases - which are all errors + + if (!protocolValid) { + throw new WebSocketHandshakeException(String.format( + "Invalid subprotocol. Actual: %s. Expected one of: %s", + receivedProtocol, expectedSubprotocol)); + } + setHandshakeComplete(); ChannelPipeline p = channel.pipeline(); diff --git a/codec-http/src/main/java/io/netty/handler/codec/http/websocketx/WebSocketClientProtocolHandler.java b/codec-http/src/main/java/io/netty/handler/codec/http/websocketx/WebSocketClientProtocolHandler.java index 40dd991b3d..c818f7c95e 100644 --- a/codec-http/src/main/java/io/netty/handler/codec/http/websocketx/WebSocketClientProtocolHandler.java +++ b/codec-http/src/main/java/io/netty/handler/codec/http/websocketx/WebSocketClientProtocolHandler.java @@ -42,6 +42,11 @@ public class WebSocketClientProtocolHandler extends WebSocketProtocolHandler { private final WebSocketClientHandshaker handshaker; private final boolean handleCloseFrames; + /** + * Returns the used handshaker + */ + public WebSocketClientHandshaker handshaker() { return handshaker; } + /** * Events that are fired to notify about handshake status */