From 2c78b4c84fb968c6d130a63bee30cbef706c3da9 Mon Sep 17 00:00:00 2001 From: Norman Maurer Date: Thu, 16 Apr 2020 08:56:45 +0200 Subject: [PATCH] Correctly propagate exceptions from inbound operations in all cases (#10176) Motivation: In AbstractChannelHandlerContext we had some code where we tried to guard against endless loops caused by exceptions thrown by exceptionCaught(...) that would trigger exceptionCaught again. This code was proplematic for two reasons: - It is quite expensive as we need to compare all the stacks - We may end up not notify another handlers exceptionCaught(...) if in our exeuction stack we triggered actions that will cause an exceptionCaught somewhere else in the pipeline Modifications: - Just remove the detection code as we already handle everything correctly when we invoke exceptionCaught(...) - Add unit tests Result: Ensure we always notify correctly and also fixes performance issue reported as https://github.com/netty/netty/issues/10165 --- .../AbstractChannelHandlerContext.java | 53 +++--------- .../channel/DefaultChannelPipelineTest.java | 82 +++++++++++++++++++ 2 files changed, 92 insertions(+), 43 deletions(-) diff --git a/transport/src/main/java/io/netty/channel/AbstractChannelHandlerContext.java b/transport/src/main/java/io/netty/channel/AbstractChannelHandlerContext.java index 92c912a500..e9d97cd247 100644 --- a/transport/src/main/java/io/netty/channel/AbstractChannelHandlerContext.java +++ b/transport/src/main/java/io/netty/channel/AbstractChannelHandlerContext.java @@ -165,7 +165,7 @@ abstract class AbstractChannelHandlerContext implements ChannelHandlerContext, R try { ((ChannelInboundHandler) handler()).channelRegistered(this); } catch (Throwable t) { - notifyHandlerException(t); + invokeExceptionCaught(t); } } else { fireChannelRegistered(); @@ -197,7 +197,7 @@ abstract class AbstractChannelHandlerContext implements ChannelHandlerContext, R try { ((ChannelInboundHandler) handler()).channelUnregistered(this); } catch (Throwable t) { - notifyHandlerException(t); + invokeExceptionCaught(t); } } else { fireChannelUnregistered(); @@ -229,7 +229,7 @@ abstract class AbstractChannelHandlerContext implements ChannelHandlerContext, R try { ((ChannelInboundHandler) handler()).channelActive(this); } catch (Throwable t) { - notifyHandlerException(t); + invokeExceptionCaught(t); } } else { fireChannelActive(); @@ -261,7 +261,7 @@ abstract class AbstractChannelHandlerContext implements ChannelHandlerContext, R try { ((ChannelInboundHandler) handler()).channelInactive(this); } catch (Throwable t) { - notifyHandlerException(t); + invokeExceptionCaught(t); } } else { fireChannelInactive(); @@ -345,7 +345,7 @@ abstract class AbstractChannelHandlerContext implements ChannelHandlerContext, R try { ((ChannelInboundHandler) handler()).userEventTriggered(this, event); } catch (Throwable t) { - notifyHandlerException(t); + invokeExceptionCaught(t); } } else { fireUserEventTriggered(event); @@ -378,7 +378,7 @@ abstract class AbstractChannelHandlerContext implements ChannelHandlerContext, R try { ((ChannelInboundHandler) handler()).channelRead(this, msg); } catch (Throwable t) { - notifyHandlerException(t); + invokeExceptionCaught(t); } } else { fireChannelRead(msg); @@ -409,7 +409,7 @@ abstract class AbstractChannelHandlerContext implements ChannelHandlerContext, R try { ((ChannelInboundHandler) handler()).channelReadComplete(this); } catch (Throwable t) { - notifyHandlerException(t); + invokeExceptionCaught(t); } } else { fireChannelReadComplete(); @@ -440,7 +440,7 @@ abstract class AbstractChannelHandlerContext implements ChannelHandlerContext, R try { ((ChannelInboundHandler) handler()).channelWritabilityChanged(this); } catch (Throwable t) { - notifyHandlerException(t); + invokeExceptionCaught(t); } } else { fireChannelWritabilityChanged(); @@ -685,7 +685,7 @@ abstract class AbstractChannelHandlerContext implements ChannelHandlerContext, R try { ((ChannelOutboundHandler) handler()).read(this); } catch (Throwable t) { - notifyHandlerException(t); + invokeExceptionCaught(t); } } else { read(); @@ -749,7 +749,7 @@ abstract class AbstractChannelHandlerContext implements ChannelHandlerContext, R try { ((ChannelOutboundHandler) handler()).flush(this); } catch (Throwable t) { - notifyHandlerException(t); + invokeExceptionCaught(t); } } @@ -814,39 +814,6 @@ abstract class AbstractChannelHandlerContext implements ChannelHandlerContext, R PromiseNotificationUtil.tryFailure(promise, cause, promise instanceof VoidChannelPromise ? null : logger); } - private void notifyHandlerException(Throwable cause) { - if (inExceptionCaught(cause)) { - if (logger.isWarnEnabled()) { - logger.warn( - "An exception was thrown by a user handler " + - "while handling an exceptionCaught event", cause); - } - return; - } - - invokeExceptionCaught(cause); - } - - private static boolean inExceptionCaught(Throwable cause) { - do { - StackTraceElement[] trace = cause.getStackTrace(); - if (trace != null) { - for (StackTraceElement t : trace) { - if (t == null) { - break; - } - if ("exceptionCaught".equals(t.getMethodName())) { - return true; - } - } - } - - cause = cause.getCause(); - } while (cause != null); - - return false; - } - @Override public ChannelPromise newPromise() { return new DefaultChannelPromise(channel(), executor()); diff --git a/transport/src/test/java/io/netty/channel/DefaultChannelPipelineTest.java b/transport/src/test/java/io/netty/channel/DefaultChannelPipelineTest.java index 4d49fdc514..69fe9e7a89 100644 --- a/transport/src/test/java/io/netty/channel/DefaultChannelPipelineTest.java +++ b/transport/src/test/java/io/netty/channel/DefaultChannelPipelineTest.java @@ -60,6 +60,7 @@ import java.util.concurrent.Executors; import java.util.concurrent.LinkedBlockingQueue; import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicBoolean; +import java.util.concurrent.atomic.AtomicInteger; import java.util.concurrent.atomic.AtomicReference; import java.util.concurrent.locks.LockSupport; @@ -333,6 +334,87 @@ public class DefaultChannelPipelineTest { verifyContextNumber(pipeline, HANDLER_ARRAY_LEN * 2); } + @Test(timeout = 3000) + public void testThrowInExceptionCaught() throws InterruptedException { + final CountDownLatch latch = new CountDownLatch(1); + final AtomicInteger counter = new AtomicInteger(); + Channel channel = new LocalChannel(); + try { + group.register(channel).syncUninterruptibly(); + channel.pipeline().addLast(new ChannelInboundHandlerAdapter() { + class TestException extends Exception { } + + @Override + public void channelReadComplete(ChannelHandlerContext ctx) throws Exception { + throw new TestException(); + } + + @Override + public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) throws Exception { + if (cause instanceof TestException) { + ctx.executor().execute(new Runnable() { + @Override + public void run() { + latch.countDown(); + } + }); + } + counter.incrementAndGet(); + throw new Exception(); + } + }); + + channel.pipeline().fireChannelReadComplete(); + latch.await(); + assertEquals(1, counter.get()); + } finally { + channel.close().syncUninterruptibly(); + } + } + + @Test(timeout = 3000) + public void testThrowInOtherHandlerAfterInvokedFromExceptionCaught() throws InterruptedException { + final CountDownLatch latch = new CountDownLatch(1); + final AtomicInteger counter = new AtomicInteger(); + Channel channel = new LocalChannel(); + try { + group.register(channel).syncUninterruptibly(); + channel.pipeline().addLast(new ChannelInboundHandlerAdapter() { + @Override + public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) { + ctx.fireChannelReadComplete(); + } + }, new ChannelInboundHandlerAdapter() { + class TestException extends Exception { } + + @Override + public void channelReadComplete(ChannelHandlerContext ctx) throws Exception { + throw new TestException(); + } + + @Override + public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) throws Exception { + if (cause instanceof TestException) { + ctx.executor().execute(new Runnable() { + @Override + public void run() { + latch.countDown(); + } + }); + } + counter.incrementAndGet(); + throw new Exception(); + } + }); + + channel.pipeline().fireExceptionCaught(new Exception()); + latch.await(); + assertEquals(1, counter.get()); + } finally { + channel.close().syncUninterruptibly(); + } + } + @Test public void testFireChannelRegistered() throws Exception { final CountDownLatch latch = new CountDownLatch(1);