diff --git a/codec/src/main/java/io/netty/handler/codec/ByteToMessageDecoder.java b/codec/src/main/java/io/netty/handler/codec/ByteToMessageDecoder.java index 9d01abe985..c2f5885e49 100644 --- a/codec/src/main/java/io/netty/handler/codec/ByteToMessageDecoder.java +++ b/codec/src/main/java/io/netty/handler/codec/ByteToMessageDecoder.java @@ -133,7 +133,6 @@ public abstract class ByteToMessageDecoder extends ChannelInboundHandlerAdapter ByteBuf cumulation; private Cumulator cumulator = MERGE_CUMULATOR; private boolean singleDecode; - private boolean decodeWasNull; private boolean first; protected ByteToMessageDecoder() { @@ -238,7 +237,6 @@ public abstract class ByteToMessageDecoder extends ChannelInboundHandlerAdapter cumulation = null; } int size = out.size(); - decodeWasNull = size == 0; for (int i = 0; i < size; i ++) { ctx.fireChannelRead(out.get(i)); @@ -262,12 +260,6 @@ public abstract class ByteToMessageDecoder extends ChannelInboundHandlerAdapter // - https://github.com/netty/netty/issues/1764 cumulation.discardSomeReadBytes(); } - if (decodeWasNull) { - decodeWasNull = false; - if (!ctx.channel().config().isAutoRead()) { - ctx.read(); - } - } ctx.fireChannelReadComplete(); } diff --git a/codec/src/test/java/io/netty/handler/codec/ByteToMessageDecoderTest.java b/codec/src/test/java/io/netty/handler/codec/ByteToMessageDecoderTest.java index f5fd656976..b4e3cab0a5 100644 --- a/codec/src/test/java/io/netty/handler/codec/ByteToMessageDecoderTest.java +++ b/codec/src/test/java/io/netty/handler/codec/ByteToMessageDecoderTest.java @@ -20,6 +20,7 @@ import io.netty.buffer.Unpooled; import io.netty.channel.ChannelHandlerContext; import io.netty.channel.ChannelInboundHandlerAdapter; import io.netty.channel.embedded.EmbeddedChannel; +import io.netty.util.CharsetUtil; import io.netty.util.ReferenceCountUtil; import org.junit.Assert; import org.junit.Test; @@ -27,6 +28,7 @@ import org.junit.Test; import java.util.List; import java.util.concurrent.BlockingQueue; import java.util.concurrent.LinkedBlockingDeque; +import java.util.concurrent.atomic.AtomicInteger; public class ByteToMessageDecoderTest { @@ -160,4 +162,61 @@ public class ByteToMessageDecoderTest { Assert.assertEquals(3, (int) queue.take()); Assert.assertTrue(queue.isEmpty()); } + + // See https://github.com/netty/netty/pull/3263 + @Test + public void testFireChannelReadCompleteOnlyWhenDecoded() { + final AtomicInteger readComplete = new AtomicInteger(); + EmbeddedChannel ch = new EmbeddedChannel(new ByteToMessageDecoder() { + @Override + protected void decode(ChannelHandlerContext ctx, ByteBuf in, List out) throws Exception { + // Do nothing + } + }, new ChannelInboundHandlerAdapter() { + @Override + public void channelReadComplete(ChannelHandlerContext ctx) throws Exception { + readComplete.incrementAndGet(); + } + }); + Assert.assertFalse(ch.writeInbound(Unpooled.copiedBuffer("test", CharsetUtil.US_ASCII))); + Assert.assertFalse(ch.finish()); + Assert.assertEquals(0, readComplete.get()); + } + + // See https://github.com/netty/netty/pull/3263 + @Test + public void testFireChannelReadCompleteWhenDecodeOnce() { + final AtomicInteger readComplete = new AtomicInteger(); + EmbeddedChannel ch = new EmbeddedChannel(new ChannelInboundHandlerAdapter() { + @Override + public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception { + ctx.fireChannelRead(msg); + ctx.fireChannelRead(Unpooled.EMPTY_BUFFER); + } + }, new ByteToMessageDecoder() { + private boolean first = true; + @Override + protected void decode(ChannelHandlerContext ctx, ByteBuf in, List out) throws Exception { + if (first) { + first = false; + out.add(in.readSlice(in.readableBytes()).retain()); + } + } + }, new ChannelInboundHandlerAdapter() { + @Override + public void channelReadComplete(ChannelHandlerContext ctx) throws Exception { + readComplete.incrementAndGet(); + } + }); + Assert.assertTrue(ch.writeInbound(Unpooled.copiedBuffer("test", CharsetUtil.US_ASCII))); + Assert.assertTrue(ch.finish()); + Assert.assertEquals(1, readComplete.get()); + for (;;) { + ByteBuf buf = ch.readInbound(); + if (buf == null) { + break; + } + buf.release(); + } + } } diff --git a/transport/src/main/java/io/netty/channel/AbstractChannelHandlerContext.java b/transport/src/main/java/io/netty/channel/AbstractChannelHandlerContext.java index f371c66502..cd7d286284 100644 --- a/transport/src/main/java/io/netty/channel/AbstractChannelHandlerContext.java +++ b/transport/src/main/java/io/netty/channel/AbstractChannelHandlerContext.java @@ -35,6 +35,26 @@ abstract class AbstractChannelHandlerContext implements ChannelHandlerContext, R private final AbstractChannel channel; private final DefaultChannelPipeline pipeline; private final String name; + + /** + * Set when a user calls {@link #fireChannelRead(Object)} on this context. + * Cleared when a user calls {@link #fireChannelReadComplete()} on this context. + * + * See {@link #fireChannelReadComplete()} to understand how this flag is used. + */ + private volatile boolean firedChannelRead; + + /** + * Set when a user calls {@link #read()} on this context. + * Cleared when a user calls {@link #fireChannelReadComplete()} on this context. + * + * See {@link #fireChannelReadComplete()} to understand how this flag is used. + */ + private volatile boolean invokedRead; + + /** + * {@code true} if and only if this context has been removed from the pipeline. + */ private boolean removed; final ChannelHandlerInvoker invoker; @@ -154,14 +174,51 @@ abstract class AbstractChannelHandlerContext implements ChannelHandlerContext, R public ChannelHandlerContext fireChannelRead(Object msg) { AbstractChannelHandlerContext next = findContextInbound(); ReferenceCountUtil.touch(msg, next); + firedChannelRead = true; next.invoker().invokeChannelRead(next, msg); return this; } @Override public ChannelHandlerContext fireChannelReadComplete() { - AbstractChannelHandlerContext next = findContextInbound(); - next.invoker().invokeChannelReadComplete(next); + /** + * If the handler of this context did not produce any messages via {@link #fireChannelRead(Object)}, + * there's no reason to trigger {@code channelReadComplete()} even if the handler called this method. + * + * This is pretty common for the handlers that transform multiple messages into one message, + * such as byte-to-message decoder and message aggregators. + */ + if (firedChannelRead) { + // The handler of this context produced a message, so we are OK to trigger this event. + firedChannelRead = false; + invokedRead = false; + AbstractChannelHandlerContext next = findContextInbound(); + next.invoker().invokeChannelReadComplete(next); + return this; + } + + /** + * At this point, we are sure the handler of this context did not produce anything and + * we suppressed the {@code channelReadComplete()} event. + * + * If the next handler invoked {@link #read()} to read something but nothing was produced + * by the handler of this context, someone has to issue another {@link #read()} operation + * until the handler of this context produces something. + * + * Why? Because otherwise the next handler will not receive {@code channelRead()} nor + * {@code channelReadComplete()} event at all for the {@link #read()} operation it issued. + */ + if (invokedRead && !channel().config().isAutoRead()) { + /** + * The next (or upstream) handler invoked {@link #read()}, but it didn't get any + * {@code channelRead()} event. We should read once more, so that the handler of the current + * context have a chance to produce something. + */ + read(); + } else { + invokedRead = false; + } + return this; } @@ -249,6 +306,7 @@ abstract class AbstractChannelHandlerContext implements ChannelHandlerContext, R @Override public ChannelHandlerContext read() { AbstractChannelHandlerContext next = findContextOutbound(); + invokedRead = true; next.invoker().invokeRead(next); return this; } diff --git a/transport/src/test/java/io/netty/channel/DefaultChannelPipelineTest.java b/transport/src/test/java/io/netty/channel/DefaultChannelPipelineTest.java index 734fa471b6..e36d171fd4 100644 --- a/transport/src/test/java/io/netty/channel/DefaultChannelPipelineTest.java +++ b/transport/src/test/java/io/netty/channel/DefaultChannelPipelineTest.java @@ -21,12 +21,15 @@ import io.netty.bootstrap.ServerBootstrap; import io.netty.buffer.ByteBuf; import io.netty.buffer.Unpooled; import io.netty.channel.ChannelHandler.Sharable; +import io.netty.channel.embedded.EmbeddedChannel; import io.netty.channel.local.LocalAddress; import io.netty.channel.local.LocalChannel; import io.netty.channel.local.LocalServerChannel; import io.netty.util.AbstractReferenceCounted; import io.netty.util.ReferenceCountUtil; import io.netty.util.ReferenceCounted; +import io.netty.util.concurrent.DefaultEventExecutorGroup; +import io.netty.util.concurrent.EventExecutorGroup; import org.junit.After; import org.junit.AfterClass; import org.junit.Test; @@ -38,6 +41,7 @@ import java.util.List; import java.util.Queue; import java.util.concurrent.CountDownLatch; import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicInteger; import java.util.concurrent.atomic.AtomicReference; import static org.junit.Assert.*; @@ -548,6 +552,249 @@ public class DefaultChannelPipelineTest { assertNull(pipeline.last()); } + @Test + public void testSupressChannelReadComplete() throws Exception { + testSupressChannelReadComplete0(false); + } + + @Test + public void testSupressChannelReadCompleteDifferentExecutors() throws Exception { + testSupressChannelReadComplete0(true); + } + + // See: + // https://github.com/netty/netty/pull/3263 + // https://github.com/netty/netty/pull/3272 + private static void testSupressChannelReadComplete0(boolean executors) throws Exception { + final AtomicInteger read1 = new AtomicInteger(); + final AtomicInteger read2 = new AtomicInteger(); + final AtomicInteger readComplete1 = new AtomicInteger(); + final AtomicInteger readComplete2 = new AtomicInteger(); + + final CountDownLatch latch = new CountDownLatch(1); + + final EventExecutorGroup group = executors ? new DefaultEventExecutorGroup(2) : null; + + EmbeddedChannel ch = new EmbeddedChannel(new ChannelInitializer() { + @Override + protected void initChannel(Channel ch) throws Exception { + ch.pipeline().addLast(group, new ChannelInboundHandlerAdapter() { + @Override + public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception { + if (read1.incrementAndGet() == 1) { + return; + } + ctx.fireChannelRead(msg); + } + + @Override + public void channelReadComplete(ChannelHandlerContext ctx) throws Exception { + readComplete1.incrementAndGet(); + ctx.fireChannelReadComplete(); + } + }); + ch.pipeline().addLast(group, new ChannelInboundHandlerAdapter() { + @Override + public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception { + read2.incrementAndGet(); + ctx.fireChannelRead(msg); + } + + @Override + public void channelReadComplete(ChannelHandlerContext ctx) throws Exception { + readComplete2.incrementAndGet(); + ctx.fireChannelReadComplete(); + } + + @Override + public void channelInactive(ChannelHandlerContext ctx) throws Exception { + latch.countDown(); + } + }); + } + }); + + ch.writeInbound(1); + ch.writeInbound(2); + ch.writeInbound(3); + ch.finish(); + latch.await(); + assertEquals(3, read1.get()); + assertEquals(3, readComplete1.get()); + + assertEquals(2, read2.get()); + assertEquals(2, readComplete2.get()); + + assertEquals(2, ch.readInbound()); + assertEquals(3, ch.readInbound()); + assertNull(ch.readInbound()); + + if (group != null) { + group.shutdownGracefully(); + } + } + + @Test + public void testChannelReadTriggered() { + final AtomicInteger read1 = new AtomicInteger(); + final AtomicInteger channelRead1 = new AtomicInteger(); + final AtomicInteger channelReadComplete1 = new AtomicInteger(); + final AtomicInteger read2 = new AtomicInteger(); + final AtomicInteger channelRead2 = new AtomicInteger(); + final AtomicInteger channelReadComplete2 = new AtomicInteger(); + final AtomicInteger read3 = new AtomicInteger(); + final AtomicInteger channelRead3 = new AtomicInteger(); + final AtomicInteger channelReadComplete3 = new AtomicInteger(); + + EmbeddedChannel ch = new EmbeddedChannel(new ChannelDuplexHandler() { + @Override + public void read(ChannelHandlerContext ctx) throws Exception { + read1.incrementAndGet(); + ctx.read(); + } + + @Override + public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception { + channelRead1.incrementAndGet(); + ctx.fireChannelRead(msg); + } + + @Override + public void channelReadComplete(ChannelHandlerContext ctx) throws Exception { + channelReadComplete1.incrementAndGet(); + ctx.fireChannelReadComplete(); + } + }, new ChannelDuplexHandler() { + @Override + public void read(ChannelHandlerContext ctx) throws Exception { + read2.incrementAndGet(); + ctx.read(); + } + + @Override + public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception { + // Consume + channelRead2.incrementAndGet(); + } + + @Override + public void channelReadComplete(ChannelHandlerContext ctx) throws Exception { + channelReadComplete2.incrementAndGet(); + ctx.fireChannelReadComplete(); + } + }, new ChannelDuplexHandler() { + @Override + public void read(ChannelHandlerContext ctx) throws Exception { + read3.incrementAndGet(); + ctx.read(); + } + + @Override + public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception { + channelRead3.incrementAndGet(); + ctx.fireChannelRead(msg); + } + + @Override + public void channelReadComplete(ChannelHandlerContext ctx) throws Exception { + channelReadComplete3.incrementAndGet(); + ctx.fireChannelReadComplete(); + } + }); + + ch.config().setAutoRead(false); + + ch.writeInbound(1); + ch.writeInbound(2); + ch.writeInbound(3); + ch.finish(); + assertEquals(4, read1.get()); + assertEquals(3, channelRead1.get()); + assertEquals(3, channelReadComplete1.get()); + + assertEquals(1, read2.get()); + assertEquals(3, channelRead2.get()); + assertEquals(3, channelReadComplete1.get()); + + assertEquals(1, read3.get()); + assertEquals(0, channelRead3.get()); + assertEquals(0, channelReadComplete3.get()); + + assertNull(ch.readInbound()); + } + + @Test + public void testChannelReadNotTriggeredWhenLast() throws Exception { + final AtomicInteger read1 = new AtomicInteger(); + final AtomicInteger channelRead1 = new AtomicInteger(); + final AtomicInteger channelReadComplete1 = new AtomicInteger(); + final AtomicInteger read2 = new AtomicInteger(); + final AtomicInteger channelRead2 = new AtomicInteger(); + final AtomicInteger channelReadComplete2 = new AtomicInteger(); + + EmbeddedChannel ch = new EmbeddedChannel(new ChannelInboundHandlerAdapter()); + + // Ensure pipeline is really empty + ChannelPipeline pipeline = ch.pipeline(); + while (pipeline.first() != null) { + pipeline.removeFirst(); + } + + pipeline.addLast(new ChannelDuplexHandler() { + @Override + public void read(ChannelHandlerContext ctx) throws Exception { + read1.incrementAndGet(); + ctx.read(); + } + + @Override + public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception { + channelRead1.incrementAndGet(); + ctx.fireChannelRead(msg); + } + + @Override + public void channelReadComplete(ChannelHandlerContext ctx) throws Exception { + channelReadComplete1.incrementAndGet(); + ctx.fireChannelReadComplete(); + } + }, new ChannelDuplexHandler() { + @Override + public void read(ChannelHandlerContext ctx) throws Exception { + read2.incrementAndGet(); + ctx.read(); + } + + @Override + public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception { + // Consume + channelRead2.incrementAndGet(); + } + + @Override + public void channelReadComplete(ChannelHandlerContext ctx) throws Exception { + channelReadComplete2.incrementAndGet(); + ctx.fireChannelReadComplete(); + } + }); + + ch.config().setAutoRead(false); + + ch.writeInbound(1); + ch.writeInbound(2); + ch.writeInbound(3); + ch.finish(); + assertEquals(0, read1.get()); + assertEquals(3, channelRead1.get()); + assertEquals(3, channelReadComplete1.get()); + + assertEquals(0, read2.get()); + assertEquals(3, channelRead2.get()); + assertEquals(3, channelReadComplete1.get()); + + assertNull(ch.readInbound()); + } + private static int next(AbstractChannelHandlerContext ctx) { AbstractChannelHandlerContext next = ctx.next; if (next == null) {