From 7a562943ad5dfc430c75de8cdeb2a2aa0a68244e Mon Sep 17 00:00:00 2001 From: Norman Maurer Date: Tue, 2 Feb 2016 10:39:41 +0100 Subject: [PATCH] [#4533] Ensure replacement of decoder is delayed after finishHandshake() is called Motivation: If the user calls handshake.finishHandshake() we need to ensure that the user has the chance to setup the pipeline before any WebSocketFrames are read. Because of this we need to delay the removal of the HttpRequestDecoder. Modifications: - Remove the HttpRequestDecoder via the EventLoop and so delay it which gives the user a chance to setup the pipeline after finishHandshake() completes - Add unit test for this. Result: Less surpising and correct behaviour even if the http response and websocket frame are received in one read operation. --- .../websocketx/WebSocketClientHandshaker.java | 34 ++++- .../WebSocketClientHandshakerTest.java | 116 ++++++++++++++++++ 2 files changed, 146 insertions(+), 4 deletions(-) diff --git a/codec-http/src/main/java/io/netty/handler/codec/http/websocketx/WebSocketClientHandshaker.java b/codec-http/src/main/java/io/netty/handler/codec/http/websocketx/WebSocketClientHandshaker.java index 9931fa619a..4f9daedde5 100644 --- a/codec-http/src/main/java/io/netty/handler/codec/http/websocketx/WebSocketClientHandshaker.java +++ b/codec-http/src/main/java/io/netty/handler/codec/http/websocketx/WebSocketClientHandshaker.java @@ -34,6 +34,7 @@ import io.netty.handler.codec.http.HttpResponse; import io.netty.handler.codec.http.HttpResponseDecoder; import io.netty.util.ReferenceCountUtil; import io.netty.util.internal.EmptyArrays; +import io.netty.util.internal.OneTimeTask; import io.netty.util.internal.StringUtil; import java.net.URI; @@ -243,7 +244,7 @@ public abstract class WebSocketClientHandshaker { setHandshakeComplete(); - ChannelPipeline p = channel.pipeline(); + final ChannelPipeline p = channel.pipeline(); // Remove decompressor from pipeline if its in use HttpContentDecompressor decompressor = p.get(HttpContentDecompressor.class); if (decompressor != null) { @@ -263,13 +264,38 @@ public abstract class WebSocketClientHandshaker { throw new IllegalStateException("ChannelPipeline does not contain " + "a HttpRequestEncoder or HttpClientCodec"); } - p.replace(ctx.name(), "ws-decoder", newWebsocketDecoder()); + final HttpClientCodec codec = (HttpClientCodec) ctx.handler(); + // Remove the encoder part of the codec as the user may start writing frames after this method returns. + codec.removeOutboundHandler(); + + p.addAfter(ctx.name(), "ws-decoder", newWebsocketDecoder()); + + // Delay the removal of the decoder so the user can setup the pipeline if needed to handle + // WebSocketFrame messages. + // See https://github.com/netty/netty/issues/4533 + channel.eventLoop().execute(new OneTimeTask() { + @Override + public void run() { + p.remove(codec); + } + }); } else { if (p.get(HttpRequestEncoder.class) != null) { + // Remove the encoder part of the codec as the user may start writing frames after this method returns. p.remove(HttpRequestEncoder.class); } - p.replace(ctx.name(), - "ws-decoder", newWebsocketDecoder()); + final ChannelHandlerContext context = ctx; + p.addAfter(context.name(), "ws-decoder", newWebsocketDecoder()); + + // Delay the removal of the decoder so the user can setup the pipeline if needed to handle + // WebSocketFrame messages. + // See https://github.com/netty/netty/issues/4533 + channel.eventLoop().execute(new OneTimeTask() { + @Override + public void run() { + p.remove(context.handler()); + } + }); } } diff --git a/codec-http/src/test/java/io/netty/handler/codec/http/websocketx/WebSocketClientHandshakerTest.java b/codec-http/src/test/java/io/netty/handler/codec/http/websocketx/WebSocketClientHandshakerTest.java index e45c92349e..427888e8a2 100644 --- a/codec-http/src/test/java/io/netty/handler/codec/http/websocketx/WebSocketClientHandshakerTest.java +++ b/codec-http/src/test/java/io/netty/handler/codec/http/websocketx/WebSocketClientHandshakerTest.java @@ -15,12 +15,27 @@ */ package io.netty.handler.codec.http.websocketx; +import io.netty.buffer.ByteBuf; +import io.netty.buffer.CompositeByteBuf; +import io.netty.buffer.Unpooled; +import io.netty.channel.ChannelHandlerContext; +import io.netty.channel.SimpleChannelInboundHandler; +import io.netty.channel.embedded.EmbeddedChannel; +import io.netty.handler.codec.http.EmptyHttpHeaders; import io.netty.handler.codec.http.FullHttpRequest; +import io.netty.handler.codec.http.FullHttpResponse; +import io.netty.handler.codec.http.HttpClientCodec; +import io.netty.handler.codec.http.HttpObjectAggregator; +import io.netty.handler.codec.http.HttpRequestEncoder; +import io.netty.handler.codec.http.HttpResponseDecoder; +import io.netty.util.CharsetUtil; +import io.netty.util.internal.ThreadLocalRandom; import org.junit.Test; import java.net.URI; import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertTrue; public abstract class WebSocketClientHandshakerTest { protected abstract WebSocketClientHandshaker newHandshaker(URI uri); @@ -36,4 +51,105 @@ public abstract class WebSocketClientHandshakerTest { request.release(); } } + + @Test(timeout = 3000) + public void testHttpResponseAndFrameInSameBuffer() { + testHttpResponseAndFrameInSameBuffer(false); + } + + @Test(timeout = 3000) + public void testHttpResponseAndFrameInSameBufferCodec() { + testHttpResponseAndFrameInSameBuffer(true); + } + + private void testHttpResponseAndFrameInSameBuffer(boolean codec) { + String url = "ws://localhost:9999/ws"; + final WebSocketClientHandshaker shaker = newHandshaker(URI.create(url)); + final WebSocketClientHandshaker handshaker = new WebSocketClientHandshaker( + shaker.uri(), shaker.version(), null, EmptyHttpHeaders.INSTANCE, Integer.MAX_VALUE) { + @Override + protected FullHttpRequest newHandshakeRequest() { + return shaker.newHandshakeRequest(); + } + + @Override + protected void verify(FullHttpResponse response) { + // Not do any verification, so we not need to care sending the correct headers etc in the test, + // which would just make things more complicated. + } + + @Override + protected WebSocketFrameDecoder newWebsocketDecoder() { + return shaker.newWebsocketDecoder(); + } + + @Override + protected WebSocketFrameEncoder newWebSocketEncoder() { + return shaker.newWebSocketEncoder(); + } + }; + + byte[] data = new byte[24]; + ThreadLocalRandom.current().nextBytes(data); + + // Create a EmbeddedChannel which we will use to encode a BinaryWebsocketFrame to bytes and so use these + // to test the actual handshaker. + WebSocketServerHandshakerFactory factory = new WebSocketServerHandshakerFactory(url, null, false); + WebSocketServerHandshaker socketServerHandshaker = factory.newHandshaker(shaker.newHandshakeRequest()); + EmbeddedChannel websocketChannel = new EmbeddedChannel(socketServerHandshaker.newWebSocketEncoder(), + socketServerHandshaker.newWebsocketDecoder()); + assertTrue(websocketChannel.writeOutbound(new BinaryWebSocketFrame(Unpooled.wrappedBuffer(data)))); + + byte[] bytes = "HTTP/1.1 101 Switching Protocols\r\nContent-Length: 0\r\n\r\n".getBytes(CharsetUtil.US_ASCII); + + CompositeByteBuf compositeByteBuf = Unpooled.compositeBuffer(); + compositeByteBuf.addComponent(Unpooled.wrappedBuffer(bytes)); + compositeByteBuf.writerIndex(compositeByteBuf.writerIndex() + bytes.length); + for (;;) { + ByteBuf frameBytes = websocketChannel.readOutbound(); + if (frameBytes == null) { + break; + } + compositeByteBuf.addComponent(frameBytes); + compositeByteBuf.writerIndex(compositeByteBuf.writerIndex() + frameBytes.readableBytes()); + } + + EmbeddedChannel ch = new EmbeddedChannel(new HttpObjectAggregator(Integer.MAX_VALUE), + new SimpleChannelInboundHandler() { + @Override + protected void channelRead0(ChannelHandlerContext ctx, FullHttpResponse msg) throws Exception { + handshaker.finishHandshake(ctx.channel(), msg); + ctx.pipeline().remove(this); + } + }); + if (codec) { + ch.pipeline().addFirst(new HttpClientCodec()); + } else { + ch.pipeline().addFirst(new HttpRequestEncoder(), new HttpResponseDecoder()); + } + // We need to first write the request as HttpClientCodec will fail if we receive a response before a request + // was written. + shaker.handshake(ch).syncUninterruptibly(); + for (;;) { + // Just consume the bytes, we are not interested in these. + ByteBuf buf = ch.readOutbound(); + if (buf == null) { + break; + } + buf.release(); + } + assertTrue(ch.writeInbound(compositeByteBuf)); + assertTrue(ch.finish()); + + BinaryWebSocketFrame frame = ch.readInbound(); + ByteBuf expect = Unpooled.wrappedBuffer(data); + try { + assertEquals(expect, frame.content()); + assertTrue(frame.isFinalFragment()); + assertEquals(0, frame.rsv()); + } finally { + expect.release(); + frame.release(); + } + } }