netty5/transport/src/test/java/io/netty/channel/local/LocalTransportThreadModelTest.java
Norman Maurer b57d9f307f Allow per-write promises and disallow promises on flush()
- write() now accepts a ChannelPromise and returns ChannelFuture as most
  users expected.  It makes the user's life much easier because it is
  now much easier to get notified when a specific message has been
  written.
- flush() does not create a ChannelPromise nor returns ChannelFuture.
  It is now similar to what read() looks like.
2013-07-11 00:49:48 +09:00

590 lines
23 KiB
Java

/*
* Copyright 2012 The Netty Project
*
* The Netty Project licenses this file to you under the Apache License,
* version 2.0 (the "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at:
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*/
package io.netty.channel.local;
import io.netty.bootstrap.ServerBootstrap;
import io.netty.buffer.ByteBuf;
import io.netty.buffer.Unpooled;
import io.netty.channel.Channel;
import io.netty.channel.ChannelDuplexHandler;
import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.ChannelInboundHandlerAdapter;
import io.netty.channel.ChannelInitializer;
import io.netty.channel.ChannelPromise;
import io.netty.channel.EventLoopGroup;
import io.netty.util.ReferenceCountUtil;
import io.netty.util.concurrent.DefaultEventExecutorGroup;
import io.netty.util.concurrent.DefaultThreadFactory;
import io.netty.util.concurrent.EventExecutorGroup;
import org.junit.AfterClass;
import org.junit.Assert;
import org.junit.BeforeClass;
import org.junit.Ignore;
import org.junit.Test;
import java.util.HashSet;
import java.util.Queue;
import java.util.Set;
import java.util.concurrent.ConcurrentLinkedQueue;
import java.util.concurrent.atomic.AtomicReference;
public class LocalTransportThreadModelTest {
private static EventLoopGroup group;
private static LocalAddress localAddr;
@BeforeClass
public static void init() {
// Configure a test server
group = new LocalEventLoopGroup();
ServerBootstrap sb = new ServerBootstrap();
sb.group(group)
.channel(LocalServerChannel.class)
.childHandler(new ChannelInitializer<LocalChannel>() {
@Override
public void initChannel(LocalChannel ch) throws Exception {
ch.pipeline().addLast(new ChannelInboundHandlerAdapter() {
@Override
public void channelRead(ChannelHandlerContext ctx, Object msg) {
// Discard
ReferenceCountUtil.release(msg);
}
});
}
});
localAddr = (LocalAddress) sb.bind(LocalAddress.ANY).syncUninterruptibly().channel().localAddress();
}
@AfterClass
public static void destroy() throws Exception {
group.shutdownGracefully().sync();
}
@Test(timeout = 30000)
@Ignore("regression test")
public void testStagedExecutionMultiple() throws Throwable {
for (int i = 0; i < 10; i ++) {
testStagedExecution();
}
}
@Test(timeout = 5000)
public void testStagedExecution() throws Throwable {
EventLoopGroup l = new LocalEventLoopGroup(4, new DefaultThreadFactory("l"));
EventExecutorGroup e1 = new DefaultEventExecutorGroup(4, new DefaultThreadFactory("e1"));
EventExecutorGroup e2 = new DefaultEventExecutorGroup(4, new DefaultThreadFactory("e2"));
ThreadNameAuditor h1 = new ThreadNameAuditor();
ThreadNameAuditor h2 = new ThreadNameAuditor();
ThreadNameAuditor h3 = new ThreadNameAuditor();
Channel ch = new LocalChannel();
// With no EventExecutor specified, h1 will be always invoked by EventLoop 'l'.
ch.pipeline().addLast(h1);
// h2 will be always invoked by EventExecutor 'e1'.
ch.pipeline().addLast(e1, h2);
// h3 will be always invoked by EventExecutor 'e2'.
ch.pipeline().addLast(e2, h3);
l.register(ch).sync().channel().connect(localAddr).sync();
// Fire inbound events from all possible starting points.
ch.pipeline().fireChannelRead("1");
ch.pipeline().context(h1).fireChannelRead("2");
ch.pipeline().context(h2).fireChannelRead("3");
ch.pipeline().context(h3).fireChannelRead("4");
// Fire outbound events from all possible starting points.
ch.pipeline().write("5");
ch.pipeline().context(h3).write("6");
ch.pipeline().context(h2).write("7");
ch.pipeline().context(h1).writeAndFlush("8").sync();
ch.close().sync();
// Wait until all events are handled completely.
while (h1.outboundThreadNames.size() < 3 || h3.inboundThreadNames.size() < 3 ||
h1.removalThreadNames.size() < 1) {
if (h1.exception.get() != null) {
throw h1.exception.get();
}
if (h2.exception.get() != null) {
throw h2.exception.get();
}
if (h3.exception.get() != null) {
throw h3.exception.get();
}
Thread.sleep(10);
}
String currentName = Thread.currentThread().getName();
try {
// Events should never be handled from the current thread.
Assert.assertFalse(h1.inboundThreadNames.contains(currentName));
Assert.assertFalse(h2.inboundThreadNames.contains(currentName));
Assert.assertFalse(h3.inboundThreadNames.contains(currentName));
Assert.assertFalse(h1.outboundThreadNames.contains(currentName));
Assert.assertFalse(h2.outboundThreadNames.contains(currentName));
Assert.assertFalse(h3.outboundThreadNames.contains(currentName));
Assert.assertFalse(h1.removalThreadNames.contains(currentName));
Assert.assertFalse(h2.removalThreadNames.contains(currentName));
Assert.assertFalse(h3.removalThreadNames.contains(currentName));
// Assert that events were handled by the correct executor.
for (String name: h1.inboundThreadNames) {
Assert.assertTrue(name.startsWith("l-"));
}
for (String name: h2.inboundThreadNames) {
Assert.assertTrue(name.startsWith("e1-"));
}
for (String name: h3.inboundThreadNames) {
Assert.assertTrue(name.startsWith("e2-"));
}
for (String name: h1.outboundThreadNames) {
Assert.assertTrue(name.startsWith("l-"));
}
for (String name: h2.outboundThreadNames) {
Assert.assertTrue(name.startsWith("e1-"));
}
for (String name: h3.outboundThreadNames) {
Assert.assertTrue(name.startsWith("e2-"));
}
for (String name: h1.removalThreadNames) {
Assert.assertTrue(name.startsWith("l-"));
}
for (String name: h2.removalThreadNames) {
Assert.assertTrue(name.startsWith("e1-"));
}
for (String name: h3.removalThreadNames) {
Assert.assertTrue(name.startsWith("e2-"));
}
// Assert that the events for the same handler were handled by the same thread.
Set<String> names = new HashSet<String>();
names.addAll(h1.inboundThreadNames);
names.addAll(h1.outboundThreadNames);
names.addAll(h1.removalThreadNames);
Assert.assertEquals(1, names.size());
names.clear();
names.addAll(h2.inboundThreadNames);
names.addAll(h2.outboundThreadNames);
names.addAll(h2.removalThreadNames);
Assert.assertEquals(1, names.size());
names.clear();
names.addAll(h3.inboundThreadNames);
names.addAll(h3.outboundThreadNames);
names.addAll(h3.removalThreadNames);
Assert.assertEquals(1, names.size());
// Count the number of events
Assert.assertEquals(1, h1.inboundThreadNames.size());
Assert.assertEquals(2, h2.inboundThreadNames.size());
Assert.assertEquals(3, h3.inboundThreadNames.size());
Assert.assertEquals(3, h1.outboundThreadNames.size());
Assert.assertEquals(2, h2.outboundThreadNames.size());
Assert.assertEquals(1, h3.outboundThreadNames.size());
Assert.assertEquals(1, h1.removalThreadNames.size());
Assert.assertEquals(1, h2.removalThreadNames.size());
Assert.assertEquals(1, h3.removalThreadNames.size());
} catch (AssertionError e) {
System.out.println("H1I: " + h1.inboundThreadNames);
System.out.println("H2I: " + h2.inboundThreadNames);
System.out.println("H3I: " + h3.inboundThreadNames);
System.out.println("H1O: " + h1.outboundThreadNames);
System.out.println("H2O: " + h2.outboundThreadNames);
System.out.println("H3O: " + h3.outboundThreadNames);
System.out.println("H1R: " + h1.removalThreadNames);
System.out.println("H2R: " + h2.removalThreadNames);
System.out.println("H3R: " + h3.removalThreadNames);
throw e;
} finally {
l.shutdownGracefully();
e1.shutdownGracefully();
e2.shutdownGracefully();
l.terminationFuture().sync();
e1.terminationFuture().sync();
e2.terminationFuture().sync();
}
}
@Test(timeout = 30000)
@Ignore
public void testConcurrentMessageBufferAccess() throws Throwable {
EventLoopGroup l = new LocalEventLoopGroup(4, new DefaultThreadFactory("l"));
EventExecutorGroup e1 = new DefaultEventExecutorGroup(4, new DefaultThreadFactory("e1"));
EventExecutorGroup e2 = new DefaultEventExecutorGroup(4, new DefaultThreadFactory("e2"));
EventExecutorGroup e3 = new DefaultEventExecutorGroup(4, new DefaultThreadFactory("e3"));
EventExecutorGroup e4 = new DefaultEventExecutorGroup(4, new DefaultThreadFactory("e4"));
EventExecutorGroup e5 = new DefaultEventExecutorGroup(4, new DefaultThreadFactory("e5"));
try {
final MessageForwarder1 h1 = new MessageForwarder1();
final MessageForwarder2 h2 = new MessageForwarder2();
final MessageForwarder3 h3 = new MessageForwarder3();
final MessageForwarder1 h4 = new MessageForwarder1();
final MessageForwarder2 h5 = new MessageForwarder2();
final MessageDiscarder h6 = new MessageDiscarder();
final Channel ch = new LocalChannel();
// inbound: int -> byte[4] -> int -> int -> byte[4] -> int -> /dev/null
// outbound: int -> int -> byte[4] -> int -> int -> byte[4] -> /dev/null
ch.pipeline().addLast(h1)
.addLast(e1, h2)
.addLast(e2, h3)
.addLast(e3, h4)
.addLast(e4, h5)
.addLast(e5, h6);
l.register(ch).sync().channel().connect(localAddr).sync();
final int ROUNDS = 1024;
final int ELEMS_PER_ROUNDS = 8192;
final int TOTAL_CNT = ROUNDS * ELEMS_PER_ROUNDS;
for (int i = 0; i < TOTAL_CNT;) {
final int start = i;
final int end = i + ELEMS_PER_ROUNDS;
i = end;
ch.eventLoop().execute(new Runnable() {
@Override
public void run() {
for (int j = start; j < end; j ++) {
ch.pipeline().fireChannelRead(Integer.valueOf(j));
}
}
});
}
while (h1.inCnt < TOTAL_CNT || h2.inCnt < TOTAL_CNT || h3.inCnt < TOTAL_CNT ||
h4.inCnt < TOTAL_CNT || h5.inCnt < TOTAL_CNT || h6.inCnt < TOTAL_CNT) {
if (h1.exception.get() != null) {
throw h1.exception.get();
}
if (h2.exception.get() != null) {
throw h2.exception.get();
}
if (h3.exception.get() != null) {
throw h3.exception.get();
}
if (h4.exception.get() != null) {
throw h4.exception.get();
}
if (h5.exception.get() != null) {
throw h5.exception.get();
}
if (h6.exception.get() != null) {
throw h6.exception.get();
}
Thread.sleep(10);
}
for (int i = 0; i < TOTAL_CNT;) {
final int start = i;
final int end = i + ELEMS_PER_ROUNDS;
i = end;
ch.pipeline().context(h6).executor().execute(new Runnable() {
@Override
public void run() {
for (int j = start; j < end; j ++) {
ch.write(Integer.valueOf(j));
}
ch.flush();
}
});
}
while (h1.outCnt < TOTAL_CNT || h2.outCnt < TOTAL_CNT || h3.outCnt < TOTAL_CNT ||
h4.outCnt < TOTAL_CNT || h5.outCnt < TOTAL_CNT || h6.outCnt < TOTAL_CNT) {
if (h1.exception.get() != null) {
throw h1.exception.get();
}
if (h2.exception.get() != null) {
throw h2.exception.get();
}
if (h3.exception.get() != null) {
throw h3.exception.get();
}
if (h4.exception.get() != null) {
throw h4.exception.get();
}
if (h5.exception.get() != null) {
throw h5.exception.get();
}
if (h6.exception.get() != null) {
throw h6.exception.get();
}
Thread.sleep(10);
}
ch.close().sync();
} finally {
l.shutdownGracefully();
e1.shutdownGracefully();
e2.shutdownGracefully();
e3.shutdownGracefully();
e4.shutdownGracefully();
e5.shutdownGracefully();
l.terminationFuture().sync();
e1.terminationFuture().sync();
e2.terminationFuture().sync();
e3.terminationFuture().sync();
e4.terminationFuture().sync();
e5.terminationFuture().sync();
}
}
private static class ThreadNameAuditor extends ChannelDuplexHandler {
private final AtomicReference<Throwable> exception = new AtomicReference<Throwable>();
private final Queue<String> inboundThreadNames = new ConcurrentLinkedQueue<String>();
private final Queue<String> outboundThreadNames = new ConcurrentLinkedQueue<String>();
private final Queue<String> removalThreadNames = new ConcurrentLinkedQueue<String>();
@Override
public void handlerRemoved(ChannelHandlerContext ctx) throws Exception {
removalThreadNames.add(Thread.currentThread().getName());
}
@Override
public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception {
inboundThreadNames.add(Thread.currentThread().getName());
ctx.fireChannelRead(msg);
}
@Override
public void write(ChannelHandlerContext ctx, Object msg, ChannelPromise promise) throws Exception {
outboundThreadNames.add(Thread.currentThread().getName());
ctx.write(msg, promise);
}
@Override
public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) throws Exception {
exception.compareAndSet(null, cause);
System.err.print('[' + Thread.currentThread().getName() + "] ");
cause.printStackTrace();
super.exceptionCaught(ctx, cause);
}
}
/**
* Converts integers into a binary stream.
*/
private static class MessageForwarder1 extends ChannelDuplexHandler {
private final AtomicReference<Throwable> exception = new AtomicReference<Throwable>();
private volatile int inCnt;
private volatile int outCnt;
private volatile Thread t;
@Override
public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception {
Thread t = this.t;
if (t == null) {
this.t = Thread.currentThread();
} else {
Assert.assertSame(t, Thread.currentThread());
}
ByteBuf out = ctx.alloc().buffer(4);
int m = ((Integer) msg).intValue();
int expected = inCnt ++;
Assert.assertEquals(expected, m);
out.writeInt(m);
ctx.fireChannelRead(out);
}
@Override
public void write(ChannelHandlerContext ctx, Object msg, ChannelPromise promise) throws Exception {
Assert.assertSame(t, Thread.currentThread());
// Don't let the write request go to the server-side channel - just swallow.
boolean swallow = this == ctx.pipeline().first();
ByteBuf m = (ByteBuf) msg;
int count = m.readableBytes() / 4;
for (int j = 0; j < count; j ++) {
int actual = m.readInt();
int expected = outCnt ++;
Assert.assertEquals(expected, actual);
if (!swallow) {
ctx.write(actual);
}
}
ctx.writeAndFlush(Unpooled.EMPTY_BUFFER, promise);
m.release();
}
@Override
public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) throws Exception {
exception.compareAndSet(null, cause);
//System.err.print("[" + Thread.currentThread().getName() + "] ");
//cause.printStackTrace();
super.exceptionCaught(ctx, cause);
}
}
/**
* Converts a binary stream into integers.
*/
private static class MessageForwarder2 extends ChannelDuplexHandler {
private final AtomicReference<Throwable> exception = new AtomicReference<Throwable>();
private volatile int inCnt;
private volatile int outCnt;
private volatile Thread t;
@Override
public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception {
Thread t = this.t;
if (t == null) {
this.t = Thread.currentThread();
} else {
Assert.assertSame(t, Thread.currentThread());
}
ByteBuf m = (ByteBuf) msg;
int count = m.readableBytes() / 4;
for (int j = 0; j < count; j ++) {
int actual = m.readInt();
int expected = inCnt ++;
Assert.assertEquals(expected, actual);
ctx.fireChannelRead(actual);
}
m.release();
}
@Override
public void write(ChannelHandlerContext ctx, Object msg, ChannelPromise promise) throws Exception {
Assert.assertSame(t, Thread.currentThread());
ByteBuf out = ctx.alloc().buffer(4);
int m = (Integer) msg;
int expected = outCnt ++;
Assert.assertEquals(expected, m);
out.writeInt(m);
ctx.write(out, promise);
}
@Override
public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) throws Exception {
exception.compareAndSet(null, cause);
//System.err.print("[" + Thread.currentThread().getName() + "] ");
//cause.printStackTrace();
super.exceptionCaught(ctx, cause);
}
}
/**
* Simply forwards the received object to the next handler.
*/
private static class MessageForwarder3 extends ChannelDuplexHandler {
private final AtomicReference<Throwable> exception = new AtomicReference<Throwable>();
private volatile int inCnt;
private volatile int outCnt;
private volatile Thread t;
@Override
public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception {
Thread t = this.t;
if (t == null) {
this.t = Thread.currentThread();
} else {
Assert.assertSame(t, Thread.currentThread());
}
int actual = (Integer) msg;
int expected = inCnt ++;
Assert.assertEquals(expected, actual);
ctx.fireChannelRead(msg);
}
@Override
public void write(ChannelHandlerContext ctx, Object msg, ChannelPromise promise) throws Exception {
Assert.assertSame(t, Thread.currentThread());
int actual = (Integer) msg;
int expected = outCnt ++;
Assert.assertEquals(expected, actual);
ctx.write(msg, promise);
}
@Override
public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) throws Exception {
exception.compareAndSet(null, cause);
System.err.print('[' + Thread.currentThread().getName() + "] ");
cause.printStackTrace();
super.exceptionCaught(ctx, cause);
}
}
/**
* Discards all received messages.
*/
private static class MessageDiscarder extends ChannelDuplexHandler {
private final AtomicReference<Throwable> exception = new AtomicReference<Throwable>();
private volatile int inCnt;
private volatile int outCnt;
private volatile Thread t;
@Override
public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception {
Thread t = this.t;
if (t == null) {
this.t = Thread.currentThread();
} else {
Assert.assertSame(t, Thread.currentThread());
}
int actual = (Integer) msg;
int expected = inCnt ++;
Assert.assertEquals(expected, actual);
}
@Override
public void write(
ChannelHandlerContext ctx, Object msg, ChannelPromise promise) throws Exception {
Assert.assertSame(t, Thread.currentThread());
int actual = (Integer) msg;
int expected = outCnt ++;
Assert.assertEquals(expected, actual);
ctx.write(msg, promise);
}
@Override
public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) throws Exception {
exception.compareAndSet(null, cause);
//System.err.print("[" + Thread.currentThread().getName() + "] ");
//cause.printStackTrace();
super.exceptionCaught(ctx, cause);
}
}
}