From ce2ce9d7a4c375511e6f0f33a5892492f6297702 Mon Sep 17 00:00:00 2001 From: Scott Mitchell Date: Tue, 9 May 2017 12:58:29 -0700 Subject: [PATCH] ByteToMessageDecoder#handlerRemoved may release cumulation buffer prematurely Motivation: ByteToMessageDecoder#handlerRemoved will immediately release the cumulation buffer, but it is possible that a child class may still be using this buffer, and therefore use a dereferenced buffer. Modifications: - ByteToMessageDecoder#handlerRemoved and ByteToMessageDecoder#decode should coordinate to avoid the case where a child class is using the cumulation buffer but ByteToMessageDecoder releases that buffer. Result: Child classes of ByteToMessageDecoder are less likely to reference a released buffer. --- .../handler/codec/ByteToMessageDecoder.java | 47 +++++++++++- .../netty/handler/codec/ReplayingDecoder.java | 2 +- .../codec/ByteToMessageDecoderTest.java | 73 +++++++++++++------ .../handler/codec/ReplayingDecoderTest.java | 27 +++++++ 4 files changed, 122 insertions(+), 27 deletions(-) diff --git a/codec/src/main/java/io/netty/handler/codec/ByteToMessageDecoder.java b/codec/src/main/java/io/netty/handler/codec/ByteToMessageDecoder.java index de4b4743ba..28c3dc12f4 100644 --- a/codec/src/main/java/io/netty/handler/codec/ByteToMessageDecoder.java +++ b/codec/src/main/java/io/netty/handler/codec/ByteToMessageDecoder.java @@ -129,11 +129,24 @@ public abstract class ByteToMessageDecoder extends ChannelInboundHandlerAdapter } }; + private static final byte STATE_INIT = 0; + private static final byte STATE_CALLING_CHILD_DECODE = 1; + private static final byte STATE_HANDLER_REMOVED_PENDING = 2; + ByteBuf cumulation; private Cumulator cumulator = MERGE_CUMULATOR; private boolean singleDecode; private boolean decodeWasNull; private boolean first; + /** + * A bitmask where the bits are defined as + * + */ + private byte decodeState = STATE_INIT; private int discardAfterReads = 16; private int numReads; @@ -207,6 +220,10 @@ public abstract class ByteToMessageDecoder extends ChannelInboundHandlerAdapter @Override public final void handlerRemoved(ChannelHandlerContext ctx) throws Exception { + if (decodeState == STATE_CALLING_CHILD_DECODE) { + decodeState = STATE_HANDLER_REMOVED_PENDING; + return; + } ByteBuf buf = cumulation; if (buf != null) { // Directly set this to null so we are sure we not access it in any other method here anymore. @@ -408,7 +425,7 @@ public abstract class ByteToMessageDecoder extends ChannelInboundHandlerAdapter } int oldInputLength = in.readableBytes(); - decode(ctx, in, out); + decodeRemovalReentryProtection(ctx, in, out); // Check if this handler was removed before continuing the loop. // If it was removed, it is not safe to continue to operate on the buffer. @@ -429,7 +446,7 @@ public abstract class ByteToMessageDecoder extends ChannelInboundHandlerAdapter if (oldInputLength == in.readableBytes()) { throw new DecoderException( StringUtil.simpleClassName(getClass()) + - ".decode() did not read anything but decoded a message."); + ".decode() did not read anything but decoded a message."); } if (isSingleDecode()) { @@ -455,6 +472,30 @@ public abstract class ByteToMessageDecoder extends ChannelInboundHandlerAdapter */ protected abstract void decode(ChannelHandlerContext ctx, ByteBuf in, List out) throws Exception; + /** + * Decode the from one {@link ByteBuf} to an other. This method will be called till either the input + * {@link ByteBuf} has nothing to read when return from this method or till nothing was read from the input + * {@link ByteBuf}. + * + * @param ctx the {@link ChannelHandlerContext} which this {@link ByteToMessageDecoder} belongs to + * @param in the {@link ByteBuf} from which to read data + * @param out the {@link List} to which decoded messages should be added + * @throws Exception is thrown if an error occurs + */ + final void decodeRemovalReentryProtection(ChannelHandlerContext ctx, ByteBuf in, List out) + throws Exception { + decodeState = STATE_CALLING_CHILD_DECODE; + try { + decode(ctx, in, out); + } finally { + boolean removePending = decodeState == STATE_HANDLER_REMOVED_PENDING; + decodeState = STATE_INIT; + if (removePending) { + handlerRemoved(ctx); + } + } + } + /** * Is called one last time when the {@link ChannelHandlerContext} goes in-active. Which means the * {@link #channelInactive(ChannelHandlerContext)} was triggered. @@ -466,7 +507,7 @@ public abstract class ByteToMessageDecoder extends ChannelInboundHandlerAdapter if (in.isReadable()) { // Only call decode() if there is something left in the buffer to decode. // See https://github.com/netty/netty/issues/4386 - decode(ctx, in, out); + decodeRemovalReentryProtection(ctx, in, out); } } diff --git a/codec/src/main/java/io/netty/handler/codec/ReplayingDecoder.java b/codec/src/main/java/io/netty/handler/codec/ReplayingDecoder.java index e3116b38af..2e10d1ecdf 100644 --- a/codec/src/main/java/io/netty/handler/codec/ReplayingDecoder.java +++ b/codec/src/main/java/io/netty/handler/codec/ReplayingDecoder.java @@ -364,7 +364,7 @@ public abstract class ReplayingDecoder extends ByteToMessageDecoder { S oldState = state; int oldInputLength = in.readableBytes(); try { - decode(ctx, replayable, out); + decodeRemovalReentryProtection(ctx, replayable, out); // Check if this handler was removed before continuing the loop. // If it was removed, it is not safe to continue to operate on the buffer. diff --git a/codec/src/test/java/io/netty/handler/codec/ByteToMessageDecoderTest.java b/codec/src/test/java/io/netty/handler/codec/ByteToMessageDecoderTest.java index 70255a3b37..8462ee4a02 100644 --- a/codec/src/test/java/io/netty/handler/codec/ByteToMessageDecoderTest.java +++ b/codec/src/test/java/io/netty/handler/codec/ByteToMessageDecoderTest.java @@ -27,7 +27,10 @@ import java.util.List; import java.util.concurrent.BlockingQueue; import java.util.concurrent.LinkedBlockingDeque; -import static org.junit.Assert.*; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertNull; +import static org.junit.Assert.assertTrue; public class ByteToMessageDecoderTest { @@ -87,17 +90,7 @@ public class ByteToMessageDecoderTest { @Test public void testInternalBufferClearReadAll() { final ByteBuf buf = Unpooled.buffer().writeBytes(new byte[] {'a'}); - EmbeddedChannel channel = new EmbeddedChannel(new ByteToMessageDecoder() { - @Override - protected void decode(ChannelHandlerContext ctx, ByteBuf in, List out) throws Exception { - ByteBuf byteBuf = internalBuffer(); - assertEquals(1, byteBuf.refCnt()); - in.readByte(); - // Removal from pipeline should clear internal buffer - ctx.pipeline().remove(this); - assertEquals(0, byteBuf.refCnt()); - } - }); + EmbeddedChannel channel = newInternalBufferTestChannel(); assertFalse(channel.writeInbound(buf)); assertFalse(channel.finish()); } @@ -109,17 +102,7 @@ public class ByteToMessageDecoderTest { @Test public void testInternalBufferClearReadPartly() { final ByteBuf buf = Unpooled.buffer().writeBytes(new byte[] {'a', 'b'}); - EmbeddedChannel channel = new EmbeddedChannel(new ByteToMessageDecoder() { - @Override - protected void decode(ChannelHandlerContext ctx, ByteBuf in, List out) throws Exception { - ByteBuf byteBuf = internalBuffer(); - assertEquals(1, byteBuf.refCnt()); - in.readByte(); - // Removal from pipeline should clear internal buffer - ctx.pipeline().remove(this); - assertEquals(0, byteBuf.refCnt()); - } - }); + EmbeddedChannel channel = newInternalBufferTestChannel(); assertTrue(channel.writeInbound(buf)); assertTrue(channel.finish()); ByteBuf expected = Unpooled.wrappedBuffer(new byte[] {'b'}); @@ -130,6 +113,50 @@ public class ByteToMessageDecoderTest { b.release(); } + private EmbeddedChannel newInternalBufferTestChannel() { + return new EmbeddedChannel(new ByteToMessageDecoder() { + @Override + protected void decode(ChannelHandlerContext ctx, ByteBuf in, List out) throws Exception { + ByteBuf byteBuf = internalBuffer(); + assertEquals(1, byteBuf.refCnt()); + in.readByte(); + // Removal from pipeline should clear internal buffer + ctx.pipeline().remove(this); + } + + @Override + protected void handlerRemoved0(ChannelHandlerContext ctx) throws Exception { + assertCumulationReleased(internalBuffer()); + } + }); + } + + @Test + public void handlerRemovedWillNotReleaseBufferIfDecodeInProgress() { + EmbeddedChannel channel = new EmbeddedChannel(new ByteToMessageDecoder() { + @Override + protected void decode(ChannelHandlerContext ctx, ByteBuf in, List out) throws Exception { + ctx.pipeline().remove(this); + assertTrue(in.refCnt() != 0); + } + + @Override + protected void handlerRemoved0(ChannelHandlerContext ctx) throws Exception { + assertCumulationReleased(internalBuffer()); + } + }); + byte[] bytes = new byte[1024]; + PlatformDependent.threadLocalRandom().nextBytes(bytes); + + assertTrue(channel.writeInbound(Unpooled.wrappedBuffer(bytes))); + assertTrue(channel.finishAndReleaseAll()); + } + + private static void assertCumulationReleased(ByteBuf byteBuf) { + assertTrue("unexpected value: " + byteBuf, + byteBuf == null || byteBuf == Unpooled.EMPTY_BUFFER || byteBuf.refCnt() == 0); + } + @Test public void testFireChannelReadCompleteOnInactive() throws InterruptedException { final BlockingQueue queue = new LinkedBlockingDeque(); diff --git a/codec/src/test/java/io/netty/handler/codec/ReplayingDecoderTest.java b/codec/src/test/java/io/netty/handler/codec/ReplayingDecoderTest.java index 030277835b..5a7e878805 100644 --- a/codec/src/test/java/io/netty/handler/codec/ReplayingDecoderTest.java +++ b/codec/src/test/java/io/netty/handler/codec/ReplayingDecoderTest.java @@ -21,6 +21,7 @@ import io.netty.channel.ChannelHandlerContext; import io.netty.channel.ChannelInboundHandlerAdapter; import io.netty.channel.embedded.EmbeddedChannel; import io.netty.channel.socket.ChannelInputShutdownEvent; +import io.netty.util.internal.PlatformDependent; import org.junit.Test; import java.util.List; @@ -286,4 +287,30 @@ public class ReplayingDecoderTest { throw err; } } + + @Test + public void handlerRemovedWillNotReleaseBufferIfDecodeInProgress() { + EmbeddedChannel channel = new EmbeddedChannel(new ReplayingDecoder() { + @Override + protected void decode(ChannelHandlerContext ctx, ByteBuf in, List out) throws Exception { + ctx.pipeline().remove(this); + assertTrue(in.refCnt() != 0); + } + + @Override + protected void handlerRemoved0(ChannelHandlerContext ctx) throws Exception { + assertCumulationReleased(internalBuffer()); + } + }); + byte[] bytes = new byte[1024]; + PlatformDependent.threadLocalRandom().nextBytes(bytes); + + assertTrue(channel.writeInbound(Unpooled.wrappedBuffer(bytes))); + assertTrue(channel.finishAndReleaseAll()); + } + + private static void assertCumulationReleased(ByteBuf byteBuf) { + assertTrue("unexpected value: " + byteBuf, + byteBuf == null || byteBuf == Unpooled.EMPTY_BUFFER || byteBuf.refCnt() == 0); + } }