diff --git a/codec-http2/src/main/java/io/netty/handler/codec/http2/DefaultHttp2ConnectionEncoder.java b/codec-http2/src/main/java/io/netty/handler/codec/http2/DefaultHttp2ConnectionEncoder.java index abff8b48cd..313fab02f1 100644 --- a/codec-http2/src/main/java/io/netty/handler/codec/http2/DefaultHttp2ConnectionEncoder.java +++ b/codec-http2/src/main/java/io/netty/handler/codec/http2/DefaultHttp2ConnectionEncoder.java @@ -166,10 +166,6 @@ public class DefaultHttp2ConnectionEncoder implements Http2ConnectionEncoder { throw new IllegalStateException(String.format( "Stream %d in unexpected state: %s", stream.id(), stream.state())); } - - if (endOfStream) { - lifecycleManager.closeLocalSide(stream, promise); - } } catch (Throwable e) { data.release(); return promise.setFailure(e); @@ -219,9 +215,6 @@ public class DefaultHttp2ConnectionEncoder implements Http2ConnectionEncoder { flowController().sendFlowControlled(ctx, stream, new FlowControlledHeaders(ctx, stream, headers, streamDependency, weight, exclusive, padding, endOfStream, promise)); - if (endOfStream) { - lifecycleManager.closeLocalSide(stream, promise); - } return promise; } catch (Http2NoMoreStreamIdsException e) { lifecycleManager.onException(ctx, e); @@ -427,16 +420,16 @@ public class DefaultHttp2ConnectionEncoder implements Http2ConnectionEncoder { @Override public boolean write(int allowedBytes) { - if (data == null) { - return false; - } - if (allowedBytes == 0 && size() != 0) { - // No point writing an empty DATA frame, wait for a bigger allowance. - return false; - } - int maxFrameSize = frameWriter().configuration().frameSizePolicy().maxFrameSize(); + int bytesWritten = 0; try { - int bytesWritten = 0; + if (data == null) { + return false; + } + if (allowedBytes == 0 && size != 0) { + // No point writing an empty DATA frame, wait for a bigger allowance. + return false; + } + int maxFrameSize = frameWriter().configuration().frameSizePolicy().maxFrameSize(); do { int allowedFrameSize = Math.min(maxFrameSize, allowedBytes - bytesWritten); ByteBuf toWrite; @@ -465,13 +458,11 @@ public class DefaultHttp2ConnectionEncoder implements Http2ConnectionEncoder { writePromise.addListener(this); } frameWriter().writeData(ctx, stream.id(), toWrite, writeablePadding, - size == bytesWritten && endOfStream, writePromise); + size == bytesWritten && endOfStream, writePromise); } while (size != bytesWritten && allowedBytes > bytesWritten); - size -= bytesWritten; return true; - } catch (Throwable e) { - error(e); - return false; + } finally { + size -= bytesWritten; } } } @@ -538,9 +529,17 @@ public class DefaultHttp2ConnectionEncoder implements Http2ConnectionEncoder { this.endOfStream = endOfStream; this.stream = stream; this.promise = promise; + // Ensure error() gets called in case something goes wrong after the frame is passed to Netty. promise.addListener(this); } + @Override + public void writeComplete() { + if (endOfStream) { + lifecycleManager.closeLocalSide(stream, promise); + } + } + @Override public void operationComplete(ChannelFuture future) throws Exception { if (!future.isSuccess()) { diff --git a/codec-http2/src/main/java/io/netty/handler/codec/http2/DefaultHttp2RemoteFlowController.java b/codec-http2/src/main/java/io/netty/handler/codec/http2/DefaultHttp2RemoteFlowController.java index 810d0d3642..3cd72143f4 100644 --- a/codec-http2/src/main/java/io/netty/handler/codec/http2/DefaultHttp2RemoteFlowController.java +++ b/codec-http2/src/main/java/io/netty/handler/codec/http2/DefaultHttp2RemoteFlowController.java @@ -23,11 +23,12 @@ import static io.netty.util.internal.ObjectUtil.checkNotNull; import static java.lang.Math.max; import static java.lang.Math.min; import io.netty.channel.ChannelHandlerContext; +import io.netty.handler.codec.http2.Http2Stream.State; import java.util.ArrayDeque; import java.util.Arrays; import java.util.Comparator; -import java.util.Queue; +import java.util.Deque; /** * Basic implementation of {@link Http2RemoteFlowController}. @@ -73,9 +74,27 @@ public class DefaultHttp2RemoteFlowController implements Http2RemoteFlowControll @Override public void streamInactive(Http2Stream stream) { - // Any pending frames can never be written, clear and + // Any pending frames can never be written, cancel and // write errors for any pending frames. - state(stream).clear(); + state(stream).cancel(); + } + + @Override + public void streamHalfClosed(Http2Stream stream) { + if (State.HALF_CLOSED_LOCAL.equals(stream.state())) { + /** + * When this method is called there should not be any + * pending frames left if the API is used correctly. However, + * it is possible that a erroneous application can sneak + * in a frame even after having already written a frame with the + * END_STREAM flag set, as the stream state might not transition + * immediately to HALF_CLOSED_LOCAL / CLOSED due to flow control + * delaying the write. + * + * This is to cancel any such illegal writes. + */ + state(stream).cancel(); + } } @Override @@ -156,13 +175,19 @@ public class DefaultHttp2RemoteFlowController implements Http2RemoteFlowControll } // Save the context. We'll use this later when we write pending bytes. this.ctx = ctx; + FlowState state; try { - FlowState state = state(stream); + state = state(stream); state.newFrame(payload); - state.writeBytes(state.writableWindow()); + } catch (Throwable t) { + payload.error(t); + return; + } + state.writeBytes(state.writableWindow()); + try { flush(); - } catch (Throwable e) { - payload.error(e); + } catch (Throwable t) { + payload.error(t); } } @@ -228,7 +253,6 @@ public class DefaultHttp2RemoteFlowController implements Http2RemoteFlowControll return 0; } int bytesAllocated = 0; - // If the number of streamable bytes for this tree will fit in the connection window // then there is no need to prioritize the bytes...everyone sends what they have if (state.streamableBytesForTree() <= connectionWindow) { @@ -312,12 +336,16 @@ public class DefaultHttp2RemoteFlowController implements Http2RemoteFlowControll * The outbound flow control state for a single stream. */ final class FlowState { - private final Queue pendingWriteQueue; + private final Deque pendingWriteQueue; private final Http2Stream stream; private int window; private int pendingBytes; private int streamableBytesForTree; private int allocated; + // Set to true while a frame is being written, false otherwise. + private boolean writing; + // Set to true if cancel() was called. + private boolean cancelled; FlowState(Http2Stream stream, int initialWindowSize) { this.stream = stream; @@ -422,7 +450,13 @@ public class DefaultHttp2RemoteFlowController implements Http2RemoteFlowControll /** * Clears the pending queue and writes errors for each remaining frame. */ - void clear() { + void cancel() { + cancelled = true; + // Ensure that the queue can't be modified while + // we are writing. + if (writing) { + return; + } for (;;) { Frame frame = pendingWriteQueue.poll(); if (frame == null) { @@ -496,19 +530,52 @@ public class DefaultHttp2RemoteFlowController implements Http2RemoteFlowControll */ int write(int allowedBytes) { int before = payload.size(); - needFlush |= payload.write(Math.max(0, allowedBytes)); - int writtenBytes = before - payload.size(); + int writtenBytes = 0; try { - connectionState().incrementStreamWindow(-writtenBytes); - incrementStreamWindow(-writtenBytes); + if (writing) { + throw new IllegalStateException("write is not re-entrant"); + } + // Write the portion of the frame. + writing = true; + needFlush |= payload.write(Math.max(0, allowedBytes)); + if (!cancelled && payload.size() == 0) { + // This frame has been fully written, remove this frame + // and notify the payload. Since we remove this frame + // first, we're guaranteed that its error method will not + // be called when we call cancel. + pendingWriteQueue.remove(); + payload.writeComplete(); + } + } catch (Throwable e) { + // Mark the state as cancelled, we'll clear the pending queue + // via cancel() below. + cancelled = true; + } finally { + writing = false; + // Make sure we always decrement the flow control windows + // by the bytes written. + writtenBytes = before - payload.size(); + decrementFlowControlWindow(writtenBytes); + decrementPendingBytes(writtenBytes); + // If a cancellation occurred while writing, call cancel again to + // clear and error all of the pending writes. + if (cancelled) { + cancel(); + } + } + return writtenBytes; + } + + /** + * Decrement the per stream and connection flow control window by {@code bytes}. + */ + void decrementFlowControlWindow(int bytes) { + try { + connectionState().incrementStreamWindow(-bytes); + incrementStreamWindow(-bytes); } catch (Http2Exception e) { // Should never get here since we're decrementing. throw new RuntimeException("Invalid window state when writing frame: " + e.getMessage(), e); } - decrementPendingBytes(writtenBytes); - if (payload.size() == 0) { - pendingWriteQueue.remove(); - } - return writtenBytes; } /** diff --git a/codec-http2/src/main/java/io/netty/handler/codec/http2/Http2RemoteFlowController.java b/codec-http2/src/main/java/io/netty/handler/codec/http2/Http2RemoteFlowController.java index 3b3d6f3211..d1b71c9f1b 100644 --- a/codec-http2/src/main/java/io/netty/handler/codec/http2/Http2RemoteFlowController.java +++ b/codec-http2/src/main/java/io/netty/handler/codec/http2/Http2RemoteFlowController.java @@ -52,18 +52,37 @@ public interface Http2RemoteFlowController extends Http2FlowController { int size(); /** - * Signal an error and release any retained buffers. + * Called to indicate that an error occurred before this object could be completely written. + *

+ * The {@link Http2RemoteFlowController} will make exactly one call to either {@link #error(Throwable)} or + * {@link #writeComplete()}. + *

+ * * @param cause of the error. */ void error(Throwable cause); + /** + * Called after this object has been successfully written. + *

+ * The {@link Http2RemoteFlowController} will make exactly one call to either {@link #error(Throwable)} or + * {@link #writeComplete()}. + *

+ */ + void writeComplete(); + /** * Writes up to {@code allowedBytes} of the encapsulated payload to the stream. Note that * a value of 0 may be passed which will allow payloads with flow-control size == 0 to be * written. The flow-controller may call this method multiple times with different values until * the payload is fully written. + *

+ * When an exception is thrown the {@link Http2RemoteFlowController} will make a call to + * {@link #error(Throwable)}. + *

* * @param allowedBytes an upper bound on the number of bytes the payload can write at this time. + * @throws Exception if an error occurs. The method must not call {@link #error(Throwable)} by itself. * @return {@code true} if a flush is required, {@code false} otherwise. */ boolean write(int allowedBytes); diff --git a/codec-http2/src/test/java/io/netty/handler/codec/http2/DefaultHttp2ConnectionEncoderTest.java b/codec-http2/src/test/java/io/netty/handler/codec/http2/DefaultHttp2ConnectionEncoderTest.java index 49bd07375e..45ccc99120 100644 --- a/codec-http2/src/test/java/io/netty/handler/codec/http2/DefaultHttp2ConnectionEncoderTest.java +++ b/codec-http2/src/test/java/io/netty/handler/codec/http2/DefaultHttp2ConnectionEncoderTest.java @@ -18,20 +18,20 @@ import static io.netty.buffer.Unpooled.wrappedBuffer; import static io.netty.handler.codec.http2.Http2CodecUtil.DEFAULT_PRIORITY_WEIGHT; import static io.netty.handler.codec.http2.Http2CodecUtil.emptyPingBuf; import static io.netty.handler.codec.http2.Http2Error.PROTOCOL_ERROR; -import static io.netty.handler.codec.http2.Http2Stream.State.HALF_CLOSED_LOCAL; import static io.netty.handler.codec.http2.Http2Stream.State.IDLE; import static io.netty.handler.codec.http2.Http2Stream.State.OPEN; import static io.netty.handler.codec.http2.Http2Stream.State.RESERVED_LOCAL; +import static io.netty.handler.codec.http2.Http2Stream.State.HALF_CLOSED_LOCAL; import static io.netty.util.CharsetUtil.UTF_8; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertNotNull; -import static org.junit.Assert.assertSame; import static org.junit.Assert.assertTrue; import static org.junit.Assert.fail; import static org.mockito.Matchers.any; import static org.mockito.Matchers.anyBoolean; import static org.mockito.Matchers.anyInt; import static org.mockito.Matchers.anyLong; +import static org.mockito.Matchers.anyShort; import static org.mockito.Matchers.eq; import static org.mockito.Mockito.doAnswer; import static org.mockito.Mockito.doNothing; @@ -49,6 +49,7 @@ import io.netty.channel.ChannelFutureListener; import io.netty.channel.ChannelHandlerContext; import io.netty.channel.ChannelPromise; import io.netty.channel.DefaultChannelPromise; +import io.netty.handler.codec.http2.Http2RemoteFlowController.FlowControlled; import io.netty.util.concurrent.ImmediateEventExecutor; import org.junit.Before; @@ -175,9 +176,6 @@ public class DefaultHttp2ConnectionEncoderTest { fail("Stream already closed"); } else { streamClosed = (Boolean) invocationOnMock.getArguments()[4]; - if (streamClosed) { - assertSame(promise, receivedPromise); - } } writtenPadding.add((Integer) invocationOnMock.getArguments()[3]); ByteBuf data = (ByteBuf) invocationOnMock.getArguments()[2]; @@ -189,6 +187,21 @@ public class DefaultHttp2ConnectionEncoderTest { return future; } }); + when(writer.writeHeaders(eq(ctx), anyInt(), any(Http2Headers.class), anyInt(), anyShort(), anyBoolean(), + anyInt(), anyBoolean(), any(ChannelPromise.class))) + .then(new Answer() { + @Override + public Object answer(InvocationOnMock invocationOnMock) throws Throwable { + ChannelPromise receivedPromise = (ChannelPromise) invocationOnMock.getArguments()[8]; + if (streamClosed) { + fail("Stream already closed"); + } else { + streamClosed = (Boolean) invocationOnMock.getArguments()[5]; + } + receivedPromise.trySuccess(); + return future; + } + }); payloadCaptor = ArgumentCaptor.forClass(Http2RemoteFlowController.FlowControlled.class); doNothing().when(remoteFlow).sendFlowControlled(eq(ctx), eq(stream), payloadCaptor.capture()); when(ctx.alloc()).thenReturn(UnpooledByteBufAllocator.DEFAULT); @@ -227,19 +240,7 @@ public class DefaultHttp2ConnectionEncoderTest { } @Test - public void dataWriteShouldHalfCloseStream() throws Exception { - reset(future); - final ByteBuf data = dummyData(); - encoder.writeData(ctx, STREAM_ID, data, 0, true, promise); - assertEquals(1, payloadCaptor.getAllValues().size()); - // Write the DATA frame completely - assertTrue(payloadCaptor.getValue().write(Integer.MAX_VALUE)); - verify(lifecycleManager).closeLocalSide(eq(stream), eq(promise)); - assertEquals(0, data.refCnt()); - } - - @Test - public void dataLargerThanMaxFrameSizeShouldBeSplit() { + public void dataLargerThanMaxFrameSizeShouldBeSplit() throws Exception { when(frameSizePolicy.maxFrameSize()).thenReturn(3); final ByteBuf data = dummyData(); encoder.writeData(ctx, STREAM_ID, data, 0, true, promise); @@ -254,7 +255,7 @@ public class DefaultHttp2ConnectionEncoderTest { } @Test - public void paddingSplitOverFrame() { + public void paddingSplitOverFrame() throws Exception { when(frameSizePolicy.maxFrameSize()).thenReturn(5); final ByteBuf data = dummyData(); encoder.writeData(ctx, STREAM_ID, data, 5, true, promise); @@ -272,7 +273,7 @@ public class DefaultHttp2ConnectionEncoderTest { } @Test - public void frameShouldSplitPadding() { + public void frameShouldSplitPadding() throws Exception { when(frameSizePolicy.maxFrameSize()).thenReturn(5); ByteBuf data = dummyData(); encoder.writeData(ctx, STREAM_ID, data, 10, true, promise); @@ -292,18 +293,18 @@ public class DefaultHttp2ConnectionEncoderTest { } @Test - public void emptyFrameShouldSplitPadding() { + public void emptyFrameShouldSplitPadding() throws Exception { ByteBuf data = Unpooled.buffer(0); assertSplitPaddingOnEmptyBuffer(data); assertEquals(0, data.refCnt()); } @Test - public void singletonEmptyBufferShouldSplitPadding() { + public void singletonEmptyBufferShouldSplitPadding() throws Exception { assertSplitPaddingOnEmptyBuffer(Unpooled.EMPTY_BUFFER); } - private void assertSplitPaddingOnEmptyBuffer(ByteBuf data) { + private void assertSplitPaddingOnEmptyBuffer(ByteBuf data) throws Exception { when(frameSizePolicy.maxFrameSize()).thenReturn(5); encoder.writeData(ctx, STREAM_ID, data, 10, true, promise); assertEquals(payloadCaptor.getValue().size(), 10); @@ -324,7 +325,7 @@ public class DefaultHttp2ConnectionEncoderTest { verify(local, never()).createStream(anyInt()); verify(stream, never()).open(anyBoolean()); verify(writer, never()).writeHeaders(eq(ctx), anyInt(), any(Http2Headers.class), anyInt(), anyBoolean(), - eq(promise)); + eq(promise)); assertTrue(future.awaitUninterruptibly().cause() instanceof Http2Exception); } @@ -344,24 +345,6 @@ public class DefaultHttp2ConnectionEncoderTest { eq(DEFAULT_PRIORITY_WEIGHT), eq(false), eq(0), eq(false), eq(promise)); } - @Test - public void headersWriteShouldCreateHalfClosedStream() throws Exception { - int streamId = 5; - when(stream.id()).thenReturn(streamId); - when(stream.state()).thenReturn(IDLE); - mockFutureAddListener(true); - when(local.createStream(eq(streamId))).thenReturn(stream); - when(writer.writeHeaders(eq(ctx), eq(streamId), eq(EmptyHttp2Headers.INSTANCE), eq(0), - eq(DEFAULT_PRIORITY_WEIGHT), eq(false), eq(0), eq(true), eq(promise))).thenReturn(future); - encoder.writeHeaders(ctx, streamId, EmptyHttp2Headers.INSTANCE, 0, true, promise); - verify(local).createStream(eq(streamId)); - verify(stream).open(eq(true)); - // Trigger the write and mark the promise successful to trigger listeners - payloadCaptor.getValue().write(0); - promise.trySuccess(); - verify(lifecycleManager).closeLocalSide(eq(stream), eq(promise)); - } - @Test public void headersWriteShouldOpenStreamForPush() throws Exception { mockFutureAddListener(true); @@ -372,21 +355,7 @@ public class DefaultHttp2ConnectionEncoderTest { assertNotNull(payloadCaptor.getValue()); payloadCaptor.getValue().write(0); verify(writer).writeHeaders(eq(ctx), eq(STREAM_ID), eq(EmptyHttp2Headers.INSTANCE), eq(0), - eq(DEFAULT_PRIORITY_WEIGHT), eq(false), eq(0), eq(false), eq(promise)); - } - - @Test - public void headersWriteShouldClosePushStream() throws Exception { - mockFutureAddListener(true); - when(stream.state()).thenReturn(RESERVED_LOCAL).thenReturn(HALF_CLOSED_LOCAL); - when(writer.writeHeaders(eq(ctx), eq(STREAM_ID), eq(EmptyHttp2Headers.INSTANCE), eq(0), - eq(DEFAULT_PRIORITY_WEIGHT), eq(false), eq(0), eq(true), eq(promise))).thenReturn(future); - encoder.writeHeaders(ctx, STREAM_ID, EmptyHttp2Headers.INSTANCE, 0, true, promise); - verify(stream).open(true); - // Trigger the write and mark the promise successful to trigger listeners - payloadCaptor.getValue().write(0); - promise.trySuccess(); - verify(lifecycleManager).closeLocalSide(eq(stream), eq(promise)); + eq(DEFAULT_PRIORITY_WEIGHT), eq(false), eq(0), eq(false), eq(promise)); } @Test @@ -394,7 +363,7 @@ public class DefaultHttp2ConnectionEncoderTest { when(connection.isGoAway()).thenReturn(true); ChannelFuture future = encoder.writePushPromise(ctx, STREAM_ID, PUSH_STREAM_ID, - EmptyHttp2Headers.INSTANCE, 0, promise); + EmptyHttp2Headers.INSTANCE, 0, promise); assertTrue(future.awaitUninterruptibly().cause() instanceof Http2Exception); } @@ -467,6 +436,64 @@ public class DefaultHttp2ConnectionEncoderTest { verify(writer).writeSettings(eq(ctx), eq(settings), eq(promise)); } + @Test + public void dataWriteShouldCreateHalfClosedStream() { + mockSendFlowControlledWriteEverything(); + ByteBuf data = dummyData(); + encoder.writeData(ctx, STREAM_ID, data.retain(), 0, true, promise); + verify(remoteFlow).sendFlowControlled(eq(ctx), eq(stream), any(FlowControlled.class)); + verify(lifecycleManager).closeLocalSide(stream, promise); + assertEquals(data.toString(UTF_8), writtenData.get(0)); + data.release(); + } + + @Test + public void headersWriteShouldHalfCloseStream() throws Exception { + mockSendFlowControlledWriteEverything(); + int streamId = 5; + when(stream.id()).thenReturn(streamId); + when(stream.state()).thenReturn(IDLE); + mockFutureAddListener(true); + when(local.createStream(eq(streamId))).thenReturn(stream); + when(writer.writeHeaders(eq(ctx), eq(streamId), eq(EmptyHttp2Headers.INSTANCE), eq(0), + eq(DEFAULT_PRIORITY_WEIGHT), eq(false), eq(0), eq(true), eq(promise))) + .thenReturn(future); + encoder.writeHeaders(ctx, streamId, EmptyHttp2Headers.INSTANCE, 0, true, promise); + verify(local).createStream(eq(streamId)); + verify(stream).open(eq(true)); + // Trigger the write and mark the promise successful to trigger listeners + payloadCaptor.getValue().write(0); + promise.trySuccess(); + verify(lifecycleManager).closeLocalSide(eq(stream), eq(promise)); + } + + @Test + public void headersWriteShouldHalfClosePushStream() throws Exception { + mockSendFlowControlledWriteEverything(); + mockFutureAddListener(true); + when(stream.state()).thenReturn(RESERVED_LOCAL).thenReturn(HALF_CLOSED_LOCAL); + when(writer.writeHeaders(eq(ctx), eq(STREAM_ID), eq(EmptyHttp2Headers.INSTANCE), eq(0), + eq(DEFAULT_PRIORITY_WEIGHT), eq(false), eq(0), eq(true), eq(promise))) + .thenReturn(future); + encoder.writeHeaders(ctx, STREAM_ID, EmptyHttp2Headers.INSTANCE, 0, true, promise); + verify(stream).open(true); + + promise.trySuccess(); + verify(lifecycleManager).closeLocalSide(eq(stream), eq(promise)); + } + + private void mockSendFlowControlledWriteEverything() { + doAnswer(new Answer() { + @Override + public Void answer(InvocationOnMock invocationOnMock) throws Throwable { + FlowControlled flowControlled = (FlowControlled) invocationOnMock.getArguments()[2]; + flowControlled.write(Integer.MAX_VALUE); + flowControlled.writeComplete(); + return null; + } + }).when(remoteFlow).sendFlowControlled(eq(ctx), eq(stream), payloadCaptor.capture()); + } + private void mockFutureAddListener(boolean success) { when(future.isSuccess()).thenReturn(success); if (!success) { diff --git a/codec-http2/src/test/java/io/netty/handler/codec/http2/DefaultHttp2RemoteFlowControllerTest.java b/codec-http2/src/test/java/io/netty/handler/codec/http2/DefaultHttp2RemoteFlowControllerTest.java index 43f2902c54..d508e72893 100644 --- a/codec-http2/src/test/java/io/netty/handler/codec/http2/DefaultHttp2RemoteFlowControllerTest.java +++ b/codec-http2/src/test/java/io/netty/handler/codec/http2/DefaultHttp2RemoteFlowControllerTest.java @@ -22,6 +22,10 @@ import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertFalse; import static org.junit.Assert.assertNotNull; import static org.junit.Assert.assertTrue; +import static org.junit.Assert.fail; +import static org.mockito.Matchers.any; +import static org.mockito.Matchers.anyInt; +import static org.mockito.Mockito.doAnswer; import static org.mockito.Mockito.never; import static org.mockito.Mockito.reset; import static org.mockito.Mockito.times; @@ -31,6 +35,7 @@ import io.netty.buffer.ByteBuf; import io.netty.channel.ChannelHandlerContext; import io.netty.channel.ChannelPromise; import io.netty.handler.codec.http2.Http2FrameWriter.Configuration; +import io.netty.handler.codec.http2.Http2RemoteFlowController.FlowControlled; import io.netty.util.collection.IntObjectHashMap; import io.netty.util.collection.IntObjectMap; @@ -40,7 +45,10 @@ import java.util.List; import org.junit.Before; import org.junit.Test; import org.mockito.Mock; +import org.mockito.Mockito; import org.mockito.MockitoAnnotations; +import org.mockito.invocation.InvocationOnMock; +import org.mockito.stubbing.Answer; /** * Tests for {@link DefaultHttp2RemoteFlowController}. @@ -987,6 +995,149 @@ public class DefaultHttp2RemoteFlowControllerTest { streamableBytesForTree(streamD)); } + @Test + public void flowControlledWriteThrowsAnException() throws Exception { + final Http2RemoteFlowController.FlowControlled flowControlled = mockedFlowControlledThatThrowsOnWrite(); + final Http2Stream stream = stream(STREAM_A); + doAnswer(new Answer() { + public Void answer(InvocationOnMock invocationOnMock) { + stream.closeLocalSide(); + return null; + } + }).when(flowControlled).error(any(Throwable.class)); + + int windowBefore = window(STREAM_A); + + controller.sendFlowControlled(ctx, stream, flowControlled); + + verify(flowControlled, times(3)).write(anyInt()); + verify(flowControlled).error(any(Throwable.class)); + verify(flowControlled, never()).writeComplete(); + + assertEquals(90, windowBefore - window(STREAM_A)); + } + + @Test + public void flowControlledWriteAndErrorThrowAnException() throws Exception { + final Http2RemoteFlowController.FlowControlled flowControlled = mockedFlowControlledThatThrowsOnWrite(); + final Http2Stream stream = stream(STREAM_A); + doAnswer(new Answer() { + public Void answer(InvocationOnMock invocationOnMock) { + throw new RuntimeException("error failed"); + } + }).when(flowControlled).error(any(Throwable.class)); + + int windowBefore = window(STREAM_A); + + boolean exceptionThrown = false; + try { + controller.sendFlowControlled(ctx, stream, flowControlled); + } catch (RuntimeException e) { + exceptionThrown = true; + } finally { + assertTrue(exceptionThrown); + } + + verify(flowControlled, times(3)).write(anyInt()); + verify(flowControlled).error(any(Throwable.class)); + verify(flowControlled, never()).writeComplete(); + + assertEquals(90, windowBefore - window(STREAM_A)); + } + + @Test + public void flowControlledWriteCompleteThrowsAnException() throws Exception { + final Http2RemoteFlowController.FlowControlled flowControlled = + Mockito.mock(Http2RemoteFlowController.FlowControlled.class); + when(flowControlled.size()).thenReturn(100); + when(flowControlled.write(anyInt())).thenAnswer(new Answer() { + private int invocationCount; + @Override + public Boolean answer(InvocationOnMock invocationOnMock) throws Throwable { + switch(invocationCount) { + case 0: + when(flowControlled.size()).thenReturn(50); + invocationCount = 1; + return true; + case 1: + when(flowControlled.size()).thenReturn(20); + invocationCount = 2; + return true; + default: + when(flowControlled.size()).thenReturn(0); + return false; + } + } + }); + final Http2Stream stream = stream(STREAM_A); + doAnswer(new Answer() { + public Void answer(InvocationOnMock invocationOnMock) { + throw new RuntimeException("writeComplete failed"); + } + }).when(flowControlled).writeComplete(); + + int windowBefore = window(STREAM_A); + + try { + controller.sendFlowControlled(ctx, stream, flowControlled); + } catch (Exception e) { + fail(); + } + + verify(flowControlled, times(3)).write(anyInt()); + verify(flowControlled, never()).error(any(Throwable.class)); + verify(flowControlled).writeComplete(); + + assertEquals(100, windowBefore - window(STREAM_A)); + } + + @Test + public void closeStreamInFlowControlledError() throws Exception { + final Http2RemoteFlowController.FlowControlled flowControlled = + Mockito.mock(Http2RemoteFlowController.FlowControlled.class); + final Http2Stream stream = stream(STREAM_A); + when(flowControlled.size()).thenReturn(100); + when(flowControlled.write(anyInt())).thenThrow(new RuntimeException("write failed")); + doAnswer(new Answer() { + public Void answer(InvocationOnMock invocationOnMock) { + stream.close(); + return null; + } + }).when(flowControlled).error(any(Throwable.class)); + + controller.sendFlowControlled(ctx, stream, flowControlled); + + verify(flowControlled).write(anyInt()); + verify(flowControlled).error(any(Throwable.class)); + verify(flowControlled, never()).writeComplete(); + } + + private static Http2RemoteFlowController.FlowControlled mockedFlowControlledThatThrowsOnWrite() throws Exception { + final Http2RemoteFlowController.FlowControlled flowControlled = + Mockito.mock(Http2RemoteFlowController.FlowControlled.class); + when(flowControlled.size()).thenReturn(100); + when(flowControlled.write(anyInt())).thenAnswer(new Answer() { + private int invocationCount; + @Override + public Boolean answer(InvocationOnMock invocationOnMock) throws Throwable { + switch(invocationCount) { + case 0: + when(flowControlled.size()).thenReturn(50); + invocationCount = 1; + return true; + case 1: + when(flowControlled.size()).thenReturn(20); + invocationCount = 2; + return true; + default: + when(flowControlled.size()).thenReturn(10); + throw new RuntimeException("Write failed"); + } + } + }); + return flowControlled; + } + private static int calculateStreamSizeSum(IntObjectMap streamSizes, List streamIds) { int sum = 0; for (Integer streamId : streamIds) { @@ -1049,6 +1200,10 @@ public class DefaultHttp2RemoteFlowControllerTest { this.t = t; } + @Override + public void writeComplete() { + } + @Override public boolean write(int allowedBytes) { if (allowedBytes <= 0 && currentSize != 0) {