diff --git a/transport/src/main/java/io/netty/channel/AbstractChannelHandlerContext.java b/transport/src/main/java/io/netty/channel/AbstractChannelHandlerContext.java index 9e599c335d..6abf529afc 100644 --- a/transport/src/main/java/io/netty/channel/AbstractChannelHandlerContext.java +++ b/transport/src/main/java/io/netty/channel/AbstractChannelHandlerContext.java @@ -706,20 +706,6 @@ abstract class AbstractChannelHandlerContext extends DefaultAttributeMap @Override public ChannelFuture write(final Object msg, final ChannelPromise promise) { - if (msg == null) { - throw new NullPointerException("msg"); - } - - try { - if (isNotValidPromise(promise, true)) { - ReferenceCountUtil.release(msg); - // cancelled - return promise; - } - } catch (RuntimeException e) { - ReferenceCountUtil.release(msg); - throw e; - } write(msg, false, promise); return promise; @@ -781,18 +767,7 @@ abstract class AbstractChannelHandlerContext extends DefaultAttributeMap @Override public ChannelFuture writeAndFlush(Object msg, ChannelPromise promise) { - if (msg == null) { - throw new NullPointerException("msg"); - } - - if (isNotValidPromise(promise, true)) { - ReferenceCountUtil.release(msg); - // cancelled - return promise; - } - write(msg, true, promise); - return promise; } @@ -806,6 +781,18 @@ abstract class AbstractChannelHandlerContext extends DefaultAttributeMap } private void write(Object msg, boolean flush, ChannelPromise promise) { + ObjectUtil.checkNotNull(msg, "msg"); + try { + if (isNotValidPromise(promise, true)) { + ReferenceCountUtil.release(msg); + // cancelled + return; + } + } catch (RuntimeException e) { + ReferenceCountUtil.release(msg); + throw e; + } + AbstractChannelHandlerContext next = findContextOutbound(); final Object m = pipeline.touch(msg, next); EventExecutor executor = next.executor(); diff --git a/transport/src/test/java/io/netty/channel/DefaultChannelPipelineTest.java b/transport/src/test/java/io/netty/channel/DefaultChannelPipelineTest.java index 31de253b04..65209d50cf 100644 --- a/transport/src/test/java/io/netty/channel/DefaultChannelPipelineTest.java +++ b/transport/src/test/java/io/netty/channel/DefaultChannelPipelineTest.java @@ -1223,6 +1223,51 @@ public class DefaultChannelPipelineTest { } } + @Test + public void testWriteThrowsReleaseMessage() { + testWriteThrowsReleaseMessage0(false); + } + + @Test + public void testWriteAndFlushThrowsReleaseMessage() { + testWriteThrowsReleaseMessage0(true); + } + + private void testWriteThrowsReleaseMessage0(boolean flush) { + ReferenceCounted referenceCounted = new AbstractReferenceCounted() { + @Override + protected void deallocate() { + // NOOP + } + + @Override + public ReferenceCounted touch(Object hint) { + return this; + } + }; + assertEquals(1, referenceCounted.refCnt()); + + Channel channel = new LocalChannel(); + Channel channel2 = new LocalChannel(); + group.register(channel).syncUninterruptibly(); + group.register(channel2).syncUninterruptibly(); + + try { + if (flush) { + channel.writeAndFlush(referenceCounted, channel2.newPromise()); + } else { + channel.write(referenceCounted, channel2.newPromise()); + } + fail(); + } catch (IllegalArgumentException expected) { + // expected + } + assertEquals(0, referenceCounted.refCnt()); + + channel.close().syncUninterruptibly(); + channel2.close().syncUninterruptibly(); + } + @Test(timeout = 5000) public void handlerAddedStateUpdatedBeforeHandlerAddedDoneForceEventLoop() throws InterruptedException { handlerAddedStateUpdatedBeforeHandlerAddedDone(true);