diff --git a/transport/src/main/java/io/netty/channel/local/LocalChannel.java b/transport/src/main/java/io/netty/channel/local/LocalChannel.java index ef1d6ccef9..655dcb4a23 100644 --- a/transport/src/main/java/io/netty/channel/local/LocalChannel.java +++ b/transport/src/main/java/io/netty/channel/local/LocalChannel.java @@ -79,6 +79,7 @@ public class LocalChannel extends AbstractChannel { private volatile ChannelPromise connectPromise; private volatile boolean readInProgress; private volatile boolean registerInProgress; + private volatile boolean writeInProgress; public LocalChannel() { super(null); @@ -171,7 +172,7 @@ public class LocalChannel extends AbstractChannel { // This ensures that if both channels are on the same event loop, the peer's channelActive // event is triggered *after* this channel's channelRegistered event, so that this channel's // pipeline is fully initialized by ChannelInitializer before any channelRead events. - peer.eventLoop().execute(new Runnable() { + peer.eventLoop().execute(new OneTimeTask() { @Override public void run() { registerInProgress = false; @@ -198,7 +199,12 @@ public class LocalChannel extends AbstractChannel { @Override protected void doClose() throws Exception { + final LocalChannel peer = this.peer; if (state <= 2) { + // To preserve ordering of events we must process any pending reads + if (writeInProgress && peer != null) { + finishPeerRead(peer); + } // Update all internal state before the closeFuture is notified. if (localAddress != null) { if (parent() == null) { @@ -209,23 +215,19 @@ public class LocalChannel extends AbstractChannel { state = 3; } - final LocalChannel peer = this.peer; if (peer != null && peer.isActive()) { - // Need to execute the close in the correct EventLoop - // See https://github.com/netty/netty/issues/1777 - EventLoop eventLoop = peer.eventLoop(); - + // Need to execute the close in the correct EventLoop (see https://github.com/netty/netty/issues/1777). // Also check if the registration was not done yet. In this case we submit the close to the EventLoop - // to make sure it is run after the registration completes. - // - // See https://github.com/netty/netty/issues/2144 - if (eventLoop.inEventLoop() && !registerInProgress) { - peer.unsafe().close(unsafe().voidPromise()); + // to make sure its run after the registration completes (see https://github.com/netty/netty/issues/2144). + if (peer.eventLoop().inEventLoop() && !registerInProgress) { + doPeerClose(peer, peer.writeInProgress); } else { - peer.eventLoop().execute(new Runnable() { + // This value may change, and so we should save it before executing the Runnable. + final boolean peerWriteInProgress = peer.writeInProgress; + peer.eventLoop().execute(new OneTimeTask() { @Override public void run() { - peer.unsafe().close(unsafe().voidPromise()); + doPeerClose(peer, peerWriteInProgress); } }); } @@ -233,6 +235,13 @@ public class LocalChannel extends AbstractChannel { } } + private void doPeerClose(LocalChannel peer, boolean peerWriteInProgress) { + if (peerWriteInProgress) { + finishPeerRead0(this); + } + peer.unsafe().close(peer.unsafe().voidPromise()); + } + @Override protected void doDeregister() throws Exception { // Just remove the shutdownHook as this Channel may be closed later or registered to another EventLoop @@ -283,35 +292,49 @@ public class LocalChannel extends AbstractChannel { } final LocalChannel peer = this.peer; - final ChannelPipeline peerPipeline = peer.pipeline(); - final EventLoop peerLoop = peer.eventLoop(); - for (;;) { - Object msg = in.current(); - if (msg == null) { - break; - } - try { - peer.inboundBuffer.add(ReferenceCountUtil.retain(msg)); - in.remove(); - } catch (Throwable cause) { - in.remove(cause); + writeInProgress = true; + try { + for (;;) { + Object msg = in.current(); + if (msg == null) { + break; + } + try { + peer.inboundBuffer.add(ReferenceCountUtil.retain(msg)); + in.remove(); + } catch (Throwable cause) { + in.remove(cause); + } } + } finally { + // The following situation may cause trouble: + // 1. Write (with promise X) + // 2. promise X is completed when in.remove() is called, and a listener on this promise calls close() + // 3. Then the close event will be executed for the peer before the write events, when the write events + // actually happened before the close event. + writeInProgress = false; } - if (peerLoop == eventLoop()) { - finishPeerRead(peer, peerPipeline); + finishPeerRead(peer); + } + + private void finishPeerRead(final LocalChannel peer) { + // If the peer is also writing, then we must schedule the event on the event loop to preserve read order. + if (peer.eventLoop() == eventLoop() && !peer.writeInProgress) { + finishPeerRead0(peer); } else { - peerLoop.execute(new OneTimeTask() { + peer.eventLoop().execute(new OneTimeTask() { @Override public void run() { - finishPeerRead(peer, peerPipeline); + finishPeerRead0(peer); } }); } } - private static void finishPeerRead(LocalChannel peer, ChannelPipeline peerPipeline) { + private static void finishPeerRead0(LocalChannel peer) { + ChannelPipeline peerPipeline = peer.pipeline(); if (peer.readInProgress) { peer.readInProgress = false; for (;;) { diff --git a/transport/src/test/java/io/netty/channel/local/LocalChannelTest.java b/transport/src/test/java/io/netty/channel/local/LocalChannelTest.java index d72d9d3197..65d8777819 100644 --- a/transport/src/test/java/io/netty/channel/local/LocalChannelTest.java +++ b/transport/src/test/java/io/netty/channel/local/LocalChannelTest.java @@ -17,27 +17,42 @@ package io.netty.channel.local; import io.netty.bootstrap.Bootstrap; import io.netty.bootstrap.ServerBootstrap; +import io.netty.buffer.ByteBuf; +import io.netty.buffer.Unpooled; import io.netty.channel.AbstractChannel; import io.netty.channel.Channel; import io.netty.channel.ChannelFuture; +import io.netty.channel.ChannelFutureListener; import io.netty.channel.ChannelHandlerContext; import io.netty.channel.ChannelInboundHandlerAdapter; import io.netty.channel.ChannelInitializer; +import io.netty.channel.ChannelPromise; import io.netty.channel.EventLoopGroup; import io.netty.channel.SimpleChannelInboundHandler; import io.netty.channel.SingleThreadEventLoop; +import io.netty.util.ReferenceCountUtil; import io.netty.util.concurrent.EventExecutor; +import io.netty.util.concurrent.Future; +import io.netty.util.internal.OneTimeTask; import io.netty.util.internal.logging.InternalLogger; import io.netty.util.internal.logging.InternalLoggerFactory; +import org.junit.AfterClass; +import org.junit.BeforeClass; import org.junit.Test; import java.nio.channels.ClosedChannelException; import java.util.concurrent.CountDownLatch; import java.util.concurrent.ThreadFactory; -import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicReference; -import static org.hamcrest.CoreMatchers.*; -import static org.junit.Assert.*; +import static java.util.concurrent.TimeUnit.SECONDS; +import static org.hamcrest.CoreMatchers.instanceOf; +import static org.hamcrest.CoreMatchers.is; +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertNull; +import static org.junit.Assert.assertThat; +import static org.junit.Assert.assertTrue; +import static org.junit.Assert.fail; public class LocalChannelTest { @@ -45,20 +60,39 @@ public class LocalChannelTest { private static final String LOCAL_ADDR_ID = "test.id"; + private static EventLoopGroup group1; + private static EventLoopGroup group2; + private static EventLoopGroup sharedGroup; + + @BeforeClass + public static void beforeClass() { + group1 = new LocalEventLoopGroup(2); + group2 = new LocalEventLoopGroup(2); + sharedGroup = new LocalEventLoopGroup(1); + } + + @AfterClass + public static void afterClass() throws InterruptedException { + Future group1Future = group1.shutdownGracefully(0, 0, SECONDS); + Future group2Future = group2.shutdownGracefully(0, 0, SECONDS); + Future sharedGroupFuture = sharedGroup.shutdownGracefully(0, 0, SECONDS); + group1Future.await(); + group2Future.await(); + sharedGroupFuture.await(); + } + @Test public void testLocalAddressReuse() throws Exception { for (int i = 0; i < 2; i ++) { - EventLoopGroup clientGroup = new LocalEventLoopGroup(); - EventLoopGroup serverGroup = new LocalEventLoopGroup(); LocalAddress addr = new LocalAddress(LOCAL_ADDR_ID); Bootstrap cb = new Bootstrap(); ServerBootstrap sb = new ServerBootstrap(); - cb.group(clientGroup) + cb.group(group1) .channel(LocalChannel.class) .handler(new TestHandler()); - sb.group(serverGroup) + sb.group(group2) .channel(LocalServerChannel.class) .childHandler(new ChannelInitializer() { @Override @@ -67,52 +101,52 @@ public class LocalChannelTest { } }); - // Start server - Channel sc = sb.bind(addr).sync().channel(); + Channel sc = null; + Channel cc = null; + try { + // Start server + sc = sb.bind(addr).sync().channel(); - final CountDownLatch latch = new CountDownLatch(1); - // Connect to the server - final Channel cc = cb.connect(addr).sync().channel(); - cc.eventLoop().execute(new Runnable() { - @Override - public void run() { - // Send a message event up the pipeline. - cc.pipeline().fireChannelRead("Hello, World"); - latch.countDown(); - } - }); - latch.await(); + final CountDownLatch latch = new CountDownLatch(1); + // Connect to the server + cc = cb.connect(addr).sync().channel(); + final Channel ccCpy = cc; + cc.eventLoop().execute(new Runnable() { + @Override + public void run() { + // Send a message event up the pipeline. + ccCpy.pipeline().fireChannelRead("Hello, World"); + latch.countDown(); + } + }); + assertTrue(latch.await(5, SECONDS)); - // Close the channel - cc.close().sync(); + // Close the channel + closeChannel(cc); + closeChannel(sc); + sc.closeFuture().sync(); - serverGroup.shutdownGracefully(); - clientGroup.shutdownGracefully(); - - sc.closeFuture().sync(); - - assertNull(String.format( - "Expected null, got channel '%s' for local address '%s'", - LocalChannelRegistry.get(addr), addr), LocalChannelRegistry.get(addr)); - - serverGroup.terminationFuture().sync(); - clientGroup.terminationFuture().sync(); + assertNull(String.format( + "Expected null, got channel '%s' for local address '%s'", + LocalChannelRegistry.get(addr), addr), LocalChannelRegistry.get(addr)); + } finally { + closeChannel(cc); + closeChannel(sc); + } } } @Test public void testWriteFailsFastOnClosedChannel() throws Exception { - EventLoopGroup clientGroup = new LocalEventLoopGroup(); - EventLoopGroup serverGroup = new LocalEventLoopGroup(); LocalAddress addr = new LocalAddress(LOCAL_ADDR_ID); Bootstrap cb = new Bootstrap(); ServerBootstrap sb = new ServerBootstrap(); - cb.group(clientGroup) + cb.group(group1) .channel(LocalChannel.class) .handler(new TestHandler()); - sb.group(serverGroup) + sb.group(group2) .channel(LocalServerChannel.class) .childHandler(new ChannelInitializer() { @Override @@ -121,41 +155,43 @@ public class LocalChannelTest { } }); - // Start server - sb.bind(addr).sync(); - - // Connect to the server - final Channel cc = cb.connect(addr).sync().channel(); - - // Close the channel and write something. - cc.close().sync(); + Channel sc = null; + Channel cc = null; try { - cc.writeAndFlush(new Object()).sync(); - fail("must raise a ClosedChannelException"); - } catch (Exception e) { - assertThat(e, is(instanceOf(ClosedChannelException.class))); - // Ensure that the actual write attempt on a closed channel was never made by asserting that - // the ClosedChannelException has been created by AbstractUnsafe rather than transport implementations. - if (e.getStackTrace().length > 0) { - assertThat( - e.getStackTrace()[0].getClassName(), is(AbstractChannel.class.getName() + "$AbstractUnsafe")); - e.printStackTrace(); - } - } + // Start server + sc = sb.bind(addr).sync().channel(); - serverGroup.shutdownGracefully(); - clientGroup.shutdownGracefully(); - serverGroup.terminationFuture().sync(); - clientGroup.terminationFuture().sync(); + // Connect to the server + cc = cb.connect(addr).sync().channel(); + + // Close the channel and write something. + cc.close().sync(); + try { + cc.writeAndFlush(new Object()).sync(); + fail("must raise a ClosedChannelException"); + } catch (Exception e) { + assertThat(e, is(instanceOf(ClosedChannelException.class))); + // Ensure that the actual write attempt on a closed channel was never made by asserting that + // the ClosedChannelException has been created by AbstractUnsafe rather than transport implementations. + if (e.getStackTrace().length > 0) { + assertThat( + e.getStackTrace()[0].getClassName(), is(AbstractChannel.class.getName() + + "$AbstractUnsafe")); + e.printStackTrace(); + } + } + } finally { + closeChannel(cc); + closeChannel(sc); + } } @Test public void testServerCloseChannelSameEventLoop() throws Exception { LocalAddress addr = new LocalAddress(LOCAL_ADDR_ID); - LocalEventLoopGroup group = new LocalEventLoopGroup(1); final CountDownLatch latch = new CountDownLatch(1); ServerBootstrap sb = new ServerBootstrap() - .group(group) + .group(group2) .channel(LocalServerChannel.class) .childHandler(new SimpleChannelInboundHandler() { @Override @@ -164,29 +200,33 @@ public class LocalChannelTest { latch.countDown(); } }); - sb.bind(addr).sync(); + Channel sc = null; + Channel cc = null; + try { + sc = sb.bind(addr).sync().channel(); - Bootstrap b = new Bootstrap() - .group(group) - .channel(LocalChannel.class) - .handler(new SimpleChannelInboundHandler() { - @Override - protected void channelRead0(ChannelHandlerContext ctx, Object msg) throws Exception { - // discard - } - }); - Channel channel = b.connect(addr).sync().channel(); - channel.writeAndFlush(new Object()); - latch.await(); - group.shutdownGracefully(); - group.terminationFuture().sync(); + Bootstrap b = new Bootstrap() + .group(group2) + .channel(LocalChannel.class) + .handler(new SimpleChannelInboundHandler() { + @Override + protected void channelRead0(ChannelHandlerContext ctx, Object msg) throws Exception { + // discard + } + }); + cc = b.connect(addr).sync().channel(); + cc.writeAndFlush(new Object()); + assertTrue(latch.await(5, SECONDS)); + } finally { + closeChannel(cc); + closeChannel(sc); + } } @Test public void localChannelRaceCondition() throws Exception { - final LocalAddress address = new LocalAddress("test"); + final LocalAddress address = new LocalAddress(LOCAL_ADDR_ID); final CountDownLatch closeLatch = new CountDownLatch(1); - final EventLoopGroup serverGroup = new LocalEventLoopGroup(1); final EventLoopGroup clientGroup = new LocalEventLoopGroup(1) { @Override protected EventExecutor newChild(ThreadFactory threadFactory, Object... args) @@ -217,9 +257,11 @@ public class LocalChannelTest { }; } }; + Channel sc = null; + Channel cc = null; try { ServerBootstrap sb = new ServerBootstrap(); - sb.group(serverGroup). + sc = sb.group(group2). channel(LocalServerChannel.class). childHandler(new ChannelInitializer() { @Override @@ -229,7 +271,7 @@ public class LocalChannelTest { } }). bind(address). - sync(); + sync().channel(); Bootstrap bootstrap = new Bootstrap(); bootstrap.group(clientGroup). channel(LocalChannel.class). @@ -241,16 +283,16 @@ public class LocalChannelTest { }); ChannelFuture future = bootstrap.connect(address); assertTrue("Connection should finish, not time out", future.await(200)); + cc = future.channel(); } finally { - serverGroup.shutdownGracefully(0, 0, TimeUnit.SECONDS).await(); - clientGroup.shutdownGracefully(0, 0, TimeUnit.SECONDS).await(); + closeChannel(cc); + closeChannel(sc); + clientGroup.shutdownGracefully(0, 0, SECONDS).await(); } } @Test public void testReRegister() { - EventLoopGroup group1 = new LocalEventLoopGroup(); - EventLoopGroup group2 = new LocalEventLoopGroup(); LocalAddress addr = new LocalAddress(LOCAL_ADDR_ID); Bootstrap cb = new Bootstrap(); ServerBootstrap sb = new ServerBootstrap(); @@ -268,18 +310,409 @@ public class LocalChannelTest { } }); - // Start server - final Channel sc = sb.bind(addr).syncUninterruptibly().channel(); + Channel sc = null; + Channel cc = null; + try { + // Start server + sc = sb.bind(addr).syncUninterruptibly().channel(); - // Connect to the server - final Channel cc = cb.connect(addr).syncUninterruptibly().channel(); + // Connect to the server + cc = cb.connect(addr).syncUninterruptibly().channel(); - cc.deregister().syncUninterruptibly(); - // Change event loop group. - group2.register(cc).syncUninterruptibly(); - cc.close().syncUninterruptibly(); - sc.close().syncUninterruptibly(); + cc.deregister().syncUninterruptibly(); + } finally { + closeChannel(cc); + closeChannel(sc); + } } + + @Test + public void testCloseInWritePromiseCompletePreservesOrder() throws InterruptedException { + LocalAddress addr = new LocalAddress(LOCAL_ADDR_ID); + Bootstrap cb = new Bootstrap(); + ServerBootstrap sb = new ServerBootstrap(); + final CountDownLatch messageLatch = new CountDownLatch(2); + final ByteBuf data = Unpooled.wrappedBuffer(new byte[1024]); + + try { + cb.group(group1) + .channel(LocalChannel.class) + .handler(new TestHandler()); + + sb.group(group2) + .channel(LocalServerChannel.class) + .childHandler(new ChannelInboundHandlerAdapter() { + @Override + public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception { + if (msg.equals(data)) { + messageLatch.countDown(); + ReferenceCountUtil.safeRelease(msg); + } else { + super.channelRead(ctx, msg); + } + } + @Override + public void channelInactive(ChannelHandlerContext ctx) throws Exception { + messageLatch.countDown(); + super.channelInactive(ctx); + } + }); + + Channel sc = null; + Channel cc = null; + try { + // Start server + sc = sb.bind(addr).syncUninterruptibly().channel(); + + // Connect to the server + cc = cb.connect(addr).syncUninterruptibly().channel(); + + final Channel ccCpy = cc; + // Make sure a write operation is executed in the eventloop + cc.pipeline().lastContext().executor().execute(new OneTimeTask() { + @Override + public void run() { + ChannelPromise promise = ccCpy.newPromise(); + promise.addListener(new ChannelFutureListener() { + @Override + public void operationComplete(ChannelFuture future) throws Exception { + ccCpy.pipeline().lastContext().close(); + } + }); + ccCpy.writeAndFlush(data.duplicate().retain(), promise); + } + }); + + assertTrue(messageLatch.await(5, SECONDS)); + assertFalse(cc.isOpen()); + } finally { + closeChannel(cc); + closeChannel(sc); + } + } finally { + data.release(); + } + } + + @Test + public void testWriteInWritePromiseCompletePreservesOrder() throws InterruptedException { + LocalAddress addr = new LocalAddress(LOCAL_ADDR_ID); + Bootstrap cb = new Bootstrap(); + ServerBootstrap sb = new ServerBootstrap(); + final CountDownLatch messageLatch = new CountDownLatch(2); + final ByteBuf data = Unpooled.wrappedBuffer(new byte[1024]); + final ByteBuf data2 = Unpooled.wrappedBuffer(new byte[512]); + + try { + cb.group(group1) + .channel(LocalChannel.class) + .handler(new TestHandler()); + + sb.group(group2) + .channel(LocalServerChannel.class) + .childHandler(new ChannelInboundHandlerAdapter() { + @Override + public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception { + final long count = messageLatch.getCount(); + if ((data.equals(msg) && count == 2) || (data2.equals(msg) && count == 1)) { + messageLatch.countDown(); + ReferenceCountUtil.safeRelease(msg); + } else { + super.channelRead(ctx, msg); + } + } + }); + + Channel sc = null; + Channel cc = null; + try { + // Start server + sc = sb.bind(addr).syncUninterruptibly().channel(); + + // Connect to the server + cc = cb.connect(addr).syncUninterruptibly().channel(); + + final Channel ccCpy = cc; + // Make sure a write operation is executed in the eventloop + cc.pipeline().lastContext().executor().execute(new OneTimeTask() { + @Override + public void run() { + ChannelPromise promise = ccCpy.newPromise(); + promise.addListener(new ChannelFutureListener() { + @Override + public void operationComplete(ChannelFuture future) throws Exception { + ccCpy.writeAndFlush(data2.duplicate().retain(), ccCpy.newPromise()); + } + }); + ccCpy.writeAndFlush(data.duplicate().retain(), promise); + } + }); + + assertTrue(messageLatch.await(5, SECONDS)); + } finally { + closeChannel(cc); + closeChannel(sc); + } + } finally { + data.release(); + data2.release(); + } + } + + @Test + public void testPeerWriteInWritePromiseCompleteDifferentEventLoopPreservesOrder() throws InterruptedException { + LocalAddress addr = new LocalAddress(LOCAL_ADDR_ID); + Bootstrap cb = new Bootstrap(); + ServerBootstrap sb = new ServerBootstrap(); + final CountDownLatch messageLatch = new CountDownLatch(2); + final ByteBuf data = Unpooled.wrappedBuffer(new byte[1024]); + final ByteBuf data2 = Unpooled.wrappedBuffer(new byte[512]); + final CountDownLatch serverChannelLatch = new CountDownLatch(1); + final AtomicReference serverChannelRef = new AtomicReference(); + + cb.group(group1) + .channel(LocalChannel.class) + .handler(new ChannelInboundHandlerAdapter() { + @Override + public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception { + if (data2.equals(msg)) { + messageLatch.countDown(); + ReferenceCountUtil.safeRelease(msg); + } else { + super.channelRead(ctx, msg); + } + } + }); + + sb.group(group2) + .channel(LocalServerChannel.class) + .childHandler(new ChannelInitializer() { + @Override + public void initChannel(LocalChannel ch) throws Exception { + ch.pipeline().addLast(new ChannelInboundHandlerAdapter() { + @Override + public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception { + if (data.equals(msg)) { + messageLatch.countDown(); + ReferenceCountUtil.safeRelease(msg); + } else { + super.channelRead(ctx, msg); + } + } + }); + serverChannelRef.set(ch); + serverChannelLatch.countDown(); + } + }); + + Channel sc = null; + Channel cc = null; + try { + // Start server + sc = sb.bind(addr).syncUninterruptibly().channel(); + + // Connect to the server + cc = cb.connect(addr).syncUninterruptibly().channel(); + assertTrue(serverChannelLatch.await(5, SECONDS)); + + final Channel ccCpy = cc; + // Make sure a write operation is executed in the eventloop + cc.pipeline().lastContext().executor().execute(new OneTimeTask() { + @Override + public void run() { + ChannelPromise promise = ccCpy.newPromise(); + promise.addListener(new ChannelFutureListener() { + @Override + public void operationComplete(ChannelFuture future) throws Exception { + Channel serverChannelCpy = serverChannelRef.get(); + serverChannelCpy.writeAndFlush(data2.duplicate().retain(), serverChannelCpy.newPromise()); + } + }); + ccCpy.writeAndFlush(data.duplicate().retain(), promise); + } + }); + + assertTrue(messageLatch.await(5, SECONDS)); + } finally { + closeChannel(cc); + closeChannel(sc); + data.release(); + data2.release(); + } + } + + @Test + public void testPeerWriteInWritePromiseCompleteSameEventLoopPreservesOrder() throws InterruptedException { + LocalAddress addr = new LocalAddress(LOCAL_ADDR_ID); + Bootstrap cb = new Bootstrap(); + ServerBootstrap sb = new ServerBootstrap(); + final CountDownLatch messageLatch = new CountDownLatch(2); + final ByteBuf data = Unpooled.wrappedBuffer(new byte[1024]); + final ByteBuf data2 = Unpooled.wrappedBuffer(new byte[512]); + final CountDownLatch serverChannelLatch = new CountDownLatch(1); + final AtomicReference serverChannelRef = new AtomicReference(); + + try { + cb.group(sharedGroup) + .channel(LocalChannel.class) + .handler(new ChannelInboundHandlerAdapter() { + @Override + public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception { + if (data2.equals(msg) && messageLatch.getCount() == 1) { + messageLatch.countDown(); + ReferenceCountUtil.safeRelease(msg); + } else { + super.channelRead(ctx, msg); + } + } + }); + + sb.group(sharedGroup) + .channel(LocalServerChannel.class) + .childHandler(new ChannelInitializer() { + @Override + public void initChannel(LocalChannel ch) throws Exception { + ch.pipeline().addLast(new ChannelInboundHandlerAdapter() { + @Override + public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception { + if (data.equals(msg) && messageLatch.getCount() == 2) { + messageLatch.countDown(); + ReferenceCountUtil.safeRelease(msg); + } else { + super.channelRead(ctx, msg); + } + } + }); + serverChannelRef.set(ch); + serverChannelLatch.countDown(); + } + }); + + Channel sc = null; + Channel cc = null; + try { + // Start server + sc = sb.bind(addr).syncUninterruptibly().channel(); + + // Connect to the server + cc = cb.connect(addr).syncUninterruptibly().channel(); + assertTrue(serverChannelLatch.await(5, SECONDS)); + + final Channel ccCpy = cc; + // Make sure a write operation is executed in the eventloop + cc.pipeline().lastContext().executor().execute(new OneTimeTask() { + @Override + public void run() { + ChannelPromise promise = ccCpy.newPromise(); + promise.addListener(new ChannelFutureListener() { + @Override + public void operationComplete(ChannelFuture future) throws Exception { + Channel serverChannelCpy = serverChannelRef.get(); + serverChannelCpy.writeAndFlush(data2.duplicate().retain(), + serverChannelCpy.newPromise()); + } + }); + ccCpy.writeAndFlush(data.duplicate().retain(), promise); + } + }); + + assertTrue(messageLatch.await(5, SECONDS)); + } finally { + closeChannel(cc); + closeChannel(sc); + } + } finally { + data.release(); + data2.release(); + } + } + + @Test + public void testClosePeerInWritePromiseCompleteSameEventLoopPreservesOrder() throws InterruptedException { + LocalAddress addr = new LocalAddress(LOCAL_ADDR_ID); + Bootstrap cb = new Bootstrap(); + ServerBootstrap sb = new ServerBootstrap(); + final CountDownLatch messageLatch = new CountDownLatch(2); + final CountDownLatch serverChannelLatch = new CountDownLatch(1); + final ByteBuf data = Unpooled.wrappedBuffer(new byte[1024]); + final AtomicReference serverChannelRef = new AtomicReference(); + + try { + cb.group(sharedGroup) + .channel(LocalChannel.class) + .handler(new TestHandler()); + + sb.group(sharedGroup) + .channel(LocalServerChannel.class) + .childHandler(new ChannelInitializer() { + @Override + public void initChannel(LocalChannel ch) throws Exception { + ch.pipeline().addLast(new ChannelInboundHandlerAdapter() { + @Override + public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception { + if (msg.equals(data)) { + messageLatch.countDown(); + ReferenceCountUtil.safeRelease(msg); + } else { + super.channelRead(ctx, msg); + } + } + @Override + public void channelInactive(ChannelHandlerContext ctx) throws Exception { + messageLatch.countDown(); + super.channelInactive(ctx); + } + }); + serverChannelRef.set(ch); + serverChannelLatch.countDown(); + } + }); + + Channel sc = null; + Channel cc = null; + try { + // Start server + sc = sb.bind(addr).syncUninterruptibly().channel(); + + // Connect to the server + cc = cb.connect(addr).syncUninterruptibly().channel(); + + assertTrue(serverChannelLatch.await(5, SECONDS)); + + final Channel ccCpy = cc; + // Make sure a write operation is executed in the eventloop + cc.pipeline().lastContext().executor().execute(new OneTimeTask() { + @Override + public void run() { + ChannelPromise promise = ccCpy.newPromise(); + promise.addListener(new ChannelFutureListener() { + @Override + public void operationComplete(ChannelFuture future) throws Exception { + serverChannelRef.get().close(); + } + }); + ccCpy.writeAndFlush(data.duplicate().retain(), promise); + } + }); + + assertTrue(messageLatch.await(5, SECONDS)); + assertFalse(cc.isOpen()); + assertFalse(serverChannelRef.get().isOpen()); + } finally { + closeChannel(cc); + closeChannel(sc); + } + } finally { + data.release(); + } + } + + private static void closeChannel(Channel cc) { + if (cc != null) { + cc.close().syncUninterruptibly(); + } + } + static class TestHandler extends ChannelInboundHandlerAdapter { @Override public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception {