From 71308376caf1d4a5b83e4c9bd7e4f247c85ab9ce Mon Sep 17 00:00:00 2001 From: Scott Mitchell Date: Fri, 28 Aug 2015 16:26:19 -0700 Subject: [PATCH] LocalChannel write when peer closed leak Motivation: If LocalChannel doWrite executes while the peer's state changes from CONNECTED to CLOSED it is possible that some promise's won't be completed and buffers will be leaked. Modifications: - Check the peer's state in doWrite to avoid a race condition Result: All write operations should release, and the associated promise should be completed. --- .../io/netty/channel/local/LocalChannel.java | 86 ++++++++--- .../netty/channel/local/LocalChannelTest.java | 139 +++++++++++++++++- 2 files changed, 198 insertions(+), 27 deletions(-) diff --git a/transport/src/main/java/io/netty/channel/local/LocalChannel.java b/transport/src/main/java/io/netty/channel/local/LocalChannel.java index 6d782d9c54..aa5b926ebb 100644 --- a/transport/src/main/java/io/netty/channel/local/LocalChannel.java +++ b/transport/src/main/java/io/netty/channel/local/LocalChannel.java @@ -27,8 +27,9 @@ import io.netty.channel.DefaultChannelConfig; import io.netty.channel.EventLoop; import io.netty.channel.SingleThreadEventLoop; import io.netty.util.ReferenceCountUtil; -import io.netty.util.concurrent.SingleThreadEventExecutor; import io.netty.util.concurrent.Future; +import io.netty.util.concurrent.SingleThreadEventExecutor; +import io.netty.util.internal.EmptyArrays; import io.netty.util.internal.InternalThreadLocalMap; import io.netty.util.internal.OneTimeTask; import io.netty.util.internal.PlatformDependent; @@ -50,9 +51,10 @@ public class LocalChannel extends AbstractChannel { private static final AtomicReferenceFieldUpdater FINISH_READ_FUTURE_UPDATER; private static final ChannelMetadata METADATA = new ChannelMetadata(false); private static final int MAX_READER_STACK_DEPTH = 8; + private static final ClosedChannelException CLOSED_CHANNEL_EXCEPTION = new ClosedChannelException(); private final ChannelConfig config = new DefaultChannelConfig(this); - // To futher optimize this we could write our own SPSC queue. + // To further optimize this we could write our own SPSC queue. private final Queue inboundBuffer = PlatformDependent.newMpscQueue(); private final Runnable readTask = new Runnable() { @Override @@ -94,6 +96,7 @@ public class LocalChannel extends AbstractChannel { AtomicReferenceFieldUpdater.newUpdater(LocalChannel.class, Future.class, "finishReadFuture"); } FINISH_READ_FUTURE_UPDATER = finishReadFutureUpdater; + CLOSED_CHANNEL_EXCEPTION.setStackTrace(EmptyArrays.EMPTY_STACK_TRACE); } public LocalChannel() { @@ -216,10 +219,6 @@ public class LocalChannel extends AbstractChannel { protected void doClose() throws Exception { final LocalChannel peer = this.peer; if (state <= 2) { - // To preserve ordering of events we must process any pending reads - if (writeInProgress && peer != null) { - finishPeerRead(peer); - } // Update all internal state before the closeFuture is notified. if (localAddress != null) { if (parent() == null) { @@ -227,7 +226,22 @@ public class LocalChannel extends AbstractChannel { } localAddress = null; } + + // State change must happen before finishPeerRead to ensure writes are released either in doWrite or + // channelRead. state = 3; + + ChannelPromise promise = connectPromise; + if (promise != null) { + // Use tryFailure() instead of setFailure() to avoid the race against cancel(). + promise.tryFailure(CLOSED_CHANNEL_EXCEPTION); + connectPromise = null; + } + + // To preserve ordering of events we must process any pending reads + if (writeInProgress && peer != null) { + finishPeerRead(peer); + } } if (peer != null && peer.isActive()) { @@ -239,12 +253,18 @@ public class LocalChannel extends AbstractChannel { } else { // This value may change, and so we should save it before executing the Runnable. final boolean peerWriteInProgress = peer.writeInProgress; - peer.eventLoop().execute(new OneTimeTask() { - @Override - public void run() { - doPeerClose(peer, peerWriteInProgress); - } - }); + try { + peer.eventLoop().execute(new OneTimeTask() { + @Override + public void run() { + doPeerClose(peer, peerWriteInProgress); + } + }); + } catch (RuntimeException e) { + // The peer close may attempt to drain this.inboundBuffers. If that fails make sure it is drained. + releaseInboundBuffers(); + throw e; + } } this.peer = null; } @@ -293,7 +313,12 @@ public class LocalChannel extends AbstractChannel { threadLocals.setLocalChannelReaderStackDepth(stackDepth); } } else { - eventLoop().execute(readTask); + try { + eventLoop().execute(readTask); + } catch (RuntimeException e) { + releaseInboundBuffers(); + throw e; + } } } @@ -303,7 +328,7 @@ public class LocalChannel extends AbstractChannel { throw new NotYetConnectedException(); } if (state > 2) { - throw new ClosedChannelException(); + throw CLOSED_CHANNEL_EXCEPTION; } final LocalChannel peer = this.peer; @@ -316,8 +341,14 @@ public class LocalChannel extends AbstractChannel { break; } try { - peer.inboundBuffer.add(ReferenceCountUtil.retain(msg)); - in.remove(); + // It is possible the peer could have closed while we are writing, and in this case we should + // simulate real socket behavior and ensure the write operation is failed. + if (peer.state == 2) { + peer.inboundBuffer.add(ReferenceCountUtil.retain(msg)); + in.remove(); + } else { + in.remove(CLOSED_CHANNEL_EXCEPTION); + } } catch (Throwable cause) { in.remove(cause); } @@ -352,10 +383,25 @@ public class LocalChannel extends AbstractChannel { finishPeerRead0(peer); } }; - if (peer.writeInProgress) { - peer.finishReadFuture = peer.eventLoop().submit(finishPeerReadTask); - } else { - peer.eventLoop().execute(finishPeerReadTask); + try { + if (peer.writeInProgress) { + peer.finishReadFuture = peer.eventLoop().submit(finishPeerReadTask); + } else { + peer.eventLoop().execute(finishPeerReadTask); + } + } catch (RuntimeException e) { + peer.releaseInboundBuffers(); + throw e; + } + } + + private void releaseInboundBuffers() { + for (;;) { + Object o = inboundBuffer.poll(); + if (o == null) { + break; + } + ReferenceCountUtil.release(o); } } diff --git a/transport/src/test/java/io/netty/channel/local/LocalChannelTest.java b/transport/src/test/java/io/netty/channel/local/LocalChannelTest.java index ca8558af40..3d7d4a0156 100644 --- a/transport/src/test/java/io/netty/channel/local/LocalChannelTest.java +++ b/transport/src/test/java/io/netty/channel/local/LocalChannelTest.java @@ -339,8 +339,8 @@ public class LocalChannelTest { @Override public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception { if (msg.equals(data)) { - messageLatch.countDown(); ReferenceCountUtil.safeRelease(msg); + messageLatch.countDown(); } else { super.channelRead(ctx, msg); } @@ -408,8 +408,8 @@ public class LocalChannelTest { public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception { final long count = messageLatch.getCount(); if ((data.equals(msg) && count == 2) || (data2.equals(msg) && count == 1)) { - messageLatch.countDown(); ReferenceCountUtil.safeRelease(msg); + messageLatch.countDown(); } else { super.channelRead(ctx, msg); } @@ -468,8 +468,8 @@ public class LocalChannelTest { @Override public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception { if (data2.equals(msg)) { - messageLatch.countDown(); ReferenceCountUtil.safeRelease(msg); + messageLatch.countDown(); } else { super.channelRead(ctx, msg); } @@ -485,8 +485,8 @@ public class LocalChannelTest { @Override public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception { if (data.equals(msg)) { - messageLatch.countDown(); ReferenceCountUtil.safeRelease(msg); + messageLatch.countDown(); } else { super.channelRead(ctx, msg); } @@ -550,8 +550,8 @@ public class LocalChannelTest { @Override public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception { if (data2.equals(msg) && messageLatch.getCount() == 1) { - messageLatch.countDown(); ReferenceCountUtil.safeRelease(msg); + messageLatch.countDown(); } else { super.channelRead(ctx, msg); } @@ -567,8 +567,8 @@ public class LocalChannelTest { @Override public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception { if (data.equals(msg) && messageLatch.getCount() == 2) { - messageLatch.countDown(); ReferenceCountUtil.safeRelease(msg); + messageLatch.countDown(); } else { super.channelRead(ctx, msg); } @@ -641,8 +641,8 @@ public class LocalChannelTest { @Override public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception { if (msg.equals(data)) { - messageLatch.countDown(); ReferenceCountUtil.safeRelease(msg); + messageLatch.countDown(); } else { super.channelRead(ctx, msg); } @@ -697,6 +697,130 @@ public class LocalChannelTest { } } + @Test + public void testWriteWhilePeerIsClosedReleaseObjectAndFailPromise() throws InterruptedException { + Bootstrap cb = new Bootstrap(); + ServerBootstrap sb = new ServerBootstrap(); + final CountDownLatch serverMessageLatch = new CountDownLatch(1); + final LatchChannelFutureListener serverChannelCloseLatch = new LatchChannelFutureListener(1); + final LatchChannelFutureListener clientChannelCloseLatch = new LatchChannelFutureListener(1); + final CountDownLatch writeFailLatch = new CountDownLatch(1); + final ByteBuf data = Unpooled.wrappedBuffer(new byte[1024]); + final ByteBuf data2 = Unpooled.wrappedBuffer(new byte[512]); + final CountDownLatch serverChannelLatch = new CountDownLatch(1); + final AtomicReference serverChannelRef = new AtomicReference(); + + try { + cb.group(group1) + .channel(LocalChannel.class) + .handler(new TestHandler()); + + sb.group(group2) + .channel(LocalServerChannel.class) + .childHandler(new ChannelInitializer() { + @Override + public void initChannel(LocalChannel ch) throws Exception { + ch.pipeline().addLast(new ChannelInboundHandlerAdapter() { + @Override + public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception { + if (data.equals(msg)) { + ReferenceCountUtil.safeRelease(msg); + serverMessageLatch.countDown(); + } else { + super.channelRead(ctx, msg); + } + } + }); + serverChannelRef.set(ch); + serverChannelLatch.countDown(); + } + }); + + Channel sc = null; + Channel cc = null; + try { + // Start server + sc = sb.bind(TEST_ADDRESS).syncUninterruptibly().channel(); + + // Connect to the server + cc = cb.connect(sc.localAddress()).syncUninterruptibly().channel(); + assertTrue(serverChannelLatch.await(5, SECONDS)); + + final Channel ccCpy = cc; + final Channel serverChannelCpy = serverChannelRef.get(); + serverChannelCpy.closeFuture().addListener(serverChannelCloseLatch); + ccCpy.closeFuture().addListener(clientChannelCloseLatch); + + // Make sure a write operation is executed in the eventloop + cc.pipeline().lastContext().executor().execute(new OneTimeTask() { + @Override + public void run() { + ccCpy.writeAndFlush(data.duplicate().retain(), ccCpy.newPromise()) + .addListener(new ChannelFutureListener() { + @Override + public void operationComplete(ChannelFuture future) throws Exception { + serverChannelCpy.eventLoop().execute(new OneTimeTask() { + @Override + public void run() { + // The point of this test is to write while the peer is closed, so we should + // ensure the peer is actually closed before we write. + int waitCount = 0; + while (ccCpy.isOpen()) { + try { + Thread.sleep(50); + } catch (InterruptedException ignored) { + // ignored + } + if (++waitCount > 5) { + fail(); + } + } + serverChannelCpy.writeAndFlush(data2.duplicate().retain(), + serverChannelCpy.newPromise()) + .addListener(new ChannelFutureListener() { + @Override + public void operationComplete(ChannelFuture future) throws Exception { + if (!future.isSuccess() && + future.cause() instanceof ClosedChannelException) { + writeFailLatch.countDown(); + } + } + }); + } + }); + ccCpy.close(); + } + }); + } + }); + + assertTrue(serverMessageLatch.await(5, SECONDS)); + assertTrue(writeFailLatch.await(5, SECONDS)); + assertTrue(serverChannelCloseLatch.await(5, SECONDS)); + assertTrue(clientChannelCloseLatch.await(5, SECONDS)); + assertFalse(ccCpy.isOpen()); + assertFalse(serverChannelCpy.isOpen()); + } finally { + closeChannel(cc); + closeChannel(sc); + } + } finally { + data.release(); + data2.release(); + } + } + + private static final class LatchChannelFutureListener extends CountDownLatch implements ChannelFutureListener { + public LatchChannelFutureListener(int count) { + super(count); + } + + @Override + public void operationComplete(ChannelFuture future) throws Exception { + countDown(); + } + } + private static void closeChannel(Channel cc) { if (cc != null) { cc.close().syncUninterruptibly(); @@ -707,6 +831,7 @@ public class LocalChannelTest { @Override public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception { logger.info(String.format("Received mesage: %s", msg)); + ReferenceCountUtil.safeRelease(msg); } } }