diff --git a/transport/src/main/java/io/netty/channel/ChannelInitializer.java b/transport/src/main/java/io/netty/channel/ChannelInitializer.java index a23b071227..bb82a799ea 100644 --- a/transport/src/main/java/io/netty/channel/ChannelInitializer.java +++ b/transport/src/main/java/io/netty/channel/ChannelInitializer.java @@ -18,9 +18,12 @@ package io.netty.channel; import io.netty.bootstrap.Bootstrap; import io.netty.bootstrap.ServerBootstrap; import io.netty.channel.ChannelHandler.Sharable; +import io.netty.util.internal.PlatformDependent; import io.netty.util.internal.logging.InternalLogger; import io.netty.util.internal.logging.InternalLoggerFactory; +import java.util.concurrent.ConcurrentMap; + /** * A special {@link ChannelInboundHandler} which offers an easy way to initialize a {@link Channel} once it was * registered to its {@link EventLoop}. @@ -50,6 +53,9 @@ import io.netty.util.internal.logging.InternalLoggerFactory; public abstract class ChannelInitializer extends ChannelInboundHandlerAdapter { private static final InternalLogger logger = InternalLoggerFactory.getInstance(ChannelInitializer.class); + // We use a ConcurrentMap as a ChannelInitializer is usually shared between all Channels in a Bootstrap / + // ServerBootstrap. This way we can reduce the memory usage compared to use Attributes. + private final ConcurrentMap initMap = PlatformDependent.newConcurrentHashMap(); /** * This method will be called once the {@link Channel} was registered. After the method returns this instance @@ -65,9 +71,16 @@ public abstract class ChannelInitializer extends ChannelInbou @Override @SuppressWarnings("unchecked") public final void channelRegistered(ChannelHandlerContext ctx) throws Exception { - initChannel((C) ctx.channel()); - ctx.pipeline().remove(this); - ctx.pipeline().fireChannelRegistered(); + // Normally this method will never be called as handlerAdded(...) should call initChannel(...) and remove + // the handler. + if (initChannel(ctx)) { + // we called initChannel(...) so we need to call now pipeline.fireChannelRegistered() to ensure we not + // miss an event. + ctx.pipeline().fireChannelRegistered(); + } else { + // Called initChannel(...) before which is the expected behavior, so just forward the event. + ctx.fireChannelRegistered(); + } } /** @@ -76,13 +89,48 @@ public abstract class ChannelInitializer extends ChannelInbou @Override public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) throws Exception { logger.warn("Failed to initialize a channel. Closing: " + ctx.channel(), cause); + ctx.close(); + } + + /** + * {@inheritDoc} If override this method ensure you call super! + */ + @Override + public void handlerAdded(ChannelHandlerContext ctx) throws Exception { + if (ctx.channel().isRegistered()) { + // This should always be true with our current DefaultChannelPipeline implementation. + // The good thing about calling initChannel(...) in handlerAdded(...) is that there will be no ordering + // suprises if a ChannelInitializer will add another ChannelInitializer. This is as all handlers + // will be added in the expected order. + initChannel(ctx); + } + } + + @SuppressWarnings("unchecked") + private boolean initChannel(ChannelHandlerContext ctx) throws Exception { + if (initMap.putIfAbsent(ctx, Boolean.TRUE) == null) { // Guard against re-entrance. + try { + initChannel((C) ctx.channel()); + } catch (Throwable cause) { + // Explicitly call exceptionCaught(...) as we removed the handler before calling initChannel(...). + // We do so to prevent multiple calls to initChannel(...). + exceptionCaught(ctx, cause); + } finally { + remove(ctx); + } + return true; + } + return false; + } + + private void remove(ChannelHandlerContext ctx) { try { ChannelPipeline pipeline = ctx.pipeline(); if (pipeline.context(this) != null) { pipeline.remove(this); } } finally { - ctx.close(); + initMap.remove(ctx); } } } diff --git a/transport/src/test/java/io/netty/channel/ChannelInitializerTest.java b/transport/src/test/java/io/netty/channel/ChannelInitializerTest.java index f16e9c80a4..5ab48867bc 100644 --- a/transport/src/test/java/io/netty/channel/ChannelInitializerTest.java +++ b/transport/src/test/java/io/netty/channel/ChannelInitializerTest.java @@ -25,12 +25,16 @@ import org.junit.After; import org.junit.Before; import org.junit.Test; +import java.util.Iterator; +import java.util.Map; import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicBoolean; import java.util.concurrent.atomic.AtomicInteger; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertTrue; +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertSame; public class ChannelInitializerTest { private static final int TIMEOUT_MILLIS = 1000; @@ -59,6 +63,85 @@ public class ChannelInitializerTest { group.shutdownGracefully(0, TIMEOUT_MILLIS, TimeUnit.MILLISECONDS).syncUninterruptibly(); } + @Test + public void testChannelInitializerInInitializerCorrectOrdering() { + final ChannelInboundHandlerAdapter handler1 = new ChannelInboundHandlerAdapter(); + final ChannelInboundHandlerAdapter handler2 = new ChannelInboundHandlerAdapter(); + final ChannelInboundHandlerAdapter handler3 = new ChannelInboundHandlerAdapter(); + final ChannelInboundHandlerAdapter handler4 = new ChannelInboundHandlerAdapter(); + + client.handler(new ChannelInitializer() { + @Override + protected void initChannel(Channel ch) throws Exception { + ch.pipeline().addLast(handler1); + ch.pipeline().addLast(new ChannelInitializer() { + @Override + protected void initChannel(Channel ch) throws Exception { + ch.pipeline().addLast(handler2); + ch.pipeline().addLast(handler3); + } + }); + ch.pipeline().addLast(handler4); + } + }).localAddress(LocalAddress.ANY); + + Channel channel = client.bind().syncUninterruptibly().channel(); + try { + // Execute some task on the EventLoop and wait until its done to be sure all handlers are added to the + // pipeline. + channel.eventLoop().submit(new Runnable() { + @Override + public void run() { + // NOOP + } + }).syncUninterruptibly(); + Iterator> handlers = channel.pipeline().iterator(); + assertSame(handler1, handlers.next().getValue()); + assertSame(handler2, handlers.next().getValue()); + assertSame(handler3, handlers.next().getValue()); + assertSame(handler4, handlers.next().getValue()); + assertFalse(handlers.hasNext()); + } finally { + channel.close().syncUninterruptibly(); + } + } + + @Test + public void testChannelInitializerReentrance() { + final AtomicInteger registeredCalled = new AtomicInteger(0); + final ChannelInboundHandlerAdapter handler1 = new ChannelInboundHandlerAdapter() { + @Override + public void channelRegistered(ChannelHandlerContext ctx) throws Exception { + registeredCalled.incrementAndGet(); + } + }; + final AtomicInteger initChannelCalled = new AtomicInteger(0); + client.handler(new ChannelInitializer() { + @Override + protected void initChannel(Channel ch) throws Exception { + initChannelCalled.incrementAndGet(); + ch.pipeline().addLast(handler1); + ch.pipeline().fireChannelRegistered(); + } + }).localAddress(LocalAddress.ANY); + + Channel channel = client.bind().syncUninterruptibly().channel(); + try { + // Execute some task on the EventLoop and wait until its done to be sure all handlers are added to the + // pipeline. + channel.eventLoop().submit(new Runnable() { + @Override + public void run() { + // NOOP + } + }).syncUninterruptibly(); + assertEquals(1, initChannelCalled.get()); + assertEquals(2, registeredCalled.get()); + } finally { + channel.close().syncUninterruptibly(); + } + } + @Test(timeout = TIMEOUT_MILLIS) public void firstHandlerInPipelineShouldReceiveChannelRegisteredEvent() { testChannelRegisteredEventPropagation(new ChannelInitializer() {