Preserve order when using alternate event loops (#10069)

Motivation:

When the HttpContentCompressor is put on an alternate EventExecutor, the order of events should be
preserved so that the compressed content is correctly created.

Modifications:
- checking that the executor in the ChannelHandlerContext is the same executor as the current executor when evaluating if the handler should be skipped.
- Add unit test

Result:

Fixes https://github.com/netty/netty/issues/10067

Co-authored-by: Norman Maurer <norman_maurer@apple.com>
This commit is contained in:
Antony T Curtis 2020-03-03 01:42:41 -08:00 committed by GitHub
parent 0453222015
commit 7fce06a0d1
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 141 additions and 5 deletions

View File

@ -15,14 +15,32 @@
*/
package io.netty.handler.codec.http;
import io.netty.bootstrap.Bootstrap;
import io.netty.bootstrap.ServerBootstrap;
import io.netty.buffer.ByteBufUtil;
import io.netty.buffer.Unpooled;
import io.netty.channel.Channel;
import io.netty.channel.ChannelFuture;
import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.ChannelInboundHandlerAdapter;
import io.netty.channel.ChannelInitializer;
import io.netty.channel.ChannelOutboundHandlerAdapter;
import io.netty.channel.ChannelPromise;
import io.netty.channel.DefaultEventLoopGroup;
import io.netty.channel.EventLoopGroup;
import io.netty.channel.embedded.EmbeddedChannel;
import io.netty.channel.local.LocalAddress;
import io.netty.channel.local.LocalChannel;
import io.netty.channel.local.LocalServerChannel;
import io.netty.handler.codec.DecoderResult;
import io.netty.handler.codec.EncoderException;
import io.netty.handler.codec.compression.ZlibWrapper;
import io.netty.util.CharsetUtil;
import io.netty.util.ReferenceCountUtil;
import java.util.UUID;
import java.util.concurrent.BlockingQueue;
import java.util.concurrent.LinkedBlockingQueue;
import java.util.concurrent.TimeUnit;
import org.junit.Test;
import static io.netty.handler.codec.http.HttpHeadersTestUtils.of;
@ -259,6 +277,104 @@ public class HttpContentCompressorTest {
assertThat(ch.readOutbound(), is(nullValue()));
}
@Test
public void testExecutorPreserveOrdering() throws Exception {
final EventLoopGroup compressorGroup = new DefaultEventLoopGroup(1);
EventLoopGroup localGroup = new DefaultEventLoopGroup(1);
Channel server = null;
Channel client = null;
try {
ServerBootstrap bootstrap = new ServerBootstrap()
.channel(LocalServerChannel.class)
.group(localGroup)
.childHandler(new ChannelInitializer<LocalChannel>() {
@Override
protected void initChannel(LocalChannel ch) throws Exception {
ch.pipeline()
.addLast(new HttpServerCodec())
.addLast(new HttpObjectAggregator(1024))
.addLast(compressorGroup, new HttpContentCompressor())
.addLast(new ChannelOutboundHandlerAdapter() {
@Override
public void write(ChannelHandlerContext ctx, Object msg, ChannelPromise promise)
throws Exception {
super.write(ctx, msg, promise);
}
})
.addLast(new ChannelInboundHandlerAdapter() {
@Override
public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception {
if (msg instanceof FullHttpRequest) {
FullHttpResponse res =
new DefaultFullHttpResponse(HttpVersion.HTTP_1_1, HttpResponseStatus.OK,
Unpooled.copiedBuffer("Hello, World", CharsetUtil.US_ASCII));
ctx.writeAndFlush(res);
ReferenceCountUtil.release(msg);
return;
}
super.channelRead(ctx, msg);
}
});
}
});
LocalAddress address = new LocalAddress(UUID.randomUUID().toString());
server = bootstrap.bind(address).sync().channel();
final BlockingQueue<HttpObject> responses = new LinkedBlockingQueue<HttpObject>();
client = new Bootstrap()
.channel(LocalChannel.class)
.remoteAddress(address)
.group(localGroup)
.handler(new ChannelInitializer<LocalChannel>() {
@Override
protected void initChannel(LocalChannel ch) throws Exception {
ch.pipeline().addLast(new HttpClientCodec()).addLast(new ChannelInboundHandlerAdapter() {
@Override
public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception {
if (msg instanceof HttpObject) {
responses.put((HttpObject) msg);
return;
}
super.channelRead(ctx, msg);
}
});
}
}).connect().sync().channel();
client.writeAndFlush(newRequest()).sync();
assertEncodedResponse((HttpResponse) responses.poll(1, TimeUnit.SECONDS));
HttpContent c = (HttpContent) responses.poll(1, TimeUnit.SECONDS);
assertNotNull(c);
assertThat(ByteBufUtil.hexDump(c.content()),
is("1f8b0800000000000000f248cdc9c9d75108cf2fca4901000000ffff"));
c.release();
c = (HttpContent) responses.poll(1, TimeUnit.SECONDS);
assertNotNull(c);
assertThat(ByteBufUtil.hexDump(c.content()), is("0300c6865b260c000000"));
c.release();
LastHttpContent last = (LastHttpContent) responses.poll(1, TimeUnit.SECONDS);
assertNotNull(last);
assertThat(last.content().readableBytes(), is(0));
last.release();
assertNull(responses.poll(1, TimeUnit.SECONDS));
} finally {
if (client != null) {
client.close().sync();
}
if (server != null) {
server.close().sync();
}
compressorGroup.shutdownGracefully();
localGroup.shutdownGracefully();
}
}
/**
* If the length of the content is unknown, {@link HttpContentEncoder} should not skip encoding the content
* even if the actual length is turned out to be 0.
@ -543,7 +659,10 @@ public class HttpContentCompressorTest {
Object o = ch.readOutbound();
assertThat(o, is(instanceOf(HttpResponse.class)));
HttpResponse res = (HttpResponse) o;
assertEncodedResponse((HttpResponse) o);
}
private static void assertEncodedResponse(HttpResponse res) {
assertThat(res, is(not(instanceOf(HttpContent.class))));
assertThat(res.headers().get(HttpHeaderNames.TRANSFER_ENCODING), is("chunked"));
assertThat(res.headers().get(HttpHeaderNames.CONTENT_LENGTH), is(nullValue()));

View File

@ -51,6 +51,8 @@ import static io.netty.channel.ChannelHandlerMask.MASK_DEREGISTER;
import static io.netty.channel.ChannelHandlerMask.MASK_DISCONNECT;
import static io.netty.channel.ChannelHandlerMask.MASK_EXCEPTION_CAUGHT;
import static io.netty.channel.ChannelHandlerMask.MASK_FLUSH;
import static io.netty.channel.ChannelHandlerMask.MASK_ONLY_INBOUND;
import static io.netty.channel.ChannelHandlerMask.MASK_ONLY_OUTBOUND;
import static io.netty.channel.ChannelHandlerMask.MASK_READ;
import static io.netty.channel.ChannelHandlerMask.MASK_USER_EVENT_TRIGGERED;
import static io.netty.channel.ChannelHandlerMask.MASK_WRITE;
@ -906,20 +908,33 @@ abstract class AbstractChannelHandlerContext implements ChannelHandlerContext, R
private AbstractChannelHandlerContext findContextInbound(int mask) {
AbstractChannelHandlerContext ctx = this;
EventExecutor currentExecutor = executor();
do {
ctx = ctx.next;
} while ((ctx.executionMask & mask) == 0);
} while (skipContext(ctx, currentExecutor, mask, MASK_ONLY_INBOUND));
return ctx;
}
private AbstractChannelHandlerContext findContextOutbound(int mask) {
AbstractChannelHandlerContext ctx = this;
EventExecutor currentExecutor = executor();
do {
ctx = ctx.prev;
} while ((ctx.executionMask & mask) == 0);
} while (skipContext(ctx, currentExecutor, mask, MASK_ONLY_OUTBOUND));
return ctx;
}
private static boolean skipContext(
AbstractChannelHandlerContext ctx, EventExecutor currentExecutor, int mask, int onlyMask) {
// Ensure we correctly handle MASK_EXCEPTION_CAUGHT which is not included in the MASK_EXCEPTION_CAUGHT
return (ctx.executionMask & (onlyMask | mask)) == 0 ||
// We can only skip if the EventExecutor is the same as otherwise we need to ensure we offload
// everything to preserve ordering.
//
// See https://github.com/netty/netty/issues/10067
(ctx.executor() == currentExecutor && (ctx.executionMask & mask) == 0);
}
@Override
public ChannelPromise voidPromise() {
return channel().voidPromise();

View File

@ -54,11 +54,13 @@ final class ChannelHandlerMask {
static final int MASK_WRITE = 1 << 15;
static final int MASK_FLUSH = 1 << 16;
private static final int MASK_ALL_INBOUND = MASK_EXCEPTION_CAUGHT | MASK_CHANNEL_REGISTERED |
static final int MASK_ONLY_INBOUND = MASK_CHANNEL_REGISTERED |
MASK_CHANNEL_UNREGISTERED | MASK_CHANNEL_ACTIVE | MASK_CHANNEL_INACTIVE | MASK_CHANNEL_READ |
MASK_CHANNEL_READ_COMPLETE | MASK_USER_EVENT_TRIGGERED | MASK_CHANNEL_WRITABILITY_CHANGED;
private static final int MASK_ALL_OUTBOUND = MASK_EXCEPTION_CAUGHT | MASK_BIND | MASK_CONNECT | MASK_DISCONNECT |
private static final int MASK_ALL_INBOUND = MASK_EXCEPTION_CAUGHT | MASK_ONLY_INBOUND;
static final int MASK_ONLY_OUTBOUND = MASK_BIND | MASK_CONNECT | MASK_DISCONNECT |
MASK_CLOSE | MASK_DEREGISTER | MASK_READ | MASK_WRITE | MASK_FLUSH;
private static final int MASK_ALL_OUTBOUND = MASK_EXCEPTION_CAUGHT | MASK_ONLY_OUTBOUND;
private static final FastThreadLocal<Map<Class<? extends ChannelHandler>, Integer>> MASKS =
new FastThreadLocal<Map<Class<? extends ChannelHandler>, Integer>>() {