[#4533] Ensure replacement of decoder is delayed after finishHandshake() is called

Motivation:

If the user calls handshake.finishHandshake() we need to ensure that the user has the chance to setup the pipeline before any WebSocketFrames are read. Because of this we need
to delay the removal of the HttpRequestDecoder.

Modifications:

- Remove the HttpRequestDecoder via the EventLoop and so delay it which gives the user a chance to setup the pipeline after finishHandshake() completes
- Add unit test for this.

Result:

Less surpising and correct behaviour even if the http response and websocket frame are received in one read operation.
This commit is contained in:
Norman Maurer 2016-02-02 10:39:41 +01:00
parent 7e057de98b
commit d0b0e028d0
2 changed files with 146 additions and 4 deletions

View File

@ -33,6 +33,7 @@ import io.netty.handler.codec.http.HttpResponse;
import io.netty.handler.codec.http.HttpResponseDecoder; import io.netty.handler.codec.http.HttpResponseDecoder;
import io.netty.util.ReferenceCountUtil; import io.netty.util.ReferenceCountUtil;
import io.netty.util.internal.EmptyArrays; import io.netty.util.internal.EmptyArrays;
import io.netty.util.internal.OneTimeTask;
import io.netty.util.internal.StringUtil; import io.netty.util.internal.StringUtil;
import java.net.URI; import java.net.URI;
@ -242,7 +243,7 @@ public abstract class WebSocketClientHandshaker {
setHandshakeComplete(); setHandshakeComplete();
ChannelPipeline p = channel.pipeline(); final ChannelPipeline p = channel.pipeline();
// Remove decompressor from pipeline if its in use // Remove decompressor from pipeline if its in use
HttpContentDecompressor decompressor = p.get(HttpContentDecompressor.class); HttpContentDecompressor decompressor = p.get(HttpContentDecompressor.class);
if (decompressor != null) { if (decompressor != null) {
@ -262,13 +263,38 @@ public abstract class WebSocketClientHandshaker {
throw new IllegalStateException("ChannelPipeline does not contain " + throw new IllegalStateException("ChannelPipeline does not contain " +
"a HttpRequestEncoder or HttpClientCodec"); "a HttpRequestEncoder or HttpClientCodec");
} }
p.replace(ctx.name(), "ws-decoder", newWebsocketDecoder()); final HttpClientCodec codec = (HttpClientCodec) ctx.handler();
// Remove the encoder part of the codec as the user may start writing frames after this method returns.
codec.removeOutboundHandler();
p.addAfter(ctx.name(), "ws-decoder", newWebsocketDecoder());
// Delay the removal of the decoder so the user can setup the pipeline if needed to handle
// WebSocketFrame messages.
// See https://github.com/netty/netty/issues/4533
channel.eventLoop().execute(new OneTimeTask() {
@Override
public void run() {
p.remove(codec);
}
});
} else { } else {
if (p.get(HttpRequestEncoder.class) != null) { if (p.get(HttpRequestEncoder.class) != null) {
// Remove the encoder part of the codec as the user may start writing frames after this method returns.
p.remove(HttpRequestEncoder.class); p.remove(HttpRequestEncoder.class);
} }
p.replace(ctx.name(), final ChannelHandlerContext context = ctx;
"ws-decoder", newWebsocketDecoder()); p.addAfter(context.name(), "ws-decoder", newWebsocketDecoder());
// Delay the removal of the decoder so the user can setup the pipeline if needed to handle
// WebSocketFrame messages.
// See https://github.com/netty/netty/issues/4533
channel.eventLoop().execute(new OneTimeTask() {
@Override
public void run() {
p.remove(context.handler());
}
});
} }
} }

View File

@ -15,12 +15,27 @@
*/ */
package io.netty.handler.codec.http.websocketx; package io.netty.handler.codec.http.websocketx;
import io.netty.buffer.ByteBuf;
import io.netty.buffer.CompositeByteBuf;
import io.netty.buffer.Unpooled;
import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.SimpleChannelInboundHandler;
import io.netty.channel.embedded.EmbeddedChannel;
import io.netty.handler.codec.http.FullHttpRequest; import io.netty.handler.codec.http.FullHttpRequest;
import io.netty.handler.codec.http.FullHttpResponse;
import io.netty.handler.codec.http.HttpClientCodec;
import io.netty.handler.codec.http.HttpHeaders;
import io.netty.handler.codec.http.HttpObjectAggregator;
import io.netty.handler.codec.http.HttpRequestEncoder;
import io.netty.handler.codec.http.HttpResponseDecoder;
import io.netty.util.CharsetUtil;
import io.netty.util.internal.ThreadLocalRandom;
import org.junit.Test; import org.junit.Test;
import java.net.URI; import java.net.URI;
import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertTrue;
public abstract class WebSocketClientHandshakerTest { public abstract class WebSocketClientHandshakerTest {
protected abstract WebSocketClientHandshaker newHandshaker(URI uri); protected abstract WebSocketClientHandshaker newHandshaker(URI uri);
@ -36,4 +51,105 @@ public abstract class WebSocketClientHandshakerTest {
request.release(); request.release();
} }
} }
@Test(timeout = 3000)
public void testHttpResponseAndFrameInSameBuffer() {
testHttpResponseAndFrameInSameBuffer(false);
}
@Test(timeout = 3000)
public void testHttpResponseAndFrameInSameBufferCodec() {
testHttpResponseAndFrameInSameBuffer(true);
}
private void testHttpResponseAndFrameInSameBuffer(boolean codec) {
String url = "ws://localhost:9999/ws";
final WebSocketClientHandshaker shaker = newHandshaker(URI.create(url));
final WebSocketClientHandshaker handshaker = new WebSocketClientHandshaker(
shaker.uri(), shaker.version(), null, HttpHeaders.EMPTY_HEADERS, Integer.MAX_VALUE) {
@Override
protected FullHttpRequest newHandshakeRequest() {
return shaker.newHandshakeRequest();
}
@Override
protected void verify(FullHttpResponse response) {
// Not do any verification, so we not need to care sending the correct headers etc in the test,
// which would just make things more complicated.
}
@Override
protected WebSocketFrameDecoder newWebsocketDecoder() {
return shaker.newWebsocketDecoder();
}
@Override
protected WebSocketFrameEncoder newWebSocketEncoder() {
return shaker.newWebSocketEncoder();
}
};
byte[] data = new byte[24];
ThreadLocalRandom.current().nextBytes(data);
// Create a EmbeddedChannel which we will use to encode a BinaryWebsocketFrame to bytes and so use these
// to test the actual handshaker.
WebSocketServerHandshakerFactory factory = new WebSocketServerHandshakerFactory(url, null, false);
WebSocketServerHandshaker socketServerHandshaker = factory.newHandshaker(shaker.newHandshakeRequest());
EmbeddedChannel websocketChannel = new EmbeddedChannel(socketServerHandshaker.newWebSocketEncoder(),
socketServerHandshaker.newWebsocketDecoder());
assertTrue(websocketChannel.writeOutbound(new BinaryWebSocketFrame(Unpooled.wrappedBuffer(data))));
byte[] bytes = "HTTP/1.1 101 Switching Protocols\r\nContent-Length: 0\r\n\r\n".getBytes(CharsetUtil.US_ASCII);
CompositeByteBuf compositeByteBuf = Unpooled.compositeBuffer();
compositeByteBuf.addComponent(Unpooled.wrappedBuffer(bytes));
compositeByteBuf.writerIndex(compositeByteBuf.writerIndex() + bytes.length);
for (;;) {
ByteBuf frameBytes = (ByteBuf) websocketChannel.readOutbound();
if (frameBytes == null) {
break;
}
compositeByteBuf.addComponent(frameBytes);
compositeByteBuf.writerIndex(compositeByteBuf.writerIndex() + frameBytes.readableBytes());
}
EmbeddedChannel ch = new EmbeddedChannel(new HttpObjectAggregator(Integer.MAX_VALUE),
new SimpleChannelInboundHandler<FullHttpResponse>() {
@Override
protected void channelRead0(ChannelHandlerContext ctx, FullHttpResponse msg) throws Exception {
handshaker.finishHandshake(ctx.channel(), msg);
ctx.pipeline().remove(this);
}
});
if (codec) {
ch.pipeline().addFirst(new HttpClientCodec());
} else {
ch.pipeline().addFirst(new HttpRequestEncoder(), new HttpResponseDecoder());
}
// We need to first write the request as HttpClientCodec will fail if we receive a response before a request
// was written.
shaker.handshake(ch).syncUninterruptibly();
for (;;) {
// Just consume the bytes, we are not interested in these.
ByteBuf buf = (ByteBuf) ch.readOutbound();
if (buf == null) {
break;
}
buf.release();
}
assertTrue(ch.writeInbound(compositeByteBuf));
assertTrue(ch.finish());
BinaryWebSocketFrame frame = (BinaryWebSocketFrame) ch.readInbound();
ByteBuf expect = Unpooled.wrappedBuffer(data);
try {
assertEquals(expect, frame.content());
assertTrue(frame.isFinalFragment());
assertEquals(0, frame.rsv());
} finally {
expect.release();
frame.release();
}
}
} }