From af39cb6b12d9c1a8cf97018dc038fbecc3ea1277 Mon Sep 17 00:00:00 2001 From: Norman Maurer Date: Fri, 15 Jan 2016 14:23:41 +0100 Subject: [PATCH] Ensure ChannelHandler.handlerAdded(...) is always called as first method of the handler Motivation: If a user adds a ChannelHandler from outside the EventLoop it is possible to get into the situation that handlerAdded(...) is scheduled on the EventLoop and so called after another methods of the ChannelHandler as the EventLoop may already be executing on this point in time. Modification: - Ensure we always check if the handlerAdded(...) method was called already and if not add the currently needed call to the EventLoop so it will be picked up after handlerAdded(...) was called. This works as if the handler is added to the ChannelPipeline from outside the EventLoop the actual handlerAdded(...) operation is scheduled on the EventLoop. - Some cleanup in the DefaultChannelPipeline Result: Correctly order of method executions of ChannelHandler. --- .../AbstractChannelHandlerContext.java | 352 ++++++++++++++++-- .../netty/channel/DefaultChannelPipeline.java | 23 +- 2 files changed, 337 insertions(+), 38 deletions(-) diff --git a/transport/src/main/java/io/netty/channel/AbstractChannelHandlerContext.java b/transport/src/main/java/io/netty/channel/AbstractChannelHandlerContext.java index f3b4b12c14..e12f0f054e 100644 --- a/transport/src/main/java/io/netty/channel/AbstractChannelHandlerContext.java +++ b/transport/src/main/java/io/netty/channel/AbstractChannelHandlerContext.java @@ -20,6 +20,7 @@ import io.netty.util.Attribute; import io.netty.util.AttributeKey; import io.netty.util.ResourceLeakHint; import io.netty.util.concurrent.EventExecutor; +import io.netty.util.internal.OneTimeTask; import io.netty.util.internal.StringUtil; import java.net.SocketAddress; @@ -33,7 +34,23 @@ abstract class AbstractChannelHandlerContext implements ChannelHandlerContext, R private final boolean outbound; private final DefaultChannelPipeline pipeline; private final String name; - private boolean removed; + private boolean handlerRemoved; + + /** + * This is set to {@code true} once the {@link ChannelHandler#handlerAdded(ChannelHandlerContext) method was called. + * We need to keep track of this This will set to true once the + * {@link ChannelHandler#handlerAdded(ChannelHandlerContext)} method was called. We need to keep track of this + * to ensure we will never call another {@link ChannelHandler} method before handlerAdded(...) was called + * to guard againstordering issues. {@link ChannelHandler#handlerAdded(ChannelHandlerContext)} MUST be the first + * method that is called for handler when it becomes a part of the {@link ChannelPipeline} in all cases. Not doing + * so may lead to unexpected side-effects as {@link ChannelHandler} implementationsmay need to do initialization + * steps before a {@link ChannelHandler} can be used. + * + * See #4705 + * + * No need to mark volatile as this will be made visible as next/prev is volatile. + */ + private boolean handlerAdded; final ChannelHandlerInvoker invoker; private ChannelFuture succeededFuture; @@ -82,15 +99,6 @@ abstract class AbstractChannelHandlerContext implements ChannelHandlerContext, R return invoker().executor(); } - @Override - public ChannelHandlerInvoker invoker() { - if (invoker == null) { - return channel().unsafe().invoker(); - } else { - return invoker; - } - } - @Override public String name() { return name; @@ -109,63 +117,63 @@ abstract class AbstractChannelHandlerContext implements ChannelHandlerContext, R @Override public ChannelHandlerContext fireChannelRegistered() { AbstractChannelHandlerContext next = findContextInbound(); - next.invoker().invokeChannelRegistered(next); + next.invokeChannelRegistered(); return this; } @Override public ChannelHandlerContext fireChannelUnregistered() { AbstractChannelHandlerContext next = findContextInbound(); - next.invoker().invokeChannelUnregistered(next); + next.invokeChannelUnregistered(); return this; } @Override public ChannelHandlerContext fireChannelActive() { AbstractChannelHandlerContext next = findContextInbound(); - next.invoker().invokeChannelActive(next); + next.invokeChannelActive(); return this; } @Override public ChannelHandlerContext fireChannelInactive() { AbstractChannelHandlerContext next = findContextInbound(); - next.invoker().invokeChannelInactive(next); + next.invokeChannelInactive(); return this; } @Override public ChannelHandlerContext fireExceptionCaught(Throwable cause) { AbstractChannelHandlerContext next = this.next; - next.invoker().invokeExceptionCaught(next, cause); + next.invokeExceptionCaught(cause); return this; } @Override public ChannelHandlerContext fireUserEventTriggered(Object event) { AbstractChannelHandlerContext next = findContextInbound(); - next.invoker().invokeUserEventTriggered(next, event); + next.invokeUserEventTriggered(event); return this; } @Override public ChannelHandlerContext fireChannelRead(Object msg) { AbstractChannelHandlerContext next = findContextInbound(); - next.invoker().invokeChannelRead(next, pipeline.touch(msg, next)); + next.invokeChannelRead(pipeline.touch(msg, next)); return this; } @Override public ChannelHandlerContext fireChannelReadComplete() { AbstractChannelHandlerContext next = findContextInbound(); - next.invoker().invokeChannelReadComplete(next); + next.invokeChannelReadComplete(); return this; } @Override public ChannelHandlerContext fireChannelWritabilityChanged() { AbstractChannelHandlerContext next = findContextInbound(); - next.invoker().invokeChannelWritabilityChanged(next); + next.invokeChannelWritabilityChanged(); return this; } @@ -202,7 +210,7 @@ abstract class AbstractChannelHandlerContext implements ChannelHandlerContext, R @Override public ChannelFuture bind(final SocketAddress localAddress, final ChannelPromise promise) { AbstractChannelHandlerContext next = findContextOutbound(); - next.invoker().invokeBind(next, localAddress, promise); + next.invokeBind(localAddress, promise); return promise; } @@ -214,7 +222,7 @@ abstract class AbstractChannelHandlerContext implements ChannelHandlerContext, R @Override public ChannelFuture connect(SocketAddress remoteAddress, SocketAddress localAddress, ChannelPromise promise) { AbstractChannelHandlerContext next = findContextOutbound(); - next.invoker().invokeConnect(next, remoteAddress, localAddress, promise); + next.invokeConnect(remoteAddress, localAddress, promise); return promise; } @@ -225,28 +233,28 @@ abstract class AbstractChannelHandlerContext implements ChannelHandlerContext, R } AbstractChannelHandlerContext next = findContextOutbound(); - next.invoker().invokeDisconnect(next, promise); + next.invokeDisconnect(promise); return promise; } @Override public ChannelFuture close(ChannelPromise promise) { AbstractChannelHandlerContext next = findContextOutbound(); - next.invoker().invokeClose(next, promise); + next.invokeClose(promise); return promise; } @Override public ChannelFuture deregister(ChannelPromise promise) { AbstractChannelHandlerContext next = findContextOutbound(); - next.invoker().invokeDeregister(next, promise); + next.invokeDeregister(promise); return promise; } @Override public ChannelHandlerContext read() { AbstractChannelHandlerContext next = findContextOutbound(); - next.invoker().invokeRead(next); + next.invokeRead(); return this; } @@ -258,23 +266,21 @@ abstract class AbstractChannelHandlerContext implements ChannelHandlerContext, R @Override public ChannelFuture write(Object msg, ChannelPromise promise) { AbstractChannelHandlerContext next = findContextOutbound(); - next.invoker().invokeWrite(next, pipeline.touch(msg, next), promise); + next.invokeWrite(pipeline.touch(msg, next), promise); return promise; } @Override public ChannelHandlerContext flush() { AbstractChannelHandlerContext next = findContextOutbound(); - next.invoker().invokeFlush(next); + next.invokeFlush(); return this; } @Override public ChannelFuture writeAndFlush(Object msg, ChannelPromise promise) { AbstractChannelHandlerContext next = findContextOutbound(); - ChannelHandlerInvoker invoker = next.invoker(); - invoker.invokeWrite(next, pipeline.touch(msg, next) , promise); - invoker.invokeFlush(next); + next.invokeWriteAndFlush(pipeline.touch(msg, next), promise); return promise; } @@ -329,12 +335,294 @@ abstract class AbstractChannelHandlerContext implements ChannelHandlerContext, R } void setRemoved() { - removed = true; + handlerRemoved = true; } @Override public boolean isRemoved() { - return removed; + return handlerRemoved; + } + + final void invokeChannelRegistered() { + final ChannelHandlerInvoker invoker = invoker(); + if (handlerAdded) { + invoker.invokeChannelRegistered(this); + } else { + invoker.executor().execute(new OneTimeTask() { + @Override + public void run() { + assert handlerAdded; + invoker.invokeChannelRegistered(AbstractChannelHandlerContext.this); + } + }); + } + } + + final void invokeChannelUnregistered() { + final ChannelHandlerInvoker invoker = invoker(); + if (handlerAdded) { + invoker.invokeChannelUnregistered(this); + } else { + invoker.executor().execute(new OneTimeTask() { + @Override + public void run() { + assert handlerAdded; + invoker.invokeChannelUnregistered(AbstractChannelHandlerContext.this); + } + }); + } + } + + final void invokeChannelActive() { + final ChannelHandlerInvoker invoker = invoker(); + if (handlerAdded) { + invoker.invokeChannelActive(this); + } else { + invoker.executor().execute(new OneTimeTask() { + @Override + public void run() { + assert handlerAdded; + invoker.invokeChannelActive(AbstractChannelHandlerContext.this); + } + }); + } + } + + final void invokeChannelInactive() { + final ChannelHandlerInvoker invoker = invoker(); + if (handlerAdded) { + invoker.invokeChannelInactive(this); + } else { + invoker.executor().execute(new OneTimeTask() { + @Override + public void run() { + assert handlerAdded; + invoker.invokeChannelInactive(AbstractChannelHandlerContext.this); + } + }); + } + } + + final void invokeExceptionCaught(final Throwable cause) { + final ChannelHandlerInvoker invoker = invoker(); + if (handlerAdded) { + invoker.invokeExceptionCaught(this, cause); + } else { + invoker.executor().execute(new OneTimeTask() { + @Override + public void run() { + assert handlerAdded; + invoker.invokeExceptionCaught(AbstractChannelHandlerContext.this, cause); + } + }); + } + } + + final void invokeUserEventTriggered(final Object event) { + final ChannelHandlerInvoker invoker = invoker(); + if (handlerAdded) { + invoker.invokeUserEventTriggered(this, event); + } else { + invoker.executor().execute(new OneTimeTask() { + @Override + public void run() { + assert handlerAdded; + invoker.invokeUserEventTriggered(AbstractChannelHandlerContext.this, event); + } + }); + } + } + + final void invokeChannelRead(final Object msg) { + final ChannelHandlerInvoker invoker = invoker(); + if (handlerAdded) { + invoker.invokeChannelRead(this, msg); + } else { + invoker.executor().execute(new OneTimeTask() { + @Override + public void run() { + assert handlerAdded; + invoker.invokeChannelRead(AbstractChannelHandlerContext.this, msg); + } + }); + } + } + + final void invokeChannelReadComplete() { + final ChannelHandlerInvoker invoker = invoker(); + if (handlerAdded) { + invoker.invokeChannelReadComplete(this); + } else { + invoker.executor().execute(new OneTimeTask() { + @Override + public void run() { + assert handlerAdded; + invoker.invokeChannelReadComplete(AbstractChannelHandlerContext.this); + } + }); + } + } + + final void invokeChannelWritabilityChanged() { + final ChannelHandlerInvoker invoker = invoker(); + if (handlerAdded) { + invoker.invokeChannelWritabilityChanged(this); + } else { + invoker.executor().execute(new OneTimeTask() { + @Override + public void run() { + assert handlerAdded; + invoker.invokeChannelWritabilityChanged(AbstractChannelHandlerContext.this); + } + }); + } + } + + final void invokeBind(final SocketAddress localAddress, final ChannelPromise promise) { + final ChannelHandlerInvoker invoker = invoker(); + if (handlerAdded) { + invoker.invokeBind(this, localAddress, promise); + } else { + invoker.executor().execute(new OneTimeTask() { + @Override + public void run() { + assert handlerAdded; + invoker.invokeBind(AbstractChannelHandlerContext.this, localAddress, promise); + } + }); + } + } + + final void invokeConnect(final SocketAddress remoteAddress, + final SocketAddress localAddress, final ChannelPromise promise) { + final ChannelHandlerInvoker invoker = invoker(); + if (handlerAdded) { + invoker.invokeConnect(this, remoteAddress, localAddress, promise); + } else { + invoker.executor().execute(new OneTimeTask() { + @Override + public void run() { + assert handlerAdded; + invoker.invokeConnect(AbstractChannelHandlerContext.this, remoteAddress, localAddress, promise); + } + }); + } + } + + final void invokeDisconnect(final ChannelPromise promise) { + final ChannelHandlerInvoker invoker = invoker(); + if (handlerAdded) { + invoker.invokeDisconnect(this, promise); + } else { + invoker.executor().execute(new OneTimeTask() { + @Override + public void run() { + assert handlerAdded; + invoker.invokeDisconnect(AbstractChannelHandlerContext.this, promise); + } + }); + } + } + + final void invokeClose(final ChannelPromise promise) { + final ChannelHandlerInvoker invoker = invoker(); + if (handlerAdded) { + invoker.invokeClose(this, promise); + } else { + invoker.executor().execute(new OneTimeTask() { + @Override + public void run() { + assert handlerAdded; + invoker.invokeClose(AbstractChannelHandlerContext.this, promise); + } + }); + } + } + + final void invokeDeregister(final ChannelPromise promise) { + final ChannelHandlerInvoker invoker = invoker(); + if (handlerAdded) { + invoker.invokeDeregister(this, promise); + } else { + invoker.executor().execute(new OneTimeTask() { + @Override + public void run() { + assert handlerAdded; + invoker.invokeDeregister(AbstractChannelHandlerContext.this, promise); + } + }); + } + } + + final void invokeRead() { + final ChannelHandlerInvoker invoker = invoker(); + if (handlerAdded) { + invoker.invokeRead(this); + } else { + invoker.executor().execute(new OneTimeTask() { + @Override + public void run() { + assert handlerAdded; + invoker.invokeRead(AbstractChannelHandlerContext.this); + } + }); + } + } + + final void invokeWrite(final Object msg, final ChannelPromise promise) { + final ChannelHandlerInvoker invoker = invoker(); + if (handlerAdded) { + invoker.invokeWrite(this, msg, promise); + } else { + invoker.executor().execute(new OneTimeTask() { + @Override + public void run() { + assert handlerAdded; + invoker.invokeWrite(AbstractChannelHandlerContext.this, msg, promise); + } + }); + } + } + + final void invokeFlush() { + final ChannelHandlerInvoker invoker = invoker(); + if (handlerAdded) { + invoker.invokeFlush(this); + } else { + invoker.executor().execute(new OneTimeTask() { + @Override + public void run() { + assert handlerAdded; + invoker.invokeFlush(AbstractChannelHandlerContext.this); + } + }); + } + } + + final void invokeWriteAndFlush(final Object msg, final ChannelPromise promise) { + final ChannelHandlerInvoker invoker = invoker(); + if (handlerAdded) { + invoker.invokeWrite(this, msg, promise); + invoker.invokeFlush(this); + } else { + invoker.executor().execute(new OneTimeTask() { + @Override + public void run() { + assert handlerAdded; + invoker.invokeWrite(AbstractChannelHandlerContext.this, msg, promise); + invoker.invokeFlush(AbstractChannelHandlerContext.this); + } + }); + } + } + + @Override + public ChannelHandlerInvoker invoker() { + return invoker == null ? channel().unsafe().invoker() : invoker; + } + + final void setHandlerAddedCalled() { + handlerAdded = true; } @Override diff --git a/transport/src/main/java/io/netty/channel/DefaultChannelPipeline.java b/transport/src/main/java/io/netty/channel/DefaultChannelPipeline.java index 921a7e33da..03e8b235df 100644 --- a/transport/src/main/java/io/netty/channel/DefaultChannelPipeline.java +++ b/transport/src/main/java/io/netty/channel/DefaultChannelPipeline.java @@ -74,7 +74,9 @@ final class DefaultChannelPipeline implements ChannelPipeline { this.channel = channel; tail = new TailContext(this); + tail.setHandlerAddedCalled(); head = new HeadContext(this); + head.setHandlerAddedCalled(); head.next = tail; tail.prev = head; @@ -102,7 +104,7 @@ final class DefaultChannelPipeline implements ChannelPipeline { } @Override - public ChannelPipeline addFirst(ChannelHandlerInvoker invoker, String name, ChannelHandler handler) { + public synchronized ChannelPipeline addFirst(ChannelHandlerInvoker invoker, String name, ChannelHandler handler) { name = filterName(name, handler); addFirst0(new DefaultChannelHandlerContext(this, invoker, name, handler)); return this; @@ -392,7 +394,7 @@ final class DefaultChannelPipeline implements ChannelPipeline { Future future; synchronized (this) { - if (!ctx.channel().isRegistered() || ctx.executor().inEventLoop()) { + if (!isExecuteLater(ctx)) { remove0(ctx); return ctx; } else { @@ -474,7 +476,7 @@ final class DefaultChannelPipeline implements ChannelPipeline { final AbstractChannelHandlerContext newCtx = new DefaultChannelHandlerContext(this, ctx.invoker, newName, newHandler); - if (!newCtx.channel().isRegistered() || newCtx.executor().inEventLoop()) { + if (!isExecuteLater(newCtx)) { replace0(ctx, newCtx); return ctx.handler(); } else { @@ -537,7 +539,7 @@ final class DefaultChannelPipeline implements ChannelPipeline { } private void callHandlerAdded(final AbstractChannelHandlerContext ctx) { - if (ctx.channel().isRegistered() && !ctx.executor().inEventLoop()) { + if (isExecuteLater(ctx)) { ctx.executor().execute(new OneTimeTask() { @Override public void run() { @@ -551,7 +553,12 @@ final class DefaultChannelPipeline implements ChannelPipeline { private void callHandlerAdded0(final AbstractChannelHandlerContext ctx) { try { - ctx.handler().handlerAdded(ctx); + try { + ctx.handler().handlerAdded(ctx); + } finally { + // handlerAdded(...) method was called. + ctx.setHandlerAddedCalled(); + } } catch (Throwable t) { boolean removed = false; try { @@ -576,7 +583,7 @@ final class DefaultChannelPipeline implements ChannelPipeline { } private void callHandlerRemoved(final AbstractChannelHandlerContext ctx) { - if (ctx.channel().isRegistered() && !ctx.executor().inEventLoop()) { + if (isExecuteLater(ctx)) { ctx.executor().execute(new OneTimeTask() { @Override public void run() { @@ -599,6 +606,10 @@ final class DefaultChannelPipeline implements ChannelPipeline { } } + private static boolean isExecuteLater(ChannelHandlerContext ctx) { + return ctx.channel().isRegistered() && !ctx.executor().inEventLoop(); + } + /** * Waits for a future to finish. If the task is interrupted, then the current thread will be interrupted. * It is expected that the task performs any appropriate locking.