[#2144] Fix NPE in Local transport caused by a race

Motivation:
At the moment it is possible to see a NPE when the LocalSocketChannels doRegister() method is called and the LocalSocketChannels doClose() method is called before the registration was completed.

Modifications:
Make sure we delay the actual close until the registration task was executed.

Result:
No more NPE
This commit is contained in:
Norman Maurer 2014-04-17 14:24:36 +02:00
parent 5babc1a498
commit dfd6b9009c
2 changed files with 87 additions and 1 deletions

View File

@ -85,6 +85,7 @@ public class LocalChannel extends AbstractChannel {
private volatile LocalAddress remoteAddress;
private volatile ChannelPromise connectPromise;
private volatile boolean readInProgress;
private volatile boolean registerInProgress;
public LocalChannel(EventLoop eventLoop) {
super(null, eventLoop);
@ -155,6 +156,14 @@ public class LocalChannel extends AbstractChannel {
@Override
protected void doRegister() throws Exception {
if (peer != null) {
// Store the peer in a local variable as it may be set to null if doClose() is called.
// Because of this we also set registerInProgress to true as we check for this in doClose() and make sure
// we delay the fireChannelInactive() to be fired after the fireChannelActive() and so keep the correct
// order of events.
//
// See https://github.com/netty/netty/issues/2144
final LocalChannel peer = this.peer;
registerInProgress = true;
state = State.CONNECTED;
peer.remoteAddress = parent() == null ? null : parent().localAddress();
@ -167,6 +176,7 @@ public class LocalChannel extends AbstractChannel {
peer.eventLoop().execute(new Runnable() {
@Override
public void run() {
registerInProgress = false;
peer.pipeline().fireChannelActive();
peer.connectPromise.setSuccess();
}
@ -206,7 +216,12 @@ public class LocalChannel extends AbstractChannel {
// Need to execute the close in the correct EventLoop
// See https://github.com/netty/netty/issues/1777
EventLoop eventLoop = peer.eventLoop();
if (eventLoop.inEventLoop()) {
// Also check if the registration was not done yet. In this case we submit the close to the EventLoop
// to make sure it is run after the registration completes.
//
// See https://github.com/netty/netty/issues/2144
if (eventLoop.inEventLoop() && !registerInProgress) {
peer.unsafe().close(unsafe().voidPromise());
} else {
peer.eventLoop().execute(new Runnable() {

View File

@ -20,17 +20,22 @@ import io.netty.bootstrap.ServerBootstrap;
import io.netty.channel.AbstractChannel;
import io.netty.channel.Channel;
import io.netty.channel.ChannelHandlerAdapter;
import io.netty.channel.ChannelFuture;
import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.ChannelInitializer;
import io.netty.channel.DefaultEventLoopGroup;
import io.netty.channel.EventLoop;
import io.netty.channel.EventLoopGroup;
import io.netty.channel.SimpleChannelInboundHandler;
import io.netty.channel.SingleThreadEventLoop;
import io.netty.util.internal.logging.InternalLogger;
import io.netty.util.internal.logging.InternalLoggerFactory;
import org.junit.Test;
import java.nio.channels.ClosedChannelException;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.Executor;
import java.util.concurrent.TimeUnit;
import static org.hamcrest.CoreMatchers.*;
import static org.junit.Assert.*;
@ -178,7 +183,73 @@ public class LocalChannelTest {
group.terminationFuture().sync();
}
@Test
public void localChannelRaceCondition() throws Exception {
final LocalAddress address = new LocalAddress("test");
final CountDownLatch closeLatch = new CountDownLatch(1);
final EventLoopGroup serverGroup = new DefaultEventLoopGroup(1);
final EventLoopGroup clientGroup = new DefaultEventLoopGroup(1) {
@Override
protected EventLoop newChild(Executor threadFactory, Object... args)
throws Exception {
return new SingleThreadEventLoop(this, threadFactory, true) {
@Override
protected void run() {
for (;;) {
Runnable task = takeTask();
if (task != null) {
/* Only slow down the anonymous class in LocalChannel#doRegister() */
if (task.getClass().getEnclosingClass() == LocalChannel.class) {
try {
closeLatch.await();
} catch (InterruptedException e) {
throw new Error(e);
}
}
task.run();
updateLastExecutionTime();
}
if (confirmShutdown()) {
break;
}
}
}
};
}
};
try {
ServerBootstrap sb = new ServerBootstrap();
sb.group(serverGroup).
channel(LocalServerChannel.class).
childHandler(new ChannelInitializer<Channel>() {
@Override
protected void initChannel(Channel ch) throws Exception {
ch.close();
closeLatch.countDown();
}
}).
bind(address).
sync();
Bootstrap bootstrap = new Bootstrap();
bootstrap.group(clientGroup).
channel(LocalChannel.class).
handler(new ChannelInitializer<Channel>() {
@Override
protected void initChannel(Channel ch) throws Exception {
/* Do nothing */
}
});
ChannelFuture future = bootstrap.connect(address);
assertTrue("Connection should finish, not time out", future.await(200));
} finally {
serverGroup.shutdownGracefully(0, 0, TimeUnit.SECONDS).await();
clientGroup.shutdownGracefully(0, 0, TimeUnit.SECONDS).await();
}
}
static class TestHandler extends ChannelHandlerAdapter {
@Override
public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception {
logger.info(String.format("Received mesage: %s", msg));