Correctly propagate channelInactive even if cleanup throws

Motivation:

Its possible that cleanup() will throw if invalid data is passed into the wrapped EmbeddedChannel. We need to ensure we still call channelInactive(...) in this case.

Modifications:

- Correctly forward Exceptions caused by cleanup()
- Ensure all content is released when cleanup() throws
- Add unit tests

Result:

Correctly handle the case when cleanup() throws.
This commit is contained in:
Norman Maurer 2017-11-21 09:59:51 +01:00
parent e5e4c18c1b
commit 7f4ade7e7d
4 changed files with 112 additions and 36 deletions

View File

@ -195,13 +195,13 @@ public abstract class HttpContentDecoder extends MessageToMessageDecoder<HttpObj
@Override
public void handlerRemoved(ChannelHandlerContext ctx) throws Exception {
cleanup();
cleanupSafely(ctx);
super.handlerRemoved(ctx);
}
@Override
public void channelInactive(ChannelHandlerContext ctx) throws Exception {
cleanup();
cleanupSafely(ctx);
super.channelInactive(ctx);
}
@ -214,20 +214,21 @@ public abstract class HttpContentDecoder extends MessageToMessageDecoder<HttpObj
private void cleanup() {
if (decoder != null) {
// Clean-up the previous decoder if not cleaned up correctly.
if (decoder.finish()) {
for (;;) {
ByteBuf buf = decoder.readInbound();
if (buf == null) {
break;
}
// Release the buffer
buf.release();
}
}
decoder.finishAndReleaseAll();
decoder = null;
}
}
private void cleanupSafely(ChannelHandlerContext ctx) {
try {
cleanup();
} catch (Throwable cause) {
// If cleanup throws any error we need to propagate it through the pipeline
// so we don't fail to propagate pipeline events.
ctx.fireExceptionCaught(cause);
}
}
private void decode(ByteBuf in, List<Object> out) {
// call retain here as it will call release after its written to the channel
decoder.writeInbound(in.retain());

View File

@ -289,34 +289,34 @@ public abstract class HttpContentEncoder extends MessageToMessageCodec<HttpReque
@Override
public void handlerRemoved(ChannelHandlerContext ctx) throws Exception {
cleanup();
cleanupSafely(ctx);
super.handlerRemoved(ctx);
}
@Override
public void channelInactive(ChannelHandlerContext ctx) throws Exception {
cleanup();
cleanupSafely(ctx);
super.channelInactive(ctx);
}
private void cleanup() {
if (encoder != null) {
// Clean-up the previous encoder if not cleaned up correctly.
if (encoder.finish()) {
for (;;) {
ByteBuf buf = encoder.readOutbound();
if (buf == null) {
break;
}
// Release the buffer
// https://github.com/netty/netty/issues/1524
buf.release();
}
}
encoder.finishAndReleaseAll();
encoder = null;
}
}
private void cleanupSafely(ChannelHandlerContext ctx) {
try {
cleanup();
} catch (Throwable cause) {
// If cleanup throws any error we need to propagate it through the pipeline
// so we don't fail to propagate pipeline events.
ctx.fireExceptionCaught(cause);
}
}
private void encode(ByteBuf in, List<Object> out) {
// call retain here as it will call release after its written to the channel
encoder.writeOutbound(in.retain());

View File

@ -21,6 +21,8 @@ import io.netty.buffer.Unpooled;
import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.ChannelInboundHandlerAdapter;
import io.netty.channel.embedded.EmbeddedChannel;
import io.netty.handler.codec.CodecException;
import io.netty.handler.codec.DecoderException;
import io.netty.handler.codec.compression.ZlibCodecFactory;
import io.netty.handler.codec.compression.ZlibDecoder;
import io.netty.handler.codec.compression.ZlibEncoder;
@ -32,16 +34,12 @@ import org.junit.Test;
import java.util.ArrayList;
import java.util.List;
import java.util.Queue;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicReference;
import static org.hamcrest.CoreMatchers.instanceOf;
import static org.hamcrest.CoreMatchers.is;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertFalse;
import static org.junit.Assert.assertNotNull;
import static org.junit.Assert.assertNull;
import static org.junit.Assert.assertThat;
import static org.junit.Assert.assertTrue;
import static org.junit.Assert.*;
public class HttpContentDecoderTest {
private static final String HELLO_WORLD = "hello, world";
@ -486,6 +484,43 @@ public class HttpContentDecoderTest {
assertFalse(channel.finish());
}
@Test
public void testCleanupThrows() {
HttpContentDecoder decoder = new HttpContentDecoder() {
@Override
protected EmbeddedChannel newContentDecoder(String contentEncoding) throws Exception {
return new EmbeddedChannel(new ChannelInboundHandlerAdapter() {
@Override
public void channelInactive(ChannelHandlerContext ctx) throws Exception {
ctx.fireExceptionCaught(new DecoderException());
ctx.fireChannelInactive();
}
});
}
};
final AtomicBoolean channelInactiveCalled = new AtomicBoolean();
EmbeddedChannel channel = new EmbeddedChannel(decoder, new ChannelInboundHandlerAdapter() {
@Override
public void channelInactive(ChannelHandlerContext ctx) throws Exception {
assertTrue(channelInactiveCalled.compareAndSet(false, true));
super.channelInactive(ctx);
}
});
assertTrue(channel.writeInbound(new DefaultHttpRequest(HttpVersion.HTTP_1_1, HttpMethod.GET, "/")));
HttpContent content = new DefaultHttpContent(Unpooled.buffer().writeZero(10));
assertTrue(channel.writeInbound(content));
assertEquals(1, content.refCnt());
try {
channel.finishAndReleaseAll();
fail();
} catch (CodecException expected) {
// expected
}
assertTrue(channelInactiveCalled.get());
assertEquals(0, content.refCnt());
}
private static byte[] gzDecompress(byte[] input) {
ZlibDecoder decoder = ZlibCodecFactory.newZlibDecoder(ZlibWrapper.GZIP);
EmbeddedChannel channel = new EmbeddedChannel(decoder);

View File

@ -19,21 +19,22 @@ package io.netty.handler.codec.http;
import io.netty.buffer.ByteBuf;
import io.netty.buffer.Unpooled;
import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.ChannelInboundHandlerAdapter;
import io.netty.channel.embedded.EmbeddedChannel;
import io.netty.handler.codec.CodecException;
import io.netty.handler.codec.EncoderException;
import io.netty.handler.codec.MessageToByteEncoder;
import io.netty.util.CharsetUtil;
import org.junit.Test;
import java.util.concurrent.atomic.AtomicBoolean;
import static io.netty.handler.codec.http.HttpHeadersTestUtils.of;
import static org.hamcrest.CoreMatchers.instanceOf;
import static org.hamcrest.CoreMatchers.is;
import static org.hamcrest.CoreMatchers.not;
import static org.hamcrest.CoreMatchers.nullValue;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertNull;
import static org.junit.Assert.assertSame;
import static org.junit.Assert.assertThat;
import static org.junit.Assert.assertTrue;
import static org.junit.Assert.*;
public class HttpContentEncoderTest {
@ -387,6 +388,45 @@ public class HttpContentEncoderTest {
assertNull(ch.readOutbound());
}
@Test
public void testCleanupThrows() {
HttpContentEncoder encoder = new HttpContentEncoder() {
@Override
protected Result beginEncode(HttpResponse headers, String acceptEncoding) throws Exception {
return new Result("myencoding", new EmbeddedChannel(
new ChannelInboundHandlerAdapter() {
@Override
public void channelInactive(ChannelHandlerContext ctx) throws Exception {
ctx.fireExceptionCaught(new EncoderException());
ctx.fireChannelInactive();
}
}));
}
};
final AtomicBoolean channelInactiveCalled = new AtomicBoolean();
EmbeddedChannel channel = new EmbeddedChannel(encoder, new ChannelInboundHandlerAdapter() {
@Override
public void channelInactive(ChannelHandlerContext ctx) throws Exception {
assertTrue(channelInactiveCalled.compareAndSet(false, true));
super.channelInactive(ctx);
}
});
assertTrue(channel.writeInbound(new DefaultFullHttpRequest(HttpVersion.HTTP_1_1, HttpMethod.GET, "/")));
assertTrue(channel.writeOutbound(new DefaultHttpResponse(HttpVersion.HTTP_1_1, HttpResponseStatus.OK)));
HttpContent content = new DefaultHttpContent(Unpooled.buffer().writeZero(10));
assertTrue(channel.writeOutbound(content));
assertEquals(1, content.refCnt());
try {
channel.finishAndReleaseAll();
fail();
} catch (CodecException expected) {
// expected
}
assertTrue(channelInactiveCalled.get());
assertEquals(0, content.refCnt());
}
private static void assertEmptyResponse(EmbeddedChannel ch) {
Object o = ch.readOutbound();
assertThat(o, is(instanceOf(HttpResponse.class)));