DefaultHttp2ConnectionEncoder#writeHeaders shouldn't send GO_AWAY if stream is closed

Motivation:
DefaultHttp2ConnectionEncoder#writeHeaders attempts to find a stream object, and if one doesn't exist it tries to create one. However in the event that the local endpoint has received a RST_STREAM frame before writing the response headers we attempt to create a stream. Since this stream ID is for the incorrect endpoint we then generate a GO_AWAY for what appears to be a protocol error, but can instead be failed locally.

Modifications:
- Just fail the local promise in the above situation instead of sending a GO_AWAY

Result:
Less severe consequences if the server asynchronously sends headers after a RST_STREAM has been received.
Fixes https://github.com/netty/netty/issues/6906.
This commit is contained in:
Scott Mitchell 2017-06-27 15:27:42 -04:00
parent 07a641900c
commit bc46a99eaa
2 changed files with 83 additions and 7 deletions

View File

@ -126,8 +126,7 @@ public class DefaultHttp2ConnectionEncoder implements Http2ConnectionEncoder {
// Allowed sending DATA frames in these states.
break;
default:
throw new IllegalStateException(String.format(
"Stream %d in unexpected state: %s", stream.id(), stream.state()));
throw new IllegalStateException("Stream " + stream.id() + " in unexpected state " + stream.state());
}
} catch (Throwable e) {
data.release();
@ -153,7 +152,15 @@ public class DefaultHttp2ConnectionEncoder implements Http2ConnectionEncoder {
try {
Http2Stream stream = connection.stream(streamId);
if (stream == null) {
stream = connection.local().createStream(streamId, endOfStream);
try {
stream = connection.local().createStream(streamId, endOfStream);
} catch (Http2Exception cause) {
if (connection.remote().mayHaveCreatedStream(streamId)) {
promise.tryFailure(new IllegalStateException("Stream no longer exists: " + streamId, cause));
return promise;
}
throw cause;
}
} else {
switch (stream.state()) {
case RESERVED_LOCAL:
@ -164,8 +171,8 @@ public class DefaultHttp2ConnectionEncoder implements Http2ConnectionEncoder {
// Allowed sending headers in these states.
break;
default:
throw new IllegalStateException(String.format(
"Stream %d in unexpected state: %s", stream.id(), stream.state()));
throw new IllegalStateException("Stream " + stream.id() + " in unexpected state " +
stream.state());
}
}

View File

@ -50,12 +50,16 @@ import java.util.Random;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.atomic.AtomicReference;
import static io.netty.handler.codec.http2.Http2CodecUtil.CONNECTION_STREAM_ID;
import static io.netty.handler.codec.http2.Http2CodecUtil.DEFAULT_PRIORITY_WEIGHT;
import static io.netty.handler.codec.http2.Http2Error.PROTOCOL_ERROR;
import static io.netty.handler.codec.http2.Http2TestUtil.randomString;
import static io.netty.handler.codec.http2.Http2TestUtil.runInChannel;
import static io.netty.util.CharsetUtil.UTF_8;
import static java.util.concurrent.TimeUnit.MILLISECONDS;
import static java.util.concurrent.TimeUnit.SECONDS;
import static org.hamcrest.CoreMatchers.instanceOf;
import static org.hamcrest.CoreMatchers.not;
import static org.hamcrest.MatcherAssert.assertThat;
import static org.junit.Assert.assertArrayEquals;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertFalse;
@ -67,9 +71,9 @@ import static org.mockito.Mockito.anyBoolean;
import static org.mockito.Mockito.anyInt;
import static org.mockito.Mockito.anyLong;
import static org.mockito.Mockito.anyShort;
import static org.mockito.Mockito.eq;
import static org.mockito.Mockito.doAnswer;
import static org.mockito.Mockito.doThrow;
import static org.mockito.Mockito.eq;
import static org.mockito.Mockito.never;
import static org.mockito.Mockito.times;
import static org.mockito.Mockito.verify;
@ -502,6 +506,71 @@ public class Http2ConnectionRoundtripTest {
verify(clientListener, never()).onRstStreamRead(any(ChannelHandlerContext.class), anyInt(), anyLong());
}
@Test
public void headersWriteForPeerStreamWhichWasResetShouldNotGoAway() throws Exception {
bootstrapEnv(1, 1, 1, 0);
final CountDownLatch serverGotRstLatch = new CountDownLatch(1);
final CountDownLatch serverWriteHeadersLatch = new CountDownLatch(1);
final AtomicReference<Throwable> serverWriteHeadersCauseRef = new AtomicReference<Throwable>();
final Http2Headers headers = dummyHeaders();
final int streamId = 3;
runInChannel(clientChannel, new Http2Runnable() {
@Override
public void run() throws Http2Exception {
http2Client.encoder().writeHeaders(ctx(), streamId, headers, CONNECTION_STREAM_ID,
DEFAULT_PRIORITY_WEIGHT, false, 0, false, newPromise());
http2Client.encoder().writeRstStream(ctx(), streamId, Http2Error.CANCEL.code(), newPromise());
http2Client.flush(ctx());
}
});
doAnswer(new Answer<Void>() {
@Override
public Void answer(InvocationOnMock invocationOnMock) throws Throwable {
if (streamId == (Integer) invocationOnMock.getArgument(1)) {
serverGotRstLatch.countDown();
}
return null;
}
}).when(serverListener).onRstStreamRead(any(ChannelHandlerContext.class), eq(streamId), anyLong());
assertTrue(serverSettingsAckLatch.await(DEFAULT_AWAIT_TIMEOUT_SECONDS, SECONDS));
assertTrue(serverGotRstLatch.await(DEFAULT_AWAIT_TIMEOUT_SECONDS, SECONDS));
verify(serverListener).onHeadersRead(any(ChannelHandlerContext.class), eq(streamId), eq(headers), anyInt(),
anyShort(), anyBoolean(), anyInt(), eq(false));
// Now have the server attempt to send a headers frame simulating some asynchronous work.
runInChannel(serverConnectedChannel, new Http2Runnable() {
@Override
public void run() throws Http2Exception {
http2Server.encoder().writeHeaders(serverCtx(), streamId, headers, 0, true, serverNewPromise())
.addListener(new ChannelFutureListener() {
@Override
public void operationComplete(ChannelFuture future) throws Exception {
serverWriteHeadersCauseRef.set(future.cause());
serverWriteHeadersLatch.countDown();
}
});
http2Server.flush(serverCtx());
}
});
assertTrue(serverWriteHeadersLatch.await(DEFAULT_AWAIT_TIMEOUT_SECONDS, SECONDS));
Throwable serverWriteHeadersCause = serverWriteHeadersCauseRef.get();
assertNotNull(serverWriteHeadersCause);
assertThat(serverWriteHeadersCauseRef.get(), not(instanceOf(Http2Exception.class)));
// Server should receive a RST_STREAM for stream 3.
verify(serverListener, never()).onGoAwayRead(any(ChannelHandlerContext.class), anyInt(), anyLong(),
any(ByteBuf.class));
verify(clientListener, never()).onGoAwayRead(any(ChannelHandlerContext.class), anyInt(), anyLong(),
any(ByteBuf.class));
verify(clientListener, never()).onRstStreamRead(any(ChannelHandlerContext.class), anyInt(), anyLong());
}
@Test
public void http2ExceptionInPipelineShouldCloseConnection() throws Exception {
bootstrapEnv(1, 1, 2, 1);