Exchanging messages between two handlers is now thread safe

- (not byte buffers yet)
This commit is contained in:
Trustin Lee 2012-06-03 19:39:35 -07:00
parent bde9b6aa2a
commit c1afe3d8c3

View File

@ -6,7 +6,6 @@ import io.netty.channel.ChannelBufferHolder;
import io.netty.channel.ChannelBufferHolders; import io.netty.channel.ChannelBufferHolders;
import io.netty.channel.ChannelFuture; import io.netty.channel.ChannelFuture;
import io.netty.channel.ChannelHandlerAdapter; import io.netty.channel.ChannelHandlerAdapter;
import io.netty.channel.ChannelInboundHandlerAdapter;
import io.netty.channel.ChannelInboundHandlerContext; import io.netty.channel.ChannelInboundHandlerContext;
import io.netty.channel.ChannelInboundMessageHandlerAdapter; import io.netty.channel.ChannelInboundMessageHandlerAdapter;
import io.netty.channel.ChannelInitializer; import io.netty.channel.ChannelInitializer;
@ -19,6 +18,7 @@ import io.netty.util.internal.QueueFactory;
import java.util.HashSet; import java.util.HashSet;
import java.util.Queue; import java.util.Queue;
import java.util.Set; import java.util.Set;
import java.util.concurrent.BlockingQueue;
import java.util.concurrent.ThreadFactory; import java.util.concurrent.ThreadFactory;
import java.util.concurrent.TimeUnit; import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicInteger; import java.util.concurrent.atomic.AtomicInteger;
@ -185,7 +185,7 @@ public class LocalTransportThreadModelTest {
} }
} }
@Test @Test(timeout = 50000)
public void testConcurrentMessageBufferAccess() throws Throwable { public void testConcurrentMessageBufferAccess() throws Throwable {
EventLoop l = new LocalEventLoop(4, new PrefixThreadFactory("l")); EventLoop l = new LocalEventLoop(4, new PrefixThreadFactory("l"));
EventExecutor e1 = new DefaultEventExecutor(4, new PrefixThreadFactory("e1")); EventExecutor e1 = new DefaultEventExecutor(4, new PrefixThreadFactory("e1"));
@ -199,10 +199,13 @@ public class LocalTransportThreadModelTest {
l.register(ch).sync().channel().connect(ADDR).sync(); l.register(ch).sync().channel().connect(ADDR).sync();
final int COUNT = 10485760; final int COUNT = 1048576 * 4;
for (int i = 0; i < COUNT;) { for (int i = 0; i < COUNT;) {
Queue<Object> buf = ch.pipeline().inboundMessageBuffer();
// Thread-safe bridge must be returned.
Assert.assertTrue(buf instanceof BlockingQueue);
for (int j = 0; i < COUNT && j < COUNT / 8; j ++) { for (int j = 0; i < COUNT && j < COUNT / 8; j ++) {
ch.pipeline().inboundMessageBuffer().add(Integer.valueOf(i ++)); buf.add(Integer.valueOf(i ++));
if (h1.exception.get() != null) { if (h1.exception.get() != null) {
throw h1.exception.get(); throw h1.exception.get();
} }
@ -215,6 +218,52 @@ public class LocalTransportThreadModelTest {
} }
ch.pipeline().fireInboundBufferUpdated(); ch.pipeline().fireInboundBufferUpdated();
} }
while (h1.inCnt < COUNT || h2.inCnt < COUNT || h3.inCnt < COUNT) {
if (h1.exception.get() != null) {
throw h1.exception.get();
}
if (h2.exception.get() != null) {
throw h2.exception.get();
}
if (h3.exception.get() != null) {
throw h3.exception.get();
}
Thread.sleep(10);
}
for (int i = 0; i < COUNT;) {
Queue<Object> buf = ch.pipeline().outboundMessageBuffer();
for (int j = 0; i < COUNT && j < COUNT / 8; j ++) {
buf.add(Integer.valueOf(i ++));
if (h1.exception.get() != null) {
throw h1.exception.get();
}
if (h2.exception.get() != null) {
throw h2.exception.get();
}
if (h3.exception.get() != null) {
throw h3.exception.get();
}
}
ch.pipeline().flush();
}
while (h1.outCnt < COUNT || h2.outCnt < COUNT || h3.outCnt < COUNT) {
if (h1.exception.get() != null) {
throw h1.exception.get();
}
if (h2.exception.get() != null) {
throw h2.exception.get();
}
if (h3.exception.get() != null) {
throw h3.exception.get();
}
Thread.sleep(10);
}
} }
private static class ThreadNameAuditor extends ChannelHandlerAdapter<Object, Object> { private static class ThreadNameAuditor extends ChannelHandlerAdapter<Object, Object> {
@ -262,16 +311,78 @@ public class LocalTransportThreadModelTest {
} }
} }
private static class MessageForwarder extends ChannelInboundMessageHandlerAdapter<Object> { private static class MessageForwarder extends ChannelHandlerAdapter<Object, Object> {
private final AtomicReference<Throwable> exception = new AtomicReference<Throwable>(); private final AtomicReference<Throwable> exception = new AtomicReference<Throwable>();
private int counter; private volatile int inCnt;
private volatile int outCnt;
private volatile Thread t;
@Override @Override
public void messageReceived(ChannelInboundHandlerContext<Object> ctx, public ChannelBufferHolder<Object> newInboundBuffer(
Object msg) throws Exception { ChannelInboundHandlerContext<Object> ctx) throws Exception {
Assert.assertEquals(counter ++, msg); return ChannelBufferHolders.messageBuffer();
ctx.nextInboundMessageBuffer().add(msg); }
@Override
public ChannelBufferHolder<Object> newOutboundBuffer(
ChannelOutboundHandlerContext<Object> ctx) throws Exception {
return ChannelBufferHolders.messageBuffer();
}
@Override
public void inboundBufferUpdated(
ChannelInboundHandlerContext<Object> ctx) throws Exception {
Thread t = this.t;
if (t == null) {
this.t = Thread.currentThread();
} else {
Assert.assertSame(t, Thread.currentThread());
}
Queue<Object> in = ctx.inbound().messageBuffer();
Queue<Object> out = ctx.nextInboundMessageBuffer();
// Ensure the bridge buffer is returned.
Assert.assertTrue(out instanceof BlockingQueue);
for (;;) {
Object msg = in.poll();
if (msg == null) {
break;
}
int expected = inCnt ++;
Assert.assertEquals(expected, msg);
out.add(msg);
}
ctx.fireInboundBufferUpdated();
}
@Override
public void flush(ChannelOutboundHandlerContext<Object> ctx,
ChannelFuture future) throws Exception {
Assert.assertSame(t, Thread.currentThread());
Queue<Object> in = ctx.outbound().messageBuffer();
Queue<Object> out = ctx.nextOutboundMessageBuffer();
// Ensure the bridge buffer is returned.
if (ctx.pipeline().first() != this) {
Assert.assertTrue(out instanceof BlockingQueue);
}
for (;;) {
Object msg = in.poll();
if (msg == null) {
break;
}
int expected = outCnt ++;
Assert.assertEquals(expected, msg);
out.add(msg);
}
ctx.flush(future);
} }
@Override @Override
@ -284,10 +395,12 @@ public class LocalTransportThreadModelTest {
} }
} }
private static class MessageDiscarder extends ChannelInboundHandlerAdapter<Object> { private static class MessageDiscarder extends ChannelHandlerAdapter<Object, Object> {
private final AtomicReference<Throwable> exception = new AtomicReference<Throwable>(); private final AtomicReference<Throwable> exception = new AtomicReference<Throwable>();
private int counter; private volatile int inCnt;
private volatile int outCnt;
private volatile Thread t;
@Override @Override
public ChannelBufferHolder<Object> newInboundBuffer( public ChannelBufferHolder<Object> newInboundBuffer(
@ -295,15 +408,55 @@ public class LocalTransportThreadModelTest {
return ChannelBufferHolders.messageBuffer(); return ChannelBufferHolders.messageBuffer();
} }
@Override
public ChannelBufferHolder<Object> newOutboundBuffer(
ChannelOutboundHandlerContext<Object> ctx) throws Exception {
return ChannelBufferHolders.messageBuffer();
}
@Override @Override
public void inboundBufferUpdated( public void inboundBufferUpdated(
ChannelInboundHandlerContext<Object> ctx) throws Exception { ChannelInboundHandlerContext<Object> ctx) throws Exception {
Thread t = this.t;
if (t == null) {
this.t = Thread.currentThread();
} else {
Assert.assertSame(t, Thread.currentThread());
}
Queue<Object> in = ctx.inbound().messageBuffer(); Queue<Object> in = ctx.inbound().messageBuffer();
for (;;) { for (;;) {
Object msg = in.poll(); Object msg = in.poll();
Assert.assertEquals(counter ++, msg); if (msg == null) {
break;
}
int expected = inCnt ++;
Assert.assertEquals(expected, msg);
} }
}
@Override
public void flush(ChannelOutboundHandlerContext<Object> ctx,
ChannelFuture future) throws Exception {
Assert.assertSame(t, Thread.currentThread());
Queue<Object> in = ctx.outbound().messageBuffer();
Queue<Object> out = ctx.nextOutboundMessageBuffer();
// Ensure the bridge buffer is returned.
Assert.assertTrue(out instanceof BlockingQueue);
for (;;) {
Object msg = in.poll();
if (msg == null) {
break;
}
int expected = outCnt ++;
Assert.assertEquals(expected, msg);
out.add(msg);
}
ctx.flush(future);
} }
@Override @Override
@ -317,6 +470,7 @@ public class LocalTransportThreadModelTest {
} }
private static class PrefixThreadFactory implements ThreadFactory { private static class PrefixThreadFactory implements ThreadFactory {
private final String prefix; private final String prefix;
private final AtomicInteger id = new AtomicInteger(); private final AtomicInteger id = new AtomicInteger();
@ -328,6 +482,7 @@ public class LocalTransportThreadModelTest {
public Thread newThread(Runnable r) { public Thread newThread(Runnable r) {
Thread t = new Thread(r); Thread t = new Thread(r);
t.setName(prefix + '-' + id.incrementAndGet()); t.setName(prefix + '-' + id.incrementAndGet());
t.setDaemon(true);
return t; return t;
} }
} }