diff --git a/transport/src/main/java/io/netty/channel/embedded/EmbeddedChannel.java b/transport/src/main/java/io/netty/channel/embedded/EmbeddedChannel.java index 894d03b8b7..06494f8250 100644 --- a/transport/src/main/java/io/netty/channel/embedded/EmbeddedChannel.java +++ b/transport/src/main/java/io/netty/channel/embedded/EmbeddedChannel.java @@ -35,6 +35,7 @@ import io.netty.channel.ChannelPromise; import io.netty.channel.DefaultChannelConfig; import io.netty.channel.DefaultChannelPipeline; import io.netty.channel.EventLoop; +import io.netty.channel.RecvByteBufAllocator; import io.netty.util.ReferenceCountUtil; import io.netty.util.internal.ObjectUtil; import io.netty.util.internal.PlatformDependent; @@ -701,7 +702,12 @@ public class EmbeddedChannel extends AbstractChannel { @Override protected AbstractUnsafe newUnsafe() { - return new DefaultUnsafe(); + return new EmbeddedUnsafe(); + } + + @Override + public Unsafe unsafe() { + return ((EmbeddedUnsafe) super.unsafe()).wrapped; } @Override @@ -734,7 +740,97 @@ public class EmbeddedChannel extends AbstractChannel { inboundMessages().add(msg); } - private class DefaultUnsafe extends AbstractUnsafe { + private final class EmbeddedUnsafe extends AbstractUnsafe { + + // Delegates to the EmbeddedUnsafe instance but ensures runPendingTasks() is called after each operation + // that may change the state of the Channel and may schedule tasks for later execution. + final Unsafe wrapped = new Unsafe() { + @Override + public RecvByteBufAllocator.Handle recvBufAllocHandle() { + return EmbeddedUnsafe.this.recvBufAllocHandle(); + } + + @Override + public SocketAddress localAddress() { + return EmbeddedUnsafe.this.localAddress(); + } + + @Override + public SocketAddress remoteAddress() { + return EmbeddedUnsafe.this.remoteAddress(); + } + + @Override + public void register(EventLoop eventLoop, ChannelPromise promise) { + EmbeddedUnsafe.this.register(eventLoop, promise); + runPendingTasks(); + } + + @Override + public void bind(SocketAddress localAddress, ChannelPromise promise) { + EmbeddedUnsafe.this.bind(localAddress, promise); + runPendingTasks(); + } + + @Override + public void connect(SocketAddress remoteAddress, SocketAddress localAddress, ChannelPromise promise) { + EmbeddedUnsafe.this.connect(remoteAddress, localAddress, promise); + runPendingTasks(); + } + + @Override + public void disconnect(ChannelPromise promise) { + EmbeddedUnsafe.this.disconnect(promise); + runPendingTasks(); + } + + @Override + public void close(ChannelPromise promise) { + EmbeddedUnsafe.this.close(promise); + runPendingTasks(); + } + + @Override + public void closeForcibly() { + EmbeddedUnsafe.this.closeForcibly(); + runPendingTasks(); + } + + @Override + public void deregister(ChannelPromise promise) { + EmbeddedUnsafe.this.deregister(promise); + runPendingTasks(); + } + + @Override + public void beginRead() { + EmbeddedUnsafe.this.beginRead(); + runPendingTasks(); + } + + @Override + public void write(Object msg, ChannelPromise promise) { + EmbeddedUnsafe.this.write(msg, promise); + runPendingTasks(); + } + + @Override + public void flush() { + EmbeddedUnsafe.this.flush(); + runPendingTasks(); + } + + @Override + public ChannelPromise voidPromise() { + return EmbeddedUnsafe.this.voidPromise(); + } + + @Override + public ChannelOutboundBuffer outboundBuffer() { + return EmbeddedUnsafe.this.outboundBuffer(); + } + }; + @Override public void connect(SocketAddress remoteAddress, SocketAddress localAddress, ChannelPromise promise) { safeSetSuccess(promise); @@ -742,7 +838,7 @@ public class EmbeddedChannel extends AbstractChannel { } private final class EmbeddedChannelPipeline extends DefaultChannelPipeline { - public EmbeddedChannelPipeline(EmbeddedChannel channel) { + EmbeddedChannelPipeline(EmbeddedChannel channel) { super(channel); } diff --git a/transport/src/test/java/io/netty/channel/embedded/EmbeddedChannelTest.java b/transport/src/test/java/io/netty/channel/embedded/EmbeddedChannelTest.java index 3917f8ba58..a5dae45afa 100644 --- a/transport/src/test/java/io/netty/channel/embedded/EmbeddedChannelTest.java +++ b/transport/src/test/java/io/netty/channel/embedded/EmbeddedChannelTest.java @@ -27,6 +27,7 @@ import java.util.ArrayDeque; import java.util.Queue; import java.util.concurrent.CountDownLatch; import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicBoolean; import java.util.concurrent.atomic.AtomicInteger; import java.util.concurrent.atomic.AtomicReference; @@ -563,6 +564,25 @@ public class EmbeddedChannelTest { } } + @Test(timeout = 5000) + public void testChannelInactiveFired() throws InterruptedException { + final AtomicBoolean inactive = new AtomicBoolean(); + EmbeddedChannel channel = new EmbeddedChannel(new ChannelInboundHandlerAdapter() { + @Override + public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) throws Exception { + ctx.close(); + } + + @Override + public void channelInactive(ChannelHandlerContext ctx) throws Exception { + inactive.set(true); + } + }); + channel.pipeline().fireExceptionCaught(new IllegalStateException()); + + assertTrue(inactive.get()); + } + private static void release(ByteBuf... buffers) { for (ByteBuf buffer : buffers) { if (buffer.refCnt() > 0) {