diff --git a/transport/src/main/java/io/netty/channel/DefaultChannelHandlerInvoker.java b/transport/src/main/java/io/netty/channel/DefaultChannelHandlerInvoker.java index 624b405b4b..61a9f13f0f 100644 --- a/transport/src/main/java/io/netty/channel/DefaultChannelHandlerInvoker.java +++ b/transport/src/main/java/io/netty/channel/DefaultChannelHandlerInvoker.java @@ -328,10 +328,14 @@ public class DefaultChannelHandlerInvoker implements ChannelHandlerInvoker { if (msg == null) { throw new NullPointerException("msg"); } - if (!validatePromise(ctx, promise, true)) { - // promise cancelled + try { + if (!validatePromise(ctx, promise, true)) { + ReferenceCountUtil.release(msg); + return; + } + } catch (RuntimeException e) { ReferenceCountUtil.release(msg); - return; + throw e; } if (executor.inEventLoop()) { diff --git a/transport/src/test/java/io/netty/channel/DefaultChannelHandlerInvokerTest.java b/transport/src/test/java/io/netty/channel/DefaultChannelHandlerInvokerTest.java new file mode 100644 index 0000000000..ac2761a9fd --- /dev/null +++ b/transport/src/test/java/io/netty/channel/DefaultChannelHandlerInvokerTest.java @@ -0,0 +1,56 @@ +/* + * Copyright 2016 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.channel; + +import io.netty.util.ReferenceCounted; +import io.netty.util.concurrent.ImmediateEventExecutor; +import org.junit.Before; +import org.junit.Test; +import org.mockito.Mock; +import org.mockito.MockitoAnnotations; + +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; + +public class DefaultChannelHandlerInvokerTest { + @Mock + private ReferenceCounted msg; + @Mock + private ChannelHandlerContext ctx; + @Mock + private ChannelPromise promise; + + @Before + public void setup() { + MockitoAnnotations.initMocks(this); + } + + @Test(expected = IllegalArgumentException.class) + public void writeWithInvalidPromiseStillReleasesMessage() { + when(promise.isDone()).thenReturn(true); + DefaultChannelHandlerInvoker invoker = new DefaultChannelHandlerInvoker(ImmediateEventExecutor.INSTANCE); + invoker.invokeWrite(ctx, msg, promise); + verify(msg).release(); + } + + @Test(expected = NullPointerException.class) + public void writeWithNullPromiseStillReleasesMessage() { + when(promise.isDone()).thenReturn(true); + DefaultChannelHandlerInvoker invoker = new DefaultChannelHandlerInvoker(ImmediateEventExecutor.INSTANCE); + invoker.invokeWrite(ctx, msg, null); + verify(msg).release(); + } +}