Allowing inbound HTTP/2 frames after sending GOAWAY

Motivation:

If the client closes, a GOWAY is sent with a lastKnownStream of zero (since the remote side never created a stream). If there is still an exchange in progress, inbound frames for streams created by the client will be ignored because our ignore logic doesn't check to see if the stream was created by the remote endpoint. Frames for streams created by the local endpoint should continue to come through after sending GOAWAY.

Modifications:

Changed the decoder's streamCreatedAfterGoAwaySent logic to properly ensure that the stream was created remotely.

Result:

We now propertly process frames received after sending GOAWAY.
This commit is contained in:
nmittler 2015-05-03 08:49:49 -07:00
parent 60a94f0c5f
commit 4ae8bdc6ec
8 changed files with 108 additions and 40 deletions

View File

@ -147,12 +147,12 @@ public class DefaultHttp2Connection implements Http2Connection {
@Override
public boolean goAwayReceived() {
return localEndpoint.lastKnownStream >= 0;
return localEndpoint.lastStreamKnownByPeer >= 0;
}
@Override
public void goAwayReceived(final int lastKnownStream, long errorCode, ByteBuf debugData) {
localEndpoint.lastKnownStream(lastKnownStream);
localEndpoint.lastStreamKnownByPeer(lastKnownStream);
for (int i = 0; i < listeners.size(); ++i) {
try {
listeners.get(i).onGoAwayReceived(lastKnownStream, errorCode, debugData);
@ -178,12 +178,12 @@ public class DefaultHttp2Connection implements Http2Connection {
@Override
public boolean goAwaySent() {
return remoteEndpoint.lastKnownStream >= 0;
return remoteEndpoint.lastStreamKnownByPeer >= 0;
}
@Override
public void goAwaySent(final int lastKnownStream, long errorCode, ByteBuf debugData) {
remoteEndpoint.lastKnownStream(lastKnownStream);
remoteEndpoint.lastStreamKnownByPeer(lastKnownStream);
for (int i = 0; i < listeners.size(); ++i) {
try {
listeners.get(i).onGoAwaySent(lastKnownStream, errorCode, debugData);
@ -834,7 +834,7 @@ public class DefaultHttp2Connection implements Http2Connection {
private final boolean server;
private int nextStreamId;
private int lastStreamCreated;
private int lastKnownStream = -1;
private int lastStreamKnownByPeer = -1;
private boolean pushToAllowed = true;
private F flowController;
private int maxActiveStreams;
@ -989,12 +989,12 @@ public class DefaultHttp2Connection implements Http2Connection {
}
@Override
public int lastKnownStream() {
return lastKnownStream;
public int lastStreamKnownByPeer() {
return lastStreamKnownByPeer;
}
private void lastKnownStream(int lastKnownStream) {
this.lastKnownStream = lastKnownStream;
private void lastStreamKnownByPeer(int lastKnownStream) {
this.lastStreamKnownByPeer = lastKnownStream;
}
@Override
@ -1013,10 +1013,10 @@ public class DefaultHttp2Connection implements Http2Connection {
}
private void checkNewStreamAllowed(int streamId) throws Http2Exception {
if (goAwayReceived() && streamId > localEndpoint.lastKnownStream()) {
if (goAwayReceived() && streamId > localEndpoint.lastStreamKnownByPeer()) {
throw connectionError(PROTOCOL_ERROR, "Cannot create stream %d since this endpoint has received a " +
"GOAWAY frame with last stream id %d.", streamId,
localEndpoint.lastKnownStream());
localEndpoint.lastStreamKnownByPeer());
}
if (streamId < 0) {
throw new Http2NoMoreStreamIdsException();

View File

@ -527,10 +527,22 @@ public class DefaultHttp2ConnectionDecoder implements Http2ConnectionDecoder {
onUnknownFrame0(ctx, frameType, streamId, flags, payload);
}
/**
* Helper method for determining whether or not to ignore inbound frames. A stream is considered to be created
* after a go away is sent if the following conditions hold:
* <p/>
* <ul>
* <li>A {@code GOAWAY} must have been sent by the local endpoint</li>
* <li>The {@code streamId} must identify a legitimate stream id for the remote endpoint to be creating</li>
* <li>{@code streamId} is greater than the Last Known Stream ID which was sent by the local endpoint
* in the last {@code GOAWAY} frame</li>
* </ul>
* <p/>
*/
private boolean streamCreatedAfterGoAwaySent(int streamId) {
// Ignore inbound frames after a GOAWAY was sent and the stream id is greater than
// the last stream id set in the GOAWAY frame.
return connection.goAwaySent() && streamId > connection.remote().lastKnownStream();
Http2Connection.Endpoint<?> remote = connection.remote();
return connection.goAwaySent() && remote.isValidStreamId(streamId) &&
streamId > remote.lastStreamKnownByPeer();
}
private void verifyStreamMayHaveExisted(int streamId) throws Http2Exception {

View File

@ -188,7 +188,7 @@ public interface Http2Connection {
* <li>The connection is marked as going away.</li>
* </ul>
* <p>
* This method differs from {@link #createdStreamId(int)} because the initial state of the stream will be
* This method differs from {@link #createIdleStream(int)} because the initial state of the stream will be
* Immediately set before notifying {@link Listener}s. The state transition is sensitive to {@code halfClosed}
* and is defined by {@link Http2Stream#open(boolean)}.
* @param streamId The ID of the stream
@ -260,7 +260,7 @@ public interface Http2Connection {
* If a GOAWAY was received for this endpoint, this will be the last stream ID from the
* GOAWAY frame. Otherwise, this will be {@code -1}.
*/
int lastKnownStream();
int lastStreamKnownByPeer();
/**
* Gets the flow controller for this endpoint.

View File

@ -556,10 +556,10 @@ public class Http2ConnectionHandler extends ByteToMessageDecoder implements Http
final ByteBuf debugData, ChannelPromise promise) {
try {
final Http2Connection connection = connection();
if (connection.goAwaySent() && connection.remote().lastKnownStream() < lastStreamId) {
if (connection.goAwaySent() && lastStreamId > connection.remote().lastStreamKnownByPeer()) {
throw connectionError(PROTOCOL_ERROR, "Last stream identifier must not increase between " +
"sending multiple GOAWAY frames (was '%d', is '%d').",
connection.remote().lastKnownStream(),
connection.remote().lastStreamKnownByPeer(),
lastStreamId);
}
connection.goAwaySent(lastStreamId, errorCode, debugData);

View File

@ -35,7 +35,6 @@ import static org.mockito.Matchers.eq;
import static org.mockito.Mockito.doAnswer;
import static org.mockito.Mockito.doNothing;
import static org.mockito.Mockito.never;
import static org.mockito.Mockito.reset;
import static org.mockito.Mockito.times;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.verifyNoMoreInteractions;
@ -65,9 +64,9 @@ import java.util.concurrent.atomic.AtomicInteger;
* Tests for {@link DefaultHttp2ConnectionDecoder}.
*/
public class DefaultHttp2ConnectionDecoderTest {
private static final int STREAM_ID = 1;
private static final int STREAM_ID = 3;
private static final int PUSH_STREAM_ID = 2;
private static final int STREAM_DEPENDENCY_ID = 3;
private static final int STREAM_DEPENDENCY_ID = 5;
private Http2ConnectionDecoder decoder;
@ -169,8 +168,9 @@ public class DefaultHttp2ConnectionDecoderTest {
}
@Test
public void dataReadAfterGoAwayShouldApplyFlowControl() throws Exception {
when(connection.goAwaySent()).thenReturn(true);
public void dataReadAfterGoAwaySentShouldApplyFlowControl() throws Exception {
mockGoAwaySent();
final ByteBuf data = dummyData();
int padding = 10;
int processedBytes = data.readableBytes() + padding;
@ -187,6 +187,26 @@ public class DefaultHttp2ConnectionDecoderTest {
}
}
@Test
public void dataReadAfterGoAwaySentShouldAllowFramesForStreamCreatedByLocalEndpoint() throws Exception {
mockGoAwaySentShouldAllowFramesForStreamCreatedByLocalEndpoint();
final ByteBuf data = dummyData();
int padding = 10;
int processedBytes = data.readableBytes() + padding;
mockFlowControl(processedBytes);
try {
decode().onDataRead(ctx, STREAM_ID, data, padding, true);
verify(localFlow).receiveFlowControlledFrame(eq(ctx), eq(stream), eq(data), eq(padding), eq(true));
verify(localFlow).consumeBytes(eq(ctx), eq(stream), eq(processedBytes));
// Verify that the event was absorbed and not propagated to the observer.
verify(listener).onDataRead(eq(ctx), anyInt(), any(ByteBuf.class), anyInt(), anyBoolean());
} finally {
data.release();
}
}
@Test(expected = Http2Exception.class)
public void dataReadForUnknownStreamShouldApplyFlowControlAndFail() throws Exception {
when(connection.streamMayHaveExisted(STREAM_ID)).thenReturn(false);
@ -262,10 +282,10 @@ public class DefaultHttp2ConnectionDecoderTest {
}
@Test
public void dataReadAfterGoAwayForStreamInInvalidStateShouldIgnore() throws Exception {
public void dataReadAfterGoAwaySentForStreamInInvalidStateShouldIgnore() throws Exception {
// Throw an exception when checking stream state.
when(stream.state()).thenReturn(Http2Stream.State.CLOSED);
when(connection.goAwaySent()).thenReturn(true);
mockGoAwaySent();
final ByteBuf data = dummyData();
try {
decode().onDataRead(ctx, STREAM_ID, data, 10, true);
@ -459,13 +479,21 @@ public class DefaultHttp2ConnectionDecoderTest {
}
@Test
public void pushPromiseReadAfterGoAwayShouldBeIgnored() throws Exception {
when(connection.goAwaySent()).thenReturn(true);
public void pushPromiseReadAfterGoAwaySentShouldBeIgnored() throws Exception {
mockGoAwaySent();
decode().onPushPromiseRead(ctx, STREAM_ID, PUSH_STREAM_ID, EmptyHttp2Headers.INSTANCE, 0);
verify(remote, never()).reservePushStream(anyInt(), any(Http2Stream.class));
verify(listener, never()).onPushPromiseRead(eq(ctx), anyInt(), anyInt(), any(Http2Headers.class), anyInt());
}
@Test
public void pushPromiseReadAfterGoAwayShouldAllowFramesForStreamCreatedByLocalEndpoint() throws Exception {
mockGoAwaySentShouldAllowFramesForStreamCreatedByLocalEndpoint();
decode().onPushPromiseRead(ctx, STREAM_ID, PUSH_STREAM_ID, EmptyHttp2Headers.INSTANCE, 0);
verify(remote).reservePushStream(anyInt(), any(Http2Stream.class));
verify(listener).onPushPromiseRead(eq(ctx), anyInt(), anyInt(), any(Http2Headers.class), anyInt());
}
@Test(expected = Http2Exception.class)
public void pushPromiseReadForUnknownStreamShouldThrow() throws Exception {
when(connection.stream(STREAM_ID)).thenReturn(null);
@ -481,13 +509,21 @@ public class DefaultHttp2ConnectionDecoderTest {
}
@Test
public void priorityReadAfterGoAwayShouldBeIgnored() throws Exception {
when(connection.goAwaySent()).thenReturn(true);
public void priorityReadAfterGoAwaySentShouldBeIgnored() throws Exception {
mockGoAwaySent();
decode().onPriorityRead(ctx, STREAM_ID, 0, (short) 255, true);
verify(stream, never()).setPriority(anyInt(), anyShort(), anyBoolean());
verify(listener, never()).onPriorityRead(eq(ctx), anyInt(), anyInt(), anyShort(), anyBoolean());
}
@Test
public void priorityReadAfterGoAwaySentShouldAllowFramesForStreamCreatedByLocalEndpoint() throws Exception {
mockGoAwaySentShouldAllowFramesForStreamCreatedByLocalEndpoint();
decode().onPriorityRead(ctx, STREAM_ID, 0, (short) 255, true);
verify(stream).setPriority(anyInt(), anyShort(), anyBoolean());
verify(listener).onPriorityRead(eq(ctx), anyInt(), anyInt(), anyShort(), anyBoolean());
}
@Test
public void priorityReadForUnknownStreamShouldBeIgnored() throws Exception {
when(connection.stream(STREAM_ID)).thenReturn(null);
@ -521,13 +557,21 @@ public class DefaultHttp2ConnectionDecoderTest {
}
@Test
public void windowUpdateReadAfterGoAwayShouldBeIgnored() throws Exception {
when(connection.goAwaySent()).thenReturn(true);
public void windowUpdateReadAfterGoAwaySentShouldBeIgnored() throws Exception {
mockGoAwaySent();
decode().onWindowUpdateRead(ctx, STREAM_ID, 10);
verify(remoteFlow, never()).incrementWindowSize(eq(ctx), any(Http2Stream.class), anyInt());
verify(listener, never()).onWindowUpdateRead(eq(ctx), anyInt(), anyInt());
}
@Test
public void windowUpdateReadAfterGoAwaySentShouldAllowFramesForStreamCreatedByLocalEndpoint() throws Exception {
mockGoAwaySentShouldAllowFramesForStreamCreatedByLocalEndpoint();
decode().onWindowUpdateRead(ctx, STREAM_ID, 10);
verify(remoteFlow).incrementWindowSize(eq(ctx), any(Http2Stream.class), anyInt());
verify(listener).onWindowUpdateRead(eq(ctx), anyInt(), anyInt());
}
@Test(expected = Http2Exception.class)
public void windowUpdateReadForUnknownStreamShouldThrow() throws Exception {
when(connection.streamMayHaveExisted(STREAM_ID)).thenReturn(false);
@ -651,4 +695,16 @@ public class DefaultHttp2ConnectionDecoderTest {
}).when(listener).onDataRead(any(ChannelHandlerContext.class), anyInt(),
any(ByteBuf.class), anyInt(), anyBoolean());
}
private void mockGoAwaySent() {
when(connection.goAwaySent()).thenReturn(true);
when(remote.isValidStreamId(STREAM_ID)).thenReturn(true);
when(remote.lastStreamKnownByPeer()).thenReturn(0);
}
private void mockGoAwaySentShouldAllowFramesForStreamCreatedByLocalEndpoint() {
when(connection.goAwaySent()).thenReturn(true);
when(remote.isValidStreamId(STREAM_ID)).thenReturn(false);
when(remote.lastStreamKnownByPeer()).thenReturn(0);
}
}

View File

@ -517,7 +517,7 @@ public class DefaultHttp2ConnectionEncoderTest {
@Test
public void canWriteDataFrameAfterGoAwaySent() {
when(connection.goAwaySent()).thenReturn(true);
when(remote.lastKnownStream()).thenReturn(0);
when(remote.lastStreamKnownByPeer()).thenReturn(0);
ByteBuf data = mock(ByteBuf.class);
encoder.writeData(ctx, STREAM_ID, data, 0, false, promise);
verify(remoteFlow).sendFlowControlled(eq(ctx), eq(stream), any(FlowControlled.class));
@ -526,7 +526,7 @@ public class DefaultHttp2ConnectionEncoderTest {
@Test
public void canWriteHeaderFrameAfterGoAwaySent() {
when(connection.goAwaySent()).thenReturn(true);
when(remote.lastKnownStream()).thenReturn(0);
when(remote.lastStreamKnownByPeer()).thenReturn(0);
encoder.writeHeaders(ctx, STREAM_ID, EmptyHttp2Headers.INSTANCE, 0, false, promise);
verify(remoteFlow).sendFlowControlled(eq(ctx), eq(stream), any(FlowControlled.class));
}
@ -534,7 +534,7 @@ public class DefaultHttp2ConnectionEncoderTest {
@Test
public void canWriteDataFrameAfterGoAwayReceived() {
when(connection.goAwayReceived()).thenReturn(true);
when(local.lastKnownStream()).thenReturn(STREAM_ID);
when(local.lastStreamKnownByPeer()).thenReturn(STREAM_ID);
ByteBuf data = mock(ByteBuf.class);
encoder.writeData(ctx, STREAM_ID, data, 0, false, promise);
verify(remoteFlow).sendFlowControlled(eq(ctx), eq(stream), any(FlowControlled.class));
@ -543,7 +543,7 @@ public class DefaultHttp2ConnectionEncoderTest {
@Test
public void canWriteHeaderFrameAfterGoAwayReceived() {
when(connection.goAwayReceived()).thenReturn(true);
when(local.lastKnownStream()).thenReturn(STREAM_ID);
when(local.lastStreamKnownByPeer()).thenReturn(STREAM_ID);
encoder.writeHeaders(ctx, STREAM_ID, EmptyHttp2Headers.INSTANCE, 0, false, promise);
verify(remoteFlow).sendFlowControlled(eq(ctx), eq(stream), any(FlowControlled.class));
}

View File

@ -93,10 +93,10 @@ public class DefaultHttp2ConnectionTest {
assertEquals(State.OPEN, stream1.state());
assertEquals(State.CLOSED, stream2.state());
assertEquals(State.OPEN, remoteStream.state());
assertEquals(3, client.local().lastKnownStream());
assertEquals(3, client.local().lastStreamKnownByPeer());
assertEquals(5, client.local().lastStreamCreated());
// The remote endpoint must not be affected by a received GOAWAY frame.
assertEquals(-1, client.remote().lastKnownStream());
assertEquals(-1, client.remote().lastStreamKnownByPeer());
assertEquals(State.OPEN, remoteStream.state());
}
@ -111,10 +111,10 @@ public class DefaultHttp2ConnectionTest {
assertEquals(State.OPEN, stream1.state());
assertEquals(State.CLOSED, stream2.state());
assertEquals(3, server.remote().lastKnownStream());
assertEquals(3, server.remote().lastStreamKnownByPeer());
assertEquals(5, server.remote().lastStreamCreated());
// The local endpoint must not be affected by a sent GOAWAY frame.
assertEquals(-1, server.local().lastKnownStream());
assertEquals(-1, server.local().lastStreamKnownByPeer());
assertEquals(State.OPEN, localStream.state());
}

View File

@ -351,7 +351,7 @@ public class Http2ConnectionHandlerTest {
assertFalse(promise.isDone());
when(connection.goAwaySent()).thenReturn(true);
when(remote.lastKnownStream()).thenReturn(STREAM_ID);
when(remote.lastStreamKnownByPeer()).thenReturn(STREAM_ID);
handler.goAway(ctx, STREAM_ID + 2, errorCode, data, promise);
assertTrue(promise.isDone());
assertFalse(promise.isSuccess());