LocalChannel write when peer closed leak

Motivation:
If LocalChannel doWrite executes while the peer's state changes from CONNECTED to CLOSED it is possible that some promise's won't be completed and buffers will be leaked.

Modifications:
- Check the peer's state in doWrite to avoid a race condition

Result:
All write operations should release, and the associated promise should be completed.
This commit is contained in:
Scott Mitchell 2015-08-28 16:26:19 -07:00
parent fd70dd658e
commit 71308376ca
2 changed files with 198 additions and 27 deletions

View File

@ -27,8 +27,9 @@ import io.netty.channel.DefaultChannelConfig;
import io.netty.channel.EventLoop; import io.netty.channel.EventLoop;
import io.netty.channel.SingleThreadEventLoop; import io.netty.channel.SingleThreadEventLoop;
import io.netty.util.ReferenceCountUtil; import io.netty.util.ReferenceCountUtil;
import io.netty.util.concurrent.SingleThreadEventExecutor;
import io.netty.util.concurrent.Future; import io.netty.util.concurrent.Future;
import io.netty.util.concurrent.SingleThreadEventExecutor;
import io.netty.util.internal.EmptyArrays;
import io.netty.util.internal.InternalThreadLocalMap; import io.netty.util.internal.InternalThreadLocalMap;
import io.netty.util.internal.OneTimeTask; import io.netty.util.internal.OneTimeTask;
import io.netty.util.internal.PlatformDependent; import io.netty.util.internal.PlatformDependent;
@ -50,9 +51,10 @@ public class LocalChannel extends AbstractChannel {
private static final AtomicReferenceFieldUpdater<LocalChannel, Future> FINISH_READ_FUTURE_UPDATER; private static final AtomicReferenceFieldUpdater<LocalChannel, Future> FINISH_READ_FUTURE_UPDATER;
private static final ChannelMetadata METADATA = new ChannelMetadata(false); private static final ChannelMetadata METADATA = new ChannelMetadata(false);
private static final int MAX_READER_STACK_DEPTH = 8; private static final int MAX_READER_STACK_DEPTH = 8;
private static final ClosedChannelException CLOSED_CHANNEL_EXCEPTION = new ClosedChannelException();
private final ChannelConfig config = new DefaultChannelConfig(this); private final ChannelConfig config = new DefaultChannelConfig(this);
// To futher optimize this we could write our own SPSC queue. // To further optimize this we could write our own SPSC queue.
private final Queue<Object> inboundBuffer = PlatformDependent.newMpscQueue(); private final Queue<Object> inboundBuffer = PlatformDependent.newMpscQueue();
private final Runnable readTask = new Runnable() { private final Runnable readTask = new Runnable() {
@Override @Override
@ -94,6 +96,7 @@ public class LocalChannel extends AbstractChannel {
AtomicReferenceFieldUpdater.newUpdater(LocalChannel.class, Future.class, "finishReadFuture"); AtomicReferenceFieldUpdater.newUpdater(LocalChannel.class, Future.class, "finishReadFuture");
} }
FINISH_READ_FUTURE_UPDATER = finishReadFutureUpdater; FINISH_READ_FUTURE_UPDATER = finishReadFutureUpdater;
CLOSED_CHANNEL_EXCEPTION.setStackTrace(EmptyArrays.EMPTY_STACK_TRACE);
} }
public LocalChannel() { public LocalChannel() {
@ -216,10 +219,6 @@ public class LocalChannel extends AbstractChannel {
protected void doClose() throws Exception { protected void doClose() throws Exception {
final LocalChannel peer = this.peer; final LocalChannel peer = this.peer;
if (state <= 2) { if (state <= 2) {
// To preserve ordering of events we must process any pending reads
if (writeInProgress && peer != null) {
finishPeerRead(peer);
}
// Update all internal state before the closeFuture is notified. // Update all internal state before the closeFuture is notified.
if (localAddress != null) { if (localAddress != null) {
if (parent() == null) { if (parent() == null) {
@ -227,7 +226,22 @@ public class LocalChannel extends AbstractChannel {
} }
localAddress = null; localAddress = null;
} }
// State change must happen before finishPeerRead to ensure writes are released either in doWrite or
// channelRead.
state = 3; state = 3;
ChannelPromise promise = connectPromise;
if (promise != null) {
// Use tryFailure() instead of setFailure() to avoid the race against cancel().
promise.tryFailure(CLOSED_CHANNEL_EXCEPTION);
connectPromise = null;
}
// To preserve ordering of events we must process any pending reads
if (writeInProgress && peer != null) {
finishPeerRead(peer);
}
} }
if (peer != null && peer.isActive()) { if (peer != null && peer.isActive()) {
@ -239,12 +253,18 @@ public class LocalChannel extends AbstractChannel {
} else { } else {
// This value may change, and so we should save it before executing the Runnable. // This value may change, and so we should save it before executing the Runnable.
final boolean peerWriteInProgress = peer.writeInProgress; final boolean peerWriteInProgress = peer.writeInProgress;
peer.eventLoop().execute(new OneTimeTask() { try {
@Override peer.eventLoop().execute(new OneTimeTask() {
public void run() { @Override
doPeerClose(peer, peerWriteInProgress); public void run() {
} doPeerClose(peer, peerWriteInProgress);
}); }
});
} catch (RuntimeException e) {
// The peer close may attempt to drain this.inboundBuffers. If that fails make sure it is drained.
releaseInboundBuffers();
throw e;
}
} }
this.peer = null; this.peer = null;
} }
@ -293,7 +313,12 @@ public class LocalChannel extends AbstractChannel {
threadLocals.setLocalChannelReaderStackDepth(stackDepth); threadLocals.setLocalChannelReaderStackDepth(stackDepth);
} }
} else { } else {
eventLoop().execute(readTask); try {
eventLoop().execute(readTask);
} catch (RuntimeException e) {
releaseInboundBuffers();
throw e;
}
} }
} }
@ -303,7 +328,7 @@ public class LocalChannel extends AbstractChannel {
throw new NotYetConnectedException(); throw new NotYetConnectedException();
} }
if (state > 2) { if (state > 2) {
throw new ClosedChannelException(); throw CLOSED_CHANNEL_EXCEPTION;
} }
final LocalChannel peer = this.peer; final LocalChannel peer = this.peer;
@ -316,8 +341,14 @@ public class LocalChannel extends AbstractChannel {
break; break;
} }
try { try {
peer.inboundBuffer.add(ReferenceCountUtil.retain(msg)); // It is possible the peer could have closed while we are writing, and in this case we should
in.remove(); // simulate real socket behavior and ensure the write operation is failed.
if (peer.state == 2) {
peer.inboundBuffer.add(ReferenceCountUtil.retain(msg));
in.remove();
} else {
in.remove(CLOSED_CHANNEL_EXCEPTION);
}
} catch (Throwable cause) { } catch (Throwable cause) {
in.remove(cause); in.remove(cause);
} }
@ -352,10 +383,25 @@ public class LocalChannel extends AbstractChannel {
finishPeerRead0(peer); finishPeerRead0(peer);
} }
}; };
if (peer.writeInProgress) { try {
peer.finishReadFuture = peer.eventLoop().submit(finishPeerReadTask); if (peer.writeInProgress) {
} else { peer.finishReadFuture = peer.eventLoop().submit(finishPeerReadTask);
peer.eventLoop().execute(finishPeerReadTask); } else {
peer.eventLoop().execute(finishPeerReadTask);
}
} catch (RuntimeException e) {
peer.releaseInboundBuffers();
throw e;
}
}
private void releaseInboundBuffers() {
for (;;) {
Object o = inboundBuffer.poll();
if (o == null) {
break;
}
ReferenceCountUtil.release(o);
} }
} }

View File

@ -339,8 +339,8 @@ public class LocalChannelTest {
@Override @Override
public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception { public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception {
if (msg.equals(data)) { if (msg.equals(data)) {
messageLatch.countDown();
ReferenceCountUtil.safeRelease(msg); ReferenceCountUtil.safeRelease(msg);
messageLatch.countDown();
} else { } else {
super.channelRead(ctx, msg); super.channelRead(ctx, msg);
} }
@ -408,8 +408,8 @@ public class LocalChannelTest {
public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception { public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception {
final long count = messageLatch.getCount(); final long count = messageLatch.getCount();
if ((data.equals(msg) && count == 2) || (data2.equals(msg) && count == 1)) { if ((data.equals(msg) && count == 2) || (data2.equals(msg) && count == 1)) {
messageLatch.countDown();
ReferenceCountUtil.safeRelease(msg); ReferenceCountUtil.safeRelease(msg);
messageLatch.countDown();
} else { } else {
super.channelRead(ctx, msg); super.channelRead(ctx, msg);
} }
@ -468,8 +468,8 @@ public class LocalChannelTest {
@Override @Override
public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception { public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception {
if (data2.equals(msg)) { if (data2.equals(msg)) {
messageLatch.countDown();
ReferenceCountUtil.safeRelease(msg); ReferenceCountUtil.safeRelease(msg);
messageLatch.countDown();
} else { } else {
super.channelRead(ctx, msg); super.channelRead(ctx, msg);
} }
@ -485,8 +485,8 @@ public class LocalChannelTest {
@Override @Override
public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception { public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception {
if (data.equals(msg)) { if (data.equals(msg)) {
messageLatch.countDown();
ReferenceCountUtil.safeRelease(msg); ReferenceCountUtil.safeRelease(msg);
messageLatch.countDown();
} else { } else {
super.channelRead(ctx, msg); super.channelRead(ctx, msg);
} }
@ -550,8 +550,8 @@ public class LocalChannelTest {
@Override @Override
public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception { public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception {
if (data2.equals(msg) && messageLatch.getCount() == 1) { if (data2.equals(msg) && messageLatch.getCount() == 1) {
messageLatch.countDown();
ReferenceCountUtil.safeRelease(msg); ReferenceCountUtil.safeRelease(msg);
messageLatch.countDown();
} else { } else {
super.channelRead(ctx, msg); super.channelRead(ctx, msg);
} }
@ -567,8 +567,8 @@ public class LocalChannelTest {
@Override @Override
public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception { public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception {
if (data.equals(msg) && messageLatch.getCount() == 2) { if (data.equals(msg) && messageLatch.getCount() == 2) {
messageLatch.countDown();
ReferenceCountUtil.safeRelease(msg); ReferenceCountUtil.safeRelease(msg);
messageLatch.countDown();
} else { } else {
super.channelRead(ctx, msg); super.channelRead(ctx, msg);
} }
@ -641,8 +641,8 @@ public class LocalChannelTest {
@Override @Override
public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception { public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception {
if (msg.equals(data)) { if (msg.equals(data)) {
messageLatch.countDown();
ReferenceCountUtil.safeRelease(msg); ReferenceCountUtil.safeRelease(msg);
messageLatch.countDown();
} else { } else {
super.channelRead(ctx, msg); super.channelRead(ctx, msg);
} }
@ -697,6 +697,130 @@ public class LocalChannelTest {
} }
} }
@Test
public void testWriteWhilePeerIsClosedReleaseObjectAndFailPromise() throws InterruptedException {
Bootstrap cb = new Bootstrap();
ServerBootstrap sb = new ServerBootstrap();
final CountDownLatch serverMessageLatch = new CountDownLatch(1);
final LatchChannelFutureListener serverChannelCloseLatch = new LatchChannelFutureListener(1);
final LatchChannelFutureListener clientChannelCloseLatch = new LatchChannelFutureListener(1);
final CountDownLatch writeFailLatch = new CountDownLatch(1);
final ByteBuf data = Unpooled.wrappedBuffer(new byte[1024]);
final ByteBuf data2 = Unpooled.wrappedBuffer(new byte[512]);
final CountDownLatch serverChannelLatch = new CountDownLatch(1);
final AtomicReference<Channel> serverChannelRef = new AtomicReference<Channel>();
try {
cb.group(group1)
.channel(LocalChannel.class)
.handler(new TestHandler());
sb.group(group2)
.channel(LocalServerChannel.class)
.childHandler(new ChannelInitializer<LocalChannel>() {
@Override
public void initChannel(LocalChannel ch) throws Exception {
ch.pipeline().addLast(new ChannelInboundHandlerAdapter() {
@Override
public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception {
if (data.equals(msg)) {
ReferenceCountUtil.safeRelease(msg);
serverMessageLatch.countDown();
} else {
super.channelRead(ctx, msg);
}
}
});
serverChannelRef.set(ch);
serverChannelLatch.countDown();
}
});
Channel sc = null;
Channel cc = null;
try {
// Start server
sc = sb.bind(TEST_ADDRESS).syncUninterruptibly().channel();
// Connect to the server
cc = cb.connect(sc.localAddress()).syncUninterruptibly().channel();
assertTrue(serverChannelLatch.await(5, SECONDS));
final Channel ccCpy = cc;
final Channel serverChannelCpy = serverChannelRef.get();
serverChannelCpy.closeFuture().addListener(serverChannelCloseLatch);
ccCpy.closeFuture().addListener(clientChannelCloseLatch);
// Make sure a write operation is executed in the eventloop
cc.pipeline().lastContext().executor().execute(new OneTimeTask() {
@Override
public void run() {
ccCpy.writeAndFlush(data.duplicate().retain(), ccCpy.newPromise())
.addListener(new ChannelFutureListener() {
@Override
public void operationComplete(ChannelFuture future) throws Exception {
serverChannelCpy.eventLoop().execute(new OneTimeTask() {
@Override
public void run() {
// The point of this test is to write while the peer is closed, so we should
// ensure the peer is actually closed before we write.
int waitCount = 0;
while (ccCpy.isOpen()) {
try {
Thread.sleep(50);
} catch (InterruptedException ignored) {
// ignored
}
if (++waitCount > 5) {
fail();
}
}
serverChannelCpy.writeAndFlush(data2.duplicate().retain(),
serverChannelCpy.newPromise())
.addListener(new ChannelFutureListener() {
@Override
public void operationComplete(ChannelFuture future) throws Exception {
if (!future.isSuccess() &&
future.cause() instanceof ClosedChannelException) {
writeFailLatch.countDown();
}
}
});
}
});
ccCpy.close();
}
});
}
});
assertTrue(serverMessageLatch.await(5, SECONDS));
assertTrue(writeFailLatch.await(5, SECONDS));
assertTrue(serverChannelCloseLatch.await(5, SECONDS));
assertTrue(clientChannelCloseLatch.await(5, SECONDS));
assertFalse(ccCpy.isOpen());
assertFalse(serverChannelCpy.isOpen());
} finally {
closeChannel(cc);
closeChannel(sc);
}
} finally {
data.release();
data2.release();
}
}
private static final class LatchChannelFutureListener extends CountDownLatch implements ChannelFutureListener {
public LatchChannelFutureListener(int count) {
super(count);
}
@Override
public void operationComplete(ChannelFuture future) throws Exception {
countDown();
}
}
private static void closeChannel(Channel cc) { private static void closeChannel(Channel cc) {
if (cc != null) { if (cc != null) {
cc.close().syncUninterruptibly(); cc.close().syncUninterruptibly();
@ -707,6 +831,7 @@ public class LocalChannelTest {
@Override @Override
public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception { public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception {
logger.info(String.format("Received mesage: %s", msg)); logger.info(String.format("Received mesage: %s", msg));
ReferenceCountUtil.safeRelease(msg);
} }
} }
} }