From 652650b0dbbf515d8ddb6322fcd2bade91452869 Mon Sep 17 00:00:00 2001 From: Norman Maurer Date: Thu, 11 Oct 2018 18:46:10 +0200 Subject: [PATCH] Correctly decrement pending bytes when submitting AbstractWriteTask fails. (#8349) Motivation: Currently we may end up in the situation that we incremented the pending bytes before submitting the AbstractWriteTask but never decrement these again if the submitting of the task fails. This may result in incorrect watermark handling. Modifications: - Correctly decrement pending bytes if subimitting of task fails and also ensure we recycle it correctly. - Add unit test. Result: Fixes https://github.com/netty/netty/issues/8343. --- .../AbstractChannelHandlerContext.java | 51 ++++++++--- .../channel/ChannelOutboundBufferTest.java | 85 +++++++++++++++++++ 2 files changed, 122 insertions(+), 14 deletions(-) diff --git a/transport/src/main/java/io/netty/channel/AbstractChannelHandlerContext.java b/transport/src/main/java/io/netty/channel/AbstractChannelHandlerContext.java index 3b155a9e23..0a81bd6b0f 100644 --- a/transport/src/main/java/io/netty/channel/AbstractChannelHandlerContext.java +++ b/transport/src/main/java/io/netty/channel/AbstractChannelHandlerContext.java @@ -816,13 +816,19 @@ abstract class AbstractChannelHandlerContext extends DefaultAttributeMap next.invokeWrite(m, promise); } } else { - AbstractWriteTask task; + final AbstractWriteTask task; if (flush) { task = WriteAndFlushTask.newInstance(next, m, promise); } else { task = WriteTask.newInstance(next, m, promise); } - safeExecute(executor, task, promise, m); + if (!safeExecute(executor, task, promise, m)) { + // We failed to submit the AbstractWriteTask. We need to cancel it so we decrement the pending bytes + // and put it back in the Recycler for re-use later. + // + // See https://github.com/netty/netty/issues/8343. + task.cancel(); + } } } @@ -1002,9 +1008,10 @@ abstract class AbstractChannelHandlerContext extends DefaultAttributeMap return channel().hasAttr(key); } - private static void safeExecute(EventExecutor executor, Runnable runnable, ChannelPromise promise, Object msg) { + private static boolean safeExecute(EventExecutor executor, Runnable runnable, ChannelPromise promise, Object msg) { try { executor.execute(runnable); + return true; } catch (Throwable cause) { try { promise.setFailure(cause); @@ -1013,6 +1020,7 @@ abstract class AbstractChannelHandlerContext extends DefaultAttributeMap ReferenceCountUtil.release(msg); } } + return false; } } @@ -1063,20 +1071,35 @@ abstract class AbstractChannelHandlerContext extends DefaultAttributeMap @Override public final void run() { try { - // Check for null as it may be set to null if the channel is closed already - if (ESTIMATE_TASK_SIZE_ON_SUBMIT) { - ctx.pipeline.decrementPendingOutboundBytes(size); - } + decrementPendingOutboundBytes(); write(ctx, msg, promise); } finally { - // Set to null so the GC can collect them directly - ctx = null; - msg = null; - promise = null; - handle.recycle(this); + recycle(); } } + void cancel() { + try { + decrementPendingOutboundBytes(); + } finally { + recycle(); + } + } + + private void decrementPendingOutboundBytes() { + if (ESTIMATE_TASK_SIZE_ON_SUBMIT) { + ctx.pipeline.decrementPendingOutboundBytes(size); + } + } + + private void recycle() { + // Set to null so the GC can collect them directly + ctx = null; + msg = null; + promise = null; + handle.recycle(this); + } + protected void write(AbstractChannelHandlerContext ctx, Object msg, ChannelPromise promise) { ctx.invokeWrite(msg, promise); } @@ -1091,7 +1114,7 @@ abstract class AbstractChannelHandlerContext extends DefaultAttributeMap } }; - private static WriteTask newInstance( + static WriteTask newInstance( AbstractChannelHandlerContext ctx, Object msg, ChannelPromise promise) { WriteTask task = RECYCLER.get(); init(task, ctx, msg, promise); @@ -1112,7 +1135,7 @@ abstract class AbstractChannelHandlerContext extends DefaultAttributeMap } }; - private static WriteAndFlushTask newInstance( + static WriteAndFlushTask newInstance( AbstractChannelHandlerContext ctx, Object msg, ChannelPromise promise) { WriteAndFlushTask task = RECYCLER.get(); init(task, ctx, msg, promise); diff --git a/transport/src/test/java/io/netty/channel/ChannelOutboundBufferTest.java b/transport/src/test/java/io/netty/channel/ChannelOutboundBufferTest.java index 4702490916..955de75395 100644 --- a/transport/src/test/java/io/netty/channel/ChannelOutboundBufferTest.java +++ b/transport/src/test/java/io/netty/channel/ChannelOutboundBufferTest.java @@ -19,10 +19,16 @@ import io.netty.buffer.ByteBuf; import io.netty.buffer.CompositeByteBuf; import io.netty.channel.embedded.EmbeddedChannel; import io.netty.util.CharsetUtil; +import io.netty.util.concurrent.DefaultThreadFactory; +import io.netty.util.concurrent.RejectedExecutionHandlers; +import io.netty.util.concurrent.SingleThreadEventExecutor; import org.junit.Test; import java.net.SocketAddress; import java.nio.ByteBuffer; +import java.util.Queue; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.RejectedExecutionException; import static io.netty.buffer.Unpooled.*; import static org.hamcrest.Matchers.*; @@ -355,6 +361,85 @@ public class ChannelOutboundBufferTest { safeClose(ch); } + @Test(timeout = 5000) + public void testWriteTaskRejected() throws Exception { + final SingleThreadEventExecutor executor = new SingleThreadEventExecutor( + null, new DefaultThreadFactory("executorPool"), + true, 1, RejectedExecutionHandlers.reject()) { + @Override + protected void run() { + do { + Runnable task = takeTask(); + if (task != null) { + task.run(); + updateLastExecutionTime(); + } + } while (!confirmShutdown()); + } + + @Override + protected Queue newTaskQueue(int maxPendingTasks) { + return super.newTaskQueue(1); + } + }; + final CountDownLatch handlerAddedLatch = new CountDownLatch(1); + EmbeddedChannel ch = new EmbeddedChannel(); + ch.pipeline().addLast(executor, new ChannelOutboundHandlerAdapter() { + @Override + public void write(ChannelHandlerContext ctx, Object msg, ChannelPromise promise) throws Exception { + promise.setFailure(new AssertionError("Should not be called")); + } + + @Override + public void handlerAdded(ChannelHandlerContext ctx) throws Exception { + handlerAddedLatch.countDown(); + } + }); + + // Lets wait until we are sure the handler was added. + handlerAddedLatch.await(); + + final CountDownLatch executeLatch = new CountDownLatch(1); + final CountDownLatch runLatch = new CountDownLatch(1); + executor.execute(new Runnable() { + @Override + public void run() { + try { + runLatch.countDown(); + executeLatch.await(); + } catch (InterruptedException e) { + Thread.currentThread().interrupt(); + } + } + }); + + runLatch.await(); + + executor.execute(new Runnable() { + @Override + public void run() { + // Will not be executed but ensure the pending count is 1. + } + }); + + assertEquals(1, executor.pendingTasks()); + assertEquals(0, ch.unsafe().outboundBuffer().totalPendingWriteBytes()); + + ByteBuf buffer = buffer(128).writeZero(128); + ChannelFuture future = ch.write(buffer); + ch.runPendingTasks(); + + assertTrue(future.cause() instanceof RejectedExecutionException); + assertEquals(0, buffer.refCnt()); + + // In case of rejected task we should not have anything pending. + assertEquals(0, ch.unsafe().outboundBuffer().totalPendingWriteBytes()); + executeLatch.countDown(); + + safeClose(ch); + executor.shutdownGracefully(); + } + private static void safeClose(EmbeddedChannel ch) { ch.finish(); for (;;) {