diff --git a/transport/src/main/java/io/netty/channel/AbstractChannelHandlerContext.java b/transport/src/main/java/io/netty/channel/AbstractChannelHandlerContext.java index 3b155a9e23..0a81bd6b0f 100644 --- a/transport/src/main/java/io/netty/channel/AbstractChannelHandlerContext.java +++ b/transport/src/main/java/io/netty/channel/AbstractChannelHandlerContext.java @@ -816,13 +816,19 @@ abstract class AbstractChannelHandlerContext extends DefaultAttributeMap next.invokeWrite(m, promise); } } else { - AbstractWriteTask task; + final AbstractWriteTask task; if (flush) { task = WriteAndFlushTask.newInstance(next, m, promise); } else { task = WriteTask.newInstance(next, m, promise); } - safeExecute(executor, task, promise, m); + if (!safeExecute(executor, task, promise, m)) { + // We failed to submit the AbstractWriteTask. We need to cancel it so we decrement the pending bytes + // and put it back in the Recycler for re-use later. + // + // See https://github.com/netty/netty/issues/8343. + task.cancel(); + } } } @@ -1002,9 +1008,10 @@ abstract class AbstractChannelHandlerContext extends DefaultAttributeMap return channel().hasAttr(key); } - private static void safeExecute(EventExecutor executor, Runnable runnable, ChannelPromise promise, Object msg) { + private static boolean safeExecute(EventExecutor executor, Runnable runnable, ChannelPromise promise, Object msg) { try { executor.execute(runnable); + return true; } catch (Throwable cause) { try { promise.setFailure(cause); @@ -1013,6 +1020,7 @@ abstract class AbstractChannelHandlerContext extends DefaultAttributeMap ReferenceCountUtil.release(msg); } } + return false; } } @@ -1063,20 +1071,35 @@ abstract class AbstractChannelHandlerContext extends DefaultAttributeMap @Override public final void run() { try { - // Check for null as it may be set to null if the channel is closed already - if (ESTIMATE_TASK_SIZE_ON_SUBMIT) { - ctx.pipeline.decrementPendingOutboundBytes(size); - } + decrementPendingOutboundBytes(); write(ctx, msg, promise); } finally { - // Set to null so the GC can collect them directly - ctx = null; - msg = null; - promise = null; - handle.recycle(this); + recycle(); } } + void cancel() { + try { + decrementPendingOutboundBytes(); + } finally { + recycle(); + } + } + + private void decrementPendingOutboundBytes() { + if (ESTIMATE_TASK_SIZE_ON_SUBMIT) { + ctx.pipeline.decrementPendingOutboundBytes(size); + } + } + + private void recycle() { + // Set to null so the GC can collect them directly + ctx = null; + msg = null; + promise = null; + handle.recycle(this); + } + protected void write(AbstractChannelHandlerContext ctx, Object msg, ChannelPromise promise) { ctx.invokeWrite(msg, promise); } @@ -1091,7 +1114,7 @@ abstract class AbstractChannelHandlerContext extends DefaultAttributeMap } }; - private static WriteTask newInstance( + static WriteTask newInstance( AbstractChannelHandlerContext ctx, Object msg, ChannelPromise promise) { WriteTask task = RECYCLER.get(); init(task, ctx, msg, promise); @@ -1112,7 +1135,7 @@ abstract class AbstractChannelHandlerContext extends DefaultAttributeMap } }; - private static WriteAndFlushTask newInstance( + static WriteAndFlushTask newInstance( AbstractChannelHandlerContext ctx, Object msg, ChannelPromise promise) { WriteAndFlushTask task = RECYCLER.get(); init(task, ctx, msg, promise); diff --git a/transport/src/test/java/io/netty/channel/ChannelOutboundBufferTest.java b/transport/src/test/java/io/netty/channel/ChannelOutboundBufferTest.java index 4702490916..955de75395 100644 --- a/transport/src/test/java/io/netty/channel/ChannelOutboundBufferTest.java +++ b/transport/src/test/java/io/netty/channel/ChannelOutboundBufferTest.java @@ -19,10 +19,16 @@ import io.netty.buffer.ByteBuf; import io.netty.buffer.CompositeByteBuf; import io.netty.channel.embedded.EmbeddedChannel; import io.netty.util.CharsetUtil; +import io.netty.util.concurrent.DefaultThreadFactory; +import io.netty.util.concurrent.RejectedExecutionHandlers; +import io.netty.util.concurrent.SingleThreadEventExecutor; import org.junit.Test; import java.net.SocketAddress; import java.nio.ByteBuffer; +import java.util.Queue; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.RejectedExecutionException; import static io.netty.buffer.Unpooled.*; import static org.hamcrest.Matchers.*; @@ -355,6 +361,85 @@ public class ChannelOutboundBufferTest { safeClose(ch); } + @Test(timeout = 5000) + public void testWriteTaskRejected() throws Exception { + final SingleThreadEventExecutor executor = new SingleThreadEventExecutor( + null, new DefaultThreadFactory("executorPool"), + true, 1, RejectedExecutionHandlers.reject()) { + @Override + protected void run() { + do { + Runnable task = takeTask(); + if (task != null) { + task.run(); + updateLastExecutionTime(); + } + } while (!confirmShutdown()); + } + + @Override + protected Queue newTaskQueue(int maxPendingTasks) { + return super.newTaskQueue(1); + } + }; + final CountDownLatch handlerAddedLatch = new CountDownLatch(1); + EmbeddedChannel ch = new EmbeddedChannel(); + ch.pipeline().addLast(executor, new ChannelOutboundHandlerAdapter() { + @Override + public void write(ChannelHandlerContext ctx, Object msg, ChannelPromise promise) throws Exception { + promise.setFailure(new AssertionError("Should not be called")); + } + + @Override + public void handlerAdded(ChannelHandlerContext ctx) throws Exception { + handlerAddedLatch.countDown(); + } + }); + + // Lets wait until we are sure the handler was added. + handlerAddedLatch.await(); + + final CountDownLatch executeLatch = new CountDownLatch(1); + final CountDownLatch runLatch = new CountDownLatch(1); + executor.execute(new Runnable() { + @Override + public void run() { + try { + runLatch.countDown(); + executeLatch.await(); + } catch (InterruptedException e) { + Thread.currentThread().interrupt(); + } + } + }); + + runLatch.await(); + + executor.execute(new Runnable() { + @Override + public void run() { + // Will not be executed but ensure the pending count is 1. + } + }); + + assertEquals(1, executor.pendingTasks()); + assertEquals(0, ch.unsafe().outboundBuffer().totalPendingWriteBytes()); + + ByteBuf buffer = buffer(128).writeZero(128); + ChannelFuture future = ch.write(buffer); + ch.runPendingTasks(); + + assertTrue(future.cause() instanceof RejectedExecutionException); + assertEquals(0, buffer.refCnt()); + + // In case of rejected task we should not have anything pending. + assertEquals(0, ch.unsafe().outboundBuffer().totalPendingWriteBytes()); + executeLatch.countDown(); + + safeClose(ch); + executor.shutdownGracefully(); + } + private static void safeClose(EmbeddedChannel ch) { ch.finish(); for (;;) {