diff --git a/transport/src/main/java/io/netty/channel/local/LocalChannel.java b/transport/src/main/java/io/netty/channel/local/LocalChannel.java index 25ffb6ccad..24eb2d7851 100644 --- a/transport/src/main/java/io/netty/channel/local/LocalChannel.java +++ b/transport/src/main/java/io/netty/channel/local/LocalChannel.java @@ -83,6 +83,7 @@ public class LocalChannel extends AbstractChannel { private volatile LocalAddress remoteAddress; private volatile ChannelPromise connectPromise; private volatile boolean readInProgress; + private volatile boolean registerInProgress; public LocalChannel() { super(null); @@ -153,6 +154,14 @@ public class LocalChannel extends AbstractChannel { @Override protected void doRegister() throws Exception { if (peer != null) { + // Store the peer in a local variable as it may be set to null if doClose() is called. + // Because of this we also set registerInProgress to true as we check for this in doClose() and make sure + // we delay the fireChannelInactive() to be fired after the fireChannelActive() and so keep the correct + // order of events. + // + // See https://github.com/netty/netty/issues/2144 + final LocalChannel peer = this.peer; + registerInProgress = true; state = 2; peer.remoteAddress = parent().localAddress(); @@ -165,6 +174,7 @@ public class LocalChannel extends AbstractChannel { peer.eventLoop().execute(new Runnable() { @Override public void run() { + registerInProgress = false; peer.pipeline().fireChannelActive(); peer.connectPromise.setSuccess(); } @@ -204,7 +214,12 @@ public class LocalChannel extends AbstractChannel { // Need to execute the close in the correct EventLoop // See https://github.com/netty/netty/issues/1777 EventLoop eventLoop = peer.eventLoop(); - if (eventLoop.inEventLoop()) { + + // Also check if the registration was not done yet. In this case we submit the close to the EventLoop + // to make sure it is run after the registration completes. + // + // See https://github.com/netty/netty/issues/2144 + if (eventLoop.inEventLoop() && !registerInProgress) { peer.unsafe().close(unsafe().voidPromise()); } else { peer.eventLoop().execute(new Runnable() { diff --git a/transport/src/test/java/io/netty/channel/local/LocalChannelTest.java b/transport/src/test/java/io/netty/channel/local/LocalChannelTest.java index f109040a2e..adcec5f99d 100644 --- a/transport/src/test/java/io/netty/channel/local/LocalChannelTest.java +++ b/transport/src/test/java/io/netty/channel/local/LocalChannelTest.java @@ -19,17 +19,23 @@ import io.netty.bootstrap.Bootstrap; import io.netty.bootstrap.ServerBootstrap; import io.netty.channel.AbstractChannel; import io.netty.channel.Channel; +import io.netty.channel.ChannelFuture; import io.netty.channel.ChannelHandlerContext; import io.netty.channel.ChannelInboundHandlerAdapter; import io.netty.channel.ChannelInitializer; import io.netty.channel.EventLoopGroup; import io.netty.channel.SimpleChannelInboundHandler; +import io.netty.channel.SingleThreadEventLoop; +import io.netty.util.concurrent.EventExecutor; import io.netty.util.internal.logging.InternalLogger; import io.netty.util.internal.logging.InternalLoggerFactory; +import org.junit.Assert; import org.junit.Test; import java.nio.channels.ClosedChannelException; import java.util.concurrent.CountDownLatch; +import java.util.concurrent.ThreadFactory; +import java.util.concurrent.TimeUnit; import static org.hamcrest.CoreMatchers.*; import static org.junit.Assert.*; @@ -177,6 +183,71 @@ public class LocalChannelTest { group.terminationFuture().sync(); } + @Test + public void localChannelRaceCondition() throws Exception { + final LocalAddress address = new LocalAddress("test"); + final CountDownLatch closeLatch = new CountDownLatch(1); + final EventLoopGroup serverGroup = new LocalEventLoopGroup(1); + final EventLoopGroup clientGroup = new LocalEventLoopGroup(1) { + @Override + protected EventExecutor newChild(ThreadFactory threadFactory, Object... args) + throws Exception { + return new SingleThreadEventLoop(this, threadFactory, true) { + @Override + protected void run() { + for (;;) { + Runnable task = takeTask(); + if (task != null) { + /* Only slow down the anonymous class in LocalChannel#doRegister() */ + if (task.getClass().getEnclosingClass() == LocalChannel.class) { + try { + closeLatch.await(); + } catch (InterruptedException e) { + throw new Error(e); + } + } + task.run(); + updateLastExecutionTime(); + } + + if (confirmShutdown()) { + break; + } + } + } + }; + } + }; + try { + ServerBootstrap sb = new ServerBootstrap(); + sb.group(serverGroup). + channel(LocalServerChannel.class). + childHandler(new ChannelInitializer() { + @Override + protected void initChannel(Channel ch) throws Exception { + ch.close(); + closeLatch.countDown(); + } + }). + bind(address). + sync(); + Bootstrap bootstrap = new Bootstrap(); + bootstrap.group(clientGroup). + channel(LocalChannel.class). + handler(new ChannelInitializer() { + @Override + protected void initChannel(Channel ch) throws Exception { + /* Do nothing */ + } + }); + ChannelFuture future = bootstrap.connect(address); + assertTrue("Connection should finish, not time out", future.await(200)); + } finally { + serverGroup.shutdownGracefully(0, 0, TimeUnit.SECONDS).await(); + clientGroup.shutdownGracefully(0, 0, TimeUnit.SECONDS).await(); + } + } + static class TestHandler extends ChannelInboundHandlerAdapter { @Override public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception {