diff --git a/codec-http2/src/main/java/io/netty/handler/codec/http2/AbstractHttp2ConnectionHandlerBuilder.java b/codec-http2/src/main/java/io/netty/handler/codec/http2/AbstractHttp2ConnectionHandlerBuilder.java index 8b76b707e2..0a81012ef9 100644 --- a/codec-http2/src/main/java/io/netty/handler/codec/http2/AbstractHttp2ConnectionHandlerBuilder.java +++ b/codec-http2/src/main/java/io/netty/handler/codec/http2/AbstractHttp2ConnectionHandlerBuilder.java @@ -18,12 +18,12 @@ package io.netty.handler.codec.http2; import io.netty.channel.Channel; import io.netty.handler.codec.http2.Http2HeadersEncoder.SensitivityDetector; +import io.netty.util.internal.ObjectUtil; import io.netty.util.internal.UnstableApi; import static io.netty.handler.codec.http2.Http2CodecUtil.DEFAULT_HEADER_LIST_SIZE; import static io.netty.handler.codec.http2.Http2CodecUtil.DEFAULT_MAX_RESERVED_STREAMS; import static io.netty.handler.codec.http2.Http2PromisedRequestVerifier.ALWAYS_VERIFY; -import static io.netty.util.internal.ObjectUtil.checkPositive; import static io.netty.util.internal.ObjectUtil.checkPositiveOrZero; import static java.util.Objects.requireNonNull; @@ -107,6 +107,7 @@ public abstract class AbstractHttp2ConnectionHandlerBuilder() { + @Override + public ChannelFuture answer(InvocationOnMock invocationOnMock) { + return handlePromise(invocationOnMock, 3); + } + }); + when(writer.writeSettingsAck(any(ChannelHandlerContext.class), any(ChannelPromise.class))) + .thenAnswer(new Answer() { + @Override + public ChannelFuture answer(InvocationOnMock invocationOnMock) { + return handlePromise(invocationOnMock, 1); + } + }); + when(writer.writePing(any(ChannelHandlerContext.class), anyBoolean(), anyLong(), any(ChannelPromise.class))) + .thenAnswer(new Answer() { + @Override + public ChannelFuture answer(InvocationOnMock invocationOnMock) { + ChannelPromise promise = handlePromise(invocationOnMock, 3); + if (invocationOnMock.getArgument(1) == Boolean.FALSE) { + promise.trySuccess(); + } + return promise; + } + }); + when(writer.writeGoAway(any(ChannelHandlerContext.class), anyInt(), anyLong(), any(ByteBuf.class), + any(ChannelPromise.class))).thenAnswer(new Answer() { + @Override + public ChannelFuture answer(InvocationOnMock invocationOnMock) { + ReferenceCountUtil.release(invocationOnMock.getArgument(3)); + return handlePromise(invocationOnMock, 4); + } + }); + Http2Connection connection = new DefaultHttp2Connection(false); + connection.remote().flowController(new DefaultHttp2RemoteFlowController(connection)); + connection.local().flowController(new DefaultHttp2LocalFlowController(connection).frameWriter(writer)); + + DefaultHttp2ConnectionEncoder defaultEncoder = + new DefaultHttp2ConnectionEncoder(connection, writer); + encoder = new Http2ControlFrameLimitEncoder(defaultEncoder, 2); + DefaultHttp2ConnectionDecoder decoder = + new DefaultHttp2ConnectionDecoder(connection, encoder, mock(Http2FrameReader.class)); + Http2ConnectionHandler handler = new Http2ConnectionHandlerBuilder() + .frameListener(mock(Http2FrameListener.class)) + .codec(decoder, encoder).build(); + + // Set LifeCycleManager on encoder and decoder + when(ctx.channel()).thenReturn(channel); + when(ctx.alloc()).thenReturn(UnpooledByteBufAllocator.DEFAULT); + when(channel.alloc()).thenReturn(UnpooledByteBufAllocator.DEFAULT); + when(executor.inEventLoop()).thenReturn(true); + doAnswer(new Answer() { + @Override + public ChannelPromise answer(InvocationOnMock invocation) throws Throwable { + return newPromise(); + } + }).when(ctx).newPromise(); + when(ctx.executor()).thenReturn(executor); + when(channel.isActive()).thenReturn(false); + when(channel.config()).thenReturn(config); + when(channel.isWritable()).thenReturn(true); + when(channel.bytesBeforeUnwritable()).thenReturn(Long.MAX_VALUE); + when(config.getWriteBufferHighWaterMark()).thenReturn(Integer.MAX_VALUE); + when(config.getMessageSizeEstimator()).thenReturn(DefaultMessageSizeEstimator.DEFAULT); + ChannelMetadata metadata = new ChannelMetadata(false, 16); + when(channel.metadata()).thenReturn(metadata); + when(channel.unsafe()).thenReturn(unsafe); + handler.handlerAdded(ctx); + } + + private ChannelPromise handlePromise(InvocationOnMock invocationOnMock, int promiseIdx) { + ChannelPromise promise = invocationOnMock.getArgument(promiseIdx); + if (++numWrites == 2) { + promise.setSuccess(); + } + return promise; + } + + @After + public void teardown() { + // Close and release any buffered frames. + encoder.close(); + } + + @Test + public void testLimitSettingsAck() { + assertFalse(encoder.writeSettingsAck(ctx, newPromise()).isDone()); + // The second write is always marked as success by our mock, which means it will also not be queued and so + // not count to the number of queued frames. + assertTrue(encoder.writeSettingsAck(ctx, newPromise()).isSuccess()); + assertFalse(encoder.writeSettingsAck(ctx, newPromise()).isDone()); + + verifyFlushAndClose(0, false); + + assertFalse(encoder.writeSettingsAck(ctx, newPromise()).isDone()); + assertFalse(encoder.writeSettingsAck(ctx, newPromise()).isDone()); + + verifyFlushAndClose(1, true); + } + + @Test + public void testLimitPingAck() { + assertFalse(encoder.writePing(ctx, true, 8, newPromise()).isDone()); + // The second write is always marked as success by our mock, which means it will also not be queued and so + // not count to the number of queued frames. + assertTrue(encoder.writePing(ctx, true, 8, newPromise()).isSuccess()); + assertFalse(encoder.writePing(ctx, true, 8, newPromise()).isDone()); + + verifyFlushAndClose(0, false); + + assertFalse(encoder.writePing(ctx, true, 8, newPromise()).isDone()); + assertFalse(encoder.writePing(ctx, true, 8, newPromise()).isDone()); + + verifyFlushAndClose(1, true); + } + + @Test + public void testNotLimitPing() { + assertTrue(encoder.writePing(ctx, false, 8, newPromise()).isSuccess()); + assertTrue(encoder.writePing(ctx, false, 8, newPromise()).isSuccess()); + assertTrue(encoder.writePing(ctx, false, 8, newPromise()).isSuccess()); + assertTrue(encoder.writePing(ctx, false, 8, newPromise()).isSuccess()); + + verifyFlushAndClose(0, false); + } + + @Test + public void testLimitRst() { + assertFalse(encoder.writeRstStream(ctx, 1, CANCEL.code(), newPromise()).isDone()); + // The second write is always marked as success by our mock, which means it will also not be queued and so + // not count to the number of queued frames. + assertTrue(encoder.writeRstStream(ctx, 1, CANCEL.code(), newPromise()).isSuccess()); + assertFalse(encoder.writeRstStream(ctx, 1, CANCEL.code(), newPromise()).isDone()); + + verifyFlushAndClose(0, false); + + assertFalse(encoder.writeRstStream(ctx, 1, CANCEL.code(), newPromise()).isDone()); + assertFalse(encoder.writeRstStream(ctx, 1, CANCEL.code(), newPromise()).isDone()); + + verifyFlushAndClose(1, true); + } + + @Test + public void testLimit() { + assertFalse(encoder.writeRstStream(ctx, 1, CANCEL.code(), newPromise()).isDone()); + // The second write is always marked as success by our mock, which means it will also not be queued and so + // not count to the number of queued frames. + assertTrue(encoder.writePing(ctx, false, 8, newPromise()).isSuccess()); + assertFalse(encoder.writePing(ctx, true, 8, newPromise()).isSuccess()); + + verifyFlushAndClose(0, false); + + assertFalse(encoder.writeSettingsAck(ctx, newPromise()).isDone()); + assertFalse(encoder.writeRstStream(ctx, 1, CANCEL.code(), newPromise()).isDone()); + assertFalse(encoder.writePing(ctx, true, 8, newPromise()).isSuccess()); + + verifyFlushAndClose(1, true); + } + + private void verifyFlushAndClose(int invocations, boolean failed) { + verify(ctx, atLeast(invocations)).flush(); + verify(ctx, times(invocations)).close(); + if (failed) { + verify(writer, times(1)).writeGoAway(eq(ctx), eq(0), eq(ENHANCE_YOUR_CALM.code()), + any(ByteBuf.class), any(ChannelPromise.class)); + } + } + + private static void assertWriteFailure(ChannelFuture future) { + Http2Exception exception = (Http2Exception) future.cause(); + assertEquals(ShutdownHint.HARD_SHUTDOWN, exception.shutdownHint()); + assertEquals(Http2Error.ENHANCE_YOUR_CALM, exception.error()); + } + + private ChannelPromise newPromise() { + return new DefaultChannelPromise(channel, ImmediateEventExecutor.INSTANCE); + } +}