Ensure LocalChannel fire channelActive after peers's channelRegistered

- Also:
  - Made the test case more robust
  - Added a simple concurrent buffer modification test (needs more work)
This commit is contained in:
Trustin Lee 2012-06-03 12:54:26 -07:00
parent 361cb417e0
commit 234c4c70db
6 changed files with 196 additions and 64 deletions

View File

@ -424,10 +424,13 @@ public abstract class AbstractChannel extends DefaultAttributeMap implements Cha
}
try {
doRegister();
Runnable postRegisterTask = doRegister();
registered = true;
future.setSuccess();
pipeline.fireChannelRegistered();
if (postRegisterTask != null) {
postRegisterTask.run();
}
if (isActive()) {
pipeline.fireChannelActive();
}
@ -687,7 +690,7 @@ public abstract class AbstractChannel extends DefaultAttributeMap implements Cha
protected abstract SocketAddress localAddress0();
protected abstract SocketAddress remoteAddress0();
protected abstract void doRegister() throws Exception;
protected abstract Runnable doRegister() throws Exception;
protected abstract void doBind(SocketAddress localAddress) throws Exception;
protected abstract void doDisconnect() throws Exception;
protected abstract void doClose() throws Exception;

View File

@ -124,22 +124,39 @@ public class LocalChannel extends AbstractChannel {
}
@Override
protected void doRegister() throws Exception {
protected Runnable doRegister() throws Exception {
final LocalChannel peer = this.peer;
Runnable postRegisterTask;
if (peer != null) {
state = 2;
peer.remoteAddress = parent().localAddress();
peer.state = 2;
peer.eventLoop().execute(new Runnable() {
// Ensure the peer's channelActive event is triggered *after* this channel's
// channelRegistered event is triggered, so that this channel's pipeline is fully
// initialized by ChannelInitializer.
final EventLoop peerEventLoop = peer.eventLoop();
postRegisterTask = new Runnable() {
@Override
public void run() {
peer.connectFuture.setSuccess();
peer.pipeline().fireChannelActive();
peerEventLoop.execute(new Runnable() {
@Override
public void run() {
peer.connectFuture.setSuccess();
peer.pipeline().fireChannelActive();
}
});
}
});
};
} else {
postRegisterTask = null;
}
((SingleThreadEventLoop) eventLoop()).addShutdownHook(shutdownHook);
return postRegisterTask;
}
@Override

View File

@ -84,8 +84,9 @@ public class LocalServerChannel extends AbstractServerChannel {
}
@Override
protected void doRegister() throws Exception {
protected Runnable doRegister() throws Exception {
((SingleThreadEventLoop) eventLoop()).addShutdownHook(shutdownHook);
return null;
}
@Override

View File

@ -206,10 +206,11 @@ public abstract class AbstractNioChannel extends AbstractChannel {
}
@Override
protected void doRegister() throws Exception {
protected Runnable doRegister() throws Exception {
NioChildEventLoop loop = (NioChildEventLoop) eventLoop();
selectionKey = javaChannel().register(
loop.selector, isActive()? defaultInterestOps : 0, this);
return null;
}
@Override

View File

@ -74,8 +74,9 @@ abstract class AbstractOioChannel extends AbstractChannel {
}
@Override
protected void doRegister() throws Exception {
protected Runnable doRegister() throws Exception {
// NOOP
return null;
}
@Override

View File

@ -6,6 +6,7 @@ import io.netty.channel.ChannelBufferHolder;
import io.netty.channel.ChannelBufferHolders;
import io.netty.channel.ChannelFuture;
import io.netty.channel.ChannelHandlerAdapter;
import io.netty.channel.ChannelInboundHandlerAdapter;
import io.netty.channel.ChannelInboundHandlerContext;
import io.netty.channel.ChannelInboundMessageHandlerAdapter;
import io.netty.channel.ChannelInitializer;
@ -13,13 +14,16 @@ import io.netty.channel.ChannelOutboundHandlerContext;
import io.netty.channel.DefaultEventExecutor;
import io.netty.channel.EventExecutor;
import io.netty.channel.EventLoop;
import io.netty.util.internal.QueueFactory;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashSet;
import java.util.List;
import java.util.Queue;
import java.util.Set;
import java.util.concurrent.ThreadFactory;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.concurrent.atomic.AtomicReference;
import org.junit.AfterClass;
import org.junit.Assert;
@ -59,13 +63,13 @@ public class LocalTransportThreadModelTest {
}
@Test
public void testSimple() throws Exception {
public void testStagedExecution() throws Throwable {
EventLoop l = new LocalEventLoop(4, new PrefixThreadFactory("l"));
EventExecutor e1 = new DefaultEventExecutor(4, new PrefixThreadFactory("e1"));
EventExecutor e2 = new DefaultEventExecutor(4, new PrefixThreadFactory("e2"));
TestHandler h1 = new TestHandler();
TestHandler h2 = new TestHandler();
TestHandler h3 = new TestHandler();
ThreadNameAuditor h1 = new ThreadNameAuditor();
ThreadNameAuditor h2 = new ThreadNameAuditor();
ThreadNameAuditor h3 = new ThreadNameAuditor();
Channel ch = new LocalChannel();
// With no EventExecutor specified, h1 will be always invoked by EventLoop 'l'.
@ -90,63 +94,105 @@ public class LocalTransportThreadModelTest {
String currentName = Thread.currentThread().getName();
// Events should never be handled from the current thread.
Assert.assertFalse(h1.inboundThreadNames.contains(currentName));
Assert.assertFalse(h2.inboundThreadNames.contains(currentName));
Assert.assertFalse(h3.inboundThreadNames.contains(currentName));
Assert.assertFalse(h1.outboundThreadNames.contains(currentName));
Assert.assertFalse(h2.outboundThreadNames.contains(currentName));
Assert.assertFalse(h3.outboundThreadNames.contains(currentName));
try {
// Events should never be handled from the current thread.
Assert.assertFalse(h1.inboundThreadNames.contains(currentName));
Assert.assertFalse(h2.inboundThreadNames.contains(currentName));
Assert.assertFalse(h3.inboundThreadNames.contains(currentName));
Assert.assertFalse(h1.outboundThreadNames.contains(currentName));
Assert.assertFalse(h2.outboundThreadNames.contains(currentName));
Assert.assertFalse(h3.outboundThreadNames.contains(currentName));
// Assert that events were handled by the correct executor.
for (String name: h1.inboundThreadNames) {
Assert.assertTrue(name.startsWith("l-"));
}
for (String name: h2.inboundThreadNames) {
Assert.assertTrue(name.startsWith("e1-"));
}
for (String name: h3.inboundThreadNames) {
Assert.assertTrue(name.startsWith("e2-"));
}
for (String name: h1.outboundThreadNames) {
Assert.assertTrue(name.startsWith("l-"));
}
for (String name: h2.outboundThreadNames) {
Assert.assertTrue(name.startsWith("e1-"));
}
for (String name: h3.outboundThreadNames) {
Assert.assertTrue(name.startsWith("e2-"));
}
// Assert that events were handled by the correct executor.
for (String name: h1.inboundThreadNames) {
Assert.assertTrue(name.startsWith("l-"));
}
for (String name: h2.inboundThreadNames) {
Assert.assertTrue(name.startsWith("e1-"));
}
for (String name: h3.inboundThreadNames) {
Assert.assertTrue(name.startsWith("e2-"));
}
for (String name: h1.outboundThreadNames) {
Assert.assertTrue(name.startsWith("l-"));
}
for (String name: h2.outboundThreadNames) {
Assert.assertTrue(name.startsWith("e1-"));
}
for (String name: h3.outboundThreadNames) {
Assert.assertTrue(name.startsWith("e2-"));
}
// Assert that the events for the same handler were handled by the same thread.
Set<String> names = new HashSet<String>();
names.addAll(h1.inboundThreadNames);
names.addAll(h1.outboundThreadNames);
Assert.assertEquals(1, names.size());
// Assert that the events for the same handler were handled by the same thread.
Set<String> names = new HashSet<String>();
names.addAll(h1.inboundThreadNames);
names.addAll(h1.outboundThreadNames);
Assert.assertEquals(1, names.size());
names.clear();
names.addAll(h2.inboundThreadNames);
names.addAll(h2.outboundThreadNames);
Assert.assertEquals(1, names.size());
names.clear();
names.addAll(h2.inboundThreadNames);
names.addAll(h2.outboundThreadNames);
Assert.assertEquals(1, names.size());
names.clear();
names.addAll(h3.inboundThreadNames);
names.addAll(h3.outboundThreadNames);
Assert.assertEquals(1, names.size());
names.clear();
names.addAll(h3.inboundThreadNames);
names.addAll(h3.outboundThreadNames);
Assert.assertEquals(1, names.size());
// Count the number of events
Assert.assertEquals(1, h1.inboundThreadNames.size());
Assert.assertEquals(2, h2.inboundThreadNames.size());
Assert.assertEquals(3, h3.inboundThreadNames.size());
Assert.assertEquals(3, h1.outboundThreadNames.size());
Assert.assertEquals(2, h2.outboundThreadNames.size());
Assert.assertEquals(1, h3.outboundThreadNames.size());
// Count the number of events
Assert.assertEquals(1, h1.inboundThreadNames.size());
Assert.assertEquals(2, h2.inboundThreadNames.size());
Assert.assertEquals(3, h3.inboundThreadNames.size());
Assert.assertEquals(3, h1.outboundThreadNames.size());
Assert.assertEquals(2, h2.outboundThreadNames.size());
Assert.assertEquals(1, h3.outboundThreadNames.size());
if (h1.exception.get() != null) {
throw h1.exception.get();
}
if (h2.exception.get() != null) {
throw h2.exception.get();
}
if (h3.exception.get() != null) {
throw h3.exception.get();
}
} catch (AssertionError e) {
System.out.println("H1I: " + h1.inboundThreadNames);
System.out.println("H2I: " + h2.inboundThreadNames);
System.out.println("H3I: " + h3.inboundThreadNames);
System.out.println("H1O: " + h1.outboundThreadNames);
System.out.println("H2O: " + h2.outboundThreadNames);
System.out.println("H3O: " + h3.outboundThreadNames);
throw e;
}
}
private static class TestHandler extends ChannelHandlerAdapter<Object, Object> {
@Test
public void testConcurrentMessageBufferAccess() throws Exception {
EventLoop l = new LocalEventLoop(4, new PrefixThreadFactory("l"));
EventExecutor e1 = new DefaultEventExecutor(4, new PrefixThreadFactory("e1"));
EventExecutor e2 = new DefaultEventExecutor(4, new PrefixThreadFactory("e2"));
MessageForwarder h1 = new MessageForwarder();
MessageForwarder h2 = new MessageForwarder();
MessageDiscarder h3 = new MessageDiscarder();
private final Queue<String> inboundThreadNames = QueueFactory.createQueue();
private final Queue<String> outboundThreadNames = QueueFactory.createQueue();
Channel ch = new LocalChannel();
ch.pipeline().addLast(h1).addLast(e1, h2).addLast(e2, h3);
l.register(ch).sync().channel().connect(ADDR).sync();
for (int i = 0; i < 10000; i ++) {
ch.pipeline().inboundMessageBuffer().add(Integer.valueOf(i));
ch.pipeline().fireInboundBufferUpdated();
}
}
private static class ThreadNameAuditor extends ChannelHandlerAdapter<Object, Object> {
private final AtomicReference<Throwable> exception = new AtomicReference<Throwable>();
private final List<String> inboundThreadNames = Collections.synchronizedList(new ArrayList<String>());
private final List<String> outboundThreadNames = Collections.synchronizedList(new ArrayList<String>());
@Override
public ChannelBufferHolder<Object> newInboundBuffer(
@ -175,6 +221,69 @@ public class LocalTransportThreadModelTest {
outboundThreadNames.add(Thread.currentThread().getName());
ctx.flush(future);
}
@Override
public void exceptionCaught(ChannelInboundHandlerContext<Object> ctx,
Throwable cause) throws Exception {
exception.compareAndSet(null, cause);
System.err.print("[" + Thread.currentThread().getName() + "] ");
cause.printStackTrace();
super.exceptionCaught(ctx, cause);
}
}
private static class MessageForwarder extends ChannelInboundMessageHandlerAdapter<Object> {
private final AtomicReference<Throwable> exception = new AtomicReference<Throwable>();
private int counter;
@Override
public void messageReceived(ChannelInboundHandlerContext<Object> ctx,
Object msg) throws Exception {
Assert.assertEquals(counter ++, msg);
ctx.nextInboundMessageBuffer().add(msg);
}
@Override
public void exceptionCaught(ChannelInboundHandlerContext<Object> ctx,
Throwable cause) throws Exception {
exception.compareAndSet(null, cause);
System.err.print("[" + Thread.currentThread().getName() + "] ");
cause.printStackTrace();
super.exceptionCaught(ctx, cause);
}
}
private static class MessageDiscarder extends ChannelInboundHandlerAdapter<Object> {
private final AtomicReference<Throwable> exception = new AtomicReference<Throwable>();
private int counter;
@Override
public ChannelBufferHolder<Object> newInboundBuffer(
ChannelInboundHandlerContext<Object> ctx) throws Exception {
return ChannelBufferHolders.messageBuffer();
}
@Override
public void inboundBufferUpdated(
ChannelInboundHandlerContext<Object> ctx) throws Exception {
Queue<Object> in = ctx.inbound().messageBuffer();
for (;;) {
Object msg = in.poll();
Assert.assertEquals(counter ++, msg);
}
}
@Override
public void exceptionCaught(ChannelInboundHandlerContext<Object> ctx,
Throwable cause) throws Exception {
exception.compareAndSet(null, cause);
System.err.print("[" + Thread.currentThread().getName() + "] ");
cause.printStackTrace();
super.exceptionCaught(ctx, cause);
}
}
private static class PrefixThreadFactory implements ThreadFactory {