diff --git a/transport/src/main/java/io/netty/channel/embedded/EmbeddedChannel.java b/transport/src/main/java/io/netty/channel/embedded/EmbeddedChannel.java index 4080816db3..5842cff093 100644 --- a/transport/src/main/java/io/netty/channel/embedded/EmbeddedChannel.java +++ b/transport/src/main/java/io/netty/channel/embedded/EmbeddedChannel.java @@ -22,6 +22,7 @@ import io.netty.channel.ChannelFuture; import io.netty.channel.ChannelHandler; import io.netty.channel.ChannelHandlerContext; import io.netty.channel.ChannelInboundHandlerAdapter; +import io.netty.channel.ChannelInitializer; import io.netty.channel.ChannelMetadata; import io.netty.channel.ChannelOutboundBuffer; import io.netty.channel.ChannelPipeline; @@ -73,22 +74,41 @@ public class EmbeddedChannel extends AbstractChannel { * * @param handlers the @link ChannelHandler}s which will be add in the {@link ChannelPipeline} */ - public EmbeddedChannel(ChannelHandler... handlers) { + public EmbeddedChannel(final ChannelHandler... handlers) { super(null, EmbeddedChannelId.INSTANCE); if (handlers == null) { throw new NullPointerException("handlers"); } - ChannelPipeline p = pipeline(); + int nHandlers = 0; for (ChannelHandler h: handlers) { if (h == null) { break; } - p.addLast(h); + nHandlers ++; } - loop.register(this); + if (nHandlers == 0) { + throw new IllegalArgumentException("handlers is empty."); + } + + ChannelPipeline p = pipeline(); + p.addLast(new ChannelInitializer() { + @Override + protected void initChannel(Channel ch) throws Exception { + ChannelPipeline pipeline = ch.pipeline(); + for (ChannelHandler h: handlers) { + if (h == null) { + break; + } + pipeline.addLast(h); + } + } + }); + + ChannelFuture future = loop.register(this); + assert future.isDone(); p.addLast(new LastInboundHandler()); } diff --git a/transport/src/test/java/io/netty/channel/embedded/EmbeddedChannelTest.java b/transport/src/test/java/io/netty/channel/embedded/EmbeddedChannelTest.java index e32acfd1f7..b062a134c6 100644 --- a/transport/src/test/java/io/netty/channel/embedded/EmbeddedChannelTest.java +++ b/transport/src/test/java/io/netty/channel/embedded/EmbeddedChannelTest.java @@ -17,6 +17,7 @@ package io.netty.channel.embedded; import io.netty.channel.Channel; import io.netty.channel.ChannelHandler; +import io.netty.channel.ChannelHandlerAdapter; import io.netty.channel.ChannelHandlerContext; import io.netty.channel.ChannelInboundHandlerAdapter; import io.netty.channel.ChannelInitializer; @@ -29,6 +30,7 @@ import org.junit.Test; import java.util.concurrent.CountDownLatch; import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicReference; public class EmbeddedChannelTest { @@ -95,4 +97,29 @@ public class EmbeddedChannelTest { ch.finish(); Assert.assertTrue(future.isCancelled()); } + + @Test(timeout = 3000) + public void testHandlerAddedExecutedInEventLoop() throws Throwable { + final CountDownLatch latch = new CountDownLatch(1); + final AtomicReference error = new AtomicReference(); + final ChannelHandler handler = new ChannelHandlerAdapter() { + @Override + public void handlerAdded(ChannelHandlerContext ctx) throws Exception { + try { + Assert.assertTrue(ctx.executor().inEventLoop()); + } catch (Throwable cause) { + error.set(cause); + } finally { + latch.countDown(); + } + } + }; + EmbeddedChannel channel = new EmbeddedChannel(handler); + Assert.assertFalse(channel.finish()); + latch.await(); + Throwable cause = error.get(); + if (cause != null) { + throw cause; + } + } }