Correctly propagate exceptions from inbound operations in all cases (#10176)

Motivation:

In DefaultChannelHandlerContext we had some code where we tried to guard against endless loops caused by exceptions thrown by exceptionCaught(...) that would trigger exceptionCaught again. This code was proplematic for two reasons:
 - It is quite expensive as we need to compare all the stacks
 - We may end up not notify another handlers exceptionCaught(...) if in our exeuction stack we triggered actions that will cause an exceptionCaught somewhere else in the pipeline

Modifications:

- Just remove the detection code as we already handle everything correctly when we invoke exceptionCaught(...)
- Add unit tests

Result:

Ensure we always notify correctly and also fixes performance issue reported as https://github.com/netty/netty/issues/10165
This commit is contained in:
Norman Maurer 2020-04-16 08:56:45 +02:00
parent 87db5803a9
commit 9077acb6ab
2 changed files with 91 additions and 42 deletions

View File

@ -159,7 +159,7 @@ final class DefaultChannelHandlerContext implements ChannelHandlerContext, Resou
try {
handler().channelRegistered(this);
} catch (Throwable t) {
notifyHandlerException(t);
invokeExceptionCaught(t);
}
}
@ -187,7 +187,7 @@ final class DefaultChannelHandlerContext implements ChannelHandlerContext, Resou
try {
handler().channelUnregistered(this);
} catch (Throwable t) {
notifyHandlerException(t);
invokeExceptionCaught(t);
}
}
@ -215,7 +215,7 @@ final class DefaultChannelHandlerContext implements ChannelHandlerContext, Resou
try {
handler().channelActive(this);
} catch (Throwable t) {
notifyHandlerException(t);
invokeExceptionCaught(t);
}
}
@ -243,7 +243,7 @@ final class DefaultChannelHandlerContext implements ChannelHandlerContext, Resou
try {
handler().channelInactive(this);
} catch (Throwable t) {
notifyHandlerException(t);
invokeExceptionCaught(t);
}
}
@ -320,7 +320,7 @@ final class DefaultChannelHandlerContext implements ChannelHandlerContext, Resou
try {
handler().userEventTriggered(this, event);
} catch (Throwable t) {
notifyHandlerException(t);
invokeExceptionCaught(t);
}
}
@ -356,7 +356,7 @@ final class DefaultChannelHandlerContext implements ChannelHandlerContext, Resou
try {
handler().channelRead(this, m);
} catch (Throwable t) {
notifyHandlerException(t);
invokeExceptionCaught(t);
}
}
@ -385,7 +385,7 @@ final class DefaultChannelHandlerContext implements ChannelHandlerContext, Resou
try {
handler().channelReadComplete(this);
} catch (Throwable t) {
notifyHandlerException(t);
invokeExceptionCaught(t);
}
}
@ -414,7 +414,7 @@ final class DefaultChannelHandlerContext implements ChannelHandlerContext, Resou
try {
handler().channelWritabilityChanged(this);
} catch (Throwable t) {
notifyHandlerException(t);
invokeExceptionCaught(t);
}
}
@ -694,7 +694,7 @@ final class DefaultChannelHandlerContext implements ChannelHandlerContext, Resou
private void invokeExceptionCaughtFromOutbound(Throwable t) {
if ((executionMask & MASK_EXCEPTION_CAUGHT) != 0) {
notifyHandlerException(t);
invokeExceptionCaught(t);
} else {
DefaultChannelHandlerContext ctx = findContextInbound(MASK_EXCEPTION_CAUGHT);
if (ctx == null) {
@ -820,39 +820,6 @@ final class DefaultChannelHandlerContext implements ChannelHandlerContext, Resou
PromiseNotificationUtil.tryFailure(promise, cause, promise instanceof VoidChannelPromise ? null : logger);
}
private void notifyHandlerException(Throwable cause) {
if (inExceptionCaught(cause)) {
if (logger.isWarnEnabled()) {
logger.warn(
"An exception was thrown by a user handler " +
"while handling an exceptionCaught event", cause);
}
return;
}
invokeExceptionCaught(cause);
}
private static boolean inExceptionCaught(Throwable cause) {
do {
StackTraceElement[] trace = cause.getStackTrace();
if (trace != null) {
for (StackTraceElement t : trace) {
if (t == null) {
break;
}
if ("exceptionCaught".equals(t.getMethodName())) {
return true;
}
}
}
cause = cause.getCause();
} while (cause != null);
return false;
}
@Override
public ChannelPromise newPromise() {
return pipeline().newPromise();

View File

@ -57,6 +57,7 @@ import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.concurrent.atomic.AtomicReference;
import static org.junit.Assert.assertEquals;
@ -331,6 +332,87 @@ public class DefaultChannelPipelineTest {
verifyContextNumber(pipeline, HANDLER_ARRAY_LEN * 2);
}
@Test(timeout = 3000)
public void testThrowInExceptionCaught() throws InterruptedException {
final CountDownLatch latch = new CountDownLatch(1);
final AtomicInteger counter = new AtomicInteger();
Channel channel = newLocalChannel();
try {
channel.register().syncUninterruptibly();
channel.pipeline().addLast(new ChannelHandler() {
class TestException extends Exception { }
@Override
public void channelReadComplete(ChannelHandlerContext ctx) throws Exception {
throw new TestException();
}
@Override
public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) throws Exception {
if (cause instanceof TestException) {
ctx.executor().execute(new Runnable() {
@Override
public void run() {
latch.countDown();
}
});
}
counter.incrementAndGet();
throw new Exception();
}
});
channel.pipeline().fireChannelReadComplete();
latch.await();
assertEquals(1, counter.get());
} finally {
channel.close().syncUninterruptibly();
}
}
@Test(timeout = 3000)
public void testThrowInOtherHandlerAfterInvokedFromExceptionCaught() throws InterruptedException {
final CountDownLatch latch = new CountDownLatch(1);
final AtomicInteger counter = new AtomicInteger();
Channel channel = newLocalChannel();
try {
channel.register().syncUninterruptibly();
channel.pipeline().addLast(new ChannelHandler() {
@Override
public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) {
ctx.fireChannelReadComplete();
}
}, new ChannelHandler() {
class TestException extends Exception { }
@Override
public void channelReadComplete(ChannelHandlerContext ctx) throws Exception {
throw new TestException();
}
@Override
public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) throws Exception {
if (cause instanceof TestException) {
ctx.executor().execute(new Runnable() {
@Override
public void run() {
latch.countDown();
}
});
}
counter.incrementAndGet();
throw new Exception();
}
});
channel.pipeline().fireExceptionCaught(new Exception());
latch.await();
assertEquals(1, counter.get());
} finally {
channel.close().syncUninterruptibly();
}
}
@Test
public void testFireChannelRegistered() throws Exception {
final CountDownLatch latch = new CountDownLatch(1);