ByteToMessageDecoder#handlerRemoved may release cumulation buffer prematurely

Motivation:
ByteToMessageDecoder#handlerRemoved will immediately release the cumulation buffer, but it is possible that a child class may still be using this buffer, and therefore use a dereferenced buffer.

Modifications:
- ByteToMessageDecoder#handlerRemoved and ByteToMessageDecoder#decode should coordinate to avoid the case where a child class is using the cumulation buffer but ByteToMessageDecoder releases that buffer.

Result:
Child classes of ByteToMessageDecoder are less likely to reference a released buffer.
This commit is contained in:
Scott Mitchell 2017-05-09 12:58:29 -07:00
parent c053c5144d
commit ce2ce9d7a4
4 changed files with 122 additions and 27 deletions

View File

@ -129,11 +129,24 @@ public abstract class ByteToMessageDecoder extends ChannelInboundHandlerAdapter
} }
}; };
private static final byte STATE_INIT = 0;
private static final byte STATE_CALLING_CHILD_DECODE = 1;
private static final byte STATE_HANDLER_REMOVED_PENDING = 2;
ByteBuf cumulation; ByteBuf cumulation;
private Cumulator cumulator = MERGE_CUMULATOR; private Cumulator cumulator = MERGE_CUMULATOR;
private boolean singleDecode; private boolean singleDecode;
private boolean decodeWasNull; private boolean decodeWasNull;
private boolean first; private boolean first;
/**
* A bitmask where the bits are defined as
* <ul>
* <li>{@link #STATE_INIT}</li>
* <li>{@link #STATE_CALLING_CHILD_DECODE}</li>
* <li>{@link #STATE_HANDLER_REMOVED_PENDING}</li>
* </ul>
*/
private byte decodeState = STATE_INIT;
private int discardAfterReads = 16; private int discardAfterReads = 16;
private int numReads; private int numReads;
@ -207,6 +220,10 @@ public abstract class ByteToMessageDecoder extends ChannelInboundHandlerAdapter
@Override @Override
public final void handlerRemoved(ChannelHandlerContext ctx) throws Exception { public final void handlerRemoved(ChannelHandlerContext ctx) throws Exception {
if (decodeState == STATE_CALLING_CHILD_DECODE) {
decodeState = STATE_HANDLER_REMOVED_PENDING;
return;
}
ByteBuf buf = cumulation; ByteBuf buf = cumulation;
if (buf != null) { if (buf != null) {
// Directly set this to null so we are sure we not access it in any other method here anymore. // Directly set this to null so we are sure we not access it in any other method here anymore.
@ -408,7 +425,7 @@ public abstract class ByteToMessageDecoder extends ChannelInboundHandlerAdapter
} }
int oldInputLength = in.readableBytes(); int oldInputLength = in.readableBytes();
decode(ctx, in, out); decodeRemovalReentryProtection(ctx, in, out);
// Check if this handler was removed before continuing the loop. // Check if this handler was removed before continuing the loop.
// If it was removed, it is not safe to continue to operate on the buffer. // If it was removed, it is not safe to continue to operate on the buffer.
@ -429,7 +446,7 @@ public abstract class ByteToMessageDecoder extends ChannelInboundHandlerAdapter
if (oldInputLength == in.readableBytes()) { if (oldInputLength == in.readableBytes()) {
throw new DecoderException( throw new DecoderException(
StringUtil.simpleClassName(getClass()) + StringUtil.simpleClassName(getClass()) +
".decode() did not read anything but decoded a message."); ".decode() did not read anything but decoded a message.");
} }
if (isSingleDecode()) { if (isSingleDecode()) {
@ -455,6 +472,30 @@ public abstract class ByteToMessageDecoder extends ChannelInboundHandlerAdapter
*/ */
protected abstract void decode(ChannelHandlerContext ctx, ByteBuf in, List<Object> out) throws Exception; protected abstract void decode(ChannelHandlerContext ctx, ByteBuf in, List<Object> out) throws Exception;
/**
* Decode the from one {@link ByteBuf} to an other. This method will be called till either the input
* {@link ByteBuf} has nothing to read when return from this method or till nothing was read from the input
* {@link ByteBuf}.
*
* @param ctx the {@link ChannelHandlerContext} which this {@link ByteToMessageDecoder} belongs to
* @param in the {@link ByteBuf} from which to read data
* @param out the {@link List} to which decoded messages should be added
* @throws Exception is thrown if an error occurs
*/
final void decodeRemovalReentryProtection(ChannelHandlerContext ctx, ByteBuf in, List<Object> out)
throws Exception {
decodeState = STATE_CALLING_CHILD_DECODE;
try {
decode(ctx, in, out);
} finally {
boolean removePending = decodeState == STATE_HANDLER_REMOVED_PENDING;
decodeState = STATE_INIT;
if (removePending) {
handlerRemoved(ctx);
}
}
}
/** /**
* Is called one last time when the {@link ChannelHandlerContext} goes in-active. Which means the * Is called one last time when the {@link ChannelHandlerContext} goes in-active. Which means the
* {@link #channelInactive(ChannelHandlerContext)} was triggered. * {@link #channelInactive(ChannelHandlerContext)} was triggered.
@ -466,7 +507,7 @@ public abstract class ByteToMessageDecoder extends ChannelInboundHandlerAdapter
if (in.isReadable()) { if (in.isReadable()) {
// Only call decode() if there is something left in the buffer to decode. // Only call decode() if there is something left in the buffer to decode.
// See https://github.com/netty/netty/issues/4386 // See https://github.com/netty/netty/issues/4386
decode(ctx, in, out); decodeRemovalReentryProtection(ctx, in, out);
} }
} }

View File

@ -364,7 +364,7 @@ public abstract class ReplayingDecoder<S> extends ByteToMessageDecoder {
S oldState = state; S oldState = state;
int oldInputLength = in.readableBytes(); int oldInputLength = in.readableBytes();
try { try {
decode(ctx, replayable, out); decodeRemovalReentryProtection(ctx, replayable, out);
// Check if this handler was removed before continuing the loop. // Check if this handler was removed before continuing the loop.
// If it was removed, it is not safe to continue to operate on the buffer. // If it was removed, it is not safe to continue to operate on the buffer.

View File

@ -27,7 +27,10 @@ import java.util.List;
import java.util.concurrent.BlockingQueue; import java.util.concurrent.BlockingQueue;
import java.util.concurrent.LinkedBlockingDeque; import java.util.concurrent.LinkedBlockingDeque;
import static org.junit.Assert.*; import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertFalse;
import static org.junit.Assert.assertNull;
import static org.junit.Assert.assertTrue;
public class ByteToMessageDecoderTest { public class ByteToMessageDecoderTest {
@ -87,17 +90,7 @@ public class ByteToMessageDecoderTest {
@Test @Test
public void testInternalBufferClearReadAll() { public void testInternalBufferClearReadAll() {
final ByteBuf buf = Unpooled.buffer().writeBytes(new byte[] {'a'}); final ByteBuf buf = Unpooled.buffer().writeBytes(new byte[] {'a'});
EmbeddedChannel channel = new EmbeddedChannel(new ByteToMessageDecoder() { EmbeddedChannel channel = newInternalBufferTestChannel();
@Override
protected void decode(ChannelHandlerContext ctx, ByteBuf in, List<Object> out) throws Exception {
ByteBuf byteBuf = internalBuffer();
assertEquals(1, byteBuf.refCnt());
in.readByte();
// Removal from pipeline should clear internal buffer
ctx.pipeline().remove(this);
assertEquals(0, byteBuf.refCnt());
}
});
assertFalse(channel.writeInbound(buf)); assertFalse(channel.writeInbound(buf));
assertFalse(channel.finish()); assertFalse(channel.finish());
} }
@ -109,17 +102,7 @@ public class ByteToMessageDecoderTest {
@Test @Test
public void testInternalBufferClearReadPartly() { public void testInternalBufferClearReadPartly() {
final ByteBuf buf = Unpooled.buffer().writeBytes(new byte[] {'a', 'b'}); final ByteBuf buf = Unpooled.buffer().writeBytes(new byte[] {'a', 'b'});
EmbeddedChannel channel = new EmbeddedChannel(new ByteToMessageDecoder() { EmbeddedChannel channel = newInternalBufferTestChannel();
@Override
protected void decode(ChannelHandlerContext ctx, ByteBuf in, List<Object> out) throws Exception {
ByteBuf byteBuf = internalBuffer();
assertEquals(1, byteBuf.refCnt());
in.readByte();
// Removal from pipeline should clear internal buffer
ctx.pipeline().remove(this);
assertEquals(0, byteBuf.refCnt());
}
});
assertTrue(channel.writeInbound(buf)); assertTrue(channel.writeInbound(buf));
assertTrue(channel.finish()); assertTrue(channel.finish());
ByteBuf expected = Unpooled.wrappedBuffer(new byte[] {'b'}); ByteBuf expected = Unpooled.wrappedBuffer(new byte[] {'b'});
@ -130,6 +113,50 @@ public class ByteToMessageDecoderTest {
b.release(); b.release();
} }
private EmbeddedChannel newInternalBufferTestChannel() {
return new EmbeddedChannel(new ByteToMessageDecoder() {
@Override
protected void decode(ChannelHandlerContext ctx, ByteBuf in, List<Object> out) throws Exception {
ByteBuf byteBuf = internalBuffer();
assertEquals(1, byteBuf.refCnt());
in.readByte();
// Removal from pipeline should clear internal buffer
ctx.pipeline().remove(this);
}
@Override
protected void handlerRemoved0(ChannelHandlerContext ctx) throws Exception {
assertCumulationReleased(internalBuffer());
}
});
}
@Test
public void handlerRemovedWillNotReleaseBufferIfDecodeInProgress() {
EmbeddedChannel channel = new EmbeddedChannel(new ByteToMessageDecoder() {
@Override
protected void decode(ChannelHandlerContext ctx, ByteBuf in, List<Object> out) throws Exception {
ctx.pipeline().remove(this);
assertTrue(in.refCnt() != 0);
}
@Override
protected void handlerRemoved0(ChannelHandlerContext ctx) throws Exception {
assertCumulationReleased(internalBuffer());
}
});
byte[] bytes = new byte[1024];
PlatformDependent.threadLocalRandom().nextBytes(bytes);
assertTrue(channel.writeInbound(Unpooled.wrappedBuffer(bytes)));
assertTrue(channel.finishAndReleaseAll());
}
private static void assertCumulationReleased(ByteBuf byteBuf) {
assertTrue("unexpected value: " + byteBuf,
byteBuf == null || byteBuf == Unpooled.EMPTY_BUFFER || byteBuf.refCnt() == 0);
}
@Test @Test
public void testFireChannelReadCompleteOnInactive() throws InterruptedException { public void testFireChannelReadCompleteOnInactive() throws InterruptedException {
final BlockingQueue<Integer> queue = new LinkedBlockingDeque<Integer>(); final BlockingQueue<Integer> queue = new LinkedBlockingDeque<Integer>();

View File

@ -21,6 +21,7 @@ import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.ChannelInboundHandlerAdapter; import io.netty.channel.ChannelInboundHandlerAdapter;
import io.netty.channel.embedded.EmbeddedChannel; import io.netty.channel.embedded.EmbeddedChannel;
import io.netty.channel.socket.ChannelInputShutdownEvent; import io.netty.channel.socket.ChannelInputShutdownEvent;
import io.netty.util.internal.PlatformDependent;
import org.junit.Test; import org.junit.Test;
import java.util.List; import java.util.List;
@ -286,4 +287,30 @@ public class ReplayingDecoderTest {
throw err; throw err;
} }
} }
@Test
public void handlerRemovedWillNotReleaseBufferIfDecodeInProgress() {
EmbeddedChannel channel = new EmbeddedChannel(new ReplayingDecoder<Integer>() {
@Override
protected void decode(ChannelHandlerContext ctx, ByteBuf in, List<Object> out) throws Exception {
ctx.pipeline().remove(this);
assertTrue(in.refCnt() != 0);
}
@Override
protected void handlerRemoved0(ChannelHandlerContext ctx) throws Exception {
assertCumulationReleased(internalBuffer());
}
});
byte[] bytes = new byte[1024];
PlatformDependent.threadLocalRandom().nextBytes(bytes);
assertTrue(channel.writeInbound(Unpooled.wrappedBuffer(bytes)));
assertTrue(channel.finishAndReleaseAll());
}
private static void assertCumulationReleased(ByteBuf byteBuf) {
assertTrue("unexpected value: " + byteBuf,
byteBuf == null || byteBuf == Unpooled.EMPTY_BUFFER || byteBuf.refCnt() == 0);
}
} }