From fe66f33f4219fa1d40fee2e90eecc8ea4d8ae0fb Mon Sep 17 00:00:00 2001 From: Norman Maurer Date: Tue, 12 Mar 2013 07:19:31 +0100 Subject: [PATCH] Fix issue where the bytes/messages are forwarded to the wrong handler --- .../channel/DefaultChannelHandlerContext.java | 103 +++- .../netty/channel/DefaultChannelPipeline.java | 9 +- .../channel/DefaultChannelPipelineTest.java | 503 ++++++++++++++++++ 3 files changed, 593 insertions(+), 22 deletions(-) diff --git a/transport/src/main/java/io/netty/channel/DefaultChannelHandlerContext.java b/transport/src/main/java/io/netty/channel/DefaultChannelHandlerContext.java index fa10b1caa9..60f52152c9 100755 --- a/transport/src/main/java/io/netty/channel/DefaultChannelHandlerContext.java +++ b/transport/src/main/java/io/netty/channel/DefaultChannelHandlerContext.java @@ -221,25 +221,94 @@ final class DefaultChannelHandlerContext extends DefaultAttributeMap implements outMsgBuf = null; } - void forwardBufferContent() { + void forwardBufferContent(final DefaultChannelHandlerContext forwardPrev, + final DefaultChannelHandlerContext forwardNext) { + boolean flush = false; + boolean inboundBufferUpdated = false; if (hasOutboundByteBuffer() && outboundByteBuffer().isReadable()) { - nextOutboundByteBuffer().writeBytes(outboundByteBuffer()); - flush(); + ByteBuf forwardPrevBuf; + if (forwardPrev.hasOutboundByteBuffer()) { + forwardPrevBuf = forwardPrev.outboundByteBuffer(); + } else { + forwardPrevBuf = forwardPrev.nextOutboundByteBuffer(); + } + forwardPrevBuf.writeBytes(outboundByteBuffer()); + flush = true; } if (hasOutboundMessageBuffer() && !outboundMessageBuffer().isEmpty()) { - if (outboundMessageBuffer().drainTo(nextOutboundMessageBuffer()) > 0) { - flush(); + MessageBuf forwardPrevBuf; + if (forwardPrev.hasOutboundMessageBuffer()) { + forwardPrevBuf = forwardPrev.outboundMessageBuffer(); + } else { + forwardPrevBuf = forwardPrev.nextOutboundMessageBuffer(); + } + if (outboundMessageBuffer().drainTo(forwardPrevBuf) > 0) { + flush = true; } } if (hasInboundByteBuffer() && inboundByteBuffer().isReadable()) { - nextInboundByteBuffer().writeBytes(inboundByteBuffer()); - fireInboundBufferUpdated(); + ByteBuf forwardNextBuf; + if (forwardNext.hasInboundByteBuffer()) { + forwardNextBuf = forwardNext.inboundByteBuffer(); + } else { + forwardNextBuf = forwardNext.nextInboundByteBuffer(); + } + forwardNextBuf.writeBytes(inboundByteBuffer()); + inboundBufferUpdated = true; } if (hasInboundMessageBuffer() && !inboundMessageBuffer().isEmpty()) { - if (inboundMessageBuffer().drainTo(nextInboundMessageBuffer()) > 0) { - fireInboundBufferUpdated(); + MessageBuf forwardNextBuf; + if (forwardNext.hasInboundMessageBuffer()) { + forwardNextBuf = forwardNext.inboundMessageBuffer(); + } else { + forwardNextBuf = forwardNext.nextInboundMessageBuffer(); + } + if (inboundMessageBuffer().drainTo(forwardNextBuf) > 0) { + inboundBufferUpdated = true; } } + if (flush) { + EventExecutor executor = executor(); + Thread currentThread = Thread.currentThread(); + if (executor.inEventLoop(currentThread)) { + invokePrevFlush(newPromise(), currentThread, findContextOutboundInclusive(forwardPrev)); + } else { + executor.execute(new Runnable() { + @Override + public void run() { + invokePrevFlush(newPromise(), Thread.currentThread(), + findContextOutboundInclusive(forwardPrev)); + } + }); + } + } + if (inboundBufferUpdated) { + EventExecutor executor = executor(); + if (executor.inEventLoop()) { + fireInboundBufferUpdated0(findContextInboundInclusive(forwardNext)); + } else { + executor.execute(new Runnable() { + @Override + public void run() { + fireInboundBufferUpdated0(findContextInboundInclusive(forwardNext)); + } + }); + } + } + } + + private static DefaultChannelHandlerContext findContextOutboundInclusive(DefaultChannelHandlerContext ctx) { + if (ctx.handler() instanceof ChannelOperationHandler) { + return ctx; + } + return ctx.findContextOutbound(); + } + + private static DefaultChannelHandlerContext findContextInboundInclusive(DefaultChannelHandlerContext ctx) { + if (ctx.handler() instanceof ChannelStateHandler) { + return ctx; + } + return ctx.findContextInbound(); } void clearBuffer() { @@ -889,14 +958,14 @@ final class DefaultChannelHandlerContext extends DefaultAttributeMap implements public ChannelHandlerContext fireInboundBufferUpdated() { EventExecutor executor = executor(); if (executor.inEventLoop()) { - fireInboundBufferUpdated0(); + fireInboundBufferUpdated0(findContextInbound()); } else { Runnable task = fireInboundBufferUpdated0Task; if (task == null) { fireInboundBufferUpdated0Task = task = new Runnable() { @Override public void run() { - fireInboundBufferUpdated0(); + fireInboundBufferUpdated0(findContextInbound()); } }; } @@ -905,8 +974,7 @@ final class DefaultChannelHandlerContext extends DefaultAttributeMap implements return this; } - private void fireInboundBufferUpdated0() { - final DefaultChannelHandlerContext next = findContextInbound(); + private void fireInboundBufferUpdated0(final DefaultChannelHandlerContext next) { if (!pipeline.isInboundShutdown()) { next.fillInboundBridge(); // This comparison is safe because this method is always executed from the executor. @@ -926,7 +994,7 @@ final class DefaultChannelHandlerContext extends DefaultAttributeMap implements next.invokeInboundBufferUpdated(); } else { // Pipeline changed since the task was submitted; try again. - fireInboundBufferUpdated0(); + fireInboundBufferUpdated0(next); } } }; @@ -1265,12 +1333,12 @@ final class DefaultChannelHandlerContext extends DefaultAttributeMap implements EventExecutor executor = executor(); Thread currentThread = Thread.currentThread(); if (executor.inEventLoop(currentThread)) { - invokePrevFlush(promise, currentThread); + invokePrevFlush(promise, currentThread, findContextOutbound()); } else { executor.execute(new Runnable() { @Override public void run() { - invokePrevFlush(promise, Thread.currentThread()); + invokePrevFlush(promise, Thread.currentThread(), findContextOutbound()); } }); } @@ -1278,8 +1346,7 @@ final class DefaultChannelHandlerContext extends DefaultAttributeMap implements return promise; } - private void invokePrevFlush(ChannelPromise promise, Thread currentThread) { - DefaultChannelHandlerContext prev = findContextOutbound(); + private void invokePrevFlush(ChannelPromise promise, Thread currentThread, DefaultChannelHandlerContext prev) { if (pipeline.isOutboundShutdown()) { promise.setFailure(new ChannelPipelineException( "Unable to flush as outbound buffer of next handler was freed already")); diff --git a/transport/src/main/java/io/netty/channel/DefaultChannelPipeline.java b/transport/src/main/java/io/netty/channel/DefaultChannelPipeline.java index 8dd090d311..4faed5641b 100755 --- a/transport/src/main/java/io/netty/channel/DefaultChannelPipeline.java +++ b/transport/src/main/java/io/netty/channel/DefaultChannelPipeline.java @@ -473,7 +473,7 @@ final class DefaultChannelPipeline implements ChannelPipeline { next.prev = prev; name2ctx.remove(ctx.name()); - callAfterRemove(ctx, forward); + callAfterRemove(ctx, prev, next, forward); } @Override @@ -592,7 +592,7 @@ final class DefaultChannelPipeline implements ChannelPipeline { ChannelPipelineException addException = null; boolean removed = false; try { - callAfterRemove(ctx, forward); + callAfterRemove(ctx, newCtx, newCtx, forward); removed = true; } catch (ChannelPipelineException e) { removeException = e; @@ -676,7 +676,8 @@ final class DefaultChannelPipeline implements ChannelPipeline { } } - private void callAfterRemove(final DefaultChannelHandlerContext ctx, boolean forward) { + private void callAfterRemove(final DefaultChannelHandlerContext ctx, DefaultChannelHandlerContext ctxPrev, + DefaultChannelHandlerContext ctxNext, boolean forward) { final ChannelHandler handler = ctx.handler(); // Notify the complete removal. @@ -689,7 +690,7 @@ final class DefaultChannelPipeline implements ChannelPipeline { } if (forward) { - ctx.forwardBufferContent(); + ctx.forwardBufferContent(ctxPrev, ctxNext); } else { ctx.clearBuffer(); } diff --git a/transport/src/test/java/io/netty/channel/DefaultChannelPipelineTest.java b/transport/src/test/java/io/netty/channel/DefaultChannelPipelineTest.java index 0f4f8dbf5e..0ccd27de53 100644 --- a/transport/src/test/java/io/netty/channel/DefaultChannelPipelineTest.java +++ b/transport/src/test/java/io/netty/channel/DefaultChannelPipelineTest.java @@ -17,12 +17,15 @@ package io.netty.channel; import io.netty.buffer.ByteBuf; +import io.netty.buffer.MessageBuf; import io.netty.buffer.ReferenceCounted; +import io.netty.buffer.Unpooled; import io.netty.channel.ChannelHandler.Sharable; import io.netty.channel.local.LocalChannel; import io.netty.channel.local.LocalEventLoopGroup; import org.junit.Test; +import java.net.SocketAddress; import java.util.concurrent.CountDownLatch; import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicBoolean; @@ -306,6 +309,382 @@ public class DefaultChannelPipelineTest { verifyContextNumber(pipeline, 8); } + @Test + public void testRemoveAndForwardInboundByte() throws Exception { + LocalChannel channel = new LocalChannel(); + LocalEventLoopGroup group = new LocalEventLoopGroup(); + group.register(channel).awaitUninterruptibly(); + final DefaultChannelPipeline pipeline = new DefaultChannelPipeline(channel); + + + final ChannelInboundByteHandlerImpl handler1 = new ChannelInboundByteHandlerImpl(); + final ChannelInboundByteHandlerImpl handler2 = new ChannelInboundByteHandlerImpl(); + pipeline.addLast("handler1", handler1); + pipeline.addLast("handler2", handler2); + final CountDownLatch latch = new CountDownLatch(1); + channel.eventLoop().execute(new Runnable() { + @Override + public void run() { + pipeline.context(handler1).inboundByteBuffer().writeLong(8); + assertEquals(8, pipeline.context(handler1).inboundByteBuffer().readableBytes()); + assertEquals(0, pipeline.context(handler2).inboundByteBuffer().readableBytes()); + pipeline.removeAndForward(handler1); + assertEquals(8, pipeline.context(handler2).inboundByteBuffer().readableBytes()); + latch.countDown(); + } + }); + + assertTrue(latch.await(10, TimeUnit.SECONDS)); + assertTrue(handler2.updated); + } + + @Test + public void testReplaceAndForwardInboundByte() throws Exception { + LocalChannel channel = new LocalChannel(); + LocalEventLoopGroup group = new LocalEventLoopGroup(); + group.register(channel).awaitUninterruptibly(); + final DefaultChannelPipeline pipeline = new DefaultChannelPipeline(channel); + + + final ChannelInboundByteHandlerImpl handler1 = new ChannelInboundByteHandlerImpl(); + final ChannelInboundByteHandlerImpl handler2 = new ChannelInboundByteHandlerImpl(); + pipeline.addLast("handler1", handler1); + final CountDownLatch latch = new CountDownLatch(1); + channel.eventLoop().execute(new Runnable() { + @Override + public void run() { + pipeline.context(handler1).inboundByteBuffer().writeLong(8); + assertEquals(8, pipeline.context(handler1).inboundByteBuffer().readableBytes()); + pipeline.replaceAndForward(handler1, "handler2", handler2); + assertEquals(8, pipeline.context(handler2).inboundByteBuffer().readableBytes()); + latch.countDown(); + } + }); + + assertTrue(latch.await(10, TimeUnit.SECONDS)); + assertTrue(handler2.updated); + + } + + @Test + public void testRemoveAndForwardOutboundByte() throws Exception { + LocalChannel channel = new LocalChannel(); + LocalEventLoopGroup group = new LocalEventLoopGroup(); + group.register(channel).awaitUninterruptibly(); + final DefaultChannelPipeline pipeline = new DefaultChannelPipeline(channel); + + + final ChannelOutboundByteHandlerImpl handler1 = new ChannelOutboundByteHandlerImpl(); + final ChannelOutboundByteHandlerImpl handler2 = new ChannelOutboundByteHandlerImpl(); + pipeline.addLast("handler1", handler1); + pipeline.addLast("handler2", handler2); + final CountDownLatch latch = new CountDownLatch(1); + channel.eventLoop().execute(new Runnable() { + @Override + public void run() { + pipeline.context(handler2).outboundByteBuffer().writeLong(8); + assertEquals(8, pipeline.context(handler2).outboundByteBuffer().readableBytes()); + assertEquals(0, pipeline.context(handler1).outboundByteBuffer().readableBytes()); + pipeline.removeAndForward(handler2); + assertEquals(8, pipeline.context(handler1).outboundByteBuffer().readableBytes()); + latch.countDown(); + } + }); + + assertTrue(latch.await(10, TimeUnit.SECONDS)); + assertTrue(handler1.flushed); + + } + + @Test + public void testReplaceAndForwardOutboundByte() throws Exception { + LocalChannel channel = new LocalChannel(); + LocalEventLoopGroup group = new LocalEventLoopGroup(); + group.register(channel).awaitUninterruptibly(); + final DefaultChannelPipeline pipeline = new DefaultChannelPipeline(channel); + + + final ChannelOutboundByteHandlerImpl handler1 = new ChannelOutboundByteHandlerImpl(); + final ChannelOutboundByteHandlerImpl handler2 = new ChannelOutboundByteHandlerImpl(); + pipeline.addLast("handler1", handler1); + final CountDownLatch latch = new CountDownLatch(1); + channel.eventLoop().execute(new Runnable() { + @Override + public void run() { + pipeline.context(handler1).outboundByteBuffer().writeLong(8); + assertEquals(8, pipeline.context(handler1).outboundByteBuffer().readableBytes()); + pipeline.replaceAndForward(handler1, "handler2", handler2); + assertEquals(8, pipeline.context(handler2).outboundByteBuffer().readableBytes()); + latch.countDown(); + } + }); + + assertTrue(latch.await(10, TimeUnit.SECONDS)); + assertTrue(handler2.flushed); + } + + @Test + public void testReplaceAndForwardDuplexByte() throws Exception { + LocalChannel channel = new LocalChannel(); + LocalEventLoopGroup group = new LocalEventLoopGroup(); + group.register(channel).awaitUninterruptibly(); + final DefaultChannelPipeline pipeline = new DefaultChannelPipeline(channel); + + + final ByteHandlerImpl handler1 = new ByteHandlerImpl(); + final ByteHandlerImpl handler2 = new ByteHandlerImpl(); + pipeline.addLast("handler1", handler1); + final CountDownLatch latch = new CountDownLatch(1); + channel.eventLoop().execute(new Runnable() { + @Override + public void run() { + pipeline.context(handler1).outboundByteBuffer().writeLong(8); + pipeline.context(handler1).inboundByteBuffer().writeLong(8); + + assertEquals(8, pipeline.context(handler1).outboundByteBuffer().readableBytes()); + assertEquals(8, pipeline.context(handler1).inboundByteBuffer().readableBytes()); + + pipeline.replaceAndForward(handler1, "handler2", handler2); + assertEquals(8, pipeline.context(handler2).outboundByteBuffer().readableBytes()); + assertEquals(8, pipeline.context(handler2).inboundByteBuffer().readableBytes()); + + latch.countDown(); + } + }); + + assertTrue(latch.await(10, TimeUnit.SECONDS)); + assertTrue(((ChannelInboundByteHandlerImpl)handler2.stateHandler()).updated); + assertTrue(((ChannelOutboundByteHandlerImpl)handler2.operationHandler()).flushed); + } + + @Test + public void testRemoveAndForwardDuplexByte() throws Exception { + LocalChannel channel = new LocalChannel(); + LocalEventLoopGroup group = new LocalEventLoopGroup(); + group.register(channel).awaitUninterruptibly(); + final DefaultChannelPipeline pipeline = new DefaultChannelPipeline(channel); + + + final ChannelOutboundByteHandlerImpl handler1 = new ChannelOutboundByteHandlerImpl(); + final ByteHandlerImpl handler2 = new ByteHandlerImpl(); + final ChannelInboundByteHandlerImpl handler3 = new ChannelInboundByteHandlerImpl(); + pipeline.addLast("handler1", handler1); + pipeline.addLast("handler2", handler2); + pipeline.addLast("handler3", handler3); + + final CountDownLatch latch = new CountDownLatch(1); + channel.eventLoop().execute(new Runnable() { + @Override + public void run() { + pipeline.context(handler2).outboundByteBuffer().writeLong(8); + pipeline.context(handler2).inboundByteBuffer().writeLong(8); + + assertEquals(8, pipeline.context(handler2).outboundByteBuffer().readableBytes()); + assertEquals(8, pipeline.context(handler2).inboundByteBuffer().readableBytes()); + + assertEquals(0, pipeline.context(handler1).outboundByteBuffer().readableBytes()); + assertEquals(0, pipeline.context(handler3).inboundByteBuffer().readableBytes()); + + pipeline.removeAndForward(handler2); + assertEquals(8, pipeline.context(handler1).outboundByteBuffer().readableBytes()); + assertEquals(8, pipeline.context(handler3).inboundByteBuffer().readableBytes()); + latch.countDown(); + } + }); + + assertTrue(latch.await(10, TimeUnit.SECONDS)); + assertTrue(handler1.flushed); + assertTrue(handler3.updated); + + } + + @Test + public void testRemoveAndForwardInboundMessage() throws Exception { + LocalChannel channel = new LocalChannel(); + LocalEventLoopGroup group = new LocalEventLoopGroup(); + group.register(channel).awaitUninterruptibly(); + final DefaultChannelPipeline pipeline = new DefaultChannelPipeline(channel); + + + final ChannelInboundMessageHandlerImpl handler1 = new ChannelInboundMessageHandlerImpl(); + final ChannelInboundMessageHandlerImpl handler2 = new ChannelInboundMessageHandlerImpl(); + pipeline.addLast("handler1", handler1); + pipeline.addLast("handler2", handler2); + final CountDownLatch latch = new CountDownLatch(1); + channel.eventLoop().execute(new Runnable() { + @Override + public void run() { + pipeline.context(handler1).inboundMessageBuffer().add(new Object()); + assertEquals(1, pipeline.context(handler1).inboundMessageBuffer().size()); + assertEquals(0, pipeline.context(handler2).inboundMessageBuffer().size()); + pipeline.removeAndForward(handler1); + assertEquals(1, pipeline.context(handler2).inboundMessageBuffer().size()); + latch.countDown(); + } + }); + + assertTrue(latch.await(10, TimeUnit.SECONDS)); + assertTrue(handler2.updated); + } + + @Test + public void testReplaceAndForwardInboundMessage() throws Exception { + LocalChannel channel = new LocalChannel(); + LocalEventLoopGroup group = new LocalEventLoopGroup(); + group.register(channel).awaitUninterruptibly(); + final DefaultChannelPipeline pipeline = new DefaultChannelPipeline(channel); + + + final ChannelInboundMessageHandlerImpl handler1 = new ChannelInboundMessageHandlerImpl(); + final ChannelInboundMessageHandlerImpl handler2 = new ChannelInboundMessageHandlerImpl(); + pipeline.addLast("handler1", handler1); + final CountDownLatch latch = new CountDownLatch(1); + channel.eventLoop().execute(new Runnable() { + @Override + public void run() { + pipeline.context(handler1).inboundMessageBuffer().add(new Object()); + assertEquals(1, pipeline.context(handler1).inboundMessageBuffer().size()); + pipeline.replaceAndForward(handler1, "handler2", handler2); + assertEquals(1, pipeline.context(handler2).inboundMessageBuffer().size()); + latch.countDown(); + } + }); + + assertTrue(latch.await(10, TimeUnit.SECONDS)); + assertTrue(handler2.updated); + } + + @Test + public void testRemoveAndForwardOutboundMessage() throws Exception { + LocalChannel channel = new LocalChannel(); + LocalEventLoopGroup group = new LocalEventLoopGroup(); + group.register(channel).awaitUninterruptibly(); + final DefaultChannelPipeline pipeline = new DefaultChannelPipeline(channel); + + + final ChannelOutboundMessageHandlerImpl handler1 = new ChannelOutboundMessageHandlerImpl(); + final ChannelOutboundMessageHandlerImpl handler2 = new ChannelOutboundMessageHandlerImpl(); + pipeline.addLast("handler1", handler1); + pipeline.addLast("handler2", handler2); + final CountDownLatch latch = new CountDownLatch(1); + channel.eventLoop().execute(new Runnable() { + @Override + public void run() { + pipeline.context(handler2).outboundMessageBuffer().add(new Object()); + assertEquals(1, pipeline.context(handler2).outboundMessageBuffer().size()); + assertEquals(0, pipeline.context(handler1).outboundMessageBuffer().size()); + pipeline.removeAndForward(handler2); + assertEquals(1, pipeline.context(handler1).outboundMessageBuffer().size()); + latch.countDown(); + } + }); + + assertTrue(latch.await(10, TimeUnit.SECONDS)); + assertTrue(handler1.flushed); + } + + @Test + public void testReplaceAndForwardOutboundMessage() throws Exception { + LocalChannel channel = new LocalChannel(); + LocalEventLoopGroup group = new LocalEventLoopGroup(); + group.register(channel).awaitUninterruptibly(); + final DefaultChannelPipeline pipeline = new DefaultChannelPipeline(channel); + + + final ChannelOutboundMessageHandlerImpl handler1 = new ChannelOutboundMessageHandlerImpl(); + final ChannelOutboundMessageHandlerImpl handler2 = new ChannelOutboundMessageHandlerImpl(); + pipeline.addLast("handler1", handler1); + final CountDownLatch latch = new CountDownLatch(1); + channel.eventLoop().execute(new Runnable() { + @Override + public void run() { + pipeline.context(handler1).outboundMessageBuffer().add(new Object()); + assertEquals(1, pipeline.context(handler1).outboundMessageBuffer().size()); + pipeline.replaceAndForward(handler1, "handler2", handler2); + assertEquals(1, pipeline.context(handler2).outboundMessageBuffer().size()); + latch.countDown(); + } + }); + + assertTrue(latch.await(10, TimeUnit.SECONDS)); + assertTrue(handler2.flushed); + } + + @Test + public void testReplaceAndForwardDuplexMessage() throws Exception { + LocalChannel channel = new LocalChannel(); + LocalEventLoopGroup group = new LocalEventLoopGroup(); + group.register(channel).awaitUninterruptibly(); + final DefaultChannelPipeline pipeline = new DefaultChannelPipeline(channel); + + + final MessageHandlerImpl handler1 = new MessageHandlerImpl(); + final MessageHandlerImpl handler2 = new MessageHandlerImpl(); + pipeline.addLast("handler1", handler1); + final CountDownLatch latch = new CountDownLatch(1); + channel.eventLoop().execute(new Runnable() { + @Override + public void run() { + pipeline.context(handler1).outboundMessageBuffer().add(new Object()); + pipeline.context(handler1).inboundMessageBuffer().add(new Object()); + + assertEquals(1, pipeline.context(handler1).outboundMessageBuffer().size()); + assertEquals(1, pipeline.context(handler1).inboundMessageBuffer().size()); + + pipeline.replaceAndForward(handler1, "handler2", handler2); + assertEquals(1, pipeline.context(handler2).outboundMessageBuffer().size()); + assertEquals(1, pipeline.context(handler2).inboundMessageBuffer().size()); + + latch.countDown(); + } + }); + + assertTrue(latch.await(10, TimeUnit.SECONDS)); + assertTrue(((ChannelInboundMessageHandlerImpl)handler2.stateHandler()).updated); + assertTrue(((ChannelOutboundMessageHandlerImpl)handler2.operationHandler()).flushed); + + } + + @Test + public void testRemoveAndForwardDuplexMessage() throws Exception { + LocalChannel channel = new LocalChannel(); + LocalEventLoopGroup group = new LocalEventLoopGroup(); + group.register(channel).awaitUninterruptibly(); + final DefaultChannelPipeline pipeline = new DefaultChannelPipeline(channel); + + + final ChannelOutboundMessageHandlerImpl handler1 = new ChannelOutboundMessageHandlerImpl(); + final MessageHandlerImpl handler2 = new MessageHandlerImpl(); + final ChannelInboundMessageHandlerImpl handler3 = new ChannelInboundMessageHandlerImpl(); + pipeline.addLast("handler1", handler1); + pipeline.addLast("handler2", handler2); + pipeline.addLast("handler3", handler3); + + final CountDownLatch latch = new CountDownLatch(1); + channel.eventLoop().execute(new Runnable() { + @Override + public void run() { + pipeline.context(handler2).outboundMessageBuffer().add(new Object()); + pipeline.context(handler2).inboundMessageBuffer().add(new Object()); + + assertEquals(1, pipeline.context(handler2).outboundMessageBuffer().size()); + assertEquals(1, pipeline.context(handler2).inboundMessageBuffer().size()); + + assertEquals(0, pipeline.context(handler1).outboundMessageBuffer().size()); + assertEquals(0, pipeline.context(handler3).inboundMessageBuffer().size()); + + pipeline.removeAndForward(handler2); + assertEquals(1, pipeline.context(handler1).outboundMessageBuffer().size()); + assertEquals(1, pipeline.context(handler3).inboundMessageBuffer().size()); + latch.countDown(); + } + }); + + assertTrue(latch.await(10, TimeUnit.SECONDS)); + assertTrue(handler1.flushed); + assertTrue(handler3.updated); + + } private static int next(DefaultChannelHandlerContext ctx) { DefaultChannelHandlerContext next = ctx.next; if (next == null) { @@ -360,4 +739,128 @@ public class DefaultChannelPipelineTest { ctx.fireInboundBufferUpdated(); } } + + private static final class ChannelInboundByteHandlerImpl extends ChannelInboundByteHandlerAdapter { + boolean updated; + + @Override + protected void inboundBufferUpdated(ChannelHandlerContext ctx, ByteBuf in) throws Exception { + updated = true; + } + } + + private static final class ChannelOutboundByteHandlerImpl extends ChannelOutboundByteHandlerAdapter { + boolean flushed; + + @Override + protected void flush(ChannelHandlerContext ctx, ByteBuf in, ChannelPromise promise) throws Exception { + promise.setSuccess(); + flushed = true; + } + } + + private static final class ChannelInboundMessageHandlerImpl extends ChannelStateHandlerAdapter implements ChannelInboundMessageHandler { + boolean updated; + @Override + public MessageBuf newInboundBuffer(ChannelHandlerContext ctx) throws Exception { + return Unpooled.messageBuffer(); + } + + @Override + public void freeInboundBuffer(ChannelHandlerContext ctx) throws Exception { + ctx.inboundMessageBuffer().release(); + } + + @Override + public void inboundBufferUpdated(ChannelHandlerContext ctx) throws Exception { + updated = true; + } + } + + private static final class ChannelOutboundMessageHandlerImpl extends ChannelOperationHandlerAdapter implements ChannelOutboundMessageHandler { + boolean flushed; + @Override + public MessageBuf newOutboundBuffer(ChannelHandlerContext ctx) throws Exception { + return Unpooled.messageBuffer(); + } + + @Override + public void freeOutboundBuffer(ChannelHandlerContext ctx) throws Exception { + ctx.outboundMessageBuffer().release(); + } + + @Override + public void flush(ChannelHandlerContext ctx, ChannelPromise promise) throws Exception { + promise.setSuccess(); + flushed = true; + } + } + + private static final class ByteHandlerImpl extends CombinedChannelDuplexHandler implements ChannelInboundByteHandler, ChannelOutboundByteHandler{ + ByteHandlerImpl() { + super(new ChannelInboundByteHandlerImpl(), new ChannelOutboundByteHandlerImpl()); + } + + @Override + public ByteBuf newInboundBuffer(ChannelHandlerContext ctx) throws Exception { + return ((ChannelInboundByteHandler) stateHandler()).newInboundBuffer(ctx); + } + + @Override + public void discardInboundReadBytes(ChannelHandlerContext ctx) throws Exception { + ((ChannelInboundByteHandler) stateHandler()).discardInboundReadBytes(ctx); + } + + @Override + public void freeInboundBuffer(ChannelHandlerContext ctx) throws Exception { + ((ChannelInboundHandler) stateHandler()).freeInboundBuffer(ctx); + } + + @Override + public ByteBuf newOutboundBuffer(ChannelHandlerContext ctx) throws Exception { + return ((ChannelOutboundByteHandler) operationHandler()).newOutboundBuffer(ctx); + } + + @Override + public void discardOutboundReadBytes(ChannelHandlerContext ctx) throws Exception { + ((ChannelOutboundByteHandler) operationHandler()).discardOutboundReadBytes(ctx); + } + + @Override + public void freeOutboundBuffer(ChannelHandlerContext ctx) throws Exception { + ((ChannelOutboundHandler) operationHandler()).freeOutboundBuffer(ctx); + } + } + + private static final class MessageHandlerImpl extends CombinedChannelDuplexHandler + implements ChannelInboundMessageHandler, ChannelOutboundMessageHandler { + MessageHandlerImpl() { + super(new ChannelInboundMessageHandlerImpl(), new ChannelOutboundMessageHandlerImpl()); + } + + @SuppressWarnings("unchecked") + @Override + public MessageBuf newInboundBuffer(ChannelHandlerContext ctx) throws Exception { + return ((ChannelInboundMessageHandler) stateHandler()).newInboundBuffer(ctx); + } + + @SuppressWarnings("unchecked") + @Override + public void freeInboundBuffer(ChannelHandlerContext ctx) throws Exception { + ((ChannelInboundHandler) stateHandler()).freeInboundBuffer(ctx); + } + + @SuppressWarnings("unchecked") + @Override + public MessageBuf newOutboundBuffer(ChannelHandlerContext ctx) throws Exception { + return ((ChannelOutboundMessageHandler) operationHandler()).newOutboundBuffer(ctx); + } + + + @SuppressWarnings("unchecked") + @Override + public void freeOutboundBuffer(ChannelHandlerContext ctx) throws Exception { + ((ChannelOutboundHandler) operationHandler()).freeOutboundBuffer(ctx); + } + } }