From e609b5eeb7daba819f530fea500991c9cfc18412 Mon Sep 17 00:00:00 2001 From: Konstantin Lutovich Date: Thu, 28 Feb 2019 21:13:56 +0100 Subject: [PATCH] Close consumed inputs in ChunkedWriteHandler (#8876) Motivation: ChunkedWriteHandler needs to close both successful and failed ChunkInputs. It used to never close successful ones. Modifications: * ChunkedWriteHandler always closes ChunkInput before completing the write promise. * Ensure only ChunkInput#close() is invoked on a failed input. * Ensure no methods are invoked on a closed input. Result: Fixes https://github.com/netty/netty/issues/8875. --- .../handler/stream/ChunkedWriteHandler.java | 40 +-- .../stream/ChunkedWriteHandlerTest.java | 232 +++++++++++++++++- 2 files changed, 256 insertions(+), 16 deletions(-) diff --git a/handler/src/main/java/io/netty/handler/stream/ChunkedWriteHandler.java b/handler/src/main/java/io/netty/handler/stream/ChunkedWriteHandler.java index f39328dff7..1a1822b597 100644 --- a/handler/src/main/java/io/netty/handler/stream/ChunkedWriteHandler.java +++ b/handler/src/main/java/io/netty/handler/stream/ChunkedWriteHandler.java @@ -166,22 +166,28 @@ public class ChunkedWriteHandler extends ChannelDuplexHandler { Object message = currentWrite.msg; if (message instanceof ChunkedInput) { ChunkedInput in = (ChunkedInput) message; + boolean endOfInput; + long inputLength; try { - if (!in.isEndOfInput()) { - if (cause == null) { - cause = new ClosedChannelException(); - } - currentWrite.fail(cause); - } else { - currentWrite.success(in.length()); - } + endOfInput = in.isEndOfInput(); + inputLength = in.length(); closeInput(in); } catch (Exception e) { + closeInput(in); currentWrite.fail(e); if (logger.isWarnEnabled()) { - logger.warn(ChunkedInput.class.getSimpleName() + ".isEndOfInput() failed", e); + logger.warn(ChunkedInput.class.getSimpleName() + " failed", e); } - closeInput(in); + continue; + } + + if (!endOfInput) { + if (cause == null) { + cause = new ClosedChannelException(); + } + currentWrite.fail(cause); + } else { + currentWrite.success(inputLength); } } else { if (cause == null) { @@ -249,8 +255,8 @@ public class ChunkedWriteHandler extends ChannelDuplexHandler { ReferenceCountUtil.release(message); } - currentWrite.fail(t); closeInput(chunks); + currentWrite.fail(t); break; } @@ -283,8 +289,12 @@ public class ChunkedWriteHandler extends ChannelDuplexHandler { closeInput(chunks); currentWrite.fail(future.cause()); } else { - currentWrite.progress(chunks.progress(), chunks.length()); - currentWrite.success(chunks.length()); + // read state of the input in local variables before closing it + long inputProgress = chunks.progress(); + long inputLength = chunks.length(); + closeInput(chunks); + currentWrite.progress(inputProgress, inputLength); + currentWrite.success(inputLength); } } }); @@ -293,7 +303,7 @@ public class ChunkedWriteHandler extends ChannelDuplexHandler { @Override public void operationComplete(ChannelFuture future) throws Exception { if (!future.isSuccess()) { - closeInput((ChunkedInput) pendingMessage); + closeInput(chunks); currentWrite.fail(future.cause()); } else { currentWrite.progress(chunks.progress(), chunks.length()); @@ -305,7 +315,7 @@ public class ChunkedWriteHandler extends ChannelDuplexHandler { @Override public void operationComplete(ChannelFuture future) throws Exception { if (!future.isSuccess()) { - closeInput((ChunkedInput) pendingMessage); + closeInput(chunks); currentWrite.fail(future.cause()); } else { currentWrite.progress(chunks.progress(), chunks.length()); diff --git a/handler/src/test/java/io/netty/handler/stream/ChunkedWriteHandlerTest.java b/handler/src/test/java/io/netty/handler/stream/ChunkedWriteHandlerTest.java index 5b03048ba6..be6951d88b 100644 --- a/handler/src/test/java/io/netty/handler/stream/ChunkedWriteHandlerTest.java +++ b/handler/src/test/java/io/netty/handler/stream/ChunkedWriteHandlerTest.java @@ -21,8 +21,8 @@ import io.netty.buffer.Unpooled; import io.netty.channel.ChannelFuture; import io.netty.channel.ChannelFutureListener; import io.netty.channel.ChannelHandlerContext; -import io.netty.channel.ChannelPromise; import io.netty.channel.ChannelOutboundHandlerAdapter; +import io.netty.channel.ChannelPromise; import io.netty.channel.embedded.EmbeddedChannel; import io.netty.util.CharsetUtil; import io.netty.util.ReferenceCountUtil; @@ -33,9 +33,11 @@ import java.io.File; import java.io.FileOutputStream; import java.io.IOException; import java.nio.channels.Channels; +import java.util.concurrent.CountDownLatch; import java.util.concurrent.atomic.AtomicBoolean; import java.util.concurrent.atomic.AtomicInteger; +import static java.util.concurrent.TimeUnit.*; import static org.junit.Assert.*; public class ChunkedWriteHandlerTest { @@ -433,6 +435,142 @@ public class ChunkedWriteHandlerTest { assertEquals(1, chunks.get()); } + @Test + public void testCloseSuccessfulChunkedInput() { + int chunks = 10; + TestChunkedInput input = new TestChunkedInput(chunks); + EmbeddedChannel ch = new EmbeddedChannel(new ChunkedWriteHandler()); + + assertTrue(ch.writeOutbound(input)); + + for (int i = 0; i < chunks; i++) { + ByteBuf buf = ch.readOutbound(); + assertEquals(i, buf.readInt()); + buf.release(); + } + + assertTrue(input.isClosed()); + assertFalse(ch.finish()); + } + + @Test + public void testCloseFailedChunkedInput() { + Exception error = new Exception("Unable to produce a chunk"); + ThrowingChunkedInput input = new ThrowingChunkedInput(error); + + EmbeddedChannel ch = new EmbeddedChannel(new ChunkedWriteHandler()); + + try { + ch.writeOutbound(input); + fail("Exception expected"); + } catch (Exception e) { + assertEquals(error, e); + } + + assertTrue(input.isClosed()); + assertFalse(ch.finish()); + } + + @Test + public void testWriteListenerInvokedAfterSuccessfulChunkedInputClosed() throws Exception { + final TestChunkedInput input = new TestChunkedInput(2); + EmbeddedChannel ch = new EmbeddedChannel(new ChunkedWriteHandler()); + + final AtomicBoolean inputClosedWhenListenerInvoked = new AtomicBoolean(); + final CountDownLatch listenerInvoked = new CountDownLatch(1); + + ChannelFuture writeFuture = ch.write(input); + writeFuture.addListener(new ChannelFutureListener() { + @Override + public void operationComplete(ChannelFuture future) { + inputClosedWhenListenerInvoked.set(input.isClosed()); + listenerInvoked.countDown(); + } + }); + ch.flush(); + + assertTrue(listenerInvoked.await(10, SECONDS)); + assertTrue(writeFuture.isSuccess()); + assertTrue(inputClosedWhenListenerInvoked.get()); + assertTrue(ch.finishAndReleaseAll()); + } + + @Test + public void testWriteListenerInvokedAfterFailedChunkedInputClosed() throws Exception { + final ThrowingChunkedInput input = new ThrowingChunkedInput(new RuntimeException()); + EmbeddedChannel ch = new EmbeddedChannel(new ChunkedWriteHandler()); + + final AtomicBoolean inputClosedWhenListenerInvoked = new AtomicBoolean(); + final CountDownLatch listenerInvoked = new CountDownLatch(1); + + ChannelFuture writeFuture = ch.write(input); + writeFuture.addListener(new ChannelFutureListener() { + @Override + public void operationComplete(ChannelFuture future) { + inputClosedWhenListenerInvoked.set(input.isClosed()); + listenerInvoked.countDown(); + } + }); + ch.flush(); + + assertTrue(listenerInvoked.await(10, SECONDS)); + assertFalse(writeFuture.isSuccess()); + assertTrue(inputClosedWhenListenerInvoked.get()); + assertFalse(ch.finish()); + } + + @Test + public void testWriteListenerInvokedAfterChannelClosedAndInputFullyConsumed() throws Exception { + // use empty input which has endOfInput = true + final TestChunkedInput input = new TestChunkedInput(0); + EmbeddedChannel ch = new EmbeddedChannel(new ChunkedWriteHandler()); + + final AtomicBoolean inputClosedWhenListenerInvoked = new AtomicBoolean(); + final CountDownLatch listenerInvoked = new CountDownLatch(1); + + ChannelFuture writeFuture = ch.write(input); + writeFuture.addListener(new ChannelFutureListener() { + @Override + public void operationComplete(ChannelFuture future) { + inputClosedWhenListenerInvoked.set(input.isClosed()); + listenerInvoked.countDown(); + } + }); + ch.close(); // close channel to make handler discard the input on subsequent flush + ch.flush(); + + assertTrue(listenerInvoked.await(10, SECONDS)); + assertTrue(writeFuture.isSuccess()); + assertTrue(inputClosedWhenListenerInvoked.get()); + assertFalse(ch.finish()); + } + + @Test + public void testWriteListenerInvokedAfterChannelClosedAndInputNotFullyConsumed() throws Exception { + // use non-empty input which has endOfInput = false + final TestChunkedInput input = new TestChunkedInput(42); + EmbeddedChannel ch = new EmbeddedChannel(new ChunkedWriteHandler()); + + final AtomicBoolean inputClosedWhenListenerInvoked = new AtomicBoolean(); + final CountDownLatch listenerInvoked = new CountDownLatch(1); + + ChannelFuture writeFuture = ch.write(input); + writeFuture.addListener(new ChannelFutureListener() { + @Override + public void operationComplete(ChannelFuture future) { + inputClosedWhenListenerInvoked.set(input.isClosed()); + listenerInvoked.countDown(); + } + }); + ch.close(); // close channel to make handler discard the input on subsequent flush + ch.flush(); + + assertTrue(listenerInvoked.await(10, SECONDS)); + assertFalse(writeFuture.isSuccess()); + assertTrue(inputClosedWhenListenerInvoked.get()); + assertFalse(ch.finish()); + } + private static void check(Object... inputs) { EmbeddedChannel ch = new EmbeddedChannel(new ChunkedWriteHandler()); @@ -524,4 +662,96 @@ public class ChunkedWriteHandlerTest { assertEquals(BYTES.length, read); } + + private static final class TestChunkedInput implements ChunkedInput { + private final int chunksToProduce; + + private int chunksProduced; + private volatile boolean closed; + + TestChunkedInput(int chunksToProduce) { + this.chunksToProduce = chunksToProduce; + } + + @Override + public boolean isEndOfInput() { + return chunksProduced >= chunksToProduce; + } + + @Override + public void close() { + closed = true; + } + + @Override + public ByteBuf readChunk(ChannelHandlerContext ctx) { + return readChunk(ctx.alloc()); + } + + @Override + public ByteBuf readChunk(ByteBufAllocator allocator) { + ByteBuf buf = allocator.buffer(); + buf.writeInt(chunksProduced); + chunksProduced++; + return buf; + } + + @Override + public long length() { + return chunksToProduce; + } + + @Override + public long progress() { + return chunksProduced; + } + + boolean isClosed() { + return closed; + } + } + + private static final class ThrowingChunkedInput implements ChunkedInput { + private final Exception error; + + private volatile boolean closed; + + ThrowingChunkedInput(Exception error) { + this.error = error; + } + + @Override + public boolean isEndOfInput() { + return false; + } + + @Override + public void close() { + closed = true; + } + + @Override + public ByteBuf readChunk(ChannelHandlerContext ctx) throws Exception { + return readChunk(ctx.alloc()); + } + + @Override + public ByteBuf readChunk(ByteBufAllocator allocator) throws Exception { + throw error; + } + + @Override + public long length() { + return -1; + } + + @Override + public long progress() { + return -1; + } + + boolean isClosed() { + return closed; + } + } }