LocalChannel write when peer closed leak
Motivation: If LocalChannel doWrite executes while the peer's state changes from CONNECTED to CLOSED it is possible that some promise's won't be completed and buffers will be leaked. Modifications: - Check the peer's state in doWrite to avoid a race condition Result: All write operations should release, and the associated promise should be completed.
This commit is contained in:
parent
fd70dd658e
commit
71308376ca
@ -27,8 +27,9 @@ import io.netty.channel.DefaultChannelConfig;
|
||||
import io.netty.channel.EventLoop;
|
||||
import io.netty.channel.SingleThreadEventLoop;
|
||||
import io.netty.util.ReferenceCountUtil;
|
||||
import io.netty.util.concurrent.SingleThreadEventExecutor;
|
||||
import io.netty.util.concurrent.Future;
|
||||
import io.netty.util.concurrent.SingleThreadEventExecutor;
|
||||
import io.netty.util.internal.EmptyArrays;
|
||||
import io.netty.util.internal.InternalThreadLocalMap;
|
||||
import io.netty.util.internal.OneTimeTask;
|
||||
import io.netty.util.internal.PlatformDependent;
|
||||
@ -50,9 +51,10 @@ public class LocalChannel extends AbstractChannel {
|
||||
private static final AtomicReferenceFieldUpdater<LocalChannel, Future> FINISH_READ_FUTURE_UPDATER;
|
||||
private static final ChannelMetadata METADATA = new ChannelMetadata(false);
|
||||
private static final int MAX_READER_STACK_DEPTH = 8;
|
||||
private static final ClosedChannelException CLOSED_CHANNEL_EXCEPTION = new ClosedChannelException();
|
||||
|
||||
private final ChannelConfig config = new DefaultChannelConfig(this);
|
||||
// To futher optimize this we could write our own SPSC queue.
|
||||
// To further optimize this we could write our own SPSC queue.
|
||||
private final Queue<Object> inboundBuffer = PlatformDependent.newMpscQueue();
|
||||
private final Runnable readTask = new Runnable() {
|
||||
@Override
|
||||
@ -94,6 +96,7 @@ public class LocalChannel extends AbstractChannel {
|
||||
AtomicReferenceFieldUpdater.newUpdater(LocalChannel.class, Future.class, "finishReadFuture");
|
||||
}
|
||||
FINISH_READ_FUTURE_UPDATER = finishReadFutureUpdater;
|
||||
CLOSED_CHANNEL_EXCEPTION.setStackTrace(EmptyArrays.EMPTY_STACK_TRACE);
|
||||
}
|
||||
|
||||
public LocalChannel() {
|
||||
@ -216,10 +219,6 @@ public class LocalChannel extends AbstractChannel {
|
||||
protected void doClose() throws Exception {
|
||||
final LocalChannel peer = this.peer;
|
||||
if (state <= 2) {
|
||||
// To preserve ordering of events we must process any pending reads
|
||||
if (writeInProgress && peer != null) {
|
||||
finishPeerRead(peer);
|
||||
}
|
||||
// Update all internal state before the closeFuture is notified.
|
||||
if (localAddress != null) {
|
||||
if (parent() == null) {
|
||||
@ -227,7 +226,22 @@ public class LocalChannel extends AbstractChannel {
|
||||
}
|
||||
localAddress = null;
|
||||
}
|
||||
|
||||
// State change must happen before finishPeerRead to ensure writes are released either in doWrite or
|
||||
// channelRead.
|
||||
state = 3;
|
||||
|
||||
ChannelPromise promise = connectPromise;
|
||||
if (promise != null) {
|
||||
// Use tryFailure() instead of setFailure() to avoid the race against cancel().
|
||||
promise.tryFailure(CLOSED_CHANNEL_EXCEPTION);
|
||||
connectPromise = null;
|
||||
}
|
||||
|
||||
// To preserve ordering of events we must process any pending reads
|
||||
if (writeInProgress && peer != null) {
|
||||
finishPeerRead(peer);
|
||||
}
|
||||
}
|
||||
|
||||
if (peer != null && peer.isActive()) {
|
||||
@ -239,12 +253,18 @@ public class LocalChannel extends AbstractChannel {
|
||||
} else {
|
||||
// This value may change, and so we should save it before executing the Runnable.
|
||||
final boolean peerWriteInProgress = peer.writeInProgress;
|
||||
try {
|
||||
peer.eventLoop().execute(new OneTimeTask() {
|
||||
@Override
|
||||
public void run() {
|
||||
doPeerClose(peer, peerWriteInProgress);
|
||||
}
|
||||
});
|
||||
} catch (RuntimeException e) {
|
||||
// The peer close may attempt to drain this.inboundBuffers. If that fails make sure it is drained.
|
||||
releaseInboundBuffers();
|
||||
throw e;
|
||||
}
|
||||
}
|
||||
this.peer = null;
|
||||
}
|
||||
@ -293,7 +313,12 @@ public class LocalChannel extends AbstractChannel {
|
||||
threadLocals.setLocalChannelReaderStackDepth(stackDepth);
|
||||
}
|
||||
} else {
|
||||
try {
|
||||
eventLoop().execute(readTask);
|
||||
} catch (RuntimeException e) {
|
||||
releaseInboundBuffers();
|
||||
throw e;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -303,7 +328,7 @@ public class LocalChannel extends AbstractChannel {
|
||||
throw new NotYetConnectedException();
|
||||
}
|
||||
if (state > 2) {
|
||||
throw new ClosedChannelException();
|
||||
throw CLOSED_CHANNEL_EXCEPTION;
|
||||
}
|
||||
|
||||
final LocalChannel peer = this.peer;
|
||||
@ -316,8 +341,14 @@ public class LocalChannel extends AbstractChannel {
|
||||
break;
|
||||
}
|
||||
try {
|
||||
// It is possible the peer could have closed while we are writing, and in this case we should
|
||||
// simulate real socket behavior and ensure the write operation is failed.
|
||||
if (peer.state == 2) {
|
||||
peer.inboundBuffer.add(ReferenceCountUtil.retain(msg));
|
||||
in.remove();
|
||||
} else {
|
||||
in.remove(CLOSED_CHANNEL_EXCEPTION);
|
||||
}
|
||||
} catch (Throwable cause) {
|
||||
in.remove(cause);
|
||||
}
|
||||
@ -352,11 +383,26 @@ public class LocalChannel extends AbstractChannel {
|
||||
finishPeerRead0(peer);
|
||||
}
|
||||
};
|
||||
try {
|
||||
if (peer.writeInProgress) {
|
||||
peer.finishReadFuture = peer.eventLoop().submit(finishPeerReadTask);
|
||||
} else {
|
||||
peer.eventLoop().execute(finishPeerReadTask);
|
||||
}
|
||||
} catch (RuntimeException e) {
|
||||
peer.releaseInboundBuffers();
|
||||
throw e;
|
||||
}
|
||||
}
|
||||
|
||||
private void releaseInboundBuffers() {
|
||||
for (;;) {
|
||||
Object o = inboundBuffer.poll();
|
||||
if (o == null) {
|
||||
break;
|
||||
}
|
||||
ReferenceCountUtil.release(o);
|
||||
}
|
||||
}
|
||||
|
||||
private void finishPeerRead0(LocalChannel peer) {
|
||||
|
@ -339,8 +339,8 @@ public class LocalChannelTest {
|
||||
@Override
|
||||
public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception {
|
||||
if (msg.equals(data)) {
|
||||
messageLatch.countDown();
|
||||
ReferenceCountUtil.safeRelease(msg);
|
||||
messageLatch.countDown();
|
||||
} else {
|
||||
super.channelRead(ctx, msg);
|
||||
}
|
||||
@ -408,8 +408,8 @@ public class LocalChannelTest {
|
||||
public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception {
|
||||
final long count = messageLatch.getCount();
|
||||
if ((data.equals(msg) && count == 2) || (data2.equals(msg) && count == 1)) {
|
||||
messageLatch.countDown();
|
||||
ReferenceCountUtil.safeRelease(msg);
|
||||
messageLatch.countDown();
|
||||
} else {
|
||||
super.channelRead(ctx, msg);
|
||||
}
|
||||
@ -468,8 +468,8 @@ public class LocalChannelTest {
|
||||
@Override
|
||||
public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception {
|
||||
if (data2.equals(msg)) {
|
||||
messageLatch.countDown();
|
||||
ReferenceCountUtil.safeRelease(msg);
|
||||
messageLatch.countDown();
|
||||
} else {
|
||||
super.channelRead(ctx, msg);
|
||||
}
|
||||
@ -485,8 +485,8 @@ public class LocalChannelTest {
|
||||
@Override
|
||||
public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception {
|
||||
if (data.equals(msg)) {
|
||||
messageLatch.countDown();
|
||||
ReferenceCountUtil.safeRelease(msg);
|
||||
messageLatch.countDown();
|
||||
} else {
|
||||
super.channelRead(ctx, msg);
|
||||
}
|
||||
@ -550,8 +550,8 @@ public class LocalChannelTest {
|
||||
@Override
|
||||
public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception {
|
||||
if (data2.equals(msg) && messageLatch.getCount() == 1) {
|
||||
messageLatch.countDown();
|
||||
ReferenceCountUtil.safeRelease(msg);
|
||||
messageLatch.countDown();
|
||||
} else {
|
||||
super.channelRead(ctx, msg);
|
||||
}
|
||||
@ -567,8 +567,8 @@ public class LocalChannelTest {
|
||||
@Override
|
||||
public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception {
|
||||
if (data.equals(msg) && messageLatch.getCount() == 2) {
|
||||
messageLatch.countDown();
|
||||
ReferenceCountUtil.safeRelease(msg);
|
||||
messageLatch.countDown();
|
||||
} else {
|
||||
super.channelRead(ctx, msg);
|
||||
}
|
||||
@ -641,8 +641,8 @@ public class LocalChannelTest {
|
||||
@Override
|
||||
public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception {
|
||||
if (msg.equals(data)) {
|
||||
messageLatch.countDown();
|
||||
ReferenceCountUtil.safeRelease(msg);
|
||||
messageLatch.countDown();
|
||||
} else {
|
||||
super.channelRead(ctx, msg);
|
||||
}
|
||||
@ -697,6 +697,130 @@ public class LocalChannelTest {
|
||||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testWriteWhilePeerIsClosedReleaseObjectAndFailPromise() throws InterruptedException {
|
||||
Bootstrap cb = new Bootstrap();
|
||||
ServerBootstrap sb = new ServerBootstrap();
|
||||
final CountDownLatch serverMessageLatch = new CountDownLatch(1);
|
||||
final LatchChannelFutureListener serverChannelCloseLatch = new LatchChannelFutureListener(1);
|
||||
final LatchChannelFutureListener clientChannelCloseLatch = new LatchChannelFutureListener(1);
|
||||
final CountDownLatch writeFailLatch = new CountDownLatch(1);
|
||||
final ByteBuf data = Unpooled.wrappedBuffer(new byte[1024]);
|
||||
final ByteBuf data2 = Unpooled.wrappedBuffer(new byte[512]);
|
||||
final CountDownLatch serverChannelLatch = new CountDownLatch(1);
|
||||
final AtomicReference<Channel> serverChannelRef = new AtomicReference<Channel>();
|
||||
|
||||
try {
|
||||
cb.group(group1)
|
||||
.channel(LocalChannel.class)
|
||||
.handler(new TestHandler());
|
||||
|
||||
sb.group(group2)
|
||||
.channel(LocalServerChannel.class)
|
||||
.childHandler(new ChannelInitializer<LocalChannel>() {
|
||||
@Override
|
||||
public void initChannel(LocalChannel ch) throws Exception {
|
||||
ch.pipeline().addLast(new ChannelInboundHandlerAdapter() {
|
||||
@Override
|
||||
public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception {
|
||||
if (data.equals(msg)) {
|
||||
ReferenceCountUtil.safeRelease(msg);
|
||||
serverMessageLatch.countDown();
|
||||
} else {
|
||||
super.channelRead(ctx, msg);
|
||||
}
|
||||
}
|
||||
});
|
||||
serverChannelRef.set(ch);
|
||||
serverChannelLatch.countDown();
|
||||
}
|
||||
});
|
||||
|
||||
Channel sc = null;
|
||||
Channel cc = null;
|
||||
try {
|
||||
// Start server
|
||||
sc = sb.bind(TEST_ADDRESS).syncUninterruptibly().channel();
|
||||
|
||||
// Connect to the server
|
||||
cc = cb.connect(sc.localAddress()).syncUninterruptibly().channel();
|
||||
assertTrue(serverChannelLatch.await(5, SECONDS));
|
||||
|
||||
final Channel ccCpy = cc;
|
||||
final Channel serverChannelCpy = serverChannelRef.get();
|
||||
serverChannelCpy.closeFuture().addListener(serverChannelCloseLatch);
|
||||
ccCpy.closeFuture().addListener(clientChannelCloseLatch);
|
||||
|
||||
// Make sure a write operation is executed in the eventloop
|
||||
cc.pipeline().lastContext().executor().execute(new OneTimeTask() {
|
||||
@Override
|
||||
public void run() {
|
||||
ccCpy.writeAndFlush(data.duplicate().retain(), ccCpy.newPromise())
|
||||
.addListener(new ChannelFutureListener() {
|
||||
@Override
|
||||
public void operationComplete(ChannelFuture future) throws Exception {
|
||||
serverChannelCpy.eventLoop().execute(new OneTimeTask() {
|
||||
@Override
|
||||
public void run() {
|
||||
// The point of this test is to write while the peer is closed, so we should
|
||||
// ensure the peer is actually closed before we write.
|
||||
int waitCount = 0;
|
||||
while (ccCpy.isOpen()) {
|
||||
try {
|
||||
Thread.sleep(50);
|
||||
} catch (InterruptedException ignored) {
|
||||
// ignored
|
||||
}
|
||||
if (++waitCount > 5) {
|
||||
fail();
|
||||
}
|
||||
}
|
||||
serverChannelCpy.writeAndFlush(data2.duplicate().retain(),
|
||||
serverChannelCpy.newPromise())
|
||||
.addListener(new ChannelFutureListener() {
|
||||
@Override
|
||||
public void operationComplete(ChannelFuture future) throws Exception {
|
||||
if (!future.isSuccess() &&
|
||||
future.cause() instanceof ClosedChannelException) {
|
||||
writeFailLatch.countDown();
|
||||
}
|
||||
}
|
||||
});
|
||||
}
|
||||
});
|
||||
ccCpy.close();
|
||||
}
|
||||
});
|
||||
}
|
||||
});
|
||||
|
||||
assertTrue(serverMessageLatch.await(5, SECONDS));
|
||||
assertTrue(writeFailLatch.await(5, SECONDS));
|
||||
assertTrue(serverChannelCloseLatch.await(5, SECONDS));
|
||||
assertTrue(clientChannelCloseLatch.await(5, SECONDS));
|
||||
assertFalse(ccCpy.isOpen());
|
||||
assertFalse(serverChannelCpy.isOpen());
|
||||
} finally {
|
||||
closeChannel(cc);
|
||||
closeChannel(sc);
|
||||
}
|
||||
} finally {
|
||||
data.release();
|
||||
data2.release();
|
||||
}
|
||||
}
|
||||
|
||||
private static final class LatchChannelFutureListener extends CountDownLatch implements ChannelFutureListener {
|
||||
public LatchChannelFutureListener(int count) {
|
||||
super(count);
|
||||
}
|
||||
|
||||
@Override
|
||||
public void operationComplete(ChannelFuture future) throws Exception {
|
||||
countDown();
|
||||
}
|
||||
}
|
||||
|
||||
private static void closeChannel(Channel cc) {
|
||||
if (cc != null) {
|
||||
cc.close().syncUninterruptibly();
|
||||
@ -707,6 +831,7 @@ public class LocalChannelTest {
|
||||
@Override
|
||||
public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception {
|
||||
logger.info(String.format("Received mesage: %s", msg));
|
||||
ReferenceCountUtil.safeRelease(msg);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
Loading…
x
Reference in New Issue
Block a user