Fix premature cancelation of pending frames in HTTP2 Flow Control.

Motivation:

If HEADERS or DATA frames are pending due to a too small flow control
window, a frame with the END_STREAM flag set will wrongfully cancel
all pending frames (including itself).

Also see grpc/grpc-java#145

Modifications:

The transition of the stream state to CLOSE / HALF_CLOSE due to a
set END_STREAM flag is delayed until the frame with the flag is
actually written to the Channel.

Result:

Flow control works correctly. Frames with END_STREAM flag will no
longer cancel their preceding frames.
This commit is contained in:
Jakob Buchgraber 2015-03-04 15:55:55 -08:00 committed by nmittler
parent 0767da12fb
commit 88beae6838
5 changed files with 366 additions and 99 deletions

View File

@ -166,10 +166,6 @@ public class DefaultHttp2ConnectionEncoder implements Http2ConnectionEncoder {
throw new IllegalStateException(String.format( throw new IllegalStateException(String.format(
"Stream %d in unexpected state: %s", stream.id(), stream.state())); "Stream %d in unexpected state: %s", stream.id(), stream.state()));
} }
if (endOfStream) {
lifecycleManager.closeLocalSide(stream, promise);
}
} catch (Throwable e) { } catch (Throwable e) {
data.release(); data.release();
return promise.setFailure(e); return promise.setFailure(e);
@ -219,9 +215,6 @@ public class DefaultHttp2ConnectionEncoder implements Http2ConnectionEncoder {
flowController().sendFlowControlled(ctx, stream, flowController().sendFlowControlled(ctx, stream,
new FlowControlledHeaders(ctx, stream, headers, streamDependency, weight, new FlowControlledHeaders(ctx, stream, headers, streamDependency, weight,
exclusive, padding, endOfStream, promise)); exclusive, padding, endOfStream, promise));
if (endOfStream) {
lifecycleManager.closeLocalSide(stream, promise);
}
return promise; return promise;
} catch (Http2NoMoreStreamIdsException e) { } catch (Http2NoMoreStreamIdsException e) {
lifecycleManager.onException(ctx, e); lifecycleManager.onException(ctx, e);
@ -427,16 +420,16 @@ public class DefaultHttp2ConnectionEncoder implements Http2ConnectionEncoder {
@Override @Override
public boolean write(int allowedBytes) { public boolean write(int allowedBytes) {
int bytesWritten = 0;
try {
if (data == null) { if (data == null) {
return false; return false;
} }
if (allowedBytes == 0 && size() != 0) { if (allowedBytes == 0 && size != 0) {
// No point writing an empty DATA frame, wait for a bigger allowance. // No point writing an empty DATA frame, wait for a bigger allowance.
return false; return false;
} }
int maxFrameSize = frameWriter().configuration().frameSizePolicy().maxFrameSize(); int maxFrameSize = frameWriter().configuration().frameSizePolicy().maxFrameSize();
try {
int bytesWritten = 0;
do { do {
int allowedFrameSize = Math.min(maxFrameSize, allowedBytes - bytesWritten); int allowedFrameSize = Math.min(maxFrameSize, allowedBytes - bytesWritten);
ByteBuf toWrite; ByteBuf toWrite;
@ -467,11 +460,9 @@ public class DefaultHttp2ConnectionEncoder implements Http2ConnectionEncoder {
frameWriter().writeData(ctx, stream.id(), toWrite, writeablePadding, frameWriter().writeData(ctx, stream.id(), toWrite, writeablePadding,
size == bytesWritten && endOfStream, writePromise); size == bytesWritten && endOfStream, writePromise);
} while (size != bytesWritten && allowedBytes > bytesWritten); } while (size != bytesWritten && allowedBytes > bytesWritten);
size -= bytesWritten;
return true; return true;
} catch (Throwable e) { } finally {
error(e); size -= bytesWritten;
return false;
} }
} }
} }
@ -538,9 +529,17 @@ public class DefaultHttp2ConnectionEncoder implements Http2ConnectionEncoder {
this.endOfStream = endOfStream; this.endOfStream = endOfStream;
this.stream = stream; this.stream = stream;
this.promise = promise; this.promise = promise;
// Ensure error() gets called in case something goes wrong after the frame is passed to Netty.
promise.addListener(this); promise.addListener(this);
} }
@Override
public void writeComplete() {
if (endOfStream) {
lifecycleManager.closeLocalSide(stream, promise);
}
}
@Override @Override
public void operationComplete(ChannelFuture future) throws Exception { public void operationComplete(ChannelFuture future) throws Exception {
if (!future.isSuccess()) { if (!future.isSuccess()) {

View File

@ -23,11 +23,12 @@ import static io.netty.util.internal.ObjectUtil.checkNotNull;
import static java.lang.Math.max; import static java.lang.Math.max;
import static java.lang.Math.min; import static java.lang.Math.min;
import io.netty.channel.ChannelHandlerContext; import io.netty.channel.ChannelHandlerContext;
import io.netty.handler.codec.http2.Http2Stream.State;
import java.util.ArrayDeque; import java.util.ArrayDeque;
import java.util.Arrays; import java.util.Arrays;
import java.util.Comparator; import java.util.Comparator;
import java.util.Queue; import java.util.Deque;
/** /**
* Basic implementation of {@link Http2RemoteFlowController}. * Basic implementation of {@link Http2RemoteFlowController}.
@ -73,9 +74,27 @@ public class DefaultHttp2RemoteFlowController implements Http2RemoteFlowControll
@Override @Override
public void streamInactive(Http2Stream stream) { public void streamInactive(Http2Stream stream) {
// Any pending frames can never be written, clear and // Any pending frames can never be written, cancel and
// write errors for any pending frames. // write errors for any pending frames.
state(stream).clear(); state(stream).cancel();
}
@Override
public void streamHalfClosed(Http2Stream stream) {
if (State.HALF_CLOSED_LOCAL.equals(stream.state())) {
/**
* When this method is called there should not be any
* pending frames left if the API is used correctly. However,
* it is possible that a erroneous application can sneak
* in a frame even after having already written a frame with the
* END_STREAM flag set, as the stream state might not transition
* immediately to HALF_CLOSED_LOCAL / CLOSED due to flow control
* delaying the write.
*
* This is to cancel any such illegal writes.
*/
state(stream).cancel();
}
} }
@Override @Override
@ -156,13 +175,19 @@ public class DefaultHttp2RemoteFlowController implements Http2RemoteFlowControll
} }
// Save the context. We'll use this later when we write pending bytes. // Save the context. We'll use this later when we write pending bytes.
this.ctx = ctx; this.ctx = ctx;
FlowState state;
try { try {
FlowState state = state(stream); state = state(stream);
state.newFrame(payload); state.newFrame(payload);
} catch (Throwable t) {
payload.error(t);
return;
}
state.writeBytes(state.writableWindow()); state.writeBytes(state.writableWindow());
try {
flush(); flush();
} catch (Throwable e) { } catch (Throwable t) {
payload.error(e); payload.error(t);
} }
} }
@ -228,7 +253,6 @@ public class DefaultHttp2RemoteFlowController implements Http2RemoteFlowControll
return 0; return 0;
} }
int bytesAllocated = 0; int bytesAllocated = 0;
// If the number of streamable bytes for this tree will fit in the connection window // If the number of streamable bytes for this tree will fit in the connection window
// then there is no need to prioritize the bytes...everyone sends what they have // then there is no need to prioritize the bytes...everyone sends what they have
if (state.streamableBytesForTree() <= connectionWindow) { if (state.streamableBytesForTree() <= connectionWindow) {
@ -312,12 +336,16 @@ public class DefaultHttp2RemoteFlowController implements Http2RemoteFlowControll
* The outbound flow control state for a single stream. * The outbound flow control state for a single stream.
*/ */
final class FlowState { final class FlowState {
private final Queue<Frame> pendingWriteQueue; private final Deque<Frame> pendingWriteQueue;
private final Http2Stream stream; private final Http2Stream stream;
private int window; private int window;
private int pendingBytes; private int pendingBytes;
private int streamableBytesForTree; private int streamableBytesForTree;
private int allocated; private int allocated;
// Set to true while a frame is being written, false otherwise.
private boolean writing;
// Set to true if cancel() was called.
private boolean cancelled;
FlowState(Http2Stream stream, int initialWindowSize) { FlowState(Http2Stream stream, int initialWindowSize) {
this.stream = stream; this.stream = stream;
@ -422,7 +450,13 @@ public class DefaultHttp2RemoteFlowController implements Http2RemoteFlowControll
/** /**
* Clears the pending queue and writes errors for each remaining frame. * Clears the pending queue and writes errors for each remaining frame.
*/ */
void clear() { void cancel() {
cancelled = true;
// Ensure that the queue can't be modified while
// we are writing.
if (writing) {
return;
}
for (;;) { for (;;) {
Frame frame = pendingWriteQueue.poll(); Frame frame = pendingWriteQueue.poll();
if (frame == null) { if (frame == null) {
@ -496,19 +530,52 @@ public class DefaultHttp2RemoteFlowController implements Http2RemoteFlowControll
*/ */
int write(int allowedBytes) { int write(int allowedBytes) {
int before = payload.size(); int before = payload.size();
needFlush |= payload.write(Math.max(0, allowedBytes)); int writtenBytes = 0;
int writtenBytes = before - payload.size();
try { try {
connectionState().incrementStreamWindow(-writtenBytes); if (writing) {
incrementStreamWindow(-writtenBytes); throw new IllegalStateException("write is not re-entrant");
}
// Write the portion of the frame.
writing = true;
needFlush |= payload.write(Math.max(0, allowedBytes));
if (!cancelled && payload.size() == 0) {
// This frame has been fully written, remove this frame
// and notify the payload. Since we remove this frame
// first, we're guaranteed that its error method will not
// be called when we call cancel.
pendingWriteQueue.remove();
payload.writeComplete();
}
} catch (Throwable e) {
// Mark the state as cancelled, we'll clear the pending queue
// via cancel() below.
cancelled = true;
} finally {
writing = false;
// Make sure we always decrement the flow control windows
// by the bytes written.
writtenBytes = before - payload.size();
decrementFlowControlWindow(writtenBytes);
decrementPendingBytes(writtenBytes);
// If a cancellation occurred while writing, call cancel again to
// clear and error all of the pending writes.
if (cancelled) {
cancel();
}
}
return writtenBytes;
}
/**
* Decrement the per stream and connection flow control window by {@code bytes}.
*/
void decrementFlowControlWindow(int bytes) {
try {
connectionState().incrementStreamWindow(-bytes);
incrementStreamWindow(-bytes);
} catch (Http2Exception e) { // Should never get here since we're decrementing. } catch (Http2Exception e) { // Should never get here since we're decrementing.
throw new RuntimeException("Invalid window state when writing frame: " + e.getMessage(), e); throw new RuntimeException("Invalid window state when writing frame: " + e.getMessage(), e);
} }
decrementPendingBytes(writtenBytes);
if (payload.size() == 0) {
pendingWriteQueue.remove();
}
return writtenBytes;
} }
/** /**

View File

@ -52,18 +52,37 @@ public interface Http2RemoteFlowController extends Http2FlowController {
int size(); int size();
/** /**
* Signal an error and release any retained buffers. * Called to indicate that an error occurred before this object could be completely written.
* <p>
* The {@link Http2RemoteFlowController} will make exactly one call to either {@link #error(Throwable)} or
* {@link #writeComplete()}.
* </p>
*
* @param cause of the error. * @param cause of the error.
*/ */
void error(Throwable cause); void error(Throwable cause);
/**
* Called after this object has been successfully written.
* <p>
* The {@link Http2RemoteFlowController} will make exactly one call to either {@link #error(Throwable)} or
* {@link #writeComplete()}.
* </p>
*/
void writeComplete();
/** /**
* Writes up to {@code allowedBytes} of the encapsulated payload to the stream. Note that * Writes up to {@code allowedBytes} of the encapsulated payload to the stream. Note that
* a value of 0 may be passed which will allow payloads with flow-control size == 0 to be * a value of 0 may be passed which will allow payloads with flow-control size == 0 to be
* written. The flow-controller may call this method multiple times with different values until * written. The flow-controller may call this method multiple times with different values until
* the payload is fully written. * the payload is fully written.
* <p>
* When an exception is thrown the {@link Http2RemoteFlowController} will make a call to
* {@link #error(Throwable)}.
* </p>
* *
* @param allowedBytes an upper bound on the number of bytes the payload can write at this time. * @param allowedBytes an upper bound on the number of bytes the payload can write at this time.
* @throws Exception if an error occurs. The method must not call {@link #error(Throwable)} by itself.
* @return {@code true} if a flush is required, {@code false} otherwise. * @return {@code true} if a flush is required, {@code false} otherwise.
*/ */
boolean write(int allowedBytes); boolean write(int allowedBytes);

View File

@ -18,20 +18,20 @@ import static io.netty.buffer.Unpooled.wrappedBuffer;
import static io.netty.handler.codec.http2.Http2CodecUtil.DEFAULT_PRIORITY_WEIGHT; import static io.netty.handler.codec.http2.Http2CodecUtil.DEFAULT_PRIORITY_WEIGHT;
import static io.netty.handler.codec.http2.Http2CodecUtil.emptyPingBuf; import static io.netty.handler.codec.http2.Http2CodecUtil.emptyPingBuf;
import static io.netty.handler.codec.http2.Http2Error.PROTOCOL_ERROR; import static io.netty.handler.codec.http2.Http2Error.PROTOCOL_ERROR;
import static io.netty.handler.codec.http2.Http2Stream.State.HALF_CLOSED_LOCAL;
import static io.netty.handler.codec.http2.Http2Stream.State.IDLE; import static io.netty.handler.codec.http2.Http2Stream.State.IDLE;
import static io.netty.handler.codec.http2.Http2Stream.State.OPEN; import static io.netty.handler.codec.http2.Http2Stream.State.OPEN;
import static io.netty.handler.codec.http2.Http2Stream.State.RESERVED_LOCAL; import static io.netty.handler.codec.http2.Http2Stream.State.RESERVED_LOCAL;
import static io.netty.handler.codec.http2.Http2Stream.State.HALF_CLOSED_LOCAL;
import static io.netty.util.CharsetUtil.UTF_8; import static io.netty.util.CharsetUtil.UTF_8;
import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertNotNull; import static org.junit.Assert.assertNotNull;
import static org.junit.Assert.assertSame;
import static org.junit.Assert.assertTrue; import static org.junit.Assert.assertTrue;
import static org.junit.Assert.fail; import static org.junit.Assert.fail;
import static org.mockito.Matchers.any; import static org.mockito.Matchers.any;
import static org.mockito.Matchers.anyBoolean; import static org.mockito.Matchers.anyBoolean;
import static org.mockito.Matchers.anyInt; import static org.mockito.Matchers.anyInt;
import static org.mockito.Matchers.anyLong; import static org.mockito.Matchers.anyLong;
import static org.mockito.Matchers.anyShort;
import static org.mockito.Matchers.eq; import static org.mockito.Matchers.eq;
import static org.mockito.Mockito.doAnswer; import static org.mockito.Mockito.doAnswer;
import static org.mockito.Mockito.doNothing; import static org.mockito.Mockito.doNothing;
@ -49,6 +49,7 @@ import io.netty.channel.ChannelFutureListener;
import io.netty.channel.ChannelHandlerContext; import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.ChannelPromise; import io.netty.channel.ChannelPromise;
import io.netty.channel.DefaultChannelPromise; import io.netty.channel.DefaultChannelPromise;
import io.netty.handler.codec.http2.Http2RemoteFlowController.FlowControlled;
import io.netty.util.concurrent.ImmediateEventExecutor; import io.netty.util.concurrent.ImmediateEventExecutor;
import org.junit.Before; import org.junit.Before;
@ -175,9 +176,6 @@ public class DefaultHttp2ConnectionEncoderTest {
fail("Stream already closed"); fail("Stream already closed");
} else { } else {
streamClosed = (Boolean) invocationOnMock.getArguments()[4]; streamClosed = (Boolean) invocationOnMock.getArguments()[4];
if (streamClosed) {
assertSame(promise, receivedPromise);
}
} }
writtenPadding.add((Integer) invocationOnMock.getArguments()[3]); writtenPadding.add((Integer) invocationOnMock.getArguments()[3]);
ByteBuf data = (ByteBuf) invocationOnMock.getArguments()[2]; ByteBuf data = (ByteBuf) invocationOnMock.getArguments()[2];
@ -189,6 +187,21 @@ public class DefaultHttp2ConnectionEncoderTest {
return future; return future;
} }
}); });
when(writer.writeHeaders(eq(ctx), anyInt(), any(Http2Headers.class), anyInt(), anyShort(), anyBoolean(),
anyInt(), anyBoolean(), any(ChannelPromise.class)))
.then(new Answer<Object>() {
@Override
public Object answer(InvocationOnMock invocationOnMock) throws Throwable {
ChannelPromise receivedPromise = (ChannelPromise) invocationOnMock.getArguments()[8];
if (streamClosed) {
fail("Stream already closed");
} else {
streamClosed = (Boolean) invocationOnMock.getArguments()[5];
}
receivedPromise.trySuccess();
return future;
}
});
payloadCaptor = ArgumentCaptor.forClass(Http2RemoteFlowController.FlowControlled.class); payloadCaptor = ArgumentCaptor.forClass(Http2RemoteFlowController.FlowControlled.class);
doNothing().when(remoteFlow).sendFlowControlled(eq(ctx), eq(stream), payloadCaptor.capture()); doNothing().when(remoteFlow).sendFlowControlled(eq(ctx), eq(stream), payloadCaptor.capture());
when(ctx.alloc()).thenReturn(UnpooledByteBufAllocator.DEFAULT); when(ctx.alloc()).thenReturn(UnpooledByteBufAllocator.DEFAULT);
@ -227,19 +240,7 @@ public class DefaultHttp2ConnectionEncoderTest {
} }
@Test @Test
public void dataWriteShouldHalfCloseStream() throws Exception { public void dataLargerThanMaxFrameSizeShouldBeSplit() throws Exception {
reset(future);
final ByteBuf data = dummyData();
encoder.writeData(ctx, STREAM_ID, data, 0, true, promise);
assertEquals(1, payloadCaptor.getAllValues().size());
// Write the DATA frame completely
assertTrue(payloadCaptor.getValue().write(Integer.MAX_VALUE));
verify(lifecycleManager).closeLocalSide(eq(stream), eq(promise));
assertEquals(0, data.refCnt());
}
@Test
public void dataLargerThanMaxFrameSizeShouldBeSplit() {
when(frameSizePolicy.maxFrameSize()).thenReturn(3); when(frameSizePolicy.maxFrameSize()).thenReturn(3);
final ByteBuf data = dummyData(); final ByteBuf data = dummyData();
encoder.writeData(ctx, STREAM_ID, data, 0, true, promise); encoder.writeData(ctx, STREAM_ID, data, 0, true, promise);
@ -254,7 +255,7 @@ public class DefaultHttp2ConnectionEncoderTest {
} }
@Test @Test
public void paddingSplitOverFrame() { public void paddingSplitOverFrame() throws Exception {
when(frameSizePolicy.maxFrameSize()).thenReturn(5); when(frameSizePolicy.maxFrameSize()).thenReturn(5);
final ByteBuf data = dummyData(); final ByteBuf data = dummyData();
encoder.writeData(ctx, STREAM_ID, data, 5, true, promise); encoder.writeData(ctx, STREAM_ID, data, 5, true, promise);
@ -272,7 +273,7 @@ public class DefaultHttp2ConnectionEncoderTest {
} }
@Test @Test
public void frameShouldSplitPadding() { public void frameShouldSplitPadding() throws Exception {
when(frameSizePolicy.maxFrameSize()).thenReturn(5); when(frameSizePolicy.maxFrameSize()).thenReturn(5);
ByteBuf data = dummyData(); ByteBuf data = dummyData();
encoder.writeData(ctx, STREAM_ID, data, 10, true, promise); encoder.writeData(ctx, STREAM_ID, data, 10, true, promise);
@ -292,18 +293,18 @@ public class DefaultHttp2ConnectionEncoderTest {
} }
@Test @Test
public void emptyFrameShouldSplitPadding() { public void emptyFrameShouldSplitPadding() throws Exception {
ByteBuf data = Unpooled.buffer(0); ByteBuf data = Unpooled.buffer(0);
assertSplitPaddingOnEmptyBuffer(data); assertSplitPaddingOnEmptyBuffer(data);
assertEquals(0, data.refCnt()); assertEquals(0, data.refCnt());
} }
@Test @Test
public void singletonEmptyBufferShouldSplitPadding() { public void singletonEmptyBufferShouldSplitPadding() throws Exception {
assertSplitPaddingOnEmptyBuffer(Unpooled.EMPTY_BUFFER); assertSplitPaddingOnEmptyBuffer(Unpooled.EMPTY_BUFFER);
} }
private void assertSplitPaddingOnEmptyBuffer(ByteBuf data) { private void assertSplitPaddingOnEmptyBuffer(ByteBuf data) throws Exception {
when(frameSizePolicy.maxFrameSize()).thenReturn(5); when(frameSizePolicy.maxFrameSize()).thenReturn(5);
encoder.writeData(ctx, STREAM_ID, data, 10, true, promise); encoder.writeData(ctx, STREAM_ID, data, 10, true, promise);
assertEquals(payloadCaptor.getValue().size(), 10); assertEquals(payloadCaptor.getValue().size(), 10);
@ -344,24 +345,6 @@ public class DefaultHttp2ConnectionEncoderTest {
eq(DEFAULT_PRIORITY_WEIGHT), eq(false), eq(0), eq(false), eq(promise)); eq(DEFAULT_PRIORITY_WEIGHT), eq(false), eq(0), eq(false), eq(promise));
} }
@Test
public void headersWriteShouldCreateHalfClosedStream() throws Exception {
int streamId = 5;
when(stream.id()).thenReturn(streamId);
when(stream.state()).thenReturn(IDLE);
mockFutureAddListener(true);
when(local.createStream(eq(streamId))).thenReturn(stream);
when(writer.writeHeaders(eq(ctx), eq(streamId), eq(EmptyHttp2Headers.INSTANCE), eq(0),
eq(DEFAULT_PRIORITY_WEIGHT), eq(false), eq(0), eq(true), eq(promise))).thenReturn(future);
encoder.writeHeaders(ctx, streamId, EmptyHttp2Headers.INSTANCE, 0, true, promise);
verify(local).createStream(eq(streamId));
verify(stream).open(eq(true));
// Trigger the write and mark the promise successful to trigger listeners
payloadCaptor.getValue().write(0);
promise.trySuccess();
verify(lifecycleManager).closeLocalSide(eq(stream), eq(promise));
}
@Test @Test
public void headersWriteShouldOpenStreamForPush() throws Exception { public void headersWriteShouldOpenStreamForPush() throws Exception {
mockFutureAddListener(true); mockFutureAddListener(true);
@ -375,20 +358,6 @@ public class DefaultHttp2ConnectionEncoderTest {
eq(DEFAULT_PRIORITY_WEIGHT), eq(false), eq(0), eq(false), eq(promise)); eq(DEFAULT_PRIORITY_WEIGHT), eq(false), eq(0), eq(false), eq(promise));
} }
@Test
public void headersWriteShouldClosePushStream() throws Exception {
mockFutureAddListener(true);
when(stream.state()).thenReturn(RESERVED_LOCAL).thenReturn(HALF_CLOSED_LOCAL);
when(writer.writeHeaders(eq(ctx), eq(STREAM_ID), eq(EmptyHttp2Headers.INSTANCE), eq(0),
eq(DEFAULT_PRIORITY_WEIGHT), eq(false), eq(0), eq(true), eq(promise))).thenReturn(future);
encoder.writeHeaders(ctx, STREAM_ID, EmptyHttp2Headers.INSTANCE, 0, true, promise);
verify(stream).open(true);
// Trigger the write and mark the promise successful to trigger listeners
payloadCaptor.getValue().write(0);
promise.trySuccess();
verify(lifecycleManager).closeLocalSide(eq(stream), eq(promise));
}
@Test @Test
public void pushPromiseWriteAfterGoAwayShouldFail() throws Exception { public void pushPromiseWriteAfterGoAwayShouldFail() throws Exception {
when(connection.isGoAway()).thenReturn(true); when(connection.isGoAway()).thenReturn(true);
@ -467,6 +436,64 @@ public class DefaultHttp2ConnectionEncoderTest {
verify(writer).writeSettings(eq(ctx), eq(settings), eq(promise)); verify(writer).writeSettings(eq(ctx), eq(settings), eq(promise));
} }
@Test
public void dataWriteShouldCreateHalfClosedStream() {
mockSendFlowControlledWriteEverything();
ByteBuf data = dummyData();
encoder.writeData(ctx, STREAM_ID, data.retain(), 0, true, promise);
verify(remoteFlow).sendFlowControlled(eq(ctx), eq(stream), any(FlowControlled.class));
verify(lifecycleManager).closeLocalSide(stream, promise);
assertEquals(data.toString(UTF_8), writtenData.get(0));
data.release();
}
@Test
public void headersWriteShouldHalfCloseStream() throws Exception {
mockSendFlowControlledWriteEverything();
int streamId = 5;
when(stream.id()).thenReturn(streamId);
when(stream.state()).thenReturn(IDLE);
mockFutureAddListener(true);
when(local.createStream(eq(streamId))).thenReturn(stream);
when(writer.writeHeaders(eq(ctx), eq(streamId), eq(EmptyHttp2Headers.INSTANCE), eq(0),
eq(DEFAULT_PRIORITY_WEIGHT), eq(false), eq(0), eq(true), eq(promise)))
.thenReturn(future);
encoder.writeHeaders(ctx, streamId, EmptyHttp2Headers.INSTANCE, 0, true, promise);
verify(local).createStream(eq(streamId));
verify(stream).open(eq(true));
// Trigger the write and mark the promise successful to trigger listeners
payloadCaptor.getValue().write(0);
promise.trySuccess();
verify(lifecycleManager).closeLocalSide(eq(stream), eq(promise));
}
@Test
public void headersWriteShouldHalfClosePushStream() throws Exception {
mockSendFlowControlledWriteEverything();
mockFutureAddListener(true);
when(stream.state()).thenReturn(RESERVED_LOCAL).thenReturn(HALF_CLOSED_LOCAL);
when(writer.writeHeaders(eq(ctx), eq(STREAM_ID), eq(EmptyHttp2Headers.INSTANCE), eq(0),
eq(DEFAULT_PRIORITY_WEIGHT), eq(false), eq(0), eq(true), eq(promise)))
.thenReturn(future);
encoder.writeHeaders(ctx, STREAM_ID, EmptyHttp2Headers.INSTANCE, 0, true, promise);
verify(stream).open(true);
promise.trySuccess();
verify(lifecycleManager).closeLocalSide(eq(stream), eq(promise));
}
private void mockSendFlowControlledWriteEverything() {
doAnswer(new Answer<Void>() {
@Override
public Void answer(InvocationOnMock invocationOnMock) throws Throwable {
FlowControlled flowControlled = (FlowControlled) invocationOnMock.getArguments()[2];
flowControlled.write(Integer.MAX_VALUE);
flowControlled.writeComplete();
return null;
}
}).when(remoteFlow).sendFlowControlled(eq(ctx), eq(stream), payloadCaptor.capture());
}
private void mockFutureAddListener(boolean success) { private void mockFutureAddListener(boolean success) {
when(future.isSuccess()).thenReturn(success); when(future.isSuccess()).thenReturn(success);
if (!success) { if (!success) {

View File

@ -22,6 +22,10 @@ import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertFalse; import static org.junit.Assert.assertFalse;
import static org.junit.Assert.assertNotNull; import static org.junit.Assert.assertNotNull;
import static org.junit.Assert.assertTrue; import static org.junit.Assert.assertTrue;
import static org.junit.Assert.fail;
import static org.mockito.Matchers.any;
import static org.mockito.Matchers.anyInt;
import static org.mockito.Mockito.doAnswer;
import static org.mockito.Mockito.never; import static org.mockito.Mockito.never;
import static org.mockito.Mockito.reset; import static org.mockito.Mockito.reset;
import static org.mockito.Mockito.times; import static org.mockito.Mockito.times;
@ -31,6 +35,7 @@ import io.netty.buffer.ByteBuf;
import io.netty.channel.ChannelHandlerContext; import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.ChannelPromise; import io.netty.channel.ChannelPromise;
import io.netty.handler.codec.http2.Http2FrameWriter.Configuration; import io.netty.handler.codec.http2.Http2FrameWriter.Configuration;
import io.netty.handler.codec.http2.Http2RemoteFlowController.FlowControlled;
import io.netty.util.collection.IntObjectHashMap; import io.netty.util.collection.IntObjectHashMap;
import io.netty.util.collection.IntObjectMap; import io.netty.util.collection.IntObjectMap;
@ -40,7 +45,10 @@ import java.util.List;
import org.junit.Before; import org.junit.Before;
import org.junit.Test; import org.junit.Test;
import org.mockito.Mock; import org.mockito.Mock;
import org.mockito.Mockito;
import org.mockito.MockitoAnnotations; import org.mockito.MockitoAnnotations;
import org.mockito.invocation.InvocationOnMock;
import org.mockito.stubbing.Answer;
/** /**
* Tests for {@link DefaultHttp2RemoteFlowController}. * Tests for {@link DefaultHttp2RemoteFlowController}.
@ -987,6 +995,149 @@ public class DefaultHttp2RemoteFlowControllerTest {
streamableBytesForTree(streamD)); streamableBytesForTree(streamD));
} }
@Test
public void flowControlledWriteThrowsAnException() throws Exception {
final Http2RemoteFlowController.FlowControlled flowControlled = mockedFlowControlledThatThrowsOnWrite();
final Http2Stream stream = stream(STREAM_A);
doAnswer(new Answer<Void>() {
public Void answer(InvocationOnMock invocationOnMock) {
stream.closeLocalSide();
return null;
}
}).when(flowControlled).error(any(Throwable.class));
int windowBefore = window(STREAM_A);
controller.sendFlowControlled(ctx, stream, flowControlled);
verify(flowControlled, times(3)).write(anyInt());
verify(flowControlled).error(any(Throwable.class));
verify(flowControlled, never()).writeComplete();
assertEquals(90, windowBefore - window(STREAM_A));
}
@Test
public void flowControlledWriteAndErrorThrowAnException() throws Exception {
final Http2RemoteFlowController.FlowControlled flowControlled = mockedFlowControlledThatThrowsOnWrite();
final Http2Stream stream = stream(STREAM_A);
doAnswer(new Answer<Void>() {
public Void answer(InvocationOnMock invocationOnMock) {
throw new RuntimeException("error failed");
}
}).when(flowControlled).error(any(Throwable.class));
int windowBefore = window(STREAM_A);
boolean exceptionThrown = false;
try {
controller.sendFlowControlled(ctx, stream, flowControlled);
} catch (RuntimeException e) {
exceptionThrown = true;
} finally {
assertTrue(exceptionThrown);
}
verify(flowControlled, times(3)).write(anyInt());
verify(flowControlled).error(any(Throwable.class));
verify(flowControlled, never()).writeComplete();
assertEquals(90, windowBefore - window(STREAM_A));
}
@Test
public void flowControlledWriteCompleteThrowsAnException() throws Exception {
final Http2RemoteFlowController.FlowControlled flowControlled =
Mockito.mock(Http2RemoteFlowController.FlowControlled.class);
when(flowControlled.size()).thenReturn(100);
when(flowControlled.write(anyInt())).thenAnswer(new Answer<Boolean>() {
private int invocationCount;
@Override
public Boolean answer(InvocationOnMock invocationOnMock) throws Throwable {
switch(invocationCount) {
case 0:
when(flowControlled.size()).thenReturn(50);
invocationCount = 1;
return true;
case 1:
when(flowControlled.size()).thenReturn(20);
invocationCount = 2;
return true;
default:
when(flowControlled.size()).thenReturn(0);
return false;
}
}
});
final Http2Stream stream = stream(STREAM_A);
doAnswer(new Answer<Void>() {
public Void answer(InvocationOnMock invocationOnMock) {
throw new RuntimeException("writeComplete failed");
}
}).when(flowControlled).writeComplete();
int windowBefore = window(STREAM_A);
try {
controller.sendFlowControlled(ctx, stream, flowControlled);
} catch (Exception e) {
fail();
}
verify(flowControlled, times(3)).write(anyInt());
verify(flowControlled, never()).error(any(Throwable.class));
verify(flowControlled).writeComplete();
assertEquals(100, windowBefore - window(STREAM_A));
}
@Test
public void closeStreamInFlowControlledError() throws Exception {
final Http2RemoteFlowController.FlowControlled flowControlled =
Mockito.mock(Http2RemoteFlowController.FlowControlled.class);
final Http2Stream stream = stream(STREAM_A);
when(flowControlled.size()).thenReturn(100);
when(flowControlled.write(anyInt())).thenThrow(new RuntimeException("write failed"));
doAnswer(new Answer<Void>() {
public Void answer(InvocationOnMock invocationOnMock) {
stream.close();
return null;
}
}).when(flowControlled).error(any(Throwable.class));
controller.sendFlowControlled(ctx, stream, flowControlled);
verify(flowControlled).write(anyInt());
verify(flowControlled).error(any(Throwable.class));
verify(flowControlled, never()).writeComplete();
}
private static Http2RemoteFlowController.FlowControlled mockedFlowControlledThatThrowsOnWrite() throws Exception {
final Http2RemoteFlowController.FlowControlled flowControlled =
Mockito.mock(Http2RemoteFlowController.FlowControlled.class);
when(flowControlled.size()).thenReturn(100);
when(flowControlled.write(anyInt())).thenAnswer(new Answer<Boolean>() {
private int invocationCount;
@Override
public Boolean answer(InvocationOnMock invocationOnMock) throws Throwable {
switch(invocationCount) {
case 0:
when(flowControlled.size()).thenReturn(50);
invocationCount = 1;
return true;
case 1:
when(flowControlled.size()).thenReturn(20);
invocationCount = 2;
return true;
default:
when(flowControlled.size()).thenReturn(10);
throw new RuntimeException("Write failed");
}
}
});
return flowControlled;
}
private static int calculateStreamSizeSum(IntObjectMap<Integer> streamSizes, List<Integer> streamIds) { private static int calculateStreamSizeSum(IntObjectMap<Integer> streamSizes, List<Integer> streamIds) {
int sum = 0; int sum = 0;
for (Integer streamId : streamIds) { for (Integer streamId : streamIds) {
@ -1049,6 +1200,10 @@ public class DefaultHttp2RemoteFlowControllerTest {
this.t = t; this.t = t;
} }
@Override
public void writeComplete() {
}
@Override @Override
public boolean write(int allowedBytes) { public boolean write(int allowedBytes) {
if (allowedBytes <= 0 && currentSize != 0) { if (allowedBytes <= 0 && currentSize != 0) {