From c1afe3d8c39badbd51fda030a5c17aa9b62cac7d Mon Sep 17 00:00:00 2001 From: Trustin Lee Date: Sun, 3 Jun 2012 19:39:35 -0700 Subject: [PATCH] Exchanging messages between two handlers is now thread safe - (not byte buffers yet) --- .../local/LocalTransportThreadModelTest.java | 181 ++++++++++++++++-- 1 file changed, 168 insertions(+), 13 deletions(-) diff --git a/transport/src/test/java/io/netty/channel/local/LocalTransportThreadModelTest.java b/transport/src/test/java/io/netty/channel/local/LocalTransportThreadModelTest.java index b4b3cb87cf..cc94088773 100644 --- a/transport/src/test/java/io/netty/channel/local/LocalTransportThreadModelTest.java +++ b/transport/src/test/java/io/netty/channel/local/LocalTransportThreadModelTest.java @@ -6,7 +6,6 @@ import io.netty.channel.ChannelBufferHolder; import io.netty.channel.ChannelBufferHolders; import io.netty.channel.ChannelFuture; import io.netty.channel.ChannelHandlerAdapter; -import io.netty.channel.ChannelInboundHandlerAdapter; import io.netty.channel.ChannelInboundHandlerContext; import io.netty.channel.ChannelInboundMessageHandlerAdapter; import io.netty.channel.ChannelInitializer; @@ -19,6 +18,7 @@ import io.netty.util.internal.QueueFactory; import java.util.HashSet; import java.util.Queue; import java.util.Set; +import java.util.concurrent.BlockingQueue; import java.util.concurrent.ThreadFactory; import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicInteger; @@ -185,7 +185,7 @@ public class LocalTransportThreadModelTest { } } - @Test + @Test(timeout = 50000) public void testConcurrentMessageBufferAccess() throws Throwable { EventLoop l = new LocalEventLoop(4, new PrefixThreadFactory("l")); EventExecutor e1 = new DefaultEventExecutor(4, new PrefixThreadFactory("e1")); @@ -199,10 +199,13 @@ public class LocalTransportThreadModelTest { l.register(ch).sync().channel().connect(ADDR).sync(); - final int COUNT = 10485760; + final int COUNT = 1048576 * 4; for (int i = 0; i < COUNT;) { + Queue buf = ch.pipeline().inboundMessageBuffer(); + // Thread-safe bridge must be returned. + Assert.assertTrue(buf instanceof BlockingQueue); for (int j = 0; i < COUNT && j < COUNT / 8; j ++) { - ch.pipeline().inboundMessageBuffer().add(Integer.valueOf(i ++)); + buf.add(Integer.valueOf(i ++)); if (h1.exception.get() != null) { throw h1.exception.get(); } @@ -215,6 +218,52 @@ public class LocalTransportThreadModelTest { } ch.pipeline().fireInboundBufferUpdated(); } + + while (h1.inCnt < COUNT || h2.inCnt < COUNT || h3.inCnt < COUNT) { + if (h1.exception.get() != null) { + throw h1.exception.get(); + } + if (h2.exception.get() != null) { + throw h2.exception.get(); + } + if (h3.exception.get() != null) { + throw h3.exception.get(); + } + + Thread.sleep(10); + } + + for (int i = 0; i < COUNT;) { + Queue buf = ch.pipeline().outboundMessageBuffer(); + for (int j = 0; i < COUNT && j < COUNT / 8; j ++) { + buf.add(Integer.valueOf(i ++)); + if (h1.exception.get() != null) { + throw h1.exception.get(); + } + if (h2.exception.get() != null) { + throw h2.exception.get(); + } + if (h3.exception.get() != null) { + throw h3.exception.get(); + } + } + ch.pipeline().flush(); + } + + while (h1.outCnt < COUNT || h2.outCnt < COUNT || h3.outCnt < COUNT) { + if (h1.exception.get() != null) { + throw h1.exception.get(); + } + if (h2.exception.get() != null) { + throw h2.exception.get(); + } + if (h3.exception.get() != null) { + throw h3.exception.get(); + } + + Thread.sleep(10); + } + } private static class ThreadNameAuditor extends ChannelHandlerAdapter { @@ -262,16 +311,78 @@ public class LocalTransportThreadModelTest { } } - private static class MessageForwarder extends ChannelInboundMessageHandlerAdapter { + private static class MessageForwarder extends ChannelHandlerAdapter { private final AtomicReference exception = new AtomicReference(); - private int counter; + private volatile int inCnt; + private volatile int outCnt; + private volatile Thread t; @Override - public void messageReceived(ChannelInboundHandlerContext ctx, - Object msg) throws Exception { - Assert.assertEquals(counter ++, msg); - ctx.nextInboundMessageBuffer().add(msg); + public ChannelBufferHolder newInboundBuffer( + ChannelInboundHandlerContext ctx) throws Exception { + return ChannelBufferHolders.messageBuffer(); + } + + @Override + public ChannelBufferHolder newOutboundBuffer( + ChannelOutboundHandlerContext ctx) throws Exception { + return ChannelBufferHolders.messageBuffer(); + } + + @Override + public void inboundBufferUpdated( + ChannelInboundHandlerContext ctx) throws Exception { + Thread t = this.t; + if (t == null) { + this.t = Thread.currentThread(); + } else { + Assert.assertSame(t, Thread.currentThread()); + } + + Queue in = ctx.inbound().messageBuffer(); + Queue out = ctx.nextInboundMessageBuffer(); + + // Ensure the bridge buffer is returned. + Assert.assertTrue(out instanceof BlockingQueue); + + for (;;) { + Object msg = in.poll(); + if (msg == null) { + break; + } + + int expected = inCnt ++; + Assert.assertEquals(expected, msg); + out.add(msg); + } + ctx.fireInboundBufferUpdated(); + } + + @Override + public void flush(ChannelOutboundHandlerContext ctx, + ChannelFuture future) throws Exception { + Assert.assertSame(t, Thread.currentThread()); + + Queue in = ctx.outbound().messageBuffer(); + Queue out = ctx.nextOutboundMessageBuffer(); + + // Ensure the bridge buffer is returned. + if (ctx.pipeline().first() != this) { + Assert.assertTrue(out instanceof BlockingQueue); + } + + for (;;) { + Object msg = in.poll(); + if (msg == null) { + break; + } + + int expected = outCnt ++; + Assert.assertEquals(expected, msg); + out.add(msg); + } + ctx.flush(future); } @Override @@ -284,10 +395,12 @@ public class LocalTransportThreadModelTest { } } - private static class MessageDiscarder extends ChannelInboundHandlerAdapter { + private static class MessageDiscarder extends ChannelHandlerAdapter { private final AtomicReference exception = new AtomicReference(); - private int counter; + private volatile int inCnt; + private volatile int outCnt; + private volatile Thread t; @Override public ChannelBufferHolder newInboundBuffer( @@ -295,15 +408,55 @@ public class LocalTransportThreadModelTest { return ChannelBufferHolders.messageBuffer(); } + @Override + public ChannelBufferHolder newOutboundBuffer( + ChannelOutboundHandlerContext ctx) throws Exception { + return ChannelBufferHolders.messageBuffer(); + } + @Override public void inboundBufferUpdated( ChannelInboundHandlerContext ctx) throws Exception { + Thread t = this.t; + if (t == null) { + this.t = Thread.currentThread(); + } else { + Assert.assertSame(t, Thread.currentThread()); + } + Queue in = ctx.inbound().messageBuffer(); for (;;) { Object msg = in.poll(); - Assert.assertEquals(counter ++, msg); + if (msg == null) { + break; + } + int expected = inCnt ++; + Assert.assertEquals(expected, msg); } + } + @Override + public void flush(ChannelOutboundHandlerContext ctx, + ChannelFuture future) throws Exception { + Assert.assertSame(t, Thread.currentThread()); + + Queue in = ctx.outbound().messageBuffer(); + Queue out = ctx.nextOutboundMessageBuffer(); + + // Ensure the bridge buffer is returned. + Assert.assertTrue(out instanceof BlockingQueue); + + for (;;) { + Object msg = in.poll(); + if (msg == null) { + break; + } + + int expected = outCnt ++; + Assert.assertEquals(expected, msg); + out.add(msg); + } + ctx.flush(future); } @Override @@ -317,6 +470,7 @@ public class LocalTransportThreadModelTest { } private static class PrefixThreadFactory implements ThreadFactory { + private final String prefix; private final AtomicInteger id = new AtomicInteger(); @@ -328,6 +482,7 @@ public class LocalTransportThreadModelTest { public Thread newThread(Runnable r) { Thread t = new Thread(r); t.setName(prefix + '-' + id.incrementAndGet()); + t.setDaemon(true); return t; } }