From 7f4ade7e7d6f907a30926c54e16abefe88e81a8e Mon Sep 17 00:00:00 2001 From: Norman Maurer Date: Tue, 21 Nov 2017 09:59:51 +0100 Subject: [PATCH] Correctly propagate channelInactive even if cleanup throws Motivation: Its possible that cleanup() will throw if invalid data is passed into the wrapped EmbeddedChannel. We need to ensure we still call channelInactive(...) in this case. Modifications: - Correctly forward Exceptions caused by cleanup() - Ensure all content is released when cleanup() throws - Add unit tests Result: Correctly handle the case when cleanup() throws. --- .../codec/http/HttpContentDecoder.java | 25 +++++----- .../codec/http/HttpContentEncoder.java | 26 +++++----- .../codec/http/HttpContentDecoderTest.java | 47 ++++++++++++++--- .../codec/http/HttpContentEncoderTest.java | 50 +++++++++++++++++-- 4 files changed, 112 insertions(+), 36 deletions(-) diff --git a/codec-http/src/main/java/io/netty/handler/codec/http/HttpContentDecoder.java b/codec-http/src/main/java/io/netty/handler/codec/http/HttpContentDecoder.java index a39c94e32c..e85adaaa37 100644 --- a/codec-http/src/main/java/io/netty/handler/codec/http/HttpContentDecoder.java +++ b/codec-http/src/main/java/io/netty/handler/codec/http/HttpContentDecoder.java @@ -195,13 +195,13 @@ public abstract class HttpContentDecoder extends MessageToMessageDecoder out) { // call retain here as it will call release after its written to the channel decoder.writeInbound(in.retain()); diff --git a/codec-http/src/main/java/io/netty/handler/codec/http/HttpContentEncoder.java b/codec-http/src/main/java/io/netty/handler/codec/http/HttpContentEncoder.java index c04d538681..bd59bc2e86 100644 --- a/codec-http/src/main/java/io/netty/handler/codec/http/HttpContentEncoder.java +++ b/codec-http/src/main/java/io/netty/handler/codec/http/HttpContentEncoder.java @@ -289,34 +289,34 @@ public abstract class HttpContentEncoder extends MessageToMessageCodec out) { // call retain here as it will call release after its written to the channel encoder.writeOutbound(in.retain()); diff --git a/codec-http/src/test/java/io/netty/handler/codec/http/HttpContentDecoderTest.java b/codec-http/src/test/java/io/netty/handler/codec/http/HttpContentDecoderTest.java index 8aa692f724..7a27a4c08a 100644 --- a/codec-http/src/test/java/io/netty/handler/codec/http/HttpContentDecoderTest.java +++ b/codec-http/src/test/java/io/netty/handler/codec/http/HttpContentDecoderTest.java @@ -21,6 +21,8 @@ import io.netty.buffer.Unpooled; import io.netty.channel.ChannelHandlerContext; import io.netty.channel.ChannelInboundHandlerAdapter; import io.netty.channel.embedded.EmbeddedChannel; +import io.netty.handler.codec.CodecException; +import io.netty.handler.codec.DecoderException; import io.netty.handler.codec.compression.ZlibCodecFactory; import io.netty.handler.codec.compression.ZlibDecoder; import io.netty.handler.codec.compression.ZlibEncoder; @@ -32,16 +34,12 @@ import org.junit.Test; import java.util.ArrayList; import java.util.List; import java.util.Queue; +import java.util.concurrent.atomic.AtomicBoolean; import java.util.concurrent.atomic.AtomicReference; import static org.hamcrest.CoreMatchers.instanceOf; import static org.hamcrest.CoreMatchers.is; -import static org.junit.Assert.assertEquals; -import static org.junit.Assert.assertFalse; -import static org.junit.Assert.assertNotNull; -import static org.junit.Assert.assertNull; -import static org.junit.Assert.assertThat; -import static org.junit.Assert.assertTrue; +import static org.junit.Assert.*; public class HttpContentDecoderTest { private static final String HELLO_WORLD = "hello, world"; @@ -486,6 +484,43 @@ public class HttpContentDecoderTest { assertFalse(channel.finish()); } + @Test + public void testCleanupThrows() { + HttpContentDecoder decoder = new HttpContentDecoder() { + @Override + protected EmbeddedChannel newContentDecoder(String contentEncoding) throws Exception { + return new EmbeddedChannel(new ChannelInboundHandlerAdapter() { + @Override + public void channelInactive(ChannelHandlerContext ctx) throws Exception { + ctx.fireExceptionCaught(new DecoderException()); + ctx.fireChannelInactive(); + } + }); + } + }; + + final AtomicBoolean channelInactiveCalled = new AtomicBoolean(); + EmbeddedChannel channel = new EmbeddedChannel(decoder, new ChannelInboundHandlerAdapter() { + @Override + public void channelInactive(ChannelHandlerContext ctx) throws Exception { + assertTrue(channelInactiveCalled.compareAndSet(false, true)); + super.channelInactive(ctx); + } + }); + assertTrue(channel.writeInbound(new DefaultHttpRequest(HttpVersion.HTTP_1_1, HttpMethod.GET, "/"))); + HttpContent content = new DefaultHttpContent(Unpooled.buffer().writeZero(10)); + assertTrue(channel.writeInbound(content)); + assertEquals(1, content.refCnt()); + try { + channel.finishAndReleaseAll(); + fail(); + } catch (CodecException expected) { + // expected + } + assertTrue(channelInactiveCalled.get()); + assertEquals(0, content.refCnt()); + } + private static byte[] gzDecompress(byte[] input) { ZlibDecoder decoder = ZlibCodecFactory.newZlibDecoder(ZlibWrapper.GZIP); EmbeddedChannel channel = new EmbeddedChannel(decoder); diff --git a/codec-http/src/test/java/io/netty/handler/codec/http/HttpContentEncoderTest.java b/codec-http/src/test/java/io/netty/handler/codec/http/HttpContentEncoderTest.java index d4bb1af78b..9242cf8728 100644 --- a/codec-http/src/test/java/io/netty/handler/codec/http/HttpContentEncoderTest.java +++ b/codec-http/src/test/java/io/netty/handler/codec/http/HttpContentEncoderTest.java @@ -19,21 +19,22 @@ package io.netty.handler.codec.http; import io.netty.buffer.ByteBuf; import io.netty.buffer.Unpooled; import io.netty.channel.ChannelHandlerContext; +import io.netty.channel.ChannelInboundHandlerAdapter; import io.netty.channel.embedded.EmbeddedChannel; +import io.netty.handler.codec.CodecException; +import io.netty.handler.codec.EncoderException; import io.netty.handler.codec.MessageToByteEncoder; import io.netty.util.CharsetUtil; import org.junit.Test; +import java.util.concurrent.atomic.AtomicBoolean; + import static io.netty.handler.codec.http.HttpHeadersTestUtils.of; import static org.hamcrest.CoreMatchers.instanceOf; import static org.hamcrest.CoreMatchers.is; import static org.hamcrest.CoreMatchers.not; import static org.hamcrest.CoreMatchers.nullValue; -import static org.junit.Assert.assertEquals; -import static org.junit.Assert.assertNull; -import static org.junit.Assert.assertSame; -import static org.junit.Assert.assertThat; -import static org.junit.Assert.assertTrue; +import static org.junit.Assert.*; public class HttpContentEncoderTest { @@ -387,6 +388,45 @@ public class HttpContentEncoderTest { assertNull(ch.readOutbound()); } + @Test + public void testCleanupThrows() { + HttpContentEncoder encoder = new HttpContentEncoder() { + @Override + protected Result beginEncode(HttpResponse headers, String acceptEncoding) throws Exception { + return new Result("myencoding", new EmbeddedChannel( + new ChannelInboundHandlerAdapter() { + @Override + public void channelInactive(ChannelHandlerContext ctx) throws Exception { + ctx.fireExceptionCaught(new EncoderException()); + ctx.fireChannelInactive(); + } + })); + } + }; + + final AtomicBoolean channelInactiveCalled = new AtomicBoolean(); + EmbeddedChannel channel = new EmbeddedChannel(encoder, new ChannelInboundHandlerAdapter() { + @Override + public void channelInactive(ChannelHandlerContext ctx) throws Exception { + assertTrue(channelInactiveCalled.compareAndSet(false, true)); + super.channelInactive(ctx); + } + }); + assertTrue(channel.writeInbound(new DefaultFullHttpRequest(HttpVersion.HTTP_1_1, HttpMethod.GET, "/"))); + assertTrue(channel.writeOutbound(new DefaultHttpResponse(HttpVersion.HTTP_1_1, HttpResponseStatus.OK))); + HttpContent content = new DefaultHttpContent(Unpooled.buffer().writeZero(10)); + assertTrue(channel.writeOutbound(content)); + assertEquals(1, content.refCnt()); + try { + channel.finishAndReleaseAll(); + fail(); + } catch (CodecException expected) { + // expected + } + assertTrue(channelInactiveCalled.get()); + assertEquals(0, content.refCnt()); + } + private static void assertEmptyResponse(EmbeddedChannel ch) { Object o = ch.readOutbound(); assertThat(o, is(instanceOf(HttpResponse.class)));