diff --git a/codec/src/main/java/io/netty/handler/codec/ByteToMessageDecoder.java b/codec/src/main/java/io/netty/handler/codec/ByteToMessageDecoder.java index ef071243b3..cb18d38379 100644 --- a/codec/src/main/java/io/netty/handler/codec/ByteToMessageDecoder.java +++ b/codec/src/main/java/io/netty/handler/codec/ByteToMessageDecoder.java @@ -353,6 +353,15 @@ public abstract class ByteToMessageDecoder extends ChannelInboundHandlerAdapter if (outSize > 0) { fireChannelRead(ctx, out, outSize); out.clear(); + + // Check if this handler was removed before continuing with decoding. + // If it was removed, it is not safe to continue to operate on the buffer. + // + // See: + // - https://github.com/netty/netty/issues/4635 + if (ctx.isRemoved()) { + break; + } outSize = 0; } diff --git a/codec/src/test/java/io/netty/handler/codec/ByteToMessageDecoderTest.java b/codec/src/test/java/io/netty/handler/codec/ByteToMessageDecoderTest.java index f5fd656976..add337673c 100644 --- a/codec/src/test/java/io/netty/handler/codec/ByteToMessageDecoderTest.java +++ b/codec/src/test/java/io/netty/handler/codec/ByteToMessageDecoderTest.java @@ -21,13 +21,14 @@ import io.netty.channel.ChannelHandlerContext; import io.netty.channel.ChannelInboundHandlerAdapter; import io.netty.channel.embedded.EmbeddedChannel; import io.netty.util.ReferenceCountUtil; -import org.junit.Assert; import org.junit.Test; import java.util.List; import java.util.concurrent.BlockingQueue; import java.util.concurrent.LinkedBlockingDeque; +import static org.junit.Assert.*; + public class ByteToMessageDecoderTest { @Test @@ -37,7 +38,7 @@ public class ByteToMessageDecoderTest { @Override protected void decode(ChannelHandlerContext ctx, ByteBuf in, List out) throws Exception { - Assert.assertFalse(removed); + assertFalse(removed); in.readByte(); ctx.pipeline().remove(this); removed = true; @@ -47,20 +48,20 @@ public class ByteToMessageDecoderTest { ByteBuf buf = Unpooled.wrappedBuffer(new byte[] {'a', 'b', 'c'}); channel.writeInbound(buf.copy()); ByteBuf b = channel.readInbound(); - Assert.assertEquals(b, buf.skipBytes(1)); + assertEquals(b, buf.skipBytes(1)); b.release(); buf.release(); } @Test public void testRemoveItselfWriteBuffer() { - final ByteBuf buf = Unpooled.buffer().writeBytes(new byte[]{'a', 'b', 'c'}); + final ByteBuf buf = Unpooled.buffer().writeBytes(new byte[] {'a', 'b', 'c'}); EmbeddedChannel channel = new EmbeddedChannel(new ByteToMessageDecoder() { private boolean removed; @Override protected void decode(ChannelHandlerContext ctx, ByteBuf in, List out) throws Exception { - Assert.assertFalse(removed); + assertFalse(removed); in.readByte(); ctx.pipeline().remove(this); @@ -71,8 +72,10 @@ public class ByteToMessageDecoderTest { }); channel.writeInbound(buf.copy()); + ByteBuf expected = Unpooled.wrappedBuffer(new byte[] {'b', 'c'}); ByteBuf b = channel.readInbound(); - Assert.assertEquals(b, Unpooled.wrappedBuffer(new byte[] { 'b', 'c'})); + assertEquals(expected, b); + expected.release(); buf.release(); b.release(); } @@ -83,20 +86,20 @@ public class ByteToMessageDecoderTest { */ @Test public void testInternalBufferClearReadAll() { - final ByteBuf buf = ReferenceCountUtil.releaseLater(Unpooled.buffer().writeBytes(new byte[]{'a'})); + final ByteBuf buf = ReferenceCountUtil.releaseLater(Unpooled.buffer().writeBytes(new byte[] {'a'})); EmbeddedChannel channel = new EmbeddedChannel(new ByteToMessageDecoder() { @Override protected void decode(ChannelHandlerContext ctx, ByteBuf in, List out) throws Exception { ByteBuf byteBuf = internalBuffer(); - Assert.assertEquals(1, byteBuf.refCnt()); + assertEquals(1, byteBuf.refCnt()); in.readByte(); // Removal from pipeline should clear internal buffer ctx.pipeline().remove(this); - Assert.assertEquals(0, byteBuf.refCnt()); + assertEquals(0, byteBuf.refCnt()); } }); - Assert.assertFalse(channel.writeInbound(buf)); - Assert.assertFalse(channel.finish()); + assertFalse(channel.writeInbound(buf)); + assertFalse(channel.finish()); } /** @@ -105,28 +108,32 @@ public class ByteToMessageDecoderTest { */ @Test public void testInternalBufferClearReadPartly() { - final ByteBuf buf = ReferenceCountUtil.releaseLater(Unpooled.buffer().writeBytes(new byte[]{'a', 'b'})); + final ByteBuf buf = ReferenceCountUtil.releaseLater(Unpooled.buffer().writeBytes(new byte[] {'a', 'b'})); EmbeddedChannel channel = new EmbeddedChannel(new ByteToMessageDecoder() { @Override protected void decode(ChannelHandlerContext ctx, ByteBuf in, List out) throws Exception { ByteBuf byteBuf = internalBuffer(); - Assert.assertEquals(1, byteBuf.refCnt()); + assertEquals(1, byteBuf.refCnt()); in.readByte(); // Removal from pipeline should clear internal buffer ctx.pipeline().remove(this); - Assert.assertEquals(0, byteBuf.refCnt()); + assertEquals(0, byteBuf.refCnt()); } }); - Assert.assertTrue(channel.writeInbound(buf)); - Assert.assertTrue(channel.finish()); - Assert.assertEquals(channel.readInbound(), Unpooled.wrappedBuffer(new byte[] {'b'})); - Assert.assertNull(channel.readInbound()); + assertTrue(channel.writeInbound(buf)); + assertTrue(channel.finish()); + ByteBuf expected = Unpooled.wrappedBuffer(new byte[] {'b'}); + ByteBuf b = channel.readInbound(); + assertEquals(expected, b); + assertNull(channel.readInbound()); + expected.release(); + b.release(); } @Test public void testFireChannelReadCompleteOnInactive() throws InterruptedException { final BlockingQueue queue = new LinkedBlockingDeque(); - final ByteBuf buf = ReferenceCountUtil.releaseLater(Unpooled.buffer().writeBytes(new byte[]{'a', 'b'})); + final ByteBuf buf = ReferenceCountUtil.releaseLater(Unpooled.buffer().writeBytes(new byte[] {'a', 'b'})); EmbeddedChannel channel = new EmbeddedChannel(new ByteToMessageDecoder() { @Override protected void decode(ChannelHandlerContext ctx, ByteBuf in, List out) throws Exception { @@ -153,11 +160,43 @@ public class ByteToMessageDecoderTest { } } }); - Assert.assertFalse(channel.writeInbound(buf)); + assertFalse(channel.writeInbound(buf)); channel.finish(); - Assert.assertEquals(1, (int) queue.take()); - Assert.assertEquals(2, (int) queue.take()); - Assert.assertEquals(3, (int) queue.take()); - Assert.assertTrue(queue.isEmpty()); + assertEquals(1, (int) queue.take()); + assertEquals(2, (int) queue.take()); + assertEquals(3, (int) queue.take()); + assertTrue(queue.isEmpty()); + } + + // See https://github.com/netty/netty/issues/4635 + @Test + public void testRemoveWhileInCallDecode() { + final Object upgradeMessage = new Object(); + final ByteToMessageDecoder decoder = new ByteToMessageDecoder() { + @Override + protected void decode(ChannelHandlerContext ctx, ByteBuf in, List out) throws Exception { + assertEquals('a', in.readByte()); + out.add(upgradeMessage); + } + }; + + EmbeddedChannel channel = new EmbeddedChannel(decoder, new ChannelInboundHandlerAdapter() { + @Override + public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception { + if (msg == upgradeMessage) { + ctx.pipeline().remove(decoder); + return; + } + ctx.fireChannelRead(msg); + } + }); + + ByteBuf buf = Unpooled.wrappedBuffer(new byte[] { 'a', 'b', 'c' }); + assertTrue(channel.writeInbound(buf.copy())); + ByteBuf b = channel.readInbound(); + assertEquals(b, buf.skipBytes(1)); + assertFalse(channel.finish()); + buf.release(); + b.release(); } }