LocalChannel Event Ordering Error

Motivation:
When a LocalChannel write operation occurs, the promise associated with the write operation is marked successful when it is added to the peer's queue, but before the peer has actually received the data. If the promise callback closes the channel then a race condition exists where the close event may occur before the data is delivered. We should preserve ordering of events.

Modifications:
- LocalChannel should track when a write is in progress, and if a close operation happens make sure the peer gets all pending read operations.

Result:
LocalChannel preserves order of operations.
Fixes https://github.com/netty/netty/issues/4118
This commit is contained in:
Scott Mitchell 2015-08-20 19:19:51 -07:00
parent 48662bf41d
commit e37069b947
2 changed files with 582 additions and 126 deletions

View File

@ -81,6 +81,7 @@ public class LocalChannel extends AbstractChannel {
private volatile ChannelPromise connectPromise; private volatile ChannelPromise connectPromise;
private volatile boolean readInProgress; private volatile boolean readInProgress;
private volatile boolean registerInProgress; private volatile boolean registerInProgress;
private volatile boolean writeInProgress;
public LocalChannel() { public LocalChannel() {
super(null); super(null);
@ -173,7 +174,7 @@ public class LocalChannel extends AbstractChannel {
// This ensures that if both channels are on the same event loop, the peer's channelActive // This ensures that if both channels are on the same event loop, the peer's channelActive
// event is triggered *after* this channel's channelRegistered event, so that this channel's // event is triggered *after* this channel's channelRegistered event, so that this channel's
// pipeline is fully initialized by ChannelInitializer before any channelRead events. // pipeline is fully initialized by ChannelInitializer before any channelRead events.
peer.eventLoop().execute(new Runnable() { peer.eventLoop().execute(new OneTimeTask() {
@Override @Override
public void run() { public void run() {
registerInProgress = false; registerInProgress = false;
@ -200,7 +201,12 @@ public class LocalChannel extends AbstractChannel {
@Override @Override
protected void doClose() throws Exception { protected void doClose() throws Exception {
final LocalChannel peer = this.peer;
if (state != State.CLOSED) { if (state != State.CLOSED) {
// To preserve ordering of events we must process any pending reads
if (writeInProgress && peer != null) {
finishPeerRead(peer);
}
// Update all internal state before the closeFuture is notified. // Update all internal state before the closeFuture is notified.
if (localAddress != null) { if (localAddress != null) {
if (parent() == null) { if (parent() == null) {
@ -211,23 +217,19 @@ public class LocalChannel extends AbstractChannel {
state = State.CLOSED; state = State.CLOSED;
} }
final LocalChannel peer = this.peer;
if (peer != null && peer.isActive()) { if (peer != null && peer.isActive()) {
// Need to execute the close in the correct EventLoop // Need to execute the close in the correct EventLoop (see https://github.com/netty/netty/issues/1777).
// See https://github.com/netty/netty/issues/1777
EventLoop eventLoop = peer.eventLoop();
// Also check if the registration was not done yet. In this case we submit the close to the EventLoop // Also check if the registration was not done yet. In this case we submit the close to the EventLoop
// to make sure it is run after the registration completes. // to make sure its run after the registration completes (see https://github.com/netty/netty/issues/2144).
// if (peer.eventLoop().inEventLoop() && !registerInProgress) {
// See https://github.com/netty/netty/issues/2144 doPeerClose(peer, peer.writeInProgress);
if (eventLoop.inEventLoop() && !registerInProgress) {
peer.unsafe().close(unsafe().voidPromise());
} else { } else {
peer.eventLoop().execute(new Runnable() { // This value may change, and so we should save it before executing the Runnable.
final boolean peerWriteInProgress = peer.writeInProgress;
peer.eventLoop().execute(new OneTimeTask() {
@Override @Override
public void run() { public void run() {
peer.unsafe().close(unsafe().voidPromise()); doPeerClose(peer, peerWriteInProgress);
} }
}); });
} }
@ -235,6 +237,13 @@ public class LocalChannel extends AbstractChannel {
} }
} }
private void doPeerClose(LocalChannel peer, boolean peerWriteInProgress) {
if (peerWriteInProgress) {
finishPeerRead0(this);
}
peer.unsafe().close(peer.unsafe().voidPromise());
}
@Override @Override
protected void doDeregister() throws Exception { protected void doDeregister() throws Exception {
// Just remove the shutdownHook as this Channel may be closed later or registered to another EventLoop // Just remove the shutdownHook as this Channel may be closed later or registered to another EventLoop
@ -288,9 +297,9 @@ public class LocalChannel extends AbstractChannel {
} }
final LocalChannel peer = this.peer; final LocalChannel peer = this.peer;
final ChannelPipeline peerPipeline = peer.pipeline();
final EventLoop peerLoop = peer.eventLoop();
writeInProgress = true;
try {
for (;;) { for (;;) {
Object msg = in.current(); Object msg = in.current();
if (msg == null) { if (msg == null) {
@ -303,20 +312,34 @@ public class LocalChannel extends AbstractChannel {
in.remove(cause); in.remove(cause);
} }
} }
} finally {
// The following situation may cause trouble:
// 1. Write (with promise X)
// 2. promise X is completed when in.remove() is called, and a listener on this promise calls close()
// 3. Then the close event will be executed for the peer before the write events, when the write events
// actually happened before the close event.
writeInProgress = false;
}
if (peerLoop == eventLoop()) { finishPeerRead(peer);
finishPeerRead(peer, peerPipeline); }
private void finishPeerRead(final LocalChannel peer) {
// If the peer is also writing, then we must schedule the event on the event loop to preserve read order.
if (peer.eventLoop() == eventLoop() && !peer.writeInProgress) {
finishPeerRead0(peer);
} else { } else {
peerLoop.execute(new OneTimeTask() { peer.eventLoop().execute(new OneTimeTask() {
@Override @Override
public void run() { public void run() {
finishPeerRead(peer, peerPipeline); finishPeerRead0(peer);
} }
}); });
} }
} }
private static void finishPeerRead(LocalChannel peer, ChannelPipeline peerPipeline) { private static void finishPeerRead0(LocalChannel peer) {
ChannelPipeline peerPipeline = peer.pipeline();
if (peer.readInProgress) { if (peer.readInProgress) {
peer.readInProgress = false; peer.readInProgress = false;
for (;;) { for (;;) {

View File

@ -17,28 +17,43 @@ package io.netty.channel.local;
import io.netty.bootstrap.Bootstrap; import io.netty.bootstrap.Bootstrap;
import io.netty.bootstrap.ServerBootstrap; import io.netty.bootstrap.ServerBootstrap;
import io.netty.buffer.ByteBuf;
import io.netty.buffer.Unpooled;
import io.netty.channel.AbstractChannel; import io.netty.channel.AbstractChannel;
import io.netty.channel.Channel; import io.netty.channel.Channel;
import io.netty.channel.ChannelFuture; import io.netty.channel.ChannelFuture;
import io.netty.channel.ChannelFutureListener;
import io.netty.channel.ChannelHandlerContext; import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.ChannelInboundHandlerAdapter; import io.netty.channel.ChannelInboundHandlerAdapter;
import io.netty.channel.ChannelInitializer; import io.netty.channel.ChannelInitializer;
import io.netty.channel.ChannelPromise;
import io.netty.channel.DefaultEventLoopGroup; import io.netty.channel.DefaultEventLoopGroup;
import io.netty.channel.EventLoop; import io.netty.channel.EventLoop;
import io.netty.channel.EventLoopGroup; import io.netty.channel.EventLoopGroup;
import io.netty.channel.SimpleChannelInboundHandler; import io.netty.channel.SimpleChannelInboundHandler;
import io.netty.channel.SingleThreadEventLoop; import io.netty.channel.SingleThreadEventLoop;
import io.netty.util.ReferenceCountUtil;
import io.netty.util.concurrent.Future;
import io.netty.util.internal.OneTimeTask;
import io.netty.util.internal.logging.InternalLogger; import io.netty.util.internal.logging.InternalLogger;
import io.netty.util.internal.logging.InternalLoggerFactory; import io.netty.util.internal.logging.InternalLoggerFactory;
import org.junit.AfterClass;
import org.junit.BeforeClass;
import org.junit.Test; import org.junit.Test;
import java.nio.channels.ClosedChannelException; import java.nio.channels.ClosedChannelException;
import java.util.concurrent.CountDownLatch; import java.util.concurrent.CountDownLatch;
import java.util.concurrent.Executor; import java.util.concurrent.Executor;
import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicReference;
import static org.hamcrest.CoreMatchers.*; import static java.util.concurrent.TimeUnit.SECONDS;
import static org.junit.Assert.*; import static org.hamcrest.CoreMatchers.instanceOf;
import static org.hamcrest.CoreMatchers.is;
import static org.junit.Assert.assertFalse;
import static org.junit.Assert.assertNull;
import static org.junit.Assert.assertThat;
import static org.junit.Assert.assertTrue;
import static org.junit.Assert.fail;
public class LocalChannelTest { public class LocalChannelTest {
@ -46,20 +61,39 @@ public class LocalChannelTest {
private static final String LOCAL_ADDR_ID = "test.id"; private static final String LOCAL_ADDR_ID = "test.id";
private static EventLoopGroup group1;
private static EventLoopGroup group2;
private static EventLoopGroup sharedGroup;
@BeforeClass
public static void beforeClass() {
group1 = new DefaultEventLoopGroup(2);
group2 = new DefaultEventLoopGroup(2);
sharedGroup = new DefaultEventLoopGroup(1);
}
@AfterClass
public static void afterClass() throws InterruptedException {
Future<?> group1Future = group1.shutdownGracefully(0, 0, SECONDS);
Future<?> group2Future = group2.shutdownGracefully(0, 0, SECONDS);
Future<?> sharedGroupFuture = sharedGroup.shutdownGracefully(0, 0, SECONDS);
group1Future.await();
group2Future.await();
sharedGroupFuture.await();
}
@Test @Test
public void testLocalAddressReuse() throws Exception { public void testLocalAddressReuse() throws Exception {
for (int i = 0; i < 2; i ++) { for (int i = 0; i < 2; i ++) {
EventLoopGroup clientGroup = new DefaultEventLoopGroup();
EventLoopGroup serverGroup = new DefaultEventLoopGroup();
LocalAddress addr = new LocalAddress(LOCAL_ADDR_ID); LocalAddress addr = new LocalAddress(LOCAL_ADDR_ID);
Bootstrap cb = new Bootstrap(); Bootstrap cb = new Bootstrap();
ServerBootstrap sb = new ServerBootstrap(); ServerBootstrap sb = new ServerBootstrap();
cb.group(clientGroup) cb.group(group1)
.channel(LocalChannel.class) .channel(LocalChannel.class)
.handler(new TestHandler()); .handler(new TestHandler());
sb.group(serverGroup) sb.group(group2)
.channel(LocalServerChannel.class) .channel(LocalServerChannel.class)
.childHandler(new ChannelInitializer<LocalChannel>() { .childHandler(new ChannelInitializer<LocalChannel>() {
@Override @Override
@ -68,52 +102,52 @@ public class LocalChannelTest {
} }
}); });
Channel sc = null;
Channel cc = null;
try {
// Start server // Start server
Channel sc = sb.bind(addr).sync().channel(); sc = sb.bind(addr).sync().channel();
final CountDownLatch latch = new CountDownLatch(1); final CountDownLatch latch = new CountDownLatch(1);
// Connect to the server // Connect to the server
final Channel cc = cb.connect(addr).sync().channel(); cc = cb.connect(addr).sync().channel();
final Channel ccCpy = cc;
cc.eventLoop().execute(new Runnable() { cc.eventLoop().execute(new Runnable() {
@Override @Override
public void run() { public void run() {
// Send a message event up the pipeline. // Send a message event up the pipeline.
cc.pipeline().fireChannelRead("Hello, World"); ccCpy.pipeline().fireChannelRead("Hello, World");
latch.countDown(); latch.countDown();
} }
}); });
latch.await(); assertTrue(latch.await(5, SECONDS));
// Close the channel // Close the channel
cc.close().sync(); closeChannel(cc);
closeChannel(sc);
serverGroup.shutdownGracefully();
clientGroup.shutdownGracefully();
sc.closeFuture().sync(); sc.closeFuture().sync();
assertNull(String.format( assertNull(String.format(
"Expected null, got channel '%s' for local address '%s'", "Expected null, got channel '%s' for local address '%s'",
LocalChannelRegistry.get(addr), addr), LocalChannelRegistry.get(addr)); LocalChannelRegistry.get(addr), addr), LocalChannelRegistry.get(addr));
} finally {
serverGroup.terminationFuture().sync(); closeChannel(cc);
clientGroup.terminationFuture().sync(); closeChannel(sc);
}
} }
} }
@Test @Test
public void testWriteFailsFastOnClosedChannel() throws Exception { public void testWriteFailsFastOnClosedChannel() throws Exception {
EventLoopGroup clientGroup = new DefaultEventLoopGroup();
EventLoopGroup serverGroup = new DefaultEventLoopGroup();
LocalAddress addr = new LocalAddress(LOCAL_ADDR_ID); LocalAddress addr = new LocalAddress(LOCAL_ADDR_ID);
Bootstrap cb = new Bootstrap(); Bootstrap cb = new Bootstrap();
ServerBootstrap sb = new ServerBootstrap(); ServerBootstrap sb = new ServerBootstrap();
cb.group(clientGroup) cb.group(group1)
.channel(LocalChannel.class) .channel(LocalChannel.class)
.handler(new TestHandler()); .handler(new TestHandler());
sb.group(serverGroup) sb.group(group2)
.channel(LocalServerChannel.class) .channel(LocalServerChannel.class)
.childHandler(new ChannelInitializer<LocalChannel>() { .childHandler(new ChannelInitializer<LocalChannel>() {
@Override @Override
@ -122,11 +156,14 @@ public class LocalChannelTest {
} }
}); });
Channel sc = null;
Channel cc = null;
try {
// Start server // Start server
sb.bind(addr).sync(); sc = sb.bind(addr).sync().channel();
// Connect to the server // Connect to the server
final Channel cc = cb.connect(addr).sync().channel(); cc = cb.connect(addr).sync().channel();
// Close the channel and write something. // Close the channel and write something.
cc.close().sync(); cc.close().sync();
@ -139,24 +176,23 @@ public class LocalChannelTest {
// the ClosedChannelException has been created by AbstractUnsafe rather than transport implementations. // the ClosedChannelException has been created by AbstractUnsafe rather than transport implementations.
if (e.getStackTrace().length > 0) { if (e.getStackTrace().length > 0) {
assertThat( assertThat(
e.getStackTrace()[0].getClassName(), is(AbstractChannel.class.getName() + "$AbstractUnsafe")); e.getStackTrace()[0].getClassName(), is(AbstractChannel.class.getName() +
"$AbstractUnsafe"));
e.printStackTrace(); e.printStackTrace();
} }
} }
} finally {
serverGroup.shutdownGracefully(); closeChannel(cc);
clientGroup.shutdownGracefully(); closeChannel(sc);
serverGroup.terminationFuture().sync(); }
clientGroup.terminationFuture().sync();
} }
@Test @Test
public void testServerCloseChannelSameEventLoop() throws Exception { public void testServerCloseChannelSameEventLoop() throws Exception {
LocalAddress addr = new LocalAddress(LOCAL_ADDR_ID); LocalAddress addr = new LocalAddress(LOCAL_ADDR_ID);
EventLoopGroup group = new DefaultEventLoopGroup(1);
final CountDownLatch latch = new CountDownLatch(1); final CountDownLatch latch = new CountDownLatch(1);
ServerBootstrap sb = new ServerBootstrap() ServerBootstrap sb = new ServerBootstrap()
.group(group) .group(group2)
.channel(LocalServerChannel.class) .channel(LocalServerChannel.class)
.childHandler(new SimpleChannelInboundHandler<Object>() { .childHandler(new SimpleChannelInboundHandler<Object>() {
@Override @Override
@ -165,10 +201,13 @@ public class LocalChannelTest {
latch.countDown(); latch.countDown();
} }
}); });
sb.bind(addr).sync(); Channel sc = null;
Channel cc = null;
try {
sc = sb.bind(addr).sync().channel();
Bootstrap b = new Bootstrap() Bootstrap b = new Bootstrap()
.group(group) .group(group2)
.channel(LocalChannel.class) .channel(LocalChannel.class)
.handler(new SimpleChannelInboundHandler<Object>() { .handler(new SimpleChannelInboundHandler<Object>() {
@Override @Override
@ -176,18 +215,19 @@ public class LocalChannelTest {
// discard // discard
} }
}); });
Channel channel = b.connect(addr).sync().channel(); cc = b.connect(addr).sync().channel();
channel.writeAndFlush(new Object()); cc.writeAndFlush(new Object());
latch.await(); assertTrue(latch.await(5, SECONDS));
group.shutdownGracefully(); } finally {
group.terminationFuture().sync(); closeChannel(cc);
closeChannel(sc);
}
} }
@Test @Test
public void localChannelRaceCondition() throws Exception { public void localChannelRaceCondition() throws Exception {
final LocalAddress address = new LocalAddress("test"); final LocalAddress address = new LocalAddress(LOCAL_ADDR_ID);
final CountDownLatch closeLatch = new CountDownLatch(1); final CountDownLatch closeLatch = new CountDownLatch(1);
final EventLoopGroup serverGroup = new DefaultEventLoopGroup(1);
final EventLoopGroup clientGroup = new DefaultEventLoopGroup(1) { final EventLoopGroup clientGroup = new DefaultEventLoopGroup(1) {
@Override @Override
protected EventLoop newChild(Executor threadFactory, Object... args) protected EventLoop newChild(Executor threadFactory, Object... args)
@ -218,9 +258,11 @@ public class LocalChannelTest {
}; };
} }
}; };
Channel sc = null;
Channel cc = null;
try { try {
ServerBootstrap sb = new ServerBootstrap(); ServerBootstrap sb = new ServerBootstrap();
sb.group(serverGroup). sc = sb.group(group2).
channel(LocalServerChannel.class). channel(LocalServerChannel.class).
childHandler(new ChannelInitializer<Channel>() { childHandler(new ChannelInitializer<Channel>() {
@Override @Override
@ -230,7 +272,7 @@ public class LocalChannelTest {
} }
}). }).
bind(address). bind(address).
sync(); sync().channel();
Bootstrap bootstrap = new Bootstrap(); Bootstrap bootstrap = new Bootstrap();
bootstrap.group(clientGroup). bootstrap.group(clientGroup).
channel(LocalChannel.class). channel(LocalChannel.class).
@ -242,16 +284,16 @@ public class LocalChannelTest {
}); });
ChannelFuture future = bootstrap.connect(address); ChannelFuture future = bootstrap.connect(address);
assertTrue("Connection should finish, not time out", future.await(200)); assertTrue("Connection should finish, not time out", future.await(200));
cc = future.channel();
} finally { } finally {
serverGroup.shutdownGracefully(0, 0, TimeUnit.SECONDS).await(); closeChannel(cc);
clientGroup.shutdownGracefully(0, 0, TimeUnit.SECONDS).await(); closeChannel(sc);
clientGroup.shutdownGracefully(0, 0, SECONDS).await();
} }
} }
@Test @Test
public void testReRegister() { public void testReRegister() {
EventLoopGroup group1 = new DefaultEventLoopGroup();
EventLoopGroup group2 = new DefaultEventLoopGroup();
LocalAddress addr = new LocalAddress(LOCAL_ADDR_ID); LocalAddress addr = new LocalAddress(LOCAL_ADDR_ID);
Bootstrap cb = new Bootstrap(); Bootstrap cb = new Bootstrap();
ServerBootstrap sb = new ServerBootstrap(); ServerBootstrap sb = new ServerBootstrap();
@ -269,18 +311,409 @@ public class LocalChannelTest {
} }
}); });
Channel sc = null;
Channel cc = null;
try {
// Start server // Start server
final Channel sc = sb.bind(addr).syncUninterruptibly().channel(); sc = sb.bind(addr).syncUninterruptibly().channel();
// Connect to the server // Connect to the server
final Channel cc = cb.connect(addr).syncUninterruptibly().channel(); cc = cb.connect(addr).syncUninterruptibly().channel();
cc.deregister().syncUninterruptibly(); cc.deregister().syncUninterruptibly();
// Change event loop group. } finally {
group2.register(cc).syncUninterruptibly(); closeChannel(cc);
cc.close().syncUninterruptibly(); closeChannel(sc);
sc.close().syncUninterruptibly();
} }
}
@Test
public void testCloseInWritePromiseCompletePreservesOrder() throws InterruptedException {
LocalAddress addr = new LocalAddress(LOCAL_ADDR_ID);
Bootstrap cb = new Bootstrap();
ServerBootstrap sb = new ServerBootstrap();
final CountDownLatch messageLatch = new CountDownLatch(2);
final ByteBuf data = Unpooled.wrappedBuffer(new byte[1024]);
try {
cb.group(group1)
.channel(LocalChannel.class)
.handler(new TestHandler());
sb.group(group2)
.channel(LocalServerChannel.class)
.childHandler(new ChannelInboundHandlerAdapter() {
@Override
public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception {
if (msg.equals(data)) {
messageLatch.countDown();
ReferenceCountUtil.safeRelease(msg);
} else {
super.channelRead(ctx, msg);
}
}
@Override
public void channelInactive(ChannelHandlerContext ctx) throws Exception {
messageLatch.countDown();
super.channelInactive(ctx);
}
});
Channel sc = null;
Channel cc = null;
try {
// Start server
sc = sb.bind(addr).syncUninterruptibly().channel();
// Connect to the server
cc = cb.connect(addr).syncUninterruptibly().channel();
final Channel ccCpy = cc;
// Make sure a write operation is executed in the eventloop
cc.pipeline().lastContext().executor().execute(new OneTimeTask() {
@Override
public void run() {
ChannelPromise promise = ccCpy.newPromise();
promise.addListener(new ChannelFutureListener() {
@Override
public void operationComplete(ChannelFuture future) throws Exception {
ccCpy.pipeline().lastContext().close();
}
});
ccCpy.writeAndFlush(data.duplicate().retain(), promise);
}
});
assertTrue(messageLatch.await(5, SECONDS));
assertFalse(cc.isOpen());
} finally {
closeChannel(cc);
closeChannel(sc);
}
} finally {
data.release();
}
}
@Test
public void testWriteInWritePromiseCompletePreservesOrder() throws InterruptedException {
LocalAddress addr = new LocalAddress(LOCAL_ADDR_ID);
Bootstrap cb = new Bootstrap();
ServerBootstrap sb = new ServerBootstrap();
final CountDownLatch messageLatch = new CountDownLatch(2);
final ByteBuf data = Unpooled.wrappedBuffer(new byte[1024]);
final ByteBuf data2 = Unpooled.wrappedBuffer(new byte[512]);
try {
cb.group(group1)
.channel(LocalChannel.class)
.handler(new TestHandler());
sb.group(group2)
.channel(LocalServerChannel.class)
.childHandler(new ChannelInboundHandlerAdapter() {
@Override
public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception {
final long count = messageLatch.getCount();
if ((data.equals(msg) && count == 2) || (data2.equals(msg) && count == 1)) {
messageLatch.countDown();
ReferenceCountUtil.safeRelease(msg);
} else {
super.channelRead(ctx, msg);
}
}
});
Channel sc = null;
Channel cc = null;
try {
// Start server
sc = sb.bind(addr).syncUninterruptibly().channel();
// Connect to the server
cc = cb.connect(addr).syncUninterruptibly().channel();
final Channel ccCpy = cc;
// Make sure a write operation is executed in the eventloop
cc.pipeline().lastContext().executor().execute(new OneTimeTask() {
@Override
public void run() {
ChannelPromise promise = ccCpy.newPromise();
promise.addListener(new ChannelFutureListener() {
@Override
public void operationComplete(ChannelFuture future) throws Exception {
ccCpy.writeAndFlush(data2.duplicate().retain(), ccCpy.newPromise());
}
});
ccCpy.writeAndFlush(data.duplicate().retain(), promise);
}
});
assertTrue(messageLatch.await(5, SECONDS));
} finally {
closeChannel(cc);
closeChannel(sc);
}
} finally {
data.release();
data2.release();
}
}
@Test
public void testPeerWriteInWritePromiseCompleteDifferentEventLoopPreservesOrder() throws InterruptedException {
LocalAddress addr = new LocalAddress(LOCAL_ADDR_ID);
Bootstrap cb = new Bootstrap();
ServerBootstrap sb = new ServerBootstrap();
final CountDownLatch messageLatch = new CountDownLatch(2);
final ByteBuf data = Unpooled.wrappedBuffer(new byte[1024]);
final ByteBuf data2 = Unpooled.wrappedBuffer(new byte[512]);
final CountDownLatch serverChannelLatch = new CountDownLatch(1);
final AtomicReference<Channel> serverChannelRef = new AtomicReference<Channel>();
cb.group(group1)
.channel(LocalChannel.class)
.handler(new ChannelInboundHandlerAdapter() {
@Override
public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception {
if (data2.equals(msg)) {
messageLatch.countDown();
ReferenceCountUtil.safeRelease(msg);
} else {
super.channelRead(ctx, msg);
}
}
});
sb.group(group2)
.channel(LocalServerChannel.class)
.childHandler(new ChannelInitializer<LocalChannel>() {
@Override
public void initChannel(LocalChannel ch) throws Exception {
ch.pipeline().addLast(new ChannelInboundHandlerAdapter() {
@Override
public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception {
if (data.equals(msg)) {
messageLatch.countDown();
ReferenceCountUtil.safeRelease(msg);
} else {
super.channelRead(ctx, msg);
}
}
});
serverChannelRef.set(ch);
serverChannelLatch.countDown();
}
});
Channel sc = null;
Channel cc = null;
try {
// Start server
sc = sb.bind(addr).syncUninterruptibly().channel();
// Connect to the server
cc = cb.connect(addr).syncUninterruptibly().channel();
assertTrue(serverChannelLatch.await(5, SECONDS));
final Channel ccCpy = cc;
// Make sure a write operation is executed in the eventloop
cc.pipeline().lastContext().executor().execute(new OneTimeTask() {
@Override
public void run() {
ChannelPromise promise = ccCpy.newPromise();
promise.addListener(new ChannelFutureListener() {
@Override
public void operationComplete(ChannelFuture future) throws Exception {
Channel serverChannelCpy = serverChannelRef.get();
serverChannelCpy.writeAndFlush(data2.duplicate().retain(), serverChannelCpy.newPromise());
}
});
ccCpy.writeAndFlush(data.duplicate().retain(), promise);
}
});
assertTrue(messageLatch.await(5, SECONDS));
} finally {
closeChannel(cc);
closeChannel(sc);
data.release();
data2.release();
}
}
@Test
public void testPeerWriteInWritePromiseCompleteSameEventLoopPreservesOrder() throws InterruptedException {
LocalAddress addr = new LocalAddress(LOCAL_ADDR_ID);
Bootstrap cb = new Bootstrap();
ServerBootstrap sb = new ServerBootstrap();
final CountDownLatch messageLatch = new CountDownLatch(2);
final ByteBuf data = Unpooled.wrappedBuffer(new byte[1024]);
final ByteBuf data2 = Unpooled.wrappedBuffer(new byte[512]);
final CountDownLatch serverChannelLatch = new CountDownLatch(1);
final AtomicReference<Channel> serverChannelRef = new AtomicReference<Channel>();
try {
cb.group(sharedGroup)
.channel(LocalChannel.class)
.handler(new ChannelInboundHandlerAdapter() {
@Override
public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception {
if (data2.equals(msg) && messageLatch.getCount() == 1) {
messageLatch.countDown();
ReferenceCountUtil.safeRelease(msg);
} else {
super.channelRead(ctx, msg);
}
}
});
sb.group(sharedGroup)
.channel(LocalServerChannel.class)
.childHandler(new ChannelInitializer<LocalChannel>() {
@Override
public void initChannel(LocalChannel ch) throws Exception {
ch.pipeline().addLast(new ChannelInboundHandlerAdapter() {
@Override
public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception {
if (data.equals(msg) && messageLatch.getCount() == 2) {
messageLatch.countDown();
ReferenceCountUtil.safeRelease(msg);
} else {
super.channelRead(ctx, msg);
}
}
});
serverChannelRef.set(ch);
serverChannelLatch.countDown();
}
});
Channel sc = null;
Channel cc = null;
try {
// Start server
sc = sb.bind(addr).syncUninterruptibly().channel();
// Connect to the server
cc = cb.connect(addr).syncUninterruptibly().channel();
assertTrue(serverChannelLatch.await(5, SECONDS));
final Channel ccCpy = cc;
// Make sure a write operation is executed in the eventloop
cc.pipeline().lastContext().executor().execute(new OneTimeTask() {
@Override
public void run() {
ChannelPromise promise = ccCpy.newPromise();
promise.addListener(new ChannelFutureListener() {
@Override
public void operationComplete(ChannelFuture future) throws Exception {
Channel serverChannelCpy = serverChannelRef.get();
serverChannelCpy.writeAndFlush(data2.duplicate().retain(),
serverChannelCpy.newPromise());
}
});
ccCpy.writeAndFlush(data.duplicate().retain(), promise);
}
});
assertTrue(messageLatch.await(5, SECONDS));
} finally {
closeChannel(cc);
closeChannel(sc);
}
} finally {
data.release();
data2.release();
}
}
@Test
public void testClosePeerInWritePromiseCompleteSameEventLoopPreservesOrder() throws InterruptedException {
LocalAddress addr = new LocalAddress(LOCAL_ADDR_ID);
Bootstrap cb = new Bootstrap();
ServerBootstrap sb = new ServerBootstrap();
final CountDownLatch messageLatch = new CountDownLatch(2);
final CountDownLatch serverChannelLatch = new CountDownLatch(1);
final ByteBuf data = Unpooled.wrappedBuffer(new byte[1024]);
final AtomicReference<Channel> serverChannelRef = new AtomicReference<Channel>();
try {
cb.group(sharedGroup)
.channel(LocalChannel.class)
.handler(new TestHandler());
sb.group(sharedGroup)
.channel(LocalServerChannel.class)
.childHandler(new ChannelInitializer<LocalChannel>() {
@Override
public void initChannel(LocalChannel ch) throws Exception {
ch.pipeline().addLast(new ChannelInboundHandlerAdapter() {
@Override
public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception {
if (msg.equals(data)) {
messageLatch.countDown();
ReferenceCountUtil.safeRelease(msg);
} else {
super.channelRead(ctx, msg);
}
}
@Override
public void channelInactive(ChannelHandlerContext ctx) throws Exception {
messageLatch.countDown();
super.channelInactive(ctx);
}
});
serverChannelRef.set(ch);
serverChannelLatch.countDown();
}
});
Channel sc = null;
Channel cc = null;
try {
// Start server
sc = sb.bind(addr).syncUninterruptibly().channel();
// Connect to the server
cc = cb.connect(addr).syncUninterruptibly().channel();
assertTrue(serverChannelLatch.await(5, SECONDS));
final Channel ccCpy = cc;
// Make sure a write operation is executed in the eventloop
cc.pipeline().lastContext().executor().execute(new OneTimeTask() {
@Override
public void run() {
ChannelPromise promise = ccCpy.newPromise();
promise.addListener(new ChannelFutureListener() {
@Override
public void operationComplete(ChannelFuture future) throws Exception {
serverChannelRef.get().close();
}
});
ccCpy.writeAndFlush(data.duplicate().retain(), promise);
}
});
assertTrue(messageLatch.await(5, SECONDS));
assertFalse(cc.isOpen());
assertFalse(serverChannelRef.get().isOpen());
} finally {
closeChannel(cc);
closeChannel(sc);
}
} finally {
data.release();
}
}
private static void closeChannel(Channel cc) {
if (cc != null) {
cc.close().syncUninterruptibly();
}
}
static class TestHandler extends ChannelInboundHandlerAdapter { static class TestHandler extends ChannelInboundHandlerAdapter {
@Override @Override
public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception { public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception {