Fixing concurrency issue with HTTP/2 test mocking

Motivation:

Some tests occasionally appear unstable, throwing a
org.mockito.exceptions.misusing.UnfinishedStubbingException. Mockito
stubbing does not work properly in multi-threaded environments, so any
stubbing has to be done before the threads are started.

Modifications:

Modified tests to perform any custom stubbing before the client/server
bootstrap logic executes.

Result:

HTTP/2 tests should be more stable.
This commit is contained in:
nmittler 2014-10-04 10:31:02 -07:00
parent 217d921075
commit 276b826b59
5 changed files with 406 additions and 442 deletions

View File

@ -91,48 +91,6 @@ public class DataCompressionHttp2Test {
@Before
public void setup() throws InterruptedException {
MockitoAnnotations.initMocks(this);
alloc = UnpooledByteBufAllocator.DEFAULT;
sb = new ServerBootstrap();
cb = new Bootstrap();
serverLatch(new CountDownLatch(1));
clientLatch(new CountDownLatch(1));
frameWriter = new DefaultHttp2FrameWriter();
serverConnection = new DefaultHttp2Connection(true);
sb.group(new NioEventLoopGroup(), new NioEventLoopGroup());
sb.channel(NioServerSocketChannel.class);
sb.childHandler(new ChannelInitializer<Channel>() {
@Override
protected void initChannel(Channel ch) throws Exception {
ChannelPipeline p = ch.pipeline();
serverAdapter = new Http2TestUtil.FrameAdapter(serverConnection,
new DelegatingDecompressorFrameListener(serverConnection, serverListener),
serverLatch);
p.addLast("reader", serverAdapter);
p.addLast(Http2CodecUtil.ignoreSettingsHandler());
serverConnectedChannel = ch;
}
});
cb.group(new NioEventLoopGroup());
cb.channel(NioSocketChannel.class);
cb.handler(new ChannelInitializer<Channel>() {
@Override
protected void initChannel(Channel ch) throws Exception {
ChannelPipeline p = ch.pipeline();
clientAdapter = new Http2TestUtil.FrameAdapter(clientListener, clientLatch);
p.addLast("reader", clientAdapter);
p.addLast(Http2CodecUtil.ignoreSettingsHandler());
}
});
serverChannel = sb.bind(new InetSocketAddress(0)).sync().channel();
int port = ((InetSocketAddress) serverChannel.localAddress()).getPort();
ChannelFuture ccf = cb.connect(new InetSocketAddress(NetUtil.LOCALHOST, port));
assertTrue(ccf.awaitUninterruptibly().isSuccess());
clientChannel = ccf.channel();
}
@After
@ -157,6 +115,7 @@ public class DataCompressionHttp2Test {
@Test
public void justHeadersNoData() throws Exception {
bootstrapEnv(2, 1);
final Http2Headers headers =
new DefaultHttp2Headers().method(GET).path(PATH).set(CONTENT_ENCODING, GZIP);
// Required because the decompressor intercepts the onXXXRead events before
@ -175,7 +134,7 @@ public class DataCompressionHttp2Test {
@Test
public void gzipEncodingSingleEmptyMessage() throws Exception {
serverLatch(new CountDownLatch(2));
bootstrapEnv(2, 1);
final String text = "";
final ByteBuf data = Unpooled.copiedBuffer(text.getBytes());
final EmbeddedChannel encoder = new EmbeddedChannel(ZlibCodecFactory.newZlibEncoder(ZlibWrapper.GZIP));
@ -209,7 +168,7 @@ public class DataCompressionHttp2Test {
@Test
public void gzipEncodingSingleMessage() throws Exception {
serverLatch(new CountDownLatch(2));
bootstrapEnv(2, 1);
final String text = "aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaabbbbbbbbbbbbbbbbbbbbbbbbbbbbbccccccccccccccccccccccc";
final ByteBuf data = Unpooled.copiedBuffer(text.getBytes());
final EmbeddedChannel encoder = new EmbeddedChannel(ZlibCodecFactory.newZlibEncoder(ZlibWrapper.GZIP));
@ -243,7 +202,7 @@ public class DataCompressionHttp2Test {
@Test
public void gzipEncodingMultipleMessages() throws Exception {
serverLatch(new CountDownLatch(3));
bootstrapEnv(3, 1);
final String text1 = "aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaabbbbbbbbbbbbbbbbbbbbbbbbbbbbbccccccccccccccccccccccc";
final String text2 = "dddddddddddddddddddeeeeeeeeeeeeeeeeeeeffffffffffffffffffff";
final ByteBuf data1 = Unpooled.copiedBuffer(text1.getBytes());
@ -288,7 +247,7 @@ public class DataCompressionHttp2Test {
@Test
public void deflateEncodingSingleLargeMessageReducedWindow() throws Exception {
serverLatch(new CountDownLatch(3));
bootstrapEnv(3, 1);
final int BUFFER_SIZE = 1 << 16;
final ByteBuf data = Unpooled.buffer(BUFFER_SIZE);
final EmbeddedChannel encoder = new EmbeddedChannel(ZlibCodecFactory.newZlibEncoder(ZlibWrapper.ZLIB));
@ -367,18 +326,49 @@ public class DataCompressionHttp2Test {
}
}
private void serverLatch(CountDownLatch latch) {
serverLatch = latch;
if (serverAdapter != null) {
serverAdapter.latch(serverLatch);
}
}
private void bootstrapEnv(int serverCountDown, int clientCountDown) throws Exception {
alloc = UnpooledByteBufAllocator.DEFAULT;
sb = new ServerBootstrap();
cb = new Bootstrap();
private void clientLatch(CountDownLatch latch) {
clientLatch = latch;
if (clientAdapter != null) {
clientAdapter.latch(clientLatch);
}
serverLatch = new CountDownLatch(serverCountDown);
clientLatch = new CountDownLatch(clientCountDown);
frameWriter = new DefaultHttp2FrameWriter();
serverConnection = new DefaultHttp2Connection(true);
sb.group(new NioEventLoopGroup(), new NioEventLoopGroup());
sb.channel(NioServerSocketChannel.class);
sb.childHandler(new ChannelInitializer<Channel>() {
@Override
protected void initChannel(Channel ch) throws Exception {
ChannelPipeline p = ch.pipeline();
serverAdapter = new Http2TestUtil.FrameAdapter(serverConnection,
new DelegatingDecompressorFrameListener(serverConnection, serverListener),
serverLatch);
p.addLast("reader", serverAdapter);
p.addLast(Http2CodecUtil.ignoreSettingsHandler());
serverConnectedChannel = ch;
}
});
cb.group(new NioEventLoopGroup());
cb.channel(NioSocketChannel.class);
cb.handler(new ChannelInitializer<Channel>() {
@Override
protected void initChannel(Channel ch) throws Exception {
ChannelPipeline p = ch.pipeline();
clientAdapter = new Http2TestUtil.FrameAdapter(clientListener, clientLatch);
p.addLast("reader", clientAdapter);
p.addLast(Http2CodecUtil.ignoreSettingsHandler());
}
});
serverChannel = sb.bind(new InetSocketAddress(0)).sync().channel();
int port = ((InetSocketAddress) serverChannel.localAddress()).getPort();
ChannelFuture ccf = cb.connect(new InetSocketAddress(NetUtil.LOCALHOST, port));
assertTrue(ccf.awaitUninterruptibly().isSuccess());
clientChannel = ccf.channel();
}
private void awaitClient() throws Exception {

View File

@ -70,7 +70,6 @@ import org.mockito.stubbing.Answer;
* Testing the {@link Http2ToHttpConnectionHandler} for {@link FullHttpRequest} objects into HTTP/2 frames
*/
public class DefaultHttp2ToHttpConnectionHandlerTest {
private static final int CONNECTION_SETUP_READ_COUNT = 2;
@Mock
private Http2FrameListener clientListener;
@ -88,41 +87,6 @@ public class DefaultHttp2ToHttpConnectionHandlerTest {
@Before
public void setup() throws Exception {
MockitoAnnotations.initMocks(this);
requestLatch(new CountDownLatch(CONNECTION_SETUP_READ_COUNT + 1));
sb = new ServerBootstrap();
cb = new Bootstrap();
sb.group(new NioEventLoopGroup(), new NioEventLoopGroup());
sb.channel(NioServerSocketChannel.class);
sb.childHandler(new ChannelInitializer<Channel>() {
@Override
protected void initChannel(Channel ch) throws Exception {
ChannelPipeline p = ch.pipeline();
serverFrameCountDown = new Http2TestUtil.FrameCountDown(serverListener, requestLatch);
p.addLast(new Http2ToHttpConnectionHandler(true, serverFrameCountDown));
p.addLast(ignoreSettingsHandler());
}
});
cb.group(new NioEventLoopGroup());
cb.channel(NioSocketChannel.class);
cb.handler(new ChannelInitializer<Channel>() {
@Override
protected void initChannel(Channel ch) throws Exception {
ChannelPipeline p = ch.pipeline();
p.addLast(new Http2ToHttpConnectionHandler(false, clientListener));
p.addLast(ignoreSettingsHandler());
}
});
serverChannel = sb.bind(new InetSocketAddress(0)).sync().channel();
int port = ((InetSocketAddress) serverChannel.localAddress()).getPort();
ChannelFuture ccf = cb.connect(new InetSocketAddress(NetUtil.LOCALHOST, port));
assertTrue(ccf.awaitUninterruptibly().isSuccess());
clientChannel = ccf.channel();
}
@After
@ -138,6 +102,7 @@ public class DefaultHttp2ToHttpConnectionHandlerTest {
@Test
public void testJustHeadersRequest() throws Exception {
bootstrapEnv(3);
final FullHttpRequest request = new DefaultFullHttpRequest(HTTP_1_1, GET, "/example");
try {
final HttpHeaders httpHeaders = request.headers();
@ -173,7 +138,6 @@ public class DefaultHttp2ToHttpConnectionHandlerTest {
@Test
public void testRequestWithBody() throws Exception {
requestLatch(new CountDownLatch(CONNECTION_SETUP_READ_COUNT + 2));
final String text = "foooooogoooo";
final ByteBuf data = Unpooled.copiedBuffer(text, UTF_8);
final List<String> receivedBuffers = Collections.synchronizedList(new ArrayList<String>());
@ -185,6 +149,7 @@ public class DefaultHttp2ToHttpConnectionHandlerTest {
}
}).when(serverListener).onDataRead(any(ChannelHandlerContext.class), eq(3),
any(ByteBuf.class), eq(0), eq(true));
bootstrapEnv(4);
try {
final HttpRequest request = new DefaultFullHttpRequest(HTTP_1_1, POST, "/example", data.retain());
final HttpHeaders httpHeaders = request.headers();
@ -216,11 +181,41 @@ public class DefaultHttp2ToHttpConnectionHandlerTest {
}
}
private void requestLatch(CountDownLatch latch) {
requestLatch = latch;
if (serverFrameCountDown != null) {
serverFrameCountDown.messageLatch(latch);
}
private void bootstrapEnv(int requestCountDown) throws Exception {
requestLatch = new CountDownLatch(requestCountDown);
sb = new ServerBootstrap();
cb = new Bootstrap();
sb.group(new NioEventLoopGroup(), new NioEventLoopGroup());
sb.channel(NioServerSocketChannel.class);
sb.childHandler(new ChannelInitializer<Channel>() {
@Override
protected void initChannel(Channel ch) throws Exception {
ChannelPipeline p = ch.pipeline();
serverFrameCountDown = new Http2TestUtil.FrameCountDown(serverListener, requestLatch);
p.addLast(new Http2ToHttpConnectionHandler(true, serverFrameCountDown));
p.addLast(ignoreSettingsHandler());
}
});
cb.group(new NioEventLoopGroup());
cb.channel(NioSocketChannel.class);
cb.handler(new ChannelInitializer<Channel>() {
@Override
protected void initChannel(Channel ch) throws Exception {
ChannelPipeline p = ch.pipeline();
p.addLast(new Http2ToHttpConnectionHandler(false, clientListener));
p.addLast(ignoreSettingsHandler());
}
});
serverChannel = sb.bind(new InetSocketAddress(0)).sync().channel();
int port = ((InetSocketAddress) serverChannel.localAddress()).getPort();
ChannelFuture ccf = cb.connect(new InetSocketAddress(NetUtil.LOCALHOST, port));
assertTrue(ccf.awaitUninterruptibly().isSuccess());
clientChannel = ccf.channel();
}
private void awaitRequests() throws Exception {

View File

@ -71,9 +71,6 @@ import org.mockito.stubbing.Answer;
* Tests the full HTTP/2 framing stack including the connection and preface handlers.
*/
public class Http2ConnectionRoundtripTest {
private static final int STRESS_TIMEOUT_SECONDS = 30;
private static final int NUM_STREAMS = 5000;
private final byte[] DATA_TEXT = "hello world".getBytes(UTF_8);
@Mock
private Http2FrameListener clientListener;
@ -93,42 +90,6 @@ public class Http2ConnectionRoundtripTest {
@Before
public void setup() throws Exception {
MockitoAnnotations.initMocks(this);
requestLatch(new CountDownLatch(NUM_STREAMS * 3));
dataLatch(new CountDownLatch(NUM_STREAMS * DATA_TEXT.length));
sb = new ServerBootstrap();
cb = new Bootstrap();
sb.group(new NioEventLoopGroup(), new NioEventLoopGroup());
sb.channel(NioServerSocketChannel.class);
sb.childHandler(new ChannelInitializer<Channel>() {
@Override
protected void initChannel(Channel ch) throws Exception {
ChannelPipeline p = ch.pipeline();
serverFrameCountDown = new Http2TestUtil.FrameCountDown(serverListener, requestLatch, dataLatch);
p.addLast(new Http2ConnectionHandler(true, serverFrameCountDown));
p.addLast(Http2CodecUtil.ignoreSettingsHandler());
}
});
cb.group(new NioEventLoopGroup());
cb.channel(NioSocketChannel.class);
cb.handler(new ChannelInitializer<Channel>() {
@Override
protected void initChannel(Channel ch) throws Exception {
ChannelPipeline p = ch.pipeline();
p.addLast(new Http2ConnectionHandler(false, clientListener));
p.addLast(Http2CodecUtil.ignoreSettingsHandler());
}
});
serverChannel = sb.bind(new InetSocketAddress(0)).sync().channel();
int port = ((InetSocketAddress) serverChannel.localAddress()).getPort();
ChannelFuture ccf = cb.connect(new InetSocketAddress(NetUtil.LOCALHOST, port));
assertTrue(ccf.awaitUninterruptibly().isSuccess());
clientChannel = ccf.channel();
http2Client = clientChannel.pipeline().get(Http2ConnectionHandler.class);
}
@After
@ -144,6 +105,7 @@ public class Http2ConnectionRoundtripTest {
@Test
public void http2ExceptionInPipelineShouldCloseConnection() throws Exception {
bootstrapEnv(1, 1);
// Create a latch to track when the close occurs.
final CountDownLatch closeLatch = new CountDownLatch(1);
@ -156,7 +118,6 @@ public class Http2ConnectionRoundtripTest {
// Create a single stream by sending a HEADERS frame to the server.
final Http2Headers headers = dummyHeaders();
requestLatch(new CountDownLatch(1));
runInChannel(clientChannel, new Http2Runnable() {
@Override
public void run() {
@ -183,6 +144,7 @@ public class Http2ConnectionRoundtripTest {
@Test
public void nonHttp2ExceptionInPipelineShouldNotCloseConnection() throws Exception {
bootstrapEnv(1, 1);
// Create a latch to track when the close occurs.
final CountDownLatch closeLatch = new CountDownLatch(1);
@ -195,7 +157,6 @@ public class Http2ConnectionRoundtripTest {
// Create a single stream by sending a HEADERS frame to the server.
final Http2Headers headers = dummyHeaders();
requestLatch(new CountDownLatch(1));
runInChannel(clientChannel, new Http2Runnable() {
@Override
public void run() {
@ -243,8 +204,7 @@ public class Http2ConnectionRoundtripTest {
any(ByteBuf.class), eq(0), Mockito.anyBoolean());
try {
// Initialize the data latch based on the number of bytes expected.
requestLatch(new CountDownLatch(2));
dataLatch(new CountDownLatch(length));
bootstrapEnv(length, 2);
// Create the stream and send all of the data at once.
runInChannel(clientChannel, new Http2Runnable() {
@ -282,7 +242,8 @@ public class Http2ConnectionRoundtripTest {
final String pingMsg = "12345678";
final ByteBuf data = Unpooled.copiedBuffer(text, UTF_8);
final ByteBuf pingData = Unpooled.copiedBuffer(pingMsg, UTF_8);
final List<String> receivedPingBuffers = Collections.synchronizedList(new ArrayList<String>(NUM_STREAMS));
final int numStreams = 5000;
final List<String> receivedPingBuffers = Collections.synchronizedList(new ArrayList<String>(numStreams));
doAnswer(new Answer<Void>() {
@Override
public Void answer(InvocationOnMock in) throws Throwable {
@ -290,7 +251,7 @@ public class Http2ConnectionRoundtripTest {
return null;
}
}).when(serverListener).onPingRead(any(ChannelHandlerContext.class), eq(pingData));
final List<String> receivedDataBuffers = Collections.synchronizedList(new ArrayList<String>(NUM_STREAMS));
final List<String> receivedDataBuffers = Collections.synchronizedList(new ArrayList<String>(numStreams));
doAnswer(new Answer<Void>() {
@Override
public Void answer(InvocationOnMock in) throws Throwable {
@ -300,10 +261,11 @@ public class Http2ConnectionRoundtripTest {
}).when(serverListener).onDataRead(any(ChannelHandlerContext.class), anyInt(), eq(data),
eq(0), eq(true));
try {
bootstrapEnv(numStreams * text.length(), numStreams * 3);
runInChannel(clientChannel, new Http2Runnable() {
@Override
public void run() {
for (int i = 0, nextStream = 3; i < NUM_STREAMS; ++i, nextStream += 2) {
for (int i = 0, nextStream = 3; i < numStreams; ++i, nextStream += 2) {
http2Client.encoder().writeHeaders(ctx(), nextStream, headers, 0,
(short) 16, false, 0, false, newPromise());
http2Client.encoder().writePing(ctx(), false, pingData.slice().retain(),
@ -314,15 +276,15 @@ public class Http2ConnectionRoundtripTest {
}
});
// Wait for all frames to be received.
assertTrue(requestLatch.await(STRESS_TIMEOUT_SECONDS, SECONDS));
verify(serverListener, times(NUM_STREAMS)).onHeadersRead(any(ChannelHandlerContext.class), anyInt(),
assertTrue(requestLatch.await(30, SECONDS));
verify(serverListener, times(numStreams)).onHeadersRead(any(ChannelHandlerContext.class), anyInt(),
eq(headers), eq(0), eq((short) 16), eq(false), eq(0), eq(false));
verify(serverListener, times(NUM_STREAMS)).onPingRead(any(ChannelHandlerContext.class),
verify(serverListener, times(numStreams)).onPingRead(any(ChannelHandlerContext.class),
any(ByteBuf.class));
verify(serverListener, times(NUM_STREAMS)).onDataRead(any(ChannelHandlerContext.class),
verify(serverListener, times(numStreams)).onDataRead(any(ChannelHandlerContext.class),
anyInt(), any(ByteBuf.class), eq(0), eq(true));
assertEquals(NUM_STREAMS, receivedPingBuffers.size());
assertEquals(NUM_STREAMS, receivedDataBuffers.size());
assertEquals(numStreams, receivedPingBuffers.size());
assertEquals(numStreams, receivedDataBuffers.size());
for (String receivedData : receivedDataBuffers) {
assertEquals(text, receivedData);
}
@ -335,18 +297,42 @@ public class Http2ConnectionRoundtripTest {
}
}
private void dataLatch(CountDownLatch latch) {
dataLatch = latch;
if (serverFrameCountDown != null) {
serverFrameCountDown.dataLatch(latch);
}
}
private void bootstrapEnv(int dataCountDown, int requestCountDown) throws Exception {
requestLatch = new CountDownLatch(requestCountDown);
dataLatch = new CountDownLatch(dataCountDown);
sb = new ServerBootstrap();
cb = new Bootstrap();
private void requestLatch(CountDownLatch latch) {
requestLatch = latch;
if (serverFrameCountDown != null) {
serverFrameCountDown.messageLatch(latch);
}
sb.group(new NioEventLoopGroup(), new NioEventLoopGroup());
sb.channel(NioServerSocketChannel.class);
sb.childHandler(new ChannelInitializer<Channel>() {
@Override
protected void initChannel(Channel ch) throws Exception {
ChannelPipeline p = ch.pipeline();
serverFrameCountDown = new Http2TestUtil.FrameCountDown(serverListener, requestLatch, dataLatch);
p.addLast(new Http2ConnectionHandler(true, serverFrameCountDown));
p.addLast(Http2CodecUtil.ignoreSettingsHandler());
}
});
cb.group(new NioEventLoopGroup());
cb.channel(NioSocketChannel.class);
cb.handler(new ChannelInitializer<Channel>() {
@Override
protected void initChannel(Channel ch) throws Exception {
ChannelPipeline p = ch.pipeline();
p.addLast(new Http2ConnectionHandler(false, clientListener));
p.addLast(Http2CodecUtil.ignoreSettingsHandler());
}
});
serverChannel = sb.bind(new InetSocketAddress(0)).sync().channel();
int port = ((InetSocketAddress) serverChannel.localAddress()).getPort();
ChannelFuture ccf = cb.connect(new InetSocketAddress(NetUtil.LOCALHOST, port));
assertTrue(ccf.awaitUninterruptibly().isSuccess());
clientChannel = ccf.channel();
http2Client = clientChannel.pipeline().get(Http2ConnectionHandler.class);
}
private ChannelHandlerContext ctx() {

View File

@ -79,8 +79,277 @@ public class Http2FrameRoundtripTest {
@Before
public void setup() throws Exception {
MockitoAnnotations.initMocks(this);
}
serverLatch(new CountDownLatch(1));
@After
public void teardown() throws Exception {
serverChannel.close().sync();
Future<?> serverGroup = sb.group().shutdownGracefully(0, 0, TimeUnit.MILLISECONDS);
Future<?> serverChildGroup = sb.childGroup().shutdownGracefully(0, 0, TimeUnit.MILLISECONDS);
Future<?> clientGroup = cb.group().shutdownGracefully(0, 0, TimeUnit.MILLISECONDS);
serverGroup.sync();
serverChildGroup.sync();
clientGroup.sync();
serverAdapter = null;
}
@Test
public void dataFrameShouldMatch() throws Exception {
final String text = "hello world";
final ByteBuf data = Unpooled.copiedBuffer(text, UTF_8);
final List<String> receivedBuffers = Collections.synchronizedList(new ArrayList<String>());
doAnswer(new Answer<Void>() {
@Override
public Void answer(InvocationOnMock in) throws Throwable {
receivedBuffers.add(((ByteBuf) in.getArguments()[2]).toString(UTF_8));
return null;
}
}).when(serverListener).onDataRead(any(ChannelHandlerContext.class), eq(0x7FFFFFFF),
any(ByteBuf.class), eq(100), eq(true));
try {
bootstrapEnv(1);
runInChannel(clientChannel, new Http2Runnable() {
@Override
public void run() {
frameWriter.writeData(ctx(), 0x7FFFFFFF, data.slice().retain(), 100, true, newPromise());
ctx().flush();
}
});
awaitRequests();
verify(serverListener).onDataRead(any(ChannelHandlerContext.class), eq(0x7FFFFFFF),
any(ByteBuf.class), eq(100), eq(true));
assertEquals(1, receivedBuffers.size());
assertEquals(text, receivedBuffers.get(0));
} finally {
data.release();
}
}
@Test
public void headersFrameWithoutPriorityShouldMatch() throws Exception {
final Http2Headers headers = headers();
bootstrapEnv(1);
runInChannel(clientChannel, new Http2Runnable() {
@Override
public void run() {
frameWriter.writeHeaders(ctx(), 0x7FFFFFFF, headers, 0, true, newPromise());
ctx().flush();
}
});
awaitRequests();
verify(serverListener).onHeadersRead(any(ChannelHandlerContext.class), eq(0x7FFFFFFF),
eq(headers), eq(0), eq(true));
}
@Test
public void headersFrameWithPriorityShouldMatch() throws Exception {
final Http2Headers headers = headers();
bootstrapEnv(1);
runInChannel(clientChannel, new Http2Runnable() {
@Override
public void run() {
frameWriter.writeHeaders(ctx(), 0x7FFFFFFF, headers, 4, (short) 255,
true, 0, true, newPromise());
ctx().flush();
}
});
awaitRequests();
verify(serverListener).onHeadersRead(any(ChannelHandlerContext.class), eq(0x7FFFFFFF),
eq(headers), eq(4), eq((short) 255), eq(true), eq(0), eq(true));
}
@Test
public void goAwayFrameShouldMatch() throws Exception {
final String text = "test";
final ByteBuf data = Unpooled.copiedBuffer(text.getBytes());
final List<String> receivedBuffers = Collections.synchronizedList(new ArrayList<String>());
doAnswer(new Answer<Void>() {
@Override
public Void answer(InvocationOnMock in) throws Throwable {
receivedBuffers.add(((ByteBuf) in.getArguments()[3]).toString(UTF_8));
return null;
}
}).when(serverListener).onGoAwayRead(any(ChannelHandlerContext.class), eq(0x7FFFFFFF),
eq(0xFFFFFFFFL), eq(data));
bootstrapEnv(1);
try {
runInChannel(clientChannel, new Http2Runnable() {
@Override
public void run() {
frameWriter.writeGoAway(ctx(), 0x7FFFFFFF, 0xFFFFFFFFL, data.retain(), newPromise());
ctx().flush();
}
});
awaitRequests();
verify(serverListener).onGoAwayRead(any(ChannelHandlerContext.class), eq(0x7FFFFFFF),
eq(0xFFFFFFFFL), any(ByteBuf.class));
assertEquals(1, receivedBuffers.size());
assertEquals(text, receivedBuffers.get(0));
} finally {
data.release();
}
}
@Test
public void pingFrameShouldMatch() throws Exception {
String text = "01234567";
final ByteBuf data = Unpooled.copiedBuffer(text, UTF_8);
final List<String> receivedBuffers = Collections.synchronizedList(new ArrayList<String>());
doAnswer(new Answer<Void>() {
@Override
public Void answer(InvocationOnMock in) throws Throwable {
receivedBuffers.add(((ByteBuf) in.getArguments()[1]).toString(UTF_8));
return null;
}
}).when(serverListener).onPingAckRead(any(ChannelHandlerContext.class), eq(data));
try {
bootstrapEnv(1);
runInChannel(clientChannel, new Http2Runnable() {
@Override
public void run() {
frameWriter.writePing(ctx(), true, data.retain(), newPromise());
ctx().flush();
}
});
awaitRequests();
verify(serverListener).onPingAckRead(any(ChannelHandlerContext.class), any(ByteBuf.class));
assertEquals(1, receivedBuffers.size());
for (String receivedData : receivedBuffers) {
assertEquals(text, receivedData);
}
} finally {
data.release();
}
}
@Test
public void priorityFrameShouldMatch() throws Exception {
bootstrapEnv(1);
runInChannel(clientChannel, new Http2Runnable() {
@Override
public void run() {
frameWriter.writePriority(ctx(), 0x7FFFFFFF, 1, (short) 1, true, newPromise());
ctx().flush();
}
});
awaitRequests();
verify(serverListener).onPriorityRead(any(ChannelHandlerContext.class), eq(0x7FFFFFFF),
eq(1), eq((short) 1), eq(true));
}
@Test
public void pushPromiseFrameShouldMatch() throws Exception {
final Http2Headers headers = headers();
bootstrapEnv(1);
runInChannel(clientChannel, new Http2Runnable() {
@Override
public void run() {
frameWriter.writePushPromise(ctx(), 0x7FFFFFFF, 1, headers, 5, newPromise());
ctx().flush();
}
});
awaitRequests();
verify(serverListener).onPushPromiseRead(any(ChannelHandlerContext.class), eq(0x7FFFFFFF),
eq(1), eq(headers), eq(5));
}
@Test
public void rstStreamFrameShouldMatch() throws Exception {
bootstrapEnv(1);
runInChannel(clientChannel, new Http2Runnable() {
@Override
public void run() {
frameWriter.writeRstStream(ctx(), 0x7FFFFFFF, 0xFFFFFFFFL, newPromise());
ctx().flush();
}
});
awaitRequests();
verify(serverListener).onRstStreamRead(any(ChannelHandlerContext.class), eq(0x7FFFFFFF),
eq(0xFFFFFFFFL));
}
@Test
public void settingsFrameShouldMatch() throws Exception {
bootstrapEnv(1);
final Http2Settings settings = new Http2Settings();
settings.initialWindowSize(10);
settings.maxConcurrentStreams(1000);
settings.headerTableSize(4096);
runInChannel(clientChannel, new Http2Runnable() {
@Override
public void run() {
frameWriter.writeSettings(ctx(), settings, newPromise());
ctx().flush();
}
});
awaitRequests();
verify(serverListener).onSettingsRead(any(ChannelHandlerContext.class), eq(settings));
}
@Test
public void windowUpdateFrameShouldMatch() throws Exception {
bootstrapEnv(1);
runInChannel(clientChannel, new Http2Runnable() {
@Override
public void run() {
frameWriter.writeWindowUpdate(ctx(), 0x7FFFFFFF, 0x7FFFFFFF, newPromise());
ctx().flush();
}
});
awaitRequests();
verify(serverListener).onWindowUpdateRead(any(ChannelHandlerContext.class), eq(0x7FFFFFFF),
eq(0x7FFFFFFF));
}
@Test
public void stressTest() throws Exception {
final Http2Headers headers = headers();
final String text = "hello world";
final ByteBuf data = Unpooled.copiedBuffer(text.getBytes());
final int numStreams = 10000;
final List<String> receivedBuffers = Collections.synchronizedList(new ArrayList<String>(numStreams));
doAnswer(new Answer<Void>() {
@Override
public Void answer(InvocationOnMock in) throws Throwable {
receivedBuffers.add(((ByteBuf) in.getArguments()[2]).toString(UTF_8));
return null;
}
}).when(serverListener).onDataRead(any(ChannelHandlerContext.class), anyInt(), eq(data), eq(0), eq(true));
try {
final int expectedFrames = numStreams * 2;
bootstrapEnv(expectedFrames);
runInChannel(clientChannel, new Http2Runnable() {
@Override
public void run() {
for (int i = 1; i < numStreams + 1; ++i) {
frameWriter.writeHeaders(ctx(), i, headers, 0, (short) 16, false, 0, false, newPromise());
frameWriter.writeData(ctx(), i, data.retain(), 0, true, newPromise());
ctx().flush();
}
}
});
awaitRequests(30);
verify(serverListener, times(numStreams)).onDataRead(any(ChannelHandlerContext.class), anyInt(),
any(ByteBuf.class), eq(0), eq(true));
assertEquals(numStreams, receivedBuffers.size());
for (String receivedData : receivedBuffers) {
assertEquals(text, receivedData);
}
} finally {
data.release();
}
}
private void awaitRequests(long seconds) throws InterruptedException {
assertTrue(requestLatch.await(seconds, SECONDS));
}
private void awaitRequests() throws InterruptedException {
awaitRequests(5);
}
private void bootstrapEnv(int requestCountDown) throws Exception {
requestLatch = new CountDownLatch(requestCountDown);
frameWriter = new DefaultHttp2FrameWriter();
sb = new ServerBootstrap();
@ -117,270 +386,6 @@ public class Http2FrameRoundtripTest {
clientChannel = ccf.channel();
}
@After
public void teardown() throws Exception {
serverChannel.close().sync();
Future<?> serverGroup = sb.group().shutdownGracefully(0, 0, TimeUnit.MILLISECONDS);
Future<?> serverChildGroup = sb.childGroup().shutdownGracefully(0, 0, TimeUnit.MILLISECONDS);
Future<?> clientGroup = cb.group().shutdownGracefully(0, 0, TimeUnit.MILLISECONDS);
serverGroup.sync();
serverChildGroup.sync();
clientGroup.sync();
serverAdapter = null;
}
@Test
public void dataFrameShouldMatch() throws Exception {
final String text = "hello world";
final ByteBuf data = Unpooled.copiedBuffer(text, UTF_8);
final List<String> receivedBuffers = Collections.synchronizedList(new ArrayList<String>());
doAnswer(new Answer<Void>() {
@Override
public Void answer(InvocationOnMock in) throws Throwable {
receivedBuffers.add(((ByteBuf) in.getArguments()[2]).toString(UTF_8));
return null;
}
}).when(serverListener).onDataRead(any(ChannelHandlerContext.class), eq(0x7FFFFFFF),
any(ByteBuf.class), eq(100), eq(true));
try {
runInChannel(clientChannel, new Http2Runnable() {
@Override
public void run() {
frameWriter.writeData(ctx(), 0x7FFFFFFF, data.slice().retain(), 100, true, newPromise());
ctx().flush();
}
});
awaitRequests();
verify(serverListener).onDataRead(any(ChannelHandlerContext.class), eq(0x7FFFFFFF),
any(ByteBuf.class), eq(100), eq(true));
assertEquals(1, receivedBuffers.size());
assertEquals(text, receivedBuffers.get(0));
} finally {
data.release();
}
}
@Test
public void headersFrameWithoutPriorityShouldMatch() throws Exception {
final Http2Headers headers = headers();
runInChannel(clientChannel, new Http2Runnable() {
@Override
public void run() {
frameWriter.writeHeaders(ctx(), 0x7FFFFFFF, headers, 0, true, newPromise());
ctx().flush();
}
});
awaitRequests();
verify(serverListener).onHeadersRead(any(ChannelHandlerContext.class), eq(0x7FFFFFFF),
eq(headers), eq(0), eq(true));
}
@Test
public void headersFrameWithPriorityShouldMatch() throws Exception {
final Http2Headers headers = headers();
runInChannel(clientChannel, new Http2Runnable() {
@Override
public void run() {
frameWriter.writeHeaders(ctx(), 0x7FFFFFFF, headers, 4, (short) 255,
true, 0, true, newPromise());
ctx().flush();
}
});
awaitRequests();
verify(serverListener).onHeadersRead(any(ChannelHandlerContext.class), eq(0x7FFFFFFF),
eq(headers), eq(4), eq((short) 255), eq(true), eq(0), eq(true));
}
@Test
public void goAwayFrameShouldMatch() throws Exception {
final String text = "test";
final ByteBuf data = Unpooled.copiedBuffer(text.getBytes());
final List<String> receivedBuffers = Collections.synchronizedList(new ArrayList<String>());
doAnswer(new Answer<Void>() {
@Override
public Void answer(InvocationOnMock in) throws Throwable {
receivedBuffers.add(((ByteBuf) in.getArguments()[3]).toString(UTF_8));
return null;
}
}).when(serverListener).onGoAwayRead(any(ChannelHandlerContext.class), eq(0x7FFFFFFF),
eq(0xFFFFFFFFL), eq(data));
try {
runInChannel(clientChannel, new Http2Runnable() {
@Override
public void run() {
frameWriter.writeGoAway(ctx(), 0x7FFFFFFF, 0xFFFFFFFFL, data.retain(), newPromise());
ctx().flush();
}
});
awaitRequests();
verify(serverListener).onGoAwayRead(any(ChannelHandlerContext.class), eq(0x7FFFFFFF),
eq(0xFFFFFFFFL), any(ByteBuf.class));
assertEquals(1, receivedBuffers.size());
assertEquals(text, receivedBuffers.get(0));
} finally {
data.release();
}
}
@Test
public void pingFrameShouldMatch() throws Exception {
String text = "01234567";
final ByteBuf data = Unpooled.copiedBuffer(text, UTF_8);
final List<String> receivedBuffers = Collections.synchronizedList(new ArrayList<String>());
doAnswer(new Answer<Void>() {
@Override
public Void answer(InvocationOnMock in) throws Throwable {
receivedBuffers.add(((ByteBuf) in.getArguments()[1]).toString(UTF_8));
return null;
}
}).when(serverListener).onPingAckRead(any(ChannelHandlerContext.class), eq(data));
try {
runInChannel(clientChannel, new Http2Runnable() {
@Override
public void run() {
frameWriter.writePing(ctx(), true, data.retain(), newPromise());
ctx().flush();
}
});
awaitRequests();
verify(serverListener).onPingAckRead(any(ChannelHandlerContext.class), any(ByteBuf.class));
assertEquals(1, receivedBuffers.size());
for (String receivedData : receivedBuffers) {
assertEquals(text, receivedData);
}
} finally {
data.release();
}
}
@Test
public void priorityFrameShouldMatch() throws Exception {
runInChannel(clientChannel, new Http2Runnable() {
@Override
public void run() {
frameWriter.writePriority(ctx(), 0x7FFFFFFF, 1, (short) 1, true, newPromise());
ctx().flush();
}
});
awaitRequests();
verify(serverListener).onPriorityRead(any(ChannelHandlerContext.class), eq(0x7FFFFFFF),
eq(1), eq((short) 1), eq(true));
}
@Test
public void pushPromiseFrameShouldMatch() throws Exception {
final Http2Headers headers = headers();
runInChannel(clientChannel, new Http2Runnable() {
@Override
public void run() {
frameWriter.writePushPromise(ctx(), 0x7FFFFFFF, 1, headers, 5, newPromise());
ctx().flush();
}
});
awaitRequests();
verify(serverListener).onPushPromiseRead(any(ChannelHandlerContext.class), eq(0x7FFFFFFF),
eq(1), eq(headers), eq(5));
}
@Test
public void rstStreamFrameShouldMatch() throws Exception {
runInChannel(clientChannel, new Http2Runnable() {
@Override
public void run() {
frameWriter.writeRstStream(ctx(), 0x7FFFFFFF, 0xFFFFFFFFL, newPromise());
ctx().flush();
}
});
awaitRequests();
verify(serverListener).onRstStreamRead(any(ChannelHandlerContext.class), eq(0x7FFFFFFF),
eq(0xFFFFFFFFL));
}
@Test
public void settingsFrameShouldMatch() throws Exception {
final Http2Settings settings = new Http2Settings();
settings.initialWindowSize(10);
settings.maxConcurrentStreams(1000);
settings.headerTableSize(4096);
runInChannel(clientChannel, new Http2Runnable() {
@Override
public void run() {
frameWriter.writeSettings(ctx(), settings, newPromise());
ctx().flush();
}
});
awaitRequests();
verify(serverListener).onSettingsRead(any(ChannelHandlerContext.class), eq(settings));
}
@Test
public void windowUpdateFrameShouldMatch() throws Exception {
runInChannel(clientChannel, new Http2Runnable() {
@Override
public void run() {
frameWriter.writeWindowUpdate(ctx(), 0x7FFFFFFF, 0x7FFFFFFF, newPromise());
ctx().flush();
}
});
awaitRequests();
verify(serverListener).onWindowUpdateRead(any(ChannelHandlerContext.class), eq(0x7FFFFFFF),
eq(0x7FFFFFFF));
}
@Test
public void stressTest() throws Exception {
final Http2Headers headers = headers();
final String text = "hello world";
final ByteBuf data = Unpooled.copiedBuffer(text.getBytes());
final int numStreams = 10000;
final List<String> receivedBuffers = Collections.synchronizedList(new ArrayList<String>(numStreams));
doAnswer(new Answer<Void>() {
@Override
public Void answer(InvocationOnMock in) throws Throwable {
receivedBuffers.add(((ByteBuf) in.getArguments()[2]).toString(UTF_8));
return null;
}
}).when(serverListener).onDataRead(any(ChannelHandlerContext.class), anyInt(), eq(data), eq(0), eq(true));
try {
final int expectedFrames = numStreams * 2;
serverLatch(new CountDownLatch(expectedFrames));
runInChannel(clientChannel, new Http2Runnable() {
@Override
public void run() {
for (int i = 1; i < numStreams + 1; ++i) {
frameWriter.writeHeaders(ctx(), i, headers, 0, (short) 16, false, 0, false, newPromise());
frameWriter.writeData(ctx(), i, data.retain(), 0, true, newPromise());
ctx().flush();
}
}
});
awaitRequests(30);
verify(serverListener, times(numStreams)).onDataRead(any(ChannelHandlerContext.class), anyInt(),
any(ByteBuf.class), eq(0), eq(true));
assertEquals(numStreams, receivedBuffers.size());
for (String receivedData : receivedBuffers) {
assertEquals(text, receivedData);
}
} finally {
data.release();
}
}
private void awaitRequests(long seconds) throws InterruptedException {
assertTrue(requestLatch.await(seconds, SECONDS));
}
private void awaitRequests() throws InterruptedException {
awaitRequests(5);
}
private void serverLatch(CountDownLatch latch) {
requestLatch = latch;
if (serverAdapter != null) {
serverAdapter.latch(latch);
}
}
private ChannelHandlerContext ctx() {
return clientChannel.pipeline().firstContext();
}

View File

@ -88,7 +88,7 @@ final class Http2TestUtil {
private final Http2Connection connection;
private final Http2FrameListener listener;
private final DefaultHttp2FrameReader reader;
private volatile CountDownLatch latch;
private final CountDownLatch latch;
FrameAdapter(Http2FrameListener listener, CountDownLatch latch) {
this(null, listener, latch);
@ -103,10 +103,6 @@ final class Http2TestUtil {
this.connection = connection;
this.listener = listener;
this.reader = reader;
latch(latch);
}
public void latch(CountDownLatch latch) {
this.latch = latch;
}
@ -264,8 +260,8 @@ final class Http2TestUtil {
*/
static class FrameCountDown implements Http2FrameListener {
private final Http2FrameListener listener;
private volatile CountDownLatch messageLatch;
private volatile CountDownLatch dataLatch;
private final CountDownLatch messageLatch;
private final CountDownLatch dataLatch;
public FrameCountDown(Http2FrameListener listener, CountDownLatch messageLatch) {
this(listener, messageLatch, null);
@ -277,14 +273,6 @@ final class Http2TestUtil {
this.dataLatch = dataLatch;
}
public void messageLatch(CountDownLatch latch) {
messageLatch = latch;
}
public void dataLatch(CountDownLatch latch) {
dataLatch = latch;
}
@Override
public void onDataRead(ChannelHandlerContext ctx, int streamId, ByteBuf data, int padding, boolean endOfStream)
throws Http2Exception {