From d2615ab532b3d149df175d142bbb423947b1a05b Mon Sep 17 00:00:00 2001 From: nmittler Date: Tue, 2 Jun 2015 20:15:59 -0700 Subject: [PATCH] Porting BufferingHttp2ConnectionEncoder from gRPC Motivation: gRPC's BufferingHttp2ConnectionEncoder is a generic utility that simplifies client-side applications that want to allow stream creation without worrying about violating the SETTINGS_MAX_CONCURRENT_STREAMS limit. Since it's not gRPC-specific it makes sense to move it into Netty proper. Modifications: Adding the BufferingHttp2ConnectionEncoder and it's unit test. Result: Netty now supports buffering stream creation. --- .../handler/codec/http2/Http2CodecUtil.java | 6 + .../codec/http2/StreamBufferingEncoder.java | 337 ++++++++++++++ .../http2/StreamBufferingEncoderTest.java | 435 ++++++++++++++++++ 3 files changed, 778 insertions(+) create mode 100644 codec-http2/src/main/java/io/netty/handler/codec/http2/StreamBufferingEncoder.java create mode 100644 codec-http2/src/test/java/io/netty/handler/codec/http2/StreamBufferingEncoderTest.java diff --git a/codec-http2/src/main/java/io/netty/handler/codec/http2/Http2CodecUtil.java b/codec-http2/src/main/java/io/netty/handler/codec/http2/Http2CodecUtil.java index 27c69846ce..d9e46c2ec7 100644 --- a/codec-http2/src/main/java/io/netty/handler/codec/http2/Http2CodecUtil.java +++ b/codec-http2/src/main/java/io/netty/handler/codec/http2/Http2CodecUtil.java @@ -91,6 +91,12 @@ public final class Http2CodecUtil { public static final int DEFAULT_MAX_HEADER_SIZE = 8192; public static final int DEFAULT_MAX_FRAME_SIZE = MAX_FRAME_SIZE_LOWER_BOUND; + /** + * The assumed minimum value for {@code SETTINGS_MAX_CONCURRENT_STREAMS} as + * recommended by the HTTP/2 spec. + */ + public static final int SMALLEST_MAX_CONCURRENT_STREAMS = 100; + /** * Indicates whether or not the given value for max frame size falls within the valid range. */ diff --git a/codec-http2/src/main/java/io/netty/handler/codec/http2/StreamBufferingEncoder.java b/codec-http2/src/main/java/io/netty/handler/codec/http2/StreamBufferingEncoder.java new file mode 100644 index 0000000000..7da5cf7229 --- /dev/null +++ b/codec-http2/src/main/java/io/netty/handler/codec/http2/StreamBufferingEncoder.java @@ -0,0 +1,337 @@ +/* + * Copyright 2015 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, version 2.0 (the + * "License"); you may not use this file except in compliance with the License. You may obtain a + * copy of the License at: + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software distributed under the License + * is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express + * or implied. See the License for the specific language governing permissions and limitations under + * the License. + */ + +package io.netty.handler.codec.http2; + +import static io.netty.handler.codec.http2.Http2CodecUtil.SMALLEST_MAX_CONCURRENT_STREAMS; +import static io.netty.handler.codec.http2.Http2Error.PROTOCOL_ERROR; +import static io.netty.handler.codec.http2.Http2Exception.connectionError; + +import io.netty.buffer.ByteBuf; +import io.netty.channel.ChannelFuture; +import io.netty.channel.ChannelHandlerContext; +import io.netty.channel.ChannelPromise; +import io.netty.util.ReferenceCountUtil; + +import java.util.ArrayDeque; +import java.util.Iterator; +import java.util.Map; +import java.util.Queue; +import java.util.TreeMap; + +/** + * Implementation of a {@link Http2ConnectionEncoder} that dispatches all method call to another + * {@link Http2ConnectionEncoder}, until {@code SETTINGS_MAX_CONCURRENT_STREAMS} is reached. + *

+ *

When this limit is hit, instead of rejecting any new streams this implementation buffers newly + * created streams and their corresponding frames. Once an active stream gets closed or the maximum + * number of concurrent streams is increased, this encoder will automatically try to empty its + * buffer and create as many new streams as possible. + *

+ *

+ * If a {@code GOAWAY} frame is received from the remote endpoint, all buffered writes for streams + * with an ID less than the specified {@code lastStreamId} will immediately fail with a + * {@link StreamBufferingEncoder.GoAwayException}. + *

+ *

This implementation makes the buffering mostly transparent and is expected to be used as a + * drop-in decorator of {@link DefaultHttp2ConnectionEncoder}. + *

+ */ +public class StreamBufferingEncoder extends DecoratingHttp2ConnectionEncoder { + + /** + * Thrown by {@link StreamBufferingEncoder} if buffered streams are terminated due to + * receipt of a {@code GOAWAY}. + */ + public static final class GoAwayException extends Http2Exception { + private static final long serialVersionUID = 1326785622777291198L; + private final int lastStreamId; + private final long errorCode; + private final ByteBuf debugData; + + public GoAwayException(int lastStreamId, long errorCode, ByteBuf debugData) { + super(Http2Error.STREAM_CLOSED); + this.lastStreamId = lastStreamId; + this.errorCode = errorCode; + this.debugData = debugData; + } + + public int lastStreamId() { + return lastStreamId; + } + + public long errorCode() { + return errorCode; + } + + public ByteBuf debugData() { + return debugData; + } + } + + /** + * Buffer for any streams and corresponding frames that could not be created due to the maximum + * concurrent stream limit being hit. + */ + private final TreeMap pendingStreams = new TreeMap(); + private int maxConcurrentStreams; + + public StreamBufferingEncoder(Http2ConnectionEncoder delegate) { + this(delegate, SMALLEST_MAX_CONCURRENT_STREAMS); + } + + public StreamBufferingEncoder(Http2ConnectionEncoder delegate, int initialMaxConcurrentStreams) { + super(delegate); + this.maxConcurrentStreams = initialMaxConcurrentStreams; + connection().addListener(new Http2ConnectionAdapter() { + + @Override + public void onGoAwayReceived(int lastStreamId, long errorCode, ByteBuf debugData) { + cancelGoAwayStreams(lastStreamId, errorCode, debugData); + } + + @Override + public void onStreamClosed(Http2Stream stream) { + tryCreatePendingStreams(); + } + }); + } + + /** + * Indicates the number of streams that are currently buffered, awaiting creation. + */ + public int numBufferedStreams() { + return pendingStreams.size(); + } + + @Override + public ChannelFuture writeHeaders(ChannelHandlerContext ctx, int streamId, Http2Headers headers, + int padding, boolean endStream, ChannelPromise promise) { + return writeHeaders(ctx, streamId, headers, 0, Http2CodecUtil.DEFAULT_PRIORITY_WEIGHT, + false, padding, endStream, promise); + } + + @Override + public ChannelFuture writeHeaders(ChannelHandlerContext ctx, int streamId, Http2Headers headers, + int streamDependency, short weight, boolean exclusive, + int padding, boolean endOfStream, ChannelPromise promise) { + if (isExistingStream(streamId) || connection().goAwayReceived()) { + return super.writeHeaders(ctx, streamId, headers, streamDependency, weight, + exclusive, padding, endOfStream, promise); + } + if (canCreateStream()) { + return super.writeHeaders(ctx, streamId, headers, streamDependency, weight, + exclusive, padding, endOfStream, promise); + } + PendingStream pendingStream = pendingStreams.get(streamId); + if (pendingStream == null) { + pendingStream = new PendingStream(ctx, streamId); + pendingStreams.put(streamId, pendingStream); + } + pendingStream.frames.add(new HeadersFrame(headers, streamDependency, weight, exclusive, + padding, endOfStream, promise)); + return promise; + } + + @Override + public ChannelFuture writeRstStream(ChannelHandlerContext ctx, int streamId, long errorCode, + ChannelPromise promise) { + if (isExistingStream(streamId)) { + return super.writeRstStream(ctx, streamId, errorCode, promise); + } + // Since the delegate doesn't know about any buffered streams we have to handle cancellation + // of the promises and releasing of the ByteBufs here. + PendingStream stream = pendingStreams.remove(streamId); + if (stream != null) { + // Sending a RST_STREAM to a buffered stream will succeed the promise of all frames + // associated with the stream, as sending a RST_STREAM means that someone "doesn't care" + // about the stream anymore and thus there is not point in failing the promises and invoking + // error handling routines. + stream.close(null); + promise.setSuccess(); + } else { + promise.setFailure(connectionError(PROTOCOL_ERROR, "Stream does not exist %d", streamId)); + } + return promise; + } + + @Override + public ChannelFuture writeData(ChannelHandlerContext ctx, int streamId, ByteBuf data, + int padding, boolean endOfStream, ChannelPromise promise) { + if (isExistingStream(streamId)) { + return super.writeData(ctx, streamId, data, padding, endOfStream, promise); + } + PendingStream pendingStream = pendingStreams.get(streamId); + if (pendingStream != null) { + pendingStream.frames.add(new DataFrame(data, padding, endOfStream, promise)); + } else { + ReferenceCountUtil.safeRelease(data); + promise.setFailure(connectionError(PROTOCOL_ERROR, "Stream does not exist %d", streamId)); + } + return promise; + } + + @Override + public void remoteSettings(Http2Settings settings) throws Http2Exception { + // Need to let the delegate decoder handle the settings first, so that it sees the + // new setting before we attempt to create any new streams. + super.remoteSettings(settings); + + // Get the updated value for SETTINGS_MAX_CONCURRENT_STREAMS. + maxConcurrentStreams = connection().local().maxActiveStreams(); + + // Try to create new streams up to the new threshold. + tryCreatePendingStreams(); + } + + @Override + public void close() { + super.close(); + cancelPendingStreams(); + } + + private void tryCreatePendingStreams() { + while (!pendingStreams.isEmpty() && canCreateStream()) { + Map.Entry entry = pendingStreams.pollFirstEntry(); + PendingStream pendingStream = entry.getValue(); + pendingStream.sendFrames(); + } + } + + private void cancelPendingStreams() { + Exception e = new Exception("Connection closed."); + while (!pendingStreams.isEmpty()) { + PendingStream stream = pendingStreams.pollFirstEntry().getValue(); + stream.close(e); + } + } + + private void cancelGoAwayStreams(int lastStreamId, long errorCode, ByteBuf debugData) { + Iterator iter = pendingStreams.values().iterator(); + Exception e = new GoAwayException(lastStreamId, errorCode, debugData); + while (iter.hasNext()) { + PendingStream stream = iter.next(); + if (stream.streamId > lastStreamId) { + iter.remove(); + stream.close(e); + } + } + } + + /** + * Determines whether or not we're allowed to create a new stream right now. + */ + private boolean canCreateStream() { + return connection().local().numActiveStreams() < maxConcurrentStreams; + } + + private boolean isExistingStream(int streamId) { + return streamId <= connection().local().lastStreamCreated(); + } + + private static final class PendingStream { + final ChannelHandlerContext ctx; + final int streamId; + final Queue frames = new ArrayDeque(2); + + PendingStream(ChannelHandlerContext ctx, int streamId) { + this.ctx = ctx; + this.streamId = streamId; + } + + void sendFrames() { + for (Frame frame : frames) { + frame.send(ctx, streamId); + } + } + + void close(Throwable t) { + for (Frame frame : frames) { + frame.release(t); + } + } + } + + private abstract static class Frame { + final ChannelPromise promise; + + Frame(ChannelPromise promise) { + this.promise = promise; + } + + /** + * Release any resources (features, buffers, ...) associated with the frame. + */ + void release(Throwable t) { + if (t == null) { + promise.setSuccess(); + } else { + promise.setFailure(t); + } + } + + abstract void send(ChannelHandlerContext ctx, int streamId); + } + + private final class HeadersFrame extends Frame { + final Http2Headers headers; + final int streamDependency; + final short weight; + final boolean exclusive; + final int padding; + final boolean endOfStream; + + HeadersFrame(Http2Headers headers, int streamDependency, short weight, boolean exclusive, + int padding, boolean endOfStream, ChannelPromise promise) { + super(promise); + this.headers = headers; + this.streamDependency = streamDependency; + this.weight = weight; + this.exclusive = exclusive; + this.padding = padding; + this.endOfStream = endOfStream; + } + + @Override + void send(ChannelHandlerContext ctx, int streamId) { + writeHeaders(ctx, streamId, headers, streamDependency, weight, exclusive, padding, + endOfStream, promise); + } + } + + private final class DataFrame extends Frame { + final ByteBuf data; + final int padding; + final boolean endOfStream; + + DataFrame(ByteBuf data, int padding, boolean endOfStream, ChannelPromise promise) { + super(promise); + this.data = data; + this.padding = padding; + this.endOfStream = endOfStream; + } + + @Override + void release(Throwable t) { + super.release(t); + ReferenceCountUtil.safeRelease(data); + } + + @Override + void send(ChannelHandlerContext ctx, int streamId) { + writeData(ctx, streamId, data, padding, endOfStream, promise); + } + } +} diff --git a/codec-http2/src/test/java/io/netty/handler/codec/http2/StreamBufferingEncoderTest.java b/codec-http2/src/test/java/io/netty/handler/codec/http2/StreamBufferingEncoderTest.java new file mode 100644 index 0000000000..3c75c874e9 --- /dev/null +++ b/codec-http2/src/test/java/io/netty/handler/codec/http2/StreamBufferingEncoderTest.java @@ -0,0 +1,435 @@ +/* + * Copyright 2015 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, version 2.0 (the + * "License"); you may not use this file except in compliance with the License. You may obtain a + * copy of the License at: + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software distributed under the License + * is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express + * or implied. See the License for the specific language governing permissions and limitations under + * the License. + */ + +package io.netty.handler.codec.http2; + +import static io.netty.handler.codec.http2.Http2CodecUtil.DEFAULT_MAX_FRAME_SIZE; +import static io.netty.handler.codec.http2.Http2CodecUtil.DEFAULT_PRIORITY_WEIGHT; +import static io.netty.handler.codec.http2.Http2CodecUtil.SMALLEST_MAX_CONCURRENT_STREAMS; +import static io.netty.handler.codec.http2.Http2Error.CANCEL; +import static io.netty.handler.codec.http2.Http2Stream.State.HALF_CLOSED_LOCAL; +import static org.junit.Assert.assertEquals; +import static org.mockito.Matchers.any; +import static org.mockito.Matchers.anyInt; +import static org.mockito.Matchers.anyLong; +import static org.mockito.Matchers.eq; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.never; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.Unpooled; +import io.netty.buffer.UnpooledByteBufAllocator; +import io.netty.channel.Channel; +import io.netty.channel.ChannelFuture; +import io.netty.channel.ChannelHandlerContext; +import io.netty.channel.ChannelPromise; +import io.netty.channel.DefaultChannelPromise; +import io.netty.handler.codec.http2.StreamBufferingEncoder.GoAwayException; +import io.netty.util.concurrent.ImmediateEventExecutor; +import org.junit.Before; +import org.junit.Test; +import org.mockito.Mock; +import org.mockito.MockitoAnnotations; +import org.mockito.invocation.InvocationOnMock; +import org.mockito.stubbing.Answer; +import org.mockito.verification.VerificationMode; + +/** + * Tests for {@link StreamBufferingEncoder}. + */ +public class StreamBufferingEncoderTest { + + private StreamBufferingEncoder encoder; + + private Http2Connection connection; + + @Mock + private Http2FrameWriter writer; + + @Mock + private ChannelHandlerContext ctx; + + @Mock + private Channel channel; + + @Mock + private ChannelPromise promise; + + /** + * Init fields and do mocking. + */ + @Before + public void setup() throws Exception { + MockitoAnnotations.initMocks(this); + + Http2FrameWriter.Configuration configuration = mock(Http2FrameWriter.Configuration.class); + Http2FrameSizePolicy frameSizePolicy = mock(Http2FrameSizePolicy.class); + when(writer.configuration()).thenReturn(configuration); + when(configuration.frameSizePolicy()).thenReturn(frameSizePolicy); + when(frameSizePolicy.maxFrameSize()).thenReturn(DEFAULT_MAX_FRAME_SIZE); + when(writer.writeRstStream(eq(ctx), anyInt(), anyLong(), eq(promise))).thenAnswer( + successAnswer()); + when(writer.writeGoAway(eq(ctx), anyInt(), anyLong(), any(ByteBuf.class), + any(ChannelPromise.class))) + .thenAnswer(successAnswer()); + + connection = new DefaultHttp2Connection(false); + + DefaultHttp2ConnectionEncoder defaultEncoder = + new DefaultHttp2ConnectionEncoder(connection, writer); + encoder = new StreamBufferingEncoder(defaultEncoder); + DefaultHttp2ConnectionDecoder decoder = + new DefaultHttp2ConnectionDecoder(connection, encoder, + mock(Http2FrameReader.class), mock(Http2FrameListener.class)); + + Http2ConnectionHandler handler = new Http2ConnectionHandler(decoder, encoder); + // Set LifeCycleManager on encoder and decoder + when(ctx.channel()).thenReturn(channel); + when(ctx.alloc()).thenReturn(UnpooledByteBufAllocator.DEFAULT); + when(channel.isActive()).thenReturn(false); + handler.handlerAdded(ctx); + } + + @Test + public void multipleWritesToActiveStream() { + encoder.writeSettingsAck(ctx, promise); + encoderWriteHeaders(3, promise); + assertEquals(0, encoder.numBufferedStreams()); + encoder.writeData(ctx, 3, data(), 0, false, promise); + encoder.writeData(ctx, 3, data(), 0, false, promise); + encoder.writeData(ctx, 3, data(), 0, false, promise); + encoderWriteHeaders(3, promise); + + writeVerifyWriteHeaders(times(2), 3, promise); + verify(writer, times(3)) + .writeData(eq(ctx), eq(3), any(ByteBuf.class), eq(0), eq(false), eq(promise)); + } + + @Test + public void ensureCanCreateNextStreamWhenStreamCloses() { + encoder.writeSettingsAck(ctx, promise); + setMaxConcurrentStreams(1); + + encoderWriteHeaders(3, promise); + assertEquals(0, encoder.numBufferedStreams()); + + // This one gets buffered. + encoderWriteHeaders(5, promise); + assertEquals(1, connection.numActiveStreams()); + assertEquals(1, encoder.numBufferedStreams()); + + // Now prevent us from creating another stream. + setMaxConcurrentStreams(0); + + // Close the previous stream. + connection.stream(3).close(); + + // Ensure that no streams are currently active and that only the HEADERS from the first + // stream were written. + writeVerifyWriteHeaders(times(1), 3, promise); + writeVerifyWriteHeaders(never(), 5, promise); + assertEquals(0, connection.numActiveStreams()); + assertEquals(1, encoder.numBufferedStreams()); + } + + @Test + public void alternatingWritesToActiveAndBufferedStreams() { + encoder.writeSettingsAck(ctx, promise); + setMaxConcurrentStreams(1); + + encoderWriteHeaders(3, promise); + assertEquals(0, encoder.numBufferedStreams()); + + encoderWriteHeaders(5, promise); + assertEquals(1, connection.numActiveStreams()); + assertEquals(1, encoder.numBufferedStreams()); + + encoder.writeData(ctx, 3, Unpooled.buffer(0), 0, false, promise); + writeVerifyWriteHeaders(times(1), 3, promise); + encoder.writeData(ctx, 5, Unpooled.buffer(0), 0, false, promise); + verify(writer, never()) + .writeData(eq(ctx), eq(5), any(ByteBuf.class), eq(0), eq(false), eq(promise)); + } + + @Test + public void bufferingNewStreamFailsAfterGoAwayReceived() { + encoder.writeSettingsAck(ctx, promise); + setMaxConcurrentStreams(0); + connection.goAwayReceived(1, 8, null); + + promise = mock(ChannelPromise.class); + encoderWriteHeaders(3, promise); + assertEquals(0, encoder.numBufferedStreams()); + verify(promise).setFailure(any(Throwable.class)); + } + + @Test + public void receivingGoAwayFailsBufferedStreams() { + encoder.writeSettingsAck(ctx, promise); + setMaxConcurrentStreams(5); + + int streamId = 3; + for (int i = 0; i < 9; i++) { + encoderWriteHeaders(streamId, promise); + streamId += 2; + } + assertEquals(4, encoder.numBufferedStreams()); + + connection.goAwayReceived(11, 8, null); + + assertEquals(5, connection.numActiveStreams()); + // The 4 buffered streams must have been failed. + verify(promise, times(4)).setFailure(any(Throwable.class)); + assertEquals(0, encoder.numBufferedStreams()); + } + + @Test + public void sendingGoAwayShouldNotFailStreams() { + encoder.writeSettingsAck(ctx, promise); + setMaxConcurrentStreams(1); + + encoderWriteHeaders(3, promise); + assertEquals(0, encoder.numBufferedStreams()); + encoderWriteHeaders(5, promise); + assertEquals(1, encoder.numBufferedStreams()); + encoderWriteHeaders(7, promise); + assertEquals(2, encoder.numBufferedStreams()); + + ByteBuf empty = Unpooled.buffer(0); + encoder.writeGoAway(ctx, 3, CANCEL.code(), empty, promise); + + assertEquals(1, connection.numActiveStreams()); + assertEquals(2, encoder.numBufferedStreams()); + verify(promise, never()).setFailure(any(GoAwayException.class)); + } + + @Test + public void endStreamDoesNotFailBufferedStream() { + encoder.writeSettingsAck(ctx, promise); + setMaxConcurrentStreams(0); + + encoderWriteHeaders(3, promise); + assertEquals(1, encoder.numBufferedStreams()); + + ByteBuf empty = Unpooled.buffer(0); + encoder.writeData(ctx, 3, empty, 0, true, promise); + + assertEquals(0, connection.numActiveStreams()); + assertEquals(1, encoder.numBufferedStreams()); + + // Simulate that we received a SETTINGS frame which + // increased MAX_CONCURRENT_STREAMS to 1. + setMaxConcurrentStreams(1); + encoder.writeSettingsAck(ctx, promise); + + assertEquals(1, connection.numActiveStreams()); + assertEquals(0, encoder.numBufferedStreams()); + assertEquals(HALF_CLOSED_LOCAL, connection.stream(3).state()); + } + + @Test + public void rstStreamClosesBufferedStream() { + encoder.writeSettingsAck(ctx, promise); + setMaxConcurrentStreams(0); + + encoderWriteHeaders(3, promise); + assertEquals(1, encoder.numBufferedStreams()); + + verify(promise, never()).setSuccess(); + ChannelPromise rstStreamPromise = mock(ChannelPromise.class); + encoder.writeRstStream(ctx, 3, CANCEL.code(), rstStreamPromise); + verify(promise).setSuccess(); + verify(rstStreamPromise).setSuccess(); + assertEquals(0, encoder.numBufferedStreams()); + } + + @Test + public void bufferUntilActiveStreamsAreReset() { + encoder.writeSettingsAck(ctx, promise); + setMaxConcurrentStreams(1); + + encoderWriteHeaders(3, promise); + assertEquals(0, encoder.numBufferedStreams()); + encoderWriteHeaders(5, promise); + assertEquals(1, encoder.numBufferedStreams()); + encoderWriteHeaders(7, promise); + assertEquals(2, encoder.numBufferedStreams()); + + writeVerifyWriteHeaders(times(1), 3, promise); + writeVerifyWriteHeaders(never(), 5, promise); + writeVerifyWriteHeaders(never(), 7, promise); + + encoder.writeRstStream(ctx, 3, CANCEL.code(), promise); + assertEquals(1, connection.numActiveStreams()); + assertEquals(1, encoder.numBufferedStreams()); + encoder.writeRstStream(ctx, 5, CANCEL.code(), promise); + assertEquals(1, connection.numActiveStreams()); + assertEquals(0, encoder.numBufferedStreams()); + encoder.writeRstStream(ctx, 7, CANCEL.code(), promise); + assertEquals(0, connection.numActiveStreams()); + assertEquals(0, encoder.numBufferedStreams()); + } + + @Test + public void bufferUntilMaxStreamsIncreased() { + encoder.writeSettingsAck(ctx, promise); + setMaxConcurrentStreams(2); + + encoderWriteHeaders(3, promise); + encoderWriteHeaders(5, promise); + encoderWriteHeaders(7, promise); + encoderWriteHeaders(9, promise); + assertEquals(2, encoder.numBufferedStreams()); + + writeVerifyWriteHeaders(times(1), 3, promise); + writeVerifyWriteHeaders(times(1), 5, promise); + writeVerifyWriteHeaders(never(), 7, promise); + writeVerifyWriteHeaders(never(), 9, promise); + + // Simulate that we received a SETTINGS frame which + // increased MAX_CONCURRENT_STREAMS to 5. + setMaxConcurrentStreams(5); + encoder.writeSettingsAck(ctx, promise); + + assertEquals(0, encoder.numBufferedStreams()); + writeVerifyWriteHeaders(times(1), 7, promise); + writeVerifyWriteHeaders(times(1), 9, promise); + + encoderWriteHeaders(11, promise); + + writeVerifyWriteHeaders(times(1), 11, promise); + + assertEquals(5, connection.local().numActiveStreams()); + } + + @Test + public void bufferUntilSettingsReceived() throws Http2Exception { + int initialLimit = SMALLEST_MAX_CONCURRENT_STREAMS; + int numStreams = initialLimit * 2; + for (int ix = 0, nextStreamId = 3; ix < numStreams; ++ix, nextStreamId += 2) { + encoderWriteHeaders(nextStreamId, promise); + if (ix < initialLimit) { + writeVerifyWriteHeaders(times(1), nextStreamId, promise); + } else { + writeVerifyWriteHeaders(never(), nextStreamId, promise); + } + } + assertEquals(numStreams / 2, encoder.numBufferedStreams()); + + // Simulate that we received a SETTINGS frame. + setMaxConcurrentStreams(initialLimit * 2); + + assertEquals(0, encoder.numBufferedStreams()); + assertEquals(numStreams, connection.local().numActiveStreams()); + } + + @Test + public void bufferUntilSettingsReceivedWithNoMaxConcurrentStreamValue() throws Http2Exception { + int initialLimit = SMALLEST_MAX_CONCURRENT_STREAMS; + int numStreams = initialLimit * 2; + for (int ix = 0, nextStreamId = 3; ix < numStreams; ++ix, nextStreamId += 2) { + encoderWriteHeaders(nextStreamId, promise); + if (ix < initialLimit) { + writeVerifyWriteHeaders(times(1), nextStreamId, promise); + } else { + writeVerifyWriteHeaders(never(), nextStreamId, promise); + } + } + assertEquals(numStreams / 2, encoder.numBufferedStreams()); + + // Simulate that we received an empty SETTINGS frame. + encoder.remoteSettings(new Http2Settings()); + + assertEquals(0, encoder.numBufferedStreams()); + assertEquals(numStreams, connection.local().numActiveStreams()); + } + + @Test + public void exhaustedStreamsDoNotBuffer() throws Http2Exception { + // Write the highest possible stream ID for the client. + // This will cause the next stream ID to be negative. + encoderWriteHeaders(Integer.MAX_VALUE, promise); + + // Disallow any further streams. + setMaxConcurrentStreams(0); + + // Simulate numeric overflow for the next stream ID. + encoderWriteHeaders(-1, promise); + + // Verify that the write fails. + verify(promise).setFailure(any(Http2Exception.class)); + } + + @Test + public void closedBufferedStreamReleasesByteBuf() { + encoder.writeSettingsAck(ctx, promise); + setMaxConcurrentStreams(0); + ByteBuf data = mock(ByteBuf.class); + encoderWriteHeaders(3, promise); + assertEquals(1, encoder.numBufferedStreams()); + encoder.writeData(ctx, 3, data, 0, false, promise); + + ChannelPromise rstPromise = mock(ChannelPromise.class); + encoder.writeRstStream(ctx, 3, CANCEL.code(), rstPromise); + + assertEquals(0, encoder.numBufferedStreams()); + verify(rstPromise).setSuccess(); + verify(promise, times(2)).setSuccess(); + verify(data).release(); + } + + private void setMaxConcurrentStreams(int newValue) { + try { + encoder.remoteSettings(new Http2Settings().maxConcurrentStreams(newValue)); + } catch (Http2Exception e) { + throw new RuntimeException(e); + } + } + + private void encoderWriteHeaders(int streamId, ChannelPromise promise) { + encoder.writeHeaders(ctx, streamId, new DefaultHttp2Headers(), 0, DEFAULT_PRIORITY_WEIGHT, + false, 0, false, promise); + } + + private void writeVerifyWriteHeaders(VerificationMode mode, int streamId, + ChannelPromise promise) { + verify(writer, mode).writeHeaders(eq(ctx), eq(streamId), any(Http2Headers.class), eq(0), + eq(DEFAULT_PRIORITY_WEIGHT), eq(false), eq(0), + eq(false), eq(promise)); + } + + private Answer successAnswer() { + return new Answer() { + @Override + public ChannelFuture answer(InvocationOnMock invocation) throws Throwable { + ChannelPromise future = + new DefaultChannelPromise(channel, ImmediateEventExecutor.INSTANCE); + future.setSuccess(); + return future; + } + }; + } + + private static ByteBuf data() { + ByteBuf buf = Unpooled.buffer(10); + for (int i = 0; i < buf.writableBytes(); i++) { + buf.writeByte(i); + } + return buf; + } +}