Allow HTTP2 frame writer to accept arbitrarily large frames

Motivation:

The encoder is currently responsible for chunking frames when writing in order to conform to max frame size. The frame writer would be a better place for this since it could perform a reuse the same promise aggregator for all the write and could also perform a single allocation for all of the frame headers.

Modifications:

Modified `DefaultHttp2FrameWriter` to perform the chunking and modified the contract in the `Http2FrameWriter` interface. Modified `DefaultHttp2ConnectionEncoder` to send give all allocated bytes to the writer.

Result:

Fixes #3966
This commit is contained in:
nmittler 2015-11-18 20:46:11 -08:00
parent 9df16a6c54
commit e4b716a9da
8 changed files with 638 additions and 927 deletions

View File

@ -18,6 +18,8 @@ import static io.netty.handler.codec.http2.Http2CodecUtil.DEFAULT_PRIORITY_WEIGH
import static io.netty.handler.codec.http2.Http2Error.PROTOCOL_ERROR;
import static io.netty.handler.codec.http2.Http2Exception.connectionError;
import static io.netty.util.internal.ObjectUtil.checkNotNull;
import static java.lang.Math.min;
import io.netty.buffer.ByteBuf;
import io.netty.channel.ChannelFuture;
import io.netty.channel.ChannelFutureListener;
@ -83,12 +85,12 @@ public class DefaultHttp2ConnectionEncoder implements Http2ConnectionEncoder {
Long maxConcurrentStreams = settings.maxConcurrentStreams();
if (maxConcurrentStreams != null) {
connection.local().maxActiveStreams((int) Math.min(maxConcurrentStreams, Integer.MAX_VALUE));
connection.local().maxActiveStreams((int) min(maxConcurrentStreams, Integer.MAX_VALUE));
}
Long headerTableSize = settings.headerTableSize();
if (headerTableSize != null) {
outboundHeaderTable.maxHeaderTableSize((int) Math.min(headerTableSize, Integer.MAX_VALUE));
outboundHeaderTable.maxHeaderTableSize((int) min(headerTableSize, Integer.MAX_VALUE));
}
Integer maxHeaderListSize = settings.maxHeaderListSize();
@ -333,24 +335,24 @@ public class DefaultHttp2ConnectionEncoder implements Http2ConnectionEncoder {
@Override
public void write(ChannelHandlerContext ctx, int allowedBytes) {
if (!endOfStream && (queue.readableBytes() == 0 || allowedBytes == 0)) {
int queuedData = queue.readableBytes();
if (!endOfStream && (queuedData == 0 || allowedBytes == 0)) {
// Nothing to write and we don't have to force a write because of EOS.
return;
}
int maxFrameSize = frameWriter().configuration().frameSizePolicy().maxFrameSize();
do {
int allowedFrameSize = Math.min(maxFrameSize, allowedBytes);
int writeableData = Math.min(queue.readableBytes(), allowedFrameSize);
ChannelPromise writePromise = ctx.newPromise();
writePromise.addListener(this);
ByteBuf toWrite = queue.remove(writeableData, writePromise);
int writeablePadding = Math.min(allowedFrameSize - writeableData, padding);
padding -= writeablePadding;
allowedBytes -= writeableData + writeablePadding;
frameWriter().writeData(ctx, stream.id(), toWrite, writeablePadding,
endOfStream && size() == 0, writePromise);
} while (size() > 0 && allowedBytes > 0);
// Determine how much data to write.
int writeableData = min(queuedData, allowedBytes);
ChannelPromise writePromise = ctx.newPromise().addListener(this);
ByteBuf toWrite = queue.remove(writeableData, writePromise);
// Determine how much padding to write.
int writeablePadding = min(allowedBytes - writeableData, padding);
padding -= writeablePadding;
// Write the frame(s).
frameWriter().writeData(ctx, stream.id(), toWrite, writeablePadding,
endOfStream && size() == 0, writePromise);
}
@Override

View File

@ -78,7 +78,6 @@ public class DefaultHttp2FrameReader implements Http2FrameReader, Http2FrameSize
/**
* Create a new instance.
* @param validateHeaders {@code true} to validate headers. {@code false} to not validate headers.
* @see #DefaultHttp2HeadersDecoder(boolean)
*/
public DefaultHttp2FrameReader(boolean validateHeaders) {
this(new DefaultHttp2HeadersDecoder(validateHeaders));

View File

@ -15,13 +15,6 @@
package io.netty.handler.codec.http2;
import io.netty.buffer.ByteBuf;
import io.netty.channel.ChannelFuture;
import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.ChannelPromise;
import io.netty.handler.codec.http2.Http2CodecUtil.SimpleChannelPromiseAggregator;
import io.netty.handler.codec.http2.Http2FrameWriter.Configuration;
import static io.netty.buffer.Unpooled.directBuffer;
import static io.netty.buffer.Unpooled.unmodifiableBuffer;
import static io.netty.buffer.Unpooled.unreleasableBuffer;
@ -60,6 +53,15 @@ import static io.netty.handler.codec.http2.Http2FrameTypes.RST_STREAM;
import static io.netty.handler.codec.http2.Http2FrameTypes.SETTINGS;
import static io.netty.handler.codec.http2.Http2FrameTypes.WINDOW_UPDATE;
import static io.netty.util.internal.ObjectUtil.checkNotNull;
import static java.lang.Math.max;
import static java.lang.Math.min;
import io.netty.buffer.ByteBuf;
import io.netty.channel.ChannelFuture;
import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.ChannelPromise;
import io.netty.handler.codec.http2.Http2CodecUtil.SimpleChannelPromiseAggregator;
import io.netty.handler.codec.http2.Http2FrameWriter.Configuration;
/**
* A {@link Http2FrameWriter} that supports all frame types defined by the HTTP/2 specification.
@ -121,33 +123,52 @@ public class DefaultHttp2FrameWriter implements Http2FrameWriter, Http2FrameSize
@Override
public ChannelFuture writeData(ChannelHandlerContext ctx, int streamId, ByteBuf data,
int padding, boolean endStream, ChannelPromise promise) {
boolean releaseData = true;
SimpleChannelPromiseAggregator promiseAggregator =
final SimpleChannelPromiseAggregator promiseAggregator =
new SimpleChannelPromiseAggregator(promise, ctx.channel(), ctx.executor());
final DataFrameHeader header = new DataFrameHeader(ctx, streamId);
boolean needToReleaseHeaders = true;
boolean needToReleaseData = true;
try {
verifyStreamId(streamId, STREAM_ID);
verifyPadding(padding);
Http2Flags flags = new Http2Flags().paddingPresent(padding > 0).endOfStream(endStream);
boolean lastFrame;
int remainingData = data.readableBytes();
do {
// Determine how much data and padding to write in this frame. Put all padding at the end.
int frameDataBytes = min(remainingData, maxFrameSize);
int framePaddingBytes = min(padding, max(0, (maxFrameSize - 1) - frameDataBytes));
int payloadLength = data.readableBytes() + padding + flags.getPaddingPresenceFieldLength();
verifyPayloadLength(payloadLength);
// Decrement the remaining counters.
padding -= framePaddingBytes;
remainingData -= frameDataBytes;
ByteBuf buf = ctx.alloc().buffer(DATA_FRAME_HEADER_LENGTH);
writeFrameHeaderInternal(buf, payloadLength, DATA, flags, streamId);
writePaddingLength(buf, padding);
ctx.write(buf, promiseAggregator.newPromise());
// Determine whether or not this is the last frame to be sent.
lastFrame = remainingData == 0 && padding == 0;
// Write the data.
releaseData = false;
ctx.write(data, promiseAggregator.newPromise());
// Only the last frame is not retained. Until then, the outer finally must release.
ByteBuf frameHeader = header.slice(frameDataBytes, framePaddingBytes, lastFrame && endStream);
needToReleaseHeaders = !lastFrame;
ctx.write(lastFrame ? frameHeader : frameHeader.retain(), promiseAggregator.newPromise());
// Write the frame data.
ByteBuf frameData = data.readSlice(frameDataBytes);
// Only the last frame is not retained. Until then, the outer finally must release.
needToReleaseData = !lastFrame;
ctx.write(lastFrame ? frameData : frameData.retain(), promiseAggregator.newPromise());
// Write the frame padding.
if (framePaddingBytes > 0) {
ctx.write(ZERO_BUFFER.slice(0, framePaddingBytes), promiseAggregator.newPromise());
}
} while (!lastFrame);
if (padding > 0) { // Write the required padding.
ctx.write(ZERO_BUFFER.slice(0, padding).retain(), promiseAggregator.newPromise());
}
return promiseAggregator.doneAllocatingPromises();
} catch (Throwable t) {
if (releaseData) {
if (needToReleaseHeaders) {
header.release();
}
if (needToReleaseData) {
data.release();
}
return promiseAggregator.setFailure(t);
@ -281,7 +302,7 @@ public class DefaultHttp2FrameWriter implements Http2FrameWriter, Http2FrameSize
int nonFragmentLength = INT_FIELD_LENGTH + padding + flags.getPaddingPresenceFieldLength();
int maxFragmentLength = maxFrameSize - nonFragmentLength;
ByteBuf fragment =
headerBlock.readSlice(Math.min(headerBlock.readableBytes(), maxFragmentLength)).retain();
headerBlock.readSlice(min(headerBlock.readableBytes(), maxFragmentLength)).retain();
flags.endOfHeaders(!headerBlock.isReadable());
@ -298,7 +319,7 @@ public class DefaultHttp2FrameWriter implements Http2FrameWriter, Http2FrameSize
ctx.write(fragment, promiseAggregator.newPromise());
if (padding > 0) { // Write out the padding, if any.
ctx.write(ZERO_BUFFER.slice(0, padding).retain(), promiseAggregator.newPromise());
ctx.write(ZERO_BUFFER.slice(0, padding), promiseAggregator.newPromise());
}
if (!flags.endOfHeaders()) {
@ -407,7 +428,7 @@ public class DefaultHttp2FrameWriter implements Http2FrameWriter, Http2FrameSize
int nonFragmentBytes = padding + flags.getNumPriorityBytes() + flags.getPaddingPresenceFieldLength();
int maxFragmentLength = maxFrameSize - nonFragmentBytes;
ByteBuf fragment =
headerBlock.readSlice(Math.min(headerBlock.readableBytes(), maxFragmentLength)).retain();
headerBlock.readSlice(min(headerBlock.readableBytes(), maxFragmentLength)).retain();
// Set the end of headers flag for the first frame.
flags.endOfHeaders(!headerBlock.isReadable());
@ -430,7 +451,7 @@ public class DefaultHttp2FrameWriter implements Http2FrameWriter, Http2FrameSize
ctx.write(fragment, promiseAggregator.newPromise());
if (padding > 0) { // Write out the padding, if any.
ctx.write(ZERO_BUFFER.slice(0, padding).retain(), promiseAggregator.newPromise());
ctx.write(ZERO_BUFFER.slice(0, padding), promiseAggregator.newPromise());
}
if (!flags.endOfHeaders()) {
@ -463,15 +484,14 @@ public class DefaultHttp2FrameWriter implements Http2FrameWriter, Http2FrameSize
if (headerBlock.isReadable()) {
// The frame header (and padding) only changes on the last frame, so allocate it once and re-use
final ByteBuf paddingBuf = padding > 0 ? ZERO_BUFFER.slice(0, padding) : null;
int fragmentReadableBytes = Math.min(headerBlock.readableBytes(), maxFragmentLength);
int fragmentReadableBytes = min(headerBlock.readableBytes(), maxFragmentLength);
int payloadLength = fragmentReadableBytes + nonFragmentLength;
ByteBuf buf = ctx.alloc().buffer(CONTINUATION_FRAME_HEADER_LENGTH);
writeFrameHeaderInternal(buf, payloadLength, CONTINUATION, flags, streamId);
writePaddingLength(buf, padding);
do {
fragmentReadableBytes = Math.min(headerBlock.readableBytes(), maxFragmentLength);
fragmentReadableBytes = min(headerBlock.readableBytes(), maxFragmentLength);
ByteBuf fragment = headerBlock.readSlice(fragmentReadableBytes).retain();
payloadLength = fragmentReadableBytes + nonFragmentLength;
@ -490,8 +510,8 @@ public class DefaultHttp2FrameWriter implements Http2FrameWriter, Http2FrameSize
ctx.write(fragment, promiseAggregator.newPromise());
// Write out the padding, if any.
if (paddingBuf != null) {
ctx.write(paddingBuf.retain(), promiseAggregator.newPromise());
if (padding > 0) {
ctx.write(ZERO_BUFFER.slice(0, padding), promiseAggregator.newPromise());
}
} while(headerBlock.isReadable());
}
@ -523,13 +543,6 @@ public class DefaultHttp2FrameWriter implements Http2FrameWriter, Http2FrameSize
}
}
private void verifyPayloadLength(int payloadLength) {
if (payloadLength > maxFrameSize) {
throw new IllegalArgumentException("Total payload length " + payloadLength
+ " exceeds max frame length.");
}
}
private static void verifyWeight(short weight) {
if (weight < MIN_WEIGHT || weight > MAX_WEIGHT) {
throw new IllegalArgumentException("Invalid weight: " + weight);
@ -553,4 +566,51 @@ public class DefaultHttp2FrameWriter implements Http2FrameWriter, Http2FrameSize
throw new IllegalArgumentException("Opaque data must be " + PING_FRAME_PAYLOAD_LENGTH + " bytes");
}
}
/**
* Utility class that manages the creation of frame header buffers for {@code DATA} frames. Attempts
* to reuse the same buffer repeatedly when splitting data into multiple frames.
*/
private static final class DataFrameHeader {
private final int streamId;
private final ByteBuf buffer;
private final Http2Flags flags = new Http2Flags();
private int prevData;
private int prevPadding;
private ByteBuf frameHeader;
DataFrameHeader(ChannelHandlerContext ctx, int streamId) {
// All padding will be put at the end, so in the worst case we need 3 headers:
// a repeated no-padding frame of maxFrameSize, a frame that has part data and part
// padding, and a frame that has the remainder of the padding.
buffer = ctx.alloc().buffer(3 * DATA_FRAME_HEADER_LENGTH);
this.streamId = streamId;
}
/**
* Gets the frame header buffer configured for the current frame.
*/
ByteBuf slice(int data, int padding, boolean endOfStream) {
// Since we're reusing the current frame header whenever possible, check if anything changed
// that requires a new header.
if (data != prevData || padding != prevPadding
|| endOfStream != flags.endOfStream() || frameHeader == null) {
// Update the header state.
prevData = data;
prevPadding = padding;
flags.paddingPresent(padding > 0);
flags.endOfStream(endOfStream);
frameHeader = buffer.readSlice(DATA_FRAME_HEADER_LENGTH).writerIndex(0);
int payloadLength = data + padding + flags.getPaddingPresenceFieldLength();
writeFrameHeaderInternal(frameHeader, payloadLength, DATA, flags, streamId);
writePaddingLength(frameHeader, padding);
}
return frameHeader.slice();
}
void release() {
buffer.release();
}
}
}

View File

@ -24,7 +24,8 @@ import io.netty.channel.ChannelPromise;
*/
public interface Http2DataWriter {
/**
* Writes a {@code DATA} frame to the remote endpoint.
* Writes a {@code DATA} frame to the remote endpoint. This will result in one or more
* frames being written to the context.
*
* @param ctx the context to use for writing.
* @param streamId the stream for which to send the frame.

View File

@ -123,9 +123,6 @@ public class DefaultHttp2ConnectionEncoderTest {
@Mock
private Http2FrameWriter.Configuration writerConfig;
@Mock
private Http2FrameSizePolicy frameSizePolicy;
@Mock
private Http2LifecycleManager lifecycleManager;
@ -166,8 +163,6 @@ public class DefaultHttp2ConnectionEncoderTest {
when(connection.remote()).thenReturn(remote);
when(remote.flowController()).thenReturn(remoteFlow);
when(writer.configuration()).thenReturn(writerConfig);
when(writerConfig.frameSizePolicy()).thenReturn(frameSizePolicy);
when(frameSizePolicy.maxFrameSize()).thenReturn(64);
when(local.createIdleStream(eq(STREAM_ID))).thenReturn(stream);
when(local.reservePushStream(eq(PUSH_STREAM_ID), eq(stream))).thenReturn(pushStream);
when(remote.createIdleStream(eq(STREAM_ID))).thenReturn(stream);
@ -286,81 +281,16 @@ public class DefaultHttp2ConnectionEncoderTest {
}
@Test
public void dataLargerThanMaxFrameSizeShouldBeSplit() throws Exception {
when(frameSizePolicy.maxFrameSize()).thenReturn(3);
final ByteBuf data = dummyData();
encoder.writeData(ctx, STREAM_ID, data, 0, true, promise);
assertEquals(payloadCaptor.getValue().size(), 8);
payloadCaptor.getValue().write(ctx, 8);
// writer was called 3 times
assertEquals(3, writtenData.size());
assertEquals("abc", writtenData.get(0));
assertEquals("def", writtenData.get(1));
assertEquals("gh", writtenData.get(2));
assertEquals(0, data.refCnt());
}
@Test
public void paddingSplitOverFrame() throws Exception {
when(frameSizePolicy.maxFrameSize()).thenReturn(5);
final ByteBuf data = dummyData();
encoder.writeData(ctx, STREAM_ID, data, 5, true, promise);
assertEquals(payloadCaptor.getValue().size(), 13);
payloadCaptor.getValue().write(ctx, 13);
// writer was called 3 times
assertEquals(3, writtenData.size());
assertEquals("abcde", writtenData.get(0));
assertEquals(0, (int) writtenPadding.get(0));
assertEquals("fgh", writtenData.get(1));
assertEquals(2, (int) writtenPadding.get(1));
assertEquals("", writtenData.get(2));
assertEquals(3, (int) writtenPadding.get(2));
assertEquals(0, data.refCnt());
}
@Test
public void frameShouldSplitPadding() throws Exception {
when(frameSizePolicy.maxFrameSize()).thenReturn(5);
ByteBuf data = dummyData();
encoder.writeData(ctx, STREAM_ID, data, 10, true, promise);
assertEquals(payloadCaptor.getValue().size(), 18);
payloadCaptor.getValue().write(ctx, 18);
// writer was called 4 times
assertEquals(4, writtenData.size());
assertEquals("abcde", writtenData.get(0));
assertEquals(0, (int) writtenPadding.get(0));
assertEquals("fgh", writtenData.get(1));
assertEquals(2, (int) writtenPadding.get(1));
assertEquals("", writtenData.get(2));
assertEquals(5, (int) writtenPadding.get(2));
assertEquals("", writtenData.get(3));
assertEquals(3, (int) writtenPadding.get(3));
assertEquals(0, data.refCnt());
}
@Test
public void emptyFrameShouldSplitPadding() throws Exception {
public void emptyFrameShouldWritePadding() throws Exception {
ByteBuf data = Unpooled.buffer(0);
assertSplitPaddingOnEmptyBuffer(data);
assertEquals(0, data.refCnt());
}
@Test
public void singletonEmptyBufferShouldSplitPadding() throws Exception {
assertSplitPaddingOnEmptyBuffer(Unpooled.EMPTY_BUFFER);
}
private void assertSplitPaddingOnEmptyBuffer(ByteBuf data) throws Exception {
when(frameSizePolicy.maxFrameSize()).thenReturn(5);
encoder.writeData(ctx, STREAM_ID, data, 10, true, promise);
assertEquals(payloadCaptor.getValue().size(), 10);
payloadCaptor.getValue().write(ctx, 10);
// writer was called 2 times
assertEquals(2, writtenData.size());
assertEquals(1, writtenData.size());
assertEquals("", writtenData.get(0));
assertEquals(5, (int) writtenPadding.get(0));
assertEquals("", writtenData.get(1));
assertEquals(5, (int) writtenPadding.get(1));
assertEquals(10, (int) writtenPadding.get(0));
assertEquals(0, data.refCnt());
}
@Test

View File

@ -1,390 +0,0 @@
/*
* Copyright 2014 The Netty Project
*
* The Netty Project licenses this file to you under the Apache License, version 2.0 (the
* "License"); you may not use this file except in compliance with the License. You may obtain a
* copy of the License at:
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software distributed under the License
* is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
* or implied. See the License for the specific language governing permissions and limitations under
* the License.
*/
package io.netty.handler.codec.http2;
import io.netty.buffer.ByteBuf;
import io.netty.buffer.ByteBufAllocator;
import io.netty.buffer.Unpooled;
import io.netty.buffer.UnpooledByteBufAllocator;
import io.netty.channel.Channel;
import io.netty.channel.ChannelFuture;
import io.netty.channel.ChannelFutureListener;
import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.ChannelPromise;
import io.netty.util.AsciiString;
import io.netty.util.CharsetUtil;
import io.netty.util.concurrent.EventExecutor;
import org.junit.After;
import org.junit.Before;
import org.junit.Test;
import org.mockito.Mock;
import org.mockito.MockitoAnnotations;
import org.mockito.invocation.InvocationOnMock;
import org.mockito.stubbing.Answer;
import static io.netty.handler.codec.http2.Http2CodecUtil.DEFAULT_MAX_HEADER_SIZE;
import static io.netty.handler.codec.http2.Http2CodecUtil.MAX_UNSIGNED_INT;
import static io.netty.handler.codec.http2.Http2TestUtil.randomString;
import static org.junit.Assert.fail;
import static org.mockito.Matchers.any;
import static org.mockito.Matchers.anyBoolean;
import static org.mockito.Matchers.anyInt;
import static org.mockito.Matchers.anyShort;
import static org.mockito.Matchers.eq;
import static org.mockito.Mockito.doAnswer;
import static org.mockito.Mockito.never;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;
/**
* Integration tests for {@link DefaultHttp2FrameReader} and {@link DefaultHttp2FrameWriter}.
*/
public class DefaultHttp2FrameIOTest {
private DefaultHttp2FrameReader reader;
private DefaultHttp2FrameWriter writer;
private ByteBufAllocator alloc;
private ByteBuf buffer;
private ByteBuf data;
@Mock
private ChannelHandlerContext ctx;
@Mock
private Http2FrameListener listener;
@Mock
private ChannelPromise promise;
@Mock
private ChannelPromise aggregatePromise;
@Mock
private Channel channel;
@Mock
private EventExecutor executor;
@Before
public void setup() {
MockitoAnnotations.initMocks(this);
alloc = UnpooledByteBufAllocator.DEFAULT;
buffer = alloc.buffer();
data = dummyData();
when(executor.inEventLoop()).thenReturn(true);
when(ctx.alloc()).thenReturn(alloc);
when(ctx.channel()).thenReturn(channel);
when(ctx.executor()).thenReturn(executor);
when(ctx.newPromise()).thenReturn(promise);
when(promise.isDone()).thenReturn(true);
when(promise.isSuccess()).thenReturn(true);
doAnswer(new Answer<Void>() {
@Override
public Void answer(InvocationOnMock in) throws Throwable {
ChannelFutureListener l = (ChannelFutureListener) in.getArguments()[0];
l.operationComplete(promise);
return null;
}
}).when(promise).addListener(any(ChannelFutureListener.class));
doAnswer(new Answer<ChannelFuture>() {
@Override
public ChannelFuture answer(InvocationOnMock in) throws Throwable {
if (in.getArguments()[0] instanceof ByteBuf) {
ByteBuf tmp = (ByteBuf) in.getArguments()[0];
try {
buffer.writeBytes(tmp);
} finally {
tmp.release();
}
}
if (in.getArguments()[1] instanceof ChannelPromise) {
return ((ChannelPromise) in.getArguments()[1]).setSuccess();
}
return null;
}
}).when(ctx).write(any(), any(ChannelPromise.class));
reader = new DefaultHttp2FrameReader(false);
writer = new DefaultHttp2FrameWriter();
}
@After
public void tearDown() {
buffer.release();
data.release();
}
@Test
public void emptyDataShouldRoundtrip() throws Exception {
final ByteBuf data = Unpooled.EMPTY_BUFFER;
writer.writeData(ctx, 1000, data, 0, false, promise);
reader.readFrame(ctx, buffer, listener);
verify(listener).onDataRead(eq(ctx), eq(1000), eq(data), eq(0), eq(false));
}
@Test
public void dataShouldRoundtrip() throws Exception {
writer.writeData(ctx, 1000, data.retain().duplicate(), 0, false, promise);
reader.readFrame(ctx, buffer, listener);
verify(listener).onDataRead(eq(ctx), eq(1000), eq(data), eq(0), eq(false));
}
@Test
public void dataWithPaddingShouldRoundtrip() throws Exception {
writer.writeData(ctx, 1, data.retain().duplicate(), 0xFF, true, promise);
reader.readFrame(ctx, buffer, listener);
verify(listener).onDataRead(eq(ctx), eq(1), eq(data), eq(0xFF), eq(true));
}
@Test
public void priorityShouldRoundtrip() throws Exception {
writer.writePriority(ctx, 1, 2, (short) 255, true, promise);
reader.readFrame(ctx, buffer, listener);
verify(listener).onPriorityRead(eq(ctx), eq(1), eq(2), eq((short) 255), eq(true));
}
@Test
public void rstStreamShouldRoundtrip() throws Exception {
writer.writeRstStream(ctx, 1, MAX_UNSIGNED_INT, promise);
reader.readFrame(ctx, buffer, listener);
verify(listener).onRstStreamRead(eq(ctx), eq(1), eq(MAX_UNSIGNED_INT));
}
@Test
public void emptySettingsShouldRoundtrip() throws Exception {
writer.writeSettings(ctx, new Http2Settings(), promise);
reader.readFrame(ctx, buffer, listener);
verify(listener).onSettingsRead(eq(ctx), eq(new Http2Settings()));
}
@Test
public void settingsShouldStripShouldRoundtrip() throws Exception {
Http2Settings settings = new Http2Settings();
settings.pushEnabled(true);
settings.headerTableSize(4096);
settings.initialWindowSize(123);
settings.maxConcurrentStreams(456);
writer.writeSettings(ctx, settings, promise);
reader.readFrame(ctx, buffer, listener);
verify(listener).onSettingsRead(eq(ctx), eq(settings));
}
@Test
public void settingsAckShouldRoundtrip() throws Exception {
writer.writeSettingsAck(ctx, promise);
reader.readFrame(ctx, buffer, listener);
verify(listener).onSettingsAckRead(eq(ctx));
}
@Test
public void pingShouldRoundtrip() throws Exception {
writer.writePing(ctx, false, data.retain().duplicate(), promise);
reader.readFrame(ctx, buffer, listener);
verify(listener).onPingRead(eq(ctx), eq(data));
}
@Test
public void pingAckShouldRoundtrip() throws Exception {
writer.writePing(ctx, true, data.retain().duplicate(), promise);
reader.readFrame(ctx, buffer, listener);
verify(listener).onPingAckRead(eq(ctx), eq(data));
}
@Test
public void goAwayShouldRoundtrip() throws Exception {
writer.writeGoAway(ctx, 1, MAX_UNSIGNED_INT, data.retain().duplicate(), promise);
reader.readFrame(ctx, buffer, listener);
verify(listener).onGoAwayRead(eq(ctx), eq(1), eq(MAX_UNSIGNED_INT), eq(data));
}
@Test
public void windowUpdateShouldRoundtrip() throws Exception {
writer.writeWindowUpdate(ctx, 1, Integer.MAX_VALUE, promise);
reader.readFrame(ctx, buffer, listener);
verify(listener).onWindowUpdateRead(eq(ctx), eq(1), eq(Integer.MAX_VALUE));
}
@Test
public void emptyHeadersShouldRoundtrip() throws Exception {
Http2Headers headers = EmptyHttp2Headers.INSTANCE;
writer.writeHeaders(ctx, 1, headers, 0, true, promise);
reader.readFrame(ctx, buffer, listener);
verify(listener).onHeadersRead(eq(ctx), eq(1), eq(headers), eq(0), eq(true));
}
@Test
public void emptyHeadersWithPaddingShouldRoundtrip() throws Exception {
Http2Headers headers = EmptyHttp2Headers.INSTANCE;
writer.writeHeaders(ctx, 1, headers, 0xFF, true, promise);
reader.readFrame(ctx, buffer, listener);
verify(listener).onHeadersRead(eq(ctx), eq(1), eq(headers), eq(0xFF), eq(true));
}
@Test
public void binaryHeadersWithoutPriorityShouldRoundtrip() throws Exception {
Http2Headers headers = dummyBinaryHeaders();
writer.writeHeaders(ctx, 1, headers, 0, true, promise);
reader.readFrame(ctx, buffer, listener);
verify(listener).onHeadersRead(eq(ctx), eq(1), eq(headers), eq(0), eq(true));
}
@Test
public void headersWithoutPriorityShouldRoundtrip() throws Exception {
Http2Headers headers = dummyHeaders();
writer.writeHeaders(ctx, 1, headers, 0, true, promise);
reader.readFrame(ctx, buffer, listener);
verify(listener).onHeadersRead(eq(ctx), eq(1), eq(headers), eq(0), eq(true));
}
@Test
public void headersWithPaddingWithoutPriorityShouldRoundtrip() throws Exception {
Http2Headers headers = dummyHeaders();
writer.writeHeaders(ctx, 1, headers, 0xFF, true, promise);
reader.readFrame(ctx, buffer, listener);
verify(listener).onHeadersRead(eq(ctx), eq(1), eq(headers), eq(0xFF), eq(true));
}
@Test
public void headersWithPriorityShouldRoundtrip() throws Exception {
Http2Headers headers = dummyHeaders();
writer.writeHeaders(ctx, 1, headers, 2, (short) 3, true, 0, true, promise);
reader.readFrame(ctx, buffer, listener);
verify(listener)
.onHeadersRead(eq(ctx), eq(1), eq(headers), eq(2), eq((short) 3), eq(true), eq(0), eq(true));
}
@Test
public void headersWithPaddingWithPriorityShouldRoundtrip() throws Exception {
Http2Headers headers = dummyHeaders();
writer.writeHeaders(ctx, 1, headers, 2, (short) 3, true, 0xFF, true, promise);
reader.readFrame(ctx, buffer, listener);
verify(listener).onHeadersRead(eq(ctx), eq(1), eq(headers), eq(2), eq((short) 3), eq(true), eq(0xFF),
eq(true));
}
@Test
public void continuedHeadersShouldRoundtrip() throws Exception {
Http2Headers headers = largeHeaders();
writer.writeHeaders(ctx, 1, headers, 2, (short) 3, true, 0, true, promise);
reader.readFrame(ctx, buffer, listener);
verify(listener)
.onHeadersRead(eq(ctx), eq(1), eq(headers), eq(2), eq((short) 3), eq(true), eq(0), eq(true));
}
@Test
public void continuedHeadersWithPaddingShouldRoundtrip() throws Exception {
Http2Headers headers = largeHeaders();
writer.writeHeaders(ctx, 1, headers, 2, (short) 3, true, 0xFF, true, promise);
reader.readFrame(ctx, buffer, listener);
verify(listener).onHeadersRead(eq(ctx), eq(1), eq(headers), eq(2), eq((short) 3), eq(true), eq(0xFF),
eq(true));
}
@Test
public void headersThatAreTooBigShouldFail() throws Exception {
Http2Headers headers = headersOfSize(DEFAULT_MAX_HEADER_SIZE + 1);
writer.writeHeaders(ctx, 1, headers, 2, (short) 3, true, 0xFF, true, promise);
try {
reader.readFrame(ctx, buffer, listener);
fail();
} catch (Http2Exception e) {
verify(listener, never()).onHeadersRead(any(ChannelHandlerContext.class), anyInt(),
any(Http2Headers.class), anyInt(), anyShort(), anyBoolean(), anyInt(),
anyBoolean());
}
}
@Test
public void emptypushPromiseShouldRoundtrip() throws Exception {
Http2Headers headers = EmptyHttp2Headers.INSTANCE;
writer.writePushPromise(ctx, 1, 2, headers, 0, promise);
reader.readFrame(ctx, buffer, listener);
verify(listener).onPushPromiseRead(eq(ctx), eq(1), eq(2), eq(headers), eq(0));
}
@Test
public void pushPromiseShouldRoundtrip() throws Exception {
Http2Headers headers = dummyHeaders();
writer.writePushPromise(ctx, 1, 2, headers, 0, promise);
reader.readFrame(ctx, buffer, listener);
verify(listener).onPushPromiseRead(eq(ctx), eq(1), eq(2), eq(headers), eq(0));
}
@Test
public void pushPromiseWithPaddingShouldRoundtrip() throws Exception {
Http2Headers headers = dummyHeaders();
writer.writePushPromise(ctx, 1, 2, headers, 0xFF, promise);
reader.readFrame(ctx, buffer, listener);
verify(listener).onPushPromiseRead(eq(ctx), eq(1), eq(2), eq(headers), eq(0xFF));
}
@Test
public void continuedPushPromiseShouldRoundtrip() throws Exception {
Http2Headers headers = largeHeaders();
writer.writePushPromise(ctx, 1, 2, headers, 0, promise);
reader.readFrame(ctx, buffer, listener);
verify(listener).onPushPromiseRead(eq(ctx), eq(1), eq(2), eq(headers), eq(0));
}
@Test
public void continuedPushPromiseWithPaddingShouldRoundtrip() throws Exception {
Http2Headers headers = largeHeaders();
writer.writePushPromise(ctx, 1, 2, headers, 0xFF, promise);
reader.readFrame(ctx, buffer, listener);
verify(listener).onPushPromiseRead(eq(ctx), eq(1), eq(2), eq(headers), eq(0xFF));
}
private ByteBuf dummyData() {
return alloc.buffer().writeBytes("abcdefgh".getBytes(CharsetUtil.UTF_8));
}
private static Http2Headers dummyBinaryHeaders() {
DefaultHttp2Headers headers = new DefaultHttp2Headers(false);
for (int ix = 0; ix < 10; ++ix) {
headers.add(randomString(), randomString());
}
return headers;
}
private static Http2Headers dummyHeaders() {
return new DefaultHttp2Headers(false).method(new AsciiString("GET")).scheme(new AsciiString("https"))
.authority(new AsciiString("example.org")).path(new AsciiString("/some/path"))
.add(new AsciiString("accept"), new AsciiString("*/*"));
}
private static Http2Headers largeHeaders() {
DefaultHttp2Headers headers = new DefaultHttp2Headers(false);
for (int i = 0; i < 100; ++i) {
String key = "this-is-a-test-header-key-" + i;
String value = "this-is-a-test-header-value-" + i;
headers.add(new AsciiString(key), new AsciiString(value));
}
return headers;
}
private Http2Headers headersOfSize(final int minSize) {
final AsciiString singleByte = new AsciiString(new byte[]{0}, false);
DefaultHttp2Headers headers = new DefaultHttp2Headers(false);
for (int size = 0; size < minSize; size += 2) {
headers.add(singleByte, singleByte);
}
return headers;
}
}

View File

@ -15,391 +15,474 @@
package io.netty.handler.codec.http2;
import io.netty.bootstrap.Bootstrap;
import io.netty.bootstrap.ServerBootstrap;
import static io.netty.buffer.Unpooled.EMPTY_BUFFER;
import static io.netty.handler.codec.http2.Http2CodecUtil.DEFAULT_MAX_HEADER_SIZE;
import static io.netty.handler.codec.http2.Http2TestUtil.randomString;
import static io.netty.util.CharsetUtil.UTF_8;
import static java.lang.Math.min;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertFalse;
import static org.junit.Assert.fail;
import static org.mockito.Matchers.any;
import static org.mockito.Matchers.anyBoolean;
import static org.mockito.Matchers.anyInt;
import static org.mockito.Matchers.anyShort;
import static org.mockito.Matchers.eq;
import static org.mockito.Matchers.isA;
import static org.mockito.Mockito.atLeastOnce;
import static org.mockito.Mockito.doAnswer;
import static org.mockito.Mockito.never;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;
import io.netty.buffer.ByteBuf;
import io.netty.buffer.ByteBufAllocator;
import io.netty.buffer.CompositeByteBuf;
import io.netty.buffer.EmptyByteBuf;
import io.netty.buffer.ReadOnlyByteBuf;
import io.netty.buffer.Unpooled;
import io.netty.channel.Channel;
import io.netty.channel.ChannelFuture;
import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.ChannelInitializer;
import io.netty.channel.ChannelPipeline;
import io.netty.channel.ChannelPromise;
import io.netty.channel.DefaultEventLoopGroup;
import io.netty.channel.local.LocalAddress;
import io.netty.channel.local.LocalChannel;
import io.netty.channel.local.LocalServerChannel;
import io.netty.handler.codec.http2.Http2TestUtil.Http2Runnable;
import io.netty.channel.DefaultChannelPromise;
import io.netty.util.AsciiString;
import io.netty.util.concurrent.Future;
import io.netty.util.concurrent.EventExecutor;
import org.junit.After;
import org.junit.Before;
import org.junit.Test;
import org.mockito.ArgumentCaptor;
import org.mockito.Mock;
import org.mockito.MockitoAnnotations;
import org.mockito.invocation.InvocationOnMock;
import org.mockito.stubbing.Answer;
import java.util.ArrayList;
import java.util.Collections;
import java.util.LinkedList;
import java.util.List;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.TimeUnit;
import static io.netty.handler.codec.http2.Http2TestUtil.randomString;
import static io.netty.handler.codec.http2.Http2TestUtil.runInChannel;
import static io.netty.util.CharsetUtil.UTF_8;
import static java.util.concurrent.TimeUnit.SECONDS;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertTrue;
import static org.mockito.Matchers.any;
import static org.mockito.Matchers.anyInt;
import static org.mockito.Matchers.eq;
import static org.mockito.Mockito.doAnswer;
import static org.mockito.Mockito.times;
import static org.mockito.Mockito.verify;
/**
* Tests encoding/decoding each HTTP2 frame type.
*/
public class Http2FrameRoundtripTest {
private static final byte[] MESSAGE = "hello world".getBytes(UTF_8);
private static final int STREAM_ID = 0x7FFFFFFF;
private static final int WINDOW_UPDATE = 0x7FFFFFFF;
private static final long ERROR_CODE = 0xFFFFFFFFL;
@Mock
private Http2FrameListener serverListener;
private Http2FrameListener listener;
private Http2FrameWriter frameWriter;
private ServerBootstrap sb;
private Bootstrap cb;
private Channel serverChannel;
private Channel clientChannel;
private CountDownLatch requestLatch;
private Http2TestUtil.FrameAdapter serverAdapter;
@Mock
private ChannelHandlerContext ctx;
@Mock
private EventExecutor executor;
@Mock
private Channel channel;
@Mock
private ByteBufAllocator alloc;
private Http2FrameWriter writer;
private Http2FrameReader reader;
private List<ByteBuf> needReleasing = new LinkedList<ByteBuf>();
@Before
public void setup() throws Exception {
MockitoAnnotations.initMocks(this);
when(ctx.alloc()).thenReturn(alloc);
when(ctx.executor()).thenReturn(executor);
when(ctx.channel()).thenReturn(channel);
doAnswer(new Answer<ByteBuf>() {
@Override
public ByteBuf answer(InvocationOnMock in) throws Throwable {
return Unpooled.buffer();
}
}).when(alloc).buffer();
doAnswer(new Answer<ByteBuf>() {
@Override
public ByteBuf answer(InvocationOnMock in) throws Throwable {
return Unpooled.buffer((Integer) in.getArguments()[0]);
}
}).when(alloc).buffer(anyInt());
doAnswer(new Answer<ChannelPromise>() {
@Override
public ChannelPromise answer(InvocationOnMock invocation) throws Throwable {
return new DefaultChannelPromise(channel);
}
}).when(ctx).newPromise();
writer = new DefaultHttp2FrameWriter();
reader = new DefaultHttp2FrameReader(false);
}
@After
public void teardown() throws Exception {
if (clientChannel != null) {
clientChannel.close().sync();
clientChannel = null;
public void teardown() {
try {
// Release all of the buffers.
for (ByteBuf buf : needReleasing) {
buf.release();
}
// Now verify that all of the reference counts are zero.
for (ByteBuf buf : needReleasing) {
int expectedFinalRefCount = 0;
if (buf instanceof ReadOnlyByteBuf || buf instanceof EmptyByteBuf) {
// Special case for when we're writing slices of the padding buffer.
expectedFinalRefCount = 1;
}
assertEquals(expectedFinalRefCount, buf.refCnt());
}
} finally {
needReleasing.clear();
}
if (serverChannel != null) {
serverChannel.close().sync();
serverChannel = null;
}
Future<?> serverGroup = sb.group().shutdownGracefully(0, 0, TimeUnit.MILLISECONDS);
Future<?> serverChildGroup = sb.childGroup().shutdownGracefully(0, 0, TimeUnit.MILLISECONDS);
Future<?> clientGroup = cb.group().shutdownGracefully(0, 0, TimeUnit.MILLISECONDS);
serverGroup.sync();
serverChildGroup.sync();
clientGroup.sync();
serverAdapter = null;
}
@Test
public void dataFrameShouldMatch() throws Exception {
final String text = "hello world";
final ByteBuf data = Unpooled.copiedBuffer(text, UTF_8);
final List<String> receivedBuffers = Collections.synchronizedList(new ArrayList<String>());
doAnswer(new Answer<Void>() {
@Override
public Void answer(InvocationOnMock in) throws Throwable {
receivedBuffers.add(((ByteBuf) in.getArguments()[2]).toString(UTF_8));
return null;
}
}).when(serverListener).onDataRead(any(ChannelHandlerContext.class), eq(0x7FFFFFFF),
any(ByteBuf.class), eq(100), eq(true));
try {
bootstrapEnv(1);
runInChannel(clientChannel, new Http2Runnable() {
@Override
public void run() {
frameWriter.writeData(ctx(), 0x7FFFFFFF, data.slice().retain(), 100, true, newPromise());
ctx().flush();
}
});
awaitRequests();
verify(serverListener).onDataRead(any(ChannelHandlerContext.class), eq(0x7FFFFFFF),
any(ByteBuf.class), eq(100), eq(true));
assertEquals(1, receivedBuffers.size());
assertEquals(text, receivedBuffers.get(0));
} finally {
data.release();
public void emptyDataShouldMatch() throws Exception {
final ByteBuf data = EMPTY_BUFFER;
writer.writeData(ctx, STREAM_ID, data.slice(), 0, false, ctx.newPromise());
readFrames();
verify(listener).onDataRead(eq(ctx), eq(STREAM_ID), eq(data), eq(0), eq(false));
}
@Test
public void dataShouldMatch() throws Exception {
final ByteBuf data = data(10);
writer.writeData(ctx, STREAM_ID, data.slice(), 0, false, ctx.newPromise());
readFrames();
verify(listener).onDataRead(eq(ctx), eq(STREAM_ID), eq(data), eq(0), eq(false));
}
@Test
public void dataWithPaddingShouldMatch() throws Exception {
final ByteBuf data = data(10);
writer.writeData(ctx, STREAM_ID, data.slice(), 0xFF, true, ctx.newPromise());
readFrames();
verify(listener).onDataRead(eq(ctx), eq(STREAM_ID), eq(data), eq(0xFF), eq(true));
}
@Test
public void largeDataFrameShouldMatch() throws Exception {
// Create a large message to force chunking.
final ByteBuf originalData = data(1024 * 1024);
final int originalPadding = 100;
final boolean endOfStream = true;
writer.writeData(ctx, STREAM_ID, originalData.slice(), originalPadding,
endOfStream, ctx.newPromise());
readFrames();
// Verify that at least one frame was sent with eos=false and exactly one with eos=true.
verify(listener, atLeastOnce()).onDataRead(eq(ctx), eq(STREAM_ID), any(ByteBuf.class),
anyInt(), eq(false));
verify(listener).onDataRead(eq(ctx), eq(STREAM_ID), any(ByteBuf.class),
anyInt(), eq(true));
// Capture the read data and padding.
ArgumentCaptor<ByteBuf> dataCaptor = ArgumentCaptor.forClass(ByteBuf.class);
ArgumentCaptor<Integer> paddingCaptor = ArgumentCaptor.forClass(Integer.class);
verify(listener, atLeastOnce()).onDataRead(eq(ctx), eq(STREAM_ID), dataCaptor.capture(),
paddingCaptor.capture(), anyBoolean());
// Make sure the data matches the original.
for (ByteBuf chunk : dataCaptor.getAllValues()) {
ByteBuf originalChunk = originalData.readSlice(chunk.readableBytes());
assertEquals(originalChunk, chunk);
}
assertFalse(originalData.isReadable());
// Make sure the padding matches the original.
int totalReadPadding = 0;
for (int framePadding : paddingCaptor.getAllValues()) {
totalReadPadding += framePadding;
}
assertEquals(originalPadding, totalReadPadding);
}
@Test
public void emptyHeadersShouldMatch() throws Exception {
final Http2Headers headers = EmptyHttp2Headers.INSTANCE;
writer.writeHeaders(ctx, STREAM_ID, headers, 0, true, ctx.newPromise());
readFrames();
verify(listener).onHeadersRead(eq(ctx), eq(STREAM_ID), eq(headers), eq(0), eq(true));
}
@Test
public void emptyHeadersWithPaddingShouldMatch() throws Exception {
final Http2Headers headers = EmptyHttp2Headers.INSTANCE;
writer.writeHeaders(ctx, STREAM_ID, headers, 0xFF, true, ctx.newPromise());
readFrames();
verify(listener).onHeadersRead(eq(ctx), eq(STREAM_ID), eq(headers), eq(0xFF), eq(true));
}
@Test
public void binaryHeadersWithoutPriorityShouldMatch() throws Exception {
final Http2Headers headers = binaryHeaders();
writer.writeHeaders(ctx, STREAM_ID, headers, 0, true, ctx.newPromise());
readFrames();
verify(listener).onHeadersRead(eq(ctx), eq(STREAM_ID), eq(headers), eq(0), eq(true));
}
@Test
public void headersFrameWithoutPriorityShouldMatch() throws Exception {
final Http2Headers headers = headers();
bootstrapEnv(1);
runInChannel(clientChannel, new Http2Runnable() {
@Override
public void run() {
frameWriter.writeHeaders(ctx(), 0x7FFFFFFF, headers, 0, true, newPromise());
ctx().flush();
}
});
awaitRequests();
verify(serverListener).onHeadersRead(any(ChannelHandlerContext.class), eq(0x7FFFFFFF),
eq(headers), eq(0), eq(true));
writer.writeHeaders(ctx, STREAM_ID, headers, 0, true, ctx.newPromise());
readFrames();
verify(listener).onHeadersRead(eq(ctx), eq(STREAM_ID), eq(headers), eq(0), eq(true));
}
@Test
public void headersFrameWithPriorityShouldMatch() throws Exception {
final Http2Headers headers = headers();
bootstrapEnv(1);
runInChannel(clientChannel, new Http2Runnable() {
@Override
public void run() {
frameWriter.writeHeaders(ctx(), 0x7FFFFFFF, headers, 4, (short) 255,
true, 0, true, newPromise());
ctx().flush();
}
});
awaitRequests();
verify(serverListener).onHeadersRead(any(ChannelHandlerContext.class), eq(0x7FFFFFFF),
eq(headers), eq(4), eq((short) 255), eq(true), eq(0), eq(true));
writer.writeHeaders(ctx, STREAM_ID, headers, 4, (short) 255, true, 0, true, ctx.newPromise());
readFrames();
verify(listener).onHeadersRead(eq(ctx), eq(STREAM_ID), eq(headers), eq(4), eq((short) 255),
eq(true), eq(0), eq(true));
}
@Test
public void goAwayFrameShouldMatch() throws Exception {
final String text = "test";
final ByteBuf data = Unpooled.copiedBuffer(text.getBytes());
final List<String> receivedBuffers = Collections.synchronizedList(new ArrayList<String>());
doAnswer(new Answer<Void>() {
@Override
public Void answer(InvocationOnMock in) throws Throwable {
receivedBuffers.add(((ByteBuf) in.getArguments()[3]).toString(UTF_8));
return null;
}
}).when(serverListener).onGoAwayRead(any(ChannelHandlerContext.class), eq(0x7FFFFFFF),
eq(0xFFFFFFFFL), eq(data));
bootstrapEnv(1);
public void headersWithPaddingWithoutPriorityShouldMatch() throws Exception {
final Http2Headers headers = headers();
writer.writeHeaders(ctx, STREAM_ID, headers, 0xFF, true, ctx.newPromise());
readFrames();
verify(listener).onHeadersRead(eq(ctx), eq(STREAM_ID), eq(headers), eq(0xFF), eq(true));
}
@Test
public void headersWithPaddingWithPriorityShouldMatch() throws Exception {
final Http2Headers headers = headers();
writer.writeHeaders(ctx, STREAM_ID, headers, 2, (short) 3, true, 0xFF, true, ctx.newPromise());
readFrames();
verify(listener).onHeadersRead(eq(ctx), eq(STREAM_ID), eq(headers), eq(2), eq((short) 3), eq(true),
eq(0xFF), eq(true));
}
@Test
public void continuedHeadersShouldMatch() throws Exception {
final Http2Headers headers = largeHeaders();
writer.writeHeaders(ctx, STREAM_ID, headers, 2, (short) 3, true, 0, true, ctx.newPromise());
readFrames();
verify(listener)
.onHeadersRead(eq(ctx), eq(STREAM_ID), eq(headers), eq(2), eq((short) 3), eq(true), eq(0), eq(true));
}
@Test
public void continuedHeadersWithPaddingShouldMatch() throws Exception {
final Http2Headers headers = largeHeaders();
writer.writeHeaders(ctx, STREAM_ID, headers, 2, (short) 3, true, 0xFF, true, ctx.newPromise());
readFrames();
verify(listener).onHeadersRead(eq(ctx), eq(STREAM_ID), eq(headers), eq(2), eq((short) 3), eq(true),
eq(0xFF), eq(true));
}
@Test
public void headersThatAreTooBigShouldFail() throws Exception {
final Http2Headers headers = headersOfSize(DEFAULT_MAX_HEADER_SIZE + 1);
writer.writeHeaders(ctx, STREAM_ID, headers, 2, (short) 3, true, 0xFF, true, ctx.newPromise());
try {
runInChannel(clientChannel, new Http2Runnable() {
@Override
public void run() {
frameWriter.writeGoAway(ctx(), 0x7FFFFFFF, 0xFFFFFFFFL, data.duplicate().retain(), newPromise());
ctx().flush();
}
});
awaitRequests();
verify(serverListener).onGoAwayRead(any(ChannelHandlerContext.class), eq(0x7FFFFFFF),
eq(0xFFFFFFFFL), any(ByteBuf.class));
assertEquals(1, receivedBuffers.size());
assertEquals(text, receivedBuffers.get(0));
} finally {
data.release();
readFrames();
fail();
} catch (Http2Exception e) {
verify(listener, never()).onHeadersRead(any(ChannelHandlerContext.class), anyInt(),
any(Http2Headers.class), anyInt(), anyShort(), anyBoolean(), anyInt(),
anyBoolean());
}
}
@Test
public void pingFrameShouldMatch() throws Exception {
String text = "01234567";
final ByteBuf data = Unpooled.copiedBuffer(text, UTF_8);
final List<String> receivedBuffers = Collections.synchronizedList(new ArrayList<String>());
doAnswer(new Answer<Void>() {
@Override
public Void answer(InvocationOnMock in) throws Throwable {
receivedBuffers.add(((ByteBuf) in.getArguments()[1]).toString(UTF_8));
return null;
}
}).when(serverListener).onPingAckRead(any(ChannelHandlerContext.class), eq(data));
try {
bootstrapEnv(1);
runInChannel(clientChannel, new Http2Runnable() {
@Override
public void run() {
frameWriter.writePing(ctx(), true, data.duplicate().retain(), newPromise());
ctx().flush();
}
});
awaitRequests();
verify(serverListener).onPingAckRead(any(ChannelHandlerContext.class), any(ByteBuf.class));
assertEquals(1, receivedBuffers.size());
for (String receivedData : receivedBuffers) {
assertEquals(text, receivedData);
}
} finally {
data.release();
}
}
@Test
public void priorityFrameShouldMatch() throws Exception {
bootstrapEnv(1);
runInChannel(clientChannel, new Http2Runnable() {
@Override
public void run() {
frameWriter.writePriority(ctx(), 0x7FFFFFFF, 1, (short) 1, true, newPromise());
ctx().flush();
}
});
awaitRequests();
verify(serverListener).onPriorityRead(any(ChannelHandlerContext.class), eq(0x7FFFFFFF),
eq(1), eq((short) 1), eq(true));
public void emptyPushPromiseShouldMatch() throws Exception {
final Http2Headers headers = EmptyHttp2Headers.INSTANCE;
writer.writePushPromise(ctx, STREAM_ID, 2, headers, 0, ctx.newPromise());
readFrames();
verify(listener).onPushPromiseRead(eq(ctx), eq(STREAM_ID), eq(2), eq(headers), eq(0));
}
@Test
public void pushPromiseFrameShouldMatch() throws Exception {
final Http2Headers headers = headers();
bootstrapEnv(1);
runInChannel(clientChannel, new Http2Runnable() {
@Override
public void run() {
frameWriter.writePushPromise(ctx(), 0x7FFFFFFF, 1, headers, 5, newPromise());
ctx().flush();
}
});
awaitRequests();
verify(serverListener).onPushPromiseRead(any(ChannelHandlerContext.class), eq(0x7FFFFFFF),
eq(1), eq(headers), eq(5));
writer.writePushPromise(ctx, STREAM_ID, 1, headers, 5, ctx.newPromise());
readFrames();
verify(listener).onPushPromiseRead(eq(ctx), eq(STREAM_ID), eq(1), eq(headers), eq(5));
}
@Test
public void pushPromiseWithPaddingShouldMatch() throws Exception {
final Http2Headers headers = headers();
writer.writePushPromise(ctx, STREAM_ID, 2, headers, 0xFF, ctx.newPromise());
readFrames();
verify(listener).onPushPromiseRead(eq(ctx), eq(STREAM_ID), eq(2), eq(headers), eq(0xFF));
}
@Test
public void continuedPushPromiseShouldMatch() throws Exception {
final Http2Headers headers = largeHeaders();
writer.writePushPromise(ctx, STREAM_ID, 2, headers, 0, ctx.newPromise());
readFrames();
verify(listener).onPushPromiseRead(eq(ctx), eq(STREAM_ID), eq(2), eq(headers), eq(0));
}
@Test
public void continuedPushPromiseWithPaddingShouldMatch() throws Exception {
final Http2Headers headers = largeHeaders();
writer.writePushPromise(ctx, STREAM_ID, 2, headers, 0xFF, ctx.newPromise());
readFrames();
verify(listener).onPushPromiseRead(eq(ctx), eq(STREAM_ID), eq(2), eq(headers), eq(0xFF));
}
@Test
public void goAwayFrameShouldMatch() throws Exception {
final String text = "test";
final ByteBuf data = buf(text.getBytes());
writer.writeGoAway(ctx, STREAM_ID, ERROR_CODE, data.slice(), ctx.newPromise());
readFrames();
ArgumentCaptor<ByteBuf> captor = ArgumentCaptor.forClass(ByteBuf.class);
verify(listener).onGoAwayRead(eq(ctx), eq(STREAM_ID), eq(ERROR_CODE), captor.capture());
assertEquals(data, captor.getValue());
}
@Test
public void pingFrameShouldMatch() throws Exception {
final ByteBuf data = buf("01234567".getBytes(UTF_8));
writer.writePing(ctx, false, data.slice(), ctx.newPromise());
readFrames();
ArgumentCaptor<ByteBuf> captor = ArgumentCaptor.forClass(ByteBuf.class);
verify(listener).onPingRead(eq(ctx), captor.capture());
assertEquals(data, captor.getValue());
}
@Test
public void pingAckFrameShouldMatch() throws Exception {
final ByteBuf data = buf("01234567".getBytes(UTF_8));
writer.writePing(ctx, true, data.slice(), ctx.newPromise());
readFrames();
ArgumentCaptor<ByteBuf> captor = ArgumentCaptor.forClass(ByteBuf.class);
verify(listener).onPingAckRead(eq(ctx), captor.capture());
assertEquals(data, captor.getValue());
}
@Test
public void priorityFrameShouldMatch() throws Exception {
writer.writePriority(ctx, STREAM_ID, 1, (short) 1, true, ctx.newPromise());
readFrames();
verify(listener).onPriorityRead(eq(ctx), eq(STREAM_ID), eq(1), eq((short) 1), eq(true));
}
@Test
public void rstStreamFrameShouldMatch() throws Exception {
bootstrapEnv(1);
runInChannel(clientChannel, new Http2Runnable() {
@Override
public void run() {
frameWriter.writeRstStream(ctx(), 0x7FFFFFFF, 0xFFFFFFFFL, newPromise());
ctx().flush();
}
});
awaitRequests();
verify(serverListener).onRstStreamRead(any(ChannelHandlerContext.class), eq(0x7FFFFFFF),
eq(0xFFFFFFFFL));
writer.writeRstStream(ctx, STREAM_ID, ERROR_CODE, ctx.newPromise());
readFrames();
verify(listener).onRstStreamRead(eq(ctx), eq(STREAM_ID), eq(ERROR_CODE));
}
@Test
public void settingsFrameShouldMatch() throws Exception {
bootstrapEnv(1);
public void emptySettingsFrameShouldMatch() throws Exception {
final Http2Settings settings = new Http2Settings();
settings.initialWindowSize(10);
settings.maxConcurrentStreams(1000);
writer.writeSettings(ctx, settings, ctx.newPromise());
readFrames();
verify(listener).onSettingsRead(eq(ctx), eq(settings));
}
@Test
public void settingsShouldStripShouldMatch() throws Exception {
final Http2Settings settings = new Http2Settings();
settings.pushEnabled(true);
settings.headerTableSize(4096);
runInChannel(clientChannel, new Http2Runnable() {
@Override
public void run() {
frameWriter.writeSettings(ctx(), settings, newPromise());
ctx().flush();
}
});
awaitRequests();
verify(serverListener).onSettingsRead(any(ChannelHandlerContext.class), eq(settings));
settings.initialWindowSize(123);
settings.maxConcurrentStreams(456);
writer.writeSettings(ctx, settings, ctx.newPromise());
readFrames();
verify(listener).onSettingsRead(eq(ctx), eq(settings));
}
@Test
public void settingsAckShouldMatch() throws Exception {
writer.writeSettingsAck(ctx, ctx.newPromise());
readFrames();
verify(listener).onSettingsAckRead(eq(ctx));
}
@Test
public void windowUpdateFrameShouldMatch() throws Exception {
bootstrapEnv(1);
runInChannel(clientChannel, new Http2Runnable() {
@Override
public void run() {
frameWriter.writeWindowUpdate(ctx(), 0x7FFFFFFF, 0x7FFFFFFF, newPromise());
ctx().flush();
}
});
awaitRequests();
verify(serverListener).onWindowUpdateRead(any(ChannelHandlerContext.class), eq(0x7FFFFFFF),
eq(0x7FFFFFFF));
writer.writeWindowUpdate(ctx, STREAM_ID, WINDOW_UPDATE, ctx.newPromise());
readFrames();
verify(listener).onWindowUpdateRead(eq(ctx), eq(STREAM_ID), eq(WINDOW_UPDATE));
}
@Test
public void stressTest() throws Exception {
final Http2Headers headers = headers();
final String text = "hello world";
final ByteBuf data = Unpooled.copiedBuffer(text.getBytes());
final int numStreams = 10000;
final List<String> receivedBuffers = Collections.synchronizedList(new ArrayList<String>(numStreams));
doAnswer(new Answer<Void>() {
@Override
public Void answer(InvocationOnMock in) throws Throwable {
receivedBuffers.add(((ByteBuf) in.getArguments()[2]).toString(UTF_8));
return null;
}
}).when(serverListener).onDataRead(any(ChannelHandlerContext.class), anyInt(), eq(data), eq(0), eq(true));
try {
final int expectedFrames = numStreams * 2;
bootstrapEnv(expectedFrames);
runInChannel(clientChannel, new Http2Runnable() {
@Override
public void run() {
for (int i = 1; i < numStreams + 1; ++i) {
frameWriter.writeHeaders(ctx(), i, headers, 0, (short) 16, false, 0, false, newPromise());
frameWriter.writeData(ctx(), i, data.duplicate().retain(), 0, true, newPromise());
ctx().flush();
}
}
});
awaitRequests(60);
verify(serverListener, times(numStreams)).onDataRead(any(ChannelHandlerContext.class), anyInt(),
any(ByteBuf.class), eq(0), eq(true));
assertEquals(numStreams, receivedBuffers.size());
for (String receivedData : receivedBuffers) {
assertEquals(text, receivedData);
}
} finally {
data.release();
private void readFrames() throws Http2Exception {
// Now read all of the written frames.
ByteBuf write = captureWrites();
reader.readFrame(ctx, write, listener);
}
private ByteBuf data(int size) {
byte[] data = new byte[size];
for (int ix = 0; ix < data.length;) {
int length = min(MESSAGE.length, data.length - ix);
System.arraycopy(MESSAGE, 0, data, ix, length);
ix += length;
}
return buf(data);
}
private void awaitRequests(long seconds) throws InterruptedException {
assertTrue(requestLatch.await(seconds, SECONDS));
private ByteBuf buf(byte[] bytes) {
return Unpooled.wrappedBuffer(bytes);
}
private void awaitRequests() throws InterruptedException {
awaitRequests(5);
private <T extends ByteBuf> T releaseLater(T buf) {
needReleasing.add(buf);
return buf;
}
private void bootstrapEnv(int requestCountDown) throws Exception {
requestLatch = new CountDownLatch(requestCountDown);
frameWriter = new DefaultHttp2FrameWriter();
sb = new ServerBootstrap();
cb = new Bootstrap();
sb.group(new DefaultEventLoopGroup());
sb.channel(LocalServerChannel.class);
sb.childHandler(new ChannelInitializer<Channel>() {
@Override
protected void initChannel(Channel ch) throws Exception {
ChannelPipeline p = ch.pipeline();
serverAdapter = new Http2TestUtil.FrameAdapter(serverListener, requestLatch);
p.addLast(serverAdapter);
}
});
cb.group(new DefaultEventLoopGroup());
cb.channel(LocalChannel.class);
cb.handler(new ChannelInitializer<Channel>() {
@Override
protected void initChannel(Channel ch) throws Exception {
ChannelPipeline p = ch.pipeline();
p.addLast(new Http2TestUtil.FrameAdapter(null, null));
}
});
serverChannel = sb.bind(new LocalAddress("Http2FrameRoundtripTest")).sync().channel();
ChannelFuture ccf = cb.connect(serverChannel.localAddress());
assertTrue(ccf.awaitUninterruptibly().isSuccess());
clientChannel = ccf.channel();
}
private ChannelHandlerContext ctx() {
return clientChannel.pipeline().firstContext();
}
private ChannelPromise newPromise() {
return ctx().newPromise();
private ByteBuf captureWrites() {
ArgumentCaptor<ByteBuf> captor = ArgumentCaptor.forClass(ByteBuf.class);
verify(ctx, atLeastOnce()).write(captor.capture(), isA(ChannelPromise.class));
CompositeByteBuf composite = releaseLater(Unpooled.compositeBuffer());
for (ByteBuf buf : captor.getAllValues()) {
buf = releaseLater(buf.retain());
composite.addComponent(buf);
composite.writerIndex(composite.writerIndex() + buf.readableBytes());
}
return composite;
}
private static Http2Headers headers() {
return new DefaultHttp2Headers(false).method(new AsciiString("GET")).scheme(new AsciiString("https"))
.authority(new AsciiString("example.org")).path(new AsciiString("/some/path/resource2"))
return new DefaultHttp2Headers(false).method(AsciiString.of("GET")).scheme(AsciiString.of("https"))
.authority(AsciiString.of("example.org")).path(AsciiString.of("/some/path/resource2"))
.add(randomString(), randomString());
}
private static Http2Headers largeHeaders() {
DefaultHttp2Headers headers = new DefaultHttp2Headers(false);
for (int i = 0; i < 100; ++i) {
String key = "this-is-a-test-header-key-" + i;
String value = "this-is-a-test-header-value-" + i;
headers.add(AsciiString.of(key), AsciiString.of(value));
}
return headers;
}
private Http2Headers headersOfSize(final int minSize) {
final AsciiString singleByte = new AsciiString(new byte[]{0}, false);
DefaultHttp2Headers headers = new DefaultHttp2Headers(false);
for (int size = 0; size < minSize; size += 2) {
headers.add(singleByte, singleByte);
}
return headers;
}
private static Http2Headers binaryHeaders() {
DefaultHttp2Headers headers = new DefaultHttp2Headers(false);
for (int ix = 0; ix < 10; ++ix) {
headers.add(randomString(), randomString());
}
return headers;
}
}

View File

@ -22,16 +22,21 @@ import static io.netty.handler.codec.http2.Http2CodecUtil.SMALLEST_MAX_CONCURREN
import static io.netty.handler.codec.http2.Http2Error.CANCEL;
import static io.netty.handler.codec.http2.Http2Stream.State.HALF_CLOSED_LOCAL;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertFalse;
import static org.junit.Assert.assertNotNull;
import static org.junit.Assert.assertTrue;
import static org.mockito.Matchers.any;
import static org.mockito.Matchers.anyBoolean;
import static org.mockito.Matchers.anyInt;
import static org.mockito.Matchers.anyLong;
import static org.mockito.Matchers.eq;
import static org.mockito.Mockito.doAnswer;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.never;
import static org.mockito.Mockito.times;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;
import io.netty.buffer.ByteBuf;
import io.netty.buffer.Unpooled;
import io.netty.buffer.UnpooledByteBufAllocator;
@ -41,12 +46,9 @@ import io.netty.channel.ChannelFuture;
import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.ChannelPromise;
import io.netty.channel.DefaultChannelPromise;
import io.netty.handler.codec.http2.StreamBufferingEncoder.Http2ChannelClosedException;
import io.netty.handler.codec.http2.StreamBufferingEncoder.Http2GoAwayException;
import io.netty.util.ReferenceCountUtil;
import io.netty.util.concurrent.EventExecutor;
import io.netty.util.concurrent.ImmediateEventExecutor;
import org.junit.After;
import org.junit.Before;
import org.junit.Test;
@ -57,6 +59,9 @@ import org.mockito.invocation.InvocationOnMock;
import org.mockito.stubbing.Answer;
import org.mockito.verification.VerificationMode;
import java.util.ArrayList;
import java.util.List;
/**
* Tests for {@link StreamBufferingEncoder}.
*/
@ -81,9 +86,6 @@ public class StreamBufferingEncoderTest {
@Mock
private EventExecutor executor;
@Mock
private ChannelPromise promise;
/**
* Init fields and do mocking.
*/
@ -98,7 +100,7 @@ public class StreamBufferingEncoderTest {
when(frameSizePolicy.maxFrameSize()).thenReturn(DEFAULT_MAX_FRAME_SIZE);
when(writer.writeData(any(ChannelHandlerContext.class), anyInt(), any(ByteBuf.class), anyInt(), anyBoolean(),
any(ChannelPromise.class))).thenAnswer(successAnswer());
when(writer.writeRstStream(eq(ctx), anyInt(), anyLong(), eq(promise))).thenAnswer(
when(writer.writeRstStream(eq(ctx), anyInt(), anyLong(), any(ChannelPromise.class))).thenAnswer(
successAnswer());
when(writer.writeGoAway(any(ChannelHandlerContext.class), anyInt(), anyLong(), any(ByteBuf.class),
any(ChannelPromise.class)))
@ -121,9 +123,13 @@ public class StreamBufferingEncoderTest {
when(ctx.alloc()).thenReturn(UnpooledByteBufAllocator.DEFAULT);
when(channel.alloc()).thenReturn(UnpooledByteBufAllocator.DEFAULT);
when(executor.inEventLoop()).thenReturn(true);
when(ctx.newPromise()).thenReturn(promise);
doAnswer(new Answer<ChannelPromise>() {
@Override
public ChannelPromise answer(InvocationOnMock invocation) throws Throwable {
return newPromise();
}
}).when(ctx).newPromise();
when(ctx.executor()).thenReturn(executor);
when(promise.channel()).thenReturn(channel);
when(channel.isActive()).thenReturn(false);
when(channel.config()).thenReturn(config);
when(channel.isWritable()).thenReturn(true);
@ -140,34 +146,34 @@ public class StreamBufferingEncoderTest {
@Test
public void multipleWritesToActiveStream() {
encoder.writeSettingsAck(ctx, promise);
encoderWriteHeaders(3, promise);
encoder.writeSettingsAck(ctx, newPromise());
encoderWriteHeaders(3, newPromise());
assertEquals(0, encoder.numBufferedStreams());
ByteBuf data = data();
final int expectedBytes = data.readableBytes() * 3;
encoder.writeData(ctx, 3, data, 0, false, promise);
encoder.writeData(ctx, 3, data(), 0, false, promise);
encoder.writeData(ctx, 3, data(), 0, false, promise);
encoderWriteHeaders(3, promise);
encoder.writeData(ctx, 3, data, 0, false, newPromise());
encoder.writeData(ctx, 3, data(), 0, false, newPromise());
encoder.writeData(ctx, 3, data(), 0, false, newPromise());
encoderWriteHeaders(3, newPromise());
writeVerifyWriteHeaders(times(2), 3, promise);
writeVerifyWriteHeaders(times(2), 3);
// Contiguous data writes are coalesced
ArgumentCaptor<ByteBuf> bufCaptor = ArgumentCaptor.forClass(ByteBuf.class);
verify(writer, times(1))
.writeData(eq(ctx), eq(3), bufCaptor.capture(), eq(0), eq(false), eq(promise));
.writeData(eq(ctx), eq(3), bufCaptor.capture(), eq(0), eq(false), any(ChannelPromise.class));
assertEquals(expectedBytes, bufCaptor.getValue().readableBytes());
}
@Test
public void ensureCanCreateNextStreamWhenStreamCloses() {
encoder.writeSettingsAck(ctx, promise);
encoder.writeSettingsAck(ctx, newPromise());
setMaxConcurrentStreams(1);
encoderWriteHeaders(3, promise);
encoderWriteHeaders(3, newPromise());
assertEquals(0, encoder.numBufferedStreams());
// This one gets buffered.
encoderWriteHeaders(5, promise);
encoderWriteHeaders(5, newPromise());
assertEquals(1, connection.numActiveStreams());
assertEquals(1, encoder.numBufferedStreams());
@ -179,51 +185,53 @@ public class StreamBufferingEncoderTest {
// Ensure that no streams are currently active and that only the HEADERS from the first
// stream were written.
writeVerifyWriteHeaders(times(1), 3, promise);
writeVerifyWriteHeaders(never(), 5, promise);
writeVerifyWriteHeaders(times(1), 3);
writeVerifyWriteHeaders(never(), 5);
assertEquals(0, connection.numActiveStreams());
assertEquals(1, encoder.numBufferedStreams());
}
@Test
public void alternatingWritesToActiveAndBufferedStreams() {
encoder.writeSettingsAck(ctx, promise);
encoder.writeSettingsAck(ctx, newPromise());
setMaxConcurrentStreams(1);
encoderWriteHeaders(3, promise);
encoderWriteHeaders(3, newPromise());
assertEquals(0, encoder.numBufferedStreams());
encoderWriteHeaders(5, promise);
encoderWriteHeaders(5, newPromise());
assertEquals(1, connection.numActiveStreams());
assertEquals(1, encoder.numBufferedStreams());
encoder.writeData(ctx, 3, EMPTY_BUFFER, 0, false, promise);
writeVerifyWriteHeaders(times(1), 3, promise);
encoder.writeData(ctx, 5, EMPTY_BUFFER, 0, false, promise);
encoder.writeData(ctx, 3, EMPTY_BUFFER, 0, false, newPromise());
writeVerifyWriteHeaders(times(1), 3);
encoder.writeData(ctx, 5, EMPTY_BUFFER, 0, false, newPromise());
verify(writer, never())
.writeData(eq(ctx), eq(5), any(ByteBuf.class), eq(0), eq(false), eq(promise));
.writeData(eq(ctx), eq(5), any(ByteBuf.class), eq(0), eq(false), eq(newPromise()));
}
@Test
public void bufferingNewStreamFailsAfterGoAwayReceived() {
encoder.writeSettingsAck(ctx, promise);
encoder.writeSettingsAck(ctx, newPromise());
setMaxConcurrentStreams(0);
connection.goAwayReceived(1, 8, EMPTY_BUFFER);
promise = mock(ChannelPromise.class);
ChannelPromise promise = newPromise();
encoderWriteHeaders(3, promise);
assertEquals(0, encoder.numBufferedStreams());
verify(promise).setFailure(any(Throwable.class));
assertTrue(promise.isDone());
assertFalse(promise.isSuccess());
}
@Test
public void receivingGoAwayFailsBufferedStreams() {
encoder.writeSettingsAck(ctx, promise);
encoder.writeSettingsAck(ctx, newPromise());
setMaxConcurrentStreams(5);
int streamId = 3;
List<ChannelFuture> futures = new ArrayList<ChannelFuture>();
for (int i = 0; i < 9; i++) {
encoderWriteHeaders(streamId, promise);
futures.add(encoderWriteHeaders(streamId, newPromise()));
streamId += 2;
}
assertEquals(4, encoder.numBufferedStreams());
@ -231,40 +239,47 @@ public class StreamBufferingEncoderTest {
connection.goAwayReceived(11, 8, EMPTY_BUFFER);
assertEquals(5, connection.numActiveStreams());
// The 4 buffered streams must have been failed.
verify(promise, times(4)).setFailure(any(Throwable.class));
int failCount = 0;
for (ChannelFuture f : futures) {
if (f.cause() != null) {
failCount++;
}
}
assertEquals(4, failCount);
assertEquals(0, encoder.numBufferedStreams());
}
@Test
public void sendingGoAwayShouldNotFailStreams() {
encoder.writeSettingsAck(ctx, promise);
encoder.writeSettingsAck(ctx, newPromise());
setMaxConcurrentStreams(1);
encoderWriteHeaders(3, promise);
ChannelFuture f1 = encoderWriteHeaders(3, newPromise());
assertEquals(0, encoder.numBufferedStreams());
encoderWriteHeaders(5, promise);
ChannelFuture f2 = encoderWriteHeaders(5, newPromise());
assertEquals(1, encoder.numBufferedStreams());
encoderWriteHeaders(7, promise);
ChannelFuture f3 = encoderWriteHeaders(7, newPromise());
assertEquals(2, encoder.numBufferedStreams());
ByteBuf empty = Unpooled.buffer(0);
encoder.writeGoAway(ctx, 3, CANCEL.code(), empty, promise);
encoder.writeGoAway(ctx, 3, CANCEL.code(), empty, newPromise());
assertEquals(1, connection.numActiveStreams());
assertEquals(2, encoder.numBufferedStreams());
verify(promise, never()).setFailure(any(Http2GoAwayException.class));
assertFalse(f1.isDone());
assertFalse(f2.isDone());
assertFalse(f3.isDone());
}
@Test
public void endStreamDoesNotFailBufferedStream() {
encoder.writeSettingsAck(ctx, promise);
encoder.writeSettingsAck(ctx, newPromise());
setMaxConcurrentStreams(0);
encoderWriteHeaders(3, promise);
encoderWriteHeaders(3, newPromise());
assertEquals(1, encoder.numBufferedStreams());
encoder.writeData(ctx, 3, EMPTY_BUFFER, 0, true, promise);
encoder.writeData(ctx, 3, EMPTY_BUFFER, 0, true, newPromise());
assertEquals(0, connection.numActiveStreams());
assertEquals(1, encoder.numBufferedStreams());
@ -272,7 +287,7 @@ public class StreamBufferingEncoderTest {
// Simulate that we received a SETTINGS frame which
// increased MAX_CONCURRENT_STREAMS to 1.
setMaxConcurrentStreams(1);
encoder.writeSettingsAck(ctx, promise);
encoder.writeSettingsAck(ctx, newPromise());
assertEquals(1, connection.numActiveStreams());
assertEquals(0, encoder.numBufferedStreams());
@ -281,75 +296,80 @@ public class StreamBufferingEncoderTest {
@Test
public void rstStreamClosesBufferedStream() {
encoder.writeSettingsAck(ctx, promise);
encoder.writeSettingsAck(ctx, newPromise());
setMaxConcurrentStreams(0);
encoderWriteHeaders(3, promise);
encoderWriteHeaders(3, newPromise());
assertEquals(1, encoder.numBufferedStreams());
verify(promise, never()).setSuccess();
ChannelPromise rstStreamPromise = mock(ChannelPromise.class);
ChannelPromise rstStreamPromise = newPromise();
encoder.writeRstStream(ctx, 3, CANCEL.code(), rstStreamPromise);
verify(promise).setSuccess();
verify(rstStreamPromise).setSuccess();
assertTrue(rstStreamPromise.isSuccess());
assertEquals(0, encoder.numBufferedStreams());
}
@Test
public void bufferUntilActiveStreamsAreReset() {
encoder.writeSettingsAck(ctx, promise);
public void bufferUntilActiveStreamsAreReset() throws Exception {
encoder.writeSettingsAck(ctx, newPromise());
setMaxConcurrentStreams(1);
encoderWriteHeaders(3, promise);
encoderWriteHeaders(3, newPromise());
assertEquals(0, encoder.numBufferedStreams());
encoderWriteHeaders(5, promise);
encoderWriteHeaders(5, newPromise());
assertEquals(1, encoder.numBufferedStreams());
encoderWriteHeaders(7, promise);
encoderWriteHeaders(7, newPromise());
assertEquals(2, encoder.numBufferedStreams());
writeVerifyWriteHeaders(times(1), 3, promise);
writeVerifyWriteHeaders(never(), 5, promise);
writeVerifyWriteHeaders(never(), 7, promise);
writeVerifyWriteHeaders(times(1), 3);
writeVerifyWriteHeaders(never(), 5);
writeVerifyWriteHeaders(never(), 7);
encoder.writeRstStream(ctx, 3, CANCEL.code(), promise);
encoder.writeRstStream(ctx, 3, CANCEL.code(), newPromise());
connection.remote().flowController().writePendingBytes();
writeVerifyWriteHeaders(times(1), 5);
writeVerifyWriteHeaders(never(), 7);
assertEquals(1, connection.numActiveStreams());
assertEquals(1, encoder.numBufferedStreams());
encoder.writeRstStream(ctx, 5, CANCEL.code(), promise);
encoder.writeRstStream(ctx, 5, CANCEL.code(), newPromise());
connection.remote().flowController().writePendingBytes();
writeVerifyWriteHeaders(times(1), 7);
assertEquals(1, connection.numActiveStreams());
assertEquals(0, encoder.numBufferedStreams());
encoder.writeRstStream(ctx, 7, CANCEL.code(), promise);
encoder.writeRstStream(ctx, 7, CANCEL.code(), newPromise());
assertEquals(0, connection.numActiveStreams());
assertEquals(0, encoder.numBufferedStreams());
}
@Test
public void bufferUntilMaxStreamsIncreased() {
encoder.writeSettingsAck(ctx, promise);
encoder.writeSettingsAck(ctx, newPromise());
setMaxConcurrentStreams(2);
encoderWriteHeaders(3, promise);
encoderWriteHeaders(5, promise);
encoderWriteHeaders(7, promise);
encoderWriteHeaders(9, promise);
encoderWriteHeaders(3, newPromise());
encoderWriteHeaders(5, newPromise());
encoderWriteHeaders(7, newPromise());
encoderWriteHeaders(9, newPromise());
assertEquals(2, encoder.numBufferedStreams());
writeVerifyWriteHeaders(times(1), 3, promise);
writeVerifyWriteHeaders(times(1), 5, promise);
writeVerifyWriteHeaders(never(), 7, promise);
writeVerifyWriteHeaders(never(), 9, promise);
writeVerifyWriteHeaders(times(1), 3);
writeVerifyWriteHeaders(times(1), 5);
writeVerifyWriteHeaders(never(), 7);
writeVerifyWriteHeaders(never(), 9);
// Simulate that we received a SETTINGS frame which
// increased MAX_CONCURRENT_STREAMS to 5.
setMaxConcurrentStreams(5);
encoder.writeSettingsAck(ctx, promise);
encoder.writeSettingsAck(ctx, newPromise());
assertEquals(0, encoder.numBufferedStreams());
writeVerifyWriteHeaders(times(1), 7, promise);
writeVerifyWriteHeaders(times(1), 9, promise);
writeVerifyWriteHeaders(times(1), 7);
writeVerifyWriteHeaders(times(1), 9);
encoderWriteHeaders(11, promise);
encoderWriteHeaders(11, newPromise());
writeVerifyWriteHeaders(times(1), 11, promise);
writeVerifyWriteHeaders(times(1), 11);
assertEquals(5, connection.local().numActiveStreams());
}
@ -359,11 +379,11 @@ public class StreamBufferingEncoderTest {
int initialLimit = SMALLEST_MAX_CONCURRENT_STREAMS;
int numStreams = initialLimit * 2;
for (int ix = 0, nextStreamId = 3; ix < numStreams; ++ix, nextStreamId += 2) {
encoderWriteHeaders(nextStreamId, promise);
encoderWriteHeaders(nextStreamId, newPromise());
if (ix < initialLimit) {
writeVerifyWriteHeaders(times(1), nextStreamId, promise);
writeVerifyWriteHeaders(times(1), nextStreamId);
} else {
writeVerifyWriteHeaders(never(), nextStreamId, promise);
writeVerifyWriteHeaders(never(), nextStreamId);
}
}
assertEquals(numStreams / 2, encoder.numBufferedStreams());
@ -380,11 +400,11 @@ public class StreamBufferingEncoderTest {
int initialLimit = SMALLEST_MAX_CONCURRENT_STREAMS;
int numStreams = initialLimit * 2;
for (int ix = 0, nextStreamId = 3; ix < numStreams; ++ix, nextStreamId += 2) {
encoderWriteHeaders(nextStreamId, promise);
encoderWriteHeaders(nextStreamId, newPromise());
if (ix < initialLimit) {
writeVerifyWriteHeaders(times(1), nextStreamId, promise);
writeVerifyWriteHeaders(times(1), nextStreamId);
} else {
writeVerifyWriteHeaders(never(), nextStreamId, promise);
writeVerifyWriteHeaders(never(), nextStreamId);
}
}
assertEquals(numStreams / 2, encoder.numBufferedStreams());
@ -400,56 +420,59 @@ public class StreamBufferingEncoderTest {
public void exhaustedStreamsDoNotBuffer() throws Http2Exception {
// Write the highest possible stream ID for the client.
// This will cause the next stream ID to be negative.
encoderWriteHeaders(Integer.MAX_VALUE, promise);
encoderWriteHeaders(Integer.MAX_VALUE, newPromise());
// Disallow any further streams.
setMaxConcurrentStreams(0);
// Simulate numeric overflow for the next stream ID.
encoderWriteHeaders(-1, promise);
ChannelFuture f = encoderWriteHeaders(-1, newPromise());
// Verify that the write fails.
verify(promise).setFailure(any(Http2Exception.class));
assertNotNull(f.cause());
}
@Test
public void closedBufferedStreamReleasesByteBuf() {
encoder.writeSettingsAck(ctx, promise);
encoder.writeSettingsAck(ctx, newPromise());
setMaxConcurrentStreams(0);
ByteBuf data = mock(ByteBuf.class);
encoderWriteHeaders(3, promise);
ChannelFuture f1 = encoderWriteHeaders(3, newPromise());
assertEquals(1, encoder.numBufferedStreams());
encoder.writeData(ctx, 3, data, 0, false, promise);
ChannelFuture f2 = encoder.writeData(ctx, 3, data, 0, false, newPromise());
ChannelPromise rstPromise = mock(ChannelPromise.class);
encoder.writeRstStream(ctx, 3, CANCEL.code(), rstPromise);
assertEquals(0, encoder.numBufferedStreams());
verify(rstPromise).setSuccess();
verify(promise, times(2)).setSuccess();
assertTrue(f1.isSuccess());
assertTrue(f2.isSuccess());
verify(data).release();
}
@Test
public void closeShouldCancelAllBufferedStreams() {
encoder.writeSettingsAck(ctx, promise);
encoder.writeSettingsAck(ctx, newPromise());
connection.local().maxActiveStreams(0);
encoderWriteHeaders(3, promise);
encoderWriteHeaders(5, promise);
encoderWriteHeaders(7, promise);
ChannelFuture f1 = encoderWriteHeaders(3, newPromise());
ChannelFuture f2 = encoderWriteHeaders(5, newPromise());
ChannelFuture f3 = encoderWriteHeaders(7, newPromise());
encoder.close();
verify(promise, times(3)).setFailure(any(Http2ChannelClosedException.class));
assertNotNull(f1.cause());
assertNotNull(f2.cause());
assertNotNull(f3.cause());
}
@Test
public void headersAfterCloseShouldImmediatelyFail() {
encoder.writeSettingsAck(ctx, promise);
encoder.writeSettingsAck(ctx, newPromise());
encoder.close();
encoderWriteHeaders(3, promise);
verify(promise).setFailure(any(Http2ChannelClosedException.class));
ChannelFuture f = encoderWriteHeaders(3, newPromise());
assertNotNull(f.cause());
}
private void setMaxConcurrentStreams(int newValue) {
@ -462,21 +485,21 @@ public class StreamBufferingEncoderTest {
}
}
private void encoderWriteHeaders(int streamId, ChannelPromise promise) {
private ChannelFuture encoderWriteHeaders(int streamId, ChannelPromise promise) {
encoder.writeHeaders(ctx, streamId, new DefaultHttp2Headers(), 0, DEFAULT_PRIORITY_WEIGHT,
false, 0, false, promise);
try {
encoder.flowController().writePendingBytes();
return promise;
} catch (Http2Exception e) {
throw new RuntimeException(e);
}
}
private void writeVerifyWriteHeaders(VerificationMode mode, int streamId,
ChannelPromise promise) {
private void writeVerifyWriteHeaders(VerificationMode mode, int streamId) {
verify(writer, mode).writeHeaders(eq(ctx), eq(streamId), any(Http2Headers.class), eq(0),
eq(DEFAULT_PRIORITY_WEIGHT), eq(false), eq(0),
eq(false), eq(promise));
eq(false), any(ChannelPromise.class));
}
private Answer<ChannelFuture> successAnswer() {
@ -487,14 +510,17 @@ public class StreamBufferingEncoderTest {
ReferenceCountUtil.safeRelease(a);
}
ChannelPromise future =
new DefaultChannelPromise(channel, ImmediateEventExecutor.INSTANCE);
ChannelPromise future = newPromise();
future.setSuccess();
return future;
}
};
}
private ChannelPromise newPromise() {
return new DefaultChannelPromise(channel, ImmediateEventExecutor.INSTANCE);
}
private static ByteBuf data() {
ByteBuf buf = Unpooled.buffer(10);
for (int i = 0; i < buf.writableBytes(); i++) {