diff --git a/transport/src/main/java/io/netty/channel/AbstractChannelHandlerContext.java b/transport/src/main/java/io/netty/channel/AbstractChannelHandlerContext.java index f3b4b12c14..e12f0f054e 100644 --- a/transport/src/main/java/io/netty/channel/AbstractChannelHandlerContext.java +++ b/transport/src/main/java/io/netty/channel/AbstractChannelHandlerContext.java @@ -20,6 +20,7 @@ import io.netty.util.Attribute; import io.netty.util.AttributeKey; import io.netty.util.ResourceLeakHint; import io.netty.util.concurrent.EventExecutor; +import io.netty.util.internal.OneTimeTask; import io.netty.util.internal.StringUtil; import java.net.SocketAddress; @@ -33,7 +34,23 @@ abstract class AbstractChannelHandlerContext implements ChannelHandlerContext, R private final boolean outbound; private final DefaultChannelPipeline pipeline; private final String name; - private boolean removed; + private boolean handlerRemoved; + + /** + * This is set to {@code true} once the {@link ChannelHandler#handlerAdded(ChannelHandlerContext) method was called. + * We need to keep track of this This will set to true once the + * {@link ChannelHandler#handlerAdded(ChannelHandlerContext)} method was called. We need to keep track of this + * to ensure we will never call another {@link ChannelHandler} method before handlerAdded(...) was called + * to guard againstordering issues. {@link ChannelHandler#handlerAdded(ChannelHandlerContext)} MUST be the first + * method that is called for handler when it becomes a part of the {@link ChannelPipeline} in all cases. Not doing + * so may lead to unexpected side-effects as {@link ChannelHandler} implementationsmay need to do initialization + * steps before a {@link ChannelHandler} can be used. + * + * See #4705 + * + * No need to mark volatile as this will be made visible as next/prev is volatile. + */ + private boolean handlerAdded; final ChannelHandlerInvoker invoker; private ChannelFuture succeededFuture; @@ -82,15 +99,6 @@ abstract class AbstractChannelHandlerContext implements ChannelHandlerContext, R return invoker().executor(); } - @Override - public ChannelHandlerInvoker invoker() { - if (invoker == null) { - return channel().unsafe().invoker(); - } else { - return invoker; - } - } - @Override public String name() { return name; @@ -109,63 +117,63 @@ abstract class AbstractChannelHandlerContext implements ChannelHandlerContext, R @Override public ChannelHandlerContext fireChannelRegistered() { AbstractChannelHandlerContext next = findContextInbound(); - next.invoker().invokeChannelRegistered(next); + next.invokeChannelRegistered(); return this; } @Override public ChannelHandlerContext fireChannelUnregistered() { AbstractChannelHandlerContext next = findContextInbound(); - next.invoker().invokeChannelUnregistered(next); + next.invokeChannelUnregistered(); return this; } @Override public ChannelHandlerContext fireChannelActive() { AbstractChannelHandlerContext next = findContextInbound(); - next.invoker().invokeChannelActive(next); + next.invokeChannelActive(); return this; } @Override public ChannelHandlerContext fireChannelInactive() { AbstractChannelHandlerContext next = findContextInbound(); - next.invoker().invokeChannelInactive(next); + next.invokeChannelInactive(); return this; } @Override public ChannelHandlerContext fireExceptionCaught(Throwable cause) { AbstractChannelHandlerContext next = this.next; - next.invoker().invokeExceptionCaught(next, cause); + next.invokeExceptionCaught(cause); return this; } @Override public ChannelHandlerContext fireUserEventTriggered(Object event) { AbstractChannelHandlerContext next = findContextInbound(); - next.invoker().invokeUserEventTriggered(next, event); + next.invokeUserEventTriggered(event); return this; } @Override public ChannelHandlerContext fireChannelRead(Object msg) { AbstractChannelHandlerContext next = findContextInbound(); - next.invoker().invokeChannelRead(next, pipeline.touch(msg, next)); + next.invokeChannelRead(pipeline.touch(msg, next)); return this; } @Override public ChannelHandlerContext fireChannelReadComplete() { AbstractChannelHandlerContext next = findContextInbound(); - next.invoker().invokeChannelReadComplete(next); + next.invokeChannelReadComplete(); return this; } @Override public ChannelHandlerContext fireChannelWritabilityChanged() { AbstractChannelHandlerContext next = findContextInbound(); - next.invoker().invokeChannelWritabilityChanged(next); + next.invokeChannelWritabilityChanged(); return this; } @@ -202,7 +210,7 @@ abstract class AbstractChannelHandlerContext implements ChannelHandlerContext, R @Override public ChannelFuture bind(final SocketAddress localAddress, final ChannelPromise promise) { AbstractChannelHandlerContext next = findContextOutbound(); - next.invoker().invokeBind(next, localAddress, promise); + next.invokeBind(localAddress, promise); return promise; } @@ -214,7 +222,7 @@ abstract class AbstractChannelHandlerContext implements ChannelHandlerContext, R @Override public ChannelFuture connect(SocketAddress remoteAddress, SocketAddress localAddress, ChannelPromise promise) { AbstractChannelHandlerContext next = findContextOutbound(); - next.invoker().invokeConnect(next, remoteAddress, localAddress, promise); + next.invokeConnect(remoteAddress, localAddress, promise); return promise; } @@ -225,28 +233,28 @@ abstract class AbstractChannelHandlerContext implements ChannelHandlerContext, R } AbstractChannelHandlerContext next = findContextOutbound(); - next.invoker().invokeDisconnect(next, promise); + next.invokeDisconnect(promise); return promise; } @Override public ChannelFuture close(ChannelPromise promise) { AbstractChannelHandlerContext next = findContextOutbound(); - next.invoker().invokeClose(next, promise); + next.invokeClose(promise); return promise; } @Override public ChannelFuture deregister(ChannelPromise promise) { AbstractChannelHandlerContext next = findContextOutbound(); - next.invoker().invokeDeregister(next, promise); + next.invokeDeregister(promise); return promise; } @Override public ChannelHandlerContext read() { AbstractChannelHandlerContext next = findContextOutbound(); - next.invoker().invokeRead(next); + next.invokeRead(); return this; } @@ -258,23 +266,21 @@ abstract class AbstractChannelHandlerContext implements ChannelHandlerContext, R @Override public ChannelFuture write(Object msg, ChannelPromise promise) { AbstractChannelHandlerContext next = findContextOutbound(); - next.invoker().invokeWrite(next, pipeline.touch(msg, next), promise); + next.invokeWrite(pipeline.touch(msg, next), promise); return promise; } @Override public ChannelHandlerContext flush() { AbstractChannelHandlerContext next = findContextOutbound(); - next.invoker().invokeFlush(next); + next.invokeFlush(); return this; } @Override public ChannelFuture writeAndFlush(Object msg, ChannelPromise promise) { AbstractChannelHandlerContext next = findContextOutbound(); - ChannelHandlerInvoker invoker = next.invoker(); - invoker.invokeWrite(next, pipeline.touch(msg, next) , promise); - invoker.invokeFlush(next); + next.invokeWriteAndFlush(pipeline.touch(msg, next), promise); return promise; } @@ -329,12 +335,294 @@ abstract class AbstractChannelHandlerContext implements ChannelHandlerContext, R } void setRemoved() { - removed = true; + handlerRemoved = true; } @Override public boolean isRemoved() { - return removed; + return handlerRemoved; + } + + final void invokeChannelRegistered() { + final ChannelHandlerInvoker invoker = invoker(); + if (handlerAdded) { + invoker.invokeChannelRegistered(this); + } else { + invoker.executor().execute(new OneTimeTask() { + @Override + public void run() { + assert handlerAdded; + invoker.invokeChannelRegistered(AbstractChannelHandlerContext.this); + } + }); + } + } + + final void invokeChannelUnregistered() { + final ChannelHandlerInvoker invoker = invoker(); + if (handlerAdded) { + invoker.invokeChannelUnregistered(this); + } else { + invoker.executor().execute(new OneTimeTask() { + @Override + public void run() { + assert handlerAdded; + invoker.invokeChannelUnregistered(AbstractChannelHandlerContext.this); + } + }); + } + } + + final void invokeChannelActive() { + final ChannelHandlerInvoker invoker = invoker(); + if (handlerAdded) { + invoker.invokeChannelActive(this); + } else { + invoker.executor().execute(new OneTimeTask() { + @Override + public void run() { + assert handlerAdded; + invoker.invokeChannelActive(AbstractChannelHandlerContext.this); + } + }); + } + } + + final void invokeChannelInactive() { + final ChannelHandlerInvoker invoker = invoker(); + if (handlerAdded) { + invoker.invokeChannelInactive(this); + } else { + invoker.executor().execute(new OneTimeTask() { + @Override + public void run() { + assert handlerAdded; + invoker.invokeChannelInactive(AbstractChannelHandlerContext.this); + } + }); + } + } + + final void invokeExceptionCaught(final Throwable cause) { + final ChannelHandlerInvoker invoker = invoker(); + if (handlerAdded) { + invoker.invokeExceptionCaught(this, cause); + } else { + invoker.executor().execute(new OneTimeTask() { + @Override + public void run() { + assert handlerAdded; + invoker.invokeExceptionCaught(AbstractChannelHandlerContext.this, cause); + } + }); + } + } + + final void invokeUserEventTriggered(final Object event) { + final ChannelHandlerInvoker invoker = invoker(); + if (handlerAdded) { + invoker.invokeUserEventTriggered(this, event); + } else { + invoker.executor().execute(new OneTimeTask() { + @Override + public void run() { + assert handlerAdded; + invoker.invokeUserEventTriggered(AbstractChannelHandlerContext.this, event); + } + }); + } + } + + final void invokeChannelRead(final Object msg) { + final ChannelHandlerInvoker invoker = invoker(); + if (handlerAdded) { + invoker.invokeChannelRead(this, msg); + } else { + invoker.executor().execute(new OneTimeTask() { + @Override + public void run() { + assert handlerAdded; + invoker.invokeChannelRead(AbstractChannelHandlerContext.this, msg); + } + }); + } + } + + final void invokeChannelReadComplete() { + final ChannelHandlerInvoker invoker = invoker(); + if (handlerAdded) { + invoker.invokeChannelReadComplete(this); + } else { + invoker.executor().execute(new OneTimeTask() { + @Override + public void run() { + assert handlerAdded; + invoker.invokeChannelReadComplete(AbstractChannelHandlerContext.this); + } + }); + } + } + + final void invokeChannelWritabilityChanged() { + final ChannelHandlerInvoker invoker = invoker(); + if (handlerAdded) { + invoker.invokeChannelWritabilityChanged(this); + } else { + invoker.executor().execute(new OneTimeTask() { + @Override + public void run() { + assert handlerAdded; + invoker.invokeChannelWritabilityChanged(AbstractChannelHandlerContext.this); + } + }); + } + } + + final void invokeBind(final SocketAddress localAddress, final ChannelPromise promise) { + final ChannelHandlerInvoker invoker = invoker(); + if (handlerAdded) { + invoker.invokeBind(this, localAddress, promise); + } else { + invoker.executor().execute(new OneTimeTask() { + @Override + public void run() { + assert handlerAdded; + invoker.invokeBind(AbstractChannelHandlerContext.this, localAddress, promise); + } + }); + } + } + + final void invokeConnect(final SocketAddress remoteAddress, + final SocketAddress localAddress, final ChannelPromise promise) { + final ChannelHandlerInvoker invoker = invoker(); + if (handlerAdded) { + invoker.invokeConnect(this, remoteAddress, localAddress, promise); + } else { + invoker.executor().execute(new OneTimeTask() { + @Override + public void run() { + assert handlerAdded; + invoker.invokeConnect(AbstractChannelHandlerContext.this, remoteAddress, localAddress, promise); + } + }); + } + } + + final void invokeDisconnect(final ChannelPromise promise) { + final ChannelHandlerInvoker invoker = invoker(); + if (handlerAdded) { + invoker.invokeDisconnect(this, promise); + } else { + invoker.executor().execute(new OneTimeTask() { + @Override + public void run() { + assert handlerAdded; + invoker.invokeDisconnect(AbstractChannelHandlerContext.this, promise); + } + }); + } + } + + final void invokeClose(final ChannelPromise promise) { + final ChannelHandlerInvoker invoker = invoker(); + if (handlerAdded) { + invoker.invokeClose(this, promise); + } else { + invoker.executor().execute(new OneTimeTask() { + @Override + public void run() { + assert handlerAdded; + invoker.invokeClose(AbstractChannelHandlerContext.this, promise); + } + }); + } + } + + final void invokeDeregister(final ChannelPromise promise) { + final ChannelHandlerInvoker invoker = invoker(); + if (handlerAdded) { + invoker.invokeDeregister(this, promise); + } else { + invoker.executor().execute(new OneTimeTask() { + @Override + public void run() { + assert handlerAdded; + invoker.invokeDeregister(AbstractChannelHandlerContext.this, promise); + } + }); + } + } + + final void invokeRead() { + final ChannelHandlerInvoker invoker = invoker(); + if (handlerAdded) { + invoker.invokeRead(this); + } else { + invoker.executor().execute(new OneTimeTask() { + @Override + public void run() { + assert handlerAdded; + invoker.invokeRead(AbstractChannelHandlerContext.this); + } + }); + } + } + + final void invokeWrite(final Object msg, final ChannelPromise promise) { + final ChannelHandlerInvoker invoker = invoker(); + if (handlerAdded) { + invoker.invokeWrite(this, msg, promise); + } else { + invoker.executor().execute(new OneTimeTask() { + @Override + public void run() { + assert handlerAdded; + invoker.invokeWrite(AbstractChannelHandlerContext.this, msg, promise); + } + }); + } + } + + final void invokeFlush() { + final ChannelHandlerInvoker invoker = invoker(); + if (handlerAdded) { + invoker.invokeFlush(this); + } else { + invoker.executor().execute(new OneTimeTask() { + @Override + public void run() { + assert handlerAdded; + invoker.invokeFlush(AbstractChannelHandlerContext.this); + } + }); + } + } + + final void invokeWriteAndFlush(final Object msg, final ChannelPromise promise) { + final ChannelHandlerInvoker invoker = invoker(); + if (handlerAdded) { + invoker.invokeWrite(this, msg, promise); + invoker.invokeFlush(this); + } else { + invoker.executor().execute(new OneTimeTask() { + @Override + public void run() { + assert handlerAdded; + invoker.invokeWrite(AbstractChannelHandlerContext.this, msg, promise); + invoker.invokeFlush(AbstractChannelHandlerContext.this); + } + }); + } + } + + @Override + public ChannelHandlerInvoker invoker() { + return invoker == null ? channel().unsafe().invoker() : invoker; + } + + final void setHandlerAddedCalled() { + handlerAdded = true; } @Override diff --git a/transport/src/main/java/io/netty/channel/DefaultChannelPipeline.java b/transport/src/main/java/io/netty/channel/DefaultChannelPipeline.java index 921a7e33da..03e8b235df 100644 --- a/transport/src/main/java/io/netty/channel/DefaultChannelPipeline.java +++ b/transport/src/main/java/io/netty/channel/DefaultChannelPipeline.java @@ -74,7 +74,9 @@ final class DefaultChannelPipeline implements ChannelPipeline { this.channel = channel; tail = new TailContext(this); + tail.setHandlerAddedCalled(); head = new HeadContext(this); + head.setHandlerAddedCalled(); head.next = tail; tail.prev = head; @@ -102,7 +104,7 @@ final class DefaultChannelPipeline implements ChannelPipeline { } @Override - public ChannelPipeline addFirst(ChannelHandlerInvoker invoker, String name, ChannelHandler handler) { + public synchronized ChannelPipeline addFirst(ChannelHandlerInvoker invoker, String name, ChannelHandler handler) { name = filterName(name, handler); addFirst0(new DefaultChannelHandlerContext(this, invoker, name, handler)); return this; @@ -392,7 +394,7 @@ final class DefaultChannelPipeline implements ChannelPipeline { Future future; synchronized (this) { - if (!ctx.channel().isRegistered() || ctx.executor().inEventLoop()) { + if (!isExecuteLater(ctx)) { remove0(ctx); return ctx; } else { @@ -474,7 +476,7 @@ final class DefaultChannelPipeline implements ChannelPipeline { final AbstractChannelHandlerContext newCtx = new DefaultChannelHandlerContext(this, ctx.invoker, newName, newHandler); - if (!newCtx.channel().isRegistered() || newCtx.executor().inEventLoop()) { + if (!isExecuteLater(newCtx)) { replace0(ctx, newCtx); return ctx.handler(); } else { @@ -537,7 +539,7 @@ final class DefaultChannelPipeline implements ChannelPipeline { } private void callHandlerAdded(final AbstractChannelHandlerContext ctx) { - if (ctx.channel().isRegistered() && !ctx.executor().inEventLoop()) { + if (isExecuteLater(ctx)) { ctx.executor().execute(new OneTimeTask() { @Override public void run() { @@ -551,7 +553,12 @@ final class DefaultChannelPipeline implements ChannelPipeline { private void callHandlerAdded0(final AbstractChannelHandlerContext ctx) { try { - ctx.handler().handlerAdded(ctx); + try { + ctx.handler().handlerAdded(ctx); + } finally { + // handlerAdded(...) method was called. + ctx.setHandlerAddedCalled(); + } } catch (Throwable t) { boolean removed = false; try { @@ -576,7 +583,7 @@ final class DefaultChannelPipeline implements ChannelPipeline { } private void callHandlerRemoved(final AbstractChannelHandlerContext ctx) { - if (ctx.channel().isRegistered() && !ctx.executor().inEventLoop()) { + if (isExecuteLater(ctx)) { ctx.executor().execute(new OneTimeTask() { @Override public void run() { @@ -599,6 +606,10 @@ final class DefaultChannelPipeline implements ChannelPipeline { } } + private static boolean isExecuteLater(ChannelHandlerContext ctx) { + return ctx.channel().isRegistered() && !ctx.executor().inEventLoop(); + } + /** * Waits for a future to finish. If the task is interrupted, then the current thread will be interrupted. * It is expected that the task performs any appropriate locking.