LocalChannel write when peer closed leak

Motivation:
If LocalChannel doWrite executes while the peer's state changes from CONNECTED to CLOSED it is possible that some promise's won't be completed and buffers will be leaked.

Modifications:
- Check the peer's state in doWrite to avoid a race condition

Result:
All write operations should release, and the associated promise should be completed.
This commit is contained in:
Scott Mitchell 2015-08-28 16:26:19 -07:00
parent 41ee9148e5
commit 1e763b6504
2 changed files with 198 additions and 27 deletions

View File

@ -27,8 +27,9 @@ import io.netty.channel.DefaultChannelConfig;
import io.netty.channel.EventLoop;
import io.netty.channel.SingleThreadEventLoop;
import io.netty.util.ReferenceCountUtil;
import io.netty.util.concurrent.SingleThreadEventExecutor;
import io.netty.util.concurrent.Future;
import io.netty.util.concurrent.SingleThreadEventExecutor;
import io.netty.util.internal.EmptyArrays;
import io.netty.util.internal.InternalThreadLocalMap;
import io.netty.util.internal.OneTimeTask;
import io.netty.util.internal.PlatformDependent;
@ -52,9 +53,10 @@ public class LocalChannel extends AbstractChannel {
private static final AtomicReferenceFieldUpdater<LocalChannel, Future> FINISH_READ_FUTURE_UPDATER;
private static final ChannelMetadata METADATA = new ChannelMetadata(false);
private static final int MAX_READER_STACK_DEPTH = 8;
private static final ClosedChannelException CLOSED_CHANNEL_EXCEPTION = new ClosedChannelException();
private final ChannelConfig config = new DefaultChannelConfig(this);
// To futher optimize this we could write our own SPSC queue.
// To further optimize this we could write our own SPSC queue.
private final Queue<Object> inboundBuffer = PlatformDependent.newMpscQueue();
private final Runnable readTask = new Runnable() {
@Override
@ -96,6 +98,7 @@ public class LocalChannel extends AbstractChannel {
AtomicReferenceFieldUpdater.newUpdater(LocalChannel.class, Future.class, "finishReadFuture");
}
FINISH_READ_FUTURE_UPDATER = finishReadFutureUpdater;
CLOSED_CHANNEL_EXCEPTION.setStackTrace(EmptyArrays.EMPTY_STACK_TRACE);
}
public LocalChannel() {
@ -218,10 +221,6 @@ public class LocalChannel extends AbstractChannel {
protected void doClose() throws Exception {
final LocalChannel peer = this.peer;
if (state != State.CLOSED) {
// To preserve ordering of events we must process any pending reads
if (writeInProgress && peer != null) {
finishPeerRead(peer);
}
// Update all internal state before the closeFuture is notified.
if (localAddress != null) {
if (parent() == null) {
@ -229,7 +228,22 @@ public class LocalChannel extends AbstractChannel {
}
localAddress = null;
}
// State change must happen before finishPeerRead to ensure writes are released either in doWrite or
// channelRead.
state = State.CLOSED;
ChannelPromise promise = connectPromise;
if (promise != null) {
// Use tryFailure() instead of setFailure() to avoid the race against cancel().
promise.tryFailure(CLOSED_CHANNEL_EXCEPTION);
connectPromise = null;
}
// To preserve ordering of events we must process any pending reads
if (writeInProgress && peer != null) {
finishPeerRead(peer);
}
}
if (peer != null && peer.isActive()) {
@ -241,12 +255,18 @@ public class LocalChannel extends AbstractChannel {
} else {
// This value may change, and so we should save it before executing the Runnable.
final boolean peerWriteInProgress = peer.writeInProgress;
peer.eventLoop().execute(new OneTimeTask() {
@Override
public void run() {
doPeerClose(peer, peerWriteInProgress);
}
});
try {
peer.eventLoop().execute(new OneTimeTask() {
@Override
public void run() {
doPeerClose(peer, peerWriteInProgress);
}
});
} catch (RuntimeException e) {
// The peer close may attempt to drain this.inboundBuffers. If that fails make sure it is drained.
releaseInboundBuffers();
throw e;
}
}
this.peer = null;
}
@ -295,7 +315,12 @@ public class LocalChannel extends AbstractChannel {
threadLocals.setLocalChannelReaderStackDepth(stackDepth);
}
} else {
eventLoop().execute(readTask);
try {
eventLoop().execute(readTask);
} catch (RuntimeException e) {
releaseInboundBuffers();
throw e;
}
}
}
@ -306,7 +331,7 @@ public class LocalChannel extends AbstractChannel {
case BOUND:
throw new NotYetConnectedException();
case CLOSED:
throw new ClosedChannelException();
throw CLOSED_CHANNEL_EXCEPTION;
case CONNECTED:
break;
}
@ -321,8 +346,14 @@ public class LocalChannel extends AbstractChannel {
break;
}
try {
peer.inboundBuffer.add(ReferenceCountUtil.retain(msg));
in.remove();
// It is possible the peer could have closed while we are writing, and in this case we should
// simulate real socket behavior and ensure the write operation is failed.
if (peer.state == State.CONNECTED) {
peer.inboundBuffer.add(ReferenceCountUtil.retain(msg));
in.remove();
} else {
in.remove(CLOSED_CHANNEL_EXCEPTION);
}
} catch (Throwable cause) {
in.remove(cause);
}
@ -357,10 +388,25 @@ public class LocalChannel extends AbstractChannel {
finishPeerRead0(peer);
}
};
if (peer.writeInProgress) {
peer.finishReadFuture = peer.eventLoop().submit(finishPeerReadTask);
} else {
peer.eventLoop().execute(finishPeerReadTask);
try {
if (peer.writeInProgress) {
peer.finishReadFuture = peer.eventLoop().submit(finishPeerReadTask);
} else {
peer.eventLoop().execute(finishPeerReadTask);
}
} catch (RuntimeException e) {
peer.releaseInboundBuffers();
throw e;
}
}
private void releaseInboundBuffers() {
for (;;) {
Object o = inboundBuffer.poll();
if (o == null) {
break;
}
ReferenceCountUtil.release(o);
}
}

View File

@ -340,8 +340,8 @@ public class LocalChannelTest {
@Override
public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception {
if (msg.equals(data)) {
messageLatch.countDown();
ReferenceCountUtil.safeRelease(msg);
messageLatch.countDown();
} else {
super.channelRead(ctx, msg);
}
@ -409,8 +409,8 @@ public class LocalChannelTest {
public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception {
final long count = messageLatch.getCount();
if ((data.equals(msg) && count == 2) || (data2.equals(msg) && count == 1)) {
messageLatch.countDown();
ReferenceCountUtil.safeRelease(msg);
messageLatch.countDown();
} else {
super.channelRead(ctx, msg);
}
@ -469,8 +469,8 @@ public class LocalChannelTest {
@Override
public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception {
if (data2.equals(msg)) {
messageLatch.countDown();
ReferenceCountUtil.safeRelease(msg);
messageLatch.countDown();
} else {
super.channelRead(ctx, msg);
}
@ -486,8 +486,8 @@ public class LocalChannelTest {
@Override
public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception {
if (data.equals(msg)) {
messageLatch.countDown();
ReferenceCountUtil.safeRelease(msg);
messageLatch.countDown();
} else {
super.channelRead(ctx, msg);
}
@ -551,8 +551,8 @@ public class LocalChannelTest {
@Override
public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception {
if (data2.equals(msg) && messageLatch.getCount() == 1) {
messageLatch.countDown();
ReferenceCountUtil.safeRelease(msg);
messageLatch.countDown();
} else {
super.channelRead(ctx, msg);
}
@ -568,8 +568,8 @@ public class LocalChannelTest {
@Override
public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception {
if (data.equals(msg) && messageLatch.getCount() == 2) {
messageLatch.countDown();
ReferenceCountUtil.safeRelease(msg);
messageLatch.countDown();
} else {
super.channelRead(ctx, msg);
}
@ -642,8 +642,8 @@ public class LocalChannelTest {
@Override
public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception {
if (msg.equals(data)) {
messageLatch.countDown();
ReferenceCountUtil.safeRelease(msg);
messageLatch.countDown();
} else {
super.channelRead(ctx, msg);
}
@ -698,6 +698,130 @@ public class LocalChannelTest {
}
}
@Test
public void testWriteWhilePeerIsClosedReleaseObjectAndFailPromise() throws InterruptedException {
Bootstrap cb = new Bootstrap();
ServerBootstrap sb = new ServerBootstrap();
final CountDownLatch serverMessageLatch = new CountDownLatch(1);
final LatchChannelFutureListener serverChannelCloseLatch = new LatchChannelFutureListener(1);
final LatchChannelFutureListener clientChannelCloseLatch = new LatchChannelFutureListener(1);
final CountDownLatch writeFailLatch = new CountDownLatch(1);
final ByteBuf data = Unpooled.wrappedBuffer(new byte[1024]);
final ByteBuf data2 = Unpooled.wrappedBuffer(new byte[512]);
final CountDownLatch serverChannelLatch = new CountDownLatch(1);
final AtomicReference<Channel> serverChannelRef = new AtomicReference<Channel>();
try {
cb.group(group1)
.channel(LocalChannel.class)
.handler(new TestHandler());
sb.group(group2)
.channel(LocalServerChannel.class)
.childHandler(new ChannelInitializer<LocalChannel>() {
@Override
public void initChannel(LocalChannel ch) throws Exception {
ch.pipeline().addLast(new ChannelInboundHandlerAdapter() {
@Override
public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception {
if (data.equals(msg)) {
ReferenceCountUtil.safeRelease(msg);
serverMessageLatch.countDown();
} else {
super.channelRead(ctx, msg);
}
}
});
serverChannelRef.set(ch);
serverChannelLatch.countDown();
}
});
Channel sc = null;
Channel cc = null;
try {
// Start server
sc = sb.bind(TEST_ADDRESS).syncUninterruptibly().channel();
// Connect to the server
cc = cb.connect(sc.localAddress()).syncUninterruptibly().channel();
assertTrue(serverChannelLatch.await(5, SECONDS));
final Channel ccCpy = cc;
final Channel serverChannelCpy = serverChannelRef.get();
serverChannelCpy.closeFuture().addListener(serverChannelCloseLatch);
ccCpy.closeFuture().addListener(clientChannelCloseLatch);
// Make sure a write operation is executed in the eventloop
cc.pipeline().lastContext().executor().execute(new OneTimeTask() {
@Override
public void run() {
ccCpy.writeAndFlush(data.duplicate().retain(), ccCpy.newPromise())
.addListener(new ChannelFutureListener() {
@Override
public void operationComplete(ChannelFuture future) throws Exception {
serverChannelCpy.eventLoop().execute(new OneTimeTask() {
@Override
public void run() {
// The point of this test is to write while the peer is closed, so we should
// ensure the peer is actually closed before we write.
int waitCount = 0;
while (ccCpy.isOpen()) {
try {
Thread.sleep(50);
} catch (InterruptedException ignored) {
// ignored
}
if (++waitCount > 5) {
fail();
}
}
serverChannelCpy.writeAndFlush(data2.duplicate().retain(),
serverChannelCpy.newPromise())
.addListener(new ChannelFutureListener() {
@Override
public void operationComplete(ChannelFuture future) throws Exception {
if (!future.isSuccess() &&
future.cause() instanceof ClosedChannelException) {
writeFailLatch.countDown();
}
}
});
}
});
ccCpy.close();
}
});
}
});
assertTrue(serverMessageLatch.await(5, SECONDS));
assertTrue(writeFailLatch.await(5, SECONDS));
assertTrue(serverChannelCloseLatch.await(5, SECONDS));
assertTrue(clientChannelCloseLatch.await(5, SECONDS));
assertFalse(ccCpy.isOpen());
assertFalse(serverChannelCpy.isOpen());
} finally {
closeChannel(cc);
closeChannel(sc);
}
} finally {
data.release();
data2.release();
}
}
private static final class LatchChannelFutureListener extends CountDownLatch implements ChannelFutureListener {
public LatchChannelFutureListener(int count) {
super(count);
}
@Override
public void operationComplete(ChannelFuture future) throws Exception {
countDown();
}
}
private static void closeChannel(Channel cc) {
if (cc != null) {
cc.close().syncUninterruptibly();
@ -708,6 +832,7 @@ public class LocalChannelTest {
@Override
public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception {
logger.info(String.format("Received mesage: %s", msg));
ReferenceCountUtil.safeRelease(msg);
}
}
}