LocalChannel write when peer closed leak
Motivation: If LocalChannel doWrite executes while the peer's state changes from CONNECTED to CLOSED it is possible that some promise's won't be completed and buffers will be leaked. Modifications: - Check the peer's state in doWrite to avoid a race condition Result: All write operations should release, and the associated promise should be completed.
This commit is contained in:
parent
fd70dd658e
commit
71308376ca
@ -27,8 +27,9 @@ import io.netty.channel.DefaultChannelConfig;
|
|||||||
import io.netty.channel.EventLoop;
|
import io.netty.channel.EventLoop;
|
||||||
import io.netty.channel.SingleThreadEventLoop;
|
import io.netty.channel.SingleThreadEventLoop;
|
||||||
import io.netty.util.ReferenceCountUtil;
|
import io.netty.util.ReferenceCountUtil;
|
||||||
import io.netty.util.concurrent.SingleThreadEventExecutor;
|
|
||||||
import io.netty.util.concurrent.Future;
|
import io.netty.util.concurrent.Future;
|
||||||
|
import io.netty.util.concurrent.SingleThreadEventExecutor;
|
||||||
|
import io.netty.util.internal.EmptyArrays;
|
||||||
import io.netty.util.internal.InternalThreadLocalMap;
|
import io.netty.util.internal.InternalThreadLocalMap;
|
||||||
import io.netty.util.internal.OneTimeTask;
|
import io.netty.util.internal.OneTimeTask;
|
||||||
import io.netty.util.internal.PlatformDependent;
|
import io.netty.util.internal.PlatformDependent;
|
||||||
@ -50,9 +51,10 @@ public class LocalChannel extends AbstractChannel {
|
|||||||
private static final AtomicReferenceFieldUpdater<LocalChannel, Future> FINISH_READ_FUTURE_UPDATER;
|
private static final AtomicReferenceFieldUpdater<LocalChannel, Future> FINISH_READ_FUTURE_UPDATER;
|
||||||
private static final ChannelMetadata METADATA = new ChannelMetadata(false);
|
private static final ChannelMetadata METADATA = new ChannelMetadata(false);
|
||||||
private static final int MAX_READER_STACK_DEPTH = 8;
|
private static final int MAX_READER_STACK_DEPTH = 8;
|
||||||
|
private static final ClosedChannelException CLOSED_CHANNEL_EXCEPTION = new ClosedChannelException();
|
||||||
|
|
||||||
private final ChannelConfig config = new DefaultChannelConfig(this);
|
private final ChannelConfig config = new DefaultChannelConfig(this);
|
||||||
// To futher optimize this we could write our own SPSC queue.
|
// To further optimize this we could write our own SPSC queue.
|
||||||
private final Queue<Object> inboundBuffer = PlatformDependent.newMpscQueue();
|
private final Queue<Object> inboundBuffer = PlatformDependent.newMpscQueue();
|
||||||
private final Runnable readTask = new Runnable() {
|
private final Runnable readTask = new Runnable() {
|
||||||
@Override
|
@Override
|
||||||
@ -94,6 +96,7 @@ public class LocalChannel extends AbstractChannel {
|
|||||||
AtomicReferenceFieldUpdater.newUpdater(LocalChannel.class, Future.class, "finishReadFuture");
|
AtomicReferenceFieldUpdater.newUpdater(LocalChannel.class, Future.class, "finishReadFuture");
|
||||||
}
|
}
|
||||||
FINISH_READ_FUTURE_UPDATER = finishReadFutureUpdater;
|
FINISH_READ_FUTURE_UPDATER = finishReadFutureUpdater;
|
||||||
|
CLOSED_CHANNEL_EXCEPTION.setStackTrace(EmptyArrays.EMPTY_STACK_TRACE);
|
||||||
}
|
}
|
||||||
|
|
||||||
public LocalChannel() {
|
public LocalChannel() {
|
||||||
@ -216,10 +219,6 @@ public class LocalChannel extends AbstractChannel {
|
|||||||
protected void doClose() throws Exception {
|
protected void doClose() throws Exception {
|
||||||
final LocalChannel peer = this.peer;
|
final LocalChannel peer = this.peer;
|
||||||
if (state <= 2) {
|
if (state <= 2) {
|
||||||
// To preserve ordering of events we must process any pending reads
|
|
||||||
if (writeInProgress && peer != null) {
|
|
||||||
finishPeerRead(peer);
|
|
||||||
}
|
|
||||||
// Update all internal state before the closeFuture is notified.
|
// Update all internal state before the closeFuture is notified.
|
||||||
if (localAddress != null) {
|
if (localAddress != null) {
|
||||||
if (parent() == null) {
|
if (parent() == null) {
|
||||||
@ -227,7 +226,22 @@ public class LocalChannel extends AbstractChannel {
|
|||||||
}
|
}
|
||||||
localAddress = null;
|
localAddress = null;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// State change must happen before finishPeerRead to ensure writes are released either in doWrite or
|
||||||
|
// channelRead.
|
||||||
state = 3;
|
state = 3;
|
||||||
|
|
||||||
|
ChannelPromise promise = connectPromise;
|
||||||
|
if (promise != null) {
|
||||||
|
// Use tryFailure() instead of setFailure() to avoid the race against cancel().
|
||||||
|
promise.tryFailure(CLOSED_CHANNEL_EXCEPTION);
|
||||||
|
connectPromise = null;
|
||||||
|
}
|
||||||
|
|
||||||
|
// To preserve ordering of events we must process any pending reads
|
||||||
|
if (writeInProgress && peer != null) {
|
||||||
|
finishPeerRead(peer);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if (peer != null && peer.isActive()) {
|
if (peer != null && peer.isActive()) {
|
||||||
@ -239,12 +253,18 @@ public class LocalChannel extends AbstractChannel {
|
|||||||
} else {
|
} else {
|
||||||
// This value may change, and so we should save it before executing the Runnable.
|
// This value may change, and so we should save it before executing the Runnable.
|
||||||
final boolean peerWriteInProgress = peer.writeInProgress;
|
final boolean peerWriteInProgress = peer.writeInProgress;
|
||||||
peer.eventLoop().execute(new OneTimeTask() {
|
try {
|
||||||
@Override
|
peer.eventLoop().execute(new OneTimeTask() {
|
||||||
public void run() {
|
@Override
|
||||||
doPeerClose(peer, peerWriteInProgress);
|
public void run() {
|
||||||
}
|
doPeerClose(peer, peerWriteInProgress);
|
||||||
});
|
}
|
||||||
|
});
|
||||||
|
} catch (RuntimeException e) {
|
||||||
|
// The peer close may attempt to drain this.inboundBuffers. If that fails make sure it is drained.
|
||||||
|
releaseInboundBuffers();
|
||||||
|
throw e;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
this.peer = null;
|
this.peer = null;
|
||||||
}
|
}
|
||||||
@ -293,7 +313,12 @@ public class LocalChannel extends AbstractChannel {
|
|||||||
threadLocals.setLocalChannelReaderStackDepth(stackDepth);
|
threadLocals.setLocalChannelReaderStackDepth(stackDepth);
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
eventLoop().execute(readTask);
|
try {
|
||||||
|
eventLoop().execute(readTask);
|
||||||
|
} catch (RuntimeException e) {
|
||||||
|
releaseInboundBuffers();
|
||||||
|
throw e;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -303,7 +328,7 @@ public class LocalChannel extends AbstractChannel {
|
|||||||
throw new NotYetConnectedException();
|
throw new NotYetConnectedException();
|
||||||
}
|
}
|
||||||
if (state > 2) {
|
if (state > 2) {
|
||||||
throw new ClosedChannelException();
|
throw CLOSED_CHANNEL_EXCEPTION;
|
||||||
}
|
}
|
||||||
|
|
||||||
final LocalChannel peer = this.peer;
|
final LocalChannel peer = this.peer;
|
||||||
@ -316,8 +341,14 @@ public class LocalChannel extends AbstractChannel {
|
|||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
try {
|
try {
|
||||||
peer.inboundBuffer.add(ReferenceCountUtil.retain(msg));
|
// It is possible the peer could have closed while we are writing, and in this case we should
|
||||||
in.remove();
|
// simulate real socket behavior and ensure the write operation is failed.
|
||||||
|
if (peer.state == 2) {
|
||||||
|
peer.inboundBuffer.add(ReferenceCountUtil.retain(msg));
|
||||||
|
in.remove();
|
||||||
|
} else {
|
||||||
|
in.remove(CLOSED_CHANNEL_EXCEPTION);
|
||||||
|
}
|
||||||
} catch (Throwable cause) {
|
} catch (Throwable cause) {
|
||||||
in.remove(cause);
|
in.remove(cause);
|
||||||
}
|
}
|
||||||
@ -352,10 +383,25 @@ public class LocalChannel extends AbstractChannel {
|
|||||||
finishPeerRead0(peer);
|
finishPeerRead0(peer);
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
if (peer.writeInProgress) {
|
try {
|
||||||
peer.finishReadFuture = peer.eventLoop().submit(finishPeerReadTask);
|
if (peer.writeInProgress) {
|
||||||
} else {
|
peer.finishReadFuture = peer.eventLoop().submit(finishPeerReadTask);
|
||||||
peer.eventLoop().execute(finishPeerReadTask);
|
} else {
|
||||||
|
peer.eventLoop().execute(finishPeerReadTask);
|
||||||
|
}
|
||||||
|
} catch (RuntimeException e) {
|
||||||
|
peer.releaseInboundBuffers();
|
||||||
|
throw e;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private void releaseInboundBuffers() {
|
||||||
|
for (;;) {
|
||||||
|
Object o = inboundBuffer.poll();
|
||||||
|
if (o == null) {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
ReferenceCountUtil.release(o);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -339,8 +339,8 @@ public class LocalChannelTest {
|
|||||||
@Override
|
@Override
|
||||||
public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception {
|
public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception {
|
||||||
if (msg.equals(data)) {
|
if (msg.equals(data)) {
|
||||||
messageLatch.countDown();
|
|
||||||
ReferenceCountUtil.safeRelease(msg);
|
ReferenceCountUtil.safeRelease(msg);
|
||||||
|
messageLatch.countDown();
|
||||||
} else {
|
} else {
|
||||||
super.channelRead(ctx, msg);
|
super.channelRead(ctx, msg);
|
||||||
}
|
}
|
||||||
@ -408,8 +408,8 @@ public class LocalChannelTest {
|
|||||||
public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception {
|
public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception {
|
||||||
final long count = messageLatch.getCount();
|
final long count = messageLatch.getCount();
|
||||||
if ((data.equals(msg) && count == 2) || (data2.equals(msg) && count == 1)) {
|
if ((data.equals(msg) && count == 2) || (data2.equals(msg) && count == 1)) {
|
||||||
messageLatch.countDown();
|
|
||||||
ReferenceCountUtil.safeRelease(msg);
|
ReferenceCountUtil.safeRelease(msg);
|
||||||
|
messageLatch.countDown();
|
||||||
} else {
|
} else {
|
||||||
super.channelRead(ctx, msg);
|
super.channelRead(ctx, msg);
|
||||||
}
|
}
|
||||||
@ -468,8 +468,8 @@ public class LocalChannelTest {
|
|||||||
@Override
|
@Override
|
||||||
public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception {
|
public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception {
|
||||||
if (data2.equals(msg)) {
|
if (data2.equals(msg)) {
|
||||||
messageLatch.countDown();
|
|
||||||
ReferenceCountUtil.safeRelease(msg);
|
ReferenceCountUtil.safeRelease(msg);
|
||||||
|
messageLatch.countDown();
|
||||||
} else {
|
} else {
|
||||||
super.channelRead(ctx, msg);
|
super.channelRead(ctx, msg);
|
||||||
}
|
}
|
||||||
@ -485,8 +485,8 @@ public class LocalChannelTest {
|
|||||||
@Override
|
@Override
|
||||||
public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception {
|
public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception {
|
||||||
if (data.equals(msg)) {
|
if (data.equals(msg)) {
|
||||||
messageLatch.countDown();
|
|
||||||
ReferenceCountUtil.safeRelease(msg);
|
ReferenceCountUtil.safeRelease(msg);
|
||||||
|
messageLatch.countDown();
|
||||||
} else {
|
} else {
|
||||||
super.channelRead(ctx, msg);
|
super.channelRead(ctx, msg);
|
||||||
}
|
}
|
||||||
@ -550,8 +550,8 @@ public class LocalChannelTest {
|
|||||||
@Override
|
@Override
|
||||||
public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception {
|
public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception {
|
||||||
if (data2.equals(msg) && messageLatch.getCount() == 1) {
|
if (data2.equals(msg) && messageLatch.getCount() == 1) {
|
||||||
messageLatch.countDown();
|
|
||||||
ReferenceCountUtil.safeRelease(msg);
|
ReferenceCountUtil.safeRelease(msg);
|
||||||
|
messageLatch.countDown();
|
||||||
} else {
|
} else {
|
||||||
super.channelRead(ctx, msg);
|
super.channelRead(ctx, msg);
|
||||||
}
|
}
|
||||||
@ -567,8 +567,8 @@ public class LocalChannelTest {
|
|||||||
@Override
|
@Override
|
||||||
public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception {
|
public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception {
|
||||||
if (data.equals(msg) && messageLatch.getCount() == 2) {
|
if (data.equals(msg) && messageLatch.getCount() == 2) {
|
||||||
messageLatch.countDown();
|
|
||||||
ReferenceCountUtil.safeRelease(msg);
|
ReferenceCountUtil.safeRelease(msg);
|
||||||
|
messageLatch.countDown();
|
||||||
} else {
|
} else {
|
||||||
super.channelRead(ctx, msg);
|
super.channelRead(ctx, msg);
|
||||||
}
|
}
|
||||||
@ -641,8 +641,8 @@ public class LocalChannelTest {
|
|||||||
@Override
|
@Override
|
||||||
public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception {
|
public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception {
|
||||||
if (msg.equals(data)) {
|
if (msg.equals(data)) {
|
||||||
messageLatch.countDown();
|
|
||||||
ReferenceCountUtil.safeRelease(msg);
|
ReferenceCountUtil.safeRelease(msg);
|
||||||
|
messageLatch.countDown();
|
||||||
} else {
|
} else {
|
||||||
super.channelRead(ctx, msg);
|
super.channelRead(ctx, msg);
|
||||||
}
|
}
|
||||||
@ -697,6 +697,130 @@ public class LocalChannelTest {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void testWriteWhilePeerIsClosedReleaseObjectAndFailPromise() throws InterruptedException {
|
||||||
|
Bootstrap cb = new Bootstrap();
|
||||||
|
ServerBootstrap sb = new ServerBootstrap();
|
||||||
|
final CountDownLatch serverMessageLatch = new CountDownLatch(1);
|
||||||
|
final LatchChannelFutureListener serverChannelCloseLatch = new LatchChannelFutureListener(1);
|
||||||
|
final LatchChannelFutureListener clientChannelCloseLatch = new LatchChannelFutureListener(1);
|
||||||
|
final CountDownLatch writeFailLatch = new CountDownLatch(1);
|
||||||
|
final ByteBuf data = Unpooled.wrappedBuffer(new byte[1024]);
|
||||||
|
final ByteBuf data2 = Unpooled.wrappedBuffer(new byte[512]);
|
||||||
|
final CountDownLatch serverChannelLatch = new CountDownLatch(1);
|
||||||
|
final AtomicReference<Channel> serverChannelRef = new AtomicReference<Channel>();
|
||||||
|
|
||||||
|
try {
|
||||||
|
cb.group(group1)
|
||||||
|
.channel(LocalChannel.class)
|
||||||
|
.handler(new TestHandler());
|
||||||
|
|
||||||
|
sb.group(group2)
|
||||||
|
.channel(LocalServerChannel.class)
|
||||||
|
.childHandler(new ChannelInitializer<LocalChannel>() {
|
||||||
|
@Override
|
||||||
|
public void initChannel(LocalChannel ch) throws Exception {
|
||||||
|
ch.pipeline().addLast(new ChannelInboundHandlerAdapter() {
|
||||||
|
@Override
|
||||||
|
public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception {
|
||||||
|
if (data.equals(msg)) {
|
||||||
|
ReferenceCountUtil.safeRelease(msg);
|
||||||
|
serverMessageLatch.countDown();
|
||||||
|
} else {
|
||||||
|
super.channelRead(ctx, msg);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
});
|
||||||
|
serverChannelRef.set(ch);
|
||||||
|
serverChannelLatch.countDown();
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
|
Channel sc = null;
|
||||||
|
Channel cc = null;
|
||||||
|
try {
|
||||||
|
// Start server
|
||||||
|
sc = sb.bind(TEST_ADDRESS).syncUninterruptibly().channel();
|
||||||
|
|
||||||
|
// Connect to the server
|
||||||
|
cc = cb.connect(sc.localAddress()).syncUninterruptibly().channel();
|
||||||
|
assertTrue(serverChannelLatch.await(5, SECONDS));
|
||||||
|
|
||||||
|
final Channel ccCpy = cc;
|
||||||
|
final Channel serverChannelCpy = serverChannelRef.get();
|
||||||
|
serverChannelCpy.closeFuture().addListener(serverChannelCloseLatch);
|
||||||
|
ccCpy.closeFuture().addListener(clientChannelCloseLatch);
|
||||||
|
|
||||||
|
// Make sure a write operation is executed in the eventloop
|
||||||
|
cc.pipeline().lastContext().executor().execute(new OneTimeTask() {
|
||||||
|
@Override
|
||||||
|
public void run() {
|
||||||
|
ccCpy.writeAndFlush(data.duplicate().retain(), ccCpy.newPromise())
|
||||||
|
.addListener(new ChannelFutureListener() {
|
||||||
|
@Override
|
||||||
|
public void operationComplete(ChannelFuture future) throws Exception {
|
||||||
|
serverChannelCpy.eventLoop().execute(new OneTimeTask() {
|
||||||
|
@Override
|
||||||
|
public void run() {
|
||||||
|
// The point of this test is to write while the peer is closed, so we should
|
||||||
|
// ensure the peer is actually closed before we write.
|
||||||
|
int waitCount = 0;
|
||||||
|
while (ccCpy.isOpen()) {
|
||||||
|
try {
|
||||||
|
Thread.sleep(50);
|
||||||
|
} catch (InterruptedException ignored) {
|
||||||
|
// ignored
|
||||||
|
}
|
||||||
|
if (++waitCount > 5) {
|
||||||
|
fail();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
serverChannelCpy.writeAndFlush(data2.duplicate().retain(),
|
||||||
|
serverChannelCpy.newPromise())
|
||||||
|
.addListener(new ChannelFutureListener() {
|
||||||
|
@Override
|
||||||
|
public void operationComplete(ChannelFuture future) throws Exception {
|
||||||
|
if (!future.isSuccess() &&
|
||||||
|
future.cause() instanceof ClosedChannelException) {
|
||||||
|
writeFailLatch.countDown();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
});
|
||||||
|
}
|
||||||
|
});
|
||||||
|
ccCpy.close();
|
||||||
|
}
|
||||||
|
});
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
|
assertTrue(serverMessageLatch.await(5, SECONDS));
|
||||||
|
assertTrue(writeFailLatch.await(5, SECONDS));
|
||||||
|
assertTrue(serverChannelCloseLatch.await(5, SECONDS));
|
||||||
|
assertTrue(clientChannelCloseLatch.await(5, SECONDS));
|
||||||
|
assertFalse(ccCpy.isOpen());
|
||||||
|
assertFalse(serverChannelCpy.isOpen());
|
||||||
|
} finally {
|
||||||
|
closeChannel(cc);
|
||||||
|
closeChannel(sc);
|
||||||
|
}
|
||||||
|
} finally {
|
||||||
|
data.release();
|
||||||
|
data2.release();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private static final class LatchChannelFutureListener extends CountDownLatch implements ChannelFutureListener {
|
||||||
|
public LatchChannelFutureListener(int count) {
|
||||||
|
super(count);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void operationComplete(ChannelFuture future) throws Exception {
|
||||||
|
countDown();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
private static void closeChannel(Channel cc) {
|
private static void closeChannel(Channel cc) {
|
||||||
if (cc != null) {
|
if (cc != null) {
|
||||||
cc.close().syncUninterruptibly();
|
cc.close().syncUninterruptibly();
|
||||||
@ -707,6 +831,7 @@ public class LocalChannelTest {
|
|||||||
@Override
|
@Override
|
||||||
public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception {
|
public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception {
|
||||||
logger.info(String.format("Received mesage: %s", msg));
|
logger.info(String.format("Received mesage: %s", msg));
|
||||||
|
ReferenceCountUtil.safeRelease(msg);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user