diff --git a/codec-haproxy/src/main/java/io/netty/handler/codec/haproxy/HAProxyMessageDecoder.java b/codec-haproxy/src/main/java/io/netty/handler/codec/haproxy/HAProxyMessageDecoder.java index b8fcf1bc1b..a932a5b7a7 100644 --- a/codec-haproxy/src/main/java/io/netty/handler/codec/haproxy/HAProxyMessageDecoder.java +++ b/codec-haproxy/src/main/java/io/netty/handler/codec/haproxy/HAProxyMessageDecoder.java @@ -253,6 +253,14 @@ public class HAProxyMessageDecoder extends ByteToMessageDecoder { } } + @Override + public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) throws Exception { + ctx.fireExceptionCaught(cause); + if (cause instanceof HAProxyProtocolException) { + ctx.close(); // drop connection immediately per spec + } + } + @Override protected final void decode(ChannelHandlerContext ctx, ByteBuf in) throws Exception { // determine the specification version @@ -325,7 +333,6 @@ public class HAProxyMessageDecoder extends ByteToMessageDecoder { private void fail(final ChannelHandlerContext ctx, String errMsg, Exception e) { finished = true; - ctx.close(); // drop connection immediately per spec HAProxyProtocolException ppex; if (errMsg != null && e != null) { ppex = new HAProxyProtocolException(errMsg, e); diff --git a/codec-http/src/main/java/io/netty/handler/codec/http/HttpClientUpgradeHandler.java b/codec-http/src/main/java/io/netty/handler/codec/http/HttpClientUpgradeHandler.java index 177f70b4ee..e5ff93b704 100644 --- a/codec-http/src/main/java/io/netty/handler/codec/http/HttpClientUpgradeHandler.java +++ b/codec-http/src/main/java/io/netty/handler/codec/http/HttpClientUpgradeHandler.java @@ -203,8 +203,8 @@ public class HttpClientUpgradeHandler extends HttpObjectAggregator { // NOTE: not releasing the response since we're letting it propagate to the // next handler. ctx.fireUserEventTriggered(UpgradeEvent.UPGRADE_REJECTED); - removeThisHandler(ctx); ctx.fireChannelRead(msg); + removeThisHandler(ctx); return; } } diff --git a/codec-http/src/main/java/io/netty/handler/codec/http/HttpObjectAggregator.java b/codec-http/src/main/java/io/netty/handler/codec/http/HttpObjectAggregator.java index d4d7c95aac..3155d30c5c 100644 --- a/codec-http/src/main/java/io/netty/handler/codec/http/HttpObjectAggregator.java +++ b/codec-http/src/main/java/io/netty/handler/codec/http/HttpObjectAggregator.java @@ -266,13 +266,26 @@ public class HttpObjectAggregator }); } } else if (oversized instanceof HttpResponse) { - ctx.close(); - throw new TooLongFrameException("Response entity too large: " + oversized); + throw new ResponseTooLargeException("Response entity too large: " + oversized); } else { throw new IllegalStateException(); } } + @Override + public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) throws Exception { + super.exceptionCaught(ctx, cause); + if (cause instanceof ResponseTooLargeException) { + ctx.close(); + } + } + + private static final class ResponseTooLargeException extends TooLongFrameException { + ResponseTooLargeException(String message) { + super(message); + } + } + private abstract static class AggregatedFullHttpMessage implements FullHttpMessage { protected final HttpMessage message; private final ByteBuf content; diff --git a/codec-http/src/main/java/io/netty/handler/codec/http/HttpServerUpgradeHandler.java b/codec-http/src/main/java/io/netty/handler/codec/http/HttpServerUpgradeHandler.java index 7fd1c5a61d..e5f8103046 100644 --- a/codec-http/src/main/java/io/netty/handler/codec/http/HttpServerUpgradeHandler.java +++ b/codec-http/src/main/java/io/netty/handler/codec/http/HttpServerUpgradeHandler.java @@ -328,13 +328,13 @@ public class HttpServerUpgradeHandler extends HttpObjectAggregator { sourceCodec.upgradeFrom(ctx); upgradeCodec.upgradeTo(ctx, request); - // Remove this handler from the pipeline. - ctx.pipeline().remove(HttpServerUpgradeHandler.this); - // Notify that the upgrade has occurred. Retain the event to offset // the release() in the finally block. ctx.fireUserEventTriggered(event.retain()); + // Remove this handler from the pipeline. + ctx.pipeline().remove(HttpServerUpgradeHandler.this); + // Add the listener last to avoid firing upgrade logic after // the channel is already closed since the listener may fire // immediately if the write failed eagerly. diff --git a/codec-http/src/main/java/io/netty/handler/codec/http/websocketx/WebSocketServerHandshaker.java b/codec-http/src/main/java/io/netty/handler/codec/http/websocketx/WebSocketServerHandshaker.java index d0967e3653..627e929bdc 100644 --- a/codec-http/src/main/java/io/netty/handler/codec/http/websocketx/WebSocketServerHandshaker.java +++ b/codec-http/src/main/java/io/netty/handler/codec/http/websocketx/WebSocketServerHandshaker.java @@ -300,9 +300,9 @@ public abstract class WebSocketServerHandshaker { @Override public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) throws Exception { // Remove ourself and fail the handshake promise. - ctx.pipeline().remove(this); promise.tryFailure(cause); ctx.fireExceptionCaught(cause); + ctx.pipeline().remove(this); } @Override diff --git a/codec-http/src/main/java/io/netty/handler/codec/http/websocketx/WebSocketServerProtocolHandshakeHandler.java b/codec-http/src/main/java/io/netty/handler/codec/http/websocketx/WebSocketServerProtocolHandshakeHandler.java index b6ea2aa8af..b971135d49 100644 --- a/codec-http/src/main/java/io/netty/handler/codec/http/websocketx/WebSocketServerProtocolHandshakeHandler.java +++ b/codec-http/src/main/java/io/netty/handler/codec/http/websocketx/WebSocketServerProtocolHandshakeHandler.java @@ -45,7 +45,6 @@ import static io.netty.handler.codec.http.HttpVersion.*; class WebSocketServerProtocolHandshakeHandler implements ChannelHandler { private final WebSocketServerProtocolConfig serverConfig; - private ChannelHandlerContext ctx; private ChannelPromise handshakePromise; WebSocketServerProtocolHandshakeHandler(WebSocketServerProtocolConfig serverConfig) { @@ -54,7 +53,6 @@ class WebSocketServerProtocolHandshakeHandler implements ChannelHandler { @Override public void handlerAdded(ChannelHandlerContext ctx) { - this.ctx = ctx; handshakePromise = ctx.newPromise(); } @@ -86,7 +84,6 @@ class WebSocketServerProtocolHandshakeHandler implements ChannelHandler { // // See https://github.com/netty/netty/issues/9471. WebSocketServerProtocolHandler.setHandshaker(ctx.channel(), handshaker); - ctx.pipeline().remove(this); final ChannelFuture handshakeFuture = handshaker.handshake(ctx.channel(), req); handshakeFuture.addListener((ChannelFutureListener) future -> { @@ -102,8 +99,9 @@ class WebSocketServerProtocolHandshakeHandler implements ChannelHandler { new WebSocketServerProtocolHandler.HandshakeComplete( req.uri(), req.headers(), handshaker.selectedSubprotocol())); } + ctx.pipeline().remove(this); }); - applyHandshakeTimeout(); + applyHandshakeTimeout(ctx); } } finally { req.release(); @@ -132,7 +130,7 @@ class WebSocketServerProtocolHandshakeHandler implements ChannelHandler { return protocol + "://" + host + path; } - private void applyHandshakeTimeout() { + private void applyHandshakeTimeout(ChannelHandlerContext ctx) { final ChannelPromise localHandshakePromise = handshakePromise; final long handshakeTimeoutMillis = serverConfig.handshakeTimeoutMillis(); if (handshakeTimeoutMillis <= 0 || localHandshakePromise.isDone()) { diff --git a/codec-http/src/main/java/io/netty/handler/codec/http/websocketx/extensions/WebSocketClientExtensionHandler.java b/codec-http/src/main/java/io/netty/handler/codec/http/websocketx/extensions/WebSocketClientExtensionHandler.java index bb146d52c0..92cf11dbe1 100644 --- a/codec-http/src/main/java/io/netty/handler/codec/http/websocketx/extensions/WebSocketClientExtensionHandler.java +++ b/codec-http/src/main/java/io/netty/handler/codec/http/websocketx/extensions/WebSocketClientExtensionHandler.java @@ -120,8 +120,9 @@ public class WebSocketClientExtensionHandler implements ChannelHandler { ctx.pipeline().addAfter(ctx.name(), encoder.getClass().getName(), encoder); } } - - ctx.pipeline().remove(ctx.name()); + ctx.fireChannelRead(msg); + ctx.pipeline().remove(this); + return; } } diff --git a/codec-http/src/test/java/io/netty/handler/codec/http/websocketx/WebSocketServerProtocolHandlerTest.java b/codec-http/src/test/java/io/netty/handler/codec/http/websocketx/WebSocketServerProtocolHandlerTest.java index 2e26309d6b..2ffac41e7b 100644 --- a/codec-http/src/test/java/io/netty/handler/codec/http/websocketx/WebSocketServerProtocolHandlerTest.java +++ b/codec-http/src/test/java/io/netty/handler/codec/http/websocketx/WebSocketServerProtocolHandlerTest.java @@ -66,7 +66,7 @@ public class WebSocketServerProtocolHandlerTest { } @Test - public void testWebSocketServerProtocolHandshakeHandlerReplacedBeforeHandshake() { + public void testWebSocketServerProtocolHandshakeHandlerRemovedAfterHandshake() { EmbeddedChannel ch = createChannel(new MockOutboundHandler()); ChannelHandlerContext handshakerCtx = ch.pipeline().context(WebSocketServerProtocolHandshakeHandler.class); ch.pipeline().addLast(new ChannelHandler() { @@ -74,7 +74,7 @@ public class WebSocketServerProtocolHandlerTest { public void userEventTriggered(ChannelHandlerContext ctx, Object evt) { if (evt instanceof WebSocketServerProtocolHandler.HandshakeComplete) { // We should have removed the handler already. - assertNull(ctx.pipeline().context(WebSocketServerProtocolHandshakeHandler.class)); + ctx.executor().execute(() -> ctx.pipeline().context(WebSocketServerProtocolHandshakeHandler.class)); } } }); diff --git a/codec-http2/src/main/java/io/netty/handler/codec/http2/CleartextHttp2ServerUpgradeHandler.java b/codec-http2/src/main/java/io/netty/handler/codec/http2/CleartextHttp2ServerUpgradeHandler.java index 8cf31ba230..41d9edecab 100644 --- a/codec-http2/src/main/java/io/netty/handler/codec/http2/CleartextHttp2ServerUpgradeHandler.java +++ b/codec-http2/src/main/java/io/netty/handler/codec/http2/CleartextHttp2ServerUpgradeHandler.java @@ -87,9 +87,9 @@ public final class CleartextHttp2ServerUpgradeHandler extends ByteToMessageDecod .remove(httpServerUpgradeHandler); ctx.pipeline().addAfter(ctx.name(), null, http2ServerHandler); - ctx.pipeline().remove(this); - ctx.fireUserEventTriggered(PriorKnowledgeUpgradeEvent.INSTANCE); + + ctx.pipeline().remove(this); } } diff --git a/codec/src/main/java/io/netty/handler/codec/ByteToMessageDecoder.java b/codec/src/main/java/io/netty/handler/codec/ByteToMessageDecoder.java index 1c5df26c31..461cc8d1ae 100644 --- a/codec/src/main/java/io/netty/handler/codec/ByteToMessageDecoder.java +++ b/codec/src/main/java/io/netty/handler/codec/ByteToMessageDecoder.java @@ -154,10 +154,6 @@ public abstract class ByteToMessageDecoder extends ChannelHandlerAdapter { } }; - private static final byte STATE_INIT = 0; - private static final byte STATE_CALLING_CHILD_DECODE = 1; - private static final byte STATE_HANDLER_REMOVED_PENDING = 2; - ByteBuf cumulation; private Cumulator cumulator = MERGE_CUMULATOR; private boolean singleDecode; @@ -169,15 +165,6 @@ public abstract class ByteToMessageDecoder extends ChannelHandlerAdapter { */ private boolean firedChannelRead; - /** - * A bitmask where the bits are defined as - * - */ - private byte decodeState = STATE_INIT; private int discardAfterReads = 16; private int numReads; private ByteToMessageDecoderContext context; @@ -257,10 +244,6 @@ public abstract class ByteToMessageDecoder extends ChannelHandlerAdapter { @Override public final void handlerRemoved(ChannelHandlerContext ctx) throws Exception { - if (decodeState == STATE_CALLING_CHILD_DECODE) { - decodeState = STATE_HANDLER_REMOVED_PENDING; - return; - } ByteBuf buf = cumulation; if (buf != null) { // Directly set this to null so we are sure we not access it in any other method here anymore. @@ -469,16 +452,7 @@ public abstract class ByteToMessageDecoder extends ChannelHandlerAdapter { */ final void decodeRemovalReentryProtection(ChannelHandlerContext ctx, ByteBuf in) throws Exception { - decodeState = STATE_CALLING_CHILD_DECODE; - try { - decode(ctx, in); - } finally { - boolean removePending = decodeState == STATE_HANDLER_REMOVED_PENDING; - decodeState = STATE_INIT; - if (removePending) { - handlerRemoved(ctx); - } - } + decode(ctx, in); } /** diff --git a/example/src/main/java/io/netty/example/http2/helloworld/frame/server/Http2ServerInitializer.java b/example/src/main/java/io/netty/example/http2/helloworld/frame/server/Http2ServerInitializer.java index 212b21819f..bf5cba9a04 100644 --- a/example/src/main/java/io/netty/example/http2/helloworld/frame/server/Http2ServerInitializer.java +++ b/example/src/main/java/io/netty/example/http2/helloworld/frame/server/Http2ServerInitializer.java @@ -97,8 +97,9 @@ public class Http2ServerInitializer extends ChannelInitializer { System.err.println("Directly talking: " + msg.protocolVersion() + " (no upgrade was attempted)"); ChannelPipeline pipeline = ctx.pipeline(); pipeline.addAfter(ctx.name(), null, new HelloWorldHttp1Handler("Direct. No Upgrade Attempted.")); - pipeline.replace(this, null, new HttpObjectAggregator(maxHttpContentLength)); + pipeline.addAfter(ctx.name(), null, new HttpObjectAggregator(maxHttpContentLength)); ctx.fireChannelRead(ReferenceCountUtil.retain(msg)); + pipeline.remove(this); } }); diff --git a/example/src/main/java/io/netty/example/http2/helloworld/multiplex/server/Http2ServerInitializer.java b/example/src/main/java/io/netty/example/http2/helloworld/multiplex/server/Http2ServerInitializer.java index 08389b1ec8..9543605d2f 100644 --- a/example/src/main/java/io/netty/example/http2/helloworld/multiplex/server/Http2ServerInitializer.java +++ b/example/src/main/java/io/netty/example/http2/helloworld/multiplex/server/Http2ServerInitializer.java @@ -99,8 +99,9 @@ public class Http2ServerInitializer extends ChannelInitializer { System.err.println("Directly talking: " + msg.protocolVersion() + " (no upgrade was attempted)"); ChannelPipeline pipeline = ctx.pipeline(); pipeline.addAfter(ctx.name(), null, new HelloWorldHttp1Handler("Direct. No Upgrade Attempted.")); - pipeline.replace(this, null, new HttpObjectAggregator(maxHttpContentLength)); + pipeline.addAfter(ctx.name(), null, new HttpObjectAggregator(maxHttpContentLength)); ctx.fireChannelRead(ReferenceCountUtil.retain(msg)); + pipeline.remove(this); } }); diff --git a/example/src/main/java/io/netty/example/http2/helloworld/server/Http2ServerInitializer.java b/example/src/main/java/io/netty/example/http2/helloworld/server/Http2ServerInitializer.java index 826fd8f410..b9f01b9a0e 100644 --- a/example/src/main/java/io/netty/example/http2/helloworld/server/Http2ServerInitializer.java +++ b/example/src/main/java/io/netty/example/http2/helloworld/server/Http2ServerInitializer.java @@ -98,8 +98,9 @@ public class Http2ServerInitializer extends ChannelInitializer { System.err.println("Directly talking: " + msg.protocolVersion() + " (no upgrade was attempted)"); ChannelPipeline pipeline = ctx.pipeline(); pipeline.addAfter(ctx.name(), null, new HelloWorldHttp1Handler("Direct. No Upgrade Attempted.")); - pipeline.replace(this, null, new HttpObjectAggregator(maxHttpContentLength)); + pipeline.addAfter(ctx.name(), null, new HttpObjectAggregator(maxHttpContentLength)); ctx.fireChannelRead(ReferenceCountUtil.retain(msg)); + pipeline.remove(this); } }); diff --git a/example/src/main/java/io/netty/example/socksproxy/SocksServerHandler.java b/example/src/main/java/io/netty/example/socksproxy/SocksServerHandler.java index 2cb2655bf8..3f25dda13a 100644 --- a/example/src/main/java/io/netty/example/socksproxy/SocksServerHandler.java +++ b/example/src/main/java/io/netty/example/socksproxy/SocksServerHandler.java @@ -45,8 +45,8 @@ public final class SocksServerHandler extends SimpleChannelInboundHandler imple @Override public void channelRegistered(ChannelHandlerContext ctx) throws Exception { - handleNewChannel(ctx); - ctx.fireChannelRegistered(); + handleNewChannel(ctx, true); } @Override public void channelActive(ChannelHandlerContext ctx) throws Exception { - if (!handleNewChannel(ctx)) { + if (!handleNewChannel(ctx, false)) { throw new IllegalStateException("cannot determine to accept or reject a channel: " + ctx.channel()); - } else { - ctx.fireChannelActive(); } } - private boolean handleNewChannel(ChannelHandlerContext ctx) throws Exception { + private boolean handleNewChannel(ChannelHandlerContext ctx, boolean register) throws Exception { @SuppressWarnings("unchecked") T remoteAddress = (T) ctx.channel().remoteAddress(); + boolean remove = false; + try { + // If the remote address is not available yet, defer the decision. + if (remoteAddress == null) { + return false; + } - // If the remote address is not available yet, defer the decision. - if (remoteAddress == null) { - return false; - } - - // No need to keep this handler in the pipeline anymore because the decision is going to be made now. - // Also, this will prevent the subsequent events from being handled by this handler. - ctx.pipeline().remove(this); - - if (accept(ctx, remoteAddress)) { - channelAccepted(ctx, remoteAddress); - } else { - ChannelFuture rejectedFuture = channelRejected(ctx, remoteAddress); - if (rejectedFuture != null) { - rejectedFuture.addListener(ChannelFutureListener.CLOSE); + if (accept(ctx, remoteAddress)) { + channelAccepted(ctx, remoteAddress); + remove = true; } else { - ctx.close(); + ChannelFuture rejectedFuture = channelRejected(ctx, remoteAddress); + if (rejectedFuture != null && !rejectedFuture.isDone()) { + rejectedFuture.addListener(ChannelFutureListener.CLOSE); + } else { + ctx.close(); + } + } + return true; + } finally { + if (!ctx.isRemoved()) { + if (register) { + ctx.fireChannelRegistered(); + } else { + ctx.fireChannelActive(); + } + if (remove) { + // No need to keep this handler in the pipeline anymore because the decision is going to be made + // now. Also, this will prevent the subsequent events from being handled by this handler. + ctx.pipeline().remove(this); + } } } - - return true; } /** diff --git a/handler/src/main/java/io/netty/handler/ssl/ApplicationProtocolNegotiationHandler.java b/handler/src/main/java/io/netty/handler/ssl/ApplicationProtocolNegotiationHandler.java index dd4361f7e5..df1afe4a23 100644 --- a/handler/src/main/java/io/netty/handler/ssl/ApplicationProtocolNegotiationHandler.java +++ b/handler/src/main/java/io/netty/handler/ssl/ApplicationProtocolNegotiationHandler.java @@ -97,13 +97,16 @@ public abstract class ApplicationProtocolNegotiationHandler implements ChannelHa } catch (Throwable cause) { exceptionCaught(ctx, cause); } finally { + ctx.fireUserEventTriggered(evt); + ChannelPipeline pipeline = ctx.pipeline(); if (pipeline.context(this) != null) { pipeline.remove(this); } } + } else { + ctx.fireUserEventTriggered(evt); } - ctx.fireUserEventTriggered(evt); } /** diff --git a/handler/src/main/java/io/netty/handler/ssl/SslClientHelloHandler.java b/handler/src/main/java/io/netty/handler/ssl/SslClientHelloHandler.java index d9c164b860..86beb927c2 100644 --- a/handler/src/main/java/io/netty/handler/ssl/SslClientHelloHandler.java +++ b/handler/src/main/java/io/netty/handler/ssl/SslClientHelloHandler.java @@ -63,7 +63,7 @@ public abstract class SslClientHelloHandler extends ByteToMessageDecoder { "not an SSL/TLS record: " + ByteBufUtil.hexDump(in)); in.skipBytes(in.readableBytes()); ctx.fireUserEventTriggered(new SniCompletionEvent(e)); - SslUtils.handleHandshakeFailure(ctx, e, true); + ctx.fireUserEventTriggered(new SslHandshakeCompletionEvent(e)); throw e; } if (len == SslUtils.NOT_ENOUGH_DATA) { diff --git a/handler/src/main/java/io/netty/handler/ssl/SslHandler.java b/handler/src/main/java/io/netty/handler/ssl/SslHandler.java index e859b77755..a3b4076968 100644 --- a/handler/src/main/java/io/netty/handler/ssl/SslHandler.java +++ b/handler/src/main/java/io/netty/handler/ssl/SslHandler.java @@ -1091,7 +1091,8 @@ public class SslHandler extends ByteToMessageDecoder { if (logger.isDebugEnabled()) { logger.debug( "{} Swallowing a harmless 'connection reset by peer / broken pipe' error that occurred " + - "while writing close_notify in response to the peer's close_notify", ctx.channel(), cause); + "while writing close_notify in response to the peer's close_notify", + ctx.channel(), cause); } // Close the connection explicitly just in case the transport @@ -1101,6 +1102,11 @@ public class SslHandler extends ByteToMessageDecoder { } } else { ctx.fireExceptionCaught(cause); + + if (cause instanceof SSLException || + ((cause instanceof DecoderException) && cause.getCause() instanceof SSLException)) { + ctx.close(); + } } } @@ -2057,9 +2063,13 @@ public class SslHandler extends ByteToMessageDecoder { } final long closeNotifyReadTimeout = closeNotifyReadTimeoutMillis; if (closeNotifyReadTimeout <= 0) { - // Trigger the close in all cases to make sure the promise is notified - // See https://github.com/netty/netty/issues/2358 - addCloseListener(ctx.close(ctx.newPromise()), promise); + if (ctx.channel().isActive()) { + // Trigger the close in all cases to make sure the promise is notified + // See https://github.com/netty/netty/issues/2358 + addCloseListener(ctx.close(ctx.newPromise()), promise); + } else { + promise.trySuccess(); + } } else { final ScheduledFuture closeNotifyReadTimeoutFuture; @@ -2083,7 +2093,11 @@ public class SslHandler extends ByteToMessageDecoder { if (closeNotifyReadTimeoutFuture != null) { closeNotifyReadTimeoutFuture.cancel(false); } - addCloseListener(ctx.close(ctx.newPromise()), promise); + if (ctx.channel().isActive()) { + addCloseListener(ctx.close(ctx.newPromise()), promise); + } else { + promise.trySuccess(); + } }); } }); diff --git a/handler/src/main/java/io/netty/handler/ssl/SslUtils.java b/handler/src/main/java/io/netty/handler/ssl/SslUtils.java index 296a4d640d..5a036caa79 100644 --- a/handler/src/main/java/io/netty/handler/ssl/SslUtils.java +++ b/handler/src/main/java/io/netty/handler/ssl/SslUtils.java @@ -346,7 +346,6 @@ final class SslUtils { if (notify) { ctx.fireUserEventTriggered(new SslHandshakeCompletionEvent(cause)); } - ctx.close(); } /** diff --git a/handler/src/main/java/io/netty/handler/ssl/ocsp/OcspClientHandler.java b/handler/src/main/java/io/netty/handler/ssl/ocsp/OcspClientHandler.java index eb8615a63c..c75ee85f46 100644 --- a/handler/src/main/java/io/netty/handler/ssl/ocsp/OcspClientHandler.java +++ b/handler/src/main/java/io/netty/handler/ssl/ocsp/OcspClientHandler.java @@ -48,15 +48,13 @@ public abstract class OcspClientHandler implements ChannelHandler { @Override public void userEventTriggered(ChannelHandlerContext ctx, Object evt) throws Exception { + ctx.fireUserEventTriggered(evt); if (evt instanceof SslHandshakeCompletionEvent) { - ctx.pipeline().remove(this); - SslHandshakeCompletionEvent event = (SslHandshakeCompletionEvent) evt; if (event.isSuccess() && !verify(ctx, engine)) { throw new SSLHandshakeException("Bad OCSP response"); } + ctx.pipeline().remove(this); } - - ctx.fireUserEventTriggered(evt); } } diff --git a/handler/src/test/java/io/netty/handler/ssl/SslHandlerTest.java b/handler/src/test/java/io/netty/handler/ssl/SslHandlerTest.java index 9f06c683bd..ec99e477e6 100644 --- a/handler/src/test/java/io/netty/handler/ssl/SslHandlerTest.java +++ b/handler/src/test/java/io/netty/handler/ssl/SslHandlerTest.java @@ -442,10 +442,12 @@ public class SslHandlerTest { sslHandler.setHandshakeTimeoutMillis(1000); ch.pipeline().addFirst(sslHandler); sslHandler.handshakeFuture().addListener((FutureListener) future -> { - ch.pipeline().remove(sslHandler); + ch.eventLoop().execute(() -> { + ch.pipeline().remove(sslHandler); - // Schedule the close so removal has time to propagate exception if any. - ch.eventLoop().execute(ch::close); + // Schedule the close so removal has time to propagate exception if any. + ch.eventLoop().execute(ch::close); + }); }); ch.pipeline().addLast(new ChannelHandler() { @@ -457,6 +459,7 @@ public class SslHandlerTest { if (cause instanceof IllegalReferenceCountException) { promise.setFailure(cause); } + cause.printStackTrace(); } @Override diff --git a/testsuite-http2/src/main/java/io/netty/testsuite/http2/Http2ServerInitializer.java b/testsuite-http2/src/main/java/io/netty/testsuite/http2/Http2ServerInitializer.java index 3195aa8841..4ac61260ea 100644 --- a/testsuite-http2/src/main/java/io/netty/testsuite/http2/Http2ServerInitializer.java +++ b/testsuite-http2/src/main/java/io/netty/testsuite/http2/Http2ServerInitializer.java @@ -85,8 +85,9 @@ public class Http2ServerInitializer extends ChannelInitializer { ChannelPipeline pipeline = ctx.pipeline(); ChannelHandlerContext thisCtx = pipeline.context(this); pipeline.addAfter(thisCtx.name(), null, new HelloWorldHttp1Handler("Direct. No Upgrade Attempted.")); - pipeline.replace(this, null, new HttpObjectAggregator(maxHttpContentLength)); + pipeline.addAfter(thisCtx.name(), null, new HttpObjectAggregator(maxHttpContentLength)); ctx.fireChannelRead(ReferenceCountUtil.retain(msg)); + pipeline.remove(this); } }); diff --git a/transport/src/main/java/io/netty/channel/AbstractChannel.java b/transport/src/main/java/io/netty/channel/AbstractChannel.java index 5c2429ff0d..ab42966c0e 100644 --- a/transport/src/main/java/io/netty/channel/AbstractChannel.java +++ b/transport/src/main/java/io/netty/channel/AbstractChannel.java @@ -36,6 +36,7 @@ import java.net.SocketAddress; import java.net.SocketException; import java.nio.channels.ClosedChannelException; import java.nio.channels.NotYetConnectedException; +import java.util.NoSuchElementException; import java.util.concurrent.Executor; import java.util.concurrent.RejectedExecutionException; @@ -775,12 +776,25 @@ public abstract class AbstractChannel extends DefaultAttributeMap implements Cha pipeline.fireChannelInactive(); } // Some transports like local and AIO does not allow the deregistration of - // an open channel. Their doDeregister() calls close(). Consequently, + // an open channel. Their doDeregister() calls close(). Consequently, // close() calls deregister() again - no need to fire channelUnregistered, so check // if it was registered. if (registered) { registered = false; pipeline.fireChannelUnregistered(); + + if (!isOpen()) { + // Remove all handlers from the ChannelPipeline. This is needed to ensure + // handlerRemoved(...) is called and so resources are released. + while (!pipeline.isEmpty()) { + try { + pipeline.removeLast(); + } catch (NoSuchElementException ignore) { + // try again as there may be a race when someone outside the EventLoop removes + // handlers concurrently as well. + } + } + } } safeSetSuccess(promise); } diff --git a/transport/src/main/java/io/netty/channel/ChannelPipeline.java b/transport/src/main/java/io/netty/channel/ChannelPipeline.java index cb3280fa69..7680f65c7b 100644 --- a/transport/src/main/java/io/netty/channel/ChannelPipeline.java +++ b/transport/src/main/java/io/netty/channel/ChannelPipeline.java @@ -437,6 +437,14 @@ public interface ChannelPipeline */ ChannelHandlerContext lastContext(); + /** + * Returns {@code true} if this {@link ChannelPipeline} is empty, which means no {@link ChannelHandler} is + * present. + */ + default boolean isEmpty() { + return lastContext() == null; + } + /** * Returns the {@link ChannelHandler} with the specified name in this * pipeline. diff --git a/transport/src/main/java/io/netty/channel/DefaultChannelHandlerContext.java b/transport/src/main/java/io/netty/channel/DefaultChannelHandlerContext.java index aa79904c8a..cc6e53e634 100644 --- a/transport/src/main/java/io/netty/channel/DefaultChannelHandlerContext.java +++ b/transport/src/main/java/io/netty/channel/DefaultChannelHandlerContext.java @@ -50,10 +50,15 @@ final class DefaultChannelHandlerContext implements ChannelHandlerContext, Resou */ private static final int ADD_COMPLETE = 1; + /** + * {@link ChannelHandler#handlerRemoved(ChannelHandlerContext)} is about to be called. + */ + private static final int REMOVE_STARTED = 2; + /** * {@link ChannelHandler#handlerRemoved(ChannelHandlerContext)} was called. */ - private static final int REMOVE_COMPLETE = 2; + private static final int REMOVE_COMPLETE = 3; private final int executionMask; private final DefaultChannelPipeline pipeline; @@ -63,10 +68,10 @@ final class DefaultChannelHandlerContext implements ChannelHandlerContext, Resou // Lazily instantiated tasks used to trigger events to a handler with different executor. // There is no need to make this volatile as at worse it will just create a few more instances then needed. private Tasks invokeTasks; + private int handlerState = INIT; DefaultChannelHandlerContext next; DefaultChannelHandlerContext prev; - private int handlerState = INIT; DefaultChannelHandlerContext(DefaultChannelPipeline pipeline, String name, ChannelHandler handler) { @@ -76,6 +81,22 @@ final class DefaultChannelHandlerContext implements ChannelHandlerContext, Resou this.handler = handler; } + private static void failRemoved(DefaultChannelHandlerContext ctx, ChannelPromise promise) { + promise.setFailure(newRemovedException(ctx, null)); + } + + private void notifyHandlerRemovedAlready() { + notifyHandlerRemovedAlready(null); + } + + private void notifyHandlerRemovedAlready(Throwable cause) { + pipeline().fireExceptionCaught(newRemovedException(this, cause)); + } + + private static ChannelPipelineException newRemovedException(ChannelHandlerContext ctx, Throwable cause) { + return new ChannelPipelineException("Context " + ctx + " already removed", cause); + } + private Tasks invokeTasks() { Tasks tasks = invokeTasks; if (tasks == null) { @@ -126,7 +147,12 @@ final class DefaultChannelHandlerContext implements ChannelHandlerContext, Resou } private void findAndInvokeChannelRegistered() { - findContextInbound(MASK_CHANNEL_REGISTERED).invokeChannelRegistered(); + DefaultChannelHandlerContext ctx = findContextInbound(MASK_CHANNEL_REGISTERED); + if (ctx == null) { + notifyHandlerRemovedAlready(); + return; + } + ctx.invokeChannelRegistered(); } void invokeChannelRegistered() { @@ -149,7 +175,12 @@ final class DefaultChannelHandlerContext implements ChannelHandlerContext, Resou } private void findAndInvokeChannelUnregistered() { - findContextInbound(MASK_CHANNEL_UNREGISTERED).invokeChannelUnregistered(); + DefaultChannelHandlerContext ctx = findContextInbound(MASK_CHANNEL_UNREGISTERED); + if (ctx == null) { + notifyHandlerRemovedAlready(); + return; + } + ctx.invokeChannelUnregistered(); } void invokeChannelUnregistered() { @@ -172,7 +203,12 @@ final class DefaultChannelHandlerContext implements ChannelHandlerContext, Resou } private void findAndInvokeChannelActive() { - findContextInbound(MASK_CHANNEL_ACTIVE).invokeChannelActive(); + DefaultChannelHandlerContext ctx = findContextInbound(MASK_CHANNEL_ACTIVE); + if (ctx == null) { + notifyHandlerRemovedAlready(); + return; + } + ctx.invokeChannelActive(); } void invokeChannelActive() { @@ -195,7 +231,12 @@ final class DefaultChannelHandlerContext implements ChannelHandlerContext, Resou } private void findAndInvokeChannelInactive() { - findContextInbound(MASK_CHANNEL_INACTIVE).invokeChannelInactive(); + DefaultChannelHandlerContext ctx = findContextInbound(MASK_CHANNEL_INACTIVE); + if (ctx == null) { + notifyHandlerRemovedAlready(); + return; + } + ctx.invokeChannelInactive(); } void invokeChannelInactive() { @@ -226,7 +267,12 @@ final class DefaultChannelHandlerContext implements ChannelHandlerContext, Resou } private void findAndInvokeExceptionCaught(Throwable cause) { - findContextInbound(MASK_EXCEPTION_CAUGHT).invokeExceptionCaught(cause); + DefaultChannelHandlerContext ctx = findContextInbound(MASK_EXCEPTION_CAUGHT); + if (ctx == null) { + notifyHandlerRemovedAlready(cause); + return; + } + ctx.invokeExceptionCaught(cause); } void invokeExceptionCaught(final Throwable cause) { @@ -261,7 +307,13 @@ final class DefaultChannelHandlerContext implements ChannelHandlerContext, Resou } private void findAndInvokeUserEventTriggered(Object event) { - findContextInbound(MASK_USER_EVENT_TRIGGERED).invokeUserEventTriggered(event); + DefaultChannelHandlerContext ctx = findContextInbound(MASK_USER_EVENT_TRIGGERED); + if (ctx == null) { + ReferenceCountUtil.release(event); + notifyHandlerRemovedAlready(); + return; + } + ctx.invokeUserEventTriggered(event); } void invokeUserEventTriggered(Object event) { @@ -290,7 +342,13 @@ final class DefaultChannelHandlerContext implements ChannelHandlerContext, Resou } private void findAndInvokeChannelRead(Object msg) { - findContextInbound(MASK_CHANNEL_READ).invokeChannelRead(msg); + DefaultChannelHandlerContext ctx = findContextInbound(MASK_CHANNEL_READ); + if (ctx == null) { + ReferenceCountUtil.release(msg); + notifyHandlerRemovedAlready(); + return; + } + ctx.invokeChannelRead(msg); } void invokeChannelRead(Object msg) { @@ -315,7 +373,12 @@ final class DefaultChannelHandlerContext implements ChannelHandlerContext, Resou } private void findAndInvokeChannelReadComplete() { - findContextInbound(MASK_CHANNEL_READ_COMPLETE).invokeChannelReadComplete(); + DefaultChannelHandlerContext ctx = findContextInbound(MASK_CHANNEL_READ_COMPLETE); + if (ctx == null) { + notifyHandlerRemovedAlready(); + return; + } + ctx.invokeChannelReadComplete(); } void invokeChannelReadComplete() { @@ -339,7 +402,12 @@ final class DefaultChannelHandlerContext implements ChannelHandlerContext, Resou } private void findAndInvokeChannelWritabilityChanged() { - findContextInbound(MASK_CHANNEL_WRITABILITY_CHANGED).invokeChannelWritabilityChanged(); + DefaultChannelHandlerContext ctx = findContextInbound(MASK_CHANNEL_WRITABILITY_CHANGED); + if (ctx == null) { + notifyHandlerRemovedAlready(); + return; + } + ctx.invokeChannelWritabilityChanged(); } void invokeChannelWritabilityChanged() { @@ -403,7 +471,12 @@ final class DefaultChannelHandlerContext implements ChannelHandlerContext, Resou } private void findAndInvokeBind(SocketAddress localAddress, ChannelPromise promise) { - findContextOutbound(MASK_BIND).invokeBind(localAddress, promise); + DefaultChannelHandlerContext ctx = findContextOutbound(MASK_BIND); + if (ctx == null) { + failRemoved(this, promise); + return; + } + ctx.invokeBind(localAddress, promise); } private void invokeBind(SocketAddress localAddress, ChannelPromise promise) { @@ -438,7 +511,12 @@ final class DefaultChannelHandlerContext implements ChannelHandlerContext, Resou } private void findAndInvokeConnect(SocketAddress remoteAddress, SocketAddress localAddress, ChannelPromise promise) { - findContextOutbound(MASK_CONNECT).invokeConnect(remoteAddress, localAddress, promise); + DefaultChannelHandlerContext ctx = findContextOutbound(MASK_CONNECT); + if (ctx == null) { + failRemoved(this, promise); + return; + } + ctx.invokeConnect(remoteAddress, localAddress, promise); } private void invokeConnect(SocketAddress remoteAddress, SocketAddress localAddress, ChannelPromise promise) { @@ -472,7 +550,12 @@ final class DefaultChannelHandlerContext implements ChannelHandlerContext, Resou } private void findAndInvokeDisconnect(ChannelPromise promise) { - findContextOutbound(MASK_DISCONNECT).invokeDisconnect(promise); + DefaultChannelHandlerContext ctx = findContextOutbound(MASK_DISCONNECT); + if (ctx == null) { + failRemoved(this, promise); + return; + } + ctx.invokeDisconnect(promise); } private void invokeDisconnect(ChannelPromise promise) { @@ -500,7 +583,12 @@ final class DefaultChannelHandlerContext implements ChannelHandlerContext, Resou } private void findAndInvokeClose(ChannelPromise promise) { - findContextOutbound(MASK_CLOSE).invokeClose(promise); + DefaultChannelHandlerContext ctx = findContextOutbound(MASK_CLOSE); + if (ctx == null) { + failRemoved(this, promise); + return; + } + ctx.invokeClose(promise); } private void invokeClose(ChannelPromise promise) { @@ -528,7 +616,12 @@ final class DefaultChannelHandlerContext implements ChannelHandlerContext, Resou } private void findAndInvokeRegister(ChannelPromise promise) { - findContextOutbound(MASK_REGISTER).invokeRegister(promise); + DefaultChannelHandlerContext ctx = findContextOutbound(MASK_REGISTER); + if (ctx == null) { + failRemoved(this, promise); + return; + } + ctx.invokeRegister(promise); } private void invokeRegister(ChannelPromise promise) { @@ -556,7 +649,12 @@ final class DefaultChannelHandlerContext implements ChannelHandlerContext, Resou } private void findAndInvokeDeregister(ChannelPromise promise) { - findContextOutbound(MASK_DEREGISTER).invokeDeregister(promise); + DefaultChannelHandlerContext ctx = findContextOutbound(MASK_DEREGISTER); + if (ctx == null) { + failRemoved(this, promise); + return; + } + ctx.invokeDeregister(promise); } private void invokeDeregister(ChannelPromise promise) { @@ -580,7 +678,10 @@ final class DefaultChannelHandlerContext implements ChannelHandlerContext, Resou } private void findAndInvokeRead() { - findContextOutbound(MASK_READ).invokeRead(); + DefaultChannelHandlerContext ctx = findContextOutbound(MASK_READ); + if (ctx != null) { + ctx.invokeRead(); + } } private void invokeRead() { @@ -595,7 +696,12 @@ final class DefaultChannelHandlerContext implements ChannelHandlerContext, Resou if ((executionMask & MASK_EXCEPTION_CAUGHT) != 0) { notifyHandlerException(t); } else { - findContextInbound(MASK_EXCEPTION_CAUGHT).notifyHandlerException(t); + DefaultChannelHandlerContext ctx = findContextInbound(MASK_EXCEPTION_CAUGHT); + if (ctx == null) { + notifyHandlerRemovedAlready(); + return; + } + ctx.invokeExceptionCaught(t); } } @@ -634,7 +740,10 @@ final class DefaultChannelHandlerContext implements ChannelHandlerContext, Resou } private void findAndInvokeFlush() { - findContextOutbound(MASK_FLUSH).invokeFlush(); + DefaultChannelHandlerContext ctx = findContextOutbound(MASK_FLUSH); + if (ctx != null) { + ctx.invokeFlush(); + } } private void invokeFlush() { @@ -673,6 +782,11 @@ final class DefaultChannelHandlerContext implements ChannelHandlerContext, Resou if (executor.inEventLoop()) { final DefaultChannelHandlerContext next = findContextOutbound(flush ? (MASK_WRITE | MASK_FLUSH) : MASK_WRITE); + if (next == null) { + ReferenceCountUtil.release(msg); + failRemoved(this, promise); + return; + } if (flush) { next.invokeWriteAndFlush(msg, promise); } else { @@ -796,17 +910,23 @@ final class DefaultChannelHandlerContext implements ChannelHandlerContext, Resou private DefaultChannelHandlerContext findContextInbound(int mask) { DefaultChannelHandlerContext ctx = this; + if (ctx.next == null) { + return null; + } do { ctx = ctx.next; - } while ((ctx.executionMask & mask) == 0); + } while ((ctx.executionMask & mask) == 0 || ctx.handlerState == REMOVE_STARTED); return ctx; } private DefaultChannelHandlerContext findContextOutbound(int mask) { DefaultChannelHandlerContext ctx = this; + if (ctx.prev == null) { + return null; + } do { ctx = ctx.prev; - } while ((ctx.executionMask & mask) == 0); + } while ((ctx.executionMask & mask) == 0 || ctx.handlerState == REMOVE_STARTED); return ctx; } @@ -815,15 +935,11 @@ final class DefaultChannelHandlerContext implements ChannelHandlerContext, Resou return channel().voidPromise(); } - private void setRemoved() { - handlerState = REMOVE_COMPLETE; - } - boolean setAddComplete() { // Ensure we never update when the handlerState is REMOVE_COMPLETE already. // oldState is usually ADD_PENDING but can also be REMOVE_COMPLETE when an EventExecutor is used that is not // exposing ordering guarantees. - if (handlerState != REMOVE_COMPLETE) { + if (handlerState == INIT) { handlerState = ADD_COMPLETE; return true; } @@ -842,11 +958,12 @@ final class DefaultChannelHandlerContext implements ChannelHandlerContext, Resou try { // Only call handlerRemoved(...) if we called handlerAdded(...) before. if (handlerState == ADD_COMPLETE) { + handlerState = REMOVE_STARTED; handler().handlerRemoved(this); } } finally { // Mark the handler as removed in any case. - setRemoved(); + handlerState = REMOVE_COMPLETE; } } @@ -855,6 +972,24 @@ final class DefaultChannelHandlerContext implements ChannelHandlerContext, Resou return handlerState == REMOVE_COMPLETE; } + void remove(boolean relink) { + assert handlerState == REMOVE_COMPLETE; + if (relink) { + DefaultChannelHandlerContext prev = this.prev; + DefaultChannelHandlerContext next = this.next; + // prev and next may be null if the handler was never really added to the pipeline + if (prev != null) { + prev.next = next; + } + if (next != null) { + next.prev = prev; + } + } + + this.prev = null; + this.next = null; + } + @Override public Attribute attr(AttributeKey key) { return channel().attr(key); @@ -931,6 +1066,11 @@ final class DefaultChannelHandlerContext implements ChannelHandlerContext, Resou try { decrementPendingOutboundBytes(); DefaultChannelHandlerContext next = findContext(ctx); + if (next == null) { + ReferenceCountUtil.release(msg); + failRemoved(ctx, promise); + return; + } write(next, msg, promise); } finally { recycle(); diff --git a/transport/src/main/java/io/netty/channel/DefaultChannelPipeline.java b/transport/src/main/java/io/netty/channel/DefaultChannelPipeline.java index 51983c199d..a2df83b208 100644 --- a/transport/src/main/java/io/netty/channel/DefaultChannelPipeline.java +++ b/transport/src/main/java/io/netty/channel/DefaultChannelPipeline.java @@ -48,6 +48,7 @@ public class DefaultChannelPipeline implements ChannelPipeline { private static final InternalLogger logger = InternalLoggerFactory.getInstance(DefaultChannelPipeline.class); private static final String HEAD_NAME = generateName0(HeadHandler.class); private static final String TAIL_NAME = generateName0(TailHandler.class); + private static final ChannelHandler HEAD_HANDLER = new HeadHandler(); private static final ChannelHandler TAIL_HANDLER = new TailHandler(); @@ -64,6 +65,7 @@ public class DefaultChannelPipeline implements ChannelPipeline { DefaultChannelPipeline.class, MessageSizeEstimator.Handle.class, "estimatorHandle"); private final DefaultChannelHandlerContext head; private final DefaultChannelHandlerContext tail; + private final Channel channel; private final ChannelFuture succeededFuture; private final VoidChannelPromise voidPromise; @@ -102,9 +104,26 @@ public class DefaultChannelPipeline implements ChannelPipeline { } private DefaultChannelHandlerContext newContext(String name, ChannelHandler handler) { + checkMultiplicity(handler); + if (name == null) { + name = generateName(handler); + } return new DefaultChannelHandlerContext(this, name, handler); } + private void checkDuplicateName(String name) { + if (context(name) != null) { + throw new IllegalArgumentException("Duplicate handler name: " + name); + } + } + + private int checkNoSuchElement(int idx, String name) { + if (idx == -1) { + throw new NoSuchElementException(name); + } + return idx; + } + @Override public final Channel channel() { return channel; @@ -117,18 +136,12 @@ public class DefaultChannelPipeline implements ChannelPipeline { @Override public final ChannelPipeline addFirst(String name, ChannelHandler handler) { - checkMultiplicity(handler); - if (name == null) { - name = generateName(handler); - } - DefaultChannelHandlerContext newCtx = newContext(name, handler); EventExecutor executor = executor(); boolean inEventLoop = executor.inEventLoop(); synchronized (handlers) { - if (context(name) != null) { - throw new IllegalArgumentException("Duplicate handler name: " + name); - } + checkDuplicateName(newCtx.name()); + handlers.add(0, newCtx); if (!inEventLoop) { try { @@ -156,18 +169,12 @@ public class DefaultChannelPipeline implements ChannelPipeline { @Override public final ChannelPipeline addLast(String name, ChannelHandler handler) { - checkMultiplicity(handler); - if (name == null) { - name = generateName(handler); - } - DefaultChannelHandlerContext newCtx = newContext(name, handler); EventExecutor executor = executor(); boolean inEventLoop = executor.inEventLoop(); synchronized (handlers) { - if (context(name) != null) { - throw new IllegalArgumentException("Duplicate handler name: " + name); - } + checkDuplicateName(newCtx.name()); + handlers.add(newCtx); if (!inEventLoop) { try { @@ -197,24 +204,13 @@ public class DefaultChannelPipeline implements ChannelPipeline { public final ChannelPipeline addBefore(String baseName, String name, ChannelHandler handler) { final DefaultChannelHandlerContext ctx; - checkMultiplicity(handler); - if (name == null) { - name = generateName(handler); - } - DefaultChannelHandlerContext newCtx = newContext(name, handler); EventExecutor executor = executor(); boolean inEventLoop = executor.inEventLoop(); synchronized (handlers) { - int i = findCtxIdx(context -> context.name().equals(baseName)); + int i = checkNoSuchElement(findCtxIdx(context -> context.name().equals(baseName)), baseName); + checkDuplicateName(newCtx.name()); - if (i == -1) { - throw new NoSuchElementException(baseName); - } - - if (context(name) != null) { - throw new IllegalArgumentException("Duplicate handler name: " + name); - } ctx = handlers.get(i); handlers.add(i, newCtx); if (!inEventLoop) { @@ -244,7 +240,6 @@ public class DefaultChannelPipeline implements ChannelPipeline { public final ChannelPipeline addAfter(String baseName, String name, ChannelHandler handler) { final DefaultChannelHandlerContext ctx; - checkMultiplicity(handler); if (name == null) { name = generateName(handler); } @@ -253,15 +248,10 @@ public class DefaultChannelPipeline implements ChannelPipeline { EventExecutor executor = executor(); boolean inEventLoop = executor.inEventLoop(); synchronized (handlers) { - int i = findCtxIdx(context -> context.name().equals(baseName)); + int i = checkNoSuchElement(findCtxIdx(context -> context.name().equals(baseName)), baseName); - if (i == -1) { - throw new NoSuchElementException(baseName); - } + checkDuplicateName(newCtx.name()); - if (context(name) != null) { - throw new IllegalArgumentException("Duplicate handler name: " + name); - } ctx = handlers.get(i); handlers.add(i + 1, newCtx); if (!inEventLoop) { @@ -377,21 +367,13 @@ public class DefaultChannelPipeline implements ChannelPipeline { EventExecutor executor = executor(); boolean inEventLoop = executor.inEventLoop(); synchronized (handlers) { - int idx = findCtxIdx(context -> context.handler() == handler); - if (idx == -1) { - throw new NoSuchElementException(); - } + int idx = checkNoSuchElement(findCtxIdx(context -> context.handler() == handler), null); ctx = handlers.remove(idx); assert ctx != null; if (!inEventLoop) { - try { - executor.execute(() -> remove0(ctx)); - return this; - } catch (Throwable cause) { - handlers.add(idx, ctx); - throw cause; - } + scheduleRemove(idx, ctx); + return this; } } @@ -405,21 +387,12 @@ public class DefaultChannelPipeline implements ChannelPipeline { EventExecutor executor = executor(); boolean inEventLoop = executor.inEventLoop(); synchronized (handlers) { - int idx = findCtxIdx(context -> context.name().equals(name)); - if (idx == -1) { - throw new NoSuchElementException(); - } + int idx = checkNoSuchElement(findCtxIdx(context -> context.name().equals(name)), name); ctx = handlers.remove(idx); assert ctx != null; if (!inEventLoop) { - try { - executor.execute(() -> remove0(ctx)); - return ctx.handler(); - } catch (Throwable cause) { - handlers.add(idx, ctx); - throw cause; - } + return scheduleRemove(idx, ctx); } } @@ -434,21 +407,13 @@ public class DefaultChannelPipeline implements ChannelPipeline { EventExecutor executor = executor(); boolean inEventLoop = executor.inEventLoop(); synchronized (handlers) { - int idx = findCtxIdx(context -> handlerType.isAssignableFrom(context.handler().getClass())); - if (idx == -1) { - throw new NoSuchElementException(); - } + int idx = checkNoSuchElement(findCtxIdx( + context -> handlerType.isAssignableFrom(context.handler().getClass())), null); ctx = handlers.remove(idx); assert ctx != null; if (!inEventLoop) { - try { - executor.execute(() -> remove0(ctx)); - return (T) ctx.handler(); - } catch (Throwable cause) { - handlers.add(idx, ctx); - throw cause; - } + return scheduleRemove(idx, ctx); } } @@ -483,30 +448,19 @@ public class DefaultChannelPipeline implements ChannelPipeline { assert ctx != null; if (!inEventLoop) { - try { - executor.execute(() -> remove0(ctx)); - return (T) ctx.handler(); - } catch (Throwable cause) { - handlers.add(idx, ctx); - throw cause; - } + return scheduleRemove(idx, ctx); } } remove0(ctx); return (T) ctx.handler(); } - private void unlink(DefaultChannelHandlerContext ctx) { - assert ctx != head && ctx != tail; - DefaultChannelHandlerContext prev = ctx.prev; - DefaultChannelHandlerContext next = ctx.next; - prev.next = next; - next.prev = prev; - } - private void remove0(DefaultChannelHandlerContext ctx) { - unlink(ctx); - callHandlerRemoved0(ctx); + try { + callHandlerRemoved0(ctx); + } finally { + ctx.remove(true); + } } @Override @@ -529,27 +483,17 @@ public class DefaultChannelPipeline implements ChannelPipeline { private ChannelHandler replace( Predicate predicate, String newName, ChannelHandler newHandler) { - checkMultiplicity(newHandler); - - if (newName == null) { - newName = generateName(newHandler); - } DefaultChannelHandlerContext oldCtx; DefaultChannelHandlerContext newCtx = newContext(newName, newHandler); EventExecutor executor = executor(); boolean inEventLoop = executor.inEventLoop(); synchronized (handlers) { - int idx = findCtxIdx(predicate); - if (idx == -1) { - throw new NoSuchElementException(); - } + int idx = checkNoSuchElement(findCtxIdx(predicate), null); oldCtx = handlers.get(idx); assert oldCtx != head && oldCtx != tail && oldCtx != null; - if (!oldCtx.name().equals(newName)) { - if (context(newName) != null) { - throw new IllegalArgumentException("Duplicate handler name: " + newName); - } + if (!oldCtx.name().equals(newCtx.name())) { + checkDuplicateName(newCtx.name()); } DefaultChannelHandlerContext removed = handlers.set(idx, newCtx); assert removed != null; @@ -586,11 +530,15 @@ public class DefaultChannelPipeline implements ChannelPipeline { oldCtx.prev = newCtx; oldCtx.next = newCtx; - // Invoke newHandler.handlerAdded() first (i.e. before oldHandler.handlerRemoved() is invoked) - // because callHandlerRemoved() will trigger channelRead() or flush() on newHandler and those - // event handlers must be called after handlerAdded(). - callHandlerAdded0(newCtx); - callHandlerRemoved0(oldCtx); + try { + // Invoke newHandler.handlerAdded() first (i.e. before oldHandler.handlerRemoved() is invoked) + // because callHandlerRemoved() will trigger channelRead() or flush() on newHandler and those + // event handlers must be called after handlerAdded(). + callHandlerAdded0(newCtx); + callHandlerRemoved0(oldCtx); + } finally { + oldCtx.remove(false); + } } private static void checkMultiplicity(ChannelHandler handler) { @@ -615,7 +563,6 @@ public class DefaultChannelPipeline implements ChannelPipeline { handlers.remove(ctx); } - unlink(ctx); ctx.callHandlerRemoved(); removed = true; @@ -623,6 +570,8 @@ public class DefaultChannelPipeline implements ChannelPipeline { if (logger.isWarnEnabled()) { logger.warn("Failed to remove a handler: " + ctx.name(), t2); } + } finally { + ctx.remove(true); } if (removed) { @@ -749,13 +698,7 @@ public class DefaultChannelPipeline implements ChannelPipeline { assert ctx != null; if (!inEventLoop) { - try { - executor.execute(() -> remove0(ctx)); - return ctx.handler(); - } catch (Throwable cause) { - handlers.add(idx, ctx); - throw cause; - } + return scheduleRemove(idx, ctx); } } remove0(ctx); @@ -777,19 +720,24 @@ public class DefaultChannelPipeline implements ChannelPipeline { assert ctx != null; if (!inEventLoop) { - try { - executor.execute(() -> remove0(ctx)); - return ctx.handler(); - } catch (Throwable cause) { - handlers.add(idx, ctx); - throw cause; - } + return scheduleRemove(idx, ctx); } } remove0(ctx); return ctx.handler(); } + @SuppressWarnings("unchecked") + private T scheduleRemove(int idx, DefaultChannelHandlerContext ctx) { + try { + ctx.executor().execute(() -> remove0(ctx)); + return (T) ctx.handler(); + } catch (Throwable cause) { + handlers.add(idx, ctx); + throw cause; + } + } + @Override public ChannelHandler first() { ChannelHandlerContext ctx = firstContext(); @@ -849,32 +797,6 @@ public class DefaultChannelPipeline implements ChannelPipeline { return this; } - /** - * Removes all handlers from the pipeline one by one from tail (exclusive) to head (exclusive) to trigger - * handlerRemoved(). - */ - private void destroy() { - EventExecutor executor = executor(); - if (executor.inEventLoop()) { - destroy0(); - } else { - executor.execute(this::destroy0); - } - } - - private void destroy0() { - assert executor().inEventLoop(); - DefaultChannelHandlerContext ctx = this.tail.prev; - while (ctx != head) { - synchronized (handlers) { - handlers.remove(ctx); - } - remove0(ctx); - - ctx = ctx.prev; - } - } - @Override public final ChannelPipeline fireChannelActive() { head.invokeChannelActive(); @@ -980,7 +902,7 @@ public class DefaultChannelPipeline implements ChannelPipeline { } @Override - public final ChannelFuture close(ChannelPromise promise) { + public ChannelFuture close(ChannelPromise promise) { return tail.close(promise); } @@ -1135,10 +1057,14 @@ public class DefaultChannelPipeline implements ChannelPipeline { private static final class TailHandler implements ChannelHandler { @Override - public void channelRegistered(ChannelHandlerContext ctx) { } + public void channelRegistered(ChannelHandlerContext ctx) { + // Just swallow event + } @Override - public void channelUnregistered(ChannelHandlerContext ctx) { } + public void channelUnregistered(ChannelHandlerContext ctx) { + // Just swallow event + } @Override public void channelActive(ChannelHandlerContext ctx) { @@ -1155,12 +1081,6 @@ public class DefaultChannelPipeline implements ChannelPipeline { ((DefaultChannelPipeline) ctx.pipeline()).onUnhandledChannelWritabilityChanged(); } - @Override - public void handlerAdded(ChannelHandlerContext ctx) { } - - @Override - public void handlerRemoved(ChannelHandlerContext ctx) { } - @Override public void userEventTriggered(ChannelHandlerContext ctx, Object evt) { ((DefaultChannelPipeline) ctx.pipeline()).onUnhandledInboundUserEventTriggered(evt); @@ -1232,15 +1152,5 @@ public class DefaultChannelPipeline implements ChannelPipeline { public void flush(ChannelHandlerContext ctx) { ctx.channel().unsafe().flush(); } - - @Override - public void channelUnregistered(ChannelHandlerContext ctx) { - ctx.fireChannelUnregistered(); - - // Remove all handlers sequentially if channel is closed and unregistered. - if (!ctx.channel().isOpen()) { - ((DefaultChannelPipeline) ctx.pipeline()).destroy(); - } - } } } diff --git a/transport/src/main/java/io/netty/channel/embedded/EmbeddedChannel.java b/transport/src/main/java/io/netty/channel/embedded/EmbeddedChannel.java index ffa4bec67a..cb63dd1965 100644 --- a/transport/src/main/java/io/netty/channel/embedded/EmbeddedChannel.java +++ b/transport/src/main/java/io/netty/channel/embedded/EmbeddedChannel.java @@ -782,64 +782,70 @@ public class EmbeddedChannel extends AbstractChannel { return EmbeddedUnsafe.this.remoteAddress(); } + private void mayRunPendingTasks() { + if (!((EmbeddedEventLoop) eventLoop()).running) { + runPendingTasks(); + } + } + @Override public void register(ChannelPromise promise) { EmbeddedUnsafe.this.register(promise); - runPendingTasks(); + mayRunPendingTasks(); } @Override public void bind(SocketAddress localAddress, ChannelPromise promise) { EmbeddedUnsafe.this.bind(localAddress, promise); - runPendingTasks(); + mayRunPendingTasks(); } @Override public void connect(SocketAddress remoteAddress, SocketAddress localAddress, ChannelPromise promise) { EmbeddedUnsafe.this.connect(remoteAddress, localAddress, promise); - runPendingTasks(); + mayRunPendingTasks(); } @Override public void disconnect(ChannelPromise promise) { EmbeddedUnsafe.this.disconnect(promise); - runPendingTasks(); + mayRunPendingTasks(); } @Override public void close(ChannelPromise promise) { EmbeddedUnsafe.this.close(promise); - runPendingTasks(); + mayRunPendingTasks(); } @Override public void closeForcibly() { EmbeddedUnsafe.this.closeForcibly(); - runPendingTasks(); + mayRunPendingTasks(); } @Override public void deregister(ChannelPromise promise) { EmbeddedUnsafe.this.deregister(promise); - runPendingTasks(); + mayRunPendingTasks(); } @Override public void beginRead() { EmbeddedUnsafe.this.beginRead(); - runPendingTasks(); + mayRunPendingTasks(); } @Override public void write(Object msg, ChannelPromise promise) { EmbeddedUnsafe.this.write(msg, promise); - runPendingTasks(); + mayRunPendingTasks(); } @Override public void flush() { EmbeddedUnsafe.this.flush(); - runPendingTasks(); + mayRunPendingTasks(); } @Override diff --git a/transport/src/main/java/io/netty/channel/embedded/EmbeddedEventLoop.java b/transport/src/main/java/io/netty/channel/embedded/EmbeddedEventLoop.java index 609b266c4f..36fc871472 100644 --- a/transport/src/main/java/io/netty/channel/embedded/EmbeddedEventLoop.java +++ b/transport/src/main/java/io/netty/channel/embedded/EmbeddedEventLoop.java @@ -30,7 +30,7 @@ import java.util.concurrent.TimeUnit; final class EmbeddedEventLoop extends AbstractScheduledEventExecutor implements EventLoop { private final Queue tasks = new ArrayDeque<>(2); - private boolean running; + boolean running; private static EmbeddedChannel cast(Channel channel) { if (channel instanceof EmbeddedChannel) { diff --git a/transport/src/test/java/io/netty/channel/DefaultChannelPipelineTest.java b/transport/src/test/java/io/netty/channel/DefaultChannelPipelineTest.java index 40e8a70ef5..8b832f650a 100644 --- a/transport/src/test/java/io/netty/channel/DefaultChannelPipelineTest.java +++ b/transport/src/test/java/io/netty/channel/DefaultChannelPipelineTest.java @@ -38,6 +38,7 @@ import io.netty.util.concurrent.Future; import io.netty.util.concurrent.ImmediateEventExecutor; import io.netty.util.concurrent.Promise; import io.netty.util.concurrent.ScheduledFuture; +import org.hamcrest.Matchers; import org.junit.After; import org.junit.AfterClass; import org.junit.Test; @@ -50,6 +51,7 @@ import java.util.List; import java.util.NoSuchElementException; import java.util.Queue; import java.util.concurrent.Callable; +import java.util.concurrent.CompletionException; import java.util.concurrent.CountDownLatch; import java.util.concurrent.ExecutorService; import java.util.concurrent.Executors; @@ -62,6 +64,7 @@ import static org.junit.Assert.assertFalse; import static org.junit.Assert.assertNotNull; import static org.junit.Assert.assertNull; import static org.junit.Assert.assertSame; +import static org.junit.Assert.assertThat; import static org.junit.Assert.assertTrue; import static org.junit.Assert.fail; @@ -761,7 +764,7 @@ public class DefaultChannelPipelineTest { pipeline.channel().close().syncUninterruptibly(); } - @Test(timeout = 2000) + @Test(timeout = 5000) public void testAddRemoveHandlerCalled() throws Throwable { ChannelPipeline pipeline = newLocalChannel().pipeline(); @@ -815,6 +818,110 @@ public class DefaultChannelPipelineTest { } } + @Test(timeout = 5000) + public void testOperationsFailWhenRemoved() { + ChannelPipeline pipeline = newLocalChannel().pipeline(); + try { + pipeline.channel().register().syncUninterruptibly(); + + ChannelHandler handler = new ChannelHandler() { }; + pipeline.addFirst(handler); + ChannelHandlerContext ctx = pipeline.context(handler); + pipeline.remove(handler); + + testOperationsFailsOnContext(ctx); + } finally { + pipeline.channel().close().syncUninterruptibly(); + } + } + + @Test(timeout = 5000) + public void testOperationsFailWhenReplaced() { + ChannelPipeline pipeline = newLocalChannel().pipeline(); + try { + pipeline.channel().register().syncUninterruptibly(); + + ChannelHandler handler = new ChannelHandler() { }; + pipeline.addFirst(handler); + ChannelHandlerContext ctx = pipeline.context(handler); + pipeline.replace(handler, null, new ChannelHandler() { }); + + testOperationsFailsOnContext(ctx); + } finally { + pipeline.channel().close().syncUninterruptibly(); + } + } + + private static void testOperationsFailsOnContext(ChannelHandlerContext ctx) { + assertChannelPipelineException(ctx.writeAndFlush("")); + assertChannelPipelineException(ctx.write("")); + assertChannelPipelineException(ctx.bind(new SocketAddress() { })); + assertChannelPipelineException(ctx.close()); + assertChannelPipelineException(ctx.connect(new SocketAddress() { })); + assertChannelPipelineException(ctx.connect(new SocketAddress() { }, new SocketAddress() { })); + assertChannelPipelineException(ctx.deregister()); + assertChannelPipelineException(ctx.disconnect()); + + class ChannelPipelineExceptionValidator implements ChannelHandler { + + private Promise validationPromise = ImmediateEventExecutor.INSTANCE.newPromise(); + + @Override + public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) { + try { + assertThat(cause, Matchers.instanceOf(ChannelPipelineException.class)); + } catch (Throwable error) { + validationPromise.setFailure(error); + return; + } + validationPromise.setSuccess(null); + } + + void validate() { + validationPromise.syncUninterruptibly(); + validationPromise = ImmediateEventExecutor.INSTANCE.newPromise(); + } + } + + ChannelPipelineExceptionValidator validator = new ChannelPipelineExceptionValidator(); + ctx.pipeline().addLast(validator); + + ctx.fireChannelRead(""); + validator.validate(); + + ctx.fireUserEventTriggered(""); + validator.validate(); + + ctx.fireChannelReadComplete(); + validator.validate(); + + ctx.fireExceptionCaught(new Exception()); + validator.validate(); + + ctx.fireChannelActive(); + validator.validate(); + + ctx.fireChannelRegistered(); + validator.validate(); + + ctx.fireChannelInactive(); + validator.validate(); + + ctx.fireChannelUnregistered(); + validator.validate(); + + ctx.fireChannelWritabilityChanged(); + validator.validate(); + } + + private static void assertChannelPipelineException(ChannelFuture f) { + try { + f.syncUninterruptibly(); + } catch (CompletionException e) { + assertThat(e.getCause(), Matchers.instanceOf(ChannelPipelineException.class)); + } + } + @Test(timeout = 3000) public void testAddReplaceHandlerCalled() throws Throwable { ChannelPipeline pipeline = newLocalChannel().pipeline(); diff --git a/transport/src/test/java/io/netty/channel/PendingWriteQueueTest.java b/transport/src/test/java/io/netty/channel/PendingWriteQueueTest.java index 152c2dbeb0..54bd1e460e 100644 --- a/transport/src/test/java/io/netty/channel/PendingWriteQueueTest.java +++ b/transport/src/test/java/io/netty/channel/PendingWriteQueueTest.java @@ -232,7 +232,9 @@ public class PendingWriteQueueTest { ChannelPromise promise = channel.newPromise(); final ChannelPromise promise3 = channel.newPromise(); - promise.addListener((ChannelFutureListener) future -> queue.add(3L, promise3)); + promise.addListener((ChannelFutureListener) future -> { + queue.add(3L, promise3); + }); ChannelPromise promise2 = channel.newPromise(); channel.eventLoop().execute(() -> { @@ -245,8 +247,13 @@ public class PendingWriteQueueTest { assertTrue(promise.isSuccess()); assertTrue(promise2.isDone()); assertTrue(promise2.isSuccess()); + assertFalse(promise3.isDone()); + assertFalse(promise3.isSuccess()); + + channel.eventLoop().execute(queue::removeAndWriteAll); assertTrue(promise3.isDone()); assertTrue(promise3.isSuccess()); + channel.runPendingTasks(); assertTrue(channel.finish()); assertEquals(1L, (long) channel.readOutbound()); assertEquals(2L, (long) channel.readOutbound());