LocalChannel write when peer closed leak
Motivation: If LocalChannel doWrite executes while the peer's state changes from CONNECTED to CLOSED it is possible that some promise's won't be completed and buffers will be leaked. Modifications: - Check the peer's state in doWrite to avoid a race condition Result: All write operations should release, and the associated promise should be completed.
This commit is contained in:
parent
41ee9148e5
commit
1e763b6504
@ -27,8 +27,9 @@ import io.netty.channel.DefaultChannelConfig;
|
||||
import io.netty.channel.EventLoop;
|
||||
import io.netty.channel.SingleThreadEventLoop;
|
||||
import io.netty.util.ReferenceCountUtil;
|
||||
import io.netty.util.concurrent.SingleThreadEventExecutor;
|
||||
import io.netty.util.concurrent.Future;
|
||||
import io.netty.util.concurrent.SingleThreadEventExecutor;
|
||||
import io.netty.util.internal.EmptyArrays;
|
||||
import io.netty.util.internal.InternalThreadLocalMap;
|
||||
import io.netty.util.internal.OneTimeTask;
|
||||
import io.netty.util.internal.PlatformDependent;
|
||||
@ -52,9 +53,10 @@ public class LocalChannel extends AbstractChannel {
|
||||
private static final AtomicReferenceFieldUpdater<LocalChannel, Future> FINISH_READ_FUTURE_UPDATER;
|
||||
private static final ChannelMetadata METADATA = new ChannelMetadata(false);
|
||||
private static final int MAX_READER_STACK_DEPTH = 8;
|
||||
private static final ClosedChannelException CLOSED_CHANNEL_EXCEPTION = new ClosedChannelException();
|
||||
|
||||
private final ChannelConfig config = new DefaultChannelConfig(this);
|
||||
// To futher optimize this we could write our own SPSC queue.
|
||||
// To further optimize this we could write our own SPSC queue.
|
||||
private final Queue<Object> inboundBuffer = PlatformDependent.newMpscQueue();
|
||||
private final Runnable readTask = new Runnable() {
|
||||
@Override
|
||||
@ -96,6 +98,7 @@ public class LocalChannel extends AbstractChannel {
|
||||
AtomicReferenceFieldUpdater.newUpdater(LocalChannel.class, Future.class, "finishReadFuture");
|
||||
}
|
||||
FINISH_READ_FUTURE_UPDATER = finishReadFutureUpdater;
|
||||
CLOSED_CHANNEL_EXCEPTION.setStackTrace(EmptyArrays.EMPTY_STACK_TRACE);
|
||||
}
|
||||
|
||||
public LocalChannel() {
|
||||
@ -218,10 +221,6 @@ public class LocalChannel extends AbstractChannel {
|
||||
protected void doClose() throws Exception {
|
||||
final LocalChannel peer = this.peer;
|
||||
if (state != State.CLOSED) {
|
||||
// To preserve ordering of events we must process any pending reads
|
||||
if (writeInProgress && peer != null) {
|
||||
finishPeerRead(peer);
|
||||
}
|
||||
// Update all internal state before the closeFuture is notified.
|
||||
if (localAddress != null) {
|
||||
if (parent() == null) {
|
||||
@ -229,7 +228,22 @@ public class LocalChannel extends AbstractChannel {
|
||||
}
|
||||
localAddress = null;
|
||||
}
|
||||
|
||||
// State change must happen before finishPeerRead to ensure writes are released either in doWrite or
|
||||
// channelRead.
|
||||
state = State.CLOSED;
|
||||
|
||||
ChannelPromise promise = connectPromise;
|
||||
if (promise != null) {
|
||||
// Use tryFailure() instead of setFailure() to avoid the race against cancel().
|
||||
promise.tryFailure(CLOSED_CHANNEL_EXCEPTION);
|
||||
connectPromise = null;
|
||||
}
|
||||
|
||||
// To preserve ordering of events we must process any pending reads
|
||||
if (writeInProgress && peer != null) {
|
||||
finishPeerRead(peer);
|
||||
}
|
||||
}
|
||||
|
||||
if (peer != null && peer.isActive()) {
|
||||
@ -241,12 +255,18 @@ public class LocalChannel extends AbstractChannel {
|
||||
} else {
|
||||
// This value may change, and so we should save it before executing the Runnable.
|
||||
final boolean peerWriteInProgress = peer.writeInProgress;
|
||||
try {
|
||||
peer.eventLoop().execute(new OneTimeTask() {
|
||||
@Override
|
||||
public void run() {
|
||||
doPeerClose(peer, peerWriteInProgress);
|
||||
}
|
||||
});
|
||||
} catch (RuntimeException e) {
|
||||
// The peer close may attempt to drain this.inboundBuffers. If that fails make sure it is drained.
|
||||
releaseInboundBuffers();
|
||||
throw e;
|
||||
}
|
||||
}
|
||||
this.peer = null;
|
||||
}
|
||||
@ -295,7 +315,12 @@ public class LocalChannel extends AbstractChannel {
|
||||
threadLocals.setLocalChannelReaderStackDepth(stackDepth);
|
||||
}
|
||||
} else {
|
||||
try {
|
||||
eventLoop().execute(readTask);
|
||||
} catch (RuntimeException e) {
|
||||
releaseInboundBuffers();
|
||||
throw e;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -306,7 +331,7 @@ public class LocalChannel extends AbstractChannel {
|
||||
case BOUND:
|
||||
throw new NotYetConnectedException();
|
||||
case CLOSED:
|
||||
throw new ClosedChannelException();
|
||||
throw CLOSED_CHANNEL_EXCEPTION;
|
||||
case CONNECTED:
|
||||
break;
|
||||
}
|
||||
@ -321,8 +346,14 @@ public class LocalChannel extends AbstractChannel {
|
||||
break;
|
||||
}
|
||||
try {
|
||||
// It is possible the peer could have closed while we are writing, and in this case we should
|
||||
// simulate real socket behavior and ensure the write operation is failed.
|
||||
if (peer.state == State.CONNECTED) {
|
||||
peer.inboundBuffer.add(ReferenceCountUtil.retain(msg));
|
||||
in.remove();
|
||||
} else {
|
||||
in.remove(CLOSED_CHANNEL_EXCEPTION);
|
||||
}
|
||||
} catch (Throwable cause) {
|
||||
in.remove(cause);
|
||||
}
|
||||
@ -357,11 +388,26 @@ public class LocalChannel extends AbstractChannel {
|
||||
finishPeerRead0(peer);
|
||||
}
|
||||
};
|
||||
try {
|
||||
if (peer.writeInProgress) {
|
||||
peer.finishReadFuture = peer.eventLoop().submit(finishPeerReadTask);
|
||||
} else {
|
||||
peer.eventLoop().execute(finishPeerReadTask);
|
||||
}
|
||||
} catch (RuntimeException e) {
|
||||
peer.releaseInboundBuffers();
|
||||
throw e;
|
||||
}
|
||||
}
|
||||
|
||||
private void releaseInboundBuffers() {
|
||||
for (;;) {
|
||||
Object o = inboundBuffer.poll();
|
||||
if (o == null) {
|
||||
break;
|
||||
}
|
||||
ReferenceCountUtil.release(o);
|
||||
}
|
||||
}
|
||||
|
||||
private void finishPeerRead0(LocalChannel peer) {
|
||||
|
@ -340,8 +340,8 @@ public class LocalChannelTest {
|
||||
@Override
|
||||
public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception {
|
||||
if (msg.equals(data)) {
|
||||
messageLatch.countDown();
|
||||
ReferenceCountUtil.safeRelease(msg);
|
||||
messageLatch.countDown();
|
||||
} else {
|
||||
super.channelRead(ctx, msg);
|
||||
}
|
||||
@ -409,8 +409,8 @@ public class LocalChannelTest {
|
||||
public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception {
|
||||
final long count = messageLatch.getCount();
|
||||
if ((data.equals(msg) && count == 2) || (data2.equals(msg) && count == 1)) {
|
||||
messageLatch.countDown();
|
||||
ReferenceCountUtil.safeRelease(msg);
|
||||
messageLatch.countDown();
|
||||
} else {
|
||||
super.channelRead(ctx, msg);
|
||||
}
|
||||
@ -469,8 +469,8 @@ public class LocalChannelTest {
|
||||
@Override
|
||||
public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception {
|
||||
if (data2.equals(msg)) {
|
||||
messageLatch.countDown();
|
||||
ReferenceCountUtil.safeRelease(msg);
|
||||
messageLatch.countDown();
|
||||
} else {
|
||||
super.channelRead(ctx, msg);
|
||||
}
|
||||
@ -486,8 +486,8 @@ public class LocalChannelTest {
|
||||
@Override
|
||||
public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception {
|
||||
if (data.equals(msg)) {
|
||||
messageLatch.countDown();
|
||||
ReferenceCountUtil.safeRelease(msg);
|
||||
messageLatch.countDown();
|
||||
} else {
|
||||
super.channelRead(ctx, msg);
|
||||
}
|
||||
@ -551,8 +551,8 @@ public class LocalChannelTest {
|
||||
@Override
|
||||
public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception {
|
||||
if (data2.equals(msg) && messageLatch.getCount() == 1) {
|
||||
messageLatch.countDown();
|
||||
ReferenceCountUtil.safeRelease(msg);
|
||||
messageLatch.countDown();
|
||||
} else {
|
||||
super.channelRead(ctx, msg);
|
||||
}
|
||||
@ -568,8 +568,8 @@ public class LocalChannelTest {
|
||||
@Override
|
||||
public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception {
|
||||
if (data.equals(msg) && messageLatch.getCount() == 2) {
|
||||
messageLatch.countDown();
|
||||
ReferenceCountUtil.safeRelease(msg);
|
||||
messageLatch.countDown();
|
||||
} else {
|
||||
super.channelRead(ctx, msg);
|
||||
}
|
||||
@ -642,8 +642,8 @@ public class LocalChannelTest {
|
||||
@Override
|
||||
public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception {
|
||||
if (msg.equals(data)) {
|
||||
messageLatch.countDown();
|
||||
ReferenceCountUtil.safeRelease(msg);
|
||||
messageLatch.countDown();
|
||||
} else {
|
||||
super.channelRead(ctx, msg);
|
||||
}
|
||||
@ -698,6 +698,130 @@ public class LocalChannelTest {
|
||||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testWriteWhilePeerIsClosedReleaseObjectAndFailPromise() throws InterruptedException {
|
||||
Bootstrap cb = new Bootstrap();
|
||||
ServerBootstrap sb = new ServerBootstrap();
|
||||
final CountDownLatch serverMessageLatch = new CountDownLatch(1);
|
||||
final LatchChannelFutureListener serverChannelCloseLatch = new LatchChannelFutureListener(1);
|
||||
final LatchChannelFutureListener clientChannelCloseLatch = new LatchChannelFutureListener(1);
|
||||
final CountDownLatch writeFailLatch = new CountDownLatch(1);
|
||||
final ByteBuf data = Unpooled.wrappedBuffer(new byte[1024]);
|
||||
final ByteBuf data2 = Unpooled.wrappedBuffer(new byte[512]);
|
||||
final CountDownLatch serverChannelLatch = new CountDownLatch(1);
|
||||
final AtomicReference<Channel> serverChannelRef = new AtomicReference<Channel>();
|
||||
|
||||
try {
|
||||
cb.group(group1)
|
||||
.channel(LocalChannel.class)
|
||||
.handler(new TestHandler());
|
||||
|
||||
sb.group(group2)
|
||||
.channel(LocalServerChannel.class)
|
||||
.childHandler(new ChannelInitializer<LocalChannel>() {
|
||||
@Override
|
||||
public void initChannel(LocalChannel ch) throws Exception {
|
||||
ch.pipeline().addLast(new ChannelInboundHandlerAdapter() {
|
||||
@Override
|
||||
public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception {
|
||||
if (data.equals(msg)) {
|
||||
ReferenceCountUtil.safeRelease(msg);
|
||||
serverMessageLatch.countDown();
|
||||
} else {
|
||||
super.channelRead(ctx, msg);
|
||||
}
|
||||
}
|
||||
});
|
||||
serverChannelRef.set(ch);
|
||||
serverChannelLatch.countDown();
|
||||
}
|
||||
});
|
||||
|
||||
Channel sc = null;
|
||||
Channel cc = null;
|
||||
try {
|
||||
// Start server
|
||||
sc = sb.bind(TEST_ADDRESS).syncUninterruptibly().channel();
|
||||
|
||||
// Connect to the server
|
||||
cc = cb.connect(sc.localAddress()).syncUninterruptibly().channel();
|
||||
assertTrue(serverChannelLatch.await(5, SECONDS));
|
||||
|
||||
final Channel ccCpy = cc;
|
||||
final Channel serverChannelCpy = serverChannelRef.get();
|
||||
serverChannelCpy.closeFuture().addListener(serverChannelCloseLatch);
|
||||
ccCpy.closeFuture().addListener(clientChannelCloseLatch);
|
||||
|
||||
// Make sure a write operation is executed in the eventloop
|
||||
cc.pipeline().lastContext().executor().execute(new OneTimeTask() {
|
||||
@Override
|
||||
public void run() {
|
||||
ccCpy.writeAndFlush(data.duplicate().retain(), ccCpy.newPromise())
|
||||
.addListener(new ChannelFutureListener() {
|
||||
@Override
|
||||
public void operationComplete(ChannelFuture future) throws Exception {
|
||||
serverChannelCpy.eventLoop().execute(new OneTimeTask() {
|
||||
@Override
|
||||
public void run() {
|
||||
// The point of this test is to write while the peer is closed, so we should
|
||||
// ensure the peer is actually closed before we write.
|
||||
int waitCount = 0;
|
||||
while (ccCpy.isOpen()) {
|
||||
try {
|
||||
Thread.sleep(50);
|
||||
} catch (InterruptedException ignored) {
|
||||
// ignored
|
||||
}
|
||||
if (++waitCount > 5) {
|
||||
fail();
|
||||
}
|
||||
}
|
||||
serverChannelCpy.writeAndFlush(data2.duplicate().retain(),
|
||||
serverChannelCpy.newPromise())
|
||||
.addListener(new ChannelFutureListener() {
|
||||
@Override
|
||||
public void operationComplete(ChannelFuture future) throws Exception {
|
||||
if (!future.isSuccess() &&
|
||||
future.cause() instanceof ClosedChannelException) {
|
||||
writeFailLatch.countDown();
|
||||
}
|
||||
}
|
||||
});
|
||||
}
|
||||
});
|
||||
ccCpy.close();
|
||||
}
|
||||
});
|
||||
}
|
||||
});
|
||||
|
||||
assertTrue(serverMessageLatch.await(5, SECONDS));
|
||||
assertTrue(writeFailLatch.await(5, SECONDS));
|
||||
assertTrue(serverChannelCloseLatch.await(5, SECONDS));
|
||||
assertTrue(clientChannelCloseLatch.await(5, SECONDS));
|
||||
assertFalse(ccCpy.isOpen());
|
||||
assertFalse(serverChannelCpy.isOpen());
|
||||
} finally {
|
||||
closeChannel(cc);
|
||||
closeChannel(sc);
|
||||
}
|
||||
} finally {
|
||||
data.release();
|
||||
data2.release();
|
||||
}
|
||||
}
|
||||
|
||||
private static final class LatchChannelFutureListener extends CountDownLatch implements ChannelFutureListener {
|
||||
public LatchChannelFutureListener(int count) {
|
||||
super(count);
|
||||
}
|
||||
|
||||
@Override
|
||||
public void operationComplete(ChannelFuture future) throws Exception {
|
||||
countDown();
|
||||
}
|
||||
}
|
||||
|
||||
private static void closeChannel(Channel cc) {
|
||||
if (cc != null) {
|
||||
cc.close().syncUninterruptibly();
|
||||
@ -708,6 +832,7 @@ public class LocalChannelTest {
|
||||
@Override
|
||||
public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception {
|
||||
logger.info(String.format("Received mesage: %s", msg));
|
||||
ReferenceCountUtil.safeRelease(msg);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user