LocalChannel Event Ordering Error

Motivation:
When a LocalChannel write operation occurs, the promise associated with the write operation is marked successful when it is added to the peer's queue, but before the peer has actually received the data. If the promise callback closes the channel then a race condition exists where the close event may occur before the data is delivered. We should preserve ordering of events.

Modifications:
- LocalChannel should track when a write is in progress, and if a close operation happens make sure the peer gets all pending read operations.

Result:
LocalChannel preserves order of operations.
Fixes https://github.com/netty/netty/issues/4118
This commit is contained in:
Scott Mitchell 2015-08-20 19:19:51 -07:00
parent 48662bf41d
commit e37069b947
2 changed files with 582 additions and 126 deletions

View File

@ -81,6 +81,7 @@ public class LocalChannel extends AbstractChannel {
private volatile ChannelPromise connectPromise;
private volatile boolean readInProgress;
private volatile boolean registerInProgress;
private volatile boolean writeInProgress;
public LocalChannel() {
super(null);
@ -173,7 +174,7 @@ public class LocalChannel extends AbstractChannel {
// This ensures that if both channels are on the same event loop, the peer's channelActive
// event is triggered *after* this channel's channelRegistered event, so that this channel's
// pipeline is fully initialized by ChannelInitializer before any channelRead events.
peer.eventLoop().execute(new Runnable() {
peer.eventLoop().execute(new OneTimeTask() {
@Override
public void run() {
registerInProgress = false;
@ -200,7 +201,12 @@ public class LocalChannel extends AbstractChannel {
@Override
protected void doClose() throws Exception {
final LocalChannel peer = this.peer;
if (state != State.CLOSED) {
// To preserve ordering of events we must process any pending reads
if (writeInProgress && peer != null) {
finishPeerRead(peer);
}
// Update all internal state before the closeFuture is notified.
if (localAddress != null) {
if (parent() == null) {
@ -211,23 +217,19 @@ public class LocalChannel extends AbstractChannel {
state = State.CLOSED;
}
final LocalChannel peer = this.peer;
if (peer != null && peer.isActive()) {
// Need to execute the close in the correct EventLoop
// See https://github.com/netty/netty/issues/1777
EventLoop eventLoop = peer.eventLoop();
// Need to execute the close in the correct EventLoop (see https://github.com/netty/netty/issues/1777).
// Also check if the registration was not done yet. In this case we submit the close to the EventLoop
// to make sure it is run after the registration completes.
//
// See https://github.com/netty/netty/issues/2144
if (eventLoop.inEventLoop() && !registerInProgress) {
peer.unsafe().close(unsafe().voidPromise());
// to make sure its run after the registration completes (see https://github.com/netty/netty/issues/2144).
if (peer.eventLoop().inEventLoop() && !registerInProgress) {
doPeerClose(peer, peer.writeInProgress);
} else {
peer.eventLoop().execute(new Runnable() {
// This value may change, and so we should save it before executing the Runnable.
final boolean peerWriteInProgress = peer.writeInProgress;
peer.eventLoop().execute(new OneTimeTask() {
@Override
public void run() {
peer.unsafe().close(unsafe().voidPromise());
doPeerClose(peer, peerWriteInProgress);
}
});
}
@ -235,6 +237,13 @@ public class LocalChannel extends AbstractChannel {
}
}
private void doPeerClose(LocalChannel peer, boolean peerWriteInProgress) {
if (peerWriteInProgress) {
finishPeerRead0(this);
}
peer.unsafe().close(peer.unsafe().voidPromise());
}
@Override
protected void doDeregister() throws Exception {
// Just remove the shutdownHook as this Channel may be closed later or registered to another EventLoop
@ -288,35 +297,49 @@ public class LocalChannel extends AbstractChannel {
}
final LocalChannel peer = this.peer;
final ChannelPipeline peerPipeline = peer.pipeline();
final EventLoop peerLoop = peer.eventLoop();
for (;;) {
Object msg = in.current();
if (msg == null) {
break;
}
try {
peer.inboundBuffer.add(ReferenceCountUtil.retain(msg));
in.remove();
} catch (Throwable cause) {
in.remove(cause);
writeInProgress = true;
try {
for (;;) {
Object msg = in.current();
if (msg == null) {
break;
}
try {
peer.inboundBuffer.add(ReferenceCountUtil.retain(msg));
in.remove();
} catch (Throwable cause) {
in.remove(cause);
}
}
} finally {
// The following situation may cause trouble:
// 1. Write (with promise X)
// 2. promise X is completed when in.remove() is called, and a listener on this promise calls close()
// 3. Then the close event will be executed for the peer before the write events, when the write events
// actually happened before the close event.
writeInProgress = false;
}
if (peerLoop == eventLoop()) {
finishPeerRead(peer, peerPipeline);
finishPeerRead(peer);
}
private void finishPeerRead(final LocalChannel peer) {
// If the peer is also writing, then we must schedule the event on the event loop to preserve read order.
if (peer.eventLoop() == eventLoop() && !peer.writeInProgress) {
finishPeerRead0(peer);
} else {
peerLoop.execute(new OneTimeTask() {
peer.eventLoop().execute(new OneTimeTask() {
@Override
public void run() {
finishPeerRead(peer, peerPipeline);
finishPeerRead0(peer);
}
});
}
}
private static void finishPeerRead(LocalChannel peer, ChannelPipeline peerPipeline) {
private static void finishPeerRead0(LocalChannel peer) {
ChannelPipeline peerPipeline = peer.pipeline();
if (peer.readInProgress) {
peer.readInProgress = false;
for (;;) {

View File

@ -17,28 +17,43 @@ package io.netty.channel.local;
import io.netty.bootstrap.Bootstrap;
import io.netty.bootstrap.ServerBootstrap;
import io.netty.buffer.ByteBuf;
import io.netty.buffer.Unpooled;
import io.netty.channel.AbstractChannel;
import io.netty.channel.Channel;
import io.netty.channel.ChannelFuture;
import io.netty.channel.ChannelFutureListener;
import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.ChannelInboundHandlerAdapter;
import io.netty.channel.ChannelInitializer;
import io.netty.channel.ChannelPromise;
import io.netty.channel.DefaultEventLoopGroup;
import io.netty.channel.EventLoop;
import io.netty.channel.EventLoopGroup;
import io.netty.channel.SimpleChannelInboundHandler;
import io.netty.channel.SingleThreadEventLoop;
import io.netty.util.ReferenceCountUtil;
import io.netty.util.concurrent.Future;
import io.netty.util.internal.OneTimeTask;
import io.netty.util.internal.logging.InternalLogger;
import io.netty.util.internal.logging.InternalLoggerFactory;
import org.junit.AfterClass;
import org.junit.BeforeClass;
import org.junit.Test;
import java.nio.channels.ClosedChannelException;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.Executor;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicReference;
import static org.hamcrest.CoreMatchers.*;
import static org.junit.Assert.*;
import static java.util.concurrent.TimeUnit.SECONDS;
import static org.hamcrest.CoreMatchers.instanceOf;
import static org.hamcrest.CoreMatchers.is;
import static org.junit.Assert.assertFalse;
import static org.junit.Assert.assertNull;
import static org.junit.Assert.assertThat;
import static org.junit.Assert.assertTrue;
import static org.junit.Assert.fail;
public class LocalChannelTest {
@ -46,20 +61,39 @@ public class LocalChannelTest {
private static final String LOCAL_ADDR_ID = "test.id";
private static EventLoopGroup group1;
private static EventLoopGroup group2;
private static EventLoopGroup sharedGroup;
@BeforeClass
public static void beforeClass() {
group1 = new DefaultEventLoopGroup(2);
group2 = new DefaultEventLoopGroup(2);
sharedGroup = new DefaultEventLoopGroup(1);
}
@AfterClass
public static void afterClass() throws InterruptedException {
Future<?> group1Future = group1.shutdownGracefully(0, 0, SECONDS);
Future<?> group2Future = group2.shutdownGracefully(0, 0, SECONDS);
Future<?> sharedGroupFuture = sharedGroup.shutdownGracefully(0, 0, SECONDS);
group1Future.await();
group2Future.await();
sharedGroupFuture.await();
}
@Test
public void testLocalAddressReuse() throws Exception {
for (int i = 0; i < 2; i ++) {
EventLoopGroup clientGroup = new DefaultEventLoopGroup();
EventLoopGroup serverGroup = new DefaultEventLoopGroup();
LocalAddress addr = new LocalAddress(LOCAL_ADDR_ID);
Bootstrap cb = new Bootstrap();
ServerBootstrap sb = new ServerBootstrap();
cb.group(clientGroup)
cb.group(group1)
.channel(LocalChannel.class)
.handler(new TestHandler());
sb.group(serverGroup)
sb.group(group2)
.channel(LocalServerChannel.class)
.childHandler(new ChannelInitializer<LocalChannel>() {
@Override
@ -68,52 +102,52 @@ public class LocalChannelTest {
}
});
// Start server
Channel sc = sb.bind(addr).sync().channel();
Channel sc = null;
Channel cc = null;
try {
// Start server
sc = sb.bind(addr).sync().channel();
final CountDownLatch latch = new CountDownLatch(1);
// Connect to the server
final Channel cc = cb.connect(addr).sync().channel();
cc.eventLoop().execute(new Runnable() {
@Override
public void run() {
// Send a message event up the pipeline.
cc.pipeline().fireChannelRead("Hello, World");
latch.countDown();
}
});
latch.await();
final CountDownLatch latch = new CountDownLatch(1);
// Connect to the server
cc = cb.connect(addr).sync().channel();
final Channel ccCpy = cc;
cc.eventLoop().execute(new Runnable() {
@Override
public void run() {
// Send a message event up the pipeline.
ccCpy.pipeline().fireChannelRead("Hello, World");
latch.countDown();
}
});
assertTrue(latch.await(5, SECONDS));
// Close the channel
cc.close().sync();
// Close the channel
closeChannel(cc);
closeChannel(sc);
sc.closeFuture().sync();
serverGroup.shutdownGracefully();
clientGroup.shutdownGracefully();
sc.closeFuture().sync();
assertNull(String.format(
"Expected null, got channel '%s' for local address '%s'",
LocalChannelRegistry.get(addr), addr), LocalChannelRegistry.get(addr));
serverGroup.terminationFuture().sync();
clientGroup.terminationFuture().sync();
assertNull(String.format(
"Expected null, got channel '%s' for local address '%s'",
LocalChannelRegistry.get(addr), addr), LocalChannelRegistry.get(addr));
} finally {
closeChannel(cc);
closeChannel(sc);
}
}
}
@Test
public void testWriteFailsFastOnClosedChannel() throws Exception {
EventLoopGroup clientGroup = new DefaultEventLoopGroup();
EventLoopGroup serverGroup = new DefaultEventLoopGroup();
LocalAddress addr = new LocalAddress(LOCAL_ADDR_ID);
Bootstrap cb = new Bootstrap();
ServerBootstrap sb = new ServerBootstrap();
cb.group(clientGroup)
cb.group(group1)
.channel(LocalChannel.class)
.handler(new TestHandler());
sb.group(serverGroup)
sb.group(group2)
.channel(LocalServerChannel.class)
.childHandler(new ChannelInitializer<LocalChannel>() {
@Override
@ -122,41 +156,43 @@ public class LocalChannelTest {
}
});
// Start server
sb.bind(addr).sync();
// Connect to the server
final Channel cc = cb.connect(addr).sync().channel();
// Close the channel and write something.
cc.close().sync();
Channel sc = null;
Channel cc = null;
try {
cc.writeAndFlush(new Object()).sync();
fail("must raise a ClosedChannelException");
} catch (Exception e) {
assertThat(e, is(instanceOf(ClosedChannelException.class)));
// Ensure that the actual write attempt on a closed channel was never made by asserting that
// the ClosedChannelException has been created by AbstractUnsafe rather than transport implementations.
if (e.getStackTrace().length > 0) {
assertThat(
e.getStackTrace()[0].getClassName(), is(AbstractChannel.class.getName() + "$AbstractUnsafe"));
e.printStackTrace();
}
}
// Start server
sc = sb.bind(addr).sync().channel();
serverGroup.shutdownGracefully();
clientGroup.shutdownGracefully();
serverGroup.terminationFuture().sync();
clientGroup.terminationFuture().sync();
// Connect to the server
cc = cb.connect(addr).sync().channel();
// Close the channel and write something.
cc.close().sync();
try {
cc.writeAndFlush(new Object()).sync();
fail("must raise a ClosedChannelException");
} catch (Exception e) {
assertThat(e, is(instanceOf(ClosedChannelException.class)));
// Ensure that the actual write attempt on a closed channel was never made by asserting that
// the ClosedChannelException has been created by AbstractUnsafe rather than transport implementations.
if (e.getStackTrace().length > 0) {
assertThat(
e.getStackTrace()[0].getClassName(), is(AbstractChannel.class.getName() +
"$AbstractUnsafe"));
e.printStackTrace();
}
}
} finally {
closeChannel(cc);
closeChannel(sc);
}
}
@Test
public void testServerCloseChannelSameEventLoop() throws Exception {
LocalAddress addr = new LocalAddress(LOCAL_ADDR_ID);
EventLoopGroup group = new DefaultEventLoopGroup(1);
final CountDownLatch latch = new CountDownLatch(1);
ServerBootstrap sb = new ServerBootstrap()
.group(group)
.group(group2)
.channel(LocalServerChannel.class)
.childHandler(new SimpleChannelInboundHandler<Object>() {
@Override
@ -165,29 +201,33 @@ public class LocalChannelTest {
latch.countDown();
}
});
sb.bind(addr).sync();
Channel sc = null;
Channel cc = null;
try {
sc = sb.bind(addr).sync().channel();
Bootstrap b = new Bootstrap()
.group(group)
.channel(LocalChannel.class)
.handler(new SimpleChannelInboundHandler<Object>() {
@Override
protected void channelRead0(ChannelHandlerContext ctx, Object msg) throws Exception {
// discard
}
});
Channel channel = b.connect(addr).sync().channel();
channel.writeAndFlush(new Object());
latch.await();
group.shutdownGracefully();
group.terminationFuture().sync();
Bootstrap b = new Bootstrap()
.group(group2)
.channel(LocalChannel.class)
.handler(new SimpleChannelInboundHandler<Object>() {
@Override
protected void channelRead0(ChannelHandlerContext ctx, Object msg) throws Exception {
// discard
}
});
cc = b.connect(addr).sync().channel();
cc.writeAndFlush(new Object());
assertTrue(latch.await(5, SECONDS));
} finally {
closeChannel(cc);
closeChannel(sc);
}
}
@Test
public void localChannelRaceCondition() throws Exception {
final LocalAddress address = new LocalAddress("test");
final LocalAddress address = new LocalAddress(LOCAL_ADDR_ID);
final CountDownLatch closeLatch = new CountDownLatch(1);
final EventLoopGroup serverGroup = new DefaultEventLoopGroup(1);
final EventLoopGroup clientGroup = new DefaultEventLoopGroup(1) {
@Override
protected EventLoop newChild(Executor threadFactory, Object... args)
@ -218,9 +258,11 @@ public class LocalChannelTest {
};
}
};
Channel sc = null;
Channel cc = null;
try {
ServerBootstrap sb = new ServerBootstrap();
sb.group(serverGroup).
sc = sb.group(group2).
channel(LocalServerChannel.class).
childHandler(new ChannelInitializer<Channel>() {
@Override
@ -230,7 +272,7 @@ public class LocalChannelTest {
}
}).
bind(address).
sync();
sync().channel();
Bootstrap bootstrap = new Bootstrap();
bootstrap.group(clientGroup).
channel(LocalChannel.class).
@ -242,16 +284,16 @@ public class LocalChannelTest {
});
ChannelFuture future = bootstrap.connect(address);
assertTrue("Connection should finish, not time out", future.await(200));
cc = future.channel();
} finally {
serverGroup.shutdownGracefully(0, 0, TimeUnit.SECONDS).await();
clientGroup.shutdownGracefully(0, 0, TimeUnit.SECONDS).await();
closeChannel(cc);
closeChannel(sc);
clientGroup.shutdownGracefully(0, 0, SECONDS).await();
}
}
@Test
public void testReRegister() {
EventLoopGroup group1 = new DefaultEventLoopGroup();
EventLoopGroup group2 = new DefaultEventLoopGroup();
LocalAddress addr = new LocalAddress(LOCAL_ADDR_ID);
Bootstrap cb = new Bootstrap();
ServerBootstrap sb = new ServerBootstrap();
@ -269,18 +311,409 @@ public class LocalChannelTest {
}
});
// Start server
final Channel sc = sb.bind(addr).syncUninterruptibly().channel();
Channel sc = null;
Channel cc = null;
try {
// Start server
sc = sb.bind(addr).syncUninterruptibly().channel();
// Connect to the server
final Channel cc = cb.connect(addr).syncUninterruptibly().channel();
// Connect to the server
cc = cb.connect(addr).syncUninterruptibly().channel();
cc.deregister().syncUninterruptibly();
// Change event loop group.
group2.register(cc).syncUninterruptibly();
cc.close().syncUninterruptibly();
sc.close().syncUninterruptibly();
cc.deregister().syncUninterruptibly();
} finally {
closeChannel(cc);
closeChannel(sc);
}
}
@Test
public void testCloseInWritePromiseCompletePreservesOrder() throws InterruptedException {
LocalAddress addr = new LocalAddress(LOCAL_ADDR_ID);
Bootstrap cb = new Bootstrap();
ServerBootstrap sb = new ServerBootstrap();
final CountDownLatch messageLatch = new CountDownLatch(2);
final ByteBuf data = Unpooled.wrappedBuffer(new byte[1024]);
try {
cb.group(group1)
.channel(LocalChannel.class)
.handler(new TestHandler());
sb.group(group2)
.channel(LocalServerChannel.class)
.childHandler(new ChannelInboundHandlerAdapter() {
@Override
public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception {
if (msg.equals(data)) {
messageLatch.countDown();
ReferenceCountUtil.safeRelease(msg);
} else {
super.channelRead(ctx, msg);
}
}
@Override
public void channelInactive(ChannelHandlerContext ctx) throws Exception {
messageLatch.countDown();
super.channelInactive(ctx);
}
});
Channel sc = null;
Channel cc = null;
try {
// Start server
sc = sb.bind(addr).syncUninterruptibly().channel();
// Connect to the server
cc = cb.connect(addr).syncUninterruptibly().channel();
final Channel ccCpy = cc;
// Make sure a write operation is executed in the eventloop
cc.pipeline().lastContext().executor().execute(new OneTimeTask() {
@Override
public void run() {
ChannelPromise promise = ccCpy.newPromise();
promise.addListener(new ChannelFutureListener() {
@Override
public void operationComplete(ChannelFuture future) throws Exception {
ccCpy.pipeline().lastContext().close();
}
});
ccCpy.writeAndFlush(data.duplicate().retain(), promise);
}
});
assertTrue(messageLatch.await(5, SECONDS));
assertFalse(cc.isOpen());
} finally {
closeChannel(cc);
closeChannel(sc);
}
} finally {
data.release();
}
}
@Test
public void testWriteInWritePromiseCompletePreservesOrder() throws InterruptedException {
LocalAddress addr = new LocalAddress(LOCAL_ADDR_ID);
Bootstrap cb = new Bootstrap();
ServerBootstrap sb = new ServerBootstrap();
final CountDownLatch messageLatch = new CountDownLatch(2);
final ByteBuf data = Unpooled.wrappedBuffer(new byte[1024]);
final ByteBuf data2 = Unpooled.wrappedBuffer(new byte[512]);
try {
cb.group(group1)
.channel(LocalChannel.class)
.handler(new TestHandler());
sb.group(group2)
.channel(LocalServerChannel.class)
.childHandler(new ChannelInboundHandlerAdapter() {
@Override
public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception {
final long count = messageLatch.getCount();
if ((data.equals(msg) && count == 2) || (data2.equals(msg) && count == 1)) {
messageLatch.countDown();
ReferenceCountUtil.safeRelease(msg);
} else {
super.channelRead(ctx, msg);
}
}
});
Channel sc = null;
Channel cc = null;
try {
// Start server
sc = sb.bind(addr).syncUninterruptibly().channel();
// Connect to the server
cc = cb.connect(addr).syncUninterruptibly().channel();
final Channel ccCpy = cc;
// Make sure a write operation is executed in the eventloop
cc.pipeline().lastContext().executor().execute(new OneTimeTask() {
@Override
public void run() {
ChannelPromise promise = ccCpy.newPromise();
promise.addListener(new ChannelFutureListener() {
@Override
public void operationComplete(ChannelFuture future) throws Exception {
ccCpy.writeAndFlush(data2.duplicate().retain(), ccCpy.newPromise());
}
});
ccCpy.writeAndFlush(data.duplicate().retain(), promise);
}
});
assertTrue(messageLatch.await(5, SECONDS));
} finally {
closeChannel(cc);
closeChannel(sc);
}
} finally {
data.release();
data2.release();
}
}
@Test
public void testPeerWriteInWritePromiseCompleteDifferentEventLoopPreservesOrder() throws InterruptedException {
LocalAddress addr = new LocalAddress(LOCAL_ADDR_ID);
Bootstrap cb = new Bootstrap();
ServerBootstrap sb = new ServerBootstrap();
final CountDownLatch messageLatch = new CountDownLatch(2);
final ByteBuf data = Unpooled.wrappedBuffer(new byte[1024]);
final ByteBuf data2 = Unpooled.wrappedBuffer(new byte[512]);
final CountDownLatch serverChannelLatch = new CountDownLatch(1);
final AtomicReference<Channel> serverChannelRef = new AtomicReference<Channel>();
cb.group(group1)
.channel(LocalChannel.class)
.handler(new ChannelInboundHandlerAdapter() {
@Override
public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception {
if (data2.equals(msg)) {
messageLatch.countDown();
ReferenceCountUtil.safeRelease(msg);
} else {
super.channelRead(ctx, msg);
}
}
});
sb.group(group2)
.channel(LocalServerChannel.class)
.childHandler(new ChannelInitializer<LocalChannel>() {
@Override
public void initChannel(LocalChannel ch) throws Exception {
ch.pipeline().addLast(new ChannelInboundHandlerAdapter() {
@Override
public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception {
if (data.equals(msg)) {
messageLatch.countDown();
ReferenceCountUtil.safeRelease(msg);
} else {
super.channelRead(ctx, msg);
}
}
});
serverChannelRef.set(ch);
serverChannelLatch.countDown();
}
});
Channel sc = null;
Channel cc = null;
try {
// Start server
sc = sb.bind(addr).syncUninterruptibly().channel();
// Connect to the server
cc = cb.connect(addr).syncUninterruptibly().channel();
assertTrue(serverChannelLatch.await(5, SECONDS));
final Channel ccCpy = cc;
// Make sure a write operation is executed in the eventloop
cc.pipeline().lastContext().executor().execute(new OneTimeTask() {
@Override
public void run() {
ChannelPromise promise = ccCpy.newPromise();
promise.addListener(new ChannelFutureListener() {
@Override
public void operationComplete(ChannelFuture future) throws Exception {
Channel serverChannelCpy = serverChannelRef.get();
serverChannelCpy.writeAndFlush(data2.duplicate().retain(), serverChannelCpy.newPromise());
}
});
ccCpy.writeAndFlush(data.duplicate().retain(), promise);
}
});
assertTrue(messageLatch.await(5, SECONDS));
} finally {
closeChannel(cc);
closeChannel(sc);
data.release();
data2.release();
}
}
@Test
public void testPeerWriteInWritePromiseCompleteSameEventLoopPreservesOrder() throws InterruptedException {
LocalAddress addr = new LocalAddress(LOCAL_ADDR_ID);
Bootstrap cb = new Bootstrap();
ServerBootstrap sb = new ServerBootstrap();
final CountDownLatch messageLatch = new CountDownLatch(2);
final ByteBuf data = Unpooled.wrappedBuffer(new byte[1024]);
final ByteBuf data2 = Unpooled.wrappedBuffer(new byte[512]);
final CountDownLatch serverChannelLatch = new CountDownLatch(1);
final AtomicReference<Channel> serverChannelRef = new AtomicReference<Channel>();
try {
cb.group(sharedGroup)
.channel(LocalChannel.class)
.handler(new ChannelInboundHandlerAdapter() {
@Override
public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception {
if (data2.equals(msg) && messageLatch.getCount() == 1) {
messageLatch.countDown();
ReferenceCountUtil.safeRelease(msg);
} else {
super.channelRead(ctx, msg);
}
}
});
sb.group(sharedGroup)
.channel(LocalServerChannel.class)
.childHandler(new ChannelInitializer<LocalChannel>() {
@Override
public void initChannel(LocalChannel ch) throws Exception {
ch.pipeline().addLast(new ChannelInboundHandlerAdapter() {
@Override
public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception {
if (data.equals(msg) && messageLatch.getCount() == 2) {
messageLatch.countDown();
ReferenceCountUtil.safeRelease(msg);
} else {
super.channelRead(ctx, msg);
}
}
});
serverChannelRef.set(ch);
serverChannelLatch.countDown();
}
});
Channel sc = null;
Channel cc = null;
try {
// Start server
sc = sb.bind(addr).syncUninterruptibly().channel();
// Connect to the server
cc = cb.connect(addr).syncUninterruptibly().channel();
assertTrue(serverChannelLatch.await(5, SECONDS));
final Channel ccCpy = cc;
// Make sure a write operation is executed in the eventloop
cc.pipeline().lastContext().executor().execute(new OneTimeTask() {
@Override
public void run() {
ChannelPromise promise = ccCpy.newPromise();
promise.addListener(new ChannelFutureListener() {
@Override
public void operationComplete(ChannelFuture future) throws Exception {
Channel serverChannelCpy = serverChannelRef.get();
serverChannelCpy.writeAndFlush(data2.duplicate().retain(),
serverChannelCpy.newPromise());
}
});
ccCpy.writeAndFlush(data.duplicate().retain(), promise);
}
});
assertTrue(messageLatch.await(5, SECONDS));
} finally {
closeChannel(cc);
closeChannel(sc);
}
} finally {
data.release();
data2.release();
}
}
@Test
public void testClosePeerInWritePromiseCompleteSameEventLoopPreservesOrder() throws InterruptedException {
LocalAddress addr = new LocalAddress(LOCAL_ADDR_ID);
Bootstrap cb = new Bootstrap();
ServerBootstrap sb = new ServerBootstrap();
final CountDownLatch messageLatch = new CountDownLatch(2);
final CountDownLatch serverChannelLatch = new CountDownLatch(1);
final ByteBuf data = Unpooled.wrappedBuffer(new byte[1024]);
final AtomicReference<Channel> serverChannelRef = new AtomicReference<Channel>();
try {
cb.group(sharedGroup)
.channel(LocalChannel.class)
.handler(new TestHandler());
sb.group(sharedGroup)
.channel(LocalServerChannel.class)
.childHandler(new ChannelInitializer<LocalChannel>() {
@Override
public void initChannel(LocalChannel ch) throws Exception {
ch.pipeline().addLast(new ChannelInboundHandlerAdapter() {
@Override
public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception {
if (msg.equals(data)) {
messageLatch.countDown();
ReferenceCountUtil.safeRelease(msg);
} else {
super.channelRead(ctx, msg);
}
}
@Override
public void channelInactive(ChannelHandlerContext ctx) throws Exception {
messageLatch.countDown();
super.channelInactive(ctx);
}
});
serverChannelRef.set(ch);
serverChannelLatch.countDown();
}
});
Channel sc = null;
Channel cc = null;
try {
// Start server
sc = sb.bind(addr).syncUninterruptibly().channel();
// Connect to the server
cc = cb.connect(addr).syncUninterruptibly().channel();
assertTrue(serverChannelLatch.await(5, SECONDS));
final Channel ccCpy = cc;
// Make sure a write operation is executed in the eventloop
cc.pipeline().lastContext().executor().execute(new OneTimeTask() {
@Override
public void run() {
ChannelPromise promise = ccCpy.newPromise();
promise.addListener(new ChannelFutureListener() {
@Override
public void operationComplete(ChannelFuture future) throws Exception {
serverChannelRef.get().close();
}
});
ccCpy.writeAndFlush(data.duplicate().retain(), promise);
}
});
assertTrue(messageLatch.await(5, SECONDS));
assertFalse(cc.isOpen());
assertFalse(serverChannelRef.get().isOpen());
} finally {
closeChannel(cc);
closeChannel(sc);
}
} finally {
data.release();
}
}
private static void closeChannel(Channel cc) {
if (cc != null) {
cc.close().syncUninterruptibly();
}
}
static class TestHandler extends ChannelInboundHandlerAdapter {
@Override
public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception {