Remote flow controller incorrectly updates stream state

Motivation:

The `DefaultHttp2RemoteFlowController` does not correctly determine `hasFrame` when updating the stream state for the distributor. Adding a check to enforce `hasFrame` when `streamableBytes > 0` causes several test failures.

Modifications:

Modified `DefaultHttp2RemoteFlowController` to simplify the writing logic and to correct the bookkeeping for `hasFrame`.

Result:

The distributors are always called with valid arguments.
This commit is contained in:
nmittler 2015-11-03 15:05:32 -08:00
parent a4ebdd0eca
commit 96f9b0b91b
4 changed files with 84 additions and 102 deletions

View File

@ -315,12 +315,66 @@ public class DefaultHttp2RemoteFlowController implements Http2RemoteFlowControll
@Override @Override
int writeAllocatedBytes(int allocated) { int writeAllocatedBytes(int allocated) {
final int initialAllocated = allocated;
int writtenBytes;
// In case an exception is thrown we want to remember it and pass it to cancel(Throwable).
Throwable cause = null;
FlowControlled frame;
try { try {
// Perform the write. assert !writing;
return writeBytes(allocated); writing = true;
// Write the remainder of frames that we are allowed to
boolean writeOccurred = false;
while (!cancelled && (frame = peek()) != null) {
int maxBytes = min(allocated, writableWindow());
if (maxBytes <= 0 && frame.size() > 0) {
// The frame still has data, but the amount of allocated bytes has been exhausted.
// Don't write needless empty frames.
break;
}
writeOccurred = true;
int initialFrameSize = frame.size();
try {
frame.write(ctx, max(0, maxBytes));
if (frame.size() == 0) {
// This frame has been fully written, remove this frame and notify it.
// Since we remove this frame first, we're guaranteed that its error
// method will not be called when we call cancel.
pendingWriteQueue.remove();
frame.writeComplete();
}
} finally {
// Decrement allocated by how much was actually written.
allocated -= initialFrameSize - frame.size();
}
}
if (!writeOccurred) {
// Either there was no frame, or the amount of allocated bytes has been exhausted.
return -1;
}
} catch (Throwable t) {
// Mark the state as cancelled, we'll clear the pending queue via cancel() below.
cancelled = true;
cause = t;
} finally { } finally {
streamByteDistributor.updateStreamableBytes(this); writing = false;
// Make sure we always decrement the flow control windows
// by the bytes written.
writtenBytes = initialAllocated - allocated;
decrementPendingBytes(writtenBytes, false);
decrementFlowControlWindow(writtenBytes);
// If a cancellation occurred while writing, call cancel again to
// clear and error all of the pending writes.
if (cancelled) {
cancel(cause);
}
} }
return writtenBytes;
} }
@Override @Override
@ -354,11 +408,13 @@ public class DefaultHttp2RemoteFlowController implements Http2RemoteFlowControll
@Override @Override
void enqueueFrame(FlowControlled frame) { void enqueueFrame(FlowControlled frame) {
incrementPendingBytes(frame.size());
FlowControlled last = pendingWriteQueue.peekLast(); FlowControlled last = pendingWriteQueue.peekLast();
if (last == null || !last.merge(ctx, frame)) { if (last == null || !last.merge(ctx, frame)) {
pendingWriteQueue.offer(frame); pendingWriteQueue.offer(frame);
} }
// This must be called after adding to the queue in order so that hasFrame() is
// updated before updating the stream state.
incrementPendingBytes(frame.size(), true);
} }
@Override @Override
@ -400,89 +456,23 @@ public class DefaultHttp2RemoteFlowController implements Http2RemoteFlowControll
streamByteDistributor.updateStreamableBytes(this); streamByteDistributor.updateStreamableBytes(this);
} }
int writeBytes(int bytes) {
if (!hasFrame()) {
return -1;
}
// Check if the first frame is a "writable" frame to get the "-1" return status out of the way
FlowControlled frame = peek();
int maxBytes = min(bytes, writableWindow());
if (maxBytes <= 0 && frame.size() != 0) {
// The frame still has data, but the amount of allocated bytes has been exhausted.
return -1;
}
int originalBytes = bytes;
bytes -= write(frame, maxBytes);
// Write the remainder of frames that we are allowed to
while (hasFrame()) {
frame = peek();
maxBytes = min(bytes, writableWindow());
if (maxBytes <= 0 && frame.size() != 0) {
// The frame still has data, but the amount of allocated bytes has been exhausted.
break;
}
bytes -= write(frame, maxBytes);
}
return originalBytes - bytes;
}
/** /**
* Writes the frame and decrements the stream and connection window sizes. If the frame is in the pending * Increments the number of pending bytes for this node and optionally updates the
* queue, the written bytes are removed from this branch of the priority tree. * {@link StreamByteDistributor}.
*/ */
private int write(FlowControlled frame, int allowedBytes) { private void incrementPendingBytes(int numBytes, boolean updateStreamableBytes) {
int before = frame.size();
int writtenBytes;
// In case an exception is thrown we want to remember it and pass it to cancel(Throwable).
Throwable cause = null;
try {
assert !writing;
// Write the portion of the frame.
writing = true;
frame.write(ctx, max(0, allowedBytes));
if (!cancelled && frame.size() == 0) {
// This frame has been fully written, remove this frame and notify it. Since we remove this frame
// first, we're guaranteed that its error method will not be called when we call cancel.
pendingWriteQueue.remove();
frame.writeComplete();
}
} catch (Throwable t) {
// Mark the state as cancelled, we'll clear the pending queue via cancel() below.
cancelled = true;
cause = t;
} finally {
writing = false;
// Make sure we always decrement the flow control windows
// by the bytes written.
writtenBytes = before - frame.size();
decrementFlowControlWindow(writtenBytes);
decrementPendingBytes(writtenBytes);
// If a cancellation occurred while writing, call cancel again to
// clear and error all of the pending writes.
if (cancelled) {
cancel(cause);
}
}
return writtenBytes;
}
/**
* Increments the number of pending bytes for this node and updates the {@link StreamByteDistributor}.
*/
private void incrementPendingBytes(int numBytes) {
pendingBytes += numBytes; pendingBytes += numBytes;
streamByteDistributor.updateStreamableBytes(this);
monitor.incrementPendingBytes(numBytes); monitor.incrementPendingBytes(numBytes);
if (updateStreamableBytes) {
streamByteDistributor.updateStreamableBytes(this);
}
} }
/** /**
* If this frame is in the pending queue, decrements the number of pending bytes for the stream. * If this frame is in the pending queue, decrements the number of pending bytes for the stream.
*/ */
private void decrementPendingBytes(int bytes) { private void decrementPendingBytes(int bytes, boolean updateStreamableBytes) {
incrementPendingBytes(-bytes); incrementPendingBytes(-bytes, updateStreamableBytes);
} }
/** /**
@ -505,7 +495,7 @@ public class DefaultHttp2RemoteFlowController implements Http2RemoteFlowControll
*/ */
private void writeError(FlowControlled frame, Http2Exception cause) { private void writeError(FlowControlled frame, Http2Exception cause) {
assert ctx != null; assert ctx != null;
decrementPendingBytes(frame.size()); decrementPendingBytes(frame.size(), true);
frame.error(ctx, cause); frame.error(ctx, cause);
} }
} }

View File

@ -346,6 +346,7 @@ public final class PriorityStreamByteDistributor implements StreamByteDistributo
} }
void updateStreamableBytes(int newStreamableBytes, boolean hasFrame) { void updateStreamableBytes(int newStreamableBytes, boolean hasFrame) {
assert hasFrame || newStreamableBytes == 0;
this.hasFrame = hasFrame; this.hasFrame = hasFrame;
int delta = newStreamableBytes - streamableBytes; int delta = newStreamableBytes - streamableBytes;

View File

@ -26,8 +26,10 @@ import static org.junit.Assert.assertTrue;
import static org.junit.Assert.fail; import static org.junit.Assert.fail;
import static org.mockito.Matchers.any; import static org.mockito.Matchers.any;
import static org.mockito.Matchers.anyInt; import static org.mockito.Matchers.anyInt;
import static org.mockito.Mockito.atLeastOnce;
import static org.mockito.Mockito.doAnswer; import static org.mockito.Mockito.doAnswer;
import static org.mockito.Mockito.doThrow; import static org.mockito.Mockito.doThrow;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.never; import static org.mockito.Mockito.never;
import static org.mockito.Mockito.reset; import static org.mockito.Mockito.reset;
import static org.mockito.Mockito.times; import static org.mockito.Mockito.times;
@ -47,7 +49,6 @@ import junit.framework.AssertionFailedError;
import org.junit.Before; import org.junit.Before;
import org.junit.Test; import org.junit.Test;
import org.mockito.Mock; import org.mockito.Mock;
import org.mockito.Mockito;
import org.mockito.MockitoAnnotations; import org.mockito.MockitoAnnotations;
import org.mockito.invocation.InvocationOnMock; import org.mockito.invocation.InvocationOnMock;
import org.mockito.stubbing.Answer; import org.mockito.stubbing.Answer;
@ -709,7 +710,7 @@ public class DefaultHttp2RemoteFlowControllerTest {
controller.addFlowControlled(stream, flowControlled); controller.addFlowControlled(stream, flowControlled);
controller.writePendingBytes(); controller.writePendingBytes();
verify(flowControlled, times(3)).write(any(ChannelHandlerContext.class), anyInt()); verify(flowControlled, atLeastOnce()).write(any(ChannelHandlerContext.class), anyInt());
verify(flowControlled).error(any(ChannelHandlerContext.class), any(Throwable.class)); verify(flowControlled).error(any(ChannelHandlerContext.class), any(Throwable.class));
verify(flowControlled, never()).writeComplete(); verify(flowControlled, never()).writeComplete();
@ -742,7 +743,7 @@ public class DefaultHttp2RemoteFlowControllerTest {
fail(); fail();
} }
verify(flowControlled, times(3)).write(any(ChannelHandlerContext.class), anyInt()); verify(flowControlled, atLeastOnce()).write(any(ChannelHandlerContext.class), anyInt());
verify(flowControlled).error(any(ChannelHandlerContext.class), any(Throwable.class)); verify(flowControlled).error(any(ChannelHandlerContext.class), any(Throwable.class));
verify(flowControlled, never()).writeComplete(); verify(flowControlled, never()).writeComplete();
@ -753,7 +754,7 @@ public class DefaultHttp2RemoteFlowControllerTest {
@Test @Test
public void flowControlledWriteCompleteThrowsAnException() throws Exception { public void flowControlledWriteCompleteThrowsAnException() throws Exception {
final Http2RemoteFlowController.FlowControlled flowControlled = final Http2RemoteFlowController.FlowControlled flowControlled =
Mockito.mock(Http2RemoteFlowController.FlowControlled.class); mock(Http2RemoteFlowController.FlowControlled.class);
final AtomicInteger size = new AtomicInteger(150); final AtomicInteger size = new AtomicInteger(150);
doAnswer(new Answer<Integer>() { doAnswer(new Answer<Integer>() {
@Override @Override
@ -798,7 +799,7 @@ public class DefaultHttp2RemoteFlowControllerTest {
@Test @Test
public void closeStreamInFlowControlledError() throws Exception { public void closeStreamInFlowControlledError() throws Exception {
final Http2RemoteFlowController.FlowControlled flowControlled = final Http2RemoteFlowController.FlowControlled flowControlled =
Mockito.mock(Http2RemoteFlowController.FlowControlled.class); mock(Http2RemoteFlowController.FlowControlled.class);
final Http2Stream stream = stream(STREAM_A); final Http2Stream stream = stream(STREAM_A);
when(flowControlled.size()).thenReturn(100); when(flowControlled.size()).thenReturn(100);
doThrow(new RuntimeException("write failed")) doThrow(new RuntimeException("write failed"))
@ -922,25 +923,15 @@ public class DefaultHttp2RemoteFlowControllerTest {
private static Http2RemoteFlowController.FlowControlled mockedFlowControlledThatThrowsOnWrite() throws Exception { private static Http2RemoteFlowController.FlowControlled mockedFlowControlledThatThrowsOnWrite() throws Exception {
final Http2RemoteFlowController.FlowControlled flowControlled = final Http2RemoteFlowController.FlowControlled flowControlled =
Mockito.mock(Http2RemoteFlowController.FlowControlled.class); mock(Http2RemoteFlowController.FlowControlled.class);
when(flowControlled.size()).thenReturn(100); when(flowControlled.size()).thenReturn(100);
doAnswer(new Answer<Void>() { doAnswer(new Answer<Void>() {
private int invocationCount; private int invocationCount;
@Override @Override
public Void answer(InvocationOnMock invocationOnMock) throws Throwable { public Void answer(InvocationOnMock in) throws Throwable {
switch(invocationCount) { // Write most of the bytes and then fail
case 0: when(flowControlled.size()).thenReturn(10);
when(flowControlled.size()).thenReturn(50); throw new RuntimeException("Write failed");
invocationCount = 1;
return null;
case 1:
when(flowControlled.size()).thenReturn(20);
invocationCount = 2;
return null;
default:
when(flowControlled.size()).thenReturn(10);
throw new RuntimeException("Write failed");
}
} }
}).when(flowControlled).write(any(ChannelHandlerContext.class), anyInt()); }).when(flowControlled).write(any(ChannelHandlerContext.class), anyInt());
return flowControlled; return flowControlled;

View File

@ -136,10 +136,10 @@ public class PriorityStreamByteDistributorTest {
doNothing().when(writer).write(same(stream(STREAM_C)), eq(3)); doNothing().when(writer).write(same(stream(STREAM_C)), eq(3));
write(10); write(10);
verifyWrite(atMost(1), STREAM_A, 1); verifyWrite(STREAM_A, 1);
verifyWrite(atMost(1), STREAM_B, 2); verifyWrite(STREAM_B, 2);
verifyWrite(times(2), STREAM_C, 3); verifyWrite(times(2), STREAM_C, 3);
verifyWrite(atMost(1), STREAM_D, 4); verifyWrite(STREAM_D, 4);
} }
/** /**