Ensure LocalChannel fire channelActive after peers's channelRegistered

- Also:
  - Made the test case more robust
  - Added a simple concurrent buffer modification test (needs more work)
This commit is contained in:
Trustin Lee 2012-06-03 12:54:26 -07:00
parent 361cb417e0
commit 234c4c70db
6 changed files with 196 additions and 64 deletions

View File

@ -424,10 +424,13 @@ public abstract class AbstractChannel extends DefaultAttributeMap implements Cha
} }
try { try {
doRegister(); Runnable postRegisterTask = doRegister();
registered = true; registered = true;
future.setSuccess(); future.setSuccess();
pipeline.fireChannelRegistered(); pipeline.fireChannelRegistered();
if (postRegisterTask != null) {
postRegisterTask.run();
}
if (isActive()) { if (isActive()) {
pipeline.fireChannelActive(); pipeline.fireChannelActive();
} }
@ -687,7 +690,7 @@ public abstract class AbstractChannel extends DefaultAttributeMap implements Cha
protected abstract SocketAddress localAddress0(); protected abstract SocketAddress localAddress0();
protected abstract SocketAddress remoteAddress0(); protected abstract SocketAddress remoteAddress0();
protected abstract void doRegister() throws Exception; protected abstract Runnable doRegister() throws Exception;
protected abstract void doBind(SocketAddress localAddress) throws Exception; protected abstract void doBind(SocketAddress localAddress) throws Exception;
protected abstract void doDisconnect() throws Exception; protected abstract void doDisconnect() throws Exception;
protected abstract void doClose() throws Exception; protected abstract void doClose() throws Exception;

View File

@ -124,22 +124,39 @@ public class LocalChannel extends AbstractChannel {
} }
@Override @Override
protected void doRegister() throws Exception { protected Runnable doRegister() throws Exception {
final LocalChannel peer = this.peer;
Runnable postRegisterTask;
if (peer != null) { if (peer != null) {
state = 2; state = 2;
peer.remoteAddress = parent().localAddress(); peer.remoteAddress = parent().localAddress();
peer.state = 2; peer.state = 2;
peer.eventLoop().execute(new Runnable() {
// Ensure the peer's channelActive event is triggered *after* this channel's
// channelRegistered event is triggered, so that this channel's pipeline is fully
// initialized by ChannelInitializer.
final EventLoop peerEventLoop = peer.eventLoop();
postRegisterTask = new Runnable() {
@Override @Override
public void run() { public void run() {
peer.connectFuture.setSuccess(); peerEventLoop.execute(new Runnable() {
peer.pipeline().fireChannelActive(); @Override
public void run() {
peer.connectFuture.setSuccess();
peer.pipeline().fireChannelActive();
}
});
} }
}); };
} else {
postRegisterTask = null;
} }
((SingleThreadEventLoop) eventLoop()).addShutdownHook(shutdownHook); ((SingleThreadEventLoop) eventLoop()).addShutdownHook(shutdownHook);
return postRegisterTask;
} }
@Override @Override

View File

@ -84,8 +84,9 @@ public class LocalServerChannel extends AbstractServerChannel {
} }
@Override @Override
protected void doRegister() throws Exception { protected Runnable doRegister() throws Exception {
((SingleThreadEventLoop) eventLoop()).addShutdownHook(shutdownHook); ((SingleThreadEventLoop) eventLoop()).addShutdownHook(shutdownHook);
return null;
} }
@Override @Override

View File

@ -206,10 +206,11 @@ public abstract class AbstractNioChannel extends AbstractChannel {
} }
@Override @Override
protected void doRegister() throws Exception { protected Runnable doRegister() throws Exception {
NioChildEventLoop loop = (NioChildEventLoop) eventLoop(); NioChildEventLoop loop = (NioChildEventLoop) eventLoop();
selectionKey = javaChannel().register( selectionKey = javaChannel().register(
loop.selector, isActive()? defaultInterestOps : 0, this); loop.selector, isActive()? defaultInterestOps : 0, this);
return null;
} }
@Override @Override

View File

@ -74,8 +74,9 @@ abstract class AbstractOioChannel extends AbstractChannel {
} }
@Override @Override
protected void doRegister() throws Exception { protected Runnable doRegister() throws Exception {
// NOOP // NOOP
return null;
} }
@Override @Override

View File

@ -6,6 +6,7 @@ import io.netty.channel.ChannelBufferHolder;
import io.netty.channel.ChannelBufferHolders; import io.netty.channel.ChannelBufferHolders;
import io.netty.channel.ChannelFuture; import io.netty.channel.ChannelFuture;
import io.netty.channel.ChannelHandlerAdapter; import io.netty.channel.ChannelHandlerAdapter;
import io.netty.channel.ChannelInboundHandlerAdapter;
import io.netty.channel.ChannelInboundHandlerContext; import io.netty.channel.ChannelInboundHandlerContext;
import io.netty.channel.ChannelInboundMessageHandlerAdapter; import io.netty.channel.ChannelInboundMessageHandlerAdapter;
import io.netty.channel.ChannelInitializer; import io.netty.channel.ChannelInitializer;
@ -13,13 +14,16 @@ import io.netty.channel.ChannelOutboundHandlerContext;
import io.netty.channel.DefaultEventExecutor; import io.netty.channel.DefaultEventExecutor;
import io.netty.channel.EventExecutor; import io.netty.channel.EventExecutor;
import io.netty.channel.EventLoop; import io.netty.channel.EventLoop;
import io.netty.util.internal.QueueFactory;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashSet; import java.util.HashSet;
import java.util.List;
import java.util.Queue; import java.util.Queue;
import java.util.Set; import java.util.Set;
import java.util.concurrent.ThreadFactory; import java.util.concurrent.ThreadFactory;
import java.util.concurrent.atomic.AtomicInteger; import java.util.concurrent.atomic.AtomicInteger;
import java.util.concurrent.atomic.AtomicReference;
import org.junit.AfterClass; import org.junit.AfterClass;
import org.junit.Assert; import org.junit.Assert;
@ -59,13 +63,13 @@ public class LocalTransportThreadModelTest {
} }
@Test @Test
public void testSimple() throws Exception { public void testStagedExecution() throws Throwable {
EventLoop l = new LocalEventLoop(4, new PrefixThreadFactory("l")); EventLoop l = new LocalEventLoop(4, new PrefixThreadFactory("l"));
EventExecutor e1 = new DefaultEventExecutor(4, new PrefixThreadFactory("e1")); EventExecutor e1 = new DefaultEventExecutor(4, new PrefixThreadFactory("e1"));
EventExecutor e2 = new DefaultEventExecutor(4, new PrefixThreadFactory("e2")); EventExecutor e2 = new DefaultEventExecutor(4, new PrefixThreadFactory("e2"));
TestHandler h1 = new TestHandler(); ThreadNameAuditor h1 = new ThreadNameAuditor();
TestHandler h2 = new TestHandler(); ThreadNameAuditor h2 = new ThreadNameAuditor();
TestHandler h3 = new TestHandler(); ThreadNameAuditor h3 = new ThreadNameAuditor();
Channel ch = new LocalChannel(); Channel ch = new LocalChannel();
// With no EventExecutor specified, h1 will be always invoked by EventLoop 'l'. // With no EventExecutor specified, h1 will be always invoked by EventLoop 'l'.
@ -90,63 +94,105 @@ public class LocalTransportThreadModelTest {
String currentName = Thread.currentThread().getName(); String currentName = Thread.currentThread().getName();
// Events should never be handled from the current thread. try {
Assert.assertFalse(h1.inboundThreadNames.contains(currentName)); // Events should never be handled from the current thread.
Assert.assertFalse(h2.inboundThreadNames.contains(currentName)); Assert.assertFalse(h1.inboundThreadNames.contains(currentName));
Assert.assertFalse(h3.inboundThreadNames.contains(currentName)); Assert.assertFalse(h2.inboundThreadNames.contains(currentName));
Assert.assertFalse(h1.outboundThreadNames.contains(currentName)); Assert.assertFalse(h3.inboundThreadNames.contains(currentName));
Assert.assertFalse(h2.outboundThreadNames.contains(currentName)); Assert.assertFalse(h1.outboundThreadNames.contains(currentName));
Assert.assertFalse(h3.outboundThreadNames.contains(currentName)); Assert.assertFalse(h2.outboundThreadNames.contains(currentName));
Assert.assertFalse(h3.outboundThreadNames.contains(currentName));
// Assert that events were handled by the correct executor. // Assert that events were handled by the correct executor.
for (String name: h1.inboundThreadNames) { for (String name: h1.inboundThreadNames) {
Assert.assertTrue(name.startsWith("l-")); Assert.assertTrue(name.startsWith("l-"));
} }
for (String name: h2.inboundThreadNames) { for (String name: h2.inboundThreadNames) {
Assert.assertTrue(name.startsWith("e1-")); Assert.assertTrue(name.startsWith("e1-"));
} }
for (String name: h3.inboundThreadNames) { for (String name: h3.inboundThreadNames) {
Assert.assertTrue(name.startsWith("e2-")); Assert.assertTrue(name.startsWith("e2-"));
} }
for (String name: h1.outboundThreadNames) { for (String name: h1.outboundThreadNames) {
Assert.assertTrue(name.startsWith("l-")); Assert.assertTrue(name.startsWith("l-"));
} }
for (String name: h2.outboundThreadNames) { for (String name: h2.outboundThreadNames) {
Assert.assertTrue(name.startsWith("e1-")); Assert.assertTrue(name.startsWith("e1-"));
} }
for (String name: h3.outboundThreadNames) { for (String name: h3.outboundThreadNames) {
Assert.assertTrue(name.startsWith("e2-")); Assert.assertTrue(name.startsWith("e2-"));
} }
// Assert that the events for the same handler were handled by the same thread. // Assert that the events for the same handler were handled by the same thread.
Set<String> names = new HashSet<String>(); Set<String> names = new HashSet<String>();
names.addAll(h1.inboundThreadNames); names.addAll(h1.inboundThreadNames);
names.addAll(h1.outboundThreadNames); names.addAll(h1.outboundThreadNames);
Assert.assertEquals(1, names.size()); Assert.assertEquals(1, names.size());
names.clear(); names.clear();
names.addAll(h2.inboundThreadNames); names.addAll(h2.inboundThreadNames);
names.addAll(h2.outboundThreadNames); names.addAll(h2.outboundThreadNames);
Assert.assertEquals(1, names.size()); Assert.assertEquals(1, names.size());
names.clear(); names.clear();
names.addAll(h3.inboundThreadNames); names.addAll(h3.inboundThreadNames);
names.addAll(h3.outboundThreadNames); names.addAll(h3.outboundThreadNames);
Assert.assertEquals(1, names.size()); Assert.assertEquals(1, names.size());
// Count the number of events // Count the number of events
Assert.assertEquals(1, h1.inboundThreadNames.size()); Assert.assertEquals(1, h1.inboundThreadNames.size());
Assert.assertEquals(2, h2.inboundThreadNames.size()); Assert.assertEquals(2, h2.inboundThreadNames.size());
Assert.assertEquals(3, h3.inboundThreadNames.size()); Assert.assertEquals(3, h3.inboundThreadNames.size());
Assert.assertEquals(3, h1.outboundThreadNames.size()); Assert.assertEquals(3, h1.outboundThreadNames.size());
Assert.assertEquals(2, h2.outboundThreadNames.size()); Assert.assertEquals(2, h2.outboundThreadNames.size());
Assert.assertEquals(1, h3.outboundThreadNames.size()); Assert.assertEquals(1, h3.outboundThreadNames.size());
if (h1.exception.get() != null) {
throw h1.exception.get();
}
if (h2.exception.get() != null) {
throw h2.exception.get();
}
if (h3.exception.get() != null) {
throw h3.exception.get();
}
} catch (AssertionError e) {
System.out.println("H1I: " + h1.inboundThreadNames);
System.out.println("H2I: " + h2.inboundThreadNames);
System.out.println("H3I: " + h3.inboundThreadNames);
System.out.println("H1O: " + h1.outboundThreadNames);
System.out.println("H2O: " + h2.outboundThreadNames);
System.out.println("H3O: " + h3.outboundThreadNames);
throw e;
}
} }
private static class TestHandler extends ChannelHandlerAdapter<Object, Object> { @Test
public void testConcurrentMessageBufferAccess() throws Exception {
EventLoop l = new LocalEventLoop(4, new PrefixThreadFactory("l"));
EventExecutor e1 = new DefaultEventExecutor(4, new PrefixThreadFactory("e1"));
EventExecutor e2 = new DefaultEventExecutor(4, new PrefixThreadFactory("e2"));
MessageForwarder h1 = new MessageForwarder();
MessageForwarder h2 = new MessageForwarder();
MessageDiscarder h3 = new MessageDiscarder();
private final Queue<String> inboundThreadNames = QueueFactory.createQueue(); Channel ch = new LocalChannel();
private final Queue<String> outboundThreadNames = QueueFactory.createQueue(); ch.pipeline().addLast(h1).addLast(e1, h2).addLast(e2, h3);
l.register(ch).sync().channel().connect(ADDR).sync();
for (int i = 0; i < 10000; i ++) {
ch.pipeline().inboundMessageBuffer().add(Integer.valueOf(i));
ch.pipeline().fireInboundBufferUpdated();
}
}
private static class ThreadNameAuditor extends ChannelHandlerAdapter<Object, Object> {
private final AtomicReference<Throwable> exception = new AtomicReference<Throwable>();
private final List<String> inboundThreadNames = Collections.synchronizedList(new ArrayList<String>());
private final List<String> outboundThreadNames = Collections.synchronizedList(new ArrayList<String>());
@Override @Override
public ChannelBufferHolder<Object> newInboundBuffer( public ChannelBufferHolder<Object> newInboundBuffer(
@ -175,6 +221,69 @@ public class LocalTransportThreadModelTest {
outboundThreadNames.add(Thread.currentThread().getName()); outboundThreadNames.add(Thread.currentThread().getName());
ctx.flush(future); ctx.flush(future);
} }
@Override
public void exceptionCaught(ChannelInboundHandlerContext<Object> ctx,
Throwable cause) throws Exception {
exception.compareAndSet(null, cause);
System.err.print("[" + Thread.currentThread().getName() + "] ");
cause.printStackTrace();
super.exceptionCaught(ctx, cause);
}
}
private static class MessageForwarder extends ChannelInboundMessageHandlerAdapter<Object> {
private final AtomicReference<Throwable> exception = new AtomicReference<Throwable>();
private int counter;
@Override
public void messageReceived(ChannelInboundHandlerContext<Object> ctx,
Object msg) throws Exception {
Assert.assertEquals(counter ++, msg);
ctx.nextInboundMessageBuffer().add(msg);
}
@Override
public void exceptionCaught(ChannelInboundHandlerContext<Object> ctx,
Throwable cause) throws Exception {
exception.compareAndSet(null, cause);
System.err.print("[" + Thread.currentThread().getName() + "] ");
cause.printStackTrace();
super.exceptionCaught(ctx, cause);
}
}
private static class MessageDiscarder extends ChannelInboundHandlerAdapter<Object> {
private final AtomicReference<Throwable> exception = new AtomicReference<Throwable>();
private int counter;
@Override
public ChannelBufferHolder<Object> newInboundBuffer(
ChannelInboundHandlerContext<Object> ctx) throws Exception {
return ChannelBufferHolders.messageBuffer();
}
@Override
public void inboundBufferUpdated(
ChannelInboundHandlerContext<Object> ctx) throws Exception {
Queue<Object> in = ctx.inbound().messageBuffer();
for (;;) {
Object msg = in.poll();
Assert.assertEquals(counter ++, msg);
}
}
@Override
public void exceptionCaught(ChannelInboundHandlerContext<Object> ctx,
Throwable cause) throws Exception {
exception.compareAndSet(null, cause);
System.err.print("[" + Thread.currentThread().getName() + "] ");
cause.printStackTrace();
super.exceptionCaught(ctx, cause);
}
} }
private static class PrefixThreadFactory implements ThreadFactory { private static class PrefixThreadFactory implements ThreadFactory {