diff --git a/transport/src/main/java/io/netty/channel/embedded/EmbeddedChannel.java b/transport/src/main/java/io/netty/channel/embedded/EmbeddedChannel.java index 2d78ccf1cc..fb462f6085 100644 --- a/transport/src/main/java/io/netty/channel/embedded/EmbeddedChannel.java +++ b/transport/src/main/java/io/netty/channel/embedded/EmbeddedChannel.java @@ -290,9 +290,66 @@ public class EmbeddedChannel extends AbstractChannel { * @return bufferReadable returns {@code true} if any of the used buffers has something left to read */ public boolean finish() { + return finish(false); + } + + /** + * Mark this {@link Channel} as finished and release all pending message in the inbound and outbound buffer. + * Any futher try to write data to it will fail. + * + * @return bufferReadable returns {@code true} if any of the used buffers has something left to read + */ + public boolean finishAndReleaseAll() { + return finish(true); + } + + /** + * Mark this {@link Channel} as finished. Any futher try to write data to it will fail. + * + * @param releaseAll if {@code true} all pending message in the inbound and outbound buffer are released. + * @return bufferReadable returns {@code true} if any of the used buffers has something left to read + */ + private boolean finish(boolean releaseAll) { close(); - checkException(); - return isNotEmpty(inboundMessages) || isNotEmpty(outboundMessages); + try { + checkException(); + return isNotEmpty(inboundMessages) || isNotEmpty(outboundMessages); + } finally { + if (releaseAll) { + releaseAll(inboundMessages); + releaseAll(outboundMessages); + } + } + } + + /** + * Release all buffered inbound messages and return {@code true} if any were in the inbound buffer, {@code false} + * otherwise. + */ + public boolean releaseInbound() { + return releaseAll(inboundMessages); + } + + /** + * Release all buffered outbound messages and return {@code true} if any were in the outbound buffer, {@code false} + * otherwise. + */ + public boolean releaseOutbound() { + return releaseAll(outboundMessages); + } + + private static boolean releaseAll(Queue queue) { + if (isNotEmpty(queue)) { + for (;;) { + Object msg = queue.poll(); + if (msg == null) { + break; + } + ReferenceCountUtil.release(msg); + } + return true; + } + return false; } private void finishPendingTasks(boolean cancel) { diff --git a/transport/src/test/java/io/netty/channel/embedded/EmbeddedChannelTest.java b/transport/src/test/java/io/netty/channel/embedded/EmbeddedChannelTest.java index 7c59676278..33b15cfcc0 100644 --- a/transport/src/test/java/io/netty/channel/embedded/EmbeddedChannelTest.java +++ b/transport/src/test/java/io/netty/channel/embedded/EmbeddedChannelTest.java @@ -15,6 +15,8 @@ */ package io.netty.channel.embedded; +import io.netty.buffer.ByteBuf; +import io.netty.buffer.Unpooled; import io.netty.channel.Channel; import io.netty.channel.ChannelFuture; import io.netty.channel.ChannelHandler; @@ -233,6 +235,95 @@ public class EmbeddedChannelTest { assertNull(handler.pollEvent()); } + @Test + public void testFinishAndReleaseAll() { + ByteBuf in = Unpooled.buffer(); + ByteBuf out = Unpooled.buffer(); + try { + EmbeddedChannel channel = new EmbeddedChannel(); + assertTrue(channel.writeInbound(in)); + assertEquals(1, in.refCnt()); + + assertTrue(channel.writeOutbound(out)); + assertEquals(1, out.refCnt()); + + assertTrue(channel.finishAndReleaseAll()); + assertEquals(0, in.refCnt()); + assertEquals(0, out.refCnt()); + + assertNull(channel.readInbound()); + assertNull(channel.readOutbound()); + } finally { + release(in, out); + } + } + + @Test + public void testReleaseInbound() { + ByteBuf in = Unpooled.buffer(); + ByteBuf out = Unpooled.buffer(); + try { + EmbeddedChannel channel = new EmbeddedChannel(); + assertTrue(channel.writeInbound(in)); + assertEquals(1, in.refCnt()); + + assertTrue(channel.writeOutbound(out)); + assertEquals(1, out.refCnt()); + + assertTrue(channel.releaseInbound()); + assertEquals(0, in.refCnt()); + assertEquals(1, out.refCnt()); + + assertTrue(channel.finish()); + assertNull(channel.readInbound()); + + ByteBuf buffer = channel.readOutbound(); + assertSame(out, buffer); + buffer.release(); + + assertNull(channel.readOutbound()); + } finally { + release(in, out); + } + } + + @Test + public void testReleaseOutbound() { + ByteBuf in = Unpooled.buffer(); + ByteBuf out = Unpooled.buffer(); + try { + EmbeddedChannel channel = new EmbeddedChannel(); + assertTrue(channel.writeInbound(in)); + assertEquals(1, in.refCnt()); + + assertTrue(channel.writeOutbound(out)); + assertEquals(1, out.refCnt()); + + assertTrue(channel.releaseOutbound()); + assertEquals(1, in.refCnt()); + assertEquals(0, out.refCnt()); + + assertTrue(channel.finish()); + assertNull(channel.readOutbound()); + + ByteBuf buffer = channel.readInbound(); + assertSame(in, buffer); + buffer.release(); + + assertNull(channel.readInbound()); + } finally { + release(in, out); + } + } + + private static void release(ByteBuf... buffers) { + for (ByteBuf buffer : buffers) { + if (buffer.refCnt() > 0) { + buffer.release(); + } + } + } + private static final class EventOutboundHandler extends ChannelOutboundHandlerAdapter { static final Integer DISCONNECT = 0; static final Integer CLOSE = 1;