Protect ChannelHandler from reentrancee issues (#9358)

Motivation:

At the moment it is quite easy to hit reentrance issues when you have multiple handlers in the pipeline and each of the handlers does not correctly protect against these. To make it easier for the user we should try to protect from these. The issue is usually if and inbound event will trigger and outbound event and this outbound event then against triggeres an inbound event. This may result in having methods in a ChannelHandler re-enter some method and so state can be corrupted or messages be re-ordered.

Modifications:

- Keep track of inbound / outbound operations in DefaultChannelHandlerContext and if reentrancy is detected break it by scheduling the action on the EventLoop. This will then be picked up once the method returns and so the reentrancy is broken up.
- Adjust tests which made strange assumptions about execution order

Result:

No more reentrancy of handlers possible.
This commit is contained in:
Norman Maurer 2019-09-03 10:28:08 +02:00 committed by GitHub
parent b3e6e41384
commit 48634f1466
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 465 additions and 68 deletions

View File

@ -73,6 +73,8 @@ public class WebSocketProtocolHandlerTest {
// When
channel.read();
channel.runPendingTasks();
// Then - pong frame was written to the outbound
PongWebSocketFrame response1 = channel.readOutbound();
assertEquals(text1, response1.content().toString(UTF_8));

View File

@ -420,6 +420,8 @@ public abstract class Http2MultiplexTest<C extends Http2FrameCodec> {
Http2StreamChannel childChannel = newOutboundStream(handler);
assertTrue(childChannel.isActive());
parentChannel.runPendingTasks();
childChannel.close();
verify(frameWriter).writeRstStream(eqCodecCtx(),
eqStreamId(childChannel), eq(Http2Error.CANCEL.code()), anyChannelPromise());
@ -451,6 +453,7 @@ public abstract class Http2MultiplexTest<C extends Http2FrameCodec> {
ctx.fireChannelActive();
}
});
parentChannel.runPendingTasks();
assertFalse(childChannel.isActive());
@ -530,6 +533,8 @@ public abstract class Http2MultiplexTest<C extends Http2FrameCodec> {
Http2Headers headers = new DefaultHttp2Headers().scheme("https").method("GET").path("/foo.txt");
childChannel.writeAndFlush(new DefaultHttp2HeadersFrame(headers));
parentChannel.runPendingTasks();
// Read from the child channel
frameInboundWriter.writeInboundHeaders(childChannel.stream().id(), headers, 0, false);

View File

@ -357,8 +357,7 @@ public class SocketHalfClosedTest extends AbstractSocketTest {
public void channelActive(ChannelHandlerContext ctx) throws Exception {
ByteBuf buf = ctx.alloc().buffer(expectedBytes);
buf.writerIndex(buf.writerIndex() + expectedBytes);
ctx.writeAndFlush(buf.retainedDuplicate());
ctx.writeAndFlush(buf.retainedDuplicate()).addListener((ChannelFutureListener) f -> {
// We wait here to ensure that we write before we have a chance to process the outbound
// shutdown event.
followerCloseLatch.await();
@ -370,6 +369,7 @@ public class SocketHalfClosedTest extends AbstractSocketTest {
doneLatch.countDown();
}
});
});
}
@Override

View File

@ -25,9 +25,9 @@ import io.netty.util.ReferenceCountUtil;
import io.netty.util.ResourceLeakHint;
import io.netty.util.concurrent.EventExecutor;
import io.netty.util.internal.PromiseNotificationUtil;
import io.netty.util.internal.ThrowableUtil;
import io.netty.util.internal.StringUtil;
import io.netty.util.internal.SystemPropertyUtil;
import io.netty.util.internal.StringUtil;
import io.netty.util.internal.ThrowableUtil;
import io.netty.util.internal.logging.InternalLogger;
import io.netty.util.internal.logging.InternalLoggerFactory;
@ -72,6 +72,10 @@ final class DefaultChannelHandlerContext implements ChannelHandlerContext, Resou
DefaultChannelHandlerContext prev;
private volatile int handlerState = INIT;
// Keeps track of processing different events
private short outboundOperations;
private short inboundOperations;
DefaultChannelHandlerContext(DefaultChannelPipeline pipeline, String name,
ChannelHandler handler) {
this.name = requireNonNull(name, "name");
@ -118,6 +122,54 @@ final class DefaultChannelHandlerContext implements ChannelHandlerContext, Resou
return name;
}
private boolean isProcessInboundDirectly() {
assert inboundOperations >= 0;
return inboundOperations == 0;
}
private boolean isProcessOutboundDirectly() {
assert outboundOperations >= 0;
return outboundOperations == 0;
}
private void incrementOutboundOperations() {
assert outboundOperations >= 0;
outboundOperations++;
}
private void decrementOutboundOperations() {
assert outboundOperations > 0;
outboundOperations--;
}
private void incrementInboundOperations() {
assert inboundOperations >= 0;
inboundOperations++;
}
private void decrementInboundOperations() {
assert inboundOperations > 0;
inboundOperations--;
}
private static void executeInboundReentrance(DefaultChannelHandlerContext context, Runnable task) {
context.incrementInboundOperations();
try {
context.executor().execute(task);
} catch (Throwable cause) {
context.decrementInboundOperations();
throw cause;
}
}
private static void executeOutboundReentrance(
DefaultChannelHandlerContext context, Runnable task, ChannelPromise promise, Object msg) {
context.incrementOutboundOperations();
if (!safeExecute(context.executor(), task, promise, msg)) {
context.decrementOutboundOperations();
}
}
@Override
public ChannelHandlerContext fireChannelRegistered() {
EventExecutor executor = executor();
@ -130,14 +182,26 @@ final class DefaultChannelHandlerContext implements ChannelHandlerContext, Resou
}
private void findAndInvokeChannelRegistered() {
findContextInbound(MASK_CHANNEL_REGISTERED).invokeChannelRegistered();
DefaultChannelHandlerContext context = findContextInbound(MASK_CHANNEL_REGISTERED);
if (context.isProcessInboundDirectly()) {
context.invokeChannelRegistered();
} else {
executeInboundReentrance(context, context::invokeChannelRegistered0);
}
}
void invokeChannelRegistered() {
incrementInboundOperations();
invokeChannelRegistered0();
}
private void invokeChannelRegistered0() {
try {
handler().channelRegistered(this);
} catch (Throwable t) {
notifyHandlerException(t);
} finally {
decrementInboundOperations();
}
}
@ -153,14 +217,26 @@ final class DefaultChannelHandlerContext implements ChannelHandlerContext, Resou
}
private void findAndInvokeChannelUnregistered() {
findContextInbound(MASK_CHANNEL_UNREGISTERED).invokeChannelUnregistered();
DefaultChannelHandlerContext context = findContextInbound(MASK_CHANNEL_UNREGISTERED);
if (context.isProcessInboundDirectly()) {
context.invokeChannelUnregistered();
} else {
executeInboundReentrance(context, context::invokeChannelUnregistered0);
}
}
void invokeChannelUnregistered() {
incrementInboundOperations();
invokeChannelUnregistered0();
}
private void invokeChannelUnregistered0() {
try {
handler().channelUnregistered(this);
} catch (Throwable t) {
notifyHandlerException(t);
} finally {
decrementInboundOperations();
}
}
@ -176,14 +252,26 @@ final class DefaultChannelHandlerContext implements ChannelHandlerContext, Resou
}
private void findAndInvokeChannelActive() {
findContextInbound(MASK_CHANNEL_ACTIVE).invokeChannelActive();
DefaultChannelHandlerContext context = findContextInbound(MASK_CHANNEL_ACTIVE);
if (context.isProcessInboundDirectly()) {
context.invokeChannelActive();
} else {
executeInboundReentrance(context, context::invokeChannelActive0);
}
}
void invokeChannelActive() {
incrementInboundOperations();
invokeChannelActive0();
}
private void invokeChannelActive0() {
try {
handler().channelActive(this);
} catch (Throwable t) {
notifyHandlerException(t);
} finally {
decrementInboundOperations();
}
}
@ -199,14 +287,26 @@ final class DefaultChannelHandlerContext implements ChannelHandlerContext, Resou
}
private void findAndInvokeChannelInactive() {
findContextInbound(MASK_CHANNEL_INACTIVE).invokeChannelInactive();
DefaultChannelHandlerContext context = findContextInbound(MASK_CHANNEL_INACTIVE);
if (context.isProcessInboundDirectly()) {
context.invokeChannelInactive();
} else {
executeInboundReentrance(context, context::invokeChannelInactive0);
}
}
void invokeChannelInactive() {
incrementInboundOperations();
invokeChannelInactive0();
}
private void invokeChannelInactive0() {
try {
handler().channelInactive(this);
} catch (Throwable t) {
notifyHandlerException(t);
} finally {
decrementInboundOperations();
}
}
@ -230,10 +330,20 @@ final class DefaultChannelHandlerContext implements ChannelHandlerContext, Resou
}
private void findAndInvokeExceptionCaught(Throwable cause) {
findContextInbound(MASK_EXCEPTION_CAUGHT).invokeExceptionCaught(cause);
DefaultChannelHandlerContext context = findContextInbound(MASK_EXCEPTION_CAUGHT);
if (context.isProcessInboundDirectly()) {
context.invokeExceptionCaught(cause);
} else {
executeInboundReentrance(context, () -> context.invokeExceptionCaught0(cause));
}
}
void invokeExceptionCaught(final Throwable cause) {
incrementInboundOperations();
invokeExceptionCaught0(cause);
}
private void invokeExceptionCaught0(final Throwable cause) {
try {
handler().exceptionCaught(this, cause);
} catch (Throwable error) {
@ -249,6 +359,8 @@ final class DefaultChannelHandlerContext implements ChannelHandlerContext, Resou
"was thrown by a user handler's exceptionCaught() " +
"method while handling the following exception:", error, cause);
}
} finally {
decrementInboundOperations();
}
}
@ -265,14 +377,26 @@ final class DefaultChannelHandlerContext implements ChannelHandlerContext, Resou
}
private void findAndInvokeUserEventTriggered(Object event) {
findContextInbound(MASK_USER_EVENT_TRIGGERED).invokeUserEventTriggered(event);
DefaultChannelHandlerContext context = findContextInbound(MASK_USER_EVENT_TRIGGERED);
if (context.isProcessInboundDirectly()) {
context.invokeUserEventTriggered(event);
} else {
executeInboundReentrance(context, () -> context.invokeUserEventTriggered0(event));
}
}
void invokeUserEventTriggered(Object event) {
incrementInboundOperations();
invokeUserEventTriggered0(event);
}
private void invokeUserEventTriggered0(Object event) {
try {
handler().userEventTriggered(this, event);
} catch (Throwable t) {
notifyHandlerException(t);
} finally {
decrementInboundOperations();
}
}
@ -294,15 +418,27 @@ final class DefaultChannelHandlerContext implements ChannelHandlerContext, Resou
}
private void findAndInvokeChannelRead(Object msg) {
findContextInbound(MASK_CHANNEL_READ).invokeChannelRead(msg);
DefaultChannelHandlerContext context = findContextInbound(MASK_CHANNEL_READ);
if (context.isProcessInboundDirectly()) {
context.invokeChannelRead(msg);
} else {
executeInboundReentrance(context, () -> context.invokeChannelRead0(msg));
}
}
void invokeChannelRead(Object msg) {
incrementInboundOperations();
invokeChannelRead0(msg);
}
private void invokeChannelRead0(Object msg) {
final Object m = pipeline.touch(requireNonNull(msg, "msg"), this);
try {
handler().channelRead(this, m);
} catch (Throwable t) {
notifyHandlerException(t);
} finally {
decrementInboundOperations();
}
}
@ -319,14 +455,26 @@ final class DefaultChannelHandlerContext implements ChannelHandlerContext, Resou
}
private void findAndInvokeChannelReadComplete() {
findContextInbound(MASK_CHANNEL_READ_COMPLETE).invokeChannelReadComplete();
DefaultChannelHandlerContext context = findContextInbound(MASK_CHANNEL_READ_COMPLETE);
if (context.isProcessInboundDirectly()) {
context.invokeChannelReadComplete();
} else {
executeInboundReentrance(context, context::invokeChannelReadComplete0);
}
}
void invokeChannelReadComplete() {
incrementInboundOperations();
invokeChannelReadComplete0();
}
private void invokeChannelReadComplete0() {
try {
handler().channelReadComplete(this);
} catch (Throwable t) {
notifyHandlerException(t);
} finally {
decrementInboundOperations();
}
}
@ -343,14 +491,26 @@ final class DefaultChannelHandlerContext implements ChannelHandlerContext, Resou
}
private void findAndInvokeChannelWritabilityChanged() {
findContextInbound(MASK_CHANNEL_WRITABILITY_CHANGED).invokeChannelWritabilityChanged();
DefaultChannelHandlerContext context = findContextInbound(MASK_CHANNEL_WRITABILITY_CHANGED);
if (context.isProcessInboundDirectly()) {
context.invokeChannelWritabilityChanged();
} else {
executeInboundReentrance(context, context::invokeChannelWritabilityChanged0);
}
}
void invokeChannelWritabilityChanged() {
incrementInboundOperations();
invokeChannelWritabilityChanged0();
}
private void invokeChannelWritabilityChanged0() {
try {
handler().channelWritabilityChanged(this);
} catch (Throwable t) {
notifyHandlerException(t);
} finally {
decrementInboundOperations();
}
}
@ -407,14 +567,26 @@ final class DefaultChannelHandlerContext implements ChannelHandlerContext, Resou
}
private void findAndInvokeBind(SocketAddress localAddress, ChannelPromise promise) {
findContextOutbound(MASK_BIND).invokeBind(localAddress, promise);
DefaultChannelHandlerContext context = findContextOutbound(MASK_BIND);
if (context.isProcessOutboundDirectly()) {
context.invokeBind(localAddress, promise);
} else {
executeOutboundReentrance(context, () -> context.invokeBind0(localAddress, promise), promise, null);
}
}
private void invokeBind(SocketAddress localAddress, ChannelPromise promise) {
incrementOutboundOperations();
invokeBind0(localAddress, promise);
}
private void invokeBind0(SocketAddress localAddress, ChannelPromise promise) {
try {
handler().bind(this, localAddress, promise);
} catch (Throwable t) {
notifyOutboundHandlerException(t, promise);
} finally {
decrementOutboundOperations();
}
}
@ -442,14 +614,27 @@ final class DefaultChannelHandlerContext implements ChannelHandlerContext, Resou
}
private void findAndInvokeConnect(SocketAddress remoteAddress, SocketAddress localAddress, ChannelPromise promise) {
findContextOutbound(MASK_CONNECT).invokeConnect(remoteAddress, localAddress, promise);
DefaultChannelHandlerContext context = findContextOutbound(MASK_CONNECT);
if (context.isProcessOutboundDirectly()) {
context.invokeConnect(remoteAddress, localAddress, promise);
} else {
executeOutboundReentrance(context, () -> context.invokeConnect0(remoteAddress, localAddress, promise),
promise, null);
}
}
private void invokeConnect(SocketAddress remoteAddress, SocketAddress localAddress, ChannelPromise promise) {
incrementOutboundOperations();
invokeConnect0(remoteAddress, localAddress, promise);
}
private void invokeConnect0(SocketAddress remoteAddress, SocketAddress localAddress, ChannelPromise promise) {
try {
handler().connect(this, remoteAddress, localAddress, promise);
} catch (Throwable t) {
notifyOutboundHandlerException(t, promise);
} finally {
decrementOutboundOperations();
}
}
@ -476,14 +661,26 @@ final class DefaultChannelHandlerContext implements ChannelHandlerContext, Resou
}
private void findAndInvokeDisconnect(ChannelPromise promise) {
findContextOutbound(MASK_DISCONNECT).invokeDisconnect(promise);
DefaultChannelHandlerContext context = findContextOutbound(MASK_DISCONNECT);
if (context.isProcessOutboundDirectly()) {
context.invokeDisconnect(promise);
} else {
executeOutboundReentrance(context, () -> context.invokeDisconnect0(promise), promise, null);
}
}
private void invokeDisconnect(ChannelPromise promise) {
incrementOutboundOperations();
invokeDisconnect0(promise);
}
private void invokeDisconnect0(ChannelPromise promise) {
try {
handler().disconnect(this, promise);
} catch (Throwable t) {
notifyOutboundHandlerException(t, promise);
} finally {
decrementOutboundOperations();
}
}
@ -504,14 +701,26 @@ final class DefaultChannelHandlerContext implements ChannelHandlerContext, Resou
}
private void findAndInvokeClose(ChannelPromise promise) {
findContextOutbound(MASK_CLOSE).invokeClose(promise);
DefaultChannelHandlerContext context = findContextOutbound(MASK_CLOSE);
if (context.isProcessOutboundDirectly()) {
context.invokeClose(promise);
} else {
executeOutboundReentrance(context, () -> context.invokeClose0(promise), promise, null);
}
}
private void invokeClose(ChannelPromise promise) {
incrementOutboundOperations();
invokeClose0(promise);
}
private void invokeClose0(ChannelPromise promise) {
try {
handler().close(this, promise);
} catch (Throwable t) {
notifyOutboundHandlerException(t, promise);
} finally {
decrementOutboundOperations();
}
}
@ -532,14 +741,26 @@ final class DefaultChannelHandlerContext implements ChannelHandlerContext, Resou
}
private void findAndInvokeRegister(ChannelPromise promise) {
findContextOutbound(MASK_REGISTER).invokeRegister(promise);
DefaultChannelHandlerContext context = findContextOutbound(MASK_REGISTER);
if (context.isProcessOutboundDirectly()) {
context.invokeRegister(promise);
} else {
executeOutboundReentrance(context, () -> context.invokeRegister0(promise), promise, null);
}
}
private void invokeRegister(ChannelPromise promise) {
incrementOutboundOperations();
invokeRegister0(promise);
}
private void invokeRegister0(ChannelPromise promise) {
try {
handler().register(this, promise);
} catch (Throwable t) {
notifyOutboundHandlerException(t, promise);
} finally {
decrementOutboundOperations();
}
}
@ -560,14 +781,26 @@ final class DefaultChannelHandlerContext implements ChannelHandlerContext, Resou
}
private void findAndInvokeDeregister(ChannelPromise promise) {
findContextOutbound(MASK_DEREGISTER).invokeDeregister(promise);
DefaultChannelHandlerContext context = findContextOutbound(MASK_DEREGISTER);
if (context.isProcessOutboundDirectly()) {
context.invokeDeregister(promise);
} else {
executeOutboundReentrance(context, () -> context.invokeDeregister0(promise), promise, null);
}
}
private void invokeDeregister(ChannelPromise promise) {
incrementOutboundOperations();
invokeDeregister0(promise);
}
private void invokeDeregister0(ChannelPromise promise) {
try {
handler().deregister(this, promise);
} catch (Throwable t) {
notifyOutboundHandlerException(t, promise);
} finally {
decrementOutboundOperations();
}
}
@ -584,14 +817,26 @@ final class DefaultChannelHandlerContext implements ChannelHandlerContext, Resou
}
private void findAndInvokeRead() {
findContextOutbound(MASK_READ).invokeRead();
DefaultChannelHandlerContext context = findContextOutbound(MASK_READ);
if (context.isProcessOutboundDirectly()) {
context.invokeRead();
} else {
executeOutboundReentrance(context, context::invokeRead0, null, null);
}
}
private void invokeRead() {
incrementOutboundOperations();
invokeRead0();
}
private void invokeRead0() {
try {
handler().read(this);
} catch (Throwable t) {
invokeExceptionCaughtFromOutbound(t);
} finally {
decrementOutboundOperations();
}
}
@ -599,7 +844,12 @@ final class DefaultChannelHandlerContext implements ChannelHandlerContext, Resou
if ((executionMask & MASK_EXCEPTION_CAUGHT) != 0) {
notifyHandlerException(t);
} else {
findContextInbound(MASK_EXCEPTION_CAUGHT).notifyHandlerException(t);
DefaultChannelHandlerContext context = findContextInbound(MASK_EXCEPTION_CAUGHT);
if (context.isProcessInboundDirectly()) {
context.invokeExceptionCaught(t);
} else {
executeInboundReentrance(context, () -> context.invokeExceptionCaught0(t));
}
}
}
@ -616,11 +866,18 @@ final class DefaultChannelHandlerContext implements ChannelHandlerContext, Resou
}
private void invokeWrite(Object msg, ChannelPromise promise) {
incrementOutboundOperations();
invokeWrite0(msg, promise);
}
private void invokeWrite0(Object msg, ChannelPromise promise) {
final Object m = pipeline.touch(msg, this);
try {
handler().write(this, m, promise);
} catch (Throwable t) {
notifyOutboundHandlerException(t, promise);
} finally {
decrementOutboundOperations();
}
}
@ -638,14 +895,26 @@ final class DefaultChannelHandlerContext implements ChannelHandlerContext, Resou
}
private void findAndInvokeFlush() {
findContextOutbound(MASK_FLUSH).invokeFlush();
DefaultChannelHandlerContext context = findContextOutbound(MASK_FLUSH);
if (context.isProcessOutboundDirectly()) {
context.invokeFlush();
} else {
executeOutboundReentrance(context, context::invokeFlush0, null, null);
}
}
private void invokeFlush() {
incrementOutboundOperations();
invokeFlush0();
}
private void invokeFlush0() {
try {
handler().flush(this);
} catch (Throwable t) {
invokeExceptionCaughtFromOutbound(t);
} finally {
decrementOutboundOperations();
}
}
@ -655,11 +924,6 @@ final class DefaultChannelHandlerContext implements ChannelHandlerContext, Resou
return promise;
}
private void invokeWriteAndFlush(Object msg, ChannelPromise promise) {
invokeWrite(msg, promise);
invokeFlush();
}
private void write(Object msg, boolean flush, ChannelPromise promise) {
requireNonNull(msg, "msg");
try {
@ -678,9 +942,19 @@ final class DefaultChannelHandlerContext implements ChannelHandlerContext, Resou
final DefaultChannelHandlerContext next = findContextOutbound(flush ?
(MASK_WRITE | MASK_FLUSH) : MASK_WRITE);
if (flush) {
next.invokeWriteAndFlush(msg, promise);
} else {
if (next.isProcessOutboundDirectly()) {
next.invokeWrite(msg, promise);
next.invokeFlush();
} else {
executeOutboundReentrance(next, () -> next.invokeWrite0(msg, promise), promise, msg);
executeOutboundReentrance(next, next::invokeFlush0, null, null);
}
} else {
if (next.isProcessOutboundDirectly()) {
next.invokeWrite(msg, promise);
} else {
executeOutboundReentrance(next, () -> next.invokeWrite0(msg, promise), promise, msg);
}
}
} else {
final AbstractWriteTask task;
@ -719,7 +993,6 @@ final class DefaultChannelHandlerContext implements ChannelHandlerContext, Resou
}
return;
}
invokeExceptionCaught(cause);
}
@ -802,7 +1075,7 @@ final class DefaultChannelHandlerContext implements ChannelHandlerContext, Resou
DefaultChannelHandlerContext ctx = this;
do {
ctx = ctx.next;
} while ((ctx.executionMask & mask) == 0);
} while ((ctx.executionMask & mask) == 0 && ctx.isProcessInboundDirectly());
return ctx;
}
@ -810,7 +1083,7 @@ final class DefaultChannelHandlerContext implements ChannelHandlerContext, Resou
DefaultChannelHandlerContext ctx = this;
do {
ctx = ctx.prev;
} while ((ctx.executionMask & mask) == 0);
} while ((ctx.executionMask & mask) == 0 && ctx.isProcessOutboundDirectly());
return ctx;
}
@ -871,7 +1144,9 @@ final class DefaultChannelHandlerContext implements ChannelHandlerContext, Resou
return true;
} catch (Throwable cause) {
try {
if (promise != null) {
promise.setFailure(cause);
}
} finally {
if (msg != null) {
ReferenceCountUtil.release(msg);

View File

@ -69,7 +69,6 @@ public class DefaultChannelPipeline implements ChannelPipeline {
private final VoidChannelPromise voidPromise;
private final boolean touch = ResourceLeakDetector.isEnabled();
private final List<DefaultChannelHandlerContext> handlers = new ArrayList<>(4);
private volatile MessageSizeEstimator.Handle estimatorHandle;
public DefaultChannelPipeline(Channel channel) {

View File

@ -49,10 +49,12 @@ import java.util.Collections;
import java.util.List;
import java.util.NoSuchElementException;
import java.util.Queue;
import java.util.concurrent.BlockingQueue;
import java.util.concurrent.Callable;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.LinkedBlockingDeque;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicReference;
@ -1370,6 +1372,87 @@ public class DefaultChannelPipelineTest {
channel2.close().syncUninterruptibly();
}
@Test
public void testReentranceInbound() throws Exception {
DefaultChannelPipeline pipeline = new DefaultChannelPipeline(newLocalChannel());
BlockingQueue<Integer> queue = new LinkedBlockingDeque<>();
pipeline.addLast(new ChannelHandler() {
@Override
public void channelActive(ChannelHandlerContext ctx) {
ctx.fireChannelRead(1);
ctx.fireChannelRead(2);
}
});
pipeline.addLast(new ChannelHandler() {
boolean called;
@Override
public void read(ChannelHandlerContext ctx) {
if (!called) {
called = true;
ctx.fireChannelRead(3);
}
ctx.read();
}
});
pipeline.addLast(new ChannelHandler() {
@Override
public void channelRead(ChannelHandlerContext ctx, Object msg) {
ctx.read();
queue.add((Integer) msg);
}
});
pipeline.fireChannelActive();
assertEquals(1, (int) queue.take());
assertEquals(3, (int) queue.take());
assertEquals(2, (int) queue.take());
pipeline.close().syncUninterruptibly();
assertNull(queue.poll());
}
@Test
public void testReentranceOutbound() throws Exception {
DefaultChannelPipeline pipeline = new DefaultChannelPipeline(newLocalChannel());
BlockingQueue<Integer> queue = new LinkedBlockingDeque<>();
pipeline.addLast(new ChannelHandler() {
@Override
public void write(ChannelHandlerContext ctx, Object msg, ChannelPromise promise) {
ctx.fireUserEventTriggered("");
queue.add((Integer) msg);
}
});
pipeline.addLast(new ChannelHandler() {
boolean called;
@Override
public void userEventTriggered(ChannelHandlerContext ctx, Object evt) {
if (!called) {
called = true;
ctx.write(3);
}
ctx.fireUserEventTriggered(evt);
}
});
pipeline.addLast(new ChannelHandler() {
@Override
public void channelActive(ChannelHandlerContext ctx) {
ctx.write(1);
ctx.write(2);
}
});
pipeline.fireChannelActive();
assertEquals(1, (int) queue.take());
assertEquals(3, (int) queue.take());
assertEquals(2, (int) queue.take());
pipeline.close().syncUninterruptibly();
assertNull(queue.poll());
}
@Test(timeout = 5000)
public void handlerAddedStateUpdatedBeforeHandlerAddedDoneForceEventLoop() throws InterruptedException {
handlerAddedStateUpdatedBeforeHandlerAddedDone(true);

View File

@ -21,6 +21,7 @@ import io.netty.channel.LoggingHandler.Event;
import io.netty.channel.local.LocalAddress;
import org.hamcrest.Matchers;
import org.junit.Ignore;
import org.junit.Test;
import java.nio.channels.ClosedChannelException;
@ -142,6 +143,22 @@ public class ReentrantChannelTest extends BaseChannelTest {
"FLUSH\n" +
"WRITABILITY: writable=true\n",
// Case 2:
"WRITABILITY: writable=false\n" +
"FLUSH\n" +
"WRITE\n" +
"WRITABILITY: writable=false\n" +
"WRITABILITY: writable=false\n" +
"FLUSH\n" +
"WRITABILITY: writable=true\n" +
"FLUSH\n",
// Case 3:
"WRITABILITY: writable=false\n" +
"FLUSH\n" +
"WRITE\n" +
"WRITABILITY: writable=false\n" +
"FLUSH\n" +
"WRITABILITY: writable=true\n",
// Case 4:
"WRITABILITY: writable=false\n" +
"FLUSH\n" +
"WRITE\n" +
@ -151,6 +168,7 @@ public class ReentrantChannelTest extends BaseChannelTest {
"WRITABILITY: writable=true\n");
}
@Ignore("The whole test is questionable so ignore for now")
@Test
public void testWriteFlushPingPong() throws Exception {
@ -186,13 +204,16 @@ public class ReentrantChannelTest extends BaseChannelTest {
ctx.channel().write(createTestBuf(2000));
}
ctx.flush();
if (flushCount == 5) {
ctx.close();
}
}
});
clientChannel.writeAndFlush(createTestBuf(2000));
clientChannel.close().sync();
clientChannel.write(createTestBuf(2000));
clientChannel.closeFuture().syncUninterruptibly();
assertLog(
// Case 1:
"WRITE\n" +
"FLUSH\n" +
"WRITE\n" +
@ -205,6 +226,18 @@ public class ReentrantChannelTest extends BaseChannelTest {
"FLUSH\n" +
"WRITE\n" +
"FLUSH\n" +
"CLOSE\n",
// Case 2:
"WRITE\n" +
"FLUSH\n" +
"FLUSH\n" +
"WRITE\n" +
"WRITE\n" +
"FLUSH\n" +
"FLUSH\n" +
"WRITE\n" +
"WRITE\n" +
"FLUSH\n" +
"CLOSE\n");
}