Revert "Protect ChannelHandler from reentrancee issues (#9358)"

This reverts commit 48634f14665c864a47c6552f59fbce8d009780be.
This commit is contained in:
Norman Maurer 2019-11-28 15:30:38 +01:00
parent 585ed4d08f
commit ee593ace33
7 changed files with 68 additions and 465 deletions

View File

@ -73,8 +73,6 @@ public class WebSocketProtocolHandlerTest {
// When // When
channel.read(); channel.read();
channel.runPendingTasks();
// Then - pong frame was written to the outbound // Then - pong frame was written to the outbound
PongWebSocketFrame response1 = channel.readOutbound(); PongWebSocketFrame response1 = channel.readOutbound();
assertEquals(text1, response1.content().toString(UTF_8)); assertEquals(text1, response1.content().toString(UTF_8));

View File

@ -417,8 +417,6 @@ public abstract class Http2MultiplexTest<C extends Http2FrameCodec> {
Http2StreamChannel childChannel = newOutboundStream(handler); Http2StreamChannel childChannel = newOutboundStream(handler);
assertTrue(childChannel.isActive()); assertTrue(childChannel.isActive());
parentChannel.runPendingTasks();
childChannel.close(); childChannel.close();
verify(frameWriter).writeRstStream(eqCodecCtx(), verify(frameWriter).writeRstStream(eqCodecCtx(),
eqStreamId(childChannel), eq(Http2Error.CANCEL.code()), anyChannelPromise()); eqStreamId(childChannel), eq(Http2Error.CANCEL.code()), anyChannelPromise());
@ -450,7 +448,6 @@ public abstract class Http2MultiplexTest<C extends Http2FrameCodec> {
ctx.fireChannelActive(); ctx.fireChannelActive();
} }
}); });
parentChannel.runPendingTasks();
assertFalse(childChannel.isActive()); assertFalse(childChannel.isActive());
@ -529,8 +526,6 @@ public abstract class Http2MultiplexTest<C extends Http2FrameCodec> {
Http2Headers headers = new DefaultHttp2Headers().scheme("https").method("GET").path("/foo.txt"); Http2Headers headers = new DefaultHttp2Headers().scheme("https").method("GET").path("/foo.txt");
childChannel.writeAndFlush(new DefaultHttp2HeadersFrame(headers)); childChannel.writeAndFlush(new DefaultHttp2HeadersFrame(headers));
parentChannel.runPendingTasks();
// Read from the child channel // Read from the child channel
frameInboundWriter.writeInboundHeaders(childChannel.stream().id(), headers, 0, false); frameInboundWriter.writeInboundHeaders(childChannel.stream().id(), headers, 0, false);

View File

@ -357,18 +357,18 @@ public class SocketHalfClosedTest extends AbstractSocketTest {
public void channelActive(ChannelHandlerContext ctx) throws Exception { public void channelActive(ChannelHandlerContext ctx) throws Exception {
ByteBuf buf = ctx.alloc().buffer(expectedBytes); ByteBuf buf = ctx.alloc().buffer(expectedBytes);
buf.writerIndex(buf.writerIndex() + expectedBytes); buf.writerIndex(buf.writerIndex() + expectedBytes);
ctx.writeAndFlush(buf.retainedDuplicate()).addListener((ChannelFutureListener) f -> { ctx.writeAndFlush(buf.retainedDuplicate());
// We wait here to ensure that we write before we have a chance to process the outbound
// shutdown event.
followerCloseLatch.await();
// This write should fail, but we should still be allowed to read the peer's data // We wait here to ensure that we write before we have a chance to process the outbound
ctx.writeAndFlush(buf).addListener((ChannelFutureListener) future -> { // shutdown event.
if (future.cause() == null) { followerCloseLatch.await();
causeRef.set(new IllegalStateException("second write should have failed!"));
doneLatch.countDown(); // This write should fail, but we should still be allowed to read the peer's data
} ctx.writeAndFlush(buf).addListener((ChannelFutureListener) future -> {
}); if (future.cause() == null) {
causeRef.set(new IllegalStateException("second write should have failed!"));
doneLatch.countDown();
}
}); });
} }

View File

@ -25,9 +25,9 @@ import io.netty.util.ResourceLeakHint;
import io.netty.util.concurrent.EventExecutor; import io.netty.util.concurrent.EventExecutor;
import io.netty.util.internal.ObjectPool; import io.netty.util.internal.ObjectPool;
import io.netty.util.internal.PromiseNotificationUtil; import io.netty.util.internal.PromiseNotificationUtil;
import io.netty.util.internal.SystemPropertyUtil;
import io.netty.util.internal.StringUtil;
import io.netty.util.internal.ThrowableUtil; import io.netty.util.internal.ThrowableUtil;
import io.netty.util.internal.StringUtil;
import io.netty.util.internal.SystemPropertyUtil;
import io.netty.util.internal.logging.InternalLogger; import io.netty.util.internal.logging.InternalLogger;
import io.netty.util.internal.logging.InternalLoggerFactory; import io.netty.util.internal.logging.InternalLoggerFactory;
@ -68,10 +68,6 @@ final class DefaultChannelHandlerContext implements ChannelHandlerContext, Resou
DefaultChannelHandlerContext prev; DefaultChannelHandlerContext prev;
private int handlerState = INIT; private int handlerState = INIT;
// Keeps track of processing different events
private short outboundOperations;
private short inboundOperations;
DefaultChannelHandlerContext(DefaultChannelPipeline pipeline, String name, DefaultChannelHandlerContext(DefaultChannelPipeline pipeline, String name,
ChannelHandler handler) { ChannelHandler handler) {
this.name = requireNonNull(name, "name"); this.name = requireNonNull(name, "name");
@ -118,54 +114,6 @@ final class DefaultChannelHandlerContext implements ChannelHandlerContext, Resou
return name; return name;
} }
private boolean isProcessInboundDirectly() {
assert inboundOperations >= 0;
return inboundOperations == 0;
}
private boolean isProcessOutboundDirectly() {
assert outboundOperations >= 0;
return outboundOperations == 0;
}
private void incrementOutboundOperations() {
assert outboundOperations >= 0;
outboundOperations++;
}
private void decrementOutboundOperations() {
assert outboundOperations > 0;
outboundOperations--;
}
private void incrementInboundOperations() {
assert inboundOperations >= 0;
inboundOperations++;
}
private void decrementInboundOperations() {
assert inboundOperations > 0;
inboundOperations--;
}
private static void executeInboundReentrance(DefaultChannelHandlerContext context, Runnable task) {
context.incrementInboundOperations();
try {
context.executor().execute(task);
} catch (Throwable cause) {
context.decrementInboundOperations();
throw cause;
}
}
private static void executeOutboundReentrance(
DefaultChannelHandlerContext context, Runnable task, ChannelPromise promise, Object msg) {
context.incrementOutboundOperations();
if (!safeExecute(context.executor(), task, promise, msg)) {
context.decrementOutboundOperations();
}
}
@Override @Override
public ChannelHandlerContext fireChannelRegistered() { public ChannelHandlerContext fireChannelRegistered() {
EventExecutor executor = executor(); EventExecutor executor = executor();
@ -178,26 +126,14 @@ final class DefaultChannelHandlerContext implements ChannelHandlerContext, Resou
} }
private void findAndInvokeChannelRegistered() { private void findAndInvokeChannelRegistered() {
DefaultChannelHandlerContext context = findContextInbound(MASK_CHANNEL_REGISTERED); findContextInbound(MASK_CHANNEL_REGISTERED).invokeChannelRegistered();
if (context.isProcessInboundDirectly()) {
context.invokeChannelRegistered();
} else {
executeInboundReentrance(context, context::invokeChannelRegistered0);
}
} }
void invokeChannelRegistered() { void invokeChannelRegistered() {
incrementInboundOperations();
invokeChannelRegistered0();
}
private void invokeChannelRegistered0() {
try { try {
handler().channelRegistered(this); handler().channelRegistered(this);
} catch (Throwable t) { } catch (Throwable t) {
notifyHandlerException(t); notifyHandlerException(t);
} finally {
decrementInboundOperations();
} }
} }
@ -213,26 +149,14 @@ final class DefaultChannelHandlerContext implements ChannelHandlerContext, Resou
} }
private void findAndInvokeChannelUnregistered() { private void findAndInvokeChannelUnregistered() {
DefaultChannelHandlerContext context = findContextInbound(MASK_CHANNEL_UNREGISTERED); findContextInbound(MASK_CHANNEL_UNREGISTERED).invokeChannelUnregistered();
if (context.isProcessInboundDirectly()) {
context.invokeChannelUnregistered();
} else {
executeInboundReentrance(context, context::invokeChannelUnregistered0);
}
} }
void invokeChannelUnregistered() { void invokeChannelUnregistered() {
incrementInboundOperations();
invokeChannelUnregistered0();
}
private void invokeChannelUnregistered0() {
try { try {
handler().channelUnregistered(this); handler().channelUnregistered(this);
} catch (Throwable t) { } catch (Throwable t) {
notifyHandlerException(t); notifyHandlerException(t);
} finally {
decrementInboundOperations();
} }
} }
@ -248,26 +172,14 @@ final class DefaultChannelHandlerContext implements ChannelHandlerContext, Resou
} }
private void findAndInvokeChannelActive() { private void findAndInvokeChannelActive() {
DefaultChannelHandlerContext context = findContextInbound(MASK_CHANNEL_ACTIVE); findContextInbound(MASK_CHANNEL_ACTIVE).invokeChannelActive();
if (context.isProcessInboundDirectly()) {
context.invokeChannelActive();
} else {
executeInboundReentrance(context, context::invokeChannelActive0);
}
} }
void invokeChannelActive() { void invokeChannelActive() {
incrementInboundOperations();
invokeChannelActive0();
}
private void invokeChannelActive0() {
try { try {
handler().channelActive(this); handler().channelActive(this);
} catch (Throwable t) { } catch (Throwable t) {
notifyHandlerException(t); notifyHandlerException(t);
} finally {
decrementInboundOperations();
} }
} }
@ -283,26 +195,14 @@ final class DefaultChannelHandlerContext implements ChannelHandlerContext, Resou
} }
private void findAndInvokeChannelInactive() { private void findAndInvokeChannelInactive() {
DefaultChannelHandlerContext context = findContextInbound(MASK_CHANNEL_INACTIVE); findContextInbound(MASK_CHANNEL_INACTIVE).invokeChannelInactive();
if (context.isProcessInboundDirectly()) {
context.invokeChannelInactive();
} else {
executeInboundReentrance(context, context::invokeChannelInactive0);
}
} }
void invokeChannelInactive() { void invokeChannelInactive() {
incrementInboundOperations();
invokeChannelInactive0();
}
private void invokeChannelInactive0() {
try { try {
handler().channelInactive(this); handler().channelInactive(this);
} catch (Throwable t) { } catch (Throwable t) {
notifyHandlerException(t); notifyHandlerException(t);
} finally {
decrementInboundOperations();
} }
} }
@ -326,20 +226,10 @@ final class DefaultChannelHandlerContext implements ChannelHandlerContext, Resou
} }
private void findAndInvokeExceptionCaught(Throwable cause) { private void findAndInvokeExceptionCaught(Throwable cause) {
DefaultChannelHandlerContext context = findContextInbound(MASK_EXCEPTION_CAUGHT); findContextInbound(MASK_EXCEPTION_CAUGHT).invokeExceptionCaught(cause);
if (context.isProcessInboundDirectly()) {
context.invokeExceptionCaught(cause);
} else {
executeInboundReentrance(context, () -> context.invokeExceptionCaught0(cause));
}
} }
void invokeExceptionCaught(final Throwable cause) { void invokeExceptionCaught(final Throwable cause) {
incrementInboundOperations();
invokeExceptionCaught0(cause);
}
private void invokeExceptionCaught0(final Throwable cause) {
try { try {
handler().exceptionCaught(this, cause); handler().exceptionCaught(this, cause);
} catch (Throwable error) { } catch (Throwable error) {
@ -355,8 +245,6 @@ final class DefaultChannelHandlerContext implements ChannelHandlerContext, Resou
"was thrown by a user handler's exceptionCaught() " + "was thrown by a user handler's exceptionCaught() " +
"method while handling the following exception:", error, cause); "method while handling the following exception:", error, cause);
} }
} finally {
decrementInboundOperations();
} }
} }
@ -373,26 +261,14 @@ final class DefaultChannelHandlerContext implements ChannelHandlerContext, Resou
} }
private void findAndInvokeUserEventTriggered(Object event) { private void findAndInvokeUserEventTriggered(Object event) {
DefaultChannelHandlerContext context = findContextInbound(MASK_USER_EVENT_TRIGGERED); findContextInbound(MASK_USER_EVENT_TRIGGERED).invokeUserEventTriggered(event);
if (context.isProcessInboundDirectly()) {
context.invokeUserEventTriggered(event);
} else {
executeInboundReentrance(context, () -> context.invokeUserEventTriggered0(event));
}
} }
void invokeUserEventTriggered(Object event) { void invokeUserEventTriggered(Object event) {
incrementInboundOperations();
invokeUserEventTriggered0(event);
}
private void invokeUserEventTriggered0(Object event) {
try { try {
handler().userEventTriggered(this, event); handler().userEventTriggered(this, event);
} catch (Throwable t) { } catch (Throwable t) {
notifyHandlerException(t); notifyHandlerException(t);
} finally {
decrementInboundOperations();
} }
} }
@ -414,27 +290,15 @@ final class DefaultChannelHandlerContext implements ChannelHandlerContext, Resou
} }
private void findAndInvokeChannelRead(Object msg) { private void findAndInvokeChannelRead(Object msg) {
DefaultChannelHandlerContext context = findContextInbound(MASK_CHANNEL_READ); findContextInbound(MASK_CHANNEL_READ).invokeChannelRead(msg);
if (context.isProcessInboundDirectly()) {
context.invokeChannelRead(msg);
} else {
executeInboundReentrance(context, () -> context.invokeChannelRead0(msg));
}
} }
void invokeChannelRead(Object msg) { void invokeChannelRead(Object msg) {
incrementInboundOperations();
invokeChannelRead0(msg);
}
private void invokeChannelRead0(Object msg) {
final Object m = pipeline.touch(requireNonNull(msg, "msg"), this); final Object m = pipeline.touch(requireNonNull(msg, "msg"), this);
try { try {
handler().channelRead(this, m); handler().channelRead(this, m);
} catch (Throwable t) { } catch (Throwable t) {
notifyHandlerException(t); notifyHandlerException(t);
} finally {
decrementInboundOperations();
} }
} }
@ -451,26 +315,14 @@ final class DefaultChannelHandlerContext implements ChannelHandlerContext, Resou
} }
private void findAndInvokeChannelReadComplete() { private void findAndInvokeChannelReadComplete() {
DefaultChannelHandlerContext context = findContextInbound(MASK_CHANNEL_READ_COMPLETE); findContextInbound(MASK_CHANNEL_READ_COMPLETE).invokeChannelReadComplete();
if (context.isProcessInboundDirectly()) {
context.invokeChannelReadComplete();
} else {
executeInboundReentrance(context, context::invokeChannelReadComplete0);
}
} }
void invokeChannelReadComplete() { void invokeChannelReadComplete() {
incrementInboundOperations();
invokeChannelReadComplete0();
}
private void invokeChannelReadComplete0() {
try { try {
handler().channelReadComplete(this); handler().channelReadComplete(this);
} catch (Throwable t) { } catch (Throwable t) {
notifyHandlerException(t); notifyHandlerException(t);
} finally {
decrementInboundOperations();
} }
} }
@ -487,26 +339,14 @@ final class DefaultChannelHandlerContext implements ChannelHandlerContext, Resou
} }
private void findAndInvokeChannelWritabilityChanged() { private void findAndInvokeChannelWritabilityChanged() {
DefaultChannelHandlerContext context = findContextInbound(MASK_CHANNEL_WRITABILITY_CHANGED); findContextInbound(MASK_CHANNEL_WRITABILITY_CHANGED).invokeChannelWritabilityChanged();
if (context.isProcessInboundDirectly()) {
context.invokeChannelWritabilityChanged();
} else {
executeInboundReentrance(context, context::invokeChannelWritabilityChanged0);
}
} }
void invokeChannelWritabilityChanged() { void invokeChannelWritabilityChanged() {
incrementInboundOperations();
invokeChannelWritabilityChanged0();
}
private void invokeChannelWritabilityChanged0() {
try { try {
handler().channelWritabilityChanged(this); handler().channelWritabilityChanged(this);
} catch (Throwable t) { } catch (Throwable t) {
notifyHandlerException(t); notifyHandlerException(t);
} finally {
decrementInboundOperations();
} }
} }
@ -563,26 +403,14 @@ final class DefaultChannelHandlerContext implements ChannelHandlerContext, Resou
} }
private void findAndInvokeBind(SocketAddress localAddress, ChannelPromise promise) { private void findAndInvokeBind(SocketAddress localAddress, ChannelPromise promise) {
DefaultChannelHandlerContext context = findContextOutbound(MASK_BIND); findContextOutbound(MASK_BIND).invokeBind(localAddress, promise);
if (context.isProcessOutboundDirectly()) {
context.invokeBind(localAddress, promise);
} else {
executeOutboundReentrance(context, () -> context.invokeBind0(localAddress, promise), promise, null);
}
} }
private void invokeBind(SocketAddress localAddress, ChannelPromise promise) { private void invokeBind(SocketAddress localAddress, ChannelPromise promise) {
incrementOutboundOperations();
invokeBind0(localAddress, promise);
}
private void invokeBind0(SocketAddress localAddress, ChannelPromise promise) {
try { try {
handler().bind(this, localAddress, promise); handler().bind(this, localAddress, promise);
} catch (Throwable t) { } catch (Throwable t) {
notifyOutboundHandlerException(t, promise); notifyOutboundHandlerException(t, promise);
} finally {
decrementOutboundOperations();
} }
} }
@ -610,27 +438,14 @@ final class DefaultChannelHandlerContext implements ChannelHandlerContext, Resou
} }
private void findAndInvokeConnect(SocketAddress remoteAddress, SocketAddress localAddress, ChannelPromise promise) { private void findAndInvokeConnect(SocketAddress remoteAddress, SocketAddress localAddress, ChannelPromise promise) {
DefaultChannelHandlerContext context = findContextOutbound(MASK_CONNECT); findContextOutbound(MASK_CONNECT).invokeConnect(remoteAddress, localAddress, promise);
if (context.isProcessOutboundDirectly()) {
context.invokeConnect(remoteAddress, localAddress, promise);
} else {
executeOutboundReentrance(context, () -> context.invokeConnect0(remoteAddress, localAddress, promise),
promise, null);
}
} }
private void invokeConnect(SocketAddress remoteAddress, SocketAddress localAddress, ChannelPromise promise) { private void invokeConnect(SocketAddress remoteAddress, SocketAddress localAddress, ChannelPromise promise) {
incrementOutboundOperations();
invokeConnect0(remoteAddress, localAddress, promise);
}
private void invokeConnect0(SocketAddress remoteAddress, SocketAddress localAddress, ChannelPromise promise) {
try { try {
handler().connect(this, remoteAddress, localAddress, promise); handler().connect(this, remoteAddress, localAddress, promise);
} catch (Throwable t) { } catch (Throwable t) {
notifyOutboundHandlerException(t, promise); notifyOutboundHandlerException(t, promise);
} finally {
decrementOutboundOperations();
} }
} }
@ -657,26 +472,14 @@ final class DefaultChannelHandlerContext implements ChannelHandlerContext, Resou
} }
private void findAndInvokeDisconnect(ChannelPromise promise) { private void findAndInvokeDisconnect(ChannelPromise promise) {
DefaultChannelHandlerContext context = findContextOutbound(MASK_DISCONNECT); findContextOutbound(MASK_DISCONNECT).invokeDisconnect(promise);
if (context.isProcessOutboundDirectly()) {
context.invokeDisconnect(promise);
} else {
executeOutboundReentrance(context, () -> context.invokeDisconnect0(promise), promise, null);
}
} }
private void invokeDisconnect(ChannelPromise promise) { private void invokeDisconnect(ChannelPromise promise) {
incrementOutboundOperations();
invokeDisconnect0(promise);
}
private void invokeDisconnect0(ChannelPromise promise) {
try { try {
handler().disconnect(this, promise); handler().disconnect(this, promise);
} catch (Throwable t) { } catch (Throwable t) {
notifyOutboundHandlerException(t, promise); notifyOutboundHandlerException(t, promise);
} finally {
decrementOutboundOperations();
} }
} }
@ -697,26 +500,14 @@ final class DefaultChannelHandlerContext implements ChannelHandlerContext, Resou
} }
private void findAndInvokeClose(ChannelPromise promise) { private void findAndInvokeClose(ChannelPromise promise) {
DefaultChannelHandlerContext context = findContextOutbound(MASK_CLOSE); findContextOutbound(MASK_CLOSE).invokeClose(promise);
if (context.isProcessOutboundDirectly()) {
context.invokeClose(promise);
} else {
executeOutboundReentrance(context, () -> context.invokeClose0(promise), promise, null);
}
} }
private void invokeClose(ChannelPromise promise) { private void invokeClose(ChannelPromise promise) {
incrementOutboundOperations();
invokeClose0(promise);
}
private void invokeClose0(ChannelPromise promise) {
try { try {
handler().close(this, promise); handler().close(this, promise);
} catch (Throwable t) { } catch (Throwable t) {
notifyOutboundHandlerException(t, promise); notifyOutboundHandlerException(t, promise);
} finally {
decrementOutboundOperations();
} }
} }
@ -737,26 +528,14 @@ final class DefaultChannelHandlerContext implements ChannelHandlerContext, Resou
} }
private void findAndInvokeRegister(ChannelPromise promise) { private void findAndInvokeRegister(ChannelPromise promise) {
DefaultChannelHandlerContext context = findContextOutbound(MASK_REGISTER); findContextOutbound(MASK_REGISTER).invokeRegister(promise);
if (context.isProcessOutboundDirectly()) {
context.invokeRegister(promise);
} else {
executeOutboundReentrance(context, () -> context.invokeRegister0(promise), promise, null);
}
} }
private void invokeRegister(ChannelPromise promise) { private void invokeRegister(ChannelPromise promise) {
incrementOutboundOperations();
invokeRegister0(promise);
}
private void invokeRegister0(ChannelPromise promise) {
try { try {
handler().register(this, promise); handler().register(this, promise);
} catch (Throwable t) { } catch (Throwable t) {
notifyOutboundHandlerException(t, promise); notifyOutboundHandlerException(t, promise);
} finally {
decrementOutboundOperations();
} }
} }
@ -777,26 +556,14 @@ final class DefaultChannelHandlerContext implements ChannelHandlerContext, Resou
} }
private void findAndInvokeDeregister(ChannelPromise promise) { private void findAndInvokeDeregister(ChannelPromise promise) {
DefaultChannelHandlerContext context = findContextOutbound(MASK_DEREGISTER); findContextOutbound(MASK_DEREGISTER).invokeDeregister(promise);
if (context.isProcessOutboundDirectly()) {
context.invokeDeregister(promise);
} else {
executeOutboundReentrance(context, () -> context.invokeDeregister0(promise), promise, null);
}
} }
private void invokeDeregister(ChannelPromise promise) { private void invokeDeregister(ChannelPromise promise) {
incrementOutboundOperations();
invokeDeregister0(promise);
}
private void invokeDeregister0(ChannelPromise promise) {
try { try {
handler().deregister(this, promise); handler().deregister(this, promise);
} catch (Throwable t) { } catch (Throwable t) {
notifyOutboundHandlerException(t, promise); notifyOutboundHandlerException(t, promise);
} finally {
decrementOutboundOperations();
} }
} }
@ -813,26 +580,14 @@ final class DefaultChannelHandlerContext implements ChannelHandlerContext, Resou
} }
private void findAndInvokeRead() { private void findAndInvokeRead() {
DefaultChannelHandlerContext context = findContextOutbound(MASK_READ); findContextOutbound(MASK_READ).invokeRead();
if (context.isProcessOutboundDirectly()) {
context.invokeRead();
} else {
executeOutboundReentrance(context, context::invokeRead0, null, null);
}
} }
private void invokeRead() { private void invokeRead() {
incrementOutboundOperations();
invokeRead0();
}
private void invokeRead0() {
try { try {
handler().read(this); handler().read(this);
} catch (Throwable t) { } catch (Throwable t) {
invokeExceptionCaughtFromOutbound(t); invokeExceptionCaughtFromOutbound(t);
} finally {
decrementOutboundOperations();
} }
} }
@ -840,12 +595,7 @@ final class DefaultChannelHandlerContext implements ChannelHandlerContext, Resou
if ((executionMask & MASK_EXCEPTION_CAUGHT) != 0) { if ((executionMask & MASK_EXCEPTION_CAUGHT) != 0) {
notifyHandlerException(t); notifyHandlerException(t);
} else { } else {
DefaultChannelHandlerContext context = findContextInbound(MASK_EXCEPTION_CAUGHT); findContextInbound(MASK_EXCEPTION_CAUGHT).notifyHandlerException(t);
if (context.isProcessInboundDirectly()) {
context.invokeExceptionCaught(t);
} else {
executeInboundReentrance(context, () -> context.invokeExceptionCaught0(t));
}
} }
} }
@ -862,18 +612,11 @@ final class DefaultChannelHandlerContext implements ChannelHandlerContext, Resou
} }
private void invokeWrite(Object msg, ChannelPromise promise) { private void invokeWrite(Object msg, ChannelPromise promise) {
incrementOutboundOperations();
invokeWrite0(msg, promise);
}
private void invokeWrite0(Object msg, ChannelPromise promise) {
final Object m = pipeline.touch(msg, this); final Object m = pipeline.touch(msg, this);
try { try {
handler().write(this, m, promise); handler().write(this, m, promise);
} catch (Throwable t) { } catch (Throwable t) {
notifyOutboundHandlerException(t, promise); notifyOutboundHandlerException(t, promise);
} finally {
decrementOutboundOperations();
} }
} }
@ -891,26 +634,14 @@ final class DefaultChannelHandlerContext implements ChannelHandlerContext, Resou
} }
private void findAndInvokeFlush() { private void findAndInvokeFlush() {
DefaultChannelHandlerContext context = findContextOutbound(MASK_FLUSH); findContextOutbound(MASK_FLUSH).invokeFlush();
if (context.isProcessOutboundDirectly()) {
context.invokeFlush();
} else {
executeOutboundReentrance(context, context::invokeFlush0, null, null);
}
} }
private void invokeFlush() { private void invokeFlush() {
incrementOutboundOperations();
invokeFlush0();
}
private void invokeFlush0() {
try { try {
handler().flush(this); handler().flush(this);
} catch (Throwable t) { } catch (Throwable t) {
invokeExceptionCaughtFromOutbound(t); invokeExceptionCaughtFromOutbound(t);
} finally {
decrementOutboundOperations();
} }
} }
@ -920,6 +651,11 @@ final class DefaultChannelHandlerContext implements ChannelHandlerContext, Resou
return promise; return promise;
} }
private void invokeWriteAndFlush(Object msg, ChannelPromise promise) {
invokeWrite(msg, promise);
invokeFlush();
}
private void write(Object msg, boolean flush, ChannelPromise promise) { private void write(Object msg, boolean flush, ChannelPromise promise) {
requireNonNull(msg, "msg"); requireNonNull(msg, "msg");
try { try {
@ -938,19 +674,9 @@ final class DefaultChannelHandlerContext implements ChannelHandlerContext, Resou
final DefaultChannelHandlerContext next = findContextOutbound(flush ? final DefaultChannelHandlerContext next = findContextOutbound(flush ?
(MASK_WRITE | MASK_FLUSH) : MASK_WRITE); (MASK_WRITE | MASK_FLUSH) : MASK_WRITE);
if (flush) { if (flush) {
if (next.isProcessOutboundDirectly()) { next.invokeWriteAndFlush(msg, promise);
next.invokeWrite(msg, promise);
next.invokeFlush();
} else {
executeOutboundReentrance(next, () -> next.invokeWrite0(msg, promise), promise, msg);
executeOutboundReentrance(next, next::invokeFlush0, null, null);
}
} else { } else {
if (next.isProcessOutboundDirectly()) { next.invokeWrite(msg, promise);
next.invokeWrite(msg, promise);
} else {
executeOutboundReentrance(next, () -> next.invokeWrite0(msg, promise), promise, msg);
}
} }
} else { } else {
final AbstractWriteTask task; final AbstractWriteTask task;
@ -989,6 +715,7 @@ final class DefaultChannelHandlerContext implements ChannelHandlerContext, Resou
} }
return; return;
} }
invokeExceptionCaught(cause); invokeExceptionCaught(cause);
} }
@ -1071,7 +798,7 @@ final class DefaultChannelHandlerContext implements ChannelHandlerContext, Resou
DefaultChannelHandlerContext ctx = this; DefaultChannelHandlerContext ctx = this;
do { do {
ctx = ctx.next; ctx = ctx.next;
} while ((ctx.executionMask & mask) == 0 && ctx.isProcessInboundDirectly()); } while ((ctx.executionMask & mask) == 0);
return ctx; return ctx;
} }
@ -1079,7 +806,7 @@ final class DefaultChannelHandlerContext implements ChannelHandlerContext, Resou
DefaultChannelHandlerContext ctx = this; DefaultChannelHandlerContext ctx = this;
do { do {
ctx = ctx.prev; ctx = ctx.prev;
} while ((ctx.executionMask & mask) == 0 && ctx.isProcessOutboundDirectly()); } while ((ctx.executionMask & mask) == 0);
return ctx; return ctx;
} }
@ -1144,9 +871,7 @@ final class DefaultChannelHandlerContext implements ChannelHandlerContext, Resou
return true; return true;
} catch (Throwable cause) { } catch (Throwable cause) {
try { try {
if (promise != null) { promise.setFailure(cause);
promise.setFailure(cause);
}
} finally { } finally {
if (msg != null) { if (msg != null) {
ReferenceCountUtil.release(msg); ReferenceCountUtil.release(msg);

View File

@ -69,6 +69,7 @@ public class DefaultChannelPipeline implements ChannelPipeline {
private final VoidChannelPromise voidPromise; private final VoidChannelPromise voidPromise;
private final boolean touch = ResourceLeakDetector.isEnabled(); private final boolean touch = ResourceLeakDetector.isEnabled();
private final List<DefaultChannelHandlerContext> handlers = new ArrayList<>(4); private final List<DefaultChannelHandlerContext> handlers = new ArrayList<>(4);
private volatile MessageSizeEstimator.Handle estimatorHandle; private volatile MessageSizeEstimator.Handle estimatorHandle;
public DefaultChannelPipeline(Channel channel) { public DefaultChannelPipeline(Channel channel) {

View File

@ -49,12 +49,10 @@ import java.util.Collections;
import java.util.List; import java.util.List;
import java.util.NoSuchElementException; import java.util.NoSuchElementException;
import java.util.Queue; import java.util.Queue;
import java.util.concurrent.BlockingQueue;
import java.util.concurrent.Callable; import java.util.concurrent.Callable;
import java.util.concurrent.CountDownLatch; import java.util.concurrent.CountDownLatch;
import java.util.concurrent.ExecutorService; import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors; import java.util.concurrent.Executors;
import java.util.concurrent.LinkedBlockingDeque;
import java.util.concurrent.TimeUnit; import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicBoolean; import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicReference; import java.util.concurrent.atomic.AtomicReference;
@ -1372,87 +1370,6 @@ public class DefaultChannelPipelineTest {
channel2.close().syncUninterruptibly(); channel2.close().syncUninterruptibly();
} }
@Test
public void testReentranceInbound() throws Exception {
DefaultChannelPipeline pipeline = new DefaultChannelPipeline(newLocalChannel());
BlockingQueue<Integer> queue = new LinkedBlockingDeque<>();
pipeline.addLast(new ChannelHandler() {
@Override
public void channelActive(ChannelHandlerContext ctx) {
ctx.fireChannelRead(1);
ctx.fireChannelRead(2);
}
});
pipeline.addLast(new ChannelHandler() {
boolean called;
@Override
public void read(ChannelHandlerContext ctx) {
if (!called) {
called = true;
ctx.fireChannelRead(3);
}
ctx.read();
}
});
pipeline.addLast(new ChannelHandler() {
@Override
public void channelRead(ChannelHandlerContext ctx, Object msg) {
ctx.read();
queue.add((Integer) msg);
}
});
pipeline.fireChannelActive();
assertEquals(1, (int) queue.take());
assertEquals(3, (int) queue.take());
assertEquals(2, (int) queue.take());
pipeline.close().syncUninterruptibly();
assertNull(queue.poll());
}
@Test
public void testReentranceOutbound() throws Exception {
DefaultChannelPipeline pipeline = new DefaultChannelPipeline(newLocalChannel());
BlockingQueue<Integer> queue = new LinkedBlockingDeque<>();
pipeline.addLast(new ChannelHandler() {
@Override
public void write(ChannelHandlerContext ctx, Object msg, ChannelPromise promise) {
ctx.fireUserEventTriggered("");
queue.add((Integer) msg);
}
});
pipeline.addLast(new ChannelHandler() {
boolean called;
@Override
public void userEventTriggered(ChannelHandlerContext ctx, Object evt) {
if (!called) {
called = true;
ctx.write(3);
}
ctx.fireUserEventTriggered(evt);
}
});
pipeline.addLast(new ChannelHandler() {
@Override
public void channelActive(ChannelHandlerContext ctx) {
ctx.write(1);
ctx.write(2);
}
});
pipeline.fireChannelActive();
assertEquals(1, (int) queue.take());
assertEquals(3, (int) queue.take());
assertEquals(2, (int) queue.take());
pipeline.close().syncUninterruptibly();
assertNull(queue.poll());
}
@Test(timeout = 5000) @Test(timeout = 5000)
public void handlerAddedStateUpdatedBeforeHandlerAddedDoneForceEventLoop() throws InterruptedException { public void handlerAddedStateUpdatedBeforeHandlerAddedDoneForceEventLoop() throws InterruptedException {
handlerAddedStateUpdatedBeforeHandlerAddedDone(true); handlerAddedStateUpdatedBeforeHandlerAddedDone(true);

View File

@ -21,7 +21,6 @@ import io.netty.channel.LoggingHandler.Event;
import io.netty.channel.local.LocalAddress; import io.netty.channel.local.LocalAddress;
import org.hamcrest.Matchers; import org.hamcrest.Matchers;
import org.junit.Ignore;
import org.junit.Test; import org.junit.Test;
import java.nio.channels.ClosedChannelException; import java.nio.channels.ClosedChannelException;
@ -136,39 +135,22 @@ public class ReentrantChannelTest extends BaseChannelTest {
assertLog( assertLog(
// Case 1: // Case 1:
"WRITABILITY: writable=false\n" + "WRITABILITY: writable=false\n" +
"FLUSH\n" + "FLUSH\n" +
"WRITE\n" + "WRITE\n" +
"WRITABILITY: writable=false\n" + "WRITABILITY: writable=false\n" +
"WRITABILITY: writable=false\n" + "WRITABILITY: writable=false\n" +
"FLUSH\n" + "FLUSH\n" +
"WRITABILITY: writable=true\n", "WRITABILITY: writable=true\n",
// Case 2: // Case 2:
"WRITABILITY: writable=false\n" + "WRITABILITY: writable=false\n" +
"FLUSH\n" + "FLUSH\n" +
"WRITE\n" + "WRITE\n" +
"WRITABILITY: writable=false\n" +
"WRITABILITY: writable=false\n" +
"FLUSH\n" +
"WRITABILITY: writable=true\n" +
"FLUSH\n",
// Case 3:
"WRITABILITY: writable=false\n" + "WRITABILITY: writable=false\n" +
"FLUSH\n" + "FLUSH\n" +
"WRITE\n" + "WRITABILITY: writable=true\n" +
"WRITABILITY: writable=false\n" + "WRITABILITY: writable=true\n");
"FLUSH\n" +
"WRITABILITY: writable=true\n",
// Case 4:
"WRITABILITY: writable=false\n" +
"FLUSH\n" +
"WRITE\n" +
"WRITABILITY: writable=false\n" +
"FLUSH\n" +
"WRITABILITY: writable=true\n" +
"WRITABILITY: writable=true\n");
} }
@Ignore("The whole test is questionable so ignore for now")
@Test @Test
public void testWriteFlushPingPong() throws Exception { public void testWriteFlushPingPong() throws Exception {
@ -204,41 +186,26 @@ public class ReentrantChannelTest extends BaseChannelTest {
ctx.channel().write(createTestBuf(2000)); ctx.channel().write(createTestBuf(2000));
} }
ctx.flush(); ctx.flush();
if (flushCount == 5) {
ctx.close();
}
} }
}); });
clientChannel.write(createTestBuf(2000)); clientChannel.writeAndFlush(createTestBuf(2000));
clientChannel.closeFuture().syncUninterruptibly(); clientChannel.close().sync();
assertLog( assertLog(
// Case 1:
"WRITE\n" + "WRITE\n" +
"FLUSH\n" + "FLUSH\n" +
"WRITE\n" +
"FLUSH\n" +
"WRITE\n" +
"FLUSH\n" +
"WRITE\n" +
"FLUSH\n" +
"WRITE\n" +
"FLUSH\n" +
"WRITE\n" +
"FLUSH\n" +
"CLOSE\n",
// Case 2:
"WRITE\n" + "WRITE\n" +
"FLUSH\n" + "FLUSH\n" +
"FLUSH\n" + "WRITE\n" +
"WRITE\n" + "FLUSH\n" +
"WRITE\n" + "WRITE\n" +
"FLUSH\n" + "FLUSH\n" +
"FLUSH\n" + "WRITE\n" +
"WRITE\n" + "FLUSH\n" +
"WRITE\n" + "WRITE\n" +
"FLUSH\n" + "FLUSH\n" +
"CLOSE\n"); "CLOSE\n");
} }
@Test @Test