Ensure the DefaultChannelHandlerContext is unlinked once removed (#9970)

Motivation:

At the moment the next / prev references are not set to "null" in the DefaultChannelHandlerContext once the ChannelHandler is removed. This is bad as it basically let users still use the ChannelHandlerContext of a ChannelHandler after it is removed and may produce very suprising behaviour.

Modifications:

- Fail if someone tries to use the ChannelHandlerContext once the ChannelHandler was removed (for outbound operations fail the promise, for inbound fire the error through the ChannelPipeline)
- Fix some handlers to ensure we not use the ChannelHandlerContext after the handler was removed
- Adjust DefaultChannelPipeline / DefaultChannelHandlerContext to fixes races with removal / replacement of handlers

Result:

Cleanup behaviour and make it more predictable for pipeline modifications
This commit is contained in:
Norman Maurer 2020-03-01 08:13:33 +01:00 committed by GitHub
parent 302cb737f3
commit ae0fbb45e4
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
30 changed files with 512 additions and 298 deletions

View File

@ -253,6 +253,14 @@ public class HAProxyMessageDecoder extends ByteToMessageDecoder {
} }
} }
@Override
public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) throws Exception {
ctx.fireExceptionCaught(cause);
if (cause instanceof HAProxyProtocolException) {
ctx.close(); // drop connection immediately per spec
}
}
@Override @Override
protected final void decode(ChannelHandlerContext ctx, ByteBuf in) throws Exception { protected final void decode(ChannelHandlerContext ctx, ByteBuf in) throws Exception {
// determine the specification version // determine the specification version
@ -325,7 +333,6 @@ public class HAProxyMessageDecoder extends ByteToMessageDecoder {
private void fail(final ChannelHandlerContext ctx, String errMsg, Exception e) { private void fail(final ChannelHandlerContext ctx, String errMsg, Exception e) {
finished = true; finished = true;
ctx.close(); // drop connection immediately per spec
HAProxyProtocolException ppex; HAProxyProtocolException ppex;
if (errMsg != null && e != null) { if (errMsg != null && e != null) {
ppex = new HAProxyProtocolException(errMsg, e); ppex = new HAProxyProtocolException(errMsg, e);

View File

@ -203,8 +203,8 @@ public class HttpClientUpgradeHandler extends HttpObjectAggregator {
// NOTE: not releasing the response since we're letting it propagate to the // NOTE: not releasing the response since we're letting it propagate to the
// next handler. // next handler.
ctx.fireUserEventTriggered(UpgradeEvent.UPGRADE_REJECTED); ctx.fireUserEventTriggered(UpgradeEvent.UPGRADE_REJECTED);
removeThisHandler(ctx);
ctx.fireChannelRead(msg); ctx.fireChannelRead(msg);
removeThisHandler(ctx);
return; return;
} }
} }

View File

@ -266,13 +266,26 @@ public class HttpObjectAggregator
}); });
} }
} else if (oversized instanceof HttpResponse) { } else if (oversized instanceof HttpResponse) {
ctx.close(); throw new ResponseTooLargeException("Response entity too large: " + oversized);
throw new TooLongFrameException("Response entity too large: " + oversized);
} else { } else {
throw new IllegalStateException(); throw new IllegalStateException();
} }
} }
@Override
public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) throws Exception {
super.exceptionCaught(ctx, cause);
if (cause instanceof ResponseTooLargeException) {
ctx.close();
}
}
private static final class ResponseTooLargeException extends TooLongFrameException {
ResponseTooLargeException(String message) {
super(message);
}
}
private abstract static class AggregatedFullHttpMessage implements FullHttpMessage { private abstract static class AggregatedFullHttpMessage implements FullHttpMessage {
protected final HttpMessage message; protected final HttpMessage message;
private final ByteBuf content; private final ByteBuf content;

View File

@ -328,13 +328,13 @@ public class HttpServerUpgradeHandler extends HttpObjectAggregator {
sourceCodec.upgradeFrom(ctx); sourceCodec.upgradeFrom(ctx);
upgradeCodec.upgradeTo(ctx, request); upgradeCodec.upgradeTo(ctx, request);
// Remove this handler from the pipeline.
ctx.pipeline().remove(HttpServerUpgradeHandler.this);
// Notify that the upgrade has occurred. Retain the event to offset // Notify that the upgrade has occurred. Retain the event to offset
// the release() in the finally block. // the release() in the finally block.
ctx.fireUserEventTriggered(event.retain()); ctx.fireUserEventTriggered(event.retain());
// Remove this handler from the pipeline.
ctx.pipeline().remove(HttpServerUpgradeHandler.this);
// Add the listener last to avoid firing upgrade logic after // Add the listener last to avoid firing upgrade logic after
// the channel is already closed since the listener may fire // the channel is already closed since the listener may fire
// immediately if the write failed eagerly. // immediately if the write failed eagerly.

View File

@ -300,9 +300,9 @@ public abstract class WebSocketServerHandshaker {
@Override @Override
public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) throws Exception { public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) throws Exception {
// Remove ourself and fail the handshake promise. // Remove ourself and fail the handshake promise.
ctx.pipeline().remove(this);
promise.tryFailure(cause); promise.tryFailure(cause);
ctx.fireExceptionCaught(cause); ctx.fireExceptionCaught(cause);
ctx.pipeline().remove(this);
} }
@Override @Override

View File

@ -45,7 +45,6 @@ import static io.netty.handler.codec.http.HttpVersion.*;
class WebSocketServerProtocolHandshakeHandler implements ChannelHandler { class WebSocketServerProtocolHandshakeHandler implements ChannelHandler {
private final WebSocketServerProtocolConfig serverConfig; private final WebSocketServerProtocolConfig serverConfig;
private ChannelHandlerContext ctx;
private ChannelPromise handshakePromise; private ChannelPromise handshakePromise;
WebSocketServerProtocolHandshakeHandler(WebSocketServerProtocolConfig serverConfig) { WebSocketServerProtocolHandshakeHandler(WebSocketServerProtocolConfig serverConfig) {
@ -54,7 +53,6 @@ class WebSocketServerProtocolHandshakeHandler implements ChannelHandler {
@Override @Override
public void handlerAdded(ChannelHandlerContext ctx) { public void handlerAdded(ChannelHandlerContext ctx) {
this.ctx = ctx;
handshakePromise = ctx.newPromise(); handshakePromise = ctx.newPromise();
} }
@ -86,7 +84,6 @@ class WebSocketServerProtocolHandshakeHandler implements ChannelHandler {
// //
// See https://github.com/netty/netty/issues/9471. // See https://github.com/netty/netty/issues/9471.
WebSocketServerProtocolHandler.setHandshaker(ctx.channel(), handshaker); WebSocketServerProtocolHandler.setHandshaker(ctx.channel(), handshaker);
ctx.pipeline().remove(this);
final ChannelFuture handshakeFuture = handshaker.handshake(ctx.channel(), req); final ChannelFuture handshakeFuture = handshaker.handshake(ctx.channel(), req);
handshakeFuture.addListener((ChannelFutureListener) future -> { handshakeFuture.addListener((ChannelFutureListener) future -> {
@ -102,8 +99,9 @@ class WebSocketServerProtocolHandshakeHandler implements ChannelHandler {
new WebSocketServerProtocolHandler.HandshakeComplete( new WebSocketServerProtocolHandler.HandshakeComplete(
req.uri(), req.headers(), handshaker.selectedSubprotocol())); req.uri(), req.headers(), handshaker.selectedSubprotocol()));
} }
ctx.pipeline().remove(this);
}); });
applyHandshakeTimeout(); applyHandshakeTimeout(ctx);
} }
} finally { } finally {
req.release(); req.release();
@ -132,7 +130,7 @@ class WebSocketServerProtocolHandshakeHandler implements ChannelHandler {
return protocol + "://" + host + path; return protocol + "://" + host + path;
} }
private void applyHandshakeTimeout() { private void applyHandshakeTimeout(ChannelHandlerContext ctx) {
final ChannelPromise localHandshakePromise = handshakePromise; final ChannelPromise localHandshakePromise = handshakePromise;
final long handshakeTimeoutMillis = serverConfig.handshakeTimeoutMillis(); final long handshakeTimeoutMillis = serverConfig.handshakeTimeoutMillis();
if (handshakeTimeoutMillis <= 0 || localHandshakePromise.isDone()) { if (handshakeTimeoutMillis <= 0 || localHandshakePromise.isDone()) {

View File

@ -120,8 +120,9 @@ public class WebSocketClientExtensionHandler implements ChannelHandler {
ctx.pipeline().addAfter(ctx.name(), encoder.getClass().getName(), encoder); ctx.pipeline().addAfter(ctx.name(), encoder.getClass().getName(), encoder);
} }
} }
ctx.fireChannelRead(msg);
ctx.pipeline().remove(ctx.name()); ctx.pipeline().remove(this);
return;
} }
} }

View File

@ -66,7 +66,7 @@ public class WebSocketServerProtocolHandlerTest {
} }
@Test @Test
public void testWebSocketServerProtocolHandshakeHandlerReplacedBeforeHandshake() { public void testWebSocketServerProtocolHandshakeHandlerRemovedAfterHandshake() {
EmbeddedChannel ch = createChannel(new MockOutboundHandler()); EmbeddedChannel ch = createChannel(new MockOutboundHandler());
ChannelHandlerContext handshakerCtx = ch.pipeline().context(WebSocketServerProtocolHandshakeHandler.class); ChannelHandlerContext handshakerCtx = ch.pipeline().context(WebSocketServerProtocolHandshakeHandler.class);
ch.pipeline().addLast(new ChannelHandler() { ch.pipeline().addLast(new ChannelHandler() {
@ -74,7 +74,7 @@ public class WebSocketServerProtocolHandlerTest {
public void userEventTriggered(ChannelHandlerContext ctx, Object evt) { public void userEventTriggered(ChannelHandlerContext ctx, Object evt) {
if (evt instanceof WebSocketServerProtocolHandler.HandshakeComplete) { if (evt instanceof WebSocketServerProtocolHandler.HandshakeComplete) {
// We should have removed the handler already. // We should have removed the handler already.
assertNull(ctx.pipeline().context(WebSocketServerProtocolHandshakeHandler.class)); ctx.executor().execute(() -> ctx.pipeline().context(WebSocketServerProtocolHandshakeHandler.class));
} }
} }
}); });

View File

@ -87,9 +87,9 @@ public final class CleartextHttp2ServerUpgradeHandler extends ByteToMessageDecod
.remove(httpServerUpgradeHandler); .remove(httpServerUpgradeHandler);
ctx.pipeline().addAfter(ctx.name(), null, http2ServerHandler); ctx.pipeline().addAfter(ctx.name(), null, http2ServerHandler);
ctx.pipeline().remove(this);
ctx.fireUserEventTriggered(PriorKnowledgeUpgradeEvent.INSTANCE); ctx.fireUserEventTriggered(PriorKnowledgeUpgradeEvent.INSTANCE);
ctx.pipeline().remove(this);
} }
} }

View File

@ -154,10 +154,6 @@ public abstract class ByteToMessageDecoder extends ChannelHandlerAdapter {
} }
}; };
private static final byte STATE_INIT = 0;
private static final byte STATE_CALLING_CHILD_DECODE = 1;
private static final byte STATE_HANDLER_REMOVED_PENDING = 2;
ByteBuf cumulation; ByteBuf cumulation;
private Cumulator cumulator = MERGE_CUMULATOR; private Cumulator cumulator = MERGE_CUMULATOR;
private boolean singleDecode; private boolean singleDecode;
@ -169,15 +165,6 @@ public abstract class ByteToMessageDecoder extends ChannelHandlerAdapter {
*/ */
private boolean firedChannelRead; private boolean firedChannelRead;
/**
* A bitmask where the bits are defined as
* <ul>
* <li>{@link #STATE_INIT}</li>
* <li>{@link #STATE_CALLING_CHILD_DECODE}</li>
* <li>{@link #STATE_HANDLER_REMOVED_PENDING}</li>
* </ul>
*/
private byte decodeState = STATE_INIT;
private int discardAfterReads = 16; private int discardAfterReads = 16;
private int numReads; private int numReads;
private ByteToMessageDecoderContext context; private ByteToMessageDecoderContext context;
@ -257,10 +244,6 @@ public abstract class ByteToMessageDecoder extends ChannelHandlerAdapter {
@Override @Override
public final void handlerRemoved(ChannelHandlerContext ctx) throws Exception { public final void handlerRemoved(ChannelHandlerContext ctx) throws Exception {
if (decodeState == STATE_CALLING_CHILD_DECODE) {
decodeState = STATE_HANDLER_REMOVED_PENDING;
return;
}
ByteBuf buf = cumulation; ByteBuf buf = cumulation;
if (buf != null) { if (buf != null) {
// Directly set this to null so we are sure we not access it in any other method here anymore. // Directly set this to null so we are sure we not access it in any other method here anymore.
@ -469,16 +452,7 @@ public abstract class ByteToMessageDecoder extends ChannelHandlerAdapter {
*/ */
final void decodeRemovalReentryProtection(ChannelHandlerContext ctx, ByteBuf in) final void decodeRemovalReentryProtection(ChannelHandlerContext ctx, ByteBuf in)
throws Exception { throws Exception {
decodeState = STATE_CALLING_CHILD_DECODE; decode(ctx, in);
try {
decode(ctx, in);
} finally {
boolean removePending = decodeState == STATE_HANDLER_REMOVED_PENDING;
decodeState = STATE_INIT;
if (removePending) {
handlerRemoved(ctx);
}
}
} }
/** /**

View File

@ -97,8 +97,9 @@ public class Http2ServerInitializer extends ChannelInitializer<SocketChannel> {
System.err.println("Directly talking: " + msg.protocolVersion() + " (no upgrade was attempted)"); System.err.println("Directly talking: " + msg.protocolVersion() + " (no upgrade was attempted)");
ChannelPipeline pipeline = ctx.pipeline(); ChannelPipeline pipeline = ctx.pipeline();
pipeline.addAfter(ctx.name(), null, new HelloWorldHttp1Handler("Direct. No Upgrade Attempted.")); pipeline.addAfter(ctx.name(), null, new HelloWorldHttp1Handler("Direct. No Upgrade Attempted."));
pipeline.replace(this, null, new HttpObjectAggregator(maxHttpContentLength)); pipeline.addAfter(ctx.name(), null, new HttpObjectAggregator(maxHttpContentLength));
ctx.fireChannelRead(ReferenceCountUtil.retain(msg)); ctx.fireChannelRead(ReferenceCountUtil.retain(msg));
pipeline.remove(this);
} }
}); });

View File

@ -99,8 +99,9 @@ public class Http2ServerInitializer extends ChannelInitializer<SocketChannel> {
System.err.println("Directly talking: " + msg.protocolVersion() + " (no upgrade was attempted)"); System.err.println("Directly talking: " + msg.protocolVersion() + " (no upgrade was attempted)");
ChannelPipeline pipeline = ctx.pipeline(); ChannelPipeline pipeline = ctx.pipeline();
pipeline.addAfter(ctx.name(), null, new HelloWorldHttp1Handler("Direct. No Upgrade Attempted.")); pipeline.addAfter(ctx.name(), null, new HelloWorldHttp1Handler("Direct. No Upgrade Attempted."));
pipeline.replace(this, null, new HttpObjectAggregator(maxHttpContentLength)); pipeline.addAfter(ctx.name(), null, new HttpObjectAggregator(maxHttpContentLength));
ctx.fireChannelRead(ReferenceCountUtil.retain(msg)); ctx.fireChannelRead(ReferenceCountUtil.retain(msg));
pipeline.remove(this);
} }
}); });

View File

@ -98,8 +98,9 @@ public class Http2ServerInitializer extends ChannelInitializer<SocketChannel> {
System.err.println("Directly talking: " + msg.protocolVersion() + " (no upgrade was attempted)"); System.err.println("Directly talking: " + msg.protocolVersion() + " (no upgrade was attempted)");
ChannelPipeline pipeline = ctx.pipeline(); ChannelPipeline pipeline = ctx.pipeline();
pipeline.addAfter(ctx.name(), null, new HelloWorldHttp1Handler("Direct. No Upgrade Attempted.")); pipeline.addAfter(ctx.name(), null, new HelloWorldHttp1Handler("Direct. No Upgrade Attempted."));
pipeline.replace(this, null, new HttpObjectAggregator(maxHttpContentLength)); pipeline.addAfter(ctx.name(), null, new HttpObjectAggregator(maxHttpContentLength));
ctx.fireChannelRead(ReferenceCountUtil.retain(msg)); ctx.fireChannelRead(ReferenceCountUtil.retain(msg));
pipeline.remove(this);
} }
}); });

View File

@ -45,8 +45,8 @@ public final class SocksServerHandler extends SimpleChannelInboundHandler<SocksM
Socks4CommandRequest socksV4CmdRequest = (Socks4CommandRequest) socksRequest; Socks4CommandRequest socksV4CmdRequest = (Socks4CommandRequest) socksRequest;
if (socksV4CmdRequest.type() == Socks4CommandType.CONNECT) { if (socksV4CmdRequest.type() == Socks4CommandType.CONNECT) {
ctx.pipeline().addLast(new SocksServerConnectHandler()); ctx.pipeline().addLast(new SocksServerConnectHandler());
ctx.pipeline().remove(this);
ctx.fireChannelRead(socksRequest); ctx.fireChannelRead(socksRequest);
ctx.pipeline().remove(this);
} else { } else {
ctx.close(); ctx.close();
} }
@ -65,8 +65,8 @@ public final class SocksServerHandler extends SimpleChannelInboundHandler<SocksM
Socks5CommandRequest socks5CmdRequest = (Socks5CommandRequest) socksRequest; Socks5CommandRequest socks5CmdRequest = (Socks5CommandRequest) socksRequest;
if (socks5CmdRequest.type() == Socks5CommandType.CONNECT) { if (socks5CmdRequest.type() == Socks5CommandType.CONNECT) {
ctx.pipeline().addLast(new SocksServerConnectHandler()); ctx.pipeline().addLast(new SocksServerConnectHandler());
ctx.pipeline().remove(this);
ctx.fireChannelRead(socksRequest); ctx.fireChannelRead(socksRequest);
ctx.pipeline().remove(this);
} else { } else {
ctx.close(); ctx.close();
} }

View File

@ -39,44 +39,52 @@ public abstract class AbstractRemoteAddressFilter<T extends SocketAddress> imple
@Override @Override
public void channelRegistered(ChannelHandlerContext ctx) throws Exception { public void channelRegistered(ChannelHandlerContext ctx) throws Exception {
handleNewChannel(ctx); handleNewChannel(ctx, true);
ctx.fireChannelRegistered();
} }
@Override @Override
public void channelActive(ChannelHandlerContext ctx) throws Exception { public void channelActive(ChannelHandlerContext ctx) throws Exception {
if (!handleNewChannel(ctx)) { if (!handleNewChannel(ctx, false)) {
throw new IllegalStateException("cannot determine to accept or reject a channel: " + ctx.channel()); throw new IllegalStateException("cannot determine to accept or reject a channel: " + ctx.channel());
} else {
ctx.fireChannelActive();
} }
} }
private boolean handleNewChannel(ChannelHandlerContext ctx) throws Exception { private boolean handleNewChannel(ChannelHandlerContext ctx, boolean register) throws Exception {
@SuppressWarnings("unchecked") @SuppressWarnings("unchecked")
T remoteAddress = (T) ctx.channel().remoteAddress(); T remoteAddress = (T) ctx.channel().remoteAddress();
boolean remove = false;
try {
// If the remote address is not available yet, defer the decision.
if (remoteAddress == null) {
return false;
}
// If the remote address is not available yet, defer the decision. if (accept(ctx, remoteAddress)) {
if (remoteAddress == null) { channelAccepted(ctx, remoteAddress);
return false; remove = true;
}
// No need to keep this handler in the pipeline anymore because the decision is going to be made now.
// Also, this will prevent the subsequent events from being handled by this handler.
ctx.pipeline().remove(this);
if (accept(ctx, remoteAddress)) {
channelAccepted(ctx, remoteAddress);
} else {
ChannelFuture rejectedFuture = channelRejected(ctx, remoteAddress);
if (rejectedFuture != null) {
rejectedFuture.addListener(ChannelFutureListener.CLOSE);
} else { } else {
ctx.close(); ChannelFuture rejectedFuture = channelRejected(ctx, remoteAddress);
if (rejectedFuture != null && !rejectedFuture.isDone()) {
rejectedFuture.addListener(ChannelFutureListener.CLOSE);
} else {
ctx.close();
}
}
return true;
} finally {
if (!ctx.isRemoved()) {
if (register) {
ctx.fireChannelRegistered();
} else {
ctx.fireChannelActive();
}
if (remove) {
// No need to keep this handler in the pipeline anymore because the decision is going to be made
// now. Also, this will prevent the subsequent events from being handled by this handler.
ctx.pipeline().remove(this);
}
} }
} }
return true;
} }
/** /**

View File

@ -97,13 +97,16 @@ public abstract class ApplicationProtocolNegotiationHandler implements ChannelHa
} catch (Throwable cause) { } catch (Throwable cause) {
exceptionCaught(ctx, cause); exceptionCaught(ctx, cause);
} finally { } finally {
ctx.fireUserEventTriggered(evt);
ChannelPipeline pipeline = ctx.pipeline(); ChannelPipeline pipeline = ctx.pipeline();
if (pipeline.context(this) != null) { if (pipeline.context(this) != null) {
pipeline.remove(this); pipeline.remove(this);
} }
} }
} else {
ctx.fireUserEventTriggered(evt);
} }
ctx.fireUserEventTriggered(evt);
} }
/** /**

View File

@ -63,7 +63,7 @@ public abstract class SslClientHelloHandler<T> extends ByteToMessageDecoder {
"not an SSL/TLS record: " + ByteBufUtil.hexDump(in)); "not an SSL/TLS record: " + ByteBufUtil.hexDump(in));
in.skipBytes(in.readableBytes()); in.skipBytes(in.readableBytes());
ctx.fireUserEventTriggered(new SniCompletionEvent(e)); ctx.fireUserEventTriggered(new SniCompletionEvent(e));
SslUtils.handleHandshakeFailure(ctx, e, true); ctx.fireUserEventTriggered(new SslHandshakeCompletionEvent(e));
throw e; throw e;
} }
if (len == SslUtils.NOT_ENOUGH_DATA) { if (len == SslUtils.NOT_ENOUGH_DATA) {

View File

@ -1091,7 +1091,8 @@ public class SslHandler extends ByteToMessageDecoder {
if (logger.isDebugEnabled()) { if (logger.isDebugEnabled()) {
logger.debug( logger.debug(
"{} Swallowing a harmless 'connection reset by peer / broken pipe' error that occurred " + "{} Swallowing a harmless 'connection reset by peer / broken pipe' error that occurred " +
"while writing close_notify in response to the peer's close_notify", ctx.channel(), cause); "while writing close_notify in response to the peer's close_notify",
ctx.channel(), cause);
} }
// Close the connection explicitly just in case the transport // Close the connection explicitly just in case the transport
@ -1101,6 +1102,11 @@ public class SslHandler extends ByteToMessageDecoder {
} }
} else { } else {
ctx.fireExceptionCaught(cause); ctx.fireExceptionCaught(cause);
if (cause instanceof SSLException ||
((cause instanceof DecoderException) && cause.getCause() instanceof SSLException)) {
ctx.close();
}
} }
} }
@ -2057,9 +2063,13 @@ public class SslHandler extends ByteToMessageDecoder {
} }
final long closeNotifyReadTimeout = closeNotifyReadTimeoutMillis; final long closeNotifyReadTimeout = closeNotifyReadTimeoutMillis;
if (closeNotifyReadTimeout <= 0) { if (closeNotifyReadTimeout <= 0) {
// Trigger the close in all cases to make sure the promise is notified if (ctx.channel().isActive()) {
// See https://github.com/netty/netty/issues/2358 // Trigger the close in all cases to make sure the promise is notified
addCloseListener(ctx.close(ctx.newPromise()), promise); // See https://github.com/netty/netty/issues/2358
addCloseListener(ctx.close(ctx.newPromise()), promise);
} else {
promise.trySuccess();
}
} else { } else {
final ScheduledFuture<?> closeNotifyReadTimeoutFuture; final ScheduledFuture<?> closeNotifyReadTimeoutFuture;
@ -2083,7 +2093,11 @@ public class SslHandler extends ByteToMessageDecoder {
if (closeNotifyReadTimeoutFuture != null) { if (closeNotifyReadTimeoutFuture != null) {
closeNotifyReadTimeoutFuture.cancel(false); closeNotifyReadTimeoutFuture.cancel(false);
} }
addCloseListener(ctx.close(ctx.newPromise()), promise); if (ctx.channel().isActive()) {
addCloseListener(ctx.close(ctx.newPromise()), promise);
} else {
promise.trySuccess();
}
}); });
} }
}); });

View File

@ -346,7 +346,6 @@ final class SslUtils {
if (notify) { if (notify) {
ctx.fireUserEventTriggered(new SslHandshakeCompletionEvent(cause)); ctx.fireUserEventTriggered(new SslHandshakeCompletionEvent(cause));
} }
ctx.close();
} }
/** /**

View File

@ -48,15 +48,13 @@ public abstract class OcspClientHandler implements ChannelHandler {
@Override @Override
public void userEventTriggered(ChannelHandlerContext ctx, Object evt) throws Exception { public void userEventTriggered(ChannelHandlerContext ctx, Object evt) throws Exception {
ctx.fireUserEventTriggered(evt);
if (evt instanceof SslHandshakeCompletionEvent) { if (evt instanceof SslHandshakeCompletionEvent) {
ctx.pipeline().remove(this);
SslHandshakeCompletionEvent event = (SslHandshakeCompletionEvent) evt; SslHandshakeCompletionEvent event = (SslHandshakeCompletionEvent) evt;
if (event.isSuccess() && !verify(ctx, engine)) { if (event.isSuccess() && !verify(ctx, engine)) {
throw new SSLHandshakeException("Bad OCSP response"); throw new SSLHandshakeException("Bad OCSP response");
} }
ctx.pipeline().remove(this);
} }
ctx.fireUserEventTriggered(evt);
} }
} }

View File

@ -442,10 +442,12 @@ public class SslHandlerTest {
sslHandler.setHandshakeTimeoutMillis(1000); sslHandler.setHandshakeTimeoutMillis(1000);
ch.pipeline().addFirst(sslHandler); ch.pipeline().addFirst(sslHandler);
sslHandler.handshakeFuture().addListener((FutureListener<Channel>) future -> { sslHandler.handshakeFuture().addListener((FutureListener<Channel>) future -> {
ch.pipeline().remove(sslHandler); ch.eventLoop().execute(() -> {
ch.pipeline().remove(sslHandler);
// Schedule the close so removal has time to propagate exception if any. // Schedule the close so removal has time to propagate exception if any.
ch.eventLoop().execute(ch::close); ch.eventLoop().execute(ch::close);
});
}); });
ch.pipeline().addLast(new ChannelHandler() { ch.pipeline().addLast(new ChannelHandler() {
@ -457,6 +459,7 @@ public class SslHandlerTest {
if (cause instanceof IllegalReferenceCountException) { if (cause instanceof IllegalReferenceCountException) {
promise.setFailure(cause); promise.setFailure(cause);
} }
cause.printStackTrace();
} }
@Override @Override

View File

@ -85,8 +85,9 @@ public class Http2ServerInitializer extends ChannelInitializer<SocketChannel> {
ChannelPipeline pipeline = ctx.pipeline(); ChannelPipeline pipeline = ctx.pipeline();
ChannelHandlerContext thisCtx = pipeline.context(this); ChannelHandlerContext thisCtx = pipeline.context(this);
pipeline.addAfter(thisCtx.name(), null, new HelloWorldHttp1Handler("Direct. No Upgrade Attempted.")); pipeline.addAfter(thisCtx.name(), null, new HelloWorldHttp1Handler("Direct. No Upgrade Attempted."));
pipeline.replace(this, null, new HttpObjectAggregator(maxHttpContentLength)); pipeline.addAfter(thisCtx.name(), null, new HttpObjectAggregator(maxHttpContentLength));
ctx.fireChannelRead(ReferenceCountUtil.retain(msg)); ctx.fireChannelRead(ReferenceCountUtil.retain(msg));
pipeline.remove(this);
} }
}); });

View File

@ -36,6 +36,7 @@ import java.net.SocketAddress;
import java.net.SocketException; import java.net.SocketException;
import java.nio.channels.ClosedChannelException; import java.nio.channels.ClosedChannelException;
import java.nio.channels.NotYetConnectedException; import java.nio.channels.NotYetConnectedException;
import java.util.NoSuchElementException;
import java.util.concurrent.Executor; import java.util.concurrent.Executor;
import java.util.concurrent.RejectedExecutionException; import java.util.concurrent.RejectedExecutionException;
@ -775,12 +776,25 @@ public abstract class AbstractChannel extends DefaultAttributeMap implements Cha
pipeline.fireChannelInactive(); pipeline.fireChannelInactive();
} }
// Some transports like local and AIO does not allow the deregistration of // Some transports like local and AIO does not allow the deregistration of
// an open channel. Their doDeregister() calls close(). Consequently, // an open channel. Their doDeregister() calls close(). Consequently,
// close() calls deregister() again - no need to fire channelUnregistered, so check // close() calls deregister() again - no need to fire channelUnregistered, so check
// if it was registered. // if it was registered.
if (registered) { if (registered) {
registered = false; registered = false;
pipeline.fireChannelUnregistered(); pipeline.fireChannelUnregistered();
if (!isOpen()) {
// Remove all handlers from the ChannelPipeline. This is needed to ensure
// handlerRemoved(...) is called and so resources are released.
while (!pipeline.isEmpty()) {
try {
pipeline.removeLast();
} catch (NoSuchElementException ignore) {
// try again as there may be a race when someone outside the EventLoop removes
// handlers concurrently as well.
}
}
}
} }
safeSetSuccess(promise); safeSetSuccess(promise);
} }

View File

@ -437,6 +437,14 @@ public interface ChannelPipeline
*/ */
ChannelHandlerContext lastContext(); ChannelHandlerContext lastContext();
/**
* Returns {@code true} if this {@link ChannelPipeline} is empty, which means no {@link ChannelHandler} is
* present.
*/
default boolean isEmpty() {
return lastContext() == null;
}
/** /**
* Returns the {@link ChannelHandler} with the specified name in this * Returns the {@link ChannelHandler} with the specified name in this
* pipeline. * pipeline.

View File

@ -50,10 +50,15 @@ final class DefaultChannelHandlerContext implements ChannelHandlerContext, Resou
*/ */
private static final int ADD_COMPLETE = 1; private static final int ADD_COMPLETE = 1;
/**
* {@link ChannelHandler#handlerRemoved(ChannelHandlerContext)} is about to be called.
*/
private static final int REMOVE_STARTED = 2;
/** /**
* {@link ChannelHandler#handlerRemoved(ChannelHandlerContext)} was called. * {@link ChannelHandler#handlerRemoved(ChannelHandlerContext)} was called.
*/ */
private static final int REMOVE_COMPLETE = 2; private static final int REMOVE_COMPLETE = 3;
private final int executionMask; private final int executionMask;
private final DefaultChannelPipeline pipeline; private final DefaultChannelPipeline pipeline;
@ -63,10 +68,10 @@ final class DefaultChannelHandlerContext implements ChannelHandlerContext, Resou
// Lazily instantiated tasks used to trigger events to a handler with different executor. // Lazily instantiated tasks used to trigger events to a handler with different executor.
// There is no need to make this volatile as at worse it will just create a few more instances then needed. // There is no need to make this volatile as at worse it will just create a few more instances then needed.
private Tasks invokeTasks; private Tasks invokeTasks;
private int handlerState = INIT;
DefaultChannelHandlerContext next; DefaultChannelHandlerContext next;
DefaultChannelHandlerContext prev; DefaultChannelHandlerContext prev;
private int handlerState = INIT;
DefaultChannelHandlerContext(DefaultChannelPipeline pipeline, String name, DefaultChannelHandlerContext(DefaultChannelPipeline pipeline, String name,
ChannelHandler handler) { ChannelHandler handler) {
@ -76,6 +81,22 @@ final class DefaultChannelHandlerContext implements ChannelHandlerContext, Resou
this.handler = handler; this.handler = handler;
} }
private static void failRemoved(DefaultChannelHandlerContext ctx, ChannelPromise promise) {
promise.setFailure(newRemovedException(ctx, null));
}
private void notifyHandlerRemovedAlready() {
notifyHandlerRemovedAlready(null);
}
private void notifyHandlerRemovedAlready(Throwable cause) {
pipeline().fireExceptionCaught(newRemovedException(this, cause));
}
private static ChannelPipelineException newRemovedException(ChannelHandlerContext ctx, Throwable cause) {
return new ChannelPipelineException("Context " + ctx + " already removed", cause);
}
private Tasks invokeTasks() { private Tasks invokeTasks() {
Tasks tasks = invokeTasks; Tasks tasks = invokeTasks;
if (tasks == null) { if (tasks == null) {
@ -126,7 +147,12 @@ final class DefaultChannelHandlerContext implements ChannelHandlerContext, Resou
} }
private void findAndInvokeChannelRegistered() { private void findAndInvokeChannelRegistered() {
findContextInbound(MASK_CHANNEL_REGISTERED).invokeChannelRegistered(); DefaultChannelHandlerContext ctx = findContextInbound(MASK_CHANNEL_REGISTERED);
if (ctx == null) {
notifyHandlerRemovedAlready();
return;
}
ctx.invokeChannelRegistered();
} }
void invokeChannelRegistered() { void invokeChannelRegistered() {
@ -149,7 +175,12 @@ final class DefaultChannelHandlerContext implements ChannelHandlerContext, Resou
} }
private void findAndInvokeChannelUnregistered() { private void findAndInvokeChannelUnregistered() {
findContextInbound(MASK_CHANNEL_UNREGISTERED).invokeChannelUnregistered(); DefaultChannelHandlerContext ctx = findContextInbound(MASK_CHANNEL_UNREGISTERED);
if (ctx == null) {
notifyHandlerRemovedAlready();
return;
}
ctx.invokeChannelUnregistered();
} }
void invokeChannelUnregistered() { void invokeChannelUnregistered() {
@ -172,7 +203,12 @@ final class DefaultChannelHandlerContext implements ChannelHandlerContext, Resou
} }
private void findAndInvokeChannelActive() { private void findAndInvokeChannelActive() {
findContextInbound(MASK_CHANNEL_ACTIVE).invokeChannelActive(); DefaultChannelHandlerContext ctx = findContextInbound(MASK_CHANNEL_ACTIVE);
if (ctx == null) {
notifyHandlerRemovedAlready();
return;
}
ctx.invokeChannelActive();
} }
void invokeChannelActive() { void invokeChannelActive() {
@ -195,7 +231,12 @@ final class DefaultChannelHandlerContext implements ChannelHandlerContext, Resou
} }
private void findAndInvokeChannelInactive() { private void findAndInvokeChannelInactive() {
findContextInbound(MASK_CHANNEL_INACTIVE).invokeChannelInactive(); DefaultChannelHandlerContext ctx = findContextInbound(MASK_CHANNEL_INACTIVE);
if (ctx == null) {
notifyHandlerRemovedAlready();
return;
}
ctx.invokeChannelInactive();
} }
void invokeChannelInactive() { void invokeChannelInactive() {
@ -226,7 +267,12 @@ final class DefaultChannelHandlerContext implements ChannelHandlerContext, Resou
} }
private void findAndInvokeExceptionCaught(Throwable cause) { private void findAndInvokeExceptionCaught(Throwable cause) {
findContextInbound(MASK_EXCEPTION_CAUGHT).invokeExceptionCaught(cause); DefaultChannelHandlerContext ctx = findContextInbound(MASK_EXCEPTION_CAUGHT);
if (ctx == null) {
notifyHandlerRemovedAlready(cause);
return;
}
ctx.invokeExceptionCaught(cause);
} }
void invokeExceptionCaught(final Throwable cause) { void invokeExceptionCaught(final Throwable cause) {
@ -261,7 +307,13 @@ final class DefaultChannelHandlerContext implements ChannelHandlerContext, Resou
} }
private void findAndInvokeUserEventTriggered(Object event) { private void findAndInvokeUserEventTriggered(Object event) {
findContextInbound(MASK_USER_EVENT_TRIGGERED).invokeUserEventTriggered(event); DefaultChannelHandlerContext ctx = findContextInbound(MASK_USER_EVENT_TRIGGERED);
if (ctx == null) {
ReferenceCountUtil.release(event);
notifyHandlerRemovedAlready();
return;
}
ctx.invokeUserEventTriggered(event);
} }
void invokeUserEventTriggered(Object event) { void invokeUserEventTriggered(Object event) {
@ -290,7 +342,13 @@ final class DefaultChannelHandlerContext implements ChannelHandlerContext, Resou
} }
private void findAndInvokeChannelRead(Object msg) { private void findAndInvokeChannelRead(Object msg) {
findContextInbound(MASK_CHANNEL_READ).invokeChannelRead(msg); DefaultChannelHandlerContext ctx = findContextInbound(MASK_CHANNEL_READ);
if (ctx == null) {
ReferenceCountUtil.release(msg);
notifyHandlerRemovedAlready();
return;
}
ctx.invokeChannelRead(msg);
} }
void invokeChannelRead(Object msg) { void invokeChannelRead(Object msg) {
@ -315,7 +373,12 @@ final class DefaultChannelHandlerContext implements ChannelHandlerContext, Resou
} }
private void findAndInvokeChannelReadComplete() { private void findAndInvokeChannelReadComplete() {
findContextInbound(MASK_CHANNEL_READ_COMPLETE).invokeChannelReadComplete(); DefaultChannelHandlerContext ctx = findContextInbound(MASK_CHANNEL_READ_COMPLETE);
if (ctx == null) {
notifyHandlerRemovedAlready();
return;
}
ctx.invokeChannelReadComplete();
} }
void invokeChannelReadComplete() { void invokeChannelReadComplete() {
@ -339,7 +402,12 @@ final class DefaultChannelHandlerContext implements ChannelHandlerContext, Resou
} }
private void findAndInvokeChannelWritabilityChanged() { private void findAndInvokeChannelWritabilityChanged() {
findContextInbound(MASK_CHANNEL_WRITABILITY_CHANGED).invokeChannelWritabilityChanged(); DefaultChannelHandlerContext ctx = findContextInbound(MASK_CHANNEL_WRITABILITY_CHANGED);
if (ctx == null) {
notifyHandlerRemovedAlready();
return;
}
ctx.invokeChannelWritabilityChanged();
} }
void invokeChannelWritabilityChanged() { void invokeChannelWritabilityChanged() {
@ -403,7 +471,12 @@ final class DefaultChannelHandlerContext implements ChannelHandlerContext, Resou
} }
private void findAndInvokeBind(SocketAddress localAddress, ChannelPromise promise) { private void findAndInvokeBind(SocketAddress localAddress, ChannelPromise promise) {
findContextOutbound(MASK_BIND).invokeBind(localAddress, promise); DefaultChannelHandlerContext ctx = findContextOutbound(MASK_BIND);
if (ctx == null) {
failRemoved(this, promise);
return;
}
ctx.invokeBind(localAddress, promise);
} }
private void invokeBind(SocketAddress localAddress, ChannelPromise promise) { private void invokeBind(SocketAddress localAddress, ChannelPromise promise) {
@ -438,7 +511,12 @@ final class DefaultChannelHandlerContext implements ChannelHandlerContext, Resou
} }
private void findAndInvokeConnect(SocketAddress remoteAddress, SocketAddress localAddress, ChannelPromise promise) { private void findAndInvokeConnect(SocketAddress remoteAddress, SocketAddress localAddress, ChannelPromise promise) {
findContextOutbound(MASK_CONNECT).invokeConnect(remoteAddress, localAddress, promise); DefaultChannelHandlerContext ctx = findContextOutbound(MASK_CONNECT);
if (ctx == null) {
failRemoved(this, promise);
return;
}
ctx.invokeConnect(remoteAddress, localAddress, promise);
} }
private void invokeConnect(SocketAddress remoteAddress, SocketAddress localAddress, ChannelPromise promise) { private void invokeConnect(SocketAddress remoteAddress, SocketAddress localAddress, ChannelPromise promise) {
@ -472,7 +550,12 @@ final class DefaultChannelHandlerContext implements ChannelHandlerContext, Resou
} }
private void findAndInvokeDisconnect(ChannelPromise promise) { private void findAndInvokeDisconnect(ChannelPromise promise) {
findContextOutbound(MASK_DISCONNECT).invokeDisconnect(promise); DefaultChannelHandlerContext ctx = findContextOutbound(MASK_DISCONNECT);
if (ctx == null) {
failRemoved(this, promise);
return;
}
ctx.invokeDisconnect(promise);
} }
private void invokeDisconnect(ChannelPromise promise) { private void invokeDisconnect(ChannelPromise promise) {
@ -500,7 +583,12 @@ final class DefaultChannelHandlerContext implements ChannelHandlerContext, Resou
} }
private void findAndInvokeClose(ChannelPromise promise) { private void findAndInvokeClose(ChannelPromise promise) {
findContextOutbound(MASK_CLOSE).invokeClose(promise); DefaultChannelHandlerContext ctx = findContextOutbound(MASK_CLOSE);
if (ctx == null) {
failRemoved(this, promise);
return;
}
ctx.invokeClose(promise);
} }
private void invokeClose(ChannelPromise promise) { private void invokeClose(ChannelPromise promise) {
@ -528,7 +616,12 @@ final class DefaultChannelHandlerContext implements ChannelHandlerContext, Resou
} }
private void findAndInvokeRegister(ChannelPromise promise) { private void findAndInvokeRegister(ChannelPromise promise) {
findContextOutbound(MASK_REGISTER).invokeRegister(promise); DefaultChannelHandlerContext ctx = findContextOutbound(MASK_REGISTER);
if (ctx == null) {
failRemoved(this, promise);
return;
}
ctx.invokeRegister(promise);
} }
private void invokeRegister(ChannelPromise promise) { private void invokeRegister(ChannelPromise promise) {
@ -556,7 +649,12 @@ final class DefaultChannelHandlerContext implements ChannelHandlerContext, Resou
} }
private void findAndInvokeDeregister(ChannelPromise promise) { private void findAndInvokeDeregister(ChannelPromise promise) {
findContextOutbound(MASK_DEREGISTER).invokeDeregister(promise); DefaultChannelHandlerContext ctx = findContextOutbound(MASK_DEREGISTER);
if (ctx == null) {
failRemoved(this, promise);
return;
}
ctx.invokeDeregister(promise);
} }
private void invokeDeregister(ChannelPromise promise) { private void invokeDeregister(ChannelPromise promise) {
@ -580,7 +678,10 @@ final class DefaultChannelHandlerContext implements ChannelHandlerContext, Resou
} }
private void findAndInvokeRead() { private void findAndInvokeRead() {
findContextOutbound(MASK_READ).invokeRead(); DefaultChannelHandlerContext ctx = findContextOutbound(MASK_READ);
if (ctx != null) {
ctx.invokeRead();
}
} }
private void invokeRead() { private void invokeRead() {
@ -595,7 +696,12 @@ final class DefaultChannelHandlerContext implements ChannelHandlerContext, Resou
if ((executionMask & MASK_EXCEPTION_CAUGHT) != 0) { if ((executionMask & MASK_EXCEPTION_CAUGHT) != 0) {
notifyHandlerException(t); notifyHandlerException(t);
} else { } else {
findContextInbound(MASK_EXCEPTION_CAUGHT).notifyHandlerException(t); DefaultChannelHandlerContext ctx = findContextInbound(MASK_EXCEPTION_CAUGHT);
if (ctx == null) {
notifyHandlerRemovedAlready();
return;
}
ctx.invokeExceptionCaught(t);
} }
} }
@ -634,7 +740,10 @@ final class DefaultChannelHandlerContext implements ChannelHandlerContext, Resou
} }
private void findAndInvokeFlush() { private void findAndInvokeFlush() {
findContextOutbound(MASK_FLUSH).invokeFlush(); DefaultChannelHandlerContext ctx = findContextOutbound(MASK_FLUSH);
if (ctx != null) {
ctx.invokeFlush();
}
} }
private void invokeFlush() { private void invokeFlush() {
@ -673,6 +782,11 @@ final class DefaultChannelHandlerContext implements ChannelHandlerContext, Resou
if (executor.inEventLoop()) { if (executor.inEventLoop()) {
final DefaultChannelHandlerContext next = findContextOutbound(flush ? final DefaultChannelHandlerContext next = findContextOutbound(flush ?
(MASK_WRITE | MASK_FLUSH) : MASK_WRITE); (MASK_WRITE | MASK_FLUSH) : MASK_WRITE);
if (next == null) {
ReferenceCountUtil.release(msg);
failRemoved(this, promise);
return;
}
if (flush) { if (flush) {
next.invokeWriteAndFlush(msg, promise); next.invokeWriteAndFlush(msg, promise);
} else { } else {
@ -796,17 +910,23 @@ final class DefaultChannelHandlerContext implements ChannelHandlerContext, Resou
private DefaultChannelHandlerContext findContextInbound(int mask) { private DefaultChannelHandlerContext findContextInbound(int mask) {
DefaultChannelHandlerContext ctx = this; DefaultChannelHandlerContext ctx = this;
if (ctx.next == null) {
return null;
}
do { do {
ctx = ctx.next; ctx = ctx.next;
} while ((ctx.executionMask & mask) == 0); } while ((ctx.executionMask & mask) == 0 || ctx.handlerState == REMOVE_STARTED);
return ctx; return ctx;
} }
private DefaultChannelHandlerContext findContextOutbound(int mask) { private DefaultChannelHandlerContext findContextOutbound(int mask) {
DefaultChannelHandlerContext ctx = this; DefaultChannelHandlerContext ctx = this;
if (ctx.prev == null) {
return null;
}
do { do {
ctx = ctx.prev; ctx = ctx.prev;
} while ((ctx.executionMask & mask) == 0); } while ((ctx.executionMask & mask) == 0 || ctx.handlerState == REMOVE_STARTED);
return ctx; return ctx;
} }
@ -815,15 +935,11 @@ final class DefaultChannelHandlerContext implements ChannelHandlerContext, Resou
return channel().voidPromise(); return channel().voidPromise();
} }
private void setRemoved() {
handlerState = REMOVE_COMPLETE;
}
boolean setAddComplete() { boolean setAddComplete() {
// Ensure we never update when the handlerState is REMOVE_COMPLETE already. // Ensure we never update when the handlerState is REMOVE_COMPLETE already.
// oldState is usually ADD_PENDING but can also be REMOVE_COMPLETE when an EventExecutor is used that is not // oldState is usually ADD_PENDING but can also be REMOVE_COMPLETE when an EventExecutor is used that is not
// exposing ordering guarantees. // exposing ordering guarantees.
if (handlerState != REMOVE_COMPLETE) { if (handlerState == INIT) {
handlerState = ADD_COMPLETE; handlerState = ADD_COMPLETE;
return true; return true;
} }
@ -842,11 +958,12 @@ final class DefaultChannelHandlerContext implements ChannelHandlerContext, Resou
try { try {
// Only call handlerRemoved(...) if we called handlerAdded(...) before. // Only call handlerRemoved(...) if we called handlerAdded(...) before.
if (handlerState == ADD_COMPLETE) { if (handlerState == ADD_COMPLETE) {
handlerState = REMOVE_STARTED;
handler().handlerRemoved(this); handler().handlerRemoved(this);
} }
} finally { } finally {
// Mark the handler as removed in any case. // Mark the handler as removed in any case.
setRemoved(); handlerState = REMOVE_COMPLETE;
} }
} }
@ -855,6 +972,24 @@ final class DefaultChannelHandlerContext implements ChannelHandlerContext, Resou
return handlerState == REMOVE_COMPLETE; return handlerState == REMOVE_COMPLETE;
} }
void remove(boolean relink) {
assert handlerState == REMOVE_COMPLETE;
if (relink) {
DefaultChannelHandlerContext prev = this.prev;
DefaultChannelHandlerContext next = this.next;
// prev and next may be null if the handler was never really added to the pipeline
if (prev != null) {
prev.next = next;
}
if (next != null) {
next.prev = prev;
}
}
this.prev = null;
this.next = null;
}
@Override @Override
public <T> Attribute<T> attr(AttributeKey<T> key) { public <T> Attribute<T> attr(AttributeKey<T> key) {
return channel().attr(key); return channel().attr(key);
@ -931,6 +1066,11 @@ final class DefaultChannelHandlerContext implements ChannelHandlerContext, Resou
try { try {
decrementPendingOutboundBytes(); decrementPendingOutboundBytes();
DefaultChannelHandlerContext next = findContext(ctx); DefaultChannelHandlerContext next = findContext(ctx);
if (next == null) {
ReferenceCountUtil.release(msg);
failRemoved(ctx, promise);
return;
}
write(next, msg, promise); write(next, msg, promise);
} finally { } finally {
recycle(); recycle();

View File

@ -48,6 +48,7 @@ public class DefaultChannelPipeline implements ChannelPipeline {
private static final InternalLogger logger = InternalLoggerFactory.getInstance(DefaultChannelPipeline.class); private static final InternalLogger logger = InternalLoggerFactory.getInstance(DefaultChannelPipeline.class);
private static final String HEAD_NAME = generateName0(HeadHandler.class); private static final String HEAD_NAME = generateName0(HeadHandler.class);
private static final String TAIL_NAME = generateName0(TailHandler.class); private static final String TAIL_NAME = generateName0(TailHandler.class);
private static final ChannelHandler HEAD_HANDLER = new HeadHandler(); private static final ChannelHandler HEAD_HANDLER = new HeadHandler();
private static final ChannelHandler TAIL_HANDLER = new TailHandler(); private static final ChannelHandler TAIL_HANDLER = new TailHandler();
@ -64,6 +65,7 @@ public class DefaultChannelPipeline implements ChannelPipeline {
DefaultChannelPipeline.class, MessageSizeEstimator.Handle.class, "estimatorHandle"); DefaultChannelPipeline.class, MessageSizeEstimator.Handle.class, "estimatorHandle");
private final DefaultChannelHandlerContext head; private final DefaultChannelHandlerContext head;
private final DefaultChannelHandlerContext tail; private final DefaultChannelHandlerContext tail;
private final Channel channel; private final Channel channel;
private final ChannelFuture succeededFuture; private final ChannelFuture succeededFuture;
private final VoidChannelPromise voidPromise; private final VoidChannelPromise voidPromise;
@ -102,9 +104,26 @@ public class DefaultChannelPipeline implements ChannelPipeline {
} }
private DefaultChannelHandlerContext newContext(String name, ChannelHandler handler) { private DefaultChannelHandlerContext newContext(String name, ChannelHandler handler) {
checkMultiplicity(handler);
if (name == null) {
name = generateName(handler);
}
return new DefaultChannelHandlerContext(this, name, handler); return new DefaultChannelHandlerContext(this, name, handler);
} }
private void checkDuplicateName(String name) {
if (context(name) != null) {
throw new IllegalArgumentException("Duplicate handler name: " + name);
}
}
private int checkNoSuchElement(int idx, String name) {
if (idx == -1) {
throw new NoSuchElementException(name);
}
return idx;
}
@Override @Override
public final Channel channel() { public final Channel channel() {
return channel; return channel;
@ -117,18 +136,12 @@ public class DefaultChannelPipeline implements ChannelPipeline {
@Override @Override
public final ChannelPipeline addFirst(String name, ChannelHandler handler) { public final ChannelPipeline addFirst(String name, ChannelHandler handler) {
checkMultiplicity(handler);
if (name == null) {
name = generateName(handler);
}
DefaultChannelHandlerContext newCtx = newContext(name, handler); DefaultChannelHandlerContext newCtx = newContext(name, handler);
EventExecutor executor = executor(); EventExecutor executor = executor();
boolean inEventLoop = executor.inEventLoop(); boolean inEventLoop = executor.inEventLoop();
synchronized (handlers) { synchronized (handlers) {
if (context(name) != null) { checkDuplicateName(newCtx.name());
throw new IllegalArgumentException("Duplicate handler name: " + name);
}
handlers.add(0, newCtx); handlers.add(0, newCtx);
if (!inEventLoop) { if (!inEventLoop) {
try { try {
@ -156,18 +169,12 @@ public class DefaultChannelPipeline implements ChannelPipeline {
@Override @Override
public final ChannelPipeline addLast(String name, ChannelHandler handler) { public final ChannelPipeline addLast(String name, ChannelHandler handler) {
checkMultiplicity(handler);
if (name == null) {
name = generateName(handler);
}
DefaultChannelHandlerContext newCtx = newContext(name, handler); DefaultChannelHandlerContext newCtx = newContext(name, handler);
EventExecutor executor = executor(); EventExecutor executor = executor();
boolean inEventLoop = executor.inEventLoop(); boolean inEventLoop = executor.inEventLoop();
synchronized (handlers) { synchronized (handlers) {
if (context(name) != null) { checkDuplicateName(newCtx.name());
throw new IllegalArgumentException("Duplicate handler name: " + name);
}
handlers.add(newCtx); handlers.add(newCtx);
if (!inEventLoop) { if (!inEventLoop) {
try { try {
@ -197,24 +204,13 @@ public class DefaultChannelPipeline implements ChannelPipeline {
public final ChannelPipeline addBefore(String baseName, String name, ChannelHandler handler) { public final ChannelPipeline addBefore(String baseName, String name, ChannelHandler handler) {
final DefaultChannelHandlerContext ctx; final DefaultChannelHandlerContext ctx;
checkMultiplicity(handler);
if (name == null) {
name = generateName(handler);
}
DefaultChannelHandlerContext newCtx = newContext(name, handler); DefaultChannelHandlerContext newCtx = newContext(name, handler);
EventExecutor executor = executor(); EventExecutor executor = executor();
boolean inEventLoop = executor.inEventLoop(); boolean inEventLoop = executor.inEventLoop();
synchronized (handlers) { synchronized (handlers) {
int i = findCtxIdx(context -> context.name().equals(baseName)); int i = checkNoSuchElement(findCtxIdx(context -> context.name().equals(baseName)), baseName);
checkDuplicateName(newCtx.name());
if (i == -1) {
throw new NoSuchElementException(baseName);
}
if (context(name) != null) {
throw new IllegalArgumentException("Duplicate handler name: " + name);
}
ctx = handlers.get(i); ctx = handlers.get(i);
handlers.add(i, newCtx); handlers.add(i, newCtx);
if (!inEventLoop) { if (!inEventLoop) {
@ -244,7 +240,6 @@ public class DefaultChannelPipeline implements ChannelPipeline {
public final ChannelPipeline addAfter(String baseName, String name, ChannelHandler handler) { public final ChannelPipeline addAfter(String baseName, String name, ChannelHandler handler) {
final DefaultChannelHandlerContext ctx; final DefaultChannelHandlerContext ctx;
checkMultiplicity(handler);
if (name == null) { if (name == null) {
name = generateName(handler); name = generateName(handler);
} }
@ -253,15 +248,10 @@ public class DefaultChannelPipeline implements ChannelPipeline {
EventExecutor executor = executor(); EventExecutor executor = executor();
boolean inEventLoop = executor.inEventLoop(); boolean inEventLoop = executor.inEventLoop();
synchronized (handlers) { synchronized (handlers) {
int i = findCtxIdx(context -> context.name().equals(baseName)); int i = checkNoSuchElement(findCtxIdx(context -> context.name().equals(baseName)), baseName);
if (i == -1) { checkDuplicateName(newCtx.name());
throw new NoSuchElementException(baseName);
}
if (context(name) != null) {
throw new IllegalArgumentException("Duplicate handler name: " + name);
}
ctx = handlers.get(i); ctx = handlers.get(i);
handlers.add(i + 1, newCtx); handlers.add(i + 1, newCtx);
if (!inEventLoop) { if (!inEventLoop) {
@ -377,21 +367,13 @@ public class DefaultChannelPipeline implements ChannelPipeline {
EventExecutor executor = executor(); EventExecutor executor = executor();
boolean inEventLoop = executor.inEventLoop(); boolean inEventLoop = executor.inEventLoop();
synchronized (handlers) { synchronized (handlers) {
int idx = findCtxIdx(context -> context.handler() == handler); int idx = checkNoSuchElement(findCtxIdx(context -> context.handler() == handler), null);
if (idx == -1) {
throw new NoSuchElementException();
}
ctx = handlers.remove(idx); ctx = handlers.remove(idx);
assert ctx != null; assert ctx != null;
if (!inEventLoop) { if (!inEventLoop) {
try { scheduleRemove(idx, ctx);
executor.execute(() -> remove0(ctx)); return this;
return this;
} catch (Throwable cause) {
handlers.add(idx, ctx);
throw cause;
}
} }
} }
@ -405,21 +387,12 @@ public class DefaultChannelPipeline implements ChannelPipeline {
EventExecutor executor = executor(); EventExecutor executor = executor();
boolean inEventLoop = executor.inEventLoop(); boolean inEventLoop = executor.inEventLoop();
synchronized (handlers) { synchronized (handlers) {
int idx = findCtxIdx(context -> context.name().equals(name)); int idx = checkNoSuchElement(findCtxIdx(context -> context.name().equals(name)), name);
if (idx == -1) {
throw new NoSuchElementException();
}
ctx = handlers.remove(idx); ctx = handlers.remove(idx);
assert ctx != null; assert ctx != null;
if (!inEventLoop) { if (!inEventLoop) {
try { return scheduleRemove(idx, ctx);
executor.execute(() -> remove0(ctx));
return ctx.handler();
} catch (Throwable cause) {
handlers.add(idx, ctx);
throw cause;
}
} }
} }
@ -434,21 +407,13 @@ public class DefaultChannelPipeline implements ChannelPipeline {
EventExecutor executor = executor(); EventExecutor executor = executor();
boolean inEventLoop = executor.inEventLoop(); boolean inEventLoop = executor.inEventLoop();
synchronized (handlers) { synchronized (handlers) {
int idx = findCtxIdx(context -> handlerType.isAssignableFrom(context.handler().getClass())); int idx = checkNoSuchElement(findCtxIdx(
if (idx == -1) { context -> handlerType.isAssignableFrom(context.handler().getClass())), null);
throw new NoSuchElementException();
}
ctx = handlers.remove(idx); ctx = handlers.remove(idx);
assert ctx != null; assert ctx != null;
if (!inEventLoop) { if (!inEventLoop) {
try { return scheduleRemove(idx, ctx);
executor.execute(() -> remove0(ctx));
return (T) ctx.handler();
} catch (Throwable cause) {
handlers.add(idx, ctx);
throw cause;
}
} }
} }
@ -483,30 +448,19 @@ public class DefaultChannelPipeline implements ChannelPipeline {
assert ctx != null; assert ctx != null;
if (!inEventLoop) { if (!inEventLoop) {
try { return scheduleRemove(idx, ctx);
executor.execute(() -> remove0(ctx));
return (T) ctx.handler();
} catch (Throwable cause) {
handlers.add(idx, ctx);
throw cause;
}
} }
} }
remove0(ctx); remove0(ctx);
return (T) ctx.handler(); return (T) ctx.handler();
} }
private void unlink(DefaultChannelHandlerContext ctx) {
assert ctx != head && ctx != tail;
DefaultChannelHandlerContext prev = ctx.prev;
DefaultChannelHandlerContext next = ctx.next;
prev.next = next;
next.prev = prev;
}
private void remove0(DefaultChannelHandlerContext ctx) { private void remove0(DefaultChannelHandlerContext ctx) {
unlink(ctx); try {
callHandlerRemoved0(ctx); callHandlerRemoved0(ctx);
} finally {
ctx.remove(true);
}
} }
@Override @Override
@ -529,27 +483,17 @@ public class DefaultChannelPipeline implements ChannelPipeline {
private ChannelHandler replace( private ChannelHandler replace(
Predicate<DefaultChannelHandlerContext> predicate, String newName, ChannelHandler newHandler) { Predicate<DefaultChannelHandlerContext> predicate, String newName, ChannelHandler newHandler) {
checkMultiplicity(newHandler);
if (newName == null) {
newName = generateName(newHandler);
}
DefaultChannelHandlerContext oldCtx; DefaultChannelHandlerContext oldCtx;
DefaultChannelHandlerContext newCtx = newContext(newName, newHandler); DefaultChannelHandlerContext newCtx = newContext(newName, newHandler);
EventExecutor executor = executor(); EventExecutor executor = executor();
boolean inEventLoop = executor.inEventLoop(); boolean inEventLoop = executor.inEventLoop();
synchronized (handlers) { synchronized (handlers) {
int idx = findCtxIdx(predicate); int idx = checkNoSuchElement(findCtxIdx(predicate), null);
if (idx == -1) {
throw new NoSuchElementException();
}
oldCtx = handlers.get(idx); oldCtx = handlers.get(idx);
assert oldCtx != head && oldCtx != tail && oldCtx != null; assert oldCtx != head && oldCtx != tail && oldCtx != null;
if (!oldCtx.name().equals(newName)) { if (!oldCtx.name().equals(newCtx.name())) {
if (context(newName) != null) { checkDuplicateName(newCtx.name());
throw new IllegalArgumentException("Duplicate handler name: " + newName);
}
} }
DefaultChannelHandlerContext removed = handlers.set(idx, newCtx); DefaultChannelHandlerContext removed = handlers.set(idx, newCtx);
assert removed != null; assert removed != null;
@ -586,11 +530,15 @@ public class DefaultChannelPipeline implements ChannelPipeline {
oldCtx.prev = newCtx; oldCtx.prev = newCtx;
oldCtx.next = newCtx; oldCtx.next = newCtx;
// Invoke newHandler.handlerAdded() first (i.e. before oldHandler.handlerRemoved() is invoked) try {
// because callHandlerRemoved() will trigger channelRead() or flush() on newHandler and those // Invoke newHandler.handlerAdded() first (i.e. before oldHandler.handlerRemoved() is invoked)
// event handlers must be called after handlerAdded(). // because callHandlerRemoved() will trigger channelRead() or flush() on newHandler and those
callHandlerAdded0(newCtx); // event handlers must be called after handlerAdded().
callHandlerRemoved0(oldCtx); callHandlerAdded0(newCtx);
callHandlerRemoved0(oldCtx);
} finally {
oldCtx.remove(false);
}
} }
private static void checkMultiplicity(ChannelHandler handler) { private static void checkMultiplicity(ChannelHandler handler) {
@ -615,7 +563,6 @@ public class DefaultChannelPipeline implements ChannelPipeline {
handlers.remove(ctx); handlers.remove(ctx);
} }
unlink(ctx);
ctx.callHandlerRemoved(); ctx.callHandlerRemoved();
removed = true; removed = true;
@ -623,6 +570,8 @@ public class DefaultChannelPipeline implements ChannelPipeline {
if (logger.isWarnEnabled()) { if (logger.isWarnEnabled()) {
logger.warn("Failed to remove a handler: " + ctx.name(), t2); logger.warn("Failed to remove a handler: " + ctx.name(), t2);
} }
} finally {
ctx.remove(true);
} }
if (removed) { if (removed) {
@ -749,13 +698,7 @@ public class DefaultChannelPipeline implements ChannelPipeline {
assert ctx != null; assert ctx != null;
if (!inEventLoop) { if (!inEventLoop) {
try { return scheduleRemove(idx, ctx);
executor.execute(() -> remove0(ctx));
return ctx.handler();
} catch (Throwable cause) {
handlers.add(idx, ctx);
throw cause;
}
} }
} }
remove0(ctx); remove0(ctx);
@ -777,19 +720,24 @@ public class DefaultChannelPipeline implements ChannelPipeline {
assert ctx != null; assert ctx != null;
if (!inEventLoop) { if (!inEventLoop) {
try { return scheduleRemove(idx, ctx);
executor.execute(() -> remove0(ctx));
return ctx.handler();
} catch (Throwable cause) {
handlers.add(idx, ctx);
throw cause;
}
} }
} }
remove0(ctx); remove0(ctx);
return ctx.handler(); return ctx.handler();
} }
@SuppressWarnings("unchecked")
private <T extends ChannelHandler> T scheduleRemove(int idx, DefaultChannelHandlerContext ctx) {
try {
ctx.executor().execute(() -> remove0(ctx));
return (T) ctx.handler();
} catch (Throwable cause) {
handlers.add(idx, ctx);
throw cause;
}
}
@Override @Override
public ChannelHandler first() { public ChannelHandler first() {
ChannelHandlerContext ctx = firstContext(); ChannelHandlerContext ctx = firstContext();
@ -849,32 +797,6 @@ public class DefaultChannelPipeline implements ChannelPipeline {
return this; return this;
} }
/**
* Removes all handlers from the pipeline one by one from tail (exclusive) to head (exclusive) to trigger
* handlerRemoved().
*/
private void destroy() {
EventExecutor executor = executor();
if (executor.inEventLoop()) {
destroy0();
} else {
executor.execute(this::destroy0);
}
}
private void destroy0() {
assert executor().inEventLoop();
DefaultChannelHandlerContext ctx = this.tail.prev;
while (ctx != head) {
synchronized (handlers) {
handlers.remove(ctx);
}
remove0(ctx);
ctx = ctx.prev;
}
}
@Override @Override
public final ChannelPipeline fireChannelActive() { public final ChannelPipeline fireChannelActive() {
head.invokeChannelActive(); head.invokeChannelActive();
@ -980,7 +902,7 @@ public class DefaultChannelPipeline implements ChannelPipeline {
} }
@Override @Override
public final ChannelFuture close(ChannelPromise promise) { public ChannelFuture close(ChannelPromise promise) {
return tail.close(promise); return tail.close(promise);
} }
@ -1135,10 +1057,14 @@ public class DefaultChannelPipeline implements ChannelPipeline {
private static final class TailHandler implements ChannelHandler { private static final class TailHandler implements ChannelHandler {
@Override @Override
public void channelRegistered(ChannelHandlerContext ctx) { } public void channelRegistered(ChannelHandlerContext ctx) {
// Just swallow event
}
@Override @Override
public void channelUnregistered(ChannelHandlerContext ctx) { } public void channelUnregistered(ChannelHandlerContext ctx) {
// Just swallow event
}
@Override @Override
public void channelActive(ChannelHandlerContext ctx) { public void channelActive(ChannelHandlerContext ctx) {
@ -1155,12 +1081,6 @@ public class DefaultChannelPipeline implements ChannelPipeline {
((DefaultChannelPipeline) ctx.pipeline()).onUnhandledChannelWritabilityChanged(); ((DefaultChannelPipeline) ctx.pipeline()).onUnhandledChannelWritabilityChanged();
} }
@Override
public void handlerAdded(ChannelHandlerContext ctx) { }
@Override
public void handlerRemoved(ChannelHandlerContext ctx) { }
@Override @Override
public void userEventTriggered(ChannelHandlerContext ctx, Object evt) { public void userEventTriggered(ChannelHandlerContext ctx, Object evt) {
((DefaultChannelPipeline) ctx.pipeline()).onUnhandledInboundUserEventTriggered(evt); ((DefaultChannelPipeline) ctx.pipeline()).onUnhandledInboundUserEventTriggered(evt);
@ -1232,15 +1152,5 @@ public class DefaultChannelPipeline implements ChannelPipeline {
public void flush(ChannelHandlerContext ctx) { public void flush(ChannelHandlerContext ctx) {
ctx.channel().unsafe().flush(); ctx.channel().unsafe().flush();
} }
@Override
public void channelUnregistered(ChannelHandlerContext ctx) {
ctx.fireChannelUnregistered();
// Remove all handlers sequentially if channel is closed and unregistered.
if (!ctx.channel().isOpen()) {
((DefaultChannelPipeline) ctx.pipeline()).destroy();
}
}
} }
} }

View File

@ -782,64 +782,70 @@ public class EmbeddedChannel extends AbstractChannel {
return EmbeddedUnsafe.this.remoteAddress(); return EmbeddedUnsafe.this.remoteAddress();
} }
private void mayRunPendingTasks() {
if (!((EmbeddedEventLoop) eventLoop()).running) {
runPendingTasks();
}
}
@Override @Override
public void register(ChannelPromise promise) { public void register(ChannelPromise promise) {
EmbeddedUnsafe.this.register(promise); EmbeddedUnsafe.this.register(promise);
runPendingTasks(); mayRunPendingTasks();
} }
@Override @Override
public void bind(SocketAddress localAddress, ChannelPromise promise) { public void bind(SocketAddress localAddress, ChannelPromise promise) {
EmbeddedUnsafe.this.bind(localAddress, promise); EmbeddedUnsafe.this.bind(localAddress, promise);
runPendingTasks(); mayRunPendingTasks();
} }
@Override @Override
public void connect(SocketAddress remoteAddress, SocketAddress localAddress, ChannelPromise promise) { public void connect(SocketAddress remoteAddress, SocketAddress localAddress, ChannelPromise promise) {
EmbeddedUnsafe.this.connect(remoteAddress, localAddress, promise); EmbeddedUnsafe.this.connect(remoteAddress, localAddress, promise);
runPendingTasks(); mayRunPendingTasks();
} }
@Override @Override
public void disconnect(ChannelPromise promise) { public void disconnect(ChannelPromise promise) {
EmbeddedUnsafe.this.disconnect(promise); EmbeddedUnsafe.this.disconnect(promise);
runPendingTasks(); mayRunPendingTasks();
} }
@Override @Override
public void close(ChannelPromise promise) { public void close(ChannelPromise promise) {
EmbeddedUnsafe.this.close(promise); EmbeddedUnsafe.this.close(promise);
runPendingTasks(); mayRunPendingTasks();
} }
@Override @Override
public void closeForcibly() { public void closeForcibly() {
EmbeddedUnsafe.this.closeForcibly(); EmbeddedUnsafe.this.closeForcibly();
runPendingTasks(); mayRunPendingTasks();
} }
@Override @Override
public void deregister(ChannelPromise promise) { public void deregister(ChannelPromise promise) {
EmbeddedUnsafe.this.deregister(promise); EmbeddedUnsafe.this.deregister(promise);
runPendingTasks(); mayRunPendingTasks();
} }
@Override @Override
public void beginRead() { public void beginRead() {
EmbeddedUnsafe.this.beginRead(); EmbeddedUnsafe.this.beginRead();
runPendingTasks(); mayRunPendingTasks();
} }
@Override @Override
public void write(Object msg, ChannelPromise promise) { public void write(Object msg, ChannelPromise promise) {
EmbeddedUnsafe.this.write(msg, promise); EmbeddedUnsafe.this.write(msg, promise);
runPendingTasks(); mayRunPendingTasks();
} }
@Override @Override
public void flush() { public void flush() {
EmbeddedUnsafe.this.flush(); EmbeddedUnsafe.this.flush();
runPendingTasks(); mayRunPendingTasks();
} }
@Override @Override

View File

@ -30,7 +30,7 @@ import java.util.concurrent.TimeUnit;
final class EmbeddedEventLoop extends AbstractScheduledEventExecutor implements EventLoop { final class EmbeddedEventLoop extends AbstractScheduledEventExecutor implements EventLoop {
private final Queue<Runnable> tasks = new ArrayDeque<>(2); private final Queue<Runnable> tasks = new ArrayDeque<>(2);
private boolean running; boolean running;
private static EmbeddedChannel cast(Channel channel) { private static EmbeddedChannel cast(Channel channel) {
if (channel instanceof EmbeddedChannel) { if (channel instanceof EmbeddedChannel) {

View File

@ -38,6 +38,7 @@ import io.netty.util.concurrent.Future;
import io.netty.util.concurrent.ImmediateEventExecutor; import io.netty.util.concurrent.ImmediateEventExecutor;
import io.netty.util.concurrent.Promise; import io.netty.util.concurrent.Promise;
import io.netty.util.concurrent.ScheduledFuture; import io.netty.util.concurrent.ScheduledFuture;
import org.hamcrest.Matchers;
import org.junit.After; import org.junit.After;
import org.junit.AfterClass; import org.junit.AfterClass;
import org.junit.Test; import org.junit.Test;
@ -50,6 +51,7 @@ import java.util.List;
import java.util.NoSuchElementException; import java.util.NoSuchElementException;
import java.util.Queue; import java.util.Queue;
import java.util.concurrent.Callable; import java.util.concurrent.Callable;
import java.util.concurrent.CompletionException;
import java.util.concurrent.CountDownLatch; import java.util.concurrent.CountDownLatch;
import java.util.concurrent.ExecutorService; import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors; import java.util.concurrent.Executors;
@ -62,6 +64,7 @@ import static org.junit.Assert.assertFalse;
import static org.junit.Assert.assertNotNull; import static org.junit.Assert.assertNotNull;
import static org.junit.Assert.assertNull; import static org.junit.Assert.assertNull;
import static org.junit.Assert.assertSame; import static org.junit.Assert.assertSame;
import static org.junit.Assert.assertThat;
import static org.junit.Assert.assertTrue; import static org.junit.Assert.assertTrue;
import static org.junit.Assert.fail; import static org.junit.Assert.fail;
@ -761,7 +764,7 @@ public class DefaultChannelPipelineTest {
pipeline.channel().close().syncUninterruptibly(); pipeline.channel().close().syncUninterruptibly();
} }
@Test(timeout = 2000) @Test(timeout = 5000)
public void testAddRemoveHandlerCalled() throws Throwable { public void testAddRemoveHandlerCalled() throws Throwable {
ChannelPipeline pipeline = newLocalChannel().pipeline(); ChannelPipeline pipeline = newLocalChannel().pipeline();
@ -815,6 +818,110 @@ public class DefaultChannelPipelineTest {
} }
} }
@Test(timeout = 5000)
public void testOperationsFailWhenRemoved() {
ChannelPipeline pipeline = newLocalChannel().pipeline();
try {
pipeline.channel().register().syncUninterruptibly();
ChannelHandler handler = new ChannelHandler() { };
pipeline.addFirst(handler);
ChannelHandlerContext ctx = pipeline.context(handler);
pipeline.remove(handler);
testOperationsFailsOnContext(ctx);
} finally {
pipeline.channel().close().syncUninterruptibly();
}
}
@Test(timeout = 5000)
public void testOperationsFailWhenReplaced() {
ChannelPipeline pipeline = newLocalChannel().pipeline();
try {
pipeline.channel().register().syncUninterruptibly();
ChannelHandler handler = new ChannelHandler() { };
pipeline.addFirst(handler);
ChannelHandlerContext ctx = pipeline.context(handler);
pipeline.replace(handler, null, new ChannelHandler() { });
testOperationsFailsOnContext(ctx);
} finally {
pipeline.channel().close().syncUninterruptibly();
}
}
private static void testOperationsFailsOnContext(ChannelHandlerContext ctx) {
assertChannelPipelineException(ctx.writeAndFlush(""));
assertChannelPipelineException(ctx.write(""));
assertChannelPipelineException(ctx.bind(new SocketAddress() { }));
assertChannelPipelineException(ctx.close());
assertChannelPipelineException(ctx.connect(new SocketAddress() { }));
assertChannelPipelineException(ctx.connect(new SocketAddress() { }, new SocketAddress() { }));
assertChannelPipelineException(ctx.deregister());
assertChannelPipelineException(ctx.disconnect());
class ChannelPipelineExceptionValidator implements ChannelHandler {
private Promise<Void> validationPromise = ImmediateEventExecutor.INSTANCE.newPromise();
@Override
public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) {
try {
assertThat(cause, Matchers.instanceOf(ChannelPipelineException.class));
} catch (Throwable error) {
validationPromise.setFailure(error);
return;
}
validationPromise.setSuccess(null);
}
void validate() {
validationPromise.syncUninterruptibly();
validationPromise = ImmediateEventExecutor.INSTANCE.newPromise();
}
}
ChannelPipelineExceptionValidator validator = new ChannelPipelineExceptionValidator();
ctx.pipeline().addLast(validator);
ctx.fireChannelRead("");
validator.validate();
ctx.fireUserEventTriggered("");
validator.validate();
ctx.fireChannelReadComplete();
validator.validate();
ctx.fireExceptionCaught(new Exception());
validator.validate();
ctx.fireChannelActive();
validator.validate();
ctx.fireChannelRegistered();
validator.validate();
ctx.fireChannelInactive();
validator.validate();
ctx.fireChannelUnregistered();
validator.validate();
ctx.fireChannelWritabilityChanged();
validator.validate();
}
private static void assertChannelPipelineException(ChannelFuture f) {
try {
f.syncUninterruptibly();
} catch (CompletionException e) {
assertThat(e.getCause(), Matchers.instanceOf(ChannelPipelineException.class));
}
}
@Test(timeout = 3000) @Test(timeout = 3000)
public void testAddReplaceHandlerCalled() throws Throwable { public void testAddReplaceHandlerCalled() throws Throwable {
ChannelPipeline pipeline = newLocalChannel().pipeline(); ChannelPipeline pipeline = newLocalChannel().pipeline();

View File

@ -232,7 +232,9 @@ public class PendingWriteQueueTest {
ChannelPromise promise = channel.newPromise(); ChannelPromise promise = channel.newPromise();
final ChannelPromise promise3 = channel.newPromise(); final ChannelPromise promise3 = channel.newPromise();
promise.addListener((ChannelFutureListener) future -> queue.add(3L, promise3)); promise.addListener((ChannelFutureListener) future -> {
queue.add(3L, promise3);
});
ChannelPromise promise2 = channel.newPromise(); ChannelPromise promise2 = channel.newPromise();
channel.eventLoop().execute(() -> { channel.eventLoop().execute(() -> {
@ -245,8 +247,13 @@ public class PendingWriteQueueTest {
assertTrue(promise.isSuccess()); assertTrue(promise.isSuccess());
assertTrue(promise2.isDone()); assertTrue(promise2.isDone());
assertTrue(promise2.isSuccess()); assertTrue(promise2.isSuccess());
assertFalse(promise3.isDone());
assertFalse(promise3.isSuccess());
channel.eventLoop().execute(queue::removeAndWriteAll);
assertTrue(promise3.isDone()); assertTrue(promise3.isDone());
assertTrue(promise3.isSuccess()); assertTrue(promise3.isSuccess());
channel.runPendingTasks();
assertTrue(channel.finish()); assertTrue(channel.finish());
assertEquals(1L, (long) channel.readOutbound()); assertEquals(1L, (long) channel.readOutbound());
assertEquals(2L, (long) channel.readOutbound()); assertEquals(2L, (long) channel.readOutbound());