ChannelInitializer may be invoked multiple times when used with custom EventExecutor. (#8620)

Motivation:

The ChannelInitializer may be invoked multipled times when used with a custom EventExecutor as removal operation may be done asynchronously. We need to guard against this.

Modifications:

- Change Map to Set which is more correct in terms of how we use it.
- Ensure we only modify the internal Set when the handler was removed yet
- Add unit test.

Result:

Fixes https://github.com/netty/netty/issues/8616.
This commit is contained in:
Norman Maurer 2018-12-05 19:30:17 +01:00
parent 564d5833cc
commit eb2118eeb4
2 changed files with 151 additions and 7 deletions

View File

@ -18,11 +18,12 @@ package io.netty.channel;
import io.netty.bootstrap.Bootstrap; import io.netty.bootstrap.Bootstrap;
import io.netty.bootstrap.ServerBootstrap; import io.netty.bootstrap.ServerBootstrap;
import io.netty.channel.ChannelHandler.Sharable; import io.netty.channel.ChannelHandler.Sharable;
import io.netty.util.internal.PlatformDependent;
import io.netty.util.internal.logging.InternalLogger; import io.netty.util.internal.logging.InternalLogger;
import io.netty.util.internal.logging.InternalLoggerFactory; import io.netty.util.internal.logging.InternalLoggerFactory;
import java.util.concurrent.ConcurrentMap; import java.util.Collections;
import java.util.Set;
import java.util.concurrent.ConcurrentHashMap;
/** /**
* A special {@link ChannelInboundHandler} which offers an easy way to initialize a {@link Channel} once it was * A special {@link ChannelInboundHandler} which offers an easy way to initialize a {@link Channel} once it was
@ -53,9 +54,10 @@ import java.util.concurrent.ConcurrentMap;
public abstract class ChannelInitializer<C extends Channel> extends ChannelInboundHandlerAdapter { public abstract class ChannelInitializer<C extends Channel> extends ChannelInboundHandlerAdapter {
private static final InternalLogger logger = InternalLoggerFactory.getInstance(ChannelInitializer.class); private static final InternalLogger logger = InternalLoggerFactory.getInstance(ChannelInitializer.class);
// We use a ConcurrentMap as a ChannelInitializer is usually shared between all Channels in a Bootstrap / // We use a Set as a ChannelInitializer is usually shared between all Channels in a Bootstrap /
// ServerBootstrap. This way we can reduce the memory usage compared to use Attributes. // ServerBootstrap. This way we can reduce the memory usage compared to use Attributes.
private final ConcurrentMap<ChannelHandlerContext, Boolean> initMap = PlatformDependent.newConcurrentHashMap(); private final Set<ChannelHandlerContext> initMap = Collections.newSetFromMap(
new ConcurrentHashMap<ChannelHandlerContext, Boolean>());
/** /**
* This method will be called once the {@link Channel} was registered. After the method returns this instance * This method will be called once the {@link Channel} was registered. After the method returns this instance
@ -108,9 +110,14 @@ public abstract class ChannelInitializer<C extends Channel> extends ChannelInbou
} }
} }
@Override
public void handlerRemoved(ChannelHandlerContext ctx) throws Exception {
initMap.remove(ctx);
}
@SuppressWarnings("unchecked") @SuppressWarnings("unchecked")
private boolean initChannel(ChannelHandlerContext ctx) throws Exception { private boolean initChannel(ChannelHandlerContext ctx) throws Exception {
if (initMap.putIfAbsent(ctx, Boolean.TRUE) == null) { // Guard against re-entrance. if (initMap.add(ctx)) { // Guard against re-entrance.
try { try {
initChannel((C) ctx.channel()); initChannel((C) ctx.channel());
} catch (Throwable cause) { } catch (Throwable cause) {
@ -125,14 +132,25 @@ public abstract class ChannelInitializer<C extends Channel> extends ChannelInbou
return false; return false;
} }
private void remove(ChannelHandlerContext ctx) { private void remove(final ChannelHandlerContext ctx) {
try { try {
ChannelPipeline pipeline = ctx.pipeline(); ChannelPipeline pipeline = ctx.pipeline();
if (pipeline.context(this) != null) { if (pipeline.context(this) != null) {
pipeline.remove(this); pipeline.remove(this);
} }
} finally { } finally {
// The removal may happen in an async fashion if the EventExecutor we use does something funky.
if (ctx.isRemoved()) {
initMap.remove(ctx);
} else {
// Ensure we always remove from the Map in all cases to not produce a memory leak.
ctx.channel().closeFuture().addListener(new ChannelFutureListener() {
@Override
public void operationComplete(ChannelFuture future) {
initMap.remove(ctx); initMap.remove(ctx);
} }
});
}
}
} }
} }

View File

@ -21,12 +21,16 @@ import io.netty.channel.embedded.EmbeddedChannel;
import io.netty.channel.local.LocalAddress; import io.netty.channel.local.LocalAddress;
import io.netty.channel.local.LocalChannel; import io.netty.channel.local.LocalChannel;
import io.netty.channel.local.LocalServerChannel; import io.netty.channel.local.LocalServerChannel;
import io.netty.util.concurrent.EventExecutor;
import io.netty.util.concurrent.Future;
import org.junit.After; import org.junit.After;
import org.junit.Before; import org.junit.Before;
import org.junit.Test; import org.junit.Test;
import java.util.Iterator; import java.util.Iterator;
import java.util.Map; import java.util.Map;
import java.util.concurrent.Executors;
import java.util.concurrent.ScheduledExecutorService;
import java.util.concurrent.TimeUnit; import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicBoolean; import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicInteger; import java.util.concurrent.atomic.AtomicInteger;
@ -35,6 +39,7 @@ import java.util.concurrent.atomic.AtomicReference;
import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertTrue; import static org.junit.Assert.assertTrue;
import static org.junit.Assert.assertFalse; import static org.junit.Assert.assertFalse;
import static org.junit.Assert.assertNotNull;
import static org.junit.Assert.assertSame; import static org.junit.Assert.assertSame;
public class ChannelInitializerTest { public class ChannelInitializerTest {
@ -249,6 +254,127 @@ public class ChannelInitializerTest {
} }
} }
@SuppressWarnings("deprecation")
@Test(timeout = 10000)
public void testChannelInitializerEventExecutor() throws Throwable {
final AtomicInteger invokeCount = new AtomicInteger();
final AtomicInteger completeCount = new AtomicInteger();
final AtomicReference<Throwable> errorRef = new AtomicReference<Throwable>();
LocalAddress addr = new LocalAddress("test");
final EventExecutor executor = new DefaultEventLoop() {
private final ScheduledExecutorService execService = Executors.newSingleThreadScheduledExecutor();
@Override
public void shutdown() {
execService.shutdown();
}
@Override
public boolean inEventLoop(Thread thread) {
// Always return false which will ensure we always call execute(...)
return false;
}
@Override
public boolean isShuttingDown() {
return false;
}
@Override
public Future<?> shutdownGracefully(long quietPeriod, long timeout, TimeUnit unit) {
throw new IllegalStateException();
}
@Override
public Future<?> terminationFuture() {
throw new IllegalStateException();
}
@Override
public boolean isShutdown() {
return execService.isShutdown();
}
@Override
public boolean isTerminated() {
return execService.isTerminated();
}
@Override
public boolean awaitTermination(long timeout, TimeUnit unit) throws InterruptedException {
return execService.awaitTermination(timeout, unit);
}
@Override
public void execute(Runnable command) {
execService.execute(command);
}
};
ServerBootstrap serverBootstrap = new ServerBootstrap()
.channel(LocalServerChannel.class)
.group(group)
.localAddress(addr)
.childHandler(new ChannelInitializer<LocalChannel>() {
@Override
protected void initChannel(LocalChannel ch) {
ch.pipeline().addLast(executor, new ChannelInitializer<Channel>() {
@Override
protected void initChannel(Channel ch) {
invokeCount.incrementAndGet();
ChannelHandlerContext ctx = ch.pipeline().context(this);
assertNotNull(ctx);
ch.pipeline().addAfter(ctx.executor(),
ctx.name(), null, new ChannelInboundHandlerAdapter() {
@Override
public void channelRead(ChannelHandlerContext ctx, Object msg) {
// just drop on the floor.
}
});
completeCount.incrementAndGet();
}
@Override
public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) {
errorRef.set(cause);
}
});
}
});
Channel server = serverBootstrap.bind().sync().channel();
Bootstrap clientBootstrap = new Bootstrap()
.channel(LocalChannel.class)
.group(group)
.remoteAddress(addr)
.handler(new ChannelInboundHandlerAdapter());
Channel client = clientBootstrap.connect().sync().channel();
client.writeAndFlush("Hello World").sync();
client.close().sync();
server.close().sync();
client.closeFuture().sync();
server.closeFuture().sync();
// Give some time to execute everything that was submitted before.
Thread.sleep(1000);
executor.shutdown();
assertTrue(executor.awaitTermination(5, TimeUnit.SECONDS));
assertEquals(invokeCount.get(), 1);
assertEquals(invokeCount.get(), completeCount.get());
Throwable cause = errorRef.get();
if (cause != null) {
throw cause;
}
}
private static void closeChannel(Channel c) { private static void closeChannel(Channel c) {
if (c != null) { if (c != null) {
c.close().syncUninterruptibly(); c.close().syncUninterruptibly();