diff --git a/codec-http/src/main/java/io/netty/handler/codec/http/websocketx/WebSocketClientHandshaker.java b/codec-http/src/main/java/io/netty/handler/codec/http/websocketx/WebSocketClientHandshaker.java index 42ef07d704..a224313652 100644 --- a/codec-http/src/main/java/io/netty/handler/codec/http/websocketx/WebSocketClientHandshaker.java +++ b/codec-http/src/main/java/io/netty/handler/codec/http/websocketx/WebSocketClientHandshaker.java @@ -33,6 +33,7 @@ import io.netty.handler.codec.http.HttpResponse; import io.netty.handler.codec.http.HttpResponseDecoder; import io.netty.util.ReferenceCountUtil; import io.netty.util.internal.EmptyArrays; +import io.netty.util.internal.OneTimeTask; import io.netty.util.internal.StringUtil; import java.net.URI; @@ -242,7 +243,7 @@ public abstract class WebSocketClientHandshaker { setHandshakeComplete(); - ChannelPipeline p = channel.pipeline(); + final ChannelPipeline p = channel.pipeline(); // Remove decompressor from pipeline if its in use HttpContentDecompressor decompressor = p.get(HttpContentDecompressor.class); if (decompressor != null) { @@ -262,13 +263,38 @@ public abstract class WebSocketClientHandshaker { throw new IllegalStateException("ChannelPipeline does not contain " + "a HttpRequestEncoder or HttpClientCodec"); } - p.replace(ctx.name(), "ws-decoder", newWebsocketDecoder()); + final HttpClientCodec codec = (HttpClientCodec) ctx.handler(); + // Remove the encoder part of the codec as the user may start writing frames after this method returns. + codec.removeOutboundHandler(); + + p.addAfter(ctx.name(), "ws-decoder", newWebsocketDecoder()); + + // Delay the removal of the decoder so the user can setup the pipeline if needed to handle + // WebSocketFrame messages. + // See https://github.com/netty/netty/issues/4533 + channel.eventLoop().execute(new OneTimeTask() { + @Override + public void run() { + p.remove(codec); + } + }); } else { if (p.get(HttpRequestEncoder.class) != null) { + // Remove the encoder part of the codec as the user may start writing frames after this method returns. p.remove(HttpRequestEncoder.class); } - p.replace(ctx.name(), - "ws-decoder", newWebsocketDecoder()); + final ChannelHandlerContext context = ctx; + p.addAfter(context.name(), "ws-decoder", newWebsocketDecoder()); + + // Delay the removal of the decoder so the user can setup the pipeline if needed to handle + // WebSocketFrame messages. + // See https://github.com/netty/netty/issues/4533 + channel.eventLoop().execute(new OneTimeTask() { + @Override + public void run() { + p.remove(context.handler()); + } + }); } } diff --git a/codec-http/src/test/java/io/netty/handler/codec/http/websocketx/WebSocketClientHandshakerTest.java b/codec-http/src/test/java/io/netty/handler/codec/http/websocketx/WebSocketClientHandshakerTest.java index e45c92349e..2b81ad1f52 100644 --- a/codec-http/src/test/java/io/netty/handler/codec/http/websocketx/WebSocketClientHandshakerTest.java +++ b/codec-http/src/test/java/io/netty/handler/codec/http/websocketx/WebSocketClientHandshakerTest.java @@ -15,12 +15,27 @@ */ package io.netty.handler.codec.http.websocketx; +import io.netty.buffer.ByteBuf; +import io.netty.buffer.CompositeByteBuf; +import io.netty.buffer.Unpooled; +import io.netty.channel.ChannelHandlerContext; +import io.netty.channel.SimpleChannelInboundHandler; +import io.netty.channel.embedded.EmbeddedChannel; import io.netty.handler.codec.http.FullHttpRequest; +import io.netty.handler.codec.http.FullHttpResponse; +import io.netty.handler.codec.http.HttpClientCodec; +import io.netty.handler.codec.http.HttpHeaders; +import io.netty.handler.codec.http.HttpObjectAggregator; +import io.netty.handler.codec.http.HttpRequestEncoder; +import io.netty.handler.codec.http.HttpResponseDecoder; +import io.netty.util.CharsetUtil; +import io.netty.util.internal.ThreadLocalRandom; import org.junit.Test; import java.net.URI; import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertTrue; public abstract class WebSocketClientHandshakerTest { protected abstract WebSocketClientHandshaker newHandshaker(URI uri); @@ -36,4 +51,105 @@ public abstract class WebSocketClientHandshakerTest { request.release(); } } + + @Test(timeout = 3000) + public void testHttpResponseAndFrameInSameBuffer() { + testHttpResponseAndFrameInSameBuffer(false); + } + + @Test(timeout = 3000) + public void testHttpResponseAndFrameInSameBufferCodec() { + testHttpResponseAndFrameInSameBuffer(true); + } + + private void testHttpResponseAndFrameInSameBuffer(boolean codec) { + String url = "ws://localhost:9999/ws"; + final WebSocketClientHandshaker shaker = newHandshaker(URI.create(url)); + final WebSocketClientHandshaker handshaker = new WebSocketClientHandshaker( + shaker.uri(), shaker.version(), null, HttpHeaders.EMPTY_HEADERS, Integer.MAX_VALUE) { + @Override + protected FullHttpRequest newHandshakeRequest() { + return shaker.newHandshakeRequest(); + } + + @Override + protected void verify(FullHttpResponse response) { + // Not do any verification, so we not need to care sending the correct headers etc in the test, + // which would just make things more complicated. + } + + @Override + protected WebSocketFrameDecoder newWebsocketDecoder() { + return shaker.newWebsocketDecoder(); + } + + @Override + protected WebSocketFrameEncoder newWebSocketEncoder() { + return shaker.newWebSocketEncoder(); + } + }; + + byte[] data = new byte[24]; + ThreadLocalRandom.current().nextBytes(data); + + // Create a EmbeddedChannel which we will use to encode a BinaryWebsocketFrame to bytes and so use these + // to test the actual handshaker. + WebSocketServerHandshakerFactory factory = new WebSocketServerHandshakerFactory(url, null, false); + WebSocketServerHandshaker socketServerHandshaker = factory.newHandshaker(shaker.newHandshakeRequest()); + EmbeddedChannel websocketChannel = new EmbeddedChannel(socketServerHandshaker.newWebSocketEncoder(), + socketServerHandshaker.newWebsocketDecoder()); + assertTrue(websocketChannel.writeOutbound(new BinaryWebSocketFrame(Unpooled.wrappedBuffer(data)))); + + byte[] bytes = "HTTP/1.1 101 Switching Protocols\r\nContent-Length: 0\r\n\r\n".getBytes(CharsetUtil.US_ASCII); + + CompositeByteBuf compositeByteBuf = Unpooled.compositeBuffer(); + compositeByteBuf.addComponent(Unpooled.wrappedBuffer(bytes)); + compositeByteBuf.writerIndex(compositeByteBuf.writerIndex() + bytes.length); + for (;;) { + ByteBuf frameBytes = (ByteBuf) websocketChannel.readOutbound(); + if (frameBytes == null) { + break; + } + compositeByteBuf.addComponent(frameBytes); + compositeByteBuf.writerIndex(compositeByteBuf.writerIndex() + frameBytes.readableBytes()); + } + + EmbeddedChannel ch = new EmbeddedChannel(new HttpObjectAggregator(Integer.MAX_VALUE), + new SimpleChannelInboundHandler() { + @Override + protected void channelRead0(ChannelHandlerContext ctx, FullHttpResponse msg) throws Exception { + handshaker.finishHandshake(ctx.channel(), msg); + ctx.pipeline().remove(this); + } + }); + if (codec) { + ch.pipeline().addFirst(new HttpClientCodec()); + } else { + ch.pipeline().addFirst(new HttpRequestEncoder(), new HttpResponseDecoder()); + } + // We need to first write the request as HttpClientCodec will fail if we receive a response before a request + // was written. + shaker.handshake(ch).syncUninterruptibly(); + for (;;) { + // Just consume the bytes, we are not interested in these. + ByteBuf buf = (ByteBuf) ch.readOutbound(); + if (buf == null) { + break; + } + buf.release(); + } + assertTrue(ch.writeInbound(compositeByteBuf)); + assertTrue(ch.finish()); + + BinaryWebSocketFrame frame = (BinaryWebSocketFrame) ch.readInbound(); + ByteBuf expect = Unpooled.wrappedBuffer(data); + try { + assertEquals(expect, frame.content()); + assertTrue(frame.isFinalFragment()); + assertEquals(0, frame.rsv()); + } finally { + expect.release(); + frame.release(); + } + } }