From 375b9e1307c83a648329711c02237b360d8e3cd5 Mon Sep 17 00:00:00 2001 From: Norman Maurer Date: Mon, 15 Dec 2014 11:07:03 +0100 Subject: [PATCH] Ensure ctx.fireChannelReadComplete() is only called when something was produced Motivation: ctx.fireChannelReadComplete() should only be called if something is produced during a channelRead(...) operation. Also we must ensure that it will be called if channelRead(...) produced something at some point as channelRead(...) maybe called multiple times by the transport before channelReadComplete(...) is called. Modifications: - Ensure channelReadComplete(...) only triggers ctx.fireChannelReadComplete() when a previous channelRead(...) call produced a message - Ensure read() is called of more data is needed Result: Correct semantic with channelReadComplete(...) events and also ensure no stales --- .../handler/codec/ByteToMessageDecoder.java | 8 - .../codec/ByteToMessageDecoderTest.java | 59 +++++ .../AbstractChannelHandlerContext.java | 38 ++- .../channel/DefaultChannelPipelineTest.java | 247 ++++++++++++++++++ 4 files changed, 340 insertions(+), 12 deletions(-) diff --git a/codec/src/main/java/io/netty/handler/codec/ByteToMessageDecoder.java b/codec/src/main/java/io/netty/handler/codec/ByteToMessageDecoder.java index 7bbc2bb113..8de0378df7 100644 --- a/codec/src/main/java/io/netty/handler/codec/ByteToMessageDecoder.java +++ b/codec/src/main/java/io/netty/handler/codec/ByteToMessageDecoder.java @@ -133,7 +133,6 @@ public abstract class ByteToMessageDecoder extends ChannelInboundHandlerAdapter ByteBuf cumulation; private Cumulator cumulator = MERGE_CUMULATOR; private boolean singleDecode; - private boolean decodeWasNull; private boolean first; protected ByteToMessageDecoder() { @@ -238,7 +237,6 @@ public abstract class ByteToMessageDecoder extends ChannelInboundHandlerAdapter cumulation = null; } int size = out.size(); - decodeWasNull = size == 0; for (int i = 0; i < size; i ++) { ctx.fireChannelRead(out.get(i)); @@ -262,12 +260,6 @@ public abstract class ByteToMessageDecoder extends ChannelInboundHandlerAdapter // - https://github.com/netty/netty/issues/1764 cumulation.discardSomeReadBytes(); } - if (decodeWasNull) { - decodeWasNull = false; - if (!ctx.channel().config().isAutoRead()) { - ctx.read(); - } - } ctx.fireChannelReadComplete(); } diff --git a/codec/src/test/java/io/netty/handler/codec/ByteToMessageDecoderTest.java b/codec/src/test/java/io/netty/handler/codec/ByteToMessageDecoderTest.java index aad5e13128..47ed85ab33 100644 --- a/codec/src/test/java/io/netty/handler/codec/ByteToMessageDecoderTest.java +++ b/codec/src/test/java/io/netty/handler/codec/ByteToMessageDecoderTest.java @@ -20,6 +20,7 @@ import io.netty.buffer.Unpooled; import io.netty.channel.ChannelHandlerContext; import io.netty.channel.ChannelInboundHandlerAdapter; import io.netty.channel.embedded.EmbeddedChannel; +import io.netty.util.CharsetUtil; import io.netty.util.ReferenceCountUtil; import org.junit.Assert; import org.junit.Test; @@ -27,6 +28,7 @@ import org.junit.Test; import java.util.List; import java.util.concurrent.BlockingQueue; import java.util.concurrent.LinkedBlockingDeque; +import java.util.concurrent.atomic.AtomicInteger; public class ByteToMessageDecoderTest { @@ -160,4 +162,61 @@ public class ByteToMessageDecoderTest { Assert.assertEquals(3, (int) queue.take()); Assert.assertTrue(queue.isEmpty()); } + + // See https://github.com/netty/netty/pull/3263 + @Test + public void testFireChannelReadCompleteOnlyWhenDecoded() { + final AtomicInteger readComplete = new AtomicInteger(); + EmbeddedChannel ch = new EmbeddedChannel(new ByteToMessageDecoder() { + @Override + protected void decode(ChannelHandlerContext ctx, ByteBuf in, List out) throws Exception { + // Do nothing + } + }, new ChannelInboundHandlerAdapter() { + @Override + public void channelReadComplete(ChannelHandlerContext ctx) throws Exception { + readComplete.incrementAndGet(); + } + }); + Assert.assertFalse(ch.writeInbound(Unpooled.copiedBuffer("test", CharsetUtil.US_ASCII))); + Assert.assertFalse(ch.finish()); + Assert.assertEquals(0, readComplete.get()); + } + + // See https://github.com/netty/netty/pull/3263 + @Test + public void testFireChannelReadCompleteWhenDecodeOnce() { + final AtomicInteger readComplete = new AtomicInteger(); + EmbeddedChannel ch = new EmbeddedChannel(new ChannelInboundHandlerAdapter() { + @Override + public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception { + ctx.fireChannelRead(msg); + ctx.fireChannelRead(Unpooled.EMPTY_BUFFER); + } + }, new ByteToMessageDecoder() { + private boolean first = true; + @Override + protected void decode(ChannelHandlerContext ctx, ByteBuf in, List out) throws Exception { + if (first) { + first = false; + out.add(in.readSlice(in.readableBytes()).retain()); + } + } + }, new ChannelInboundHandlerAdapter() { + @Override + public void channelReadComplete(ChannelHandlerContext ctx) throws Exception { + readComplete.incrementAndGet(); + } + }); + Assert.assertTrue(ch.writeInbound(Unpooled.copiedBuffer("test", CharsetUtil.US_ASCII))); + Assert.assertTrue(ch.finish()); + Assert.assertEquals(1, readComplete.get()); + for (;;) { + ByteBuf buf = (ByteBuf) ch.readInbound(); + if (buf == null) { + break; + } + buf.release(); + } + } } diff --git a/transport/src/main/java/io/netty/channel/AbstractChannelHandlerContext.java b/transport/src/main/java/io/netty/channel/AbstractChannelHandlerContext.java index 514f614e92..8b63e318be 100644 --- a/transport/src/main/java/io/netty/channel/AbstractChannelHandlerContext.java +++ b/transport/src/main/java/io/netty/channel/AbstractChannelHandlerContext.java @@ -40,6 +40,12 @@ abstract class AbstractChannelHandlerContext extends DefaultAttributeMap impleme private final DefaultChannelPipeline pipeline; private final String name; private boolean removed; + // This does not need to be volatile as we always check and set this flag from EventExecutor thread. This means + // that at worse we will submit a task for channelReadComplete() that may do nothing if nextChannelReadInvoked + // is false. This is prefered to introduce another volatile flag because often fireChannelRead(...) and + // fireChannelReadComplete() are triggered from the EventExecutor thread anyway. + private boolean nextChannelReadInvoked; + private boolean readInvoked; // Will be set to null if no child executor should be used, otherwise it will be set to the // child executor. @@ -291,12 +297,12 @@ abstract class AbstractChannelHandlerContext extends DefaultAttributeMap impleme final AbstractChannelHandlerContext next = findContextInbound(); EventExecutor executor = next.executor(); if (executor.inEventLoop()) { - next.invokeChannelRead(msg); + invokeNextChannelRead(next, msg); } else { executor.execute(new OneTimeTask() { @Override public void run() { - next.invokeChannelRead(msg); + invokeNextChannelRead(next, msg); } }); } @@ -311,19 +317,24 @@ abstract class AbstractChannelHandlerContext extends DefaultAttributeMap impleme } } + private void invokeNextChannelRead(AbstractChannelHandlerContext next, Object msg) { + nextChannelReadInvoked = true; + next.invokeChannelRead(msg); + } + @Override public ChannelHandlerContext fireChannelReadComplete() { final AbstractChannelHandlerContext next = findContextInbound(); EventExecutor executor = next.executor(); if (executor.inEventLoop()) { - next.invokeChannelReadComplete(); + invokeNextChannelReadComplete(next); } else { Runnable task = next.invokeChannelReadCompleteTask; if (task == null) { next.invokeChannelReadCompleteTask = task = new Runnable() { @Override public void run() { - next.invokeChannelReadComplete(); + invokeNextChannelReadComplete(next); } }; } @@ -340,6 +351,24 @@ abstract class AbstractChannelHandlerContext extends DefaultAttributeMap impleme } } + private void invokeNextChannelReadComplete(AbstractChannelHandlerContext next) { + if (nextChannelReadInvoked) { + nextChannelReadInvoked = false; + readInvoked = false; + + next.invokeChannelReadComplete(); + } else if (readInvoked && !channel().config().isAutoRead()) { + // As this context not belongs to the last handler in the pipeline and autoRead is false we need to + // trigger read again as otherwise we may stop reading before a full message was passed on to the + // pipeline. This is especially true for all kind of decoders that usually buffer bytes/messages until + // they are able to compose a full message that is passed via fireChannelRead(...) and so be consumed + // be the rest of the handlers in the pipeline. + read(); + } else { + readInvoked = false; + } + } + @Override public ChannelHandlerContext fireChannelWritabilityChanged() { final AbstractChannelHandlerContext next = findContextInbound(); @@ -601,6 +630,7 @@ abstract class AbstractChannelHandlerContext extends DefaultAttributeMap impleme private void invokeRead() { try { + readInvoked = true; ((ChannelOutboundHandler) handler()).read(this); } catch (Throwable t) { notifyHandlerException(t); diff --git a/transport/src/test/java/io/netty/channel/DefaultChannelPipelineTest.java b/transport/src/test/java/io/netty/channel/DefaultChannelPipelineTest.java index 40749b59db..18032e3a67 100644 --- a/transport/src/test/java/io/netty/channel/DefaultChannelPipelineTest.java +++ b/transport/src/test/java/io/netty/channel/DefaultChannelPipelineTest.java @@ -21,6 +21,7 @@ import io.netty.bootstrap.ServerBootstrap; import io.netty.buffer.ByteBuf; import io.netty.buffer.Unpooled; import io.netty.channel.ChannelHandler.Sharable; +import io.netty.channel.embedded.EmbeddedChannel; import io.netty.channel.local.LocalAddress; import io.netty.channel.local.LocalChannel; import io.netty.channel.local.LocalEventLoopGroup; @@ -28,6 +29,8 @@ import io.netty.channel.local.LocalServerChannel; import io.netty.util.AbstractReferenceCounted; import io.netty.util.ReferenceCountUtil; import io.netty.util.ReferenceCounted; +import io.netty.util.concurrent.DefaultEventExecutorGroup; +import io.netty.util.concurrent.EventExecutorGroup; import org.junit.After; import org.junit.AfterClass; import org.junit.Test; @@ -39,6 +42,7 @@ import java.util.List; import java.util.Queue; import java.util.concurrent.CountDownLatch; import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicInteger; import java.util.concurrent.atomic.AtomicReference; import static org.junit.Assert.*; @@ -535,6 +539,249 @@ public class DefaultChannelPipelineTest { assertNull(pipeline.last()); } + @Test + public void testSurpressChannelReadComplete() throws Exception { + testSurpressChannelReadComplete0(false); + } + + @Test + public void testSurpressChannelReadCompleteDifferentExecutors() throws Exception { + testSurpressChannelReadComplete0(true); + } + + // See: + // https://github.com/netty/netty/pull/3263 + // https://github.com/netty/netty/pull/3272 + private static void testSurpressChannelReadComplete0(boolean executors) throws Exception { + final AtomicInteger read1 = new AtomicInteger(); + final AtomicInteger read2 = new AtomicInteger(); + final AtomicInteger readComplete1 = new AtomicInteger(); + final AtomicInteger readComplete2 = new AtomicInteger(); + + final CountDownLatch latch = new CountDownLatch(1); + + final EventExecutorGroup group = executors ? new DefaultEventExecutorGroup(2) : null; + + EmbeddedChannel ch = new EmbeddedChannel(new ChannelInitializer() { + @Override + protected void initChannel(Channel ch) throws Exception { + ch.pipeline().addLast(group, new ChannelInboundHandlerAdapter() { + @Override + public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception { + if (read1.incrementAndGet() == 1) { + return; + } + ctx.fireChannelRead(msg); + } + + @Override + public void channelReadComplete(ChannelHandlerContext ctx) throws Exception { + readComplete1.incrementAndGet(); + ctx.fireChannelReadComplete(); + } + }); + ch.pipeline().addLast(group, new ChannelInboundHandlerAdapter() { + @Override + public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception { + read2.incrementAndGet(); + ctx.fireChannelRead(msg); + } + + @Override + public void channelReadComplete(ChannelHandlerContext ctx) throws Exception { + readComplete2.incrementAndGet(); + ctx.fireChannelReadComplete(); + } + + @Override + public void channelInactive(ChannelHandlerContext ctx) throws Exception { + latch.countDown(); + } + }); + } + }); + + ch.writeInbound(1); + ch.writeInbound(2); + ch.writeInbound(3); + ch.finish(); + latch.await(); + assertEquals(3, read1.get()); + assertEquals(3, readComplete1.get()); + + assertEquals(2, read2.get()); + assertEquals(2, readComplete2.get()); + + assertEquals(2, ch.readInbound()); + assertEquals(3, ch.readInbound()); + assertNull(ch.readInbound()); + + if (group != null) { + group.shutdownGracefully(); + } + } + + @Test + public void testChannelReadTriggered() { + final AtomicInteger read1 = new AtomicInteger(); + final AtomicInteger channelRead1 = new AtomicInteger(); + final AtomicInteger channelReadComplete1 = new AtomicInteger(); + final AtomicInteger read2 = new AtomicInteger(); + final AtomicInteger channelRead2 = new AtomicInteger(); + final AtomicInteger channelReadComplete2 = new AtomicInteger(); + final AtomicInteger read3 = new AtomicInteger(); + final AtomicInteger channelRead3 = new AtomicInteger(); + final AtomicInteger channelReadComplete3 = new AtomicInteger(); + + EmbeddedChannel ch = new EmbeddedChannel(new ChannelDuplexHandler() { + @Override + public void read(ChannelHandlerContext ctx) throws Exception { + read1.incrementAndGet(); + ctx.read(); + } + + @Override + public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception { + channelRead1.incrementAndGet(); + ctx.fireChannelRead(msg); + } + + @Override + public void channelReadComplete(ChannelHandlerContext ctx) throws Exception { + channelReadComplete1.incrementAndGet(); + ctx.fireChannelReadComplete(); + } + }, new ChannelDuplexHandler() { + @Override + public void read(ChannelHandlerContext ctx) throws Exception { + read2.incrementAndGet(); + ctx.read(); + } + + @Override + public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception { + // Consume + channelRead2.incrementAndGet(); + } + + @Override + public void channelReadComplete(ChannelHandlerContext ctx) throws Exception { + channelReadComplete2.incrementAndGet(); + ctx.fireChannelReadComplete(); + } + }, new ChannelDuplexHandler() { + @Override + public void read(ChannelHandlerContext ctx) throws Exception { + read3.incrementAndGet(); + ctx.read(); + } + + @Override + public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception { + channelRead3.incrementAndGet(); + ctx.fireChannelRead(msg); + } + + @Override + public void channelReadComplete(ChannelHandlerContext ctx) throws Exception { + channelReadComplete3.incrementAndGet(); + ctx.fireChannelReadComplete(); + } + }); + + ch.config().setAutoRead(false); + + ch.writeInbound(1); + ch.writeInbound(2); + ch.writeInbound(3); + ch.finish(); + assertEquals(4, read1.get()); + assertEquals(3, channelRead1.get()); + assertEquals(3, channelReadComplete1.get()); + + assertEquals(1, read2.get()); + assertEquals(3, channelRead2.get()); + assertEquals(3, channelReadComplete1.get()); + + assertEquals(1, read3.get()); + assertEquals(0, channelRead3.get()); + assertEquals(0, channelReadComplete3.get()); + + assertNull(ch.readInbound()); + } + + @Test + public void testChannelReadNotTriggeredWhenLast() throws Exception { + final AtomicInteger read1 = new AtomicInteger(); + final AtomicInteger channelRead1 = new AtomicInteger(); + final AtomicInteger channelReadComplete1 = new AtomicInteger(); + final AtomicInteger read2 = new AtomicInteger(); + final AtomicInteger channelRead2 = new AtomicInteger(); + final AtomicInteger channelReadComplete2 = new AtomicInteger(); + + EmbeddedChannel ch = new EmbeddedChannel(new ChannelInboundHandlerAdapter()); + + // Ensure pipeline is really empty + ChannelPipeline pipeline = ch.pipeline(); + while (pipeline.first() != null) { + pipeline.removeFirst(); + } + + pipeline.addLast(new ChannelDuplexHandler() { + @Override + public void read(ChannelHandlerContext ctx) throws Exception { + read1.incrementAndGet(); + ctx.read(); + } + + @Override + public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception { + channelRead1.incrementAndGet(); + ctx.fireChannelRead(msg); + } + + @Override + public void channelReadComplete(ChannelHandlerContext ctx) throws Exception { + channelReadComplete1.incrementAndGet(); + ctx.fireChannelReadComplete(); + } + }, new ChannelDuplexHandler() { + @Override + public void read(ChannelHandlerContext ctx) throws Exception { + read2.incrementAndGet(); + ctx.read(); + } + + @Override + public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception { + // Consume + channelRead2.incrementAndGet(); + } + + @Override + public void channelReadComplete(ChannelHandlerContext ctx) throws Exception { + channelReadComplete2.incrementAndGet(); + ctx.fireChannelReadComplete(); + } + }); + + ch.config().setAutoRead(false); + + ch.writeInbound(1); + ch.writeInbound(2); + ch.writeInbound(3); + ch.finish(); + assertEquals(0, read1.get()); + assertEquals(3, channelRead1.get()); + assertEquals(3, channelReadComplete1.get()); + + assertEquals(0, read2.get()); + assertEquals(3, channelRead2.get()); + assertEquals(3, channelReadComplete1.get()); + + assertNull(ch.readInbound()); + } + private static int next(AbstractChannelHandlerContext ctx) { AbstractChannelHandlerContext next = ctx.next; if (next == null) {