From dfd6b9009c5e9929308ff2e9c0181dff55c60e26 Mon Sep 17 00:00:00 2001 From: Norman Maurer Date: Thu, 17 Apr 2014 14:24:36 +0200 Subject: [PATCH] [#2144] Fix NPE in Local transport caused by a race Motivation: At the moment it is possible to see a NPE when the LocalSocketChannels doRegister() method is called and the LocalSocketChannels doClose() method is called before the registration was completed. Modifications: Make sure we delay the actual close until the registration task was executed. Result: No more NPE --- .../io/netty/channel/local/LocalChannel.java | 17 ++++- .../netty/channel/local/LocalChannelTest.java | 71 +++++++++++++++++++ 2 files changed, 87 insertions(+), 1 deletion(-) diff --git a/transport/src/main/java/io/netty/channel/local/LocalChannel.java b/transport/src/main/java/io/netty/channel/local/LocalChannel.java index 302ab6f036..3b56d95ab8 100644 --- a/transport/src/main/java/io/netty/channel/local/LocalChannel.java +++ b/transport/src/main/java/io/netty/channel/local/LocalChannel.java @@ -85,6 +85,7 @@ public class LocalChannel extends AbstractChannel { private volatile LocalAddress remoteAddress; private volatile ChannelPromise connectPromise; private volatile boolean readInProgress; + private volatile boolean registerInProgress; public LocalChannel(EventLoop eventLoop) { super(null, eventLoop); @@ -155,6 +156,14 @@ public class LocalChannel extends AbstractChannel { @Override protected void doRegister() throws Exception { if (peer != null) { + // Store the peer in a local variable as it may be set to null if doClose() is called. + // Because of this we also set registerInProgress to true as we check for this in doClose() and make sure + // we delay the fireChannelInactive() to be fired after the fireChannelActive() and so keep the correct + // order of events. + // + // See https://github.com/netty/netty/issues/2144 + final LocalChannel peer = this.peer; + registerInProgress = true; state = State.CONNECTED; peer.remoteAddress = parent() == null ? null : parent().localAddress(); @@ -167,6 +176,7 @@ public class LocalChannel extends AbstractChannel { peer.eventLoop().execute(new Runnable() { @Override public void run() { + registerInProgress = false; peer.pipeline().fireChannelActive(); peer.connectPromise.setSuccess(); } @@ -206,7 +216,12 @@ public class LocalChannel extends AbstractChannel { // Need to execute the close in the correct EventLoop // See https://github.com/netty/netty/issues/1777 EventLoop eventLoop = peer.eventLoop(); - if (eventLoop.inEventLoop()) { + + // Also check if the registration was not done yet. In this case we submit the close to the EventLoop + // to make sure it is run after the registration completes. + // + // See https://github.com/netty/netty/issues/2144 + if (eventLoop.inEventLoop() && !registerInProgress) { peer.unsafe().close(unsafe().voidPromise()); } else { peer.eventLoop().execute(new Runnable() { diff --git a/transport/src/test/java/io/netty/channel/local/LocalChannelTest.java b/transport/src/test/java/io/netty/channel/local/LocalChannelTest.java index 763b04e3da..b481d43e06 100644 --- a/transport/src/test/java/io/netty/channel/local/LocalChannelTest.java +++ b/transport/src/test/java/io/netty/channel/local/LocalChannelTest.java @@ -20,17 +20,22 @@ import io.netty.bootstrap.ServerBootstrap; import io.netty.channel.AbstractChannel; import io.netty.channel.Channel; import io.netty.channel.ChannelHandlerAdapter; +import io.netty.channel.ChannelFuture; import io.netty.channel.ChannelHandlerContext; import io.netty.channel.ChannelInitializer; import io.netty.channel.DefaultEventLoopGroup; +import io.netty.channel.EventLoop; import io.netty.channel.EventLoopGroup; import io.netty.channel.SimpleChannelInboundHandler; +import io.netty.channel.SingleThreadEventLoop; import io.netty.util.internal.logging.InternalLogger; import io.netty.util.internal.logging.InternalLoggerFactory; import org.junit.Test; import java.nio.channels.ClosedChannelException; import java.util.concurrent.CountDownLatch; +import java.util.concurrent.Executor; +import java.util.concurrent.TimeUnit; import static org.hamcrest.CoreMatchers.*; import static org.junit.Assert.*; @@ -178,7 +183,73 @@ public class LocalChannelTest { group.terminationFuture().sync(); } + @Test + public void localChannelRaceCondition() throws Exception { + final LocalAddress address = new LocalAddress("test"); + final CountDownLatch closeLatch = new CountDownLatch(1); + final EventLoopGroup serverGroup = new DefaultEventLoopGroup(1); + final EventLoopGroup clientGroup = new DefaultEventLoopGroup(1) { + @Override + protected EventLoop newChild(Executor threadFactory, Object... args) + throws Exception { + return new SingleThreadEventLoop(this, threadFactory, true) { + @Override + protected void run() { + for (;;) { + Runnable task = takeTask(); + if (task != null) { + /* Only slow down the anonymous class in LocalChannel#doRegister() */ + if (task.getClass().getEnclosingClass() == LocalChannel.class) { + try { + closeLatch.await(); + } catch (InterruptedException e) { + throw new Error(e); + } + } + task.run(); + updateLastExecutionTime(); + } + + if (confirmShutdown()) { + break; + } + } + } + }; + } + }; + try { + ServerBootstrap sb = new ServerBootstrap(); + sb.group(serverGroup). + channel(LocalServerChannel.class). + childHandler(new ChannelInitializer() { + @Override + protected void initChannel(Channel ch) throws Exception { + ch.close(); + closeLatch.countDown(); + } + }). + bind(address). + sync(); + Bootstrap bootstrap = new Bootstrap(); + bootstrap.group(clientGroup). + channel(LocalChannel.class). + handler(new ChannelInitializer() { + @Override + protected void initChannel(Channel ch) throws Exception { + /* Do nothing */ + } + }); + ChannelFuture future = bootstrap.connect(address); + assertTrue("Connection should finish, not time out", future.await(200)); + } finally { + serverGroup.shutdownGracefully(0, 0, TimeUnit.SECONDS).await(); + clientGroup.shutdownGracefully(0, 0, TimeUnit.SECONDS).await(); + } + } + static class TestHandler extends ChannelHandlerAdapter { + @Override public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception { logger.info(String.format("Received mesage: %s", msg));