Close consumed inputs in ChunkedWriteHandler (#8876)

Motivation:

ChunkedWriteHandler needs to close both successful and failed
ChunkInputs. It used to never close successful ones.

Modifications:

* ChunkedWriteHandler always closes ChunkInput before completing
the write promise. 
* Ensure only ChunkInput#close() is invoked
on a failed input.
* Ensure no methods are invoked on a closed input.

Result:

Fixes https://github.com/netty/netty/issues/8875.
This commit is contained in:
Konstantin Lutovich 2019-02-28 21:13:56 +01:00 committed by Norman Maurer
parent 0811409ca3
commit e609b5eeb7
2 changed files with 256 additions and 16 deletions

View File

@ -166,22 +166,28 @@ public class ChunkedWriteHandler extends ChannelDuplexHandler {
Object message = currentWrite.msg;
if (message instanceof ChunkedInput) {
ChunkedInput<?> in = (ChunkedInput<?>) message;
boolean endOfInput;
long inputLength;
try {
if (!in.isEndOfInput()) {
if (cause == null) {
cause = new ClosedChannelException();
}
currentWrite.fail(cause);
} else {
currentWrite.success(in.length());
}
endOfInput = in.isEndOfInput();
inputLength = in.length();
closeInput(in);
} catch (Exception e) {
closeInput(in);
currentWrite.fail(e);
if (logger.isWarnEnabled()) {
logger.warn(ChunkedInput.class.getSimpleName() + ".isEndOfInput() failed", e);
logger.warn(ChunkedInput.class.getSimpleName() + " failed", e);
}
closeInput(in);
continue;
}
if (!endOfInput) {
if (cause == null) {
cause = new ClosedChannelException();
}
currentWrite.fail(cause);
} else {
currentWrite.success(inputLength);
}
} else {
if (cause == null) {
@ -249,8 +255,8 @@ public class ChunkedWriteHandler extends ChannelDuplexHandler {
ReferenceCountUtil.release(message);
}
currentWrite.fail(t);
closeInput(chunks);
currentWrite.fail(t);
break;
}
@ -283,8 +289,12 @@ public class ChunkedWriteHandler extends ChannelDuplexHandler {
closeInput(chunks);
currentWrite.fail(future.cause());
} else {
currentWrite.progress(chunks.progress(), chunks.length());
currentWrite.success(chunks.length());
// read state of the input in local variables before closing it
long inputProgress = chunks.progress();
long inputLength = chunks.length();
closeInput(chunks);
currentWrite.progress(inputProgress, inputLength);
currentWrite.success(inputLength);
}
}
});
@ -293,7 +303,7 @@ public class ChunkedWriteHandler extends ChannelDuplexHandler {
@Override
public void operationComplete(ChannelFuture future) throws Exception {
if (!future.isSuccess()) {
closeInput((ChunkedInput<?>) pendingMessage);
closeInput(chunks);
currentWrite.fail(future.cause());
} else {
currentWrite.progress(chunks.progress(), chunks.length());
@ -305,7 +315,7 @@ public class ChunkedWriteHandler extends ChannelDuplexHandler {
@Override
public void operationComplete(ChannelFuture future) throws Exception {
if (!future.isSuccess()) {
closeInput((ChunkedInput<?>) pendingMessage);
closeInput(chunks);
currentWrite.fail(future.cause());
} else {
currentWrite.progress(chunks.progress(), chunks.length());

View File

@ -21,8 +21,8 @@ import io.netty.buffer.Unpooled;
import io.netty.channel.ChannelFuture;
import io.netty.channel.ChannelFutureListener;
import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.ChannelPromise;
import io.netty.channel.ChannelOutboundHandlerAdapter;
import io.netty.channel.ChannelPromise;
import io.netty.channel.embedded.EmbeddedChannel;
import io.netty.util.CharsetUtil;
import io.netty.util.ReferenceCountUtil;
@ -33,9 +33,11 @@ import java.io.File;
import java.io.FileOutputStream;
import java.io.IOException;
import java.nio.channels.Channels;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicInteger;
import static java.util.concurrent.TimeUnit.*;
import static org.junit.Assert.*;
public class ChunkedWriteHandlerTest {
@ -433,6 +435,142 @@ public class ChunkedWriteHandlerTest {
assertEquals(1, chunks.get());
}
@Test
public void testCloseSuccessfulChunkedInput() {
int chunks = 10;
TestChunkedInput input = new TestChunkedInput(chunks);
EmbeddedChannel ch = new EmbeddedChannel(new ChunkedWriteHandler());
assertTrue(ch.writeOutbound(input));
for (int i = 0; i < chunks; i++) {
ByteBuf buf = ch.readOutbound();
assertEquals(i, buf.readInt());
buf.release();
}
assertTrue(input.isClosed());
assertFalse(ch.finish());
}
@Test
public void testCloseFailedChunkedInput() {
Exception error = new Exception("Unable to produce a chunk");
ThrowingChunkedInput input = new ThrowingChunkedInput(error);
EmbeddedChannel ch = new EmbeddedChannel(new ChunkedWriteHandler());
try {
ch.writeOutbound(input);
fail("Exception expected");
} catch (Exception e) {
assertEquals(error, e);
}
assertTrue(input.isClosed());
assertFalse(ch.finish());
}
@Test
public void testWriteListenerInvokedAfterSuccessfulChunkedInputClosed() throws Exception {
final TestChunkedInput input = new TestChunkedInput(2);
EmbeddedChannel ch = new EmbeddedChannel(new ChunkedWriteHandler());
final AtomicBoolean inputClosedWhenListenerInvoked = new AtomicBoolean();
final CountDownLatch listenerInvoked = new CountDownLatch(1);
ChannelFuture writeFuture = ch.write(input);
writeFuture.addListener(new ChannelFutureListener() {
@Override
public void operationComplete(ChannelFuture future) {
inputClosedWhenListenerInvoked.set(input.isClosed());
listenerInvoked.countDown();
}
});
ch.flush();
assertTrue(listenerInvoked.await(10, SECONDS));
assertTrue(writeFuture.isSuccess());
assertTrue(inputClosedWhenListenerInvoked.get());
assertTrue(ch.finishAndReleaseAll());
}
@Test
public void testWriteListenerInvokedAfterFailedChunkedInputClosed() throws Exception {
final ThrowingChunkedInput input = new ThrowingChunkedInput(new RuntimeException());
EmbeddedChannel ch = new EmbeddedChannel(new ChunkedWriteHandler());
final AtomicBoolean inputClosedWhenListenerInvoked = new AtomicBoolean();
final CountDownLatch listenerInvoked = new CountDownLatch(1);
ChannelFuture writeFuture = ch.write(input);
writeFuture.addListener(new ChannelFutureListener() {
@Override
public void operationComplete(ChannelFuture future) {
inputClosedWhenListenerInvoked.set(input.isClosed());
listenerInvoked.countDown();
}
});
ch.flush();
assertTrue(listenerInvoked.await(10, SECONDS));
assertFalse(writeFuture.isSuccess());
assertTrue(inputClosedWhenListenerInvoked.get());
assertFalse(ch.finish());
}
@Test
public void testWriteListenerInvokedAfterChannelClosedAndInputFullyConsumed() throws Exception {
// use empty input which has endOfInput = true
final TestChunkedInput input = new TestChunkedInput(0);
EmbeddedChannel ch = new EmbeddedChannel(new ChunkedWriteHandler());
final AtomicBoolean inputClosedWhenListenerInvoked = new AtomicBoolean();
final CountDownLatch listenerInvoked = new CountDownLatch(1);
ChannelFuture writeFuture = ch.write(input);
writeFuture.addListener(new ChannelFutureListener() {
@Override
public void operationComplete(ChannelFuture future) {
inputClosedWhenListenerInvoked.set(input.isClosed());
listenerInvoked.countDown();
}
});
ch.close(); // close channel to make handler discard the input on subsequent flush
ch.flush();
assertTrue(listenerInvoked.await(10, SECONDS));
assertTrue(writeFuture.isSuccess());
assertTrue(inputClosedWhenListenerInvoked.get());
assertFalse(ch.finish());
}
@Test
public void testWriteListenerInvokedAfterChannelClosedAndInputNotFullyConsumed() throws Exception {
// use non-empty input which has endOfInput = false
final TestChunkedInput input = new TestChunkedInput(42);
EmbeddedChannel ch = new EmbeddedChannel(new ChunkedWriteHandler());
final AtomicBoolean inputClosedWhenListenerInvoked = new AtomicBoolean();
final CountDownLatch listenerInvoked = new CountDownLatch(1);
ChannelFuture writeFuture = ch.write(input);
writeFuture.addListener(new ChannelFutureListener() {
@Override
public void operationComplete(ChannelFuture future) {
inputClosedWhenListenerInvoked.set(input.isClosed());
listenerInvoked.countDown();
}
});
ch.close(); // close channel to make handler discard the input on subsequent flush
ch.flush();
assertTrue(listenerInvoked.await(10, SECONDS));
assertFalse(writeFuture.isSuccess());
assertTrue(inputClosedWhenListenerInvoked.get());
assertFalse(ch.finish());
}
private static void check(Object... inputs) {
EmbeddedChannel ch = new EmbeddedChannel(new ChunkedWriteHandler());
@ -524,4 +662,96 @@ public class ChunkedWriteHandlerTest {
assertEquals(BYTES.length, read);
}
private static final class TestChunkedInput implements ChunkedInput<ByteBuf> {
private final int chunksToProduce;
private int chunksProduced;
private volatile boolean closed;
TestChunkedInput(int chunksToProduce) {
this.chunksToProduce = chunksToProduce;
}
@Override
public boolean isEndOfInput() {
return chunksProduced >= chunksToProduce;
}
@Override
public void close() {
closed = true;
}
@Override
public ByteBuf readChunk(ChannelHandlerContext ctx) {
return readChunk(ctx.alloc());
}
@Override
public ByteBuf readChunk(ByteBufAllocator allocator) {
ByteBuf buf = allocator.buffer();
buf.writeInt(chunksProduced);
chunksProduced++;
return buf;
}
@Override
public long length() {
return chunksToProduce;
}
@Override
public long progress() {
return chunksProduced;
}
boolean isClosed() {
return closed;
}
}
private static final class ThrowingChunkedInput implements ChunkedInput<ByteBuf> {
private final Exception error;
private volatile boolean closed;
ThrowingChunkedInput(Exception error) {
this.error = error;
}
@Override
public boolean isEndOfInput() {
return false;
}
@Override
public void close() {
closed = true;
}
@Override
public ByteBuf readChunk(ChannelHandlerContext ctx) throws Exception {
return readChunk(ctx.alloc());
}
@Override
public ByteBuf readChunk(ByteBufAllocator allocator) throws Exception {
throw error;
}
@Override
public long length() {
return -1;
}
@Override
public long progress() {
return -1;
}
boolean isClosed() {
return closed;
}
}
}