Correctly run all pending tasks for EmbeddedChannel when the Channel was closed.

Motivation:

When a user called ctx.close() and used the EmbeddedChannel we did not correctly run all pending tasks which means channelInactive was never called.

Modifications:

Ensure we run all pending tasks after all operations that may change the Channel state and are part of the Channel.Unsafe impl.

Result:

Fixes [#6894].
This commit is contained in:
Norman Maurer 2017-06-23 14:34:57 +02:00
parent d9d3d65716
commit 8adb30bbe2
2 changed files with 119 additions and 3 deletions

View File

@ -35,6 +35,7 @@ import io.netty.channel.ChannelPromise;
import io.netty.channel.DefaultChannelConfig; import io.netty.channel.DefaultChannelConfig;
import io.netty.channel.DefaultChannelPipeline; import io.netty.channel.DefaultChannelPipeline;
import io.netty.channel.EventLoop; import io.netty.channel.EventLoop;
import io.netty.channel.RecvByteBufAllocator;
import io.netty.util.ReferenceCountUtil; import io.netty.util.ReferenceCountUtil;
import io.netty.util.internal.ObjectUtil; import io.netty.util.internal.ObjectUtil;
import io.netty.util.internal.PlatformDependent; import io.netty.util.internal.PlatformDependent;
@ -701,7 +702,12 @@ public class EmbeddedChannel extends AbstractChannel {
@Override @Override
protected AbstractUnsafe newUnsafe() { protected AbstractUnsafe newUnsafe() {
return new DefaultUnsafe(); return new EmbeddedUnsafe();
}
@Override
public Unsafe unsafe() {
return ((EmbeddedUnsafe) super.unsafe()).wrapped;
} }
@Override @Override
@ -734,7 +740,97 @@ public class EmbeddedChannel extends AbstractChannel {
inboundMessages().add(msg); inboundMessages().add(msg);
} }
private class DefaultUnsafe extends AbstractUnsafe { private final class EmbeddedUnsafe extends AbstractUnsafe {
// Delegates to the EmbeddedUnsafe instance but ensures runPendingTasks() is called after each operation
// that may change the state of the Channel and may schedule tasks for later execution.
final Unsafe wrapped = new Unsafe() {
@Override
public RecvByteBufAllocator.Handle recvBufAllocHandle() {
return EmbeddedUnsafe.this.recvBufAllocHandle();
}
@Override
public SocketAddress localAddress() {
return EmbeddedUnsafe.this.localAddress();
}
@Override
public SocketAddress remoteAddress() {
return EmbeddedUnsafe.this.remoteAddress();
}
@Override
public void register(EventLoop eventLoop, ChannelPromise promise) {
EmbeddedUnsafe.this.register(eventLoop, promise);
runPendingTasks();
}
@Override
public void bind(SocketAddress localAddress, ChannelPromise promise) {
EmbeddedUnsafe.this.bind(localAddress, promise);
runPendingTasks();
}
@Override
public void connect(SocketAddress remoteAddress, SocketAddress localAddress, ChannelPromise promise) {
EmbeddedUnsafe.this.connect(remoteAddress, localAddress, promise);
runPendingTasks();
}
@Override
public void disconnect(ChannelPromise promise) {
EmbeddedUnsafe.this.disconnect(promise);
runPendingTasks();
}
@Override
public void close(ChannelPromise promise) {
EmbeddedUnsafe.this.close(promise);
runPendingTasks();
}
@Override
public void closeForcibly() {
EmbeddedUnsafe.this.closeForcibly();
runPendingTasks();
}
@Override
public void deregister(ChannelPromise promise) {
EmbeddedUnsafe.this.deregister(promise);
runPendingTasks();
}
@Override
public void beginRead() {
EmbeddedUnsafe.this.beginRead();
runPendingTasks();
}
@Override
public void write(Object msg, ChannelPromise promise) {
EmbeddedUnsafe.this.write(msg, promise);
runPendingTasks();
}
@Override
public void flush() {
EmbeddedUnsafe.this.flush();
runPendingTasks();
}
@Override
public ChannelPromise voidPromise() {
return EmbeddedUnsafe.this.voidPromise();
}
@Override
public ChannelOutboundBuffer outboundBuffer() {
return EmbeddedUnsafe.this.outboundBuffer();
}
};
@Override @Override
public void connect(SocketAddress remoteAddress, SocketAddress localAddress, ChannelPromise promise) { public void connect(SocketAddress remoteAddress, SocketAddress localAddress, ChannelPromise promise) {
safeSetSuccess(promise); safeSetSuccess(promise);
@ -742,7 +838,7 @@ public class EmbeddedChannel extends AbstractChannel {
} }
private final class EmbeddedChannelPipeline extends DefaultChannelPipeline { private final class EmbeddedChannelPipeline extends DefaultChannelPipeline {
public EmbeddedChannelPipeline(EmbeddedChannel channel) { EmbeddedChannelPipeline(EmbeddedChannel channel) {
super(channel); super(channel);
} }

View File

@ -27,6 +27,7 @@ import java.util.ArrayDeque;
import java.util.Queue; import java.util.Queue;
import java.util.concurrent.CountDownLatch; import java.util.concurrent.CountDownLatch;
import java.util.concurrent.TimeUnit; import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicInteger; import java.util.concurrent.atomic.AtomicInteger;
import java.util.concurrent.atomic.AtomicReference; import java.util.concurrent.atomic.AtomicReference;
@ -563,6 +564,25 @@ public class EmbeddedChannelTest {
} }
} }
@Test(timeout = 5000)
public void testChannelInactiveFired() throws InterruptedException {
final AtomicBoolean inactive = new AtomicBoolean();
EmbeddedChannel channel = new EmbeddedChannel(new ChannelInboundHandlerAdapter() {
@Override
public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) throws Exception {
ctx.close();
}
@Override
public void channelInactive(ChannelHandlerContext ctx) throws Exception {
inactive.set(true);
}
});
channel.pipeline().fireExceptionCaught(new IllegalStateException());
assertTrue(inactive.get());
}
private static void release(ByteBuf... buffers) { private static void release(ByteBuf... buffers) {
for (ByteBuf buffer : buffers) { for (ByteBuf buffer : buffers) {
if (buffer.refCnt() > 0) { if (buffer.refCnt() > 0) {