Ensure the DefaultChannelHandlerContext is unlinked once removed (#9970)

Motivation:

At the moment the next / prev references are not set to "null" in the DefaultChannelHandlerContext once the ChannelHandler is removed. This is bad as it basically let users still use the ChannelHandlerContext of a ChannelHandler after it is removed and may produce very suprising behaviour.

Modifications:

- Fail if someone tries to use the ChannelHandlerContext once the ChannelHandler was removed (for outbound operations fail the promise, for inbound fire the error through the ChannelPipeline)
- Fix some handlers to ensure we not use the ChannelHandlerContext after the handler was removed
- Adjust DefaultChannelPipeline / DefaultChannelHandlerContext to fixes races with removal / replacement of handlers

Result:

Cleanup behaviour and make it more predictable for pipeline modifications
This commit is contained in:
Norman Maurer 2020-03-01 08:13:33 +01:00 committed by GitHub
parent 302cb737f3
commit ae0fbb45e4
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
30 changed files with 512 additions and 298 deletions

View File

@ -253,6 +253,14 @@ public class HAProxyMessageDecoder extends ByteToMessageDecoder {
}
}
@Override
public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) throws Exception {
ctx.fireExceptionCaught(cause);
if (cause instanceof HAProxyProtocolException) {
ctx.close(); // drop connection immediately per spec
}
}
@Override
protected final void decode(ChannelHandlerContext ctx, ByteBuf in) throws Exception {
// determine the specification version
@ -325,7 +333,6 @@ public class HAProxyMessageDecoder extends ByteToMessageDecoder {
private void fail(final ChannelHandlerContext ctx, String errMsg, Exception e) {
finished = true;
ctx.close(); // drop connection immediately per spec
HAProxyProtocolException ppex;
if (errMsg != null && e != null) {
ppex = new HAProxyProtocolException(errMsg, e);

View File

@ -203,8 +203,8 @@ public class HttpClientUpgradeHandler extends HttpObjectAggregator {
// NOTE: not releasing the response since we're letting it propagate to the
// next handler.
ctx.fireUserEventTriggered(UpgradeEvent.UPGRADE_REJECTED);
removeThisHandler(ctx);
ctx.fireChannelRead(msg);
removeThisHandler(ctx);
return;
}
}

View File

@ -266,13 +266,26 @@ public class HttpObjectAggregator
});
}
} else if (oversized instanceof HttpResponse) {
ctx.close();
throw new TooLongFrameException("Response entity too large: " + oversized);
throw new ResponseTooLargeException("Response entity too large: " + oversized);
} else {
throw new IllegalStateException();
}
}
@Override
public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) throws Exception {
super.exceptionCaught(ctx, cause);
if (cause instanceof ResponseTooLargeException) {
ctx.close();
}
}
private static final class ResponseTooLargeException extends TooLongFrameException {
ResponseTooLargeException(String message) {
super(message);
}
}
private abstract static class AggregatedFullHttpMessage implements FullHttpMessage {
protected final HttpMessage message;
private final ByteBuf content;

View File

@ -328,13 +328,13 @@ public class HttpServerUpgradeHandler extends HttpObjectAggregator {
sourceCodec.upgradeFrom(ctx);
upgradeCodec.upgradeTo(ctx, request);
// Remove this handler from the pipeline.
ctx.pipeline().remove(HttpServerUpgradeHandler.this);
// Notify that the upgrade has occurred. Retain the event to offset
// the release() in the finally block.
ctx.fireUserEventTriggered(event.retain());
// Remove this handler from the pipeline.
ctx.pipeline().remove(HttpServerUpgradeHandler.this);
// Add the listener last to avoid firing upgrade logic after
// the channel is already closed since the listener may fire
// immediately if the write failed eagerly.

View File

@ -300,9 +300,9 @@ public abstract class WebSocketServerHandshaker {
@Override
public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) throws Exception {
// Remove ourself and fail the handshake promise.
ctx.pipeline().remove(this);
promise.tryFailure(cause);
ctx.fireExceptionCaught(cause);
ctx.pipeline().remove(this);
}
@Override

View File

@ -45,7 +45,6 @@ import static io.netty.handler.codec.http.HttpVersion.*;
class WebSocketServerProtocolHandshakeHandler implements ChannelHandler {
private final WebSocketServerProtocolConfig serverConfig;
private ChannelHandlerContext ctx;
private ChannelPromise handshakePromise;
WebSocketServerProtocolHandshakeHandler(WebSocketServerProtocolConfig serverConfig) {
@ -54,7 +53,6 @@ class WebSocketServerProtocolHandshakeHandler implements ChannelHandler {
@Override
public void handlerAdded(ChannelHandlerContext ctx) {
this.ctx = ctx;
handshakePromise = ctx.newPromise();
}
@ -86,7 +84,6 @@ class WebSocketServerProtocolHandshakeHandler implements ChannelHandler {
//
// See https://github.com/netty/netty/issues/9471.
WebSocketServerProtocolHandler.setHandshaker(ctx.channel(), handshaker);
ctx.pipeline().remove(this);
final ChannelFuture handshakeFuture = handshaker.handshake(ctx.channel(), req);
handshakeFuture.addListener((ChannelFutureListener) future -> {
@ -102,8 +99,9 @@ class WebSocketServerProtocolHandshakeHandler implements ChannelHandler {
new WebSocketServerProtocolHandler.HandshakeComplete(
req.uri(), req.headers(), handshaker.selectedSubprotocol()));
}
ctx.pipeline().remove(this);
});
applyHandshakeTimeout();
applyHandshakeTimeout(ctx);
}
} finally {
req.release();
@ -132,7 +130,7 @@ class WebSocketServerProtocolHandshakeHandler implements ChannelHandler {
return protocol + "://" + host + path;
}
private void applyHandshakeTimeout() {
private void applyHandshakeTimeout(ChannelHandlerContext ctx) {
final ChannelPromise localHandshakePromise = handshakePromise;
final long handshakeTimeoutMillis = serverConfig.handshakeTimeoutMillis();
if (handshakeTimeoutMillis <= 0 || localHandshakePromise.isDone()) {

View File

@ -120,8 +120,9 @@ public class WebSocketClientExtensionHandler implements ChannelHandler {
ctx.pipeline().addAfter(ctx.name(), encoder.getClass().getName(), encoder);
}
}
ctx.pipeline().remove(ctx.name());
ctx.fireChannelRead(msg);
ctx.pipeline().remove(this);
return;
}
}

View File

@ -66,7 +66,7 @@ public class WebSocketServerProtocolHandlerTest {
}
@Test
public void testWebSocketServerProtocolHandshakeHandlerReplacedBeforeHandshake() {
public void testWebSocketServerProtocolHandshakeHandlerRemovedAfterHandshake() {
EmbeddedChannel ch = createChannel(new MockOutboundHandler());
ChannelHandlerContext handshakerCtx = ch.pipeline().context(WebSocketServerProtocolHandshakeHandler.class);
ch.pipeline().addLast(new ChannelHandler() {
@ -74,7 +74,7 @@ public class WebSocketServerProtocolHandlerTest {
public void userEventTriggered(ChannelHandlerContext ctx, Object evt) {
if (evt instanceof WebSocketServerProtocolHandler.HandshakeComplete) {
// We should have removed the handler already.
assertNull(ctx.pipeline().context(WebSocketServerProtocolHandshakeHandler.class));
ctx.executor().execute(() -> ctx.pipeline().context(WebSocketServerProtocolHandshakeHandler.class));
}
}
});

View File

@ -87,9 +87,9 @@ public final class CleartextHttp2ServerUpgradeHandler extends ByteToMessageDecod
.remove(httpServerUpgradeHandler);
ctx.pipeline().addAfter(ctx.name(), null, http2ServerHandler);
ctx.pipeline().remove(this);
ctx.fireUserEventTriggered(PriorKnowledgeUpgradeEvent.INSTANCE);
ctx.pipeline().remove(this);
}
}

View File

@ -154,10 +154,6 @@ public abstract class ByteToMessageDecoder extends ChannelHandlerAdapter {
}
};
private static final byte STATE_INIT = 0;
private static final byte STATE_CALLING_CHILD_DECODE = 1;
private static final byte STATE_HANDLER_REMOVED_PENDING = 2;
ByteBuf cumulation;
private Cumulator cumulator = MERGE_CUMULATOR;
private boolean singleDecode;
@ -169,15 +165,6 @@ public abstract class ByteToMessageDecoder extends ChannelHandlerAdapter {
*/
private boolean firedChannelRead;
/**
* A bitmask where the bits are defined as
* <ul>
* <li>{@link #STATE_INIT}</li>
* <li>{@link #STATE_CALLING_CHILD_DECODE}</li>
* <li>{@link #STATE_HANDLER_REMOVED_PENDING}</li>
* </ul>
*/
private byte decodeState = STATE_INIT;
private int discardAfterReads = 16;
private int numReads;
private ByteToMessageDecoderContext context;
@ -257,10 +244,6 @@ public abstract class ByteToMessageDecoder extends ChannelHandlerAdapter {
@Override
public final void handlerRemoved(ChannelHandlerContext ctx) throws Exception {
if (decodeState == STATE_CALLING_CHILD_DECODE) {
decodeState = STATE_HANDLER_REMOVED_PENDING;
return;
}
ByteBuf buf = cumulation;
if (buf != null) {
// Directly set this to null so we are sure we not access it in any other method here anymore.
@ -469,16 +452,7 @@ public abstract class ByteToMessageDecoder extends ChannelHandlerAdapter {
*/
final void decodeRemovalReentryProtection(ChannelHandlerContext ctx, ByteBuf in)
throws Exception {
decodeState = STATE_CALLING_CHILD_DECODE;
try {
decode(ctx, in);
} finally {
boolean removePending = decodeState == STATE_HANDLER_REMOVED_PENDING;
decodeState = STATE_INIT;
if (removePending) {
handlerRemoved(ctx);
}
}
}
/**

View File

@ -97,8 +97,9 @@ public class Http2ServerInitializer extends ChannelInitializer<SocketChannel> {
System.err.println("Directly talking: " + msg.protocolVersion() + " (no upgrade was attempted)");
ChannelPipeline pipeline = ctx.pipeline();
pipeline.addAfter(ctx.name(), null, new HelloWorldHttp1Handler("Direct. No Upgrade Attempted."));
pipeline.replace(this, null, new HttpObjectAggregator(maxHttpContentLength));
pipeline.addAfter(ctx.name(), null, new HttpObjectAggregator(maxHttpContentLength));
ctx.fireChannelRead(ReferenceCountUtil.retain(msg));
pipeline.remove(this);
}
});

View File

@ -99,8 +99,9 @@ public class Http2ServerInitializer extends ChannelInitializer<SocketChannel> {
System.err.println("Directly talking: " + msg.protocolVersion() + " (no upgrade was attempted)");
ChannelPipeline pipeline = ctx.pipeline();
pipeline.addAfter(ctx.name(), null, new HelloWorldHttp1Handler("Direct. No Upgrade Attempted."));
pipeline.replace(this, null, new HttpObjectAggregator(maxHttpContentLength));
pipeline.addAfter(ctx.name(), null, new HttpObjectAggregator(maxHttpContentLength));
ctx.fireChannelRead(ReferenceCountUtil.retain(msg));
pipeline.remove(this);
}
});

View File

@ -98,8 +98,9 @@ public class Http2ServerInitializer extends ChannelInitializer<SocketChannel> {
System.err.println("Directly talking: " + msg.protocolVersion() + " (no upgrade was attempted)");
ChannelPipeline pipeline = ctx.pipeline();
pipeline.addAfter(ctx.name(), null, new HelloWorldHttp1Handler("Direct. No Upgrade Attempted."));
pipeline.replace(this, null, new HttpObjectAggregator(maxHttpContentLength));
pipeline.addAfter(ctx.name(), null, new HttpObjectAggregator(maxHttpContentLength));
ctx.fireChannelRead(ReferenceCountUtil.retain(msg));
pipeline.remove(this);
}
});

View File

@ -45,8 +45,8 @@ public final class SocksServerHandler extends SimpleChannelInboundHandler<SocksM
Socks4CommandRequest socksV4CmdRequest = (Socks4CommandRequest) socksRequest;
if (socksV4CmdRequest.type() == Socks4CommandType.CONNECT) {
ctx.pipeline().addLast(new SocksServerConnectHandler());
ctx.pipeline().remove(this);
ctx.fireChannelRead(socksRequest);
ctx.pipeline().remove(this);
} else {
ctx.close();
}
@ -65,8 +65,8 @@ public final class SocksServerHandler extends SimpleChannelInboundHandler<SocksM
Socks5CommandRequest socks5CmdRequest = (Socks5CommandRequest) socksRequest;
if (socks5CmdRequest.type() == Socks5CommandType.CONNECT) {
ctx.pipeline().addLast(new SocksServerConnectHandler());
ctx.pipeline().remove(this);
ctx.fireChannelRead(socksRequest);
ctx.pipeline().remove(this);
} else {
ctx.close();
}

View File

@ -39,44 +39,52 @@ public abstract class AbstractRemoteAddressFilter<T extends SocketAddress> imple
@Override
public void channelRegistered(ChannelHandlerContext ctx) throws Exception {
handleNewChannel(ctx);
ctx.fireChannelRegistered();
handleNewChannel(ctx, true);
}
@Override
public void channelActive(ChannelHandlerContext ctx) throws Exception {
if (!handleNewChannel(ctx)) {
if (!handleNewChannel(ctx, false)) {
throw new IllegalStateException("cannot determine to accept or reject a channel: " + ctx.channel());
} else {
ctx.fireChannelActive();
}
}
private boolean handleNewChannel(ChannelHandlerContext ctx) throws Exception {
private boolean handleNewChannel(ChannelHandlerContext ctx, boolean register) throws Exception {
@SuppressWarnings("unchecked")
T remoteAddress = (T) ctx.channel().remoteAddress();
boolean remove = false;
try {
// If the remote address is not available yet, defer the decision.
if (remoteAddress == null) {
return false;
}
// No need to keep this handler in the pipeline anymore because the decision is going to be made now.
// Also, this will prevent the subsequent events from being handled by this handler.
ctx.pipeline().remove(this);
if (accept(ctx, remoteAddress)) {
channelAccepted(ctx, remoteAddress);
remove = true;
} else {
ChannelFuture rejectedFuture = channelRejected(ctx, remoteAddress);
if (rejectedFuture != null) {
if (rejectedFuture != null && !rejectedFuture.isDone()) {
rejectedFuture.addListener(ChannelFutureListener.CLOSE);
} else {
ctx.close();
}
}
return true;
} finally {
if (!ctx.isRemoved()) {
if (register) {
ctx.fireChannelRegistered();
} else {
ctx.fireChannelActive();
}
if (remove) {
// No need to keep this handler in the pipeline anymore because the decision is going to be made
// now. Also, this will prevent the subsequent events from being handled by this handler.
ctx.pipeline().remove(this);
}
}
}
}
/**

View File

@ -97,14 +97,17 @@ public abstract class ApplicationProtocolNegotiationHandler implements ChannelHa
} catch (Throwable cause) {
exceptionCaught(ctx, cause);
} finally {
ctx.fireUserEventTriggered(evt);
ChannelPipeline pipeline = ctx.pipeline();
if (pipeline.context(this) != null) {
pipeline.remove(this);
}
}
}
} else {
ctx.fireUserEventTriggered(evt);
}
}
/**
* Invoked on successful initial SSL/TLS handshake. Implement this method to configure your pipeline

View File

@ -63,7 +63,7 @@ public abstract class SslClientHelloHandler<T> extends ByteToMessageDecoder {
"not an SSL/TLS record: " + ByteBufUtil.hexDump(in));
in.skipBytes(in.readableBytes());
ctx.fireUserEventTriggered(new SniCompletionEvent(e));
SslUtils.handleHandshakeFailure(ctx, e, true);
ctx.fireUserEventTriggered(new SslHandshakeCompletionEvent(e));
throw e;
}
if (len == SslUtils.NOT_ENOUGH_DATA) {

View File

@ -1091,7 +1091,8 @@ public class SslHandler extends ByteToMessageDecoder {
if (logger.isDebugEnabled()) {
logger.debug(
"{} Swallowing a harmless 'connection reset by peer / broken pipe' error that occurred " +
"while writing close_notify in response to the peer's close_notify", ctx.channel(), cause);
"while writing close_notify in response to the peer's close_notify",
ctx.channel(), cause);
}
// Close the connection explicitly just in case the transport
@ -1101,6 +1102,11 @@ public class SslHandler extends ByteToMessageDecoder {
}
} else {
ctx.fireExceptionCaught(cause);
if (cause instanceof SSLException ||
((cause instanceof DecoderException) && cause.getCause() instanceof SSLException)) {
ctx.close();
}
}
}
@ -2057,9 +2063,13 @@ public class SslHandler extends ByteToMessageDecoder {
}
final long closeNotifyReadTimeout = closeNotifyReadTimeoutMillis;
if (closeNotifyReadTimeout <= 0) {
if (ctx.channel().isActive()) {
// Trigger the close in all cases to make sure the promise is notified
// See https://github.com/netty/netty/issues/2358
addCloseListener(ctx.close(ctx.newPromise()), promise);
} else {
promise.trySuccess();
}
} else {
final ScheduledFuture<?> closeNotifyReadTimeoutFuture;
@ -2083,7 +2093,11 @@ public class SslHandler extends ByteToMessageDecoder {
if (closeNotifyReadTimeoutFuture != null) {
closeNotifyReadTimeoutFuture.cancel(false);
}
if (ctx.channel().isActive()) {
addCloseListener(ctx.close(ctx.newPromise()), promise);
} else {
promise.trySuccess();
}
});
}
});

View File

@ -346,7 +346,6 @@ final class SslUtils {
if (notify) {
ctx.fireUserEventTriggered(new SslHandshakeCompletionEvent(cause));
}
ctx.close();
}
/**

View File

@ -48,15 +48,13 @@ public abstract class OcspClientHandler implements ChannelHandler {
@Override
public void userEventTriggered(ChannelHandlerContext ctx, Object evt) throws Exception {
ctx.fireUserEventTriggered(evt);
if (evt instanceof SslHandshakeCompletionEvent) {
ctx.pipeline().remove(this);
SslHandshakeCompletionEvent event = (SslHandshakeCompletionEvent) evt;
if (event.isSuccess() && !verify(ctx, engine)) {
throw new SSLHandshakeException("Bad OCSP response");
}
ctx.pipeline().remove(this);
}
ctx.fireUserEventTriggered(evt);
}
}

View File

@ -442,11 +442,13 @@ public class SslHandlerTest {
sslHandler.setHandshakeTimeoutMillis(1000);
ch.pipeline().addFirst(sslHandler);
sslHandler.handshakeFuture().addListener((FutureListener<Channel>) future -> {
ch.eventLoop().execute(() -> {
ch.pipeline().remove(sslHandler);
// Schedule the close so removal has time to propagate exception if any.
ch.eventLoop().execute(ch::close);
});
});
ch.pipeline().addLast(new ChannelHandler() {
@Override
@ -457,6 +459,7 @@ public class SslHandlerTest {
if (cause instanceof IllegalReferenceCountException) {
promise.setFailure(cause);
}
cause.printStackTrace();
}
@Override

View File

@ -85,8 +85,9 @@ public class Http2ServerInitializer extends ChannelInitializer<SocketChannel> {
ChannelPipeline pipeline = ctx.pipeline();
ChannelHandlerContext thisCtx = pipeline.context(this);
pipeline.addAfter(thisCtx.name(), null, new HelloWorldHttp1Handler("Direct. No Upgrade Attempted."));
pipeline.replace(this, null, new HttpObjectAggregator(maxHttpContentLength));
pipeline.addAfter(thisCtx.name(), null, new HttpObjectAggregator(maxHttpContentLength));
ctx.fireChannelRead(ReferenceCountUtil.retain(msg));
pipeline.remove(this);
}
});

View File

@ -36,6 +36,7 @@ import java.net.SocketAddress;
import java.net.SocketException;
import java.nio.channels.ClosedChannelException;
import java.nio.channels.NotYetConnectedException;
import java.util.NoSuchElementException;
import java.util.concurrent.Executor;
import java.util.concurrent.RejectedExecutionException;
@ -781,6 +782,19 @@ public abstract class AbstractChannel extends DefaultAttributeMap implements Cha
if (registered) {
registered = false;
pipeline.fireChannelUnregistered();
if (!isOpen()) {
// Remove all handlers from the ChannelPipeline. This is needed to ensure
// handlerRemoved(...) is called and so resources are released.
while (!pipeline.isEmpty()) {
try {
pipeline.removeLast();
} catch (NoSuchElementException ignore) {
// try again as there may be a race when someone outside the EventLoop removes
// handlers concurrently as well.
}
}
}
}
safeSetSuccess(promise);
}

View File

@ -437,6 +437,14 @@ public interface ChannelPipeline
*/
ChannelHandlerContext lastContext();
/**
* Returns {@code true} if this {@link ChannelPipeline} is empty, which means no {@link ChannelHandler} is
* present.
*/
default boolean isEmpty() {
return lastContext() == null;
}
/**
* Returns the {@link ChannelHandler} with the specified name in this
* pipeline.

View File

@ -50,10 +50,15 @@ final class DefaultChannelHandlerContext implements ChannelHandlerContext, Resou
*/
private static final int ADD_COMPLETE = 1;
/**
* {@link ChannelHandler#handlerRemoved(ChannelHandlerContext)} is about to be called.
*/
private static final int REMOVE_STARTED = 2;
/**
* {@link ChannelHandler#handlerRemoved(ChannelHandlerContext)} was called.
*/
private static final int REMOVE_COMPLETE = 2;
private static final int REMOVE_COMPLETE = 3;
private final int executionMask;
private final DefaultChannelPipeline pipeline;
@ -63,10 +68,10 @@ final class DefaultChannelHandlerContext implements ChannelHandlerContext, Resou
// Lazily instantiated tasks used to trigger events to a handler with different executor.
// There is no need to make this volatile as at worse it will just create a few more instances then needed.
private Tasks invokeTasks;
private int handlerState = INIT;
DefaultChannelHandlerContext next;
DefaultChannelHandlerContext prev;
private int handlerState = INIT;
DefaultChannelHandlerContext(DefaultChannelPipeline pipeline, String name,
ChannelHandler handler) {
@ -76,6 +81,22 @@ final class DefaultChannelHandlerContext implements ChannelHandlerContext, Resou
this.handler = handler;
}
private static void failRemoved(DefaultChannelHandlerContext ctx, ChannelPromise promise) {
promise.setFailure(newRemovedException(ctx, null));
}
private void notifyHandlerRemovedAlready() {
notifyHandlerRemovedAlready(null);
}
private void notifyHandlerRemovedAlready(Throwable cause) {
pipeline().fireExceptionCaught(newRemovedException(this, cause));
}
private static ChannelPipelineException newRemovedException(ChannelHandlerContext ctx, Throwable cause) {
return new ChannelPipelineException("Context " + ctx + " already removed", cause);
}
private Tasks invokeTasks() {
Tasks tasks = invokeTasks;
if (tasks == null) {
@ -126,7 +147,12 @@ final class DefaultChannelHandlerContext implements ChannelHandlerContext, Resou
}
private void findAndInvokeChannelRegistered() {
findContextInbound(MASK_CHANNEL_REGISTERED).invokeChannelRegistered();
DefaultChannelHandlerContext ctx = findContextInbound(MASK_CHANNEL_REGISTERED);
if (ctx == null) {
notifyHandlerRemovedAlready();
return;
}
ctx.invokeChannelRegistered();
}
void invokeChannelRegistered() {
@ -149,7 +175,12 @@ final class DefaultChannelHandlerContext implements ChannelHandlerContext, Resou
}
private void findAndInvokeChannelUnregistered() {
findContextInbound(MASK_CHANNEL_UNREGISTERED).invokeChannelUnregistered();
DefaultChannelHandlerContext ctx = findContextInbound(MASK_CHANNEL_UNREGISTERED);
if (ctx == null) {
notifyHandlerRemovedAlready();
return;
}
ctx.invokeChannelUnregistered();
}
void invokeChannelUnregistered() {
@ -172,7 +203,12 @@ final class DefaultChannelHandlerContext implements ChannelHandlerContext, Resou
}
private void findAndInvokeChannelActive() {
findContextInbound(MASK_CHANNEL_ACTIVE).invokeChannelActive();
DefaultChannelHandlerContext ctx = findContextInbound(MASK_CHANNEL_ACTIVE);
if (ctx == null) {
notifyHandlerRemovedAlready();
return;
}
ctx.invokeChannelActive();
}
void invokeChannelActive() {
@ -195,7 +231,12 @@ final class DefaultChannelHandlerContext implements ChannelHandlerContext, Resou
}
private void findAndInvokeChannelInactive() {
findContextInbound(MASK_CHANNEL_INACTIVE).invokeChannelInactive();
DefaultChannelHandlerContext ctx = findContextInbound(MASK_CHANNEL_INACTIVE);
if (ctx == null) {
notifyHandlerRemovedAlready();
return;
}
ctx.invokeChannelInactive();
}
void invokeChannelInactive() {
@ -226,7 +267,12 @@ final class DefaultChannelHandlerContext implements ChannelHandlerContext, Resou
}
private void findAndInvokeExceptionCaught(Throwable cause) {
findContextInbound(MASK_EXCEPTION_CAUGHT).invokeExceptionCaught(cause);
DefaultChannelHandlerContext ctx = findContextInbound(MASK_EXCEPTION_CAUGHT);
if (ctx == null) {
notifyHandlerRemovedAlready(cause);
return;
}
ctx.invokeExceptionCaught(cause);
}
void invokeExceptionCaught(final Throwable cause) {
@ -261,7 +307,13 @@ final class DefaultChannelHandlerContext implements ChannelHandlerContext, Resou
}
private void findAndInvokeUserEventTriggered(Object event) {
findContextInbound(MASK_USER_EVENT_TRIGGERED).invokeUserEventTriggered(event);
DefaultChannelHandlerContext ctx = findContextInbound(MASK_USER_EVENT_TRIGGERED);
if (ctx == null) {
ReferenceCountUtil.release(event);
notifyHandlerRemovedAlready();
return;
}
ctx.invokeUserEventTriggered(event);
}
void invokeUserEventTriggered(Object event) {
@ -290,7 +342,13 @@ final class DefaultChannelHandlerContext implements ChannelHandlerContext, Resou
}
private void findAndInvokeChannelRead(Object msg) {
findContextInbound(MASK_CHANNEL_READ).invokeChannelRead(msg);
DefaultChannelHandlerContext ctx = findContextInbound(MASK_CHANNEL_READ);
if (ctx == null) {
ReferenceCountUtil.release(msg);
notifyHandlerRemovedAlready();
return;
}
ctx.invokeChannelRead(msg);
}
void invokeChannelRead(Object msg) {
@ -315,7 +373,12 @@ final class DefaultChannelHandlerContext implements ChannelHandlerContext, Resou
}
private void findAndInvokeChannelReadComplete() {
findContextInbound(MASK_CHANNEL_READ_COMPLETE).invokeChannelReadComplete();
DefaultChannelHandlerContext ctx = findContextInbound(MASK_CHANNEL_READ_COMPLETE);
if (ctx == null) {
notifyHandlerRemovedAlready();
return;
}
ctx.invokeChannelReadComplete();
}
void invokeChannelReadComplete() {
@ -339,7 +402,12 @@ final class DefaultChannelHandlerContext implements ChannelHandlerContext, Resou
}
private void findAndInvokeChannelWritabilityChanged() {
findContextInbound(MASK_CHANNEL_WRITABILITY_CHANGED).invokeChannelWritabilityChanged();
DefaultChannelHandlerContext ctx = findContextInbound(MASK_CHANNEL_WRITABILITY_CHANGED);
if (ctx == null) {
notifyHandlerRemovedAlready();
return;
}
ctx.invokeChannelWritabilityChanged();
}
void invokeChannelWritabilityChanged() {
@ -403,7 +471,12 @@ final class DefaultChannelHandlerContext implements ChannelHandlerContext, Resou
}
private void findAndInvokeBind(SocketAddress localAddress, ChannelPromise promise) {
findContextOutbound(MASK_BIND).invokeBind(localAddress, promise);
DefaultChannelHandlerContext ctx = findContextOutbound(MASK_BIND);
if (ctx == null) {
failRemoved(this, promise);
return;
}
ctx.invokeBind(localAddress, promise);
}
private void invokeBind(SocketAddress localAddress, ChannelPromise promise) {
@ -438,7 +511,12 @@ final class DefaultChannelHandlerContext implements ChannelHandlerContext, Resou
}
private void findAndInvokeConnect(SocketAddress remoteAddress, SocketAddress localAddress, ChannelPromise promise) {
findContextOutbound(MASK_CONNECT).invokeConnect(remoteAddress, localAddress, promise);
DefaultChannelHandlerContext ctx = findContextOutbound(MASK_CONNECT);
if (ctx == null) {
failRemoved(this, promise);
return;
}
ctx.invokeConnect(remoteAddress, localAddress, promise);
}
private void invokeConnect(SocketAddress remoteAddress, SocketAddress localAddress, ChannelPromise promise) {
@ -472,7 +550,12 @@ final class DefaultChannelHandlerContext implements ChannelHandlerContext, Resou
}
private void findAndInvokeDisconnect(ChannelPromise promise) {
findContextOutbound(MASK_DISCONNECT).invokeDisconnect(promise);
DefaultChannelHandlerContext ctx = findContextOutbound(MASK_DISCONNECT);
if (ctx == null) {
failRemoved(this, promise);
return;
}
ctx.invokeDisconnect(promise);
}
private void invokeDisconnect(ChannelPromise promise) {
@ -500,7 +583,12 @@ final class DefaultChannelHandlerContext implements ChannelHandlerContext, Resou
}
private void findAndInvokeClose(ChannelPromise promise) {
findContextOutbound(MASK_CLOSE).invokeClose(promise);
DefaultChannelHandlerContext ctx = findContextOutbound(MASK_CLOSE);
if (ctx == null) {
failRemoved(this, promise);
return;
}
ctx.invokeClose(promise);
}
private void invokeClose(ChannelPromise promise) {
@ -528,7 +616,12 @@ final class DefaultChannelHandlerContext implements ChannelHandlerContext, Resou
}
private void findAndInvokeRegister(ChannelPromise promise) {
findContextOutbound(MASK_REGISTER).invokeRegister(promise);
DefaultChannelHandlerContext ctx = findContextOutbound(MASK_REGISTER);
if (ctx == null) {
failRemoved(this, promise);
return;
}
ctx.invokeRegister(promise);
}
private void invokeRegister(ChannelPromise promise) {
@ -556,7 +649,12 @@ final class DefaultChannelHandlerContext implements ChannelHandlerContext, Resou
}
private void findAndInvokeDeregister(ChannelPromise promise) {
findContextOutbound(MASK_DEREGISTER).invokeDeregister(promise);
DefaultChannelHandlerContext ctx = findContextOutbound(MASK_DEREGISTER);
if (ctx == null) {
failRemoved(this, promise);
return;
}
ctx.invokeDeregister(promise);
}
private void invokeDeregister(ChannelPromise promise) {
@ -580,7 +678,10 @@ final class DefaultChannelHandlerContext implements ChannelHandlerContext, Resou
}
private void findAndInvokeRead() {
findContextOutbound(MASK_READ).invokeRead();
DefaultChannelHandlerContext ctx = findContextOutbound(MASK_READ);
if (ctx != null) {
ctx.invokeRead();
}
}
private void invokeRead() {
@ -595,7 +696,12 @@ final class DefaultChannelHandlerContext implements ChannelHandlerContext, Resou
if ((executionMask & MASK_EXCEPTION_CAUGHT) != 0) {
notifyHandlerException(t);
} else {
findContextInbound(MASK_EXCEPTION_CAUGHT).notifyHandlerException(t);
DefaultChannelHandlerContext ctx = findContextInbound(MASK_EXCEPTION_CAUGHT);
if (ctx == null) {
notifyHandlerRemovedAlready();
return;
}
ctx.invokeExceptionCaught(t);
}
}
@ -634,7 +740,10 @@ final class DefaultChannelHandlerContext implements ChannelHandlerContext, Resou
}
private void findAndInvokeFlush() {
findContextOutbound(MASK_FLUSH).invokeFlush();
DefaultChannelHandlerContext ctx = findContextOutbound(MASK_FLUSH);
if (ctx != null) {
ctx.invokeFlush();
}
}
private void invokeFlush() {
@ -673,6 +782,11 @@ final class DefaultChannelHandlerContext implements ChannelHandlerContext, Resou
if (executor.inEventLoop()) {
final DefaultChannelHandlerContext next = findContextOutbound(flush ?
(MASK_WRITE | MASK_FLUSH) : MASK_WRITE);
if (next == null) {
ReferenceCountUtil.release(msg);
failRemoved(this, promise);
return;
}
if (flush) {
next.invokeWriteAndFlush(msg, promise);
} else {
@ -796,17 +910,23 @@ final class DefaultChannelHandlerContext implements ChannelHandlerContext, Resou
private DefaultChannelHandlerContext findContextInbound(int mask) {
DefaultChannelHandlerContext ctx = this;
if (ctx.next == null) {
return null;
}
do {
ctx = ctx.next;
} while ((ctx.executionMask & mask) == 0);
} while ((ctx.executionMask & mask) == 0 || ctx.handlerState == REMOVE_STARTED);
return ctx;
}
private DefaultChannelHandlerContext findContextOutbound(int mask) {
DefaultChannelHandlerContext ctx = this;
if (ctx.prev == null) {
return null;
}
do {
ctx = ctx.prev;
} while ((ctx.executionMask & mask) == 0);
} while ((ctx.executionMask & mask) == 0 || ctx.handlerState == REMOVE_STARTED);
return ctx;
}
@ -815,15 +935,11 @@ final class DefaultChannelHandlerContext implements ChannelHandlerContext, Resou
return channel().voidPromise();
}
private void setRemoved() {
handlerState = REMOVE_COMPLETE;
}
boolean setAddComplete() {
// Ensure we never update when the handlerState is REMOVE_COMPLETE already.
// oldState is usually ADD_PENDING but can also be REMOVE_COMPLETE when an EventExecutor is used that is not
// exposing ordering guarantees.
if (handlerState != REMOVE_COMPLETE) {
if (handlerState == INIT) {
handlerState = ADD_COMPLETE;
return true;
}
@ -842,11 +958,12 @@ final class DefaultChannelHandlerContext implements ChannelHandlerContext, Resou
try {
// Only call handlerRemoved(...) if we called handlerAdded(...) before.
if (handlerState == ADD_COMPLETE) {
handlerState = REMOVE_STARTED;
handler().handlerRemoved(this);
}
} finally {
// Mark the handler as removed in any case.
setRemoved();
handlerState = REMOVE_COMPLETE;
}
}
@ -855,6 +972,24 @@ final class DefaultChannelHandlerContext implements ChannelHandlerContext, Resou
return handlerState == REMOVE_COMPLETE;
}
void remove(boolean relink) {
assert handlerState == REMOVE_COMPLETE;
if (relink) {
DefaultChannelHandlerContext prev = this.prev;
DefaultChannelHandlerContext next = this.next;
// prev and next may be null if the handler was never really added to the pipeline
if (prev != null) {
prev.next = next;
}
if (next != null) {
next.prev = prev;
}
}
this.prev = null;
this.next = null;
}
@Override
public <T> Attribute<T> attr(AttributeKey<T> key) {
return channel().attr(key);
@ -931,6 +1066,11 @@ final class DefaultChannelHandlerContext implements ChannelHandlerContext, Resou
try {
decrementPendingOutboundBytes();
DefaultChannelHandlerContext next = findContext(ctx);
if (next == null) {
ReferenceCountUtil.release(msg);
failRemoved(ctx, promise);
return;
}
write(next, msg, promise);
} finally {
recycle();

View File

@ -48,6 +48,7 @@ public class DefaultChannelPipeline implements ChannelPipeline {
private static final InternalLogger logger = InternalLoggerFactory.getInstance(DefaultChannelPipeline.class);
private static final String HEAD_NAME = generateName0(HeadHandler.class);
private static final String TAIL_NAME = generateName0(TailHandler.class);
private static final ChannelHandler HEAD_HANDLER = new HeadHandler();
private static final ChannelHandler TAIL_HANDLER = new TailHandler();
@ -64,6 +65,7 @@ public class DefaultChannelPipeline implements ChannelPipeline {
DefaultChannelPipeline.class, MessageSizeEstimator.Handle.class, "estimatorHandle");
private final DefaultChannelHandlerContext head;
private final DefaultChannelHandlerContext tail;
private final Channel channel;
private final ChannelFuture succeededFuture;
private final VoidChannelPromise voidPromise;
@ -102,9 +104,26 @@ public class DefaultChannelPipeline implements ChannelPipeline {
}
private DefaultChannelHandlerContext newContext(String name, ChannelHandler handler) {
checkMultiplicity(handler);
if (name == null) {
name = generateName(handler);
}
return new DefaultChannelHandlerContext(this, name, handler);
}
private void checkDuplicateName(String name) {
if (context(name) != null) {
throw new IllegalArgumentException("Duplicate handler name: " + name);
}
}
private int checkNoSuchElement(int idx, String name) {
if (idx == -1) {
throw new NoSuchElementException(name);
}
return idx;
}
@Override
public final Channel channel() {
return channel;
@ -117,18 +136,12 @@ public class DefaultChannelPipeline implements ChannelPipeline {
@Override
public final ChannelPipeline addFirst(String name, ChannelHandler handler) {
checkMultiplicity(handler);
if (name == null) {
name = generateName(handler);
}
DefaultChannelHandlerContext newCtx = newContext(name, handler);
EventExecutor executor = executor();
boolean inEventLoop = executor.inEventLoop();
synchronized (handlers) {
if (context(name) != null) {
throw new IllegalArgumentException("Duplicate handler name: " + name);
}
checkDuplicateName(newCtx.name());
handlers.add(0, newCtx);
if (!inEventLoop) {
try {
@ -156,18 +169,12 @@ public class DefaultChannelPipeline implements ChannelPipeline {
@Override
public final ChannelPipeline addLast(String name, ChannelHandler handler) {
checkMultiplicity(handler);
if (name == null) {
name = generateName(handler);
}
DefaultChannelHandlerContext newCtx = newContext(name, handler);
EventExecutor executor = executor();
boolean inEventLoop = executor.inEventLoop();
synchronized (handlers) {
if (context(name) != null) {
throw new IllegalArgumentException("Duplicate handler name: " + name);
}
checkDuplicateName(newCtx.name());
handlers.add(newCtx);
if (!inEventLoop) {
try {
@ -197,24 +204,13 @@ public class DefaultChannelPipeline implements ChannelPipeline {
public final ChannelPipeline addBefore(String baseName, String name, ChannelHandler handler) {
final DefaultChannelHandlerContext ctx;
checkMultiplicity(handler);
if (name == null) {
name = generateName(handler);
}
DefaultChannelHandlerContext newCtx = newContext(name, handler);
EventExecutor executor = executor();
boolean inEventLoop = executor.inEventLoop();
synchronized (handlers) {
int i = findCtxIdx(context -> context.name().equals(baseName));
int i = checkNoSuchElement(findCtxIdx(context -> context.name().equals(baseName)), baseName);
checkDuplicateName(newCtx.name());
if (i == -1) {
throw new NoSuchElementException(baseName);
}
if (context(name) != null) {
throw new IllegalArgumentException("Duplicate handler name: " + name);
}
ctx = handlers.get(i);
handlers.add(i, newCtx);
if (!inEventLoop) {
@ -244,7 +240,6 @@ public class DefaultChannelPipeline implements ChannelPipeline {
public final ChannelPipeline addAfter(String baseName, String name, ChannelHandler handler) {
final DefaultChannelHandlerContext ctx;
checkMultiplicity(handler);
if (name == null) {
name = generateName(handler);
}
@ -253,15 +248,10 @@ public class DefaultChannelPipeline implements ChannelPipeline {
EventExecutor executor = executor();
boolean inEventLoop = executor.inEventLoop();
synchronized (handlers) {
int i = findCtxIdx(context -> context.name().equals(baseName));
int i = checkNoSuchElement(findCtxIdx(context -> context.name().equals(baseName)), baseName);
if (i == -1) {
throw new NoSuchElementException(baseName);
}
checkDuplicateName(newCtx.name());
if (context(name) != null) {
throw new IllegalArgumentException("Duplicate handler name: " + name);
}
ctx = handlers.get(i);
handlers.add(i + 1, newCtx);
if (!inEventLoop) {
@ -377,21 +367,13 @@ public class DefaultChannelPipeline implements ChannelPipeline {
EventExecutor executor = executor();
boolean inEventLoop = executor.inEventLoop();
synchronized (handlers) {
int idx = findCtxIdx(context -> context.handler() == handler);
if (idx == -1) {
throw new NoSuchElementException();
}
int idx = checkNoSuchElement(findCtxIdx(context -> context.handler() == handler), null);
ctx = handlers.remove(idx);
assert ctx != null;
if (!inEventLoop) {
try {
executor.execute(() -> remove0(ctx));
scheduleRemove(idx, ctx);
return this;
} catch (Throwable cause) {
handlers.add(idx, ctx);
throw cause;
}
}
}
@ -405,21 +387,12 @@ public class DefaultChannelPipeline implements ChannelPipeline {
EventExecutor executor = executor();
boolean inEventLoop = executor.inEventLoop();
synchronized (handlers) {
int idx = findCtxIdx(context -> context.name().equals(name));
if (idx == -1) {
throw new NoSuchElementException();
}
int idx = checkNoSuchElement(findCtxIdx(context -> context.name().equals(name)), name);
ctx = handlers.remove(idx);
assert ctx != null;
if (!inEventLoop) {
try {
executor.execute(() -> remove0(ctx));
return ctx.handler();
} catch (Throwable cause) {
handlers.add(idx, ctx);
throw cause;
}
return scheduleRemove(idx, ctx);
}
}
@ -434,21 +407,13 @@ public class DefaultChannelPipeline implements ChannelPipeline {
EventExecutor executor = executor();
boolean inEventLoop = executor.inEventLoop();
synchronized (handlers) {
int idx = findCtxIdx(context -> handlerType.isAssignableFrom(context.handler().getClass()));
if (idx == -1) {
throw new NoSuchElementException();
}
int idx = checkNoSuchElement(findCtxIdx(
context -> handlerType.isAssignableFrom(context.handler().getClass())), null);
ctx = handlers.remove(idx);
assert ctx != null;
if (!inEventLoop) {
try {
executor.execute(() -> remove0(ctx));
return (T) ctx.handler();
} catch (Throwable cause) {
handlers.add(idx, ctx);
throw cause;
}
return scheduleRemove(idx, ctx);
}
}
@ -483,30 +448,19 @@ public class DefaultChannelPipeline implements ChannelPipeline {
assert ctx != null;
if (!inEventLoop) {
try {
executor.execute(() -> remove0(ctx));
return (T) ctx.handler();
} catch (Throwable cause) {
handlers.add(idx, ctx);
throw cause;
}
return scheduleRemove(idx, ctx);
}
}
remove0(ctx);
return (T) ctx.handler();
}
private void unlink(DefaultChannelHandlerContext ctx) {
assert ctx != head && ctx != tail;
DefaultChannelHandlerContext prev = ctx.prev;
DefaultChannelHandlerContext next = ctx.next;
prev.next = next;
next.prev = prev;
}
private void remove0(DefaultChannelHandlerContext ctx) {
unlink(ctx);
try {
callHandlerRemoved0(ctx);
} finally {
ctx.remove(true);
}
}
@Override
@ -529,27 +483,17 @@ public class DefaultChannelPipeline implements ChannelPipeline {
private ChannelHandler replace(
Predicate<DefaultChannelHandlerContext> predicate, String newName, ChannelHandler newHandler) {
checkMultiplicity(newHandler);
if (newName == null) {
newName = generateName(newHandler);
}
DefaultChannelHandlerContext oldCtx;
DefaultChannelHandlerContext newCtx = newContext(newName, newHandler);
EventExecutor executor = executor();
boolean inEventLoop = executor.inEventLoop();
synchronized (handlers) {
int idx = findCtxIdx(predicate);
if (idx == -1) {
throw new NoSuchElementException();
}
int idx = checkNoSuchElement(findCtxIdx(predicate), null);
oldCtx = handlers.get(idx);
assert oldCtx != head && oldCtx != tail && oldCtx != null;
if (!oldCtx.name().equals(newName)) {
if (context(newName) != null) {
throw new IllegalArgumentException("Duplicate handler name: " + newName);
}
if (!oldCtx.name().equals(newCtx.name())) {
checkDuplicateName(newCtx.name());
}
DefaultChannelHandlerContext removed = handlers.set(idx, newCtx);
assert removed != null;
@ -586,11 +530,15 @@ public class DefaultChannelPipeline implements ChannelPipeline {
oldCtx.prev = newCtx;
oldCtx.next = newCtx;
try {
// Invoke newHandler.handlerAdded() first (i.e. before oldHandler.handlerRemoved() is invoked)
// because callHandlerRemoved() will trigger channelRead() or flush() on newHandler and those
// event handlers must be called after handlerAdded().
callHandlerAdded0(newCtx);
callHandlerRemoved0(oldCtx);
} finally {
oldCtx.remove(false);
}
}
private static void checkMultiplicity(ChannelHandler handler) {
@ -615,7 +563,6 @@ public class DefaultChannelPipeline implements ChannelPipeline {
handlers.remove(ctx);
}
unlink(ctx);
ctx.callHandlerRemoved();
removed = true;
@ -623,6 +570,8 @@ public class DefaultChannelPipeline implements ChannelPipeline {
if (logger.isWarnEnabled()) {
logger.warn("Failed to remove a handler: " + ctx.name(), t2);
}
} finally {
ctx.remove(true);
}
if (removed) {
@ -749,13 +698,7 @@ public class DefaultChannelPipeline implements ChannelPipeline {
assert ctx != null;
if (!inEventLoop) {
try {
executor.execute(() -> remove0(ctx));
return ctx.handler();
} catch (Throwable cause) {
handlers.add(idx, ctx);
throw cause;
}
return scheduleRemove(idx, ctx);
}
}
remove0(ctx);
@ -777,19 +720,24 @@ public class DefaultChannelPipeline implements ChannelPipeline {
assert ctx != null;
if (!inEventLoop) {
try {
executor.execute(() -> remove0(ctx));
return ctx.handler();
} catch (Throwable cause) {
handlers.add(idx, ctx);
throw cause;
}
return scheduleRemove(idx, ctx);
}
}
remove0(ctx);
return ctx.handler();
}
@SuppressWarnings("unchecked")
private <T extends ChannelHandler> T scheduleRemove(int idx, DefaultChannelHandlerContext ctx) {
try {
ctx.executor().execute(() -> remove0(ctx));
return (T) ctx.handler();
} catch (Throwable cause) {
handlers.add(idx, ctx);
throw cause;
}
}
@Override
public ChannelHandler first() {
ChannelHandlerContext ctx = firstContext();
@ -849,32 +797,6 @@ public class DefaultChannelPipeline implements ChannelPipeline {
return this;
}
/**
* Removes all handlers from the pipeline one by one from tail (exclusive) to head (exclusive) to trigger
* handlerRemoved().
*/
private void destroy() {
EventExecutor executor = executor();
if (executor.inEventLoop()) {
destroy0();
} else {
executor.execute(this::destroy0);
}
}
private void destroy0() {
assert executor().inEventLoop();
DefaultChannelHandlerContext ctx = this.tail.prev;
while (ctx != head) {
synchronized (handlers) {
handlers.remove(ctx);
}
remove0(ctx);
ctx = ctx.prev;
}
}
@Override
public final ChannelPipeline fireChannelActive() {
head.invokeChannelActive();
@ -980,7 +902,7 @@ public class DefaultChannelPipeline implements ChannelPipeline {
}
@Override
public final ChannelFuture close(ChannelPromise promise) {
public ChannelFuture close(ChannelPromise promise) {
return tail.close(promise);
}
@ -1135,10 +1057,14 @@ public class DefaultChannelPipeline implements ChannelPipeline {
private static final class TailHandler implements ChannelHandler {
@Override
public void channelRegistered(ChannelHandlerContext ctx) { }
public void channelRegistered(ChannelHandlerContext ctx) {
// Just swallow event
}
@Override
public void channelUnregistered(ChannelHandlerContext ctx) { }
public void channelUnregistered(ChannelHandlerContext ctx) {
// Just swallow event
}
@Override
public void channelActive(ChannelHandlerContext ctx) {
@ -1155,12 +1081,6 @@ public class DefaultChannelPipeline implements ChannelPipeline {
((DefaultChannelPipeline) ctx.pipeline()).onUnhandledChannelWritabilityChanged();
}
@Override
public void handlerAdded(ChannelHandlerContext ctx) { }
@Override
public void handlerRemoved(ChannelHandlerContext ctx) { }
@Override
public void userEventTriggered(ChannelHandlerContext ctx, Object evt) {
((DefaultChannelPipeline) ctx.pipeline()).onUnhandledInboundUserEventTriggered(evt);
@ -1232,15 +1152,5 @@ public class DefaultChannelPipeline implements ChannelPipeline {
public void flush(ChannelHandlerContext ctx) {
ctx.channel().unsafe().flush();
}
@Override
public void channelUnregistered(ChannelHandlerContext ctx) {
ctx.fireChannelUnregistered();
// Remove all handlers sequentially if channel is closed and unregistered.
if (!ctx.channel().isOpen()) {
((DefaultChannelPipeline) ctx.pipeline()).destroy();
}
}
}
}

View File

@ -782,64 +782,70 @@ public class EmbeddedChannel extends AbstractChannel {
return EmbeddedUnsafe.this.remoteAddress();
}
private void mayRunPendingTasks() {
if (!((EmbeddedEventLoop) eventLoop()).running) {
runPendingTasks();
}
}
@Override
public void register(ChannelPromise promise) {
EmbeddedUnsafe.this.register(promise);
runPendingTasks();
mayRunPendingTasks();
}
@Override
public void bind(SocketAddress localAddress, ChannelPromise promise) {
EmbeddedUnsafe.this.bind(localAddress, promise);
runPendingTasks();
mayRunPendingTasks();
}
@Override
public void connect(SocketAddress remoteAddress, SocketAddress localAddress, ChannelPromise promise) {
EmbeddedUnsafe.this.connect(remoteAddress, localAddress, promise);
runPendingTasks();
mayRunPendingTasks();
}
@Override
public void disconnect(ChannelPromise promise) {
EmbeddedUnsafe.this.disconnect(promise);
runPendingTasks();
mayRunPendingTasks();
}
@Override
public void close(ChannelPromise promise) {
EmbeddedUnsafe.this.close(promise);
runPendingTasks();
mayRunPendingTasks();
}
@Override
public void closeForcibly() {
EmbeddedUnsafe.this.closeForcibly();
runPendingTasks();
mayRunPendingTasks();
}
@Override
public void deregister(ChannelPromise promise) {
EmbeddedUnsafe.this.deregister(promise);
runPendingTasks();
mayRunPendingTasks();
}
@Override
public void beginRead() {
EmbeddedUnsafe.this.beginRead();
runPendingTasks();
mayRunPendingTasks();
}
@Override
public void write(Object msg, ChannelPromise promise) {
EmbeddedUnsafe.this.write(msg, promise);
runPendingTasks();
mayRunPendingTasks();
}
@Override
public void flush() {
EmbeddedUnsafe.this.flush();
runPendingTasks();
mayRunPendingTasks();
}
@Override

View File

@ -30,7 +30,7 @@ import java.util.concurrent.TimeUnit;
final class EmbeddedEventLoop extends AbstractScheduledEventExecutor implements EventLoop {
private final Queue<Runnable> tasks = new ArrayDeque<>(2);
private boolean running;
boolean running;
private static EmbeddedChannel cast(Channel channel) {
if (channel instanceof EmbeddedChannel) {

View File

@ -38,6 +38,7 @@ import io.netty.util.concurrent.Future;
import io.netty.util.concurrent.ImmediateEventExecutor;
import io.netty.util.concurrent.Promise;
import io.netty.util.concurrent.ScheduledFuture;
import org.hamcrest.Matchers;
import org.junit.After;
import org.junit.AfterClass;
import org.junit.Test;
@ -50,6 +51,7 @@ import java.util.List;
import java.util.NoSuchElementException;
import java.util.Queue;
import java.util.concurrent.Callable;
import java.util.concurrent.CompletionException;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
@ -62,6 +64,7 @@ import static org.junit.Assert.assertFalse;
import static org.junit.Assert.assertNotNull;
import static org.junit.Assert.assertNull;
import static org.junit.Assert.assertSame;
import static org.junit.Assert.assertThat;
import static org.junit.Assert.assertTrue;
import static org.junit.Assert.fail;
@ -761,7 +764,7 @@ public class DefaultChannelPipelineTest {
pipeline.channel().close().syncUninterruptibly();
}
@Test(timeout = 2000)
@Test(timeout = 5000)
public void testAddRemoveHandlerCalled() throws Throwable {
ChannelPipeline pipeline = newLocalChannel().pipeline();
@ -815,6 +818,110 @@ public class DefaultChannelPipelineTest {
}
}
@Test(timeout = 5000)
public void testOperationsFailWhenRemoved() {
ChannelPipeline pipeline = newLocalChannel().pipeline();
try {
pipeline.channel().register().syncUninterruptibly();
ChannelHandler handler = new ChannelHandler() { };
pipeline.addFirst(handler);
ChannelHandlerContext ctx = pipeline.context(handler);
pipeline.remove(handler);
testOperationsFailsOnContext(ctx);
} finally {
pipeline.channel().close().syncUninterruptibly();
}
}
@Test(timeout = 5000)
public void testOperationsFailWhenReplaced() {
ChannelPipeline pipeline = newLocalChannel().pipeline();
try {
pipeline.channel().register().syncUninterruptibly();
ChannelHandler handler = new ChannelHandler() { };
pipeline.addFirst(handler);
ChannelHandlerContext ctx = pipeline.context(handler);
pipeline.replace(handler, null, new ChannelHandler() { });
testOperationsFailsOnContext(ctx);
} finally {
pipeline.channel().close().syncUninterruptibly();
}
}
private static void testOperationsFailsOnContext(ChannelHandlerContext ctx) {
assertChannelPipelineException(ctx.writeAndFlush(""));
assertChannelPipelineException(ctx.write(""));
assertChannelPipelineException(ctx.bind(new SocketAddress() { }));
assertChannelPipelineException(ctx.close());
assertChannelPipelineException(ctx.connect(new SocketAddress() { }));
assertChannelPipelineException(ctx.connect(new SocketAddress() { }, new SocketAddress() { }));
assertChannelPipelineException(ctx.deregister());
assertChannelPipelineException(ctx.disconnect());
class ChannelPipelineExceptionValidator implements ChannelHandler {
private Promise<Void> validationPromise = ImmediateEventExecutor.INSTANCE.newPromise();
@Override
public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) {
try {
assertThat(cause, Matchers.instanceOf(ChannelPipelineException.class));
} catch (Throwable error) {
validationPromise.setFailure(error);
return;
}
validationPromise.setSuccess(null);
}
void validate() {
validationPromise.syncUninterruptibly();
validationPromise = ImmediateEventExecutor.INSTANCE.newPromise();
}
}
ChannelPipelineExceptionValidator validator = new ChannelPipelineExceptionValidator();
ctx.pipeline().addLast(validator);
ctx.fireChannelRead("");
validator.validate();
ctx.fireUserEventTriggered("");
validator.validate();
ctx.fireChannelReadComplete();
validator.validate();
ctx.fireExceptionCaught(new Exception());
validator.validate();
ctx.fireChannelActive();
validator.validate();
ctx.fireChannelRegistered();
validator.validate();
ctx.fireChannelInactive();
validator.validate();
ctx.fireChannelUnregistered();
validator.validate();
ctx.fireChannelWritabilityChanged();
validator.validate();
}
private static void assertChannelPipelineException(ChannelFuture f) {
try {
f.syncUninterruptibly();
} catch (CompletionException e) {
assertThat(e.getCause(), Matchers.instanceOf(ChannelPipelineException.class));
}
}
@Test(timeout = 3000)
public void testAddReplaceHandlerCalled() throws Throwable {
ChannelPipeline pipeline = newLocalChannel().pipeline();

View File

@ -232,7 +232,9 @@ public class PendingWriteQueueTest {
ChannelPromise promise = channel.newPromise();
final ChannelPromise promise3 = channel.newPromise();
promise.addListener((ChannelFutureListener) future -> queue.add(3L, promise3));
promise.addListener((ChannelFutureListener) future -> {
queue.add(3L, promise3);
});
ChannelPromise promise2 = channel.newPromise();
channel.eventLoop().execute(() -> {
@ -245,8 +247,13 @@ public class PendingWriteQueueTest {
assertTrue(promise.isSuccess());
assertTrue(promise2.isDone());
assertTrue(promise2.isSuccess());
assertFalse(promise3.isDone());
assertFalse(promise3.isSuccess());
channel.eventLoop().execute(queue::removeAndWriteAll);
assertTrue(promise3.isDone());
assertTrue(promise3.isSuccess());
channel.runPendingTasks();
assertTrue(channel.finish());
assertEquals(1L, (long) channel.readOutbound());
assertEquals(2L, (long) channel.readOutbound());