diff --git a/transport/src/main/java/io/netty/channel/DefaultChannelHandlerInvoker.java b/transport/src/main/java/io/netty/channel/DefaultChannelHandlerInvoker.java index 016d1e8d0e..754b665289 100644 --- a/transport/src/main/java/io/netty/channel/DefaultChannelHandlerInvoker.java +++ b/transport/src/main/java/io/netty/channel/DefaultChannelHandlerInvoker.java @@ -191,7 +191,10 @@ public class DefaultChannelHandlerInvoker implements ChannelHandlerInvoker { if (localAddress == null) { throw new NullPointerException("localAddress"); } - validatePromise(ctx, promise, false); + if (!validatePromise(ctx, promise, false)) { + // promise cancelled + return; + } if (executor.inEventLoop()) { invokeBindNow(ctx, localAddress, promise); @@ -212,7 +215,10 @@ public class DefaultChannelHandlerInvoker implements ChannelHandlerInvoker { if (remoteAddress == null) { throw new NullPointerException("remoteAddress"); } - validatePromise(ctx, promise, false); + if (!validatePromise(ctx, promise, false)) { + // promise cancelled + return; + } if (executor.inEventLoop()) { invokeConnectNow(ctx, remoteAddress, localAddress, promise); @@ -228,7 +234,10 @@ public class DefaultChannelHandlerInvoker implements ChannelHandlerInvoker { @Override public void invokeDisconnect(final ChannelHandlerContext ctx, final ChannelPromise promise) { - validatePromise(ctx, promise, false); + if (!validatePromise(ctx, promise, false)) { + // promise cancelled + return; + } if (executor.inEventLoop()) { invokeDisconnectNow(ctx, promise); @@ -244,7 +253,10 @@ public class DefaultChannelHandlerInvoker implements ChannelHandlerInvoker { @Override public void invokeClose(final ChannelHandlerContext ctx, final ChannelPromise promise) { - validatePromise(ctx, promise, false); + if (!validatePromise(ctx, promise, false)) { + // promise cancelled + return; + } if (executor.inEventLoop()) { invokeCloseNow(ctx, promise); @@ -282,8 +294,11 @@ public class DefaultChannelHandlerInvoker implements ChannelHandlerInvoker { if (msg == null) { throw new NullPointerException("msg"); } - - validatePromise(ctx, promise, true); + if (!validatePromise(ctx, promise, true)) { + // promise cancelled + ReferenceCountUtil.release(msg); + return; + } if (executor.inEventLoop()) { invokeWriteNow(ctx, msg, promise); @@ -320,7 +335,8 @@ public class DefaultChannelHandlerInvoker implements ChannelHandlerInvoker { } } - private static void validatePromise(ChannelHandlerContext ctx, ChannelPromise promise, boolean allowVoidPromise) { + private static boolean validatePromise( + ChannelHandlerContext ctx, ChannelPromise promise, boolean allowVoidPromise) { if (ctx == null) { throw new NullPointerException("ctx"); } @@ -330,6 +346,9 @@ public class DefaultChannelHandlerInvoker implements ChannelHandlerInvoker { } if (promise.isDone()) { + if (promise.isCancelled()) { + return false; + } throw new IllegalArgumentException("promise already done: " + promise); } @@ -339,7 +358,7 @@ public class DefaultChannelHandlerInvoker implements ChannelHandlerInvoker { } if (promise.getClass() == DefaultChannelPromise.class) { - return; + return true; } if (!allowVoidPromise && promise instanceof VoidChannelPromise) { @@ -351,6 +370,7 @@ public class DefaultChannelHandlerInvoker implements ChannelHandlerInvoker { throw new IllegalArgumentException( StringUtil.simpleClassName(AbstractChannel.CloseFuture.class) + " not allowed in a pipeline"); } + return true; } private void safeExecuteInbound(Runnable task, Object msg) { diff --git a/transport/src/test/java/io/netty/channel/DefaultChannelPipelineTest.java b/transport/src/test/java/io/netty/channel/DefaultChannelPipelineTest.java index c1bb346e90..891b5b11f9 100644 --- a/transport/src/test/java/io/netty/channel/DefaultChannelPipelineTest.java +++ b/transport/src/test/java/io/netty/channel/DefaultChannelPipelineTest.java @@ -18,6 +18,8 @@ package io.netty.channel; import io.netty.bootstrap.Bootstrap; import io.netty.bootstrap.ServerBootstrap; +import io.netty.buffer.ByteBuf; +import io.netty.buffer.Unpooled; import io.netty.channel.ChannelHandler.Sharable; import io.netty.channel.local.LocalAddress; import io.netty.channel.local.LocalChannel; @@ -439,6 +441,73 @@ public class DefaultChannelPipelineTest { }).sync(); } + // Tests for https://github.com/netty/netty/issues/2349 + @Test + public void testCancelBind() throws Exception { + ChannelPipeline pipeline = new LocalChannel(group.next()).pipeline(); + + ChannelPromise promise = pipeline.channel().newPromise(); + assertTrue(promise.cancel(false)); + ChannelFuture future = pipeline.bind(new LocalAddress("test"), promise); + assertTrue(future.isCancelled()); + } + + @Test + public void testCancelConnect() throws Exception { + ChannelPipeline pipeline = new LocalChannel(group.next()).pipeline(); + + ChannelPromise promise = pipeline.channel().newPromise(); + assertTrue(promise.cancel(false)); + ChannelFuture future = pipeline.connect(new LocalAddress("test"), promise); + assertTrue(future.isCancelled()); + } + + @Test + public void testCancelDisconnect() throws Exception { + ChannelPipeline pipeline = new LocalChannel(group.next()).pipeline(); + + ChannelPromise promise = pipeline.channel().newPromise(); + assertTrue(promise.cancel(false)); + ChannelFuture future = pipeline.disconnect(promise); + assertTrue(future.isCancelled()); + } + + @Test + public void testCancelClose() throws Exception { + ChannelPipeline pipeline = new LocalChannel(group.next()).pipeline(); + + ChannelPromise promise = pipeline.channel().newPromise(); + assertTrue(promise.cancel(false)); + ChannelFuture future = pipeline.close(promise); + assertTrue(future.isCancelled()); + } + + @Test + public void testCancelWrite() throws Exception { + ChannelPipeline pipeline = new LocalChannel(group.next()).pipeline(); + + ChannelPromise promise = pipeline.channel().newPromise(); + assertTrue(promise.cancel(false)); + ByteBuf buffer = Unpooled.buffer(); + assertEquals(1, buffer.refCnt()); + ChannelFuture future = pipeline.write(buffer, promise); + assertTrue(future.isCancelled()); + assertEquals(0, buffer.refCnt()); + } + + @Test + public void testCancelWriteAndFlush() throws Exception { + ChannelPipeline pipeline = new LocalChannel(group.next()).pipeline(); + + ChannelPromise promise = pipeline.channel().newPromise(); + assertTrue(promise.cancel(false)); + ByteBuf buffer = Unpooled.buffer(); + assertEquals(1, buffer.refCnt()); + ChannelFuture future = pipeline.writeAndFlush(buffer, promise); + assertTrue(future.isCancelled()); + assertEquals(0, buffer.refCnt()); + } + private static int next(DefaultChannelHandlerContext ctx) { DefaultChannelHandlerContext next = ctx.next; if (next == null) {