diff --git a/transport/src/main/java/io/netty/channel/AbstractChannel.java b/transport/src/main/java/io/netty/channel/AbstractChannel.java index f238bd465e..67bcee7d80 100644 --- a/transport/src/main/java/io/netty/channel/AbstractChannel.java +++ b/transport/src/main/java/io/netty/channel/AbstractChannel.java @@ -424,10 +424,13 @@ public abstract class AbstractChannel extends DefaultAttributeMap implements Cha } try { - doRegister(); + Runnable postRegisterTask = doRegister(); registered = true; future.setSuccess(); pipeline.fireChannelRegistered(); + if (postRegisterTask != null) { + postRegisterTask.run(); + } if (isActive()) { pipeline.fireChannelActive(); } @@ -687,7 +690,7 @@ public abstract class AbstractChannel extends DefaultAttributeMap implements Cha protected abstract SocketAddress localAddress0(); protected abstract SocketAddress remoteAddress0(); - protected abstract void doRegister() throws Exception; + protected abstract Runnable doRegister() throws Exception; protected abstract void doBind(SocketAddress localAddress) throws Exception; protected abstract void doDisconnect() throws Exception; protected abstract void doClose() throws Exception; diff --git a/transport/src/main/java/io/netty/channel/local/LocalChannel.java b/transport/src/main/java/io/netty/channel/local/LocalChannel.java index 6979b6da75..c0bcfcae09 100644 --- a/transport/src/main/java/io/netty/channel/local/LocalChannel.java +++ b/transport/src/main/java/io/netty/channel/local/LocalChannel.java @@ -124,22 +124,39 @@ public class LocalChannel extends AbstractChannel { } @Override - protected void doRegister() throws Exception { + protected Runnable doRegister() throws Exception { + final LocalChannel peer = this.peer; + Runnable postRegisterTask; + if (peer != null) { state = 2; peer.remoteAddress = parent().localAddress(); peer.state = 2; - peer.eventLoop().execute(new Runnable() { + + // Ensure the peer's channelActive event is triggered *after* this channel's + // channelRegistered event is triggered, so that this channel's pipeline is fully + // initialized by ChannelInitializer. + final EventLoop peerEventLoop = peer.eventLoop(); + postRegisterTask = new Runnable() { @Override public void run() { - peer.connectFuture.setSuccess(); - peer.pipeline().fireChannelActive(); + peerEventLoop.execute(new Runnable() { + @Override + public void run() { + peer.connectFuture.setSuccess(); + peer.pipeline().fireChannelActive(); + } + }); } - }); + }; + } else { + postRegisterTask = null; } ((SingleThreadEventLoop) eventLoop()).addShutdownHook(shutdownHook); + + return postRegisterTask; } @Override diff --git a/transport/src/main/java/io/netty/channel/local/LocalServerChannel.java b/transport/src/main/java/io/netty/channel/local/LocalServerChannel.java index a6b0a60a5e..945a5a3ace 100644 --- a/transport/src/main/java/io/netty/channel/local/LocalServerChannel.java +++ b/transport/src/main/java/io/netty/channel/local/LocalServerChannel.java @@ -84,8 +84,9 @@ public class LocalServerChannel extends AbstractServerChannel { } @Override - protected void doRegister() throws Exception { + protected Runnable doRegister() throws Exception { ((SingleThreadEventLoop) eventLoop()).addShutdownHook(shutdownHook); + return null; } @Override diff --git a/transport/src/main/java/io/netty/channel/socket/nio/AbstractNioChannel.java b/transport/src/main/java/io/netty/channel/socket/nio/AbstractNioChannel.java index 45c1d29e76..5e4a3945d6 100644 --- a/transport/src/main/java/io/netty/channel/socket/nio/AbstractNioChannel.java +++ b/transport/src/main/java/io/netty/channel/socket/nio/AbstractNioChannel.java @@ -206,10 +206,11 @@ public abstract class AbstractNioChannel extends AbstractChannel { } @Override - protected void doRegister() throws Exception { + protected Runnable doRegister() throws Exception { NioChildEventLoop loop = (NioChildEventLoop) eventLoop(); selectionKey = javaChannel().register( loop.selector, isActive()? defaultInterestOps : 0, this); + return null; } @Override diff --git a/transport/src/main/java/io/netty/channel/socket/oio/AbstractOioChannel.java b/transport/src/main/java/io/netty/channel/socket/oio/AbstractOioChannel.java index d30dff40ab..18840658a4 100644 --- a/transport/src/main/java/io/netty/channel/socket/oio/AbstractOioChannel.java +++ b/transport/src/main/java/io/netty/channel/socket/oio/AbstractOioChannel.java @@ -74,8 +74,9 @@ abstract class AbstractOioChannel extends AbstractChannel { } @Override - protected void doRegister() throws Exception { + protected Runnable doRegister() throws Exception { // NOOP + return null; } @Override diff --git a/transport/src/test/java/io/netty/channel/local/LocalTransportThreadModelTest.java b/transport/src/test/java/io/netty/channel/local/LocalTransportThreadModelTest.java index 6476b573a3..0c64fc4756 100644 --- a/transport/src/test/java/io/netty/channel/local/LocalTransportThreadModelTest.java +++ b/transport/src/test/java/io/netty/channel/local/LocalTransportThreadModelTest.java @@ -6,6 +6,7 @@ import io.netty.channel.ChannelBufferHolder; import io.netty.channel.ChannelBufferHolders; import io.netty.channel.ChannelFuture; import io.netty.channel.ChannelHandlerAdapter; +import io.netty.channel.ChannelInboundHandlerAdapter; import io.netty.channel.ChannelInboundHandlerContext; import io.netty.channel.ChannelInboundMessageHandlerAdapter; import io.netty.channel.ChannelInitializer; @@ -13,13 +14,16 @@ import io.netty.channel.ChannelOutboundHandlerContext; import io.netty.channel.DefaultEventExecutor; import io.netty.channel.EventExecutor; import io.netty.channel.EventLoop; -import io.netty.util.internal.QueueFactory; +import java.util.ArrayList; +import java.util.Collections; import java.util.HashSet; +import java.util.List; import java.util.Queue; import java.util.Set; import java.util.concurrent.ThreadFactory; import java.util.concurrent.atomic.AtomicInteger; +import java.util.concurrent.atomic.AtomicReference; import org.junit.AfterClass; import org.junit.Assert; @@ -59,13 +63,13 @@ public class LocalTransportThreadModelTest { } @Test - public void testSimple() throws Exception { + public void testStagedExecution() throws Throwable { EventLoop l = new LocalEventLoop(4, new PrefixThreadFactory("l")); EventExecutor e1 = new DefaultEventExecutor(4, new PrefixThreadFactory("e1")); EventExecutor e2 = new DefaultEventExecutor(4, new PrefixThreadFactory("e2")); - TestHandler h1 = new TestHandler(); - TestHandler h2 = new TestHandler(); - TestHandler h3 = new TestHandler(); + ThreadNameAuditor h1 = new ThreadNameAuditor(); + ThreadNameAuditor h2 = new ThreadNameAuditor(); + ThreadNameAuditor h3 = new ThreadNameAuditor(); Channel ch = new LocalChannel(); // With no EventExecutor specified, h1 will be always invoked by EventLoop 'l'. @@ -90,63 +94,105 @@ public class LocalTransportThreadModelTest { String currentName = Thread.currentThread().getName(); - // Events should never be handled from the current thread. - Assert.assertFalse(h1.inboundThreadNames.contains(currentName)); - Assert.assertFalse(h2.inboundThreadNames.contains(currentName)); - Assert.assertFalse(h3.inboundThreadNames.contains(currentName)); - Assert.assertFalse(h1.outboundThreadNames.contains(currentName)); - Assert.assertFalse(h2.outboundThreadNames.contains(currentName)); - Assert.assertFalse(h3.outboundThreadNames.contains(currentName)); + try { + // Events should never be handled from the current thread. + Assert.assertFalse(h1.inboundThreadNames.contains(currentName)); + Assert.assertFalse(h2.inboundThreadNames.contains(currentName)); + Assert.assertFalse(h3.inboundThreadNames.contains(currentName)); + Assert.assertFalse(h1.outboundThreadNames.contains(currentName)); + Assert.assertFalse(h2.outboundThreadNames.contains(currentName)); + Assert.assertFalse(h3.outboundThreadNames.contains(currentName)); - // Assert that events were handled by the correct executor. - for (String name: h1.inboundThreadNames) { - Assert.assertTrue(name.startsWith("l-")); - } - for (String name: h2.inboundThreadNames) { - Assert.assertTrue(name.startsWith("e1-")); - } - for (String name: h3.inboundThreadNames) { - Assert.assertTrue(name.startsWith("e2-")); - } - for (String name: h1.outboundThreadNames) { - Assert.assertTrue(name.startsWith("l-")); - } - for (String name: h2.outboundThreadNames) { - Assert.assertTrue(name.startsWith("e1-")); - } - for (String name: h3.outboundThreadNames) { - Assert.assertTrue(name.startsWith("e2-")); - } + // Assert that events were handled by the correct executor. + for (String name: h1.inboundThreadNames) { + Assert.assertTrue(name.startsWith("l-")); + } + for (String name: h2.inboundThreadNames) { + Assert.assertTrue(name.startsWith("e1-")); + } + for (String name: h3.inboundThreadNames) { + Assert.assertTrue(name.startsWith("e2-")); + } + for (String name: h1.outboundThreadNames) { + Assert.assertTrue(name.startsWith("l-")); + } + for (String name: h2.outboundThreadNames) { + Assert.assertTrue(name.startsWith("e1-")); + } + for (String name: h3.outboundThreadNames) { + Assert.assertTrue(name.startsWith("e2-")); + } - // Assert that the events for the same handler were handled by the same thread. - Set names = new HashSet(); - names.addAll(h1.inboundThreadNames); - names.addAll(h1.outboundThreadNames); - Assert.assertEquals(1, names.size()); + // Assert that the events for the same handler were handled by the same thread. + Set names = new HashSet(); + names.addAll(h1.inboundThreadNames); + names.addAll(h1.outboundThreadNames); + Assert.assertEquals(1, names.size()); - names.clear(); - names.addAll(h2.inboundThreadNames); - names.addAll(h2.outboundThreadNames); - Assert.assertEquals(1, names.size()); + names.clear(); + names.addAll(h2.inboundThreadNames); + names.addAll(h2.outboundThreadNames); + Assert.assertEquals(1, names.size()); - names.clear(); - names.addAll(h3.inboundThreadNames); - names.addAll(h3.outboundThreadNames); - Assert.assertEquals(1, names.size()); + names.clear(); + names.addAll(h3.inboundThreadNames); + names.addAll(h3.outboundThreadNames); + Assert.assertEquals(1, names.size()); - // Count the number of events - Assert.assertEquals(1, h1.inboundThreadNames.size()); - Assert.assertEquals(2, h2.inboundThreadNames.size()); - Assert.assertEquals(3, h3.inboundThreadNames.size()); - Assert.assertEquals(3, h1.outboundThreadNames.size()); - Assert.assertEquals(2, h2.outboundThreadNames.size()); - Assert.assertEquals(1, h3.outboundThreadNames.size()); + // Count the number of events + Assert.assertEquals(1, h1.inboundThreadNames.size()); + Assert.assertEquals(2, h2.inboundThreadNames.size()); + Assert.assertEquals(3, h3.inboundThreadNames.size()); + Assert.assertEquals(3, h1.outboundThreadNames.size()); + Assert.assertEquals(2, h2.outboundThreadNames.size()); + Assert.assertEquals(1, h3.outboundThreadNames.size()); + + if (h1.exception.get() != null) { + throw h1.exception.get(); + } + if (h2.exception.get() != null) { + throw h2.exception.get(); + } + if (h3.exception.get() != null) { + throw h3.exception.get(); + } + } catch (AssertionError e) { + System.out.println("H1I: " + h1.inboundThreadNames); + System.out.println("H2I: " + h2.inboundThreadNames); + System.out.println("H3I: " + h3.inboundThreadNames); + System.out.println("H1O: " + h1.outboundThreadNames); + System.out.println("H2O: " + h2.outboundThreadNames); + System.out.println("H3O: " + h3.outboundThreadNames); + throw e; + } } - private static class TestHandler extends ChannelHandlerAdapter { + @Test + public void testConcurrentMessageBufferAccess() throws Exception { + EventLoop l = new LocalEventLoop(4, new PrefixThreadFactory("l")); + EventExecutor e1 = new DefaultEventExecutor(4, new PrefixThreadFactory("e1")); + EventExecutor e2 = new DefaultEventExecutor(4, new PrefixThreadFactory("e2")); + MessageForwarder h1 = new MessageForwarder(); + MessageForwarder h2 = new MessageForwarder(); + MessageDiscarder h3 = new MessageDiscarder(); - private final Queue inboundThreadNames = QueueFactory.createQueue(); - private final Queue outboundThreadNames = QueueFactory.createQueue(); + Channel ch = new LocalChannel(); + ch.pipeline().addLast(h1).addLast(e1, h2).addLast(e2, h3); + + l.register(ch).sync().channel().connect(ADDR).sync(); + + for (int i = 0; i < 10000; i ++) { + ch.pipeline().inboundMessageBuffer().add(Integer.valueOf(i)); + ch.pipeline().fireInboundBufferUpdated(); + } + } + + private static class ThreadNameAuditor extends ChannelHandlerAdapter { + + private final AtomicReference exception = new AtomicReference(); + + private final List inboundThreadNames = Collections.synchronizedList(new ArrayList()); + private final List outboundThreadNames = Collections.synchronizedList(new ArrayList()); @Override public ChannelBufferHolder newInboundBuffer( @@ -175,6 +221,69 @@ public class LocalTransportThreadModelTest { outboundThreadNames.add(Thread.currentThread().getName()); ctx.flush(future); } + + @Override + public void exceptionCaught(ChannelInboundHandlerContext ctx, + Throwable cause) throws Exception { + exception.compareAndSet(null, cause); + System.err.print("[" + Thread.currentThread().getName() + "] "); + cause.printStackTrace(); + super.exceptionCaught(ctx, cause); + } + } + + private static class MessageForwarder extends ChannelInboundMessageHandlerAdapter { + + private final AtomicReference exception = new AtomicReference(); + private int counter; + + @Override + public void messageReceived(ChannelInboundHandlerContext ctx, + Object msg) throws Exception { + Assert.assertEquals(counter ++, msg); + ctx.nextInboundMessageBuffer().add(msg); + } + + @Override + public void exceptionCaught(ChannelInboundHandlerContext ctx, + Throwable cause) throws Exception { + exception.compareAndSet(null, cause); + System.err.print("[" + Thread.currentThread().getName() + "] "); + cause.printStackTrace(); + super.exceptionCaught(ctx, cause); + } + } + + private static class MessageDiscarder extends ChannelInboundHandlerAdapter { + + private final AtomicReference exception = new AtomicReference(); + private int counter; + + @Override + public ChannelBufferHolder newInboundBuffer( + ChannelInboundHandlerContext ctx) throws Exception { + return ChannelBufferHolders.messageBuffer(); + } + + @Override + public void inboundBufferUpdated( + ChannelInboundHandlerContext ctx) throws Exception { + Queue in = ctx.inbound().messageBuffer(); + for (;;) { + Object msg = in.poll(); + Assert.assertEquals(counter ++, msg); + } + + } + + @Override + public void exceptionCaught(ChannelInboundHandlerContext ctx, + Throwable cause) throws Exception { + exception.compareAndSet(null, cause); + System.err.print("[" + Thread.currentThread().getName() + "] "); + cause.printStackTrace(); + super.exceptionCaught(ctx, cause); + } } private static class PrefixThreadFactory implements ThreadFactory {