HTTP/2 goaway connection state update sequencing (#8080)

Motivation:
The Http2Connection state is updated by the DefaultHttp2ConnectionDecoder after the frame listener is notified of the goaway frame. If the listener sends a frame synchronously this means the connection state will not know about the goaway it just received and we may send frames that are not allowed on the connection. This may also mean a stream object is created but it may never get taken out of the stream map unless some other event occurs (e.g. timeout).

Modifications:
- The Http2Connection state should be updated before the listener is notified of the goaway
- The Http2Connection state modification and validation should be self contained when processing a goaway instead of partially in the decoder.

Result:
No more creating streams and sending frames after a goaway has been sent or received.
This commit is contained in:
Scott Mitchell 2018-07-03 19:51:16 -07:00 committed by GitHub
parent 7f5e77484c
commit 804d8434dc
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
10 changed files with 200 additions and 90 deletions

View File

@ -25,7 +25,6 @@ import io.netty.util.concurrent.Future;
import io.netty.util.concurrent.Promise;
import io.netty.util.concurrent.UnaryPromiseNotifier;
import io.netty.util.internal.EmptyArrays;
import io.netty.util.internal.PlatformDependent;
import io.netty.util.internal.UnstableApi;
import io.netty.util.internal.logging.InternalLogger;
import io.netty.util.internal.logging.InternalLoggerFactory;
@ -225,7 +224,12 @@ public class DefaultHttp2Connection implements Http2Connection {
}
@Override
public void goAwayReceived(final int lastKnownStream, long errorCode, ByteBuf debugData) {
public void goAwayReceived(final int lastKnownStream, long errorCode, ByteBuf debugData) throws Http2Exception {
if (localEndpoint.lastStreamKnownByPeer() >= 0 && localEndpoint.lastStreamKnownByPeer() < lastKnownStream) {
throw connectionError(PROTOCOL_ERROR, "lastStreamId MUST NOT increase. Current value: %d new value: %d",
localEndpoint.lastStreamKnownByPeer(), lastKnownStream);
}
localEndpoint.lastStreamKnownByPeer(lastKnownStream);
for (int i = 0; i < listeners.size(); ++i) {
try {
@ -235,19 +239,7 @@ public class DefaultHttp2Connection implements Http2Connection {
}
}
try {
forEachActiveStream(new Http2StreamVisitor() {
@Override
public boolean visit(Http2Stream stream) {
if (stream.id() > lastKnownStream && localEndpoint.isValidStreamId(stream.id())) {
stream.close();
}
return true;
}
});
} catch (Http2Exception e) {
PlatformDependent.throwException(e);
}
closeStreamsGreaterThanLastKnownStreamId(lastKnownStream, localEndpoint);
}
@Override
@ -256,7 +248,20 @@ public class DefaultHttp2Connection implements Http2Connection {
}
@Override
public void goAwaySent(final int lastKnownStream, long errorCode, ByteBuf debugData) {
public boolean goAwaySent(final int lastKnownStream, long errorCode, ByteBuf debugData) throws Http2Exception {
if (remoteEndpoint.lastStreamKnownByPeer() >= 0) {
// Protect against re-entrancy. Could happen if writing the frame fails, and error handling
// treating this is a connection handler and doing a graceful shutdown...
if (lastKnownStream == remoteEndpoint.lastStreamKnownByPeer()) {
return false;
}
if (lastKnownStream > remoteEndpoint.lastStreamKnownByPeer()) {
throw connectionError(PROTOCOL_ERROR, "Last stream identifier must not increase between " +
"sending multiple GOAWAY frames (was '%d', is '%d').",
remoteEndpoint.lastStreamKnownByPeer(), lastKnownStream);
}
}
remoteEndpoint.lastStreamKnownByPeer(lastKnownStream);
for (int i = 0; i < listeners.size(); ++i) {
try {
@ -266,19 +271,21 @@ public class DefaultHttp2Connection implements Http2Connection {
}
}
try {
closeStreamsGreaterThanLastKnownStreamId(lastKnownStream, remoteEndpoint);
return true;
}
private void closeStreamsGreaterThanLastKnownStreamId(final int lastKnownStream,
final DefaultEndpoint<?> endpoint) throws Http2Exception {
forEachActiveStream(new Http2StreamVisitor() {
@Override
public boolean visit(Http2Stream stream) {
if (stream.id() > lastKnownStream && remoteEndpoint.isValidStreamId(stream.id())) {
if (stream.id() > lastKnownStream && endpoint.isValidStreamId(stream.id())) {
stream.close();
}
return true;
}
});
} catch (Http2Exception e) {
PlatformDependent.throwException(e);
}
}
/**
@ -863,10 +870,10 @@ public class DefaultHttp2Connection implements Http2Connection {
private void checkNewStreamAllowed(int streamId, State state) throws Http2Exception {
assert state != IDLE;
if (goAwayReceived() && streamId > localEndpoint.lastStreamKnownByPeer()) {
if (lastStreamKnownByPeer >= 0 && streamId > lastStreamKnownByPeer) {
throw streamError(streamId, REFUSED_STREAM,
"Cannot create stream %d since this endpoint has received a GOAWAY frame with last stream id %d.",
streamId, localEndpoint.lastStreamKnownByPeer());
"Cannot create stream %d greater than Last-Stream-ID %d from GOAWAY.",
streamId, lastStreamKnownByPeer);
}
if (!isValidStreamId(streamId)) {
if (streamId < 0) {

View File

@ -158,12 +158,8 @@ public class DefaultHttp2ConnectionDecoder implements Http2ConnectionDecoder {
void onGoAwayRead0(ChannelHandlerContext ctx, int lastStreamId, long errorCode, ByteBuf debugData)
throws Http2Exception {
if (connection.goAwayReceived() && connection.local().lastStreamKnownByPeer() < lastStreamId) {
throw connectionError(PROTOCOL_ERROR, "lastStreamId MUST NOT increase. Current value: %d new value: %d",
connection.local().lastStreamKnownByPeer(), lastStreamId);
}
listener.onGoAwayRead(ctx, lastStreamId, errorCode, debugData);
connection.goAwayReceived(lastStreamId, errorCode, debugData);
listener.onGoAwayRead(ctx, lastStreamId, errorCode, debugData);
}
void onUnknownFrame0(ChannelHandlerContext ctx, byte frameType, int streamId, Http2Flags flags,
@ -535,12 +531,18 @@ public class DefaultHttp2ConnectionDecoder implements Http2ConnectionDecoder {
throw streamError(streamId, STREAM_CLOSED, "Received %s frame for an unknown stream %d",
frameName, streamId);
} else if (stream.isResetSent() || streamCreatedAfterGoAwaySent(streamId)) {
// If we have sent a reset stream it is assumed the stream will be closed after the write completes.
// If we have not sent a reset, but the stream was created after a GoAway this is not supported by
// DefaultHttp2Connection and if a custom Http2Connection is used it is assumed the lifetime is managed
// elsewhere so we don't close the stream or otherwise modify the stream's state.
if (logger.isInfoEnabled()) {
logger.info("{} ignoring {} frame for stream {} {}", ctx.channel(), frameName,
stream.isResetSent() ? "RST_STREAM sent." :
("Stream created after GOAWAY sent. Last known stream by peer " +
connection.remote().lastStreamKnownByPeer()));
}
return true;
}
return false;

View File

@ -326,8 +326,15 @@ public interface Http2Connection {
/**
* Indicates that a {@code GOAWAY} was received from the remote endpoint and sets the last known stream.
* @param lastKnownStream The Last-Stream-ID in the
* <a href="https://tools.ietf.org/html/rfc7540#section-6.8">GOAWAY</a> frame.
* @param errorCode the Error Code in the
* <a href="https://tools.ietf.org/html/rfc7540#section-6.8">GOAWAY</a> frame.
* @param message The Additional Debug Data in the
* <a href="https://tools.ietf.org/html/rfc7540#section-6.8">GOAWAY</a> frame. Note that reference count ownership
* belongs to the caller (ownership is not transferred to this method).
*/
void goAwayReceived(int lastKnownStream, long errorCode, ByteBuf message);
void goAwayReceived(int lastKnownStream, long errorCode, ByteBuf message) throws Http2Exception;
/**
* Indicates whether or not a {@code GOAWAY} was sent to the remote endpoint.
@ -335,7 +342,15 @@ public interface Http2Connection {
boolean goAwaySent();
/**
* Indicates that a {@code GOAWAY} was sent to the remote endpoint and sets the last known stream.
* Updates the local state of this {@link Http2Connection} as a result of a {@code GOAWAY} to send to the remote
* endpoint.
* @param lastKnownStream The Last-Stream-ID in the
* <a href="https://tools.ietf.org/html/rfc7540#section-6.8">GOAWAY</a> frame.
* @param errorCode the Error Code in the
* <a href="https://tools.ietf.org/html/rfc7540#section-6.8">GOAWAY</a> frame.
* <a href="https://tools.ietf.org/html/rfc7540#section-6.8">GOAWAY</a> frame. Note that reference count ownership
* belongs to the caller (ownership is not transferred to this method).
* @return {@code true} if the corresponding {@code GOAWAY} frame should be sent to the remote endpoint.
*/
void goAwaySent(int lastKnownStream, long errorCode, ByteBuf message);
boolean goAwaySent(int lastKnownStream, long errorCode, ByteBuf message) throws Http2Exception;
}

View File

@ -794,25 +794,19 @@ public class Http2ConnectionHandler extends ByteToMessageDecoder implements Http
@Override
public ChannelFuture goAway(final ChannelHandlerContext ctx, final int lastStreamId, final long errorCode,
final ByteBuf debugData, ChannelPromise promise) {
try {
promise = promise.unvoid();
final Http2Connection connection = connection();
if (connection().goAwaySent()) {
// Protect against re-entrancy. Could happen if writing the frame fails, and error handling
// treating this is a connection handler and doing a graceful shutdown...
if (lastStreamId == connection().remote().lastStreamKnownByPeer()) {
// Release the data and notify the promise
try {
if (!connection.goAwaySent(lastStreamId, errorCode, debugData)) {
debugData.release();
return promise.setSuccess();
promise.trySuccess();
return promise;
}
if (lastStreamId > connection.remote().lastStreamKnownByPeer()) {
throw connectionError(PROTOCOL_ERROR, "Last stream identifier must not increase between " +
"sending multiple GOAWAY frames (was '%d', is '%d').",
connection.remote().lastStreamKnownByPeer(), lastStreamId);
} catch (Throwable cause) {
debugData.release();
promise.tryFailure(cause);
return promise;
}
}
connection.goAwaySent(lastStreamId, errorCode, debugData);
// Need to retain before we write the buffer because if we do it after the refCnt could already be 0 and
// result in an IllegalRefCountException.
@ -831,10 +825,6 @@ public class Http2ConnectionHandler extends ByteToMessageDecoder implements Http
}
return future;
} catch (Throwable cause) { // Make sure to catch Throwable because we are doing a retain() in this method.
debugData.release();
return promise.setFailure(cause);
}
}
/**

View File

@ -674,13 +674,6 @@ public class DefaultHttp2ConnectionDecoderTest {
verify(listener).onRstStreamRead(eq(ctx), anyInt(), anyLong());
}
@Test(expected = Http2Exception.class)
public void goawayIncreasedLastStreamIdShouldThrow() throws Exception {
when(local.lastStreamKnownByPeer()).thenReturn(1);
when(connection.goAwayReceived()).thenReturn(true);
decode().onGoAwayRead(ctx, 3, 2L, EMPTY_BUFFER);
}
@Test(expected = Http2Exception.class)
public void rstStreamReadForUnknownStreamShouldThrow() throws Exception {
when(connection.streamMayHaveExisted(STREAM_ID)).thenReturn(false);

View File

@ -770,7 +770,7 @@ public class DefaultHttp2ConnectionEncoderTest {
}
@Test
public void canWriteHeaderFrameAfterGoAwayReceived() {
public void canWriteHeaderFrameAfterGoAwayReceived() throws Http2Exception {
writeAllFlowControlledFrames();
goAwayReceived(STREAM_ID);
ChannelPromise promise = newPromise();
@ -803,11 +803,11 @@ public class DefaultHttp2ConnectionEncoderTest {
return connection.stream(streamId);
}
private void goAwayReceived(int lastStreamId) {
private void goAwayReceived(int lastStreamId) throws Http2Exception {
connection.goAwayReceived(lastStreamId, 0, EMPTY_BUFFER);
}
private void goAwaySent(int lastStreamId) {
private void goAwaySent(int lastStreamId) throws Http2Exception {
connection.goAwaySent(lastStreamId, 0, EMPTY_BUFFER);
}

View File

@ -47,8 +47,8 @@ import static org.junit.Assert.fail;
import static org.mockito.Mockito.any;
import static org.mockito.Mockito.anyInt;
import static org.mockito.Mockito.anyLong;
import static org.mockito.Mockito.eq;
import static org.mockito.Mockito.doAnswer;
import static org.mockito.Mockito.eq;
import static org.mockito.Mockito.never;
import static org.mockito.Mockito.verify;
@ -423,11 +423,29 @@ public class DefaultHttp2ConnectionTest {
}
@Test(expected = Http2Exception.class)
public void goAwayReceivedShouldDisallowCreation() throws Http2Exception {
public void goAwayReceivedShouldDisallowLocalCreation() throws Http2Exception {
server.goAwayReceived(0, 1L, Unpooled.EMPTY_BUFFER);
server.local().createStream(3, true);
}
@Test
public void goAwayReceivedShouldAllowRemoteCreation() throws Http2Exception {
server.goAwayReceived(0, 1L, Unpooled.EMPTY_BUFFER);
server.remote().createStream(3, true);
}
@Test(expected = Http2Exception.class)
public void goAwaySentShouldDisallowRemoteCreation() throws Http2Exception {
server.goAwaySent(0, 1L, Unpooled.EMPTY_BUFFER);
server.remote().createStream(2, true);
}
@Test
public void goAwaySentShouldAllowLocalCreation() throws Http2Exception {
server.goAwaySent(0, 1L, Unpooled.EMPTY_BUFFER);
server.local().createStream(2, true);
}
@Test
public void closeShouldSucceed() throws Http2Exception {
Http2Stream stream = server.remote().createStream(3, true);

View File

@ -189,6 +189,7 @@ public class Http2ConnectionHandlerTest {
when(connection.stream(NON_EXISTANT_STREAM_ID)).thenReturn(null);
when(connection.numActiveStreams()).thenReturn(1);
when(connection.stream(STREAM_ID)).thenReturn(stream);
when(connection.goAwaySent(anyInt(), anyLong(), any(ByteBuf.class))).thenReturn(true);
when(stream.open(anyBoolean())).thenReturn(stream);
when(encoder.writeSettings(eq(ctx), any(Http2Settings.class), eq(promise))).thenReturn(future);
when(ctx.alloc()).thenReturn(UnpooledByteBufAllocator.DEFAULT);
@ -638,6 +639,12 @@ public class Http2ConnectionHandlerTest {
when(connection.goAwaySent()).thenReturn(true);
when(remote.lastStreamKnownByPeer()).thenReturn(STREAM_ID);
doAnswer(new Answer<Boolean>() {
@Override
public Boolean answer(InvocationOnMock invocationOnMock) {
throw new IllegalStateException();
}
}).when(connection).goAwaySent(anyInt(), anyLong(), any(ByteBuf.class));
handler.goAway(ctx, STREAM_ID + 2, errorCode, data, promise);
assertTrue(promise.isDone());
assertFalse(promise.isSuccess());

View File

@ -889,6 +889,84 @@ public class Http2ConnectionRoundtripTest {
setServerGracefulShutdownTime(0);
}
@Test
public void createStreamSynchronouslyAfterGoAwayReceivedShouldFailLocally() throws Exception {
bootstrapEnv(1, 1, 2, 1, 1);
final CountDownLatch clientGoAwayLatch = new CountDownLatch(1);
doAnswer(new Answer<Void>() {
@Override
public Void answer(InvocationOnMock invocationOnMock) throws Throwable {
clientGoAwayLatch.countDown();
return null;
}
}).when(clientListener).onGoAwayRead(any(ChannelHandlerContext.class), anyInt(), anyLong(), any(ByteBuf.class));
// We want both sides to do graceful shutdown during the test.
setClientGracefulShutdownTime(10000);
setServerGracefulShutdownTime(10000);
final Http2Headers headers = dummyHeaders();
final AtomicReference<ChannelFuture> clientWriteAfterGoAwayFutureRef = new AtomicReference<ChannelFuture>();
final CountDownLatch clientWriteAfterGoAwayLatch = new CountDownLatch(1);
doAnswer(new Answer<Void>() {
@Override
public Void answer(InvocationOnMock invocationOnMock) throws Throwable {
ChannelFuture f = http2Client.encoder().writeHeaders(ctx(), 5, headers, 0, (short) 16, false, 0,
true, newPromise());
clientWriteAfterGoAwayFutureRef.set(f);
f.addListener(new ChannelFutureListener() {
@Override
public void operationComplete(ChannelFuture future) throws Exception {
clientWriteAfterGoAwayLatch.countDown();
}
});
http2Client.flush(ctx());
return null;
}
}).when(clientListener).onGoAwayRead(any(ChannelHandlerContext.class), anyInt(), anyLong(), any(ByteBuf.class));
runInChannel(clientChannel, new Http2Runnable() {
@Override
public void run() throws Http2Exception {
http2Client.encoder().writeHeaders(ctx(), 3, headers, 0, (short) 16, false, 0,
true, newPromise());
http2Client.flush(ctx());
}
});
assertTrue(serverSettingsAckLatch.await(DEFAULT_AWAIT_TIMEOUT_SECONDS, SECONDS));
// Server has received the headers, so the stream is open
assertTrue(requestLatch.await(DEFAULT_AWAIT_TIMEOUT_SECONDS, SECONDS));
runInChannel(serverChannel, new Http2Runnable() {
@Override
public void run() throws Http2Exception {
http2Server.encoder().writeGoAway(serverCtx(), 3, NO_ERROR.code(), EMPTY_BUFFER, serverNewPromise());
http2Server.flush(serverCtx());
}
});
// Wait for the client's write operation to complete.
assertTrue(clientWriteAfterGoAwayLatch.await(DEFAULT_AWAIT_TIMEOUT_SECONDS, SECONDS));
ChannelFuture clientWriteAfterGoAwayFuture = clientWriteAfterGoAwayFutureRef.get();
assertNotNull(clientWriteAfterGoAwayFuture);
Throwable clientCause = clientWriteAfterGoAwayFuture.cause();
assertThat(clientCause, is(instanceOf(Http2Exception.StreamException.class)));
assertEquals(Http2Error.REFUSED_STREAM.code(), ((Http2Exception.StreamException) clientCause).error().code());
// Wait for the server to receive a GO_AWAY, but this is expected to timeout!
assertFalse(goAwayLatch.await(1, SECONDS));
verify(serverListener, never()).onGoAwayRead(any(ChannelHandlerContext.class), anyInt(), anyLong(),
any(ByteBuf.class));
// Shutdown shouldn't wait for the server to close streams
setClientGracefulShutdownTime(0);
setServerGracefulShutdownTime(0);
}
@Test
public void flowControlProperlyChunksLargeMessage() throws Exception {
final Http2Headers headers = dummyHeaders();

View File

@ -222,7 +222,7 @@ public class StreamBufferingEncoderTest {
}
@Test
public void bufferingNewStreamFailsAfterGoAwayReceived() {
public void bufferingNewStreamFailsAfterGoAwayReceived() throws Http2Exception {
encoder.writeSettingsAck(ctx, newPromise());
setMaxConcurrentStreams(0);
connection.goAwayReceived(1, 8, EMPTY_BUFFER);
@ -235,7 +235,7 @@ public class StreamBufferingEncoderTest {
}
@Test
public void receivingGoAwayFailsBufferedStreams() {
public void receivingGoAwayFailsBufferedStreams() throws Http2Exception {
encoder.writeSettingsAck(ctx, newPromise());
setMaxConcurrentStreams(5);