Ensure ChannelHandler.handlerAdded(...) is always called as first method of the handler

Motivation:

If a user adds a ChannelHandler from outside the EventLoop it is possible to get into the situation that handlerAdded(...) is scheduled on the EventLoop and so called after another methods of the ChannelHandler as the EventLoop may already be executing on this point in time.

Modification:

- Ensure we always check if the handlerAdded(...) method was called already and if not add the currently needed call to the EventLoop so it will be picked up after handlerAdded(...) was called. This works as if the handler is added to the ChannelPipeline from outside the EventLoop the actual handlerAdded(...) operation is scheduled on the EventLoop.
- Some cleanup in the DefaultChannelPipeline

Result:

Correctly order of method executions of ChannelHandler.
This commit is contained in:
Norman Maurer 2016-01-15 14:23:41 +01:00
parent a2732c6542
commit af39cb6b12
2 changed files with 337 additions and 38 deletions

View File

@ -20,6 +20,7 @@ import io.netty.util.Attribute;
import io.netty.util.AttributeKey;
import io.netty.util.ResourceLeakHint;
import io.netty.util.concurrent.EventExecutor;
import io.netty.util.internal.OneTimeTask;
import io.netty.util.internal.StringUtil;
import java.net.SocketAddress;
@ -33,7 +34,23 @@ abstract class AbstractChannelHandlerContext implements ChannelHandlerContext, R
private final boolean outbound;
private final DefaultChannelPipeline pipeline;
private final String name;
private boolean removed;
private boolean handlerRemoved;
/**
* This is set to {@code true} once the {@link ChannelHandler#handlerAdded(ChannelHandlerContext) method was called.
* We need to keep track of this This will set to true once the
* {@link ChannelHandler#handlerAdded(ChannelHandlerContext)} method was called. We need to keep track of this
* to ensure we will never call another {@link ChannelHandler} method before handlerAdded(...) was called
* to guard againstordering issues. {@link ChannelHandler#handlerAdded(ChannelHandlerContext)} MUST be the first
* method that is called for handler when it becomes a part of the {@link ChannelPipeline} in all cases. Not doing
* so may lead to unexpected side-effects as {@link ChannelHandler} implementationsmay need to do initialization
* steps before a {@link ChannelHandler} can be used.
*
* See <a href="https://github.com/netty/netty/issues/4705">#4705</a>
*
* No need to mark volatile as this will be made visible as next/prev is volatile.
*/
private boolean handlerAdded;
final ChannelHandlerInvoker invoker;
private ChannelFuture succeededFuture;
@ -82,15 +99,6 @@ abstract class AbstractChannelHandlerContext implements ChannelHandlerContext, R
return invoker().executor();
}
@Override
public ChannelHandlerInvoker invoker() {
if (invoker == null) {
return channel().unsafe().invoker();
} else {
return invoker;
}
}
@Override
public String name() {
return name;
@ -109,63 +117,63 @@ abstract class AbstractChannelHandlerContext implements ChannelHandlerContext, R
@Override
public ChannelHandlerContext fireChannelRegistered() {
AbstractChannelHandlerContext next = findContextInbound();
next.invoker().invokeChannelRegistered(next);
next.invokeChannelRegistered();
return this;
}
@Override
public ChannelHandlerContext fireChannelUnregistered() {
AbstractChannelHandlerContext next = findContextInbound();
next.invoker().invokeChannelUnregistered(next);
next.invokeChannelUnregistered();
return this;
}
@Override
public ChannelHandlerContext fireChannelActive() {
AbstractChannelHandlerContext next = findContextInbound();
next.invoker().invokeChannelActive(next);
next.invokeChannelActive();
return this;
}
@Override
public ChannelHandlerContext fireChannelInactive() {
AbstractChannelHandlerContext next = findContextInbound();
next.invoker().invokeChannelInactive(next);
next.invokeChannelInactive();
return this;
}
@Override
public ChannelHandlerContext fireExceptionCaught(Throwable cause) {
AbstractChannelHandlerContext next = this.next;
next.invoker().invokeExceptionCaught(next, cause);
next.invokeExceptionCaught(cause);
return this;
}
@Override
public ChannelHandlerContext fireUserEventTriggered(Object event) {
AbstractChannelHandlerContext next = findContextInbound();
next.invoker().invokeUserEventTriggered(next, event);
next.invokeUserEventTriggered(event);
return this;
}
@Override
public ChannelHandlerContext fireChannelRead(Object msg) {
AbstractChannelHandlerContext next = findContextInbound();
next.invoker().invokeChannelRead(next, pipeline.touch(msg, next));
next.invokeChannelRead(pipeline.touch(msg, next));
return this;
}
@Override
public ChannelHandlerContext fireChannelReadComplete() {
AbstractChannelHandlerContext next = findContextInbound();
next.invoker().invokeChannelReadComplete(next);
next.invokeChannelReadComplete();
return this;
}
@Override
public ChannelHandlerContext fireChannelWritabilityChanged() {
AbstractChannelHandlerContext next = findContextInbound();
next.invoker().invokeChannelWritabilityChanged(next);
next.invokeChannelWritabilityChanged();
return this;
}
@ -202,7 +210,7 @@ abstract class AbstractChannelHandlerContext implements ChannelHandlerContext, R
@Override
public ChannelFuture bind(final SocketAddress localAddress, final ChannelPromise promise) {
AbstractChannelHandlerContext next = findContextOutbound();
next.invoker().invokeBind(next, localAddress, promise);
next.invokeBind(localAddress, promise);
return promise;
}
@ -214,7 +222,7 @@ abstract class AbstractChannelHandlerContext implements ChannelHandlerContext, R
@Override
public ChannelFuture connect(SocketAddress remoteAddress, SocketAddress localAddress, ChannelPromise promise) {
AbstractChannelHandlerContext next = findContextOutbound();
next.invoker().invokeConnect(next, remoteAddress, localAddress, promise);
next.invokeConnect(remoteAddress, localAddress, promise);
return promise;
}
@ -225,28 +233,28 @@ abstract class AbstractChannelHandlerContext implements ChannelHandlerContext, R
}
AbstractChannelHandlerContext next = findContextOutbound();
next.invoker().invokeDisconnect(next, promise);
next.invokeDisconnect(promise);
return promise;
}
@Override
public ChannelFuture close(ChannelPromise promise) {
AbstractChannelHandlerContext next = findContextOutbound();
next.invoker().invokeClose(next, promise);
next.invokeClose(promise);
return promise;
}
@Override
public ChannelFuture deregister(ChannelPromise promise) {
AbstractChannelHandlerContext next = findContextOutbound();
next.invoker().invokeDeregister(next, promise);
next.invokeDeregister(promise);
return promise;
}
@Override
public ChannelHandlerContext read() {
AbstractChannelHandlerContext next = findContextOutbound();
next.invoker().invokeRead(next);
next.invokeRead();
return this;
}
@ -258,23 +266,21 @@ abstract class AbstractChannelHandlerContext implements ChannelHandlerContext, R
@Override
public ChannelFuture write(Object msg, ChannelPromise promise) {
AbstractChannelHandlerContext next = findContextOutbound();
next.invoker().invokeWrite(next, pipeline.touch(msg, next), promise);
next.invokeWrite(pipeline.touch(msg, next), promise);
return promise;
}
@Override
public ChannelHandlerContext flush() {
AbstractChannelHandlerContext next = findContextOutbound();
next.invoker().invokeFlush(next);
next.invokeFlush();
return this;
}
@Override
public ChannelFuture writeAndFlush(Object msg, ChannelPromise promise) {
AbstractChannelHandlerContext next = findContextOutbound();
ChannelHandlerInvoker invoker = next.invoker();
invoker.invokeWrite(next, pipeline.touch(msg, next) , promise);
invoker.invokeFlush(next);
next.invokeWriteAndFlush(pipeline.touch(msg, next), promise);
return promise;
}
@ -329,12 +335,294 @@ abstract class AbstractChannelHandlerContext implements ChannelHandlerContext, R
}
void setRemoved() {
removed = true;
handlerRemoved = true;
}
@Override
public boolean isRemoved() {
return removed;
return handlerRemoved;
}
final void invokeChannelRegistered() {
final ChannelHandlerInvoker invoker = invoker();
if (handlerAdded) {
invoker.invokeChannelRegistered(this);
} else {
invoker.executor().execute(new OneTimeTask() {
@Override
public void run() {
assert handlerAdded;
invoker.invokeChannelRegistered(AbstractChannelHandlerContext.this);
}
});
}
}
final void invokeChannelUnregistered() {
final ChannelHandlerInvoker invoker = invoker();
if (handlerAdded) {
invoker.invokeChannelUnregistered(this);
} else {
invoker.executor().execute(new OneTimeTask() {
@Override
public void run() {
assert handlerAdded;
invoker.invokeChannelUnregistered(AbstractChannelHandlerContext.this);
}
});
}
}
final void invokeChannelActive() {
final ChannelHandlerInvoker invoker = invoker();
if (handlerAdded) {
invoker.invokeChannelActive(this);
} else {
invoker.executor().execute(new OneTimeTask() {
@Override
public void run() {
assert handlerAdded;
invoker.invokeChannelActive(AbstractChannelHandlerContext.this);
}
});
}
}
final void invokeChannelInactive() {
final ChannelHandlerInvoker invoker = invoker();
if (handlerAdded) {
invoker.invokeChannelInactive(this);
} else {
invoker.executor().execute(new OneTimeTask() {
@Override
public void run() {
assert handlerAdded;
invoker.invokeChannelInactive(AbstractChannelHandlerContext.this);
}
});
}
}
final void invokeExceptionCaught(final Throwable cause) {
final ChannelHandlerInvoker invoker = invoker();
if (handlerAdded) {
invoker.invokeExceptionCaught(this, cause);
} else {
invoker.executor().execute(new OneTimeTask() {
@Override
public void run() {
assert handlerAdded;
invoker.invokeExceptionCaught(AbstractChannelHandlerContext.this, cause);
}
});
}
}
final void invokeUserEventTriggered(final Object event) {
final ChannelHandlerInvoker invoker = invoker();
if (handlerAdded) {
invoker.invokeUserEventTriggered(this, event);
} else {
invoker.executor().execute(new OneTimeTask() {
@Override
public void run() {
assert handlerAdded;
invoker.invokeUserEventTriggered(AbstractChannelHandlerContext.this, event);
}
});
}
}
final void invokeChannelRead(final Object msg) {
final ChannelHandlerInvoker invoker = invoker();
if (handlerAdded) {
invoker.invokeChannelRead(this, msg);
} else {
invoker.executor().execute(new OneTimeTask() {
@Override
public void run() {
assert handlerAdded;
invoker.invokeChannelRead(AbstractChannelHandlerContext.this, msg);
}
});
}
}
final void invokeChannelReadComplete() {
final ChannelHandlerInvoker invoker = invoker();
if (handlerAdded) {
invoker.invokeChannelReadComplete(this);
} else {
invoker.executor().execute(new OneTimeTask() {
@Override
public void run() {
assert handlerAdded;
invoker.invokeChannelReadComplete(AbstractChannelHandlerContext.this);
}
});
}
}
final void invokeChannelWritabilityChanged() {
final ChannelHandlerInvoker invoker = invoker();
if (handlerAdded) {
invoker.invokeChannelWritabilityChanged(this);
} else {
invoker.executor().execute(new OneTimeTask() {
@Override
public void run() {
assert handlerAdded;
invoker.invokeChannelWritabilityChanged(AbstractChannelHandlerContext.this);
}
});
}
}
final void invokeBind(final SocketAddress localAddress, final ChannelPromise promise) {
final ChannelHandlerInvoker invoker = invoker();
if (handlerAdded) {
invoker.invokeBind(this, localAddress, promise);
} else {
invoker.executor().execute(new OneTimeTask() {
@Override
public void run() {
assert handlerAdded;
invoker.invokeBind(AbstractChannelHandlerContext.this, localAddress, promise);
}
});
}
}
final void invokeConnect(final SocketAddress remoteAddress,
final SocketAddress localAddress, final ChannelPromise promise) {
final ChannelHandlerInvoker invoker = invoker();
if (handlerAdded) {
invoker.invokeConnect(this, remoteAddress, localAddress, promise);
} else {
invoker.executor().execute(new OneTimeTask() {
@Override
public void run() {
assert handlerAdded;
invoker.invokeConnect(AbstractChannelHandlerContext.this, remoteAddress, localAddress, promise);
}
});
}
}
final void invokeDisconnect(final ChannelPromise promise) {
final ChannelHandlerInvoker invoker = invoker();
if (handlerAdded) {
invoker.invokeDisconnect(this, promise);
} else {
invoker.executor().execute(new OneTimeTask() {
@Override
public void run() {
assert handlerAdded;
invoker.invokeDisconnect(AbstractChannelHandlerContext.this, promise);
}
});
}
}
final void invokeClose(final ChannelPromise promise) {
final ChannelHandlerInvoker invoker = invoker();
if (handlerAdded) {
invoker.invokeClose(this, promise);
} else {
invoker.executor().execute(new OneTimeTask() {
@Override
public void run() {
assert handlerAdded;
invoker.invokeClose(AbstractChannelHandlerContext.this, promise);
}
});
}
}
final void invokeDeregister(final ChannelPromise promise) {
final ChannelHandlerInvoker invoker = invoker();
if (handlerAdded) {
invoker.invokeDeregister(this, promise);
} else {
invoker.executor().execute(new OneTimeTask() {
@Override
public void run() {
assert handlerAdded;
invoker.invokeDeregister(AbstractChannelHandlerContext.this, promise);
}
});
}
}
final void invokeRead() {
final ChannelHandlerInvoker invoker = invoker();
if (handlerAdded) {
invoker.invokeRead(this);
} else {
invoker.executor().execute(new OneTimeTask() {
@Override
public void run() {
assert handlerAdded;
invoker.invokeRead(AbstractChannelHandlerContext.this);
}
});
}
}
final void invokeWrite(final Object msg, final ChannelPromise promise) {
final ChannelHandlerInvoker invoker = invoker();
if (handlerAdded) {
invoker.invokeWrite(this, msg, promise);
} else {
invoker.executor().execute(new OneTimeTask() {
@Override
public void run() {
assert handlerAdded;
invoker.invokeWrite(AbstractChannelHandlerContext.this, msg, promise);
}
});
}
}
final void invokeFlush() {
final ChannelHandlerInvoker invoker = invoker();
if (handlerAdded) {
invoker.invokeFlush(this);
} else {
invoker.executor().execute(new OneTimeTask() {
@Override
public void run() {
assert handlerAdded;
invoker.invokeFlush(AbstractChannelHandlerContext.this);
}
});
}
}
final void invokeWriteAndFlush(final Object msg, final ChannelPromise promise) {
final ChannelHandlerInvoker invoker = invoker();
if (handlerAdded) {
invoker.invokeWrite(this, msg, promise);
invoker.invokeFlush(this);
} else {
invoker.executor().execute(new OneTimeTask() {
@Override
public void run() {
assert handlerAdded;
invoker.invokeWrite(AbstractChannelHandlerContext.this, msg, promise);
invoker.invokeFlush(AbstractChannelHandlerContext.this);
}
});
}
}
@Override
public ChannelHandlerInvoker invoker() {
return invoker == null ? channel().unsafe().invoker() : invoker;
}
final void setHandlerAddedCalled() {
handlerAdded = true;
}
@Override

View File

@ -74,7 +74,9 @@ final class DefaultChannelPipeline implements ChannelPipeline {
this.channel = channel;
tail = new TailContext(this);
tail.setHandlerAddedCalled();
head = new HeadContext(this);
head.setHandlerAddedCalled();
head.next = tail;
tail.prev = head;
@ -102,7 +104,7 @@ final class DefaultChannelPipeline implements ChannelPipeline {
}
@Override
public ChannelPipeline addFirst(ChannelHandlerInvoker invoker, String name, ChannelHandler handler) {
public synchronized ChannelPipeline addFirst(ChannelHandlerInvoker invoker, String name, ChannelHandler handler) {
name = filterName(name, handler);
addFirst0(new DefaultChannelHandlerContext(this, invoker, name, handler));
return this;
@ -392,7 +394,7 @@ final class DefaultChannelPipeline implements ChannelPipeline {
Future<?> future;
synchronized (this) {
if (!ctx.channel().isRegistered() || ctx.executor().inEventLoop()) {
if (!isExecuteLater(ctx)) {
remove0(ctx);
return ctx;
} else {
@ -474,7 +476,7 @@ final class DefaultChannelPipeline implements ChannelPipeline {
final AbstractChannelHandlerContext newCtx =
new DefaultChannelHandlerContext(this, ctx.invoker, newName, newHandler);
if (!newCtx.channel().isRegistered() || newCtx.executor().inEventLoop()) {
if (!isExecuteLater(newCtx)) {
replace0(ctx, newCtx);
return ctx.handler();
} else {
@ -537,7 +539,7 @@ final class DefaultChannelPipeline implements ChannelPipeline {
}
private void callHandlerAdded(final AbstractChannelHandlerContext ctx) {
if (ctx.channel().isRegistered() && !ctx.executor().inEventLoop()) {
if (isExecuteLater(ctx)) {
ctx.executor().execute(new OneTimeTask() {
@Override
public void run() {
@ -551,7 +553,12 @@ final class DefaultChannelPipeline implements ChannelPipeline {
private void callHandlerAdded0(final AbstractChannelHandlerContext ctx) {
try {
ctx.handler().handlerAdded(ctx);
try {
ctx.handler().handlerAdded(ctx);
} finally {
// handlerAdded(...) method was called.
ctx.setHandlerAddedCalled();
}
} catch (Throwable t) {
boolean removed = false;
try {
@ -576,7 +583,7 @@ final class DefaultChannelPipeline implements ChannelPipeline {
}
private void callHandlerRemoved(final AbstractChannelHandlerContext ctx) {
if (ctx.channel().isRegistered() && !ctx.executor().inEventLoop()) {
if (isExecuteLater(ctx)) {
ctx.executor().execute(new OneTimeTask() {
@Override
public void run() {
@ -599,6 +606,10 @@ final class DefaultChannelPipeline implements ChannelPipeline {
}
}
private static boolean isExecuteLater(ChannelHandlerContext ctx) {
return ctx.channel().isRegistered() && !ctx.executor().inEventLoop();
}
/**
* Waits for a future to finish. If the task is interrupted, then the current thread will be interrupted.
* It is expected that the task performs any appropriate locking.