diff --git a/transport/src/main/java/io/netty/channel/DefaultChannelPipeline.java b/transport/src/main/java/io/netty/channel/DefaultChannelPipeline.java index 09d3d629c5..5390d535de 100644 --- a/transport/src/main/java/io/netty/channel/DefaultChannelPipeline.java +++ b/transport/src/main/java/io/netty/channel/DefaultChannelPipeline.java @@ -752,35 +752,36 @@ final class DefaultChannelPipeline implements ChannelPipeline { * See: https://github.com/netty/netty/issues/3156 */ private void destroy() { - destroyUp(head.next); + destroyUp(head.next, false); } - private void destroyUp(AbstractChannelHandlerContext ctx) { + private void destroyUp(AbstractChannelHandlerContext ctx, boolean inEventLoop) { final Thread currentThread = Thread.currentThread(); final AbstractChannelHandlerContext tail = this.tail; for (;;) { if (ctx == tail) { - destroyDown(currentThread, tail.prev); + destroyDown(currentThread, tail.prev, inEventLoop); break; } final EventExecutor executor = ctx.executor(); - if (!executor.inEventLoop(currentThread)) { + if (!inEventLoop && !executor.inEventLoop(currentThread)) { final AbstractChannelHandlerContext finalCtx = ctx; executor.execute(new OneTimeTask() { @Override public void run() { - destroyUp(finalCtx); + destroyUp(finalCtx, true); } }); break; } ctx = ctx.next; + inEventLoop = false; } } - private void destroyDown(Thread currentThread, AbstractChannelHandlerContext ctx) { + private void destroyDown(Thread currentThread, AbstractChannelHandlerContext ctx, boolean inEventLoop) { // We have reached at tail; now traverse backwards. final AbstractChannelHandlerContext head = this.head; for (;;) { @@ -789,7 +790,7 @@ final class DefaultChannelPipeline implements ChannelPipeline { } final EventExecutor executor = ctx.executor(); - if (executor.inEventLoop(currentThread)) { + if (inEventLoop || executor.inEventLoop(currentThread)) { synchronized (this) { remove0(ctx); } @@ -798,13 +799,14 @@ final class DefaultChannelPipeline implements ChannelPipeline { executor.execute(new OneTimeTask() { @Override public void run() { - destroyDown(Thread.currentThread(), finalCtx); + destroyDown(Thread.currentThread(), finalCtx, true); } }); break; } ctx = ctx.prev; + inEventLoop = false; } } diff --git a/transport/src/test/java/io/netty/channel/DefaultChannelPipelineTest.java b/transport/src/test/java/io/netty/channel/DefaultChannelPipelineTest.java index 11503bd90f..487557c5b1 100644 --- a/transport/src/test/java/io/netty/channel/DefaultChannelPipelineTest.java +++ b/transport/src/test/java/io/netty/channel/DefaultChannelPipelineTest.java @@ -29,6 +29,9 @@ import io.netty.channel.local.LocalServerChannel; import io.netty.util.AbstractReferenceCounted; import io.netty.util.ReferenceCountUtil; import io.netty.util.ReferenceCounted; +import io.netty.util.concurrent.AbstractEventExecutor; +import io.netty.util.concurrent.EventExecutorGroup; +import io.netty.util.concurrent.Future; import org.junit.After; import org.junit.AfterClass; import org.junit.Test; @@ -39,6 +42,8 @@ import java.util.Collections; import java.util.List; import java.util.Queue; import java.util.concurrent.CountDownLatch; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicReference; @@ -559,6 +564,37 @@ public class DefaultChannelPipelineTest { assertSame(exception, error.get()); } + @Test + public void testChannelUnregistrationWithCustomExecutor() throws Exception { + final CountDownLatch channelLatch = new CountDownLatch(1); + final CountDownLatch handlerLatch = new CountDownLatch(1); + ChannelPipeline pipeline = new LocalChannel().pipeline(); + pipeline.addLast(new ChannelInitializer() { + @Override + protected void initChannel(Channel ch) throws Exception { + ch.pipeline().addLast(new WrapperExecutor(), + new ChannelInboundHandlerAdapter() { + + @Override + public void channelUnregistered(ChannelHandlerContext ctx) throws Exception { + channelLatch.countDown(); + } + + @Override + public void handlerRemoved(ChannelHandlerContext ctx) throws Exception { + handlerLatch.countDown(); + } + }); + } + }); + Channel channel = pipeline.channel(); + group.register(channel); + channel.close(); + channel.deregister(); + assertTrue(channelLatch.await(2, TimeUnit.SECONDS)); + assertTrue(handlerLatch.await(2, TimeUnit.SECONDS)); + } + private static int next(AbstractChannelHandlerContext ctx) { AbstractChannelHandlerContext next = ctx.next; if (next == null) { @@ -670,4 +706,64 @@ public class DefaultChannelPipelineTest { afterRemove = true; } } + + private static final class WrapperExecutor extends AbstractEventExecutor { + + private final ExecutorService wrapped = Executors.newSingleThreadExecutor(); + + @Override + public boolean isShuttingDown() { + return wrapped.isShutdown(); + } + + @Override + public Future shutdownGracefully(long l, long l2, TimeUnit timeUnit) { + throw new IllegalStateException(); + } + + @Override + public Future terminationFuture() { + throw new IllegalStateException(); + } + + @Override + public void shutdown() { + wrapped.shutdown(); + } + + @Override + public List shutdownNow() { + return wrapped.shutdownNow(); + } + + @Override + public boolean isShutdown() { + return wrapped.isShutdown(); + } + + @Override + public boolean isTerminated() { + return wrapped.isTerminated(); + } + + @Override + public boolean awaitTermination(long timeout, TimeUnit unit) throws InterruptedException { + return wrapped.awaitTermination(timeout, unit); + } + + @Override + public EventExecutorGroup parent() { + return null; + } + + @Override + public boolean inEventLoop(Thread thread) { + return false; + } + + @Override + public void execute(Runnable command) { + wrapped.execute(command); + } + } }