Updating HTTP/2 tests to not retain buffers for validation

Motivation:

To eliminate the tests as being a cause of leaks, removing the automatic
retaining of ByteBufs in Http2TestUtil.

Modifications:

Each test that relied on retaining buffers for validation has been
modified to copy the buffer into a list of Strings that are manually
validated after the message is received.

Result:

The HTTP/2 tests should (hopefully) no longer be reporting erroneous
leaks due to the testing code, itself.
This commit is contained in:
nmittler 2014-09-22 09:38:34 -07:00
parent 3807f9fc8e
commit 0790678c72
6 changed files with 146 additions and 97 deletions

View File

@ -108,7 +108,7 @@ public class DataCompressionHttp2Test {
ChannelPipeline p = ch.pipeline(); ChannelPipeline p = ch.pipeline();
serverAdapter = new Http2TestUtil.FrameAdapter(serverConnection, serverAdapter = new Http2TestUtil.FrameAdapter(serverConnection,
new DelegatingDecompressorFrameListener(serverConnection, serverListener), new DelegatingDecompressorFrameListener(serverConnection, serverListener),
serverLatch, false); serverLatch);
p.addLast("reader", serverAdapter); p.addLast("reader", serverAdapter);
p.addLast(Http2CodecUtil.ignoreSettingsHandler()); p.addLast(Http2CodecUtil.ignoreSettingsHandler());
serverConnectedChannel = ch; serverConnectedChannel = ch;
@ -121,7 +121,7 @@ public class DataCompressionHttp2Test {
@Override @Override
protected void initChannel(Channel ch) throws Exception { protected void initChannel(Channel ch) throws Exception {
ChannelPipeline p = ch.pipeline(); ChannelPipeline p = ch.pipeline();
clientAdapter = new Http2TestUtil.FrameAdapter(clientListener, clientLatch, false); clientAdapter = new Http2TestUtil.FrameAdapter(clientListener, clientLatch);
p.addLast("reader", clientAdapter); p.addLast("reader", clientAdapter);
p.addLast(Http2CodecUtil.ignoreSettingsHandler()); p.addLast(Http2CodecUtil.ignoreSettingsHandler());
} }

View File

@ -28,6 +28,7 @@ import static org.mockito.Matchers.anyBoolean;
import static org.mockito.Matchers.anyInt; import static org.mockito.Matchers.anyInt;
import static org.mockito.Matchers.anyShort; import static org.mockito.Matchers.anyShort;
import static org.mockito.Matchers.eq; import static org.mockito.Matchers.eq;
import static org.mockito.Mockito.doAnswer;
import static org.mockito.Mockito.never; import static org.mockito.Mockito.never;
import static org.mockito.Mockito.verify; import static org.mockito.Mockito.verify;
import io.netty.bootstrap.Bootstrap; import io.netty.bootstrap.Bootstrap;
@ -51,6 +52,7 @@ import io.netty.util.NetUtil;
import io.netty.util.concurrent.Future; import io.netty.util.concurrent.Future;
import java.net.InetSocketAddress; import java.net.InetSocketAddress;
import java.util.ArrayList;
import java.util.List; import java.util.List;
import java.util.concurrent.CountDownLatch; import java.util.concurrent.CountDownLatch;
import java.util.concurrent.TimeUnit; import java.util.concurrent.TimeUnit;
@ -58,9 +60,10 @@ import java.util.concurrent.TimeUnit;
import org.junit.After; import org.junit.After;
import org.junit.Before; import org.junit.Before;
import org.junit.Test; import org.junit.Test;
import org.mockito.ArgumentCaptor;
import org.mockito.Mock; import org.mockito.Mock;
import org.mockito.MockitoAnnotations; import org.mockito.MockitoAnnotations;
import org.mockito.invocation.InvocationOnMock;
import org.mockito.stubbing.Answer;
/** /**
* Testing the {@link DelegatingHttp2HttpConnectionHandler} for {@link FullHttpRequest} objects into HTTP/2 frames * Testing the {@link DelegatingHttp2HttpConnectionHandler} for {@link FullHttpRequest} objects into HTTP/2 frames
@ -68,8 +71,6 @@ import org.mockito.MockitoAnnotations;
public class DelegatingHttp2HttpConnectionHandlerTest { public class DelegatingHttp2HttpConnectionHandlerTest {
private static final int CONNECTION_SETUP_READ_COUNT = 2; private static final int CONNECTION_SETUP_READ_COUNT = 2;
private List<ByteBuf> capturedData;
@Mock @Mock
private Http2FrameListener clientListener; private Http2FrameListener clientListener;
@ -125,12 +126,6 @@ public class DelegatingHttp2HttpConnectionHandlerTest {
@After @After
public void teardown() throws Exception { public void teardown() throws Exception {
if (capturedData != null) {
for (int i = 0; i < capturedData.size(); ++i) {
capturedData.get(i).release();
}
capturedData = null;
}
serverChannel.close().sync(); serverChannel.close().sync();
Future<?> serverGroup = sb.group().shutdownGracefully(0, 0, TimeUnit.MILLISECONDS); Future<?> serverGroup = sb.group().shutdownGracefully(0, 0, TimeUnit.MILLISECONDS);
Future<?> serverChildGroup = sb.childGroup().shutdownGracefully(0, 0, TimeUnit.MILLISECONDS); Future<?> serverChildGroup = sb.childGroup().shutdownGracefully(0, 0, TimeUnit.MILLISECONDS);
@ -180,6 +175,15 @@ public class DelegatingHttp2HttpConnectionHandlerTest {
requestLatch(new CountDownLatch(CONNECTION_SETUP_READ_COUNT + 2)); requestLatch(new CountDownLatch(CONNECTION_SETUP_READ_COUNT + 2));
final String text = "foooooogoooo"; final String text = "foooooogoooo";
final ByteBuf data = Unpooled.copiedBuffer(text, UTF_8); final ByteBuf data = Unpooled.copiedBuffer(text, UTF_8);
final List<String> receivedBuffers = new ArrayList<String>();
doAnswer(new Answer<Void>() {
@Override
public Void answer(InvocationOnMock in) throws Throwable {
receivedBuffers.add(((ByteBuf) in.getArguments()[2]).toString(UTF_8));
return null;
}
}).when(serverListener).onDataRead(any(ChannelHandlerContext.class), eq(3),
any(ByteBuf.class), eq(0), eq(true));
try { try {
final HttpRequest request = new DefaultFullHttpRequest(HTTP_1_1, POST, "/example", data.retain()); final HttpRequest request = new DefaultFullHttpRequest(HTTP_1_1, POST, "/example", data.retain());
final HttpHeaders httpHeaders = request.headers(); final HttpHeaders httpHeaders = request.headers();
@ -200,13 +204,12 @@ public class DelegatingHttp2HttpConnectionHandlerTest {
writeFuture.awaitUninterruptibly(2, SECONDS); writeFuture.awaitUninterruptibly(2, SECONDS);
assertTrue(writeFuture.isSuccess()); assertTrue(writeFuture.isSuccess());
awaitRequests(); awaitRequests();
final ArgumentCaptor<ByteBuf> dataCaptor = ArgumentCaptor.forClass(ByteBuf.class);
verify(serverListener).onHeadersRead(any(ChannelHandlerContext.class), eq(3), eq(http2Headers), eq(0), verify(serverListener).onHeadersRead(any(ChannelHandlerContext.class), eq(3), eq(http2Headers), eq(0),
anyShort(), anyBoolean(), eq(0), eq(false)); anyShort(), anyBoolean(), eq(0), eq(false));
verify(serverListener).onDataRead(any(ChannelHandlerContext.class), eq(3), dataCaptor.capture(), eq(0), verify(serverListener).onDataRead(any(ChannelHandlerContext.class), eq(3), any(ByteBuf.class), eq(0),
eq(true)); eq(true));
capturedData = dataCaptor.getAllValues(); assertEquals(1, receivedBuffers.size());
assertEquals(data, capturedData.get(0)); assertEquals(text, receivedBuffers.get(0));
} finally { } finally {
data.release(); data.release();
} }

View File

@ -20,11 +20,13 @@ import static io.netty.handler.codec.http2.Http2TestUtil.randomString;
import static io.netty.handler.codec.http2.Http2TestUtil.runInChannel; import static io.netty.handler.codec.http2.Http2TestUtil.runInChannel;
import static io.netty.util.CharsetUtil.UTF_8; import static io.netty.util.CharsetUtil.UTF_8;
import static java.util.concurrent.TimeUnit.SECONDS; import static java.util.concurrent.TimeUnit.SECONDS;
import static org.junit.Assert.assertArrayEquals;
import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertTrue; import static org.junit.Assert.assertTrue;
import static org.mockito.Matchers.any; import static org.mockito.Matchers.any;
import static org.mockito.Matchers.anyInt; import static org.mockito.Matchers.anyInt;
import static org.mockito.Matchers.eq; import static org.mockito.Matchers.eq;
import static org.mockito.Mockito.doAnswer;
import static org.mockito.Mockito.times; import static org.mockito.Mockito.times;
import static org.mockito.Mockito.verify; import static org.mockito.Mockito.verify;
import io.netty.bootstrap.Bootstrap; import io.netty.bootstrap.Bootstrap;
@ -41,11 +43,12 @@ import io.netty.channel.nio.NioEventLoopGroup;
import io.netty.channel.socket.nio.NioServerSocketChannel; import io.netty.channel.socket.nio.NioServerSocketChannel;
import io.netty.channel.socket.nio.NioSocketChannel; import io.netty.channel.socket.nio.NioSocketChannel;
import io.netty.handler.codec.http2.Http2TestUtil.Http2Runnable; import io.netty.handler.codec.http2.Http2TestUtil.Http2Runnable;
import io.netty.util.CharsetUtil;
import io.netty.util.NetUtil; import io.netty.util.NetUtil;
import io.netty.util.concurrent.Future; import io.netty.util.concurrent.Future;
import java.io.ByteArrayOutputStream;
import java.net.InetSocketAddress; import java.net.InetSocketAddress;
import java.util.ArrayList;
import java.util.List; import java.util.List;
import java.util.Random; import java.util.Random;
import java.util.concurrent.CountDownLatch; import java.util.concurrent.CountDownLatch;
@ -54,9 +57,11 @@ import java.util.concurrent.TimeUnit;
import org.junit.After; import org.junit.After;
import org.junit.Before; import org.junit.Before;
import org.junit.Test; import org.junit.Test;
import org.mockito.ArgumentCaptor;
import org.mockito.Mock; import org.mockito.Mock;
import org.mockito.Mockito;
import org.mockito.MockitoAnnotations; import org.mockito.MockitoAnnotations;
import org.mockito.invocation.InvocationOnMock;
import org.mockito.stubbing.Answer;
/** /**
* Tests the full HTTP/2 framing stack including the connection and preface handlers. * Tests the full HTTP/2 framing stack including the connection and preface handlers.
@ -144,7 +149,16 @@ public class Http2ConnectionRoundtripTest {
final byte[] bytes = new byte[length]; final byte[] bytes = new byte[length];
new Random().nextBytes(bytes); new Random().nextBytes(bytes);
final ByteBuf data = Unpooled.wrappedBuffer(bytes); final ByteBuf data = Unpooled.wrappedBuffer(bytes);
List<ByteBuf> capturedData = null; final ByteArrayOutputStream out = new ByteArrayOutputStream(length);
doAnswer(new Answer<Void>() {
@Override
public Void answer(InvocationOnMock in) throws Throwable {
ByteBuf buf = (ByteBuf) in.getArguments()[2];
buf.readBytes(out, buf.readableBytes());
return null;
}
}).when(serverListener).onDataRead(any(ChannelHandlerContext.class), eq(3),
any(ByteBuf.class), eq(0), Mockito.anyBoolean());
try { try {
// Initialize the data latch based on the number of bytes expected. // Initialize the data latch based on the number of bytes expected.
requestLatch(new CountDownLatch(2)); requestLatch(new CountDownLatch(2));
@ -163,18 +177,18 @@ public class Http2ConnectionRoundtripTest {
assertTrue(dataLatch.await(5, TimeUnit.SECONDS)); assertTrue(dataLatch.await(5, TimeUnit.SECONDS));
// Verify that headers were received and only one DATA frame was received with endStream set. // Verify that headers were received and only one DATA frame was received with endStream set.
final ArgumentCaptor<ByteBuf> dataCaptor = ArgumentCaptor.forClass(ByteBuf.class);
verify(serverListener).onHeadersRead(any(ChannelHandlerContext.class), eq(3), eq(headers), eq(0), verify(serverListener).onHeadersRead(any(ChannelHandlerContext.class), eq(3), eq(headers), eq(0),
eq((short) 16), eq(false), eq(0), eq(false)); eq((short) 16), eq(false), eq(0), eq(false));
verify(serverListener).onDataRead(any(ChannelHandlerContext.class), eq(3), dataCaptor.capture(), eq(0), verify(serverListener).onDataRead(any(ChannelHandlerContext.class), eq(3), any(ByteBuf.class), eq(0),
eq(true)); eq(true));
// Verify we received all the bytes. // Verify we received all the bytes.
capturedData = dataCaptor.getAllValues(); out.flush();
assertEquals(data, capturedData.get(0)); byte[] received = out.toByteArray();
assertArrayEquals(bytes, received);
} finally { } finally {
data.release(); data.release();
release(capturedData); out.close();
} }
} }
@ -183,10 +197,25 @@ public class Http2ConnectionRoundtripTest {
final Http2Headers headers = dummyHeaders(); final Http2Headers headers = dummyHeaders();
final String text = "hello world"; final String text = "hello world";
final String pingMsg = "12345678"; final String pingMsg = "12345678";
final ByteBuf data = Unpooled.copiedBuffer(text.getBytes()); final ByteBuf data = Unpooled.copiedBuffer(text, UTF_8);
final ByteBuf pingData = Unpooled.copiedBuffer(pingMsg.getBytes()); final ByteBuf pingData = Unpooled.copiedBuffer(pingMsg, UTF_8);
List<ByteBuf> capturedData = null; final List<String> receivedPingBuffers = new ArrayList<String>(NUM_STREAMS);
List<ByteBuf> capturedPingData = null; doAnswer(new Answer<Void>() {
@Override
public Void answer(InvocationOnMock in) throws Throwable {
receivedPingBuffers.add(((ByteBuf) in.getArguments()[1]).toString(UTF_8));
return null;
}
}).when(serverListener).onPingRead(any(ChannelHandlerContext.class), eq(pingData));
final List<String> receivedDataBuffers = new ArrayList<String>();
doAnswer(new Answer<Void>() {
@Override
public Void answer(InvocationOnMock in) throws Throwable {
receivedDataBuffers.add(((ByteBuf) in.getArguments()[2]).toString(UTF_8));
return null;
}
}).when(serverListener).onDataRead(any(ChannelHandlerContext.class), anyInt(), eq(data),
eq(0), eq(true));
try { try {
runInChannel(clientChannel, new Http2Runnable() { runInChannel(clientChannel, new Http2Runnable() {
@Override @Override
@ -194,8 +223,8 @@ public class Http2ConnectionRoundtripTest {
for (int i = 0, nextStream = 3; i < NUM_STREAMS; ++i, nextStream += 2) { for (int i = 0, nextStream = 3; i < NUM_STREAMS; ++i, nextStream += 2) {
http2Client.writeHeaders(ctx(), nextStream, headers, 0, (short) 16, false, 0, false, http2Client.writeHeaders(ctx(), nextStream, headers, 0, (short) 16, false, 0, false,
newPromise()); newPromise());
http2Client.writePing(ctx(), pingData.retain(), newPromise()); http2Client.writePing(ctx(), pingData.slice().retain(), newPromise());
http2Client.writeData(ctx(), nextStream, data.retain(), 0, true, newPromise()); http2Client.writeData(ctx(), nextStream, data.slice().retain(), 0, true, newPromise());
} }
} }
}); });
@ -203,28 +232,21 @@ public class Http2ConnectionRoundtripTest {
assertTrue(requestLatch.await(STRESS_TIMEOUT_SECONDS, SECONDS)); assertTrue(requestLatch.await(STRESS_TIMEOUT_SECONDS, SECONDS));
verify(serverListener, times(NUM_STREAMS)).onHeadersRead(any(ChannelHandlerContext.class), anyInt(), verify(serverListener, times(NUM_STREAMS)).onHeadersRead(any(ChannelHandlerContext.class), anyInt(),
eq(headers), eq(0), eq((short) 16), eq(false), eq(0), eq(false)); eq(headers), eq(0), eq((short) 16), eq(false), eq(0), eq(false));
final ArgumentCaptor<ByteBuf> dataCaptor = ArgumentCaptor.forClass(ByteBuf.class);
final ArgumentCaptor<ByteBuf> pingDataCaptor = ArgumentCaptor.forClass(ByteBuf.class);
verify(serverListener, times(NUM_STREAMS)).onPingRead(any(ChannelHandlerContext.class), verify(serverListener, times(NUM_STREAMS)).onPingRead(any(ChannelHandlerContext.class),
pingDataCaptor.capture()); any(ByteBuf.class));
capturedPingData = pingDataCaptor.getAllValues(); verify(serverListener, times(NUM_STREAMS)).onDataRead(any(ChannelHandlerContext.class),
verify(serverListener, times(NUM_STREAMS)).onDataRead(any(ChannelHandlerContext.class), anyInt(), anyInt(), any(ByteBuf.class), eq(0), eq(true));
dataCaptor.capture(), eq(0), eq(true)); assertEquals(NUM_STREAMS, receivedPingBuffers.size());
capturedData = dataCaptor.getAllValues(); assertEquals(NUM_STREAMS, receivedDataBuffers.size());
data.resetReaderIndex(); for (String receivedData : receivedDataBuffers) {
pingData.resetReaderIndex(); assertEquals(text, receivedData);
int i;
for (i = 0; i < capturedPingData.size(); ++i) {
assertEquals(pingData, capturedPingData.get(i));
} }
for (i = 0; i < capturedData.size(); ++i) { for (String receivedPing : receivedPingBuffers) {
assertEquals(capturedData.get(i).toString(CharsetUtil.UTF_8), data, capturedData.get(i)); assertEquals(pingMsg, receivedPing);
} }
} finally { } finally {
data.release(); data.release();
pingData.release(); pingData.release();
release(capturedData);
release(capturedPingData);
} }
} }
@ -250,14 +272,6 @@ public class Http2ConnectionRoundtripTest {
return ctx().newPromise(); return ctx().newPromise();
} }
private static void release(List<ByteBuf> capturedData) {
if (capturedData != null) {
for (int i = 0; i < capturedData.size(); ++i) {
capturedData.get(i).release();
}
}
}
private Http2Headers dummyHeaders() { private Http2Headers dummyHeaders() {
return new DefaultHttp2Headers().method(as("GET")).scheme(as("https")) return new DefaultHttp2Headers().method(as("GET")).scheme(as("https"))
.authority(as("example.org")).path(as("/some/path/resource2")).add(randomString(), randomString()); .authority(as("example.org")).path(as("/some/path/resource2")).add(randomString(), randomString());

View File

@ -25,6 +25,7 @@ import static org.junit.Assert.assertTrue;
import static org.mockito.Matchers.any; import static org.mockito.Matchers.any;
import static org.mockito.Matchers.anyInt; import static org.mockito.Matchers.anyInt;
import static org.mockito.Matchers.eq; import static org.mockito.Matchers.eq;
import static org.mockito.Mockito.doAnswer;
import static org.mockito.Mockito.times; import static org.mockito.Mockito.times;
import static org.mockito.Mockito.verify; import static org.mockito.Mockito.verify;
import io.netty.bootstrap.Bootstrap; import io.netty.bootstrap.Bootstrap;
@ -45,6 +46,7 @@ import io.netty.util.NetUtil;
import io.netty.util.concurrent.Future; import io.netty.util.concurrent.Future;
import java.net.InetSocketAddress; import java.net.InetSocketAddress;
import java.util.ArrayList;
import java.util.List; import java.util.List;
import java.util.concurrent.CountDownLatch; import java.util.concurrent.CountDownLatch;
import java.util.concurrent.TimeUnit; import java.util.concurrent.TimeUnit;
@ -52,9 +54,10 @@ import java.util.concurrent.TimeUnit;
import org.junit.After; import org.junit.After;
import org.junit.Before; import org.junit.Before;
import org.junit.Test; import org.junit.Test;
import org.mockito.ArgumentCaptor;
import org.mockito.Mock; import org.mockito.Mock;
import org.mockito.MockitoAnnotations; import org.mockito.MockitoAnnotations;
import org.mockito.invocation.InvocationOnMock;
import org.mockito.stubbing.Answer;
/** /**
* Tests encoding/decoding each HTTP2 frame type. * Tests encoding/decoding each HTTP2 frame type.
@ -64,7 +67,6 @@ public class Http2FrameRoundtripTest {
@Mock @Mock
private Http2FrameListener serverListener; private Http2FrameListener serverListener;
private ArgumentCaptor<ByteBuf> dataCaptor;
private Http2FrameWriter frameWriter; private Http2FrameWriter frameWriter;
private ServerBootstrap sb; private ServerBootstrap sb;
private Bootstrap cb; private Bootstrap cb;
@ -79,7 +81,6 @@ public class Http2FrameRoundtripTest {
serverLatch(new CountDownLatch(1)); serverLatch(new CountDownLatch(1));
frameWriter = new DefaultHttp2FrameWriter(); frameWriter = new DefaultHttp2FrameWriter();
dataCaptor = ArgumentCaptor.forClass(ByteBuf.class);
sb = new ServerBootstrap(); sb = new ServerBootstrap();
cb = new Bootstrap(); cb = new Bootstrap();
@ -90,7 +91,7 @@ public class Http2FrameRoundtripTest {
@Override @Override
protected void initChannel(Channel ch) throws Exception { protected void initChannel(Channel ch) throws Exception {
ChannelPipeline p = ch.pipeline(); ChannelPipeline p = ch.pipeline();
serverAdapter = new Http2TestUtil.FrameAdapter(serverListener, requestLatch, true); serverAdapter = new Http2TestUtil.FrameAdapter(serverListener, requestLatch);
p.addLast("reader", serverAdapter); p.addLast("reader", serverAdapter);
p.addLast(Http2CodecUtil.ignoreSettingsHandler()); p.addLast(Http2CodecUtil.ignoreSettingsHandler());
} }
@ -102,7 +103,7 @@ public class Http2FrameRoundtripTest {
@Override @Override
protected void initChannel(Channel ch) throws Exception { protected void initChannel(Channel ch) throws Exception {
ChannelPipeline p = ch.pipeline(); ChannelPipeline p = ch.pipeline();
p.addLast("reader", new Http2TestUtil.FrameAdapter(null, null, true)); p.addLast("reader", new Http2TestUtil.FrameAdapter(null, null));
p.addLast(Http2CodecUtil.ignoreSettingsHandler()); p.addLast(Http2CodecUtil.ignoreSettingsHandler());
} }
}); });
@ -117,10 +118,6 @@ public class Http2FrameRoundtripTest {
@After @After
public void teardown() throws Exception { public void teardown() throws Exception {
List<ByteBuf> capturedData = dataCaptor.getAllValues();
for (int i = 0; i < capturedData.size(); ++i) {
capturedData.get(i).release();
}
serverChannel.close().sync(); serverChannel.close().sync();
Future<?> serverGroup = sb.group().shutdownGracefully(0, 0, TimeUnit.MILLISECONDS); Future<?> serverGroup = sb.group().shutdownGracefully(0, 0, TimeUnit.MILLISECONDS);
Future<?> serverChildGroup = sb.childGroup().shutdownGracefully(0, 0, TimeUnit.MILLISECONDS); Future<?> serverChildGroup = sb.childGroup().shutdownGracefully(0, 0, TimeUnit.MILLISECONDS);
@ -134,20 +131,29 @@ public class Http2FrameRoundtripTest {
@Test @Test
public void dataFrameShouldMatch() throws Exception { public void dataFrameShouldMatch() throws Exception {
final String text = "hello world"; final String text = "hello world";
final ByteBuf data = Unpooled.copiedBuffer(text.getBytes()); final ByteBuf data = Unpooled.copiedBuffer(text, UTF_8);
final List<String> receivedBuffers = new ArrayList<String>();
doAnswer(new Answer<Void>() {
@Override
public Void answer(InvocationOnMock in) throws Throwable {
receivedBuffers.add(((ByteBuf) in.getArguments()[2]).toString(UTF_8));
return null;
}
}).when(serverListener).onDataRead(any(ChannelHandlerContext.class), eq(0x7FFFFFFF),
any(ByteBuf.class), eq(100), eq(true));
try { try {
runInChannel(clientChannel, new Http2Runnable() { runInChannel(clientChannel, new Http2Runnable() {
@Override @Override
public void run() { public void run() {
frameWriter.writeData(ctx(), 0x7FFFFFFF, data.retain(), 100, true, newPromise()); frameWriter.writeData(ctx(), 0x7FFFFFFF, data.slice().retain(), 100, true, newPromise());
ctx().flush(); ctx().flush();
} }
}); });
awaitRequests(); awaitRequests();
verify(serverListener).onDataRead(any(ChannelHandlerContext.class), eq(0x7FFFFFFF), verify(serverListener).onDataRead(any(ChannelHandlerContext.class), eq(0x7FFFFFFF),
dataCaptor.capture(), eq(100), eq(true)); any(ByteBuf.class), eq(100), eq(true));
List<ByteBuf> capturedData = dataCaptor.getAllValues(); assertEquals(1, receivedBuffers.size());
assertEquals(data, capturedData.get(0)); assertEquals(text, receivedBuffers.get(0));
} finally { } finally {
data.release(); data.release();
} }
@ -188,6 +194,15 @@ public class Http2FrameRoundtripTest {
public void goAwayFrameShouldMatch() throws Exception { public void goAwayFrameShouldMatch() throws Exception {
final String text = "test"; final String text = "test";
final ByteBuf data = Unpooled.copiedBuffer(text.getBytes()); final ByteBuf data = Unpooled.copiedBuffer(text.getBytes());
final List<String> receivedBuffers = new ArrayList<String>();
doAnswer(new Answer<Void>() {
@Override
public Void answer(InvocationOnMock in) throws Throwable {
receivedBuffers.add(((ByteBuf) in.getArguments()[3]).toString(UTF_8));
return null;
}
}).when(serverListener).onGoAwayRead(any(ChannelHandlerContext.class), eq(0x7FFFFFFF),
eq(0xFFFFFFFFL), eq(data));
try { try {
runInChannel(clientChannel, new Http2Runnable() { runInChannel(clientChannel, new Http2Runnable() {
@Override @Override
@ -198,9 +213,9 @@ public class Http2FrameRoundtripTest {
}); });
awaitRequests(); awaitRequests();
verify(serverListener).onGoAwayRead(any(ChannelHandlerContext.class), eq(0x7FFFFFFF), verify(serverListener).onGoAwayRead(any(ChannelHandlerContext.class), eq(0x7FFFFFFF),
eq(0xFFFFFFFFL), dataCaptor.capture()); eq(0xFFFFFFFFL), any(ByteBuf.class));
List<ByteBuf> capturedData = dataCaptor.getAllValues(); assertEquals(1, receivedBuffers.size());
assertEquals(data, capturedData.get(0)); assertEquals(text, receivedBuffers.get(0));
} finally { } finally {
data.release(); data.release();
} }
@ -208,7 +223,16 @@ public class Http2FrameRoundtripTest {
@Test @Test
public void pingFrameShouldMatch() throws Exception { public void pingFrameShouldMatch() throws Exception {
final ByteBuf data = Unpooled.copiedBuffer("01234567", UTF_8); String text = "01234567";
final ByteBuf data = Unpooled.copiedBuffer(text, UTF_8);
final List<String> receivedBuffers = new ArrayList<String>();
doAnswer(new Answer<Void>() {
@Override
public Void answer(InvocationOnMock in) throws Throwable {
receivedBuffers.add(((ByteBuf) in.getArguments()[1]).toString(UTF_8));
return null;
}
}).when(serverListener).onPingAckRead(any(ChannelHandlerContext.class), eq(data));
try { try {
runInChannel(clientChannel, new Http2Runnable() { runInChannel(clientChannel, new Http2Runnable() {
@Override @Override
@ -218,9 +242,11 @@ public class Http2FrameRoundtripTest {
} }
}); });
awaitRequests(); awaitRequests();
verify(serverListener).onPingAckRead(any(ChannelHandlerContext.class), dataCaptor.capture()); verify(serverListener).onPingAckRead(any(ChannelHandlerContext.class), any(ByteBuf.class));
List<ByteBuf> capturedData = dataCaptor.getAllValues(); assertEquals(1, receivedBuffers.size());
assertEquals(data, capturedData.get(0)); for (String receivedData : receivedBuffers) {
assertEquals(text, receivedData);
}
} finally { } finally {
data.release(); data.release();
} }
@ -305,8 +331,16 @@ public class Http2FrameRoundtripTest {
final Http2Headers headers = headers(); final Http2Headers headers = headers();
final String text = "hello world"; final String text = "hello world";
final ByteBuf data = Unpooled.copiedBuffer(text.getBytes()); final ByteBuf data = Unpooled.copiedBuffer(text.getBytes());
try {
final int numStreams = 10000; final int numStreams = 10000;
final List<String> receivedBuffers = new ArrayList<String>(numStreams);
doAnswer(new Answer<Void>() {
@Override
public Void answer(InvocationOnMock in) throws Throwable {
receivedBuffers.add(((ByteBuf) in.getArguments()[2]).toString(UTF_8));
return null;
}
}).when(serverListener).onDataRead(any(ChannelHandlerContext.class), anyInt(), eq(data), eq(0), eq(true));
try {
final int expectedFrames = numStreams * 2; final int expectedFrames = numStreams * 2;
serverLatch(new CountDownLatch(expectedFrames)); serverLatch(new CountDownLatch(expectedFrames));
runInChannel(clientChannel, new Http2Runnable() { runInChannel(clientChannel, new Http2Runnable() {
@ -321,10 +355,10 @@ public class Http2FrameRoundtripTest {
}); });
awaitRequests(30); awaitRequests(30);
verify(serverListener, times(numStreams)).onDataRead(any(ChannelHandlerContext.class), anyInt(), verify(serverListener, times(numStreams)).onDataRead(any(ChannelHandlerContext.class), anyInt(),
dataCaptor.capture(), eq(0), eq(true)); any(ByteBuf.class), eq(0), eq(true));
List<ByteBuf> capturedData = dataCaptor.getAllValues(); assertEquals(numStreams, receivedBuffers.size());
for (int i = 0; i < capturedData.size(); ++i) { for (String receivedData : receivedBuffers) {
assertEquals(data, capturedData.get(i)); assertEquals(text, receivedData);
} }
} finally { } finally {
data.release(); data.release();

View File

@ -85,27 +85,24 @@ final class Http2TestUtil {
} }
static class FrameAdapter extends ByteToMessageDecoder { static class FrameAdapter extends ByteToMessageDecoder {
private final boolean retainBufs;
private final Http2Connection connection; private final Http2Connection connection;
private final Http2FrameListener listener; private final Http2FrameListener listener;
private final DefaultHttp2FrameReader reader; private final DefaultHttp2FrameReader reader;
private CountDownLatch latch; private CountDownLatch latch;
FrameAdapter(Http2FrameListener listener, CountDownLatch latch, boolean retainBufs) { FrameAdapter(Http2FrameListener listener, CountDownLatch latch) {
this(null, listener, latch, retainBufs); this(null, listener, latch);
} }
FrameAdapter(Http2Connection connection, Http2FrameListener listener, CountDownLatch latch, FrameAdapter(Http2Connection connection, Http2FrameListener listener, CountDownLatch latch) {
boolean retainBufs) { this(connection, new DefaultHttp2FrameReader(), listener, latch);
this(connection, new DefaultHttp2FrameReader(), listener, latch, retainBufs);
} }
FrameAdapter(Http2Connection connection, DefaultHttp2FrameReader reader, Http2FrameListener listener, FrameAdapter(Http2Connection connection, DefaultHttp2FrameReader reader, Http2FrameListener listener,
CountDownLatch latch, boolean retainBufs) { CountDownLatch latch) {
this.connection = connection; this.connection = connection;
this.listener = listener; this.listener = listener;
this.reader = reader; this.reader = reader;
this.retainBufs = retainBufs;
latch(latch); latch(latch);
} }
@ -150,7 +147,7 @@ final class Http2TestUtil {
public void onDataRead(ChannelHandlerContext ctx, int streamId, ByteBuf data, int padding, public void onDataRead(ChannelHandlerContext ctx, int streamId, ByteBuf data, int padding,
boolean endOfStream) throws Http2Exception { boolean endOfStream) throws Http2Exception {
Http2Stream stream = getOrCreateStream(streamId, endOfStream); Http2Stream stream = getOrCreateStream(streamId, endOfStream);
listener.onDataRead(ctx, streamId, retainBufs ? data.retain() : data, padding, endOfStream); listener.onDataRead(ctx, streamId, data, padding, endOfStream);
if (endOfStream) { if (endOfStream) {
closeStream(stream, true); closeStream(stream, true);
} }
@ -218,13 +215,13 @@ final class Http2TestUtil {
@Override @Override
public void onPingRead(ChannelHandlerContext ctx, ByteBuf data) throws Http2Exception { public void onPingRead(ChannelHandlerContext ctx, ByteBuf data) throws Http2Exception {
listener.onPingRead(ctx, retainBufs ? data.retain() : data); listener.onPingRead(ctx, data);
latch.countDown(); latch.countDown();
} }
@Override @Override
public void onPingAckRead(ChannelHandlerContext ctx, ByteBuf data) throws Http2Exception { public void onPingAckRead(ChannelHandlerContext ctx, ByteBuf data) throws Http2Exception {
listener.onPingAckRead(ctx, retainBufs ? data.retain() : data); listener.onPingAckRead(ctx, data);
latch.countDown(); latch.countDown();
} }
@ -239,7 +236,7 @@ final class Http2TestUtil {
@Override @Override
public void onGoAwayRead(ChannelHandlerContext ctx, int lastStreamId, long errorCode, ByteBuf debugData) public void onGoAwayRead(ChannelHandlerContext ctx, int lastStreamId, long errorCode, ByteBuf debugData)
throws Http2Exception { throws Http2Exception {
listener.onGoAwayRead(ctx, lastStreamId, errorCode, retainBufs ? debugData.retain() : debugData); listener.onGoAwayRead(ctx, lastStreamId, errorCode, debugData);
latch.countDown(); latch.countDown();
} }
@ -291,10 +288,11 @@ final class Http2TestUtil {
@Override @Override
public void onDataRead(ChannelHandlerContext ctx, int streamId, ByteBuf data, int padding, boolean endOfStream) public void onDataRead(ChannelHandlerContext ctx, int streamId, ByteBuf data, int padding, boolean endOfStream)
throws Http2Exception { throws Http2Exception {
listener.onDataRead(ctx, streamId, data.retain(), padding, endOfStream); int numBytes = data.readableBytes();
listener.onDataRead(ctx, streamId, data, padding, endOfStream);
messageLatch.countDown(); messageLatch.countDown();
if (dataLatch != null) { if (dataLatch != null) {
for (int i = 0; i < data.readableBytes(); ++i) { for (int i = 0; i < numBytes; ++i) {
dataLatch.countDown(); dataLatch.countDown();
} }
} }
@ -341,13 +339,13 @@ final class Http2TestUtil {
@Override @Override
public void onPingRead(ChannelHandlerContext ctx, ByteBuf data) throws Http2Exception { public void onPingRead(ChannelHandlerContext ctx, ByteBuf data) throws Http2Exception {
listener.onPingRead(ctx, data.retain()); listener.onPingRead(ctx, data);
messageLatch.countDown(); messageLatch.countDown();
} }
@Override @Override
public void onPingAckRead(ChannelHandlerContext ctx, ByteBuf data) throws Http2Exception { public void onPingAckRead(ChannelHandlerContext ctx, ByteBuf data) throws Http2Exception {
listener.onPingAckRead(ctx, data.retain()); listener.onPingAckRead(ctx, data);
messageLatch.countDown(); messageLatch.countDown();
} }
@ -361,7 +359,7 @@ final class Http2TestUtil {
@Override @Override
public void onGoAwayRead(ChannelHandlerContext ctx, int lastStreamId, long errorCode, ByteBuf debugData) public void onGoAwayRead(ChannelHandlerContext ctx, int lastStreamId, long errorCode, ByteBuf debugData)
throws Http2Exception { throws Http2Exception {
listener.onGoAwayRead(ctx, lastStreamId, errorCode, debugData.retain()); listener.onGoAwayRead(ctx, lastStreamId, errorCode, debugData);
messageLatch.countDown(); messageLatch.countDown();
} }

View File

@ -720,7 +720,7 @@ public class InboundHttp2ToHttpAdapterTest {
private final class HttpAdapterFrameAdapter extends Http2TestUtil.FrameAdapter { private final class HttpAdapterFrameAdapter extends Http2TestUtil.FrameAdapter {
HttpAdapterFrameAdapter(Http2Connection connection, Http2FrameListener listener, CountDownLatch latch) { HttpAdapterFrameAdapter(Http2Connection connection, Http2FrameListener listener, CountDownLatch latch) {
super(connection, listener, latch, false); super(connection, listener, latch);
} }
@Override @Override