diff --git a/transport/src/main/java/io/netty/channel/AbstractChannelHandlerContext.java b/transport/src/main/java/io/netty/channel/AbstractChannelHandlerContext.java index 92c912a500..e9d97cd247 100644 --- a/transport/src/main/java/io/netty/channel/AbstractChannelHandlerContext.java +++ b/transport/src/main/java/io/netty/channel/AbstractChannelHandlerContext.java @@ -165,7 +165,7 @@ abstract class AbstractChannelHandlerContext implements ChannelHandlerContext, R try { ((ChannelInboundHandler) handler()).channelRegistered(this); } catch (Throwable t) { - notifyHandlerException(t); + invokeExceptionCaught(t); } } else { fireChannelRegistered(); @@ -197,7 +197,7 @@ abstract class AbstractChannelHandlerContext implements ChannelHandlerContext, R try { ((ChannelInboundHandler) handler()).channelUnregistered(this); } catch (Throwable t) { - notifyHandlerException(t); + invokeExceptionCaught(t); } } else { fireChannelUnregistered(); @@ -229,7 +229,7 @@ abstract class AbstractChannelHandlerContext implements ChannelHandlerContext, R try { ((ChannelInboundHandler) handler()).channelActive(this); } catch (Throwable t) { - notifyHandlerException(t); + invokeExceptionCaught(t); } } else { fireChannelActive(); @@ -261,7 +261,7 @@ abstract class AbstractChannelHandlerContext implements ChannelHandlerContext, R try { ((ChannelInboundHandler) handler()).channelInactive(this); } catch (Throwable t) { - notifyHandlerException(t); + invokeExceptionCaught(t); } } else { fireChannelInactive(); @@ -345,7 +345,7 @@ abstract class AbstractChannelHandlerContext implements ChannelHandlerContext, R try { ((ChannelInboundHandler) handler()).userEventTriggered(this, event); } catch (Throwable t) { - notifyHandlerException(t); + invokeExceptionCaught(t); } } else { fireUserEventTriggered(event); @@ -378,7 +378,7 @@ abstract class AbstractChannelHandlerContext implements ChannelHandlerContext, R try { ((ChannelInboundHandler) handler()).channelRead(this, msg); } catch (Throwable t) { - notifyHandlerException(t); + invokeExceptionCaught(t); } } else { fireChannelRead(msg); @@ -409,7 +409,7 @@ abstract class AbstractChannelHandlerContext implements ChannelHandlerContext, R try { ((ChannelInboundHandler) handler()).channelReadComplete(this); } catch (Throwable t) { - notifyHandlerException(t); + invokeExceptionCaught(t); } } else { fireChannelReadComplete(); @@ -440,7 +440,7 @@ abstract class AbstractChannelHandlerContext implements ChannelHandlerContext, R try { ((ChannelInboundHandler) handler()).channelWritabilityChanged(this); } catch (Throwable t) { - notifyHandlerException(t); + invokeExceptionCaught(t); } } else { fireChannelWritabilityChanged(); @@ -685,7 +685,7 @@ abstract class AbstractChannelHandlerContext implements ChannelHandlerContext, R try { ((ChannelOutboundHandler) handler()).read(this); } catch (Throwable t) { - notifyHandlerException(t); + invokeExceptionCaught(t); } } else { read(); @@ -749,7 +749,7 @@ abstract class AbstractChannelHandlerContext implements ChannelHandlerContext, R try { ((ChannelOutboundHandler) handler()).flush(this); } catch (Throwable t) { - notifyHandlerException(t); + invokeExceptionCaught(t); } } @@ -814,39 +814,6 @@ abstract class AbstractChannelHandlerContext implements ChannelHandlerContext, R PromiseNotificationUtil.tryFailure(promise, cause, promise instanceof VoidChannelPromise ? null : logger); } - private void notifyHandlerException(Throwable cause) { - if (inExceptionCaught(cause)) { - if (logger.isWarnEnabled()) { - logger.warn( - "An exception was thrown by a user handler " + - "while handling an exceptionCaught event", cause); - } - return; - } - - invokeExceptionCaught(cause); - } - - private static boolean inExceptionCaught(Throwable cause) { - do { - StackTraceElement[] trace = cause.getStackTrace(); - if (trace != null) { - for (StackTraceElement t : trace) { - if (t == null) { - break; - } - if ("exceptionCaught".equals(t.getMethodName())) { - return true; - } - } - } - - cause = cause.getCause(); - } while (cause != null); - - return false; - } - @Override public ChannelPromise newPromise() { return new DefaultChannelPromise(channel(), executor()); diff --git a/transport/src/test/java/io/netty/channel/DefaultChannelPipelineTest.java b/transport/src/test/java/io/netty/channel/DefaultChannelPipelineTest.java index 4d49fdc514..69fe9e7a89 100644 --- a/transport/src/test/java/io/netty/channel/DefaultChannelPipelineTest.java +++ b/transport/src/test/java/io/netty/channel/DefaultChannelPipelineTest.java @@ -60,6 +60,7 @@ import java.util.concurrent.Executors; import java.util.concurrent.LinkedBlockingQueue; import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicBoolean; +import java.util.concurrent.atomic.AtomicInteger; import java.util.concurrent.atomic.AtomicReference; import java.util.concurrent.locks.LockSupport; @@ -333,6 +334,87 @@ public class DefaultChannelPipelineTest { verifyContextNumber(pipeline, HANDLER_ARRAY_LEN * 2); } + @Test(timeout = 3000) + public void testThrowInExceptionCaught() throws InterruptedException { + final CountDownLatch latch = new CountDownLatch(1); + final AtomicInteger counter = new AtomicInteger(); + Channel channel = new LocalChannel(); + try { + group.register(channel).syncUninterruptibly(); + channel.pipeline().addLast(new ChannelInboundHandlerAdapter() { + class TestException extends Exception { } + + @Override + public void channelReadComplete(ChannelHandlerContext ctx) throws Exception { + throw new TestException(); + } + + @Override + public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) throws Exception { + if (cause instanceof TestException) { + ctx.executor().execute(new Runnable() { + @Override + public void run() { + latch.countDown(); + } + }); + } + counter.incrementAndGet(); + throw new Exception(); + } + }); + + channel.pipeline().fireChannelReadComplete(); + latch.await(); + assertEquals(1, counter.get()); + } finally { + channel.close().syncUninterruptibly(); + } + } + + @Test(timeout = 3000) + public void testThrowInOtherHandlerAfterInvokedFromExceptionCaught() throws InterruptedException { + final CountDownLatch latch = new CountDownLatch(1); + final AtomicInteger counter = new AtomicInteger(); + Channel channel = new LocalChannel(); + try { + group.register(channel).syncUninterruptibly(); + channel.pipeline().addLast(new ChannelInboundHandlerAdapter() { + @Override + public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) { + ctx.fireChannelReadComplete(); + } + }, new ChannelInboundHandlerAdapter() { + class TestException extends Exception { } + + @Override + public void channelReadComplete(ChannelHandlerContext ctx) throws Exception { + throw new TestException(); + } + + @Override + public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) throws Exception { + if (cause instanceof TestException) { + ctx.executor().execute(new Runnable() { + @Override + public void run() { + latch.countDown(); + } + }); + } + counter.incrementAndGet(); + throw new Exception(); + } + }); + + channel.pipeline().fireExceptionCaught(new Exception()); + latch.await(); + assertEquals(1, counter.get()); + } finally { + channel.close().syncUninterruptibly(); + } + } + @Test public void testFireChannelRegistered() throws Exception { final CountDownLatch latch = new CountDownLatch(1);