diff --git a/transport/src/main/java/io/netty/channel/DefaultChannelHandlerContext.java b/transport/src/main/java/io/netty/channel/DefaultChannelHandlerContext.java index 2ee1282455..31a20c4f02 100644 --- a/transport/src/main/java/io/netty/channel/DefaultChannelHandlerContext.java +++ b/transport/src/main/java/io/netty/channel/DefaultChannelHandlerContext.java @@ -437,7 +437,10 @@ final class DefaultChannelHandlerContext extends DefaultAttributeMap implements if (localAddress == null) { throw new NullPointerException("localAddress"); } - validatePromise(promise, false); + if (!validatePromise(promise, false)) { + // cancelled + return promise; + } final DefaultChannelHandlerContext next = findContextOutbound(); EventExecutor executor = next.executor(); @@ -475,7 +478,10 @@ final class DefaultChannelHandlerContext extends DefaultAttributeMap implements if (remoteAddress == null) { throw new NullPointerException("remoteAddress"); } - validatePromise(promise, false); + if (!validatePromise(promise, false)) { + // cancelled + return promise; + } final DefaultChannelHandlerContext next = findContextOutbound(); EventExecutor executor = next.executor(); @@ -503,7 +509,10 @@ final class DefaultChannelHandlerContext extends DefaultAttributeMap implements @Override public ChannelFuture disconnect(final ChannelPromise promise) { - validatePromise(promise, false); + if (!validatePromise(promise, false)) { + // cancelled + return promise; + } final DefaultChannelHandlerContext next = findContextOutbound(); EventExecutor executor = next.executor(); @@ -541,7 +550,10 @@ final class DefaultChannelHandlerContext extends DefaultAttributeMap implements @Override public ChannelFuture close(final ChannelPromise promise) { - validatePromise(promise, false); + if (!validatePromise(promise, false)) { + // cancelled + return promise; + } final DefaultChannelHandlerContext next = findContextOutbound(); EventExecutor executor = next.executor(); @@ -569,7 +581,10 @@ final class DefaultChannelHandlerContext extends DefaultAttributeMap implements @Override public ChannelFuture deregister(final ChannelPromise promise) { - validatePromise(promise, false); + if (!validatePromise(promise, false)) { + // cancelled + return promise; + } final DefaultChannelHandlerContext next = findContextOutbound(); EventExecutor executor = next.executor(); @@ -636,8 +651,11 @@ final class DefaultChannelHandlerContext extends DefaultAttributeMap implements throw new NullPointerException("msg"); } - validatePromise(promise, true); - + if (!validatePromise(promise, true)) { + ReferenceCountUtil.release(msg); + // cancelled + return promise; + } write(msg, false, promise); return promise; @@ -687,7 +705,11 @@ final class DefaultChannelHandlerContext extends DefaultAttributeMap implements throw new NullPointerException("msg"); } - validatePromise(promise, true); + if (!validatePromise(promise, true)) { + ReferenceCountUtil.release(msg); + // cancelled + return promise; + } write(msg, true, promise); @@ -798,12 +820,19 @@ final class DefaultChannelHandlerContext extends DefaultAttributeMap implements return new FailedChannelFuture(channel(), executor(), cause); } - private void validatePromise(ChannelPromise promise, boolean allowVoidPromise) { + private boolean validatePromise(ChannelPromise promise, boolean allowVoidPromise) { if (promise == null) { throw new NullPointerException("promise"); } if (promise.isDone()) { + // Check if the promise was cancelled and if so signal that the processing of the operation + // should not be performed. + // + // See https://github.com/netty/netty/issues/2349 + if (promise.isCancelled()) { + return false; + } throw new IllegalArgumentException("promise already done: " + promise); } @@ -813,7 +842,7 @@ final class DefaultChannelHandlerContext extends DefaultAttributeMap implements } if (promise.getClass() == DefaultChannelPromise.class) { - return; + return true; } if (!allowVoidPromise && promise instanceof VoidChannelPromise) { @@ -825,6 +854,7 @@ final class DefaultChannelHandlerContext extends DefaultAttributeMap implements throw new IllegalArgumentException( StringUtil.simpleClassName(AbstractChannel.CloseFuture.class) + " not allowed in a pipeline"); } + return true; } private DefaultChannelHandlerContext findContextInbound() { diff --git a/transport/src/test/java/io/netty/channel/DefaultChannelPipelineTest.java b/transport/src/test/java/io/netty/channel/DefaultChannelPipelineTest.java index 37ccddb567..35bd97d1d7 100644 --- a/transport/src/test/java/io/netty/channel/DefaultChannelPipelineTest.java +++ b/transport/src/test/java/io/netty/channel/DefaultChannelPipelineTest.java @@ -18,6 +18,8 @@ package io.netty.channel; import io.netty.bootstrap.Bootstrap; import io.netty.bootstrap.ServerBootstrap; +import io.netty.buffer.ByteBuf; +import io.netty.buffer.Unpooled; import io.netty.channel.ChannelHandler.Sharable; import io.netty.channel.local.LocalAddress; import io.netty.channel.local.LocalChannel; @@ -434,6 +436,81 @@ public class DefaultChannelPipelineTest { }).sync(); } + // Tests for https://github.com/netty/netty/issues/2349 + @Test + public void testCancelBind() throws Exception { + ChannelPipeline pipeline = new LocalChannel().pipeline(); + + ChannelPromise promise = pipeline.channel().newPromise(); + assertTrue(promise.cancel(false)); + ChannelFuture future = pipeline.bind(new LocalAddress("test"), promise); + assertTrue(future.isCancelled()); + } + + @Test + public void testCancelConnect() throws Exception { + ChannelPipeline pipeline = new LocalChannel().pipeline(); + + ChannelPromise promise = pipeline.channel().newPromise(); + assertTrue(promise.cancel(false)); + ChannelFuture future = pipeline.connect(new LocalAddress("test"), promise); + assertTrue(future.isCancelled()); + } + + @Test + public void testCancelDisconnect() throws Exception { + ChannelPipeline pipeline = new LocalChannel().pipeline(); + + ChannelPromise promise = pipeline.channel().newPromise(); + assertTrue(promise.cancel(false)); + ChannelFuture future = pipeline.disconnect(promise); + assertTrue(future.isCancelled()); + } + + @Test + public void testCancelClose() throws Exception { + ChannelPipeline pipeline = new LocalChannel().pipeline(); + + ChannelPromise promise = pipeline.channel().newPromise(); + assertTrue(promise.cancel(false)); + ChannelFuture future = pipeline.close(promise); + assertTrue(future.isCancelled()); + } + + @Test + public void testCancelDeregister() throws Exception { + ChannelPipeline pipeline = new LocalChannel().pipeline(); + + ChannelPromise promise = pipeline.channel().newPromise(); + assertTrue(promise.cancel(false)); + ChannelFuture future = pipeline.deregister(promise); + assertTrue(future.isCancelled()); + } + + @Test + public void testCancelWrite() throws Exception { + ChannelPipeline pipeline = new LocalChannel().pipeline(); + ChannelPromise promise = pipeline.channel().newPromise(); + assertTrue(promise.cancel(false)); + ByteBuf buffer = Unpooled.buffer(); + assertEquals(1, buffer.refCnt()); + ChannelFuture future = pipeline.write(buffer, promise); + assertTrue(future.isCancelled()); + assertEquals(0, buffer.refCnt()); + } + + @Test + public void testCancelWriteAndFlush() throws Exception { + ChannelPipeline pipeline = new LocalChannel().pipeline(); + ChannelPromise promise = pipeline.channel().newPromise(); + assertTrue(promise.cancel(false)); + ByteBuf buffer = Unpooled.buffer(); + assertEquals(1, buffer.refCnt()); + ChannelFuture future = pipeline.writeAndFlush(buffer, promise); + assertTrue(future.isCancelled()); + assertEquals(0, buffer.refCnt()); + } + private static int next(DefaultChannelHandlerContext ctx) { DefaultChannelHandlerContext next = ctx.next; if (next == null) {