From eb2118eeb44959ab2f23a5809acd60a8ff1a38d8 Mon Sep 17 00:00:00 2001 From: Norman Maurer Date: Wed, 5 Dec 2018 19:30:17 +0100 Subject: [PATCH] ChannelInitializer may be invoked multiple times when used with custom EventExecutor. (#8620) Motivation: The ChannelInitializer may be invoked multipled times when used with a custom EventExecutor as removal operation may be done asynchronously. We need to guard against this. Modifications: - Change Map to Set which is more correct in terms of how we use it. - Ensure we only modify the internal Set when the handler was removed yet - Add unit test. Result: Fixes https://github.com/netty/netty/issues/8616. --- .../io/netty/channel/ChannelInitializer.java | 32 ++++- .../netty/channel/ChannelInitializerTest.java | 126 ++++++++++++++++++ 2 files changed, 151 insertions(+), 7 deletions(-) diff --git a/transport/src/main/java/io/netty/channel/ChannelInitializer.java b/transport/src/main/java/io/netty/channel/ChannelInitializer.java index 9ea1b18221..18344d200f 100644 --- a/transport/src/main/java/io/netty/channel/ChannelInitializer.java +++ b/transport/src/main/java/io/netty/channel/ChannelInitializer.java @@ -18,11 +18,12 @@ package io.netty.channel; import io.netty.bootstrap.Bootstrap; import io.netty.bootstrap.ServerBootstrap; import io.netty.channel.ChannelHandler.Sharable; -import io.netty.util.internal.PlatformDependent; import io.netty.util.internal.logging.InternalLogger; import io.netty.util.internal.logging.InternalLoggerFactory; -import java.util.concurrent.ConcurrentMap; +import java.util.Collections; +import java.util.Set; +import java.util.concurrent.ConcurrentHashMap; /** * A special {@link ChannelInboundHandler} which offers an easy way to initialize a {@link Channel} once it was @@ -53,9 +54,10 @@ import java.util.concurrent.ConcurrentMap; public abstract class ChannelInitializer extends ChannelInboundHandlerAdapter { private static final InternalLogger logger = InternalLoggerFactory.getInstance(ChannelInitializer.class); - // We use a ConcurrentMap as a ChannelInitializer is usually shared between all Channels in a Bootstrap / + // We use a Set as a ChannelInitializer is usually shared between all Channels in a Bootstrap / // ServerBootstrap. This way we can reduce the memory usage compared to use Attributes. - private final ConcurrentMap initMap = PlatformDependent.newConcurrentHashMap(); + private final Set initMap = Collections.newSetFromMap( + new ConcurrentHashMap()); /** * This method will be called once the {@link Channel} was registered. After the method returns this instance @@ -108,9 +110,14 @@ public abstract class ChannelInitializer extends ChannelInbou } } + @Override + public void handlerRemoved(ChannelHandlerContext ctx) throws Exception { + initMap.remove(ctx); + } + @SuppressWarnings("unchecked") private boolean initChannel(ChannelHandlerContext ctx) throws Exception { - if (initMap.putIfAbsent(ctx, Boolean.TRUE) == null) { // Guard against re-entrance. + if (initMap.add(ctx)) { // Guard against re-entrance. try { initChannel((C) ctx.channel()); } catch (Throwable cause) { @@ -125,14 +132,25 @@ public abstract class ChannelInitializer extends ChannelInbou return false; } - private void remove(ChannelHandlerContext ctx) { + private void remove(final ChannelHandlerContext ctx) { try { ChannelPipeline pipeline = ctx.pipeline(); if (pipeline.context(this) != null) { pipeline.remove(this); } } finally { - initMap.remove(ctx); + // The removal may happen in an async fashion if the EventExecutor we use does something funky. + if (ctx.isRemoved()) { + initMap.remove(ctx); + } else { + // Ensure we always remove from the Map in all cases to not produce a memory leak. + ctx.channel().closeFuture().addListener(new ChannelFutureListener() { + @Override + public void operationComplete(ChannelFuture future) { + initMap.remove(ctx); + } + }); + } } } } diff --git a/transport/src/test/java/io/netty/channel/ChannelInitializerTest.java b/transport/src/test/java/io/netty/channel/ChannelInitializerTest.java index 26b5e4e9fc..2ac1bcdefa 100644 --- a/transport/src/test/java/io/netty/channel/ChannelInitializerTest.java +++ b/transport/src/test/java/io/netty/channel/ChannelInitializerTest.java @@ -21,12 +21,16 @@ import io.netty.channel.embedded.EmbeddedChannel; import io.netty.channel.local.LocalAddress; import io.netty.channel.local.LocalChannel; import io.netty.channel.local.LocalServerChannel; +import io.netty.util.concurrent.EventExecutor; +import io.netty.util.concurrent.Future; import org.junit.After; import org.junit.Before; import org.junit.Test; import java.util.Iterator; import java.util.Map; +import java.util.concurrent.Executors; +import java.util.concurrent.ScheduledExecutorService; import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicBoolean; import java.util.concurrent.atomic.AtomicInteger; @@ -35,6 +39,7 @@ import java.util.concurrent.atomic.AtomicReference; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertTrue; import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertNotNull; import static org.junit.Assert.assertSame; public class ChannelInitializerTest { @@ -249,6 +254,127 @@ public class ChannelInitializerTest { } } + @SuppressWarnings("deprecation") + @Test(timeout = 10000) + public void testChannelInitializerEventExecutor() throws Throwable { + final AtomicInteger invokeCount = new AtomicInteger(); + final AtomicInteger completeCount = new AtomicInteger(); + final AtomicReference errorRef = new AtomicReference(); + LocalAddress addr = new LocalAddress("test"); + + final EventExecutor executor = new DefaultEventLoop() { + private final ScheduledExecutorService execService = Executors.newSingleThreadScheduledExecutor(); + + @Override + public void shutdown() { + execService.shutdown(); + } + + @Override + public boolean inEventLoop(Thread thread) { + // Always return false which will ensure we always call execute(...) + return false; + } + + @Override + public boolean isShuttingDown() { + return false; + } + + @Override + public Future shutdownGracefully(long quietPeriod, long timeout, TimeUnit unit) { + throw new IllegalStateException(); + } + + @Override + public Future terminationFuture() { + throw new IllegalStateException(); + } + + @Override + public boolean isShutdown() { + return execService.isShutdown(); + } + + @Override + public boolean isTerminated() { + return execService.isTerminated(); + } + + @Override + public boolean awaitTermination(long timeout, TimeUnit unit) throws InterruptedException { + return execService.awaitTermination(timeout, unit); + } + + @Override + public void execute(Runnable command) { + execService.execute(command); + } + }; + + ServerBootstrap serverBootstrap = new ServerBootstrap() + .channel(LocalServerChannel.class) + .group(group) + .localAddress(addr) + .childHandler(new ChannelInitializer() { + @Override + protected void initChannel(LocalChannel ch) { + ch.pipeline().addLast(executor, new ChannelInitializer() { + @Override + protected void initChannel(Channel ch) { + invokeCount.incrementAndGet(); + ChannelHandlerContext ctx = ch.pipeline().context(this); + assertNotNull(ctx); + ch.pipeline().addAfter(ctx.executor(), + ctx.name(), null, new ChannelInboundHandlerAdapter() { + @Override + public void channelRead(ChannelHandlerContext ctx, Object msg) { + // just drop on the floor. + } + }); + completeCount.incrementAndGet(); + } + + @Override + public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) { + errorRef.set(cause); + } + }); + } + }); + + Channel server = serverBootstrap.bind().sync().channel(); + + Bootstrap clientBootstrap = new Bootstrap() + .channel(LocalChannel.class) + .group(group) + .remoteAddress(addr) + .handler(new ChannelInboundHandlerAdapter()); + + Channel client = clientBootstrap.connect().sync().channel(); + client.writeAndFlush("Hello World").sync(); + + client.close().sync(); + server.close().sync(); + + client.closeFuture().sync(); + server.closeFuture().sync(); + + // Give some time to execute everything that was submitted before. + Thread.sleep(1000); + + executor.shutdown(); + assertTrue(executor.awaitTermination(5, TimeUnit.SECONDS)); + + assertEquals(invokeCount.get(), 1); + assertEquals(invokeCount.get(), completeCount.get()); + + Throwable cause = errorRef.get(); + if (cause != null) { + throw cause; + } + } + private static void closeChannel(Channel c) { if (c != null) { c.close().syncUninterruptibly();