diff --git a/codec-http/src/main/java/io/netty/handler/codec/http/websocketx/WebSocketServerProtocolHandshakeHandler.java b/codec-http/src/main/java/io/netty/handler/codec/http/websocketx/WebSocketServerProtocolHandshakeHandler.java index 0ec5647ed0..17673cca5b 100644 --- a/codec-http/src/main/java/io/netty/handler/codec/http/websocketx/WebSocketServerProtocolHandshakeHandler.java +++ b/codec-http/src/main/java/io/netty/handler/codec/http/websocketx/WebSocketServerProtocolHandshakeHandler.java @@ -88,6 +88,15 @@ class WebSocketServerProtocolHandshakeHandler implements ChannelInboundHandler { if (handshaker == null) { WebSocketServerHandshakerFactory.sendUnsupportedVersionResponse(ctx.channel()); } else { + // Ensure we set the handshaker and replace this handler before we + // trigger the actual handshake. Otherwise we may receive websocket bytes in this handler + // before we had a chance to replace it. + // + // See https://github.com/netty/netty/issues/9471. + WebSocketServerProtocolHandler.setHandshaker(ctx.channel(), handshaker); + ctx.pipeline().replace(this, "WS403Responder", + WebSocketServerProtocolHandler.forbiddenHttpRequestResponder()); + final ChannelFuture handshakeFuture = handshaker.handshake(ctx.channel(), req); handshakeFuture.addListener((ChannelFutureListener) future -> { if (!future.isSuccess()) { @@ -104,9 +113,6 @@ class WebSocketServerProtocolHandshakeHandler implements ChannelInboundHandler { } }); applyHandshakeTimeout(); - WebSocketServerProtocolHandler.setHandshaker(ctx.channel(), handshaker); - ctx.pipeline().replace(this, "WS403Responder", - WebSocketServerProtocolHandler.forbiddenHttpRequestResponder()); } } finally { req.release(); diff --git a/codec-http/src/test/java/io/netty/handler/codec/http/websocketx/WebSocketServerProtocolHandlerTest.java b/codec-http/src/test/java/io/netty/handler/codec/http/websocketx/WebSocketServerProtocolHandlerTest.java index 0dd44ee7f7..6238d0eb45 100644 --- a/codec-http/src/test/java/io/netty/handler/codec/http/websocketx/WebSocketServerProtocolHandlerTest.java +++ b/codec-http/src/test/java/io/netty/handler/codec/http/websocketx/WebSocketServerProtocolHandlerTest.java @@ -62,6 +62,28 @@ public class WebSocketServerProtocolHandlerTest { assertFalse(ch.finish()); } + @Test + public void testWebSocketServerProtocolHandshakeHandlerReplacedBeforeHandshake() throws Exception { + EmbeddedChannel ch = createChannel(new MockOutboundHandler()); + ChannelHandlerContext handshakerCtx = ch.pipeline().context(WebSocketServerProtocolHandshakeHandler.class); + ch.pipeline().addLast(new ChannelInboundHandlerAdapter() { + @Override + public void userEventTriggered(ChannelHandlerContext ctx, Object evt) throws Exception { + if (evt instanceof WebSocketServerProtocolHandler.HandshakeComplete) { + // We should have removed the handler already. + assertNull(ctx.pipeline().context(WebSocketServerProtocolHandshakeHandler.class)); + } + } + }); + writeUpgradeRequest(ch); + + FullHttpResponse response = responses.remove(); + assertEquals(SWITCHING_PROTOCOLS, response.status()); + response.release(); + assertNotNull(WebSocketServerProtocolHandler.getHandshaker(handshakerCtx.channel())); + assertFalse(ch.finish()); + } + @Test public void testSubsequentHttpRequestsAfterUpgradeShouldReturn403() throws Exception { EmbeddedChannel ch = createChannel();