Ensure ChannelHandler.handlerAdded(...) is always called as first method of the handler

Motivation:

If a user adds a ChannelHandler from outside the EventLoop it is possible to get into the situation that handlerAdded(...) is scheduled on the EventLoop and so called after another methods of the ChannelHandler as the EventLoop may already be executing on this point in time.

Modification:

- Ensure we always check if the handlerAdded(...) method was called already and if not add the currently needed call to the EventLoop so it will be picked up after handlerAdded(...) was called. This works as if the handler is added to the ChannelPipeline from outside the EventLoop the actual handlerAdded(...) operation is scheduled on the EventLoop.
- Some cleanup in the DefaultChannelPipeline

Result:

Correctly order of method executions of ChannelHandler.
This commit is contained in:
Norman Maurer 2016-01-15 14:23:41 +01:00
parent a2732c6542
commit af39cb6b12
2 changed files with 337 additions and 38 deletions

View File

@ -20,6 +20,7 @@ import io.netty.util.Attribute;
import io.netty.util.AttributeKey; import io.netty.util.AttributeKey;
import io.netty.util.ResourceLeakHint; import io.netty.util.ResourceLeakHint;
import io.netty.util.concurrent.EventExecutor; import io.netty.util.concurrent.EventExecutor;
import io.netty.util.internal.OneTimeTask;
import io.netty.util.internal.StringUtil; import io.netty.util.internal.StringUtil;
import java.net.SocketAddress; import java.net.SocketAddress;
@ -33,7 +34,23 @@ abstract class AbstractChannelHandlerContext implements ChannelHandlerContext, R
private final boolean outbound; private final boolean outbound;
private final DefaultChannelPipeline pipeline; private final DefaultChannelPipeline pipeline;
private final String name; private final String name;
private boolean removed; private boolean handlerRemoved;
/**
* This is set to {@code true} once the {@link ChannelHandler#handlerAdded(ChannelHandlerContext) method was called.
* We need to keep track of this This will set to true once the
* {@link ChannelHandler#handlerAdded(ChannelHandlerContext)} method was called. We need to keep track of this
* to ensure we will never call another {@link ChannelHandler} method before handlerAdded(...) was called
* to guard againstordering issues. {@link ChannelHandler#handlerAdded(ChannelHandlerContext)} MUST be the first
* method that is called for handler when it becomes a part of the {@link ChannelPipeline} in all cases. Not doing
* so may lead to unexpected side-effects as {@link ChannelHandler} implementationsmay need to do initialization
* steps before a {@link ChannelHandler} can be used.
*
* See <a href="https://github.com/netty/netty/issues/4705">#4705</a>
*
* No need to mark volatile as this will be made visible as next/prev is volatile.
*/
private boolean handlerAdded;
final ChannelHandlerInvoker invoker; final ChannelHandlerInvoker invoker;
private ChannelFuture succeededFuture; private ChannelFuture succeededFuture;
@ -82,15 +99,6 @@ abstract class AbstractChannelHandlerContext implements ChannelHandlerContext, R
return invoker().executor(); return invoker().executor();
} }
@Override
public ChannelHandlerInvoker invoker() {
if (invoker == null) {
return channel().unsafe().invoker();
} else {
return invoker;
}
}
@Override @Override
public String name() { public String name() {
return name; return name;
@ -109,63 +117,63 @@ abstract class AbstractChannelHandlerContext implements ChannelHandlerContext, R
@Override @Override
public ChannelHandlerContext fireChannelRegistered() { public ChannelHandlerContext fireChannelRegistered() {
AbstractChannelHandlerContext next = findContextInbound(); AbstractChannelHandlerContext next = findContextInbound();
next.invoker().invokeChannelRegistered(next); next.invokeChannelRegistered();
return this; return this;
} }
@Override @Override
public ChannelHandlerContext fireChannelUnregistered() { public ChannelHandlerContext fireChannelUnregistered() {
AbstractChannelHandlerContext next = findContextInbound(); AbstractChannelHandlerContext next = findContextInbound();
next.invoker().invokeChannelUnregistered(next); next.invokeChannelUnregistered();
return this; return this;
} }
@Override @Override
public ChannelHandlerContext fireChannelActive() { public ChannelHandlerContext fireChannelActive() {
AbstractChannelHandlerContext next = findContextInbound(); AbstractChannelHandlerContext next = findContextInbound();
next.invoker().invokeChannelActive(next); next.invokeChannelActive();
return this; return this;
} }
@Override @Override
public ChannelHandlerContext fireChannelInactive() { public ChannelHandlerContext fireChannelInactive() {
AbstractChannelHandlerContext next = findContextInbound(); AbstractChannelHandlerContext next = findContextInbound();
next.invoker().invokeChannelInactive(next); next.invokeChannelInactive();
return this; return this;
} }
@Override @Override
public ChannelHandlerContext fireExceptionCaught(Throwable cause) { public ChannelHandlerContext fireExceptionCaught(Throwable cause) {
AbstractChannelHandlerContext next = this.next; AbstractChannelHandlerContext next = this.next;
next.invoker().invokeExceptionCaught(next, cause); next.invokeExceptionCaught(cause);
return this; return this;
} }
@Override @Override
public ChannelHandlerContext fireUserEventTriggered(Object event) { public ChannelHandlerContext fireUserEventTriggered(Object event) {
AbstractChannelHandlerContext next = findContextInbound(); AbstractChannelHandlerContext next = findContextInbound();
next.invoker().invokeUserEventTriggered(next, event); next.invokeUserEventTriggered(event);
return this; return this;
} }
@Override @Override
public ChannelHandlerContext fireChannelRead(Object msg) { public ChannelHandlerContext fireChannelRead(Object msg) {
AbstractChannelHandlerContext next = findContextInbound(); AbstractChannelHandlerContext next = findContextInbound();
next.invoker().invokeChannelRead(next, pipeline.touch(msg, next)); next.invokeChannelRead(pipeline.touch(msg, next));
return this; return this;
} }
@Override @Override
public ChannelHandlerContext fireChannelReadComplete() { public ChannelHandlerContext fireChannelReadComplete() {
AbstractChannelHandlerContext next = findContextInbound(); AbstractChannelHandlerContext next = findContextInbound();
next.invoker().invokeChannelReadComplete(next); next.invokeChannelReadComplete();
return this; return this;
} }
@Override @Override
public ChannelHandlerContext fireChannelWritabilityChanged() { public ChannelHandlerContext fireChannelWritabilityChanged() {
AbstractChannelHandlerContext next = findContextInbound(); AbstractChannelHandlerContext next = findContextInbound();
next.invoker().invokeChannelWritabilityChanged(next); next.invokeChannelWritabilityChanged();
return this; return this;
} }
@ -202,7 +210,7 @@ abstract class AbstractChannelHandlerContext implements ChannelHandlerContext, R
@Override @Override
public ChannelFuture bind(final SocketAddress localAddress, final ChannelPromise promise) { public ChannelFuture bind(final SocketAddress localAddress, final ChannelPromise promise) {
AbstractChannelHandlerContext next = findContextOutbound(); AbstractChannelHandlerContext next = findContextOutbound();
next.invoker().invokeBind(next, localAddress, promise); next.invokeBind(localAddress, promise);
return promise; return promise;
} }
@ -214,7 +222,7 @@ abstract class AbstractChannelHandlerContext implements ChannelHandlerContext, R
@Override @Override
public ChannelFuture connect(SocketAddress remoteAddress, SocketAddress localAddress, ChannelPromise promise) { public ChannelFuture connect(SocketAddress remoteAddress, SocketAddress localAddress, ChannelPromise promise) {
AbstractChannelHandlerContext next = findContextOutbound(); AbstractChannelHandlerContext next = findContextOutbound();
next.invoker().invokeConnect(next, remoteAddress, localAddress, promise); next.invokeConnect(remoteAddress, localAddress, promise);
return promise; return promise;
} }
@ -225,28 +233,28 @@ abstract class AbstractChannelHandlerContext implements ChannelHandlerContext, R
} }
AbstractChannelHandlerContext next = findContextOutbound(); AbstractChannelHandlerContext next = findContextOutbound();
next.invoker().invokeDisconnect(next, promise); next.invokeDisconnect(promise);
return promise; return promise;
} }
@Override @Override
public ChannelFuture close(ChannelPromise promise) { public ChannelFuture close(ChannelPromise promise) {
AbstractChannelHandlerContext next = findContextOutbound(); AbstractChannelHandlerContext next = findContextOutbound();
next.invoker().invokeClose(next, promise); next.invokeClose(promise);
return promise; return promise;
} }
@Override @Override
public ChannelFuture deregister(ChannelPromise promise) { public ChannelFuture deregister(ChannelPromise promise) {
AbstractChannelHandlerContext next = findContextOutbound(); AbstractChannelHandlerContext next = findContextOutbound();
next.invoker().invokeDeregister(next, promise); next.invokeDeregister(promise);
return promise; return promise;
} }
@Override @Override
public ChannelHandlerContext read() { public ChannelHandlerContext read() {
AbstractChannelHandlerContext next = findContextOutbound(); AbstractChannelHandlerContext next = findContextOutbound();
next.invoker().invokeRead(next); next.invokeRead();
return this; return this;
} }
@ -258,23 +266,21 @@ abstract class AbstractChannelHandlerContext implements ChannelHandlerContext, R
@Override @Override
public ChannelFuture write(Object msg, ChannelPromise promise) { public ChannelFuture write(Object msg, ChannelPromise promise) {
AbstractChannelHandlerContext next = findContextOutbound(); AbstractChannelHandlerContext next = findContextOutbound();
next.invoker().invokeWrite(next, pipeline.touch(msg, next), promise); next.invokeWrite(pipeline.touch(msg, next), promise);
return promise; return promise;
} }
@Override @Override
public ChannelHandlerContext flush() { public ChannelHandlerContext flush() {
AbstractChannelHandlerContext next = findContextOutbound(); AbstractChannelHandlerContext next = findContextOutbound();
next.invoker().invokeFlush(next); next.invokeFlush();
return this; return this;
} }
@Override @Override
public ChannelFuture writeAndFlush(Object msg, ChannelPromise promise) { public ChannelFuture writeAndFlush(Object msg, ChannelPromise promise) {
AbstractChannelHandlerContext next = findContextOutbound(); AbstractChannelHandlerContext next = findContextOutbound();
ChannelHandlerInvoker invoker = next.invoker(); next.invokeWriteAndFlush(pipeline.touch(msg, next), promise);
invoker.invokeWrite(next, pipeline.touch(msg, next) , promise);
invoker.invokeFlush(next);
return promise; return promise;
} }
@ -329,12 +335,294 @@ abstract class AbstractChannelHandlerContext implements ChannelHandlerContext, R
} }
void setRemoved() { void setRemoved() {
removed = true; handlerRemoved = true;
} }
@Override @Override
public boolean isRemoved() { public boolean isRemoved() {
return removed; return handlerRemoved;
}
final void invokeChannelRegistered() {
final ChannelHandlerInvoker invoker = invoker();
if (handlerAdded) {
invoker.invokeChannelRegistered(this);
} else {
invoker.executor().execute(new OneTimeTask() {
@Override
public void run() {
assert handlerAdded;
invoker.invokeChannelRegistered(AbstractChannelHandlerContext.this);
}
});
}
}
final void invokeChannelUnregistered() {
final ChannelHandlerInvoker invoker = invoker();
if (handlerAdded) {
invoker.invokeChannelUnregistered(this);
} else {
invoker.executor().execute(new OneTimeTask() {
@Override
public void run() {
assert handlerAdded;
invoker.invokeChannelUnregistered(AbstractChannelHandlerContext.this);
}
});
}
}
final void invokeChannelActive() {
final ChannelHandlerInvoker invoker = invoker();
if (handlerAdded) {
invoker.invokeChannelActive(this);
} else {
invoker.executor().execute(new OneTimeTask() {
@Override
public void run() {
assert handlerAdded;
invoker.invokeChannelActive(AbstractChannelHandlerContext.this);
}
});
}
}
final void invokeChannelInactive() {
final ChannelHandlerInvoker invoker = invoker();
if (handlerAdded) {
invoker.invokeChannelInactive(this);
} else {
invoker.executor().execute(new OneTimeTask() {
@Override
public void run() {
assert handlerAdded;
invoker.invokeChannelInactive(AbstractChannelHandlerContext.this);
}
});
}
}
final void invokeExceptionCaught(final Throwable cause) {
final ChannelHandlerInvoker invoker = invoker();
if (handlerAdded) {
invoker.invokeExceptionCaught(this, cause);
} else {
invoker.executor().execute(new OneTimeTask() {
@Override
public void run() {
assert handlerAdded;
invoker.invokeExceptionCaught(AbstractChannelHandlerContext.this, cause);
}
});
}
}
final void invokeUserEventTriggered(final Object event) {
final ChannelHandlerInvoker invoker = invoker();
if (handlerAdded) {
invoker.invokeUserEventTriggered(this, event);
} else {
invoker.executor().execute(new OneTimeTask() {
@Override
public void run() {
assert handlerAdded;
invoker.invokeUserEventTriggered(AbstractChannelHandlerContext.this, event);
}
});
}
}
final void invokeChannelRead(final Object msg) {
final ChannelHandlerInvoker invoker = invoker();
if (handlerAdded) {
invoker.invokeChannelRead(this, msg);
} else {
invoker.executor().execute(new OneTimeTask() {
@Override
public void run() {
assert handlerAdded;
invoker.invokeChannelRead(AbstractChannelHandlerContext.this, msg);
}
});
}
}
final void invokeChannelReadComplete() {
final ChannelHandlerInvoker invoker = invoker();
if (handlerAdded) {
invoker.invokeChannelReadComplete(this);
} else {
invoker.executor().execute(new OneTimeTask() {
@Override
public void run() {
assert handlerAdded;
invoker.invokeChannelReadComplete(AbstractChannelHandlerContext.this);
}
});
}
}
final void invokeChannelWritabilityChanged() {
final ChannelHandlerInvoker invoker = invoker();
if (handlerAdded) {
invoker.invokeChannelWritabilityChanged(this);
} else {
invoker.executor().execute(new OneTimeTask() {
@Override
public void run() {
assert handlerAdded;
invoker.invokeChannelWritabilityChanged(AbstractChannelHandlerContext.this);
}
});
}
}
final void invokeBind(final SocketAddress localAddress, final ChannelPromise promise) {
final ChannelHandlerInvoker invoker = invoker();
if (handlerAdded) {
invoker.invokeBind(this, localAddress, promise);
} else {
invoker.executor().execute(new OneTimeTask() {
@Override
public void run() {
assert handlerAdded;
invoker.invokeBind(AbstractChannelHandlerContext.this, localAddress, promise);
}
});
}
}
final void invokeConnect(final SocketAddress remoteAddress,
final SocketAddress localAddress, final ChannelPromise promise) {
final ChannelHandlerInvoker invoker = invoker();
if (handlerAdded) {
invoker.invokeConnect(this, remoteAddress, localAddress, promise);
} else {
invoker.executor().execute(new OneTimeTask() {
@Override
public void run() {
assert handlerAdded;
invoker.invokeConnect(AbstractChannelHandlerContext.this, remoteAddress, localAddress, promise);
}
});
}
}
final void invokeDisconnect(final ChannelPromise promise) {
final ChannelHandlerInvoker invoker = invoker();
if (handlerAdded) {
invoker.invokeDisconnect(this, promise);
} else {
invoker.executor().execute(new OneTimeTask() {
@Override
public void run() {
assert handlerAdded;
invoker.invokeDisconnect(AbstractChannelHandlerContext.this, promise);
}
});
}
}
final void invokeClose(final ChannelPromise promise) {
final ChannelHandlerInvoker invoker = invoker();
if (handlerAdded) {
invoker.invokeClose(this, promise);
} else {
invoker.executor().execute(new OneTimeTask() {
@Override
public void run() {
assert handlerAdded;
invoker.invokeClose(AbstractChannelHandlerContext.this, promise);
}
});
}
}
final void invokeDeregister(final ChannelPromise promise) {
final ChannelHandlerInvoker invoker = invoker();
if (handlerAdded) {
invoker.invokeDeregister(this, promise);
} else {
invoker.executor().execute(new OneTimeTask() {
@Override
public void run() {
assert handlerAdded;
invoker.invokeDeregister(AbstractChannelHandlerContext.this, promise);
}
});
}
}
final void invokeRead() {
final ChannelHandlerInvoker invoker = invoker();
if (handlerAdded) {
invoker.invokeRead(this);
} else {
invoker.executor().execute(new OneTimeTask() {
@Override
public void run() {
assert handlerAdded;
invoker.invokeRead(AbstractChannelHandlerContext.this);
}
});
}
}
final void invokeWrite(final Object msg, final ChannelPromise promise) {
final ChannelHandlerInvoker invoker = invoker();
if (handlerAdded) {
invoker.invokeWrite(this, msg, promise);
} else {
invoker.executor().execute(new OneTimeTask() {
@Override
public void run() {
assert handlerAdded;
invoker.invokeWrite(AbstractChannelHandlerContext.this, msg, promise);
}
});
}
}
final void invokeFlush() {
final ChannelHandlerInvoker invoker = invoker();
if (handlerAdded) {
invoker.invokeFlush(this);
} else {
invoker.executor().execute(new OneTimeTask() {
@Override
public void run() {
assert handlerAdded;
invoker.invokeFlush(AbstractChannelHandlerContext.this);
}
});
}
}
final void invokeWriteAndFlush(final Object msg, final ChannelPromise promise) {
final ChannelHandlerInvoker invoker = invoker();
if (handlerAdded) {
invoker.invokeWrite(this, msg, promise);
invoker.invokeFlush(this);
} else {
invoker.executor().execute(new OneTimeTask() {
@Override
public void run() {
assert handlerAdded;
invoker.invokeWrite(AbstractChannelHandlerContext.this, msg, promise);
invoker.invokeFlush(AbstractChannelHandlerContext.this);
}
});
}
}
@Override
public ChannelHandlerInvoker invoker() {
return invoker == null ? channel().unsafe().invoker() : invoker;
}
final void setHandlerAddedCalled() {
handlerAdded = true;
} }
@Override @Override

View File

@ -74,7 +74,9 @@ final class DefaultChannelPipeline implements ChannelPipeline {
this.channel = channel; this.channel = channel;
tail = new TailContext(this); tail = new TailContext(this);
tail.setHandlerAddedCalled();
head = new HeadContext(this); head = new HeadContext(this);
head.setHandlerAddedCalled();
head.next = tail; head.next = tail;
tail.prev = head; tail.prev = head;
@ -102,7 +104,7 @@ final class DefaultChannelPipeline implements ChannelPipeline {
} }
@Override @Override
public ChannelPipeline addFirst(ChannelHandlerInvoker invoker, String name, ChannelHandler handler) { public synchronized ChannelPipeline addFirst(ChannelHandlerInvoker invoker, String name, ChannelHandler handler) {
name = filterName(name, handler); name = filterName(name, handler);
addFirst0(new DefaultChannelHandlerContext(this, invoker, name, handler)); addFirst0(new DefaultChannelHandlerContext(this, invoker, name, handler));
return this; return this;
@ -392,7 +394,7 @@ final class DefaultChannelPipeline implements ChannelPipeline {
Future<?> future; Future<?> future;
synchronized (this) { synchronized (this) {
if (!ctx.channel().isRegistered() || ctx.executor().inEventLoop()) { if (!isExecuteLater(ctx)) {
remove0(ctx); remove0(ctx);
return ctx; return ctx;
} else { } else {
@ -474,7 +476,7 @@ final class DefaultChannelPipeline implements ChannelPipeline {
final AbstractChannelHandlerContext newCtx = final AbstractChannelHandlerContext newCtx =
new DefaultChannelHandlerContext(this, ctx.invoker, newName, newHandler); new DefaultChannelHandlerContext(this, ctx.invoker, newName, newHandler);
if (!newCtx.channel().isRegistered() || newCtx.executor().inEventLoop()) { if (!isExecuteLater(newCtx)) {
replace0(ctx, newCtx); replace0(ctx, newCtx);
return ctx.handler(); return ctx.handler();
} else { } else {
@ -537,7 +539,7 @@ final class DefaultChannelPipeline implements ChannelPipeline {
} }
private void callHandlerAdded(final AbstractChannelHandlerContext ctx) { private void callHandlerAdded(final AbstractChannelHandlerContext ctx) {
if (ctx.channel().isRegistered() && !ctx.executor().inEventLoop()) { if (isExecuteLater(ctx)) {
ctx.executor().execute(new OneTimeTask() { ctx.executor().execute(new OneTimeTask() {
@Override @Override
public void run() { public void run() {
@ -551,7 +553,12 @@ final class DefaultChannelPipeline implements ChannelPipeline {
private void callHandlerAdded0(final AbstractChannelHandlerContext ctx) { private void callHandlerAdded0(final AbstractChannelHandlerContext ctx) {
try { try {
ctx.handler().handlerAdded(ctx); try {
ctx.handler().handlerAdded(ctx);
} finally {
// handlerAdded(...) method was called.
ctx.setHandlerAddedCalled();
}
} catch (Throwable t) { } catch (Throwable t) {
boolean removed = false; boolean removed = false;
try { try {
@ -576,7 +583,7 @@ final class DefaultChannelPipeline implements ChannelPipeline {
} }
private void callHandlerRemoved(final AbstractChannelHandlerContext ctx) { private void callHandlerRemoved(final AbstractChannelHandlerContext ctx) {
if (ctx.channel().isRegistered() && !ctx.executor().inEventLoop()) { if (isExecuteLater(ctx)) {
ctx.executor().execute(new OneTimeTask() { ctx.executor().execute(new OneTimeTask() {
@Override @Override
public void run() { public void run() {
@ -599,6 +606,10 @@ final class DefaultChannelPipeline implements ChannelPipeline {
} }
} }
private static boolean isExecuteLater(ChannelHandlerContext ctx) {
return ctx.channel().isRegistered() && !ctx.executor().inEventLoop();
}
/** /**
* Waits for a future to finish. If the task is interrupted, then the current thread will be interrupted. * Waits for a future to finish. If the task is interrupted, then the current thread will be interrupted.
* It is expected that the task performs any appropriate locking. * It is expected that the task performs any appropriate locking.