From 14d64d09669e95f015148b7b1d71d303bd9e8b77 Mon Sep 17 00:00:00 2001 From: Trustin Lee Date: Sat, 7 Feb 2015 15:54:42 +0900 Subject: [PATCH] Ensure channelReadComplete() is called only when necessary Motivation: Even if a handler called ctx.fireChannelReadComplete(), the next handler should not get its channelReadComplete() invoked if fireChannelRead() was not invoked before. Modifications: - Ensure channelReadComplete() is invoked only when the handler of the current context actually produced a message, because otherwise there's no point of triggering channelReadComplete(). i.e. channelReadComplete() must follow channelRead(). - Fix a bug where ctx.read() was not called if the handler of the current context did not produce any message, making the connection stall. Read the new comment for more information. Result: - channelReadComplete() is invoked only when it makes sense. - No stale connection --- .../handler/codec/ByteToMessageDecoder.java | 8 - .../codec/ByteToMessageDecoderTest.java | 59 +++++ .../AbstractChannelHandlerContext.java | 62 ++++- .../channel/DefaultChannelPipelineTest.java | 247 ++++++++++++++++++ 4 files changed, 366 insertions(+), 10 deletions(-) diff --git a/codec/src/main/java/io/netty/handler/codec/ByteToMessageDecoder.java b/codec/src/main/java/io/netty/handler/codec/ByteToMessageDecoder.java index 586b29ba7c..4fa4b16daf 100644 --- a/codec/src/main/java/io/netty/handler/codec/ByteToMessageDecoder.java +++ b/codec/src/main/java/io/netty/handler/codec/ByteToMessageDecoder.java @@ -134,7 +134,6 @@ public abstract class ByteToMessageDecoder extends ChannelHandlerAdapter { ByteBuf cumulation; private Cumulator cumulator = MERGE_CUMULATOR; private boolean singleDecode; - private boolean decodeWasNull; private boolean first; protected ByteToMessageDecoder() { @@ -239,7 +238,6 @@ public abstract class ByteToMessageDecoder extends ChannelHandlerAdapter { cumulation = null; } int size = out.size(); - decodeWasNull = size == 0; for (int i = 0; i < size; i ++) { ctx.fireChannelRead(out.get(i)); @@ -263,12 +261,6 @@ public abstract class ByteToMessageDecoder extends ChannelHandlerAdapter { // - https://github.com/netty/netty/issues/1764 cumulation.discardSomeReadBytes(); } - if (decodeWasNull) { - decodeWasNull = false; - if (!ctx.channel().config().isAutoRead()) { - ctx.read(); - } - } ctx.fireChannelReadComplete(); } diff --git a/codec/src/test/java/io/netty/handler/codec/ByteToMessageDecoderTest.java b/codec/src/test/java/io/netty/handler/codec/ByteToMessageDecoderTest.java index f5fd656976..b4e3cab0a5 100644 --- a/codec/src/test/java/io/netty/handler/codec/ByteToMessageDecoderTest.java +++ b/codec/src/test/java/io/netty/handler/codec/ByteToMessageDecoderTest.java @@ -20,6 +20,7 @@ import io.netty.buffer.Unpooled; import io.netty.channel.ChannelHandlerContext; import io.netty.channel.ChannelInboundHandlerAdapter; import io.netty.channel.embedded.EmbeddedChannel; +import io.netty.util.CharsetUtil; import io.netty.util.ReferenceCountUtil; import org.junit.Assert; import org.junit.Test; @@ -27,6 +28,7 @@ import org.junit.Test; import java.util.List; import java.util.concurrent.BlockingQueue; import java.util.concurrent.LinkedBlockingDeque; +import java.util.concurrent.atomic.AtomicInteger; public class ByteToMessageDecoderTest { @@ -160,4 +162,61 @@ public class ByteToMessageDecoderTest { Assert.assertEquals(3, (int) queue.take()); Assert.assertTrue(queue.isEmpty()); } + + // See https://github.com/netty/netty/pull/3263 + @Test + public void testFireChannelReadCompleteOnlyWhenDecoded() { + final AtomicInteger readComplete = new AtomicInteger(); + EmbeddedChannel ch = new EmbeddedChannel(new ByteToMessageDecoder() { + @Override + protected void decode(ChannelHandlerContext ctx, ByteBuf in, List out) throws Exception { + // Do nothing + } + }, new ChannelInboundHandlerAdapter() { + @Override + public void channelReadComplete(ChannelHandlerContext ctx) throws Exception { + readComplete.incrementAndGet(); + } + }); + Assert.assertFalse(ch.writeInbound(Unpooled.copiedBuffer("test", CharsetUtil.US_ASCII))); + Assert.assertFalse(ch.finish()); + Assert.assertEquals(0, readComplete.get()); + } + + // See https://github.com/netty/netty/pull/3263 + @Test + public void testFireChannelReadCompleteWhenDecodeOnce() { + final AtomicInteger readComplete = new AtomicInteger(); + EmbeddedChannel ch = new EmbeddedChannel(new ChannelInboundHandlerAdapter() { + @Override + public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception { + ctx.fireChannelRead(msg); + ctx.fireChannelRead(Unpooled.EMPTY_BUFFER); + } + }, new ByteToMessageDecoder() { + private boolean first = true; + @Override + protected void decode(ChannelHandlerContext ctx, ByteBuf in, List out) throws Exception { + if (first) { + first = false; + out.add(in.readSlice(in.readableBytes()).retain()); + } + } + }, new ChannelInboundHandlerAdapter() { + @Override + public void channelReadComplete(ChannelHandlerContext ctx) throws Exception { + readComplete.incrementAndGet(); + } + }); + Assert.assertTrue(ch.writeInbound(Unpooled.copiedBuffer("test", CharsetUtil.US_ASCII))); + Assert.assertTrue(ch.finish()); + Assert.assertEquals(1, readComplete.get()); + for (;;) { + ByteBuf buf = ch.readInbound(); + if (buf == null) { + break; + } + buf.release(); + } + } } diff --git a/transport/src/main/java/io/netty/channel/AbstractChannelHandlerContext.java b/transport/src/main/java/io/netty/channel/AbstractChannelHandlerContext.java index 69cf681e04..c4f364a4c8 100644 --- a/transport/src/main/java/io/netty/channel/AbstractChannelHandlerContext.java +++ b/transport/src/main/java/io/netty/channel/AbstractChannelHandlerContext.java @@ -213,6 +213,26 @@ abstract class AbstractChannelHandlerContext implements ChannelHandlerContext, R private final AbstractChannel channel; private final DefaultChannelPipeline pipeline; private final String name; + + /** + * Set when a user calls {@link #fireChannelRead(Object)} on this context. + * Cleared when a user calls {@link #fireChannelReadComplete()} on this context. + * + * See {@link #fireChannelReadComplete()} to understand how this flag is used. + */ + private volatile boolean firedChannelRead; + + /** + * Set when a user calls {@link #read()} on this context. + * Cleared when a user calls {@link #fireChannelReadComplete()} on this context. + * + * See {@link #fireChannelReadComplete()} to understand how this flag is used. + */ + private volatile boolean invokedRead; + + /** + * {@code true} if and only if this context has been removed from the pipeline. + */ private boolean removed; final int skipFlags; @@ -356,14 +376,51 @@ abstract class AbstractChannelHandlerContext implements ChannelHandlerContext, R public ChannelHandlerContext fireChannelRead(Object msg) { AbstractChannelHandlerContext next = findContextInbound(); ReferenceCountUtil.touch(msg, next); + firedChannelRead = true; next.invoker().invokeChannelRead(next, msg); return this; } @Override public ChannelHandlerContext fireChannelReadComplete() { - AbstractChannelHandlerContext next = findContextInbound(); - next.invoker().invokeChannelReadComplete(next); + /** + * If the handler of this context did not produce any messages via {@link #fireChannelRead(Object)}, + * there's no reason to trigger {@code channelReadComplete()} even if the handler called this method. + * + * This is pretty common for the handlers that transform multiple messages into one message, + * such as byte-to-message decoder and message aggregators. + */ + if (firedChannelRead) { + // The handler of this context produced a message, so we are OK to trigger this event. + firedChannelRead = false; + invokedRead = false; + AbstractChannelHandlerContext next = findContextInbound(); + next.invoker().invokeChannelReadComplete(next); + return this; + } + + /** + * At this point, we are sure the handler of this context did not produce anything and + * we suppressed the {@code channelReadComplete()} event. + * + * If the next handler invoked {@link #read()} to read something but nothing was produced + * by the handler of this context, someone has to issue another {@link #read()} operation + * until the handler of this context produces something. + * + * Why? Because otherwise the next handler will not receive {@code channelRead()} nor + * {@code channelReadComplete()} event at all for the {@link #read()} operation it issued. + */ + if (invokedRead && !channel().config().isAutoRead()) { + /** + * The next (or upstream) handler invoked {@link #read()}, but it didn't get any + * {@code channelRead()} event. We should read once more, so that the handler of the current + * context have a chance to produce something. + */ + read(); + } else { + invokedRead = false; + } + return this; } @@ -451,6 +508,7 @@ abstract class AbstractChannelHandlerContext implements ChannelHandlerContext, R @Override public ChannelHandlerContext read() { AbstractChannelHandlerContext next = findContextOutbound(); + invokedRead = true; next.invoker().invokeRead(next); return this; } diff --git a/transport/src/test/java/io/netty/channel/DefaultChannelPipelineTest.java b/transport/src/test/java/io/netty/channel/DefaultChannelPipelineTest.java index 4785fb9018..9190e4b065 100644 --- a/transport/src/test/java/io/netty/channel/DefaultChannelPipelineTest.java +++ b/transport/src/test/java/io/netty/channel/DefaultChannelPipelineTest.java @@ -21,12 +21,15 @@ import io.netty.bootstrap.ServerBootstrap; import io.netty.buffer.ByteBuf; import io.netty.buffer.Unpooled; import io.netty.channel.ChannelHandler.Sharable; +import io.netty.channel.embedded.EmbeddedChannel; import io.netty.channel.local.LocalAddress; import io.netty.channel.local.LocalChannel; import io.netty.channel.local.LocalServerChannel; import io.netty.util.AbstractReferenceCounted; import io.netty.util.ReferenceCountUtil; import io.netty.util.ReferenceCounted; +import io.netty.util.concurrent.DefaultEventExecutorGroup; +import io.netty.util.concurrent.EventExecutorGroup; import org.junit.After; import org.junit.AfterClass; import org.junit.Test; @@ -38,6 +41,7 @@ import java.util.List; import java.util.Queue; import java.util.concurrent.CountDownLatch; import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicInteger; import java.util.concurrent.atomic.AtomicReference; import static org.junit.Assert.*; @@ -548,6 +552,249 @@ public class DefaultChannelPipelineTest { assertNull(pipeline.last()); } + @Test + public void testSupressChannelReadComplete() throws Exception { + testSupressChannelReadComplete0(false); + } + + @Test + public void testSupressChannelReadCompleteDifferentExecutors() throws Exception { + testSupressChannelReadComplete0(true); + } + + // See: + // https://github.com/netty/netty/pull/3263 + // https://github.com/netty/netty/pull/3272 + private static void testSupressChannelReadComplete0(boolean executors) throws Exception { + final AtomicInteger read1 = new AtomicInteger(); + final AtomicInteger read2 = new AtomicInteger(); + final AtomicInteger readComplete1 = new AtomicInteger(); + final AtomicInteger readComplete2 = new AtomicInteger(); + + final CountDownLatch latch = new CountDownLatch(1); + + final EventExecutorGroup group = executors ? new DefaultEventExecutorGroup(2) : null; + + EmbeddedChannel ch = new EmbeddedChannel(new ChannelInitializer() { + @Override + protected void initChannel(Channel ch) throws Exception { + ch.pipeline().addLast(group, new ChannelInboundHandlerAdapter() { + @Override + public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception { + if (read1.incrementAndGet() == 1) { + return; + } + ctx.fireChannelRead(msg); + } + + @Override + public void channelReadComplete(ChannelHandlerContext ctx) throws Exception { + readComplete1.incrementAndGet(); + ctx.fireChannelReadComplete(); + } + }); + ch.pipeline().addLast(group, new ChannelInboundHandlerAdapter() { + @Override + public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception { + read2.incrementAndGet(); + ctx.fireChannelRead(msg); + } + + @Override + public void channelReadComplete(ChannelHandlerContext ctx) throws Exception { + readComplete2.incrementAndGet(); + ctx.fireChannelReadComplete(); + } + + @Override + public void channelInactive(ChannelHandlerContext ctx) throws Exception { + latch.countDown(); + } + }); + } + }); + + ch.writeInbound(1); + ch.writeInbound(2); + ch.writeInbound(3); + ch.finish(); + latch.await(); + assertEquals(3, read1.get()); + assertEquals(3, readComplete1.get()); + + assertEquals(2, read2.get()); + assertEquals(2, readComplete2.get()); + + assertEquals(2, ch.readInbound()); + assertEquals(3, ch.readInbound()); + assertNull(ch.readInbound()); + + if (group != null) { + group.shutdownGracefully(); + } + } + + @Test + public void testChannelReadTriggered() { + final AtomicInteger read1 = new AtomicInteger(); + final AtomicInteger channelRead1 = new AtomicInteger(); + final AtomicInteger channelReadComplete1 = new AtomicInteger(); + final AtomicInteger read2 = new AtomicInteger(); + final AtomicInteger channelRead2 = new AtomicInteger(); + final AtomicInteger channelReadComplete2 = new AtomicInteger(); + final AtomicInteger read3 = new AtomicInteger(); + final AtomicInteger channelRead3 = new AtomicInteger(); + final AtomicInteger channelReadComplete3 = new AtomicInteger(); + + EmbeddedChannel ch = new EmbeddedChannel(new ChannelDuplexHandler() { + @Override + public void read(ChannelHandlerContext ctx) throws Exception { + read1.incrementAndGet(); + ctx.read(); + } + + @Override + public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception { + channelRead1.incrementAndGet(); + ctx.fireChannelRead(msg); + } + + @Override + public void channelReadComplete(ChannelHandlerContext ctx) throws Exception { + channelReadComplete1.incrementAndGet(); + ctx.fireChannelReadComplete(); + } + }, new ChannelDuplexHandler() { + @Override + public void read(ChannelHandlerContext ctx) throws Exception { + read2.incrementAndGet(); + ctx.read(); + } + + @Override + public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception { + // Consume + channelRead2.incrementAndGet(); + } + + @Override + public void channelReadComplete(ChannelHandlerContext ctx) throws Exception { + channelReadComplete2.incrementAndGet(); + ctx.fireChannelReadComplete(); + } + }, new ChannelDuplexHandler() { + @Override + public void read(ChannelHandlerContext ctx) throws Exception { + read3.incrementAndGet(); + ctx.read(); + } + + @Override + public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception { + channelRead3.incrementAndGet(); + ctx.fireChannelRead(msg); + } + + @Override + public void channelReadComplete(ChannelHandlerContext ctx) throws Exception { + channelReadComplete3.incrementAndGet(); + ctx.fireChannelReadComplete(); + } + }); + + ch.config().setAutoRead(false); + + ch.writeInbound(1); + ch.writeInbound(2); + ch.writeInbound(3); + ch.finish(); + assertEquals(4, read1.get()); + assertEquals(3, channelRead1.get()); + assertEquals(3, channelReadComplete1.get()); + + assertEquals(1, read2.get()); + assertEquals(3, channelRead2.get()); + assertEquals(3, channelReadComplete1.get()); + + assertEquals(1, read3.get()); + assertEquals(0, channelRead3.get()); + assertEquals(0, channelReadComplete3.get()); + + assertNull(ch.readInbound()); + } + + @Test + public void testChannelReadNotTriggeredWhenLast() throws Exception { + final AtomicInteger read1 = new AtomicInteger(); + final AtomicInteger channelRead1 = new AtomicInteger(); + final AtomicInteger channelReadComplete1 = new AtomicInteger(); + final AtomicInteger read2 = new AtomicInteger(); + final AtomicInteger channelRead2 = new AtomicInteger(); + final AtomicInteger channelReadComplete2 = new AtomicInteger(); + + EmbeddedChannel ch = new EmbeddedChannel(new ChannelInboundHandlerAdapter()); + + // Ensure pipeline is really empty + ChannelPipeline pipeline = ch.pipeline(); + while (pipeline.first() != null) { + pipeline.removeFirst(); + } + + pipeline.addLast(new ChannelDuplexHandler() { + @Override + public void read(ChannelHandlerContext ctx) throws Exception { + read1.incrementAndGet(); + ctx.read(); + } + + @Override + public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception { + channelRead1.incrementAndGet(); + ctx.fireChannelRead(msg); + } + + @Override + public void channelReadComplete(ChannelHandlerContext ctx) throws Exception { + channelReadComplete1.incrementAndGet(); + ctx.fireChannelReadComplete(); + } + }, new ChannelDuplexHandler() { + @Override + public void read(ChannelHandlerContext ctx) throws Exception { + read2.incrementAndGet(); + ctx.read(); + } + + @Override + public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception { + // Consume + channelRead2.incrementAndGet(); + } + + @Override + public void channelReadComplete(ChannelHandlerContext ctx) throws Exception { + channelReadComplete2.incrementAndGet(); + ctx.fireChannelReadComplete(); + } + }); + + ch.config().setAutoRead(false); + + ch.writeInbound(1); + ch.writeInbound(2); + ch.writeInbound(3); + ch.finish(); + assertEquals(0, read1.get()); + assertEquals(3, channelRead1.get()); + assertEquals(3, channelReadComplete1.get()); + + assertEquals(0, read2.get()); + assertEquals(3, channelRead2.get()); + assertEquals(3, channelReadComplete1.get()); + + assertNull(ch.readInbound()); + } + private static int next(AbstractChannelHandlerContext ctx) { AbstractChannelHandlerContext next = ctx.next; if (next == null) {