From bc46a99eaa448136ae3a54c248cc3f9dfb8f87e9 Mon Sep 17 00:00:00 2001 From: Scott Mitchell Date: Tue, 27 Jun 2017 15:27:42 -0400 Subject: [PATCH] DefaultHttp2ConnectionEncoder#writeHeaders shouldn't send GO_AWAY if stream is closed Motivation: DefaultHttp2ConnectionEncoder#writeHeaders attempts to find a stream object, and if one doesn't exist it tries to create one. However in the event that the local endpoint has received a RST_STREAM frame before writing the response headers we attempt to create a stream. Since this stream ID is for the incorrect endpoint we then generate a GO_AWAY for what appears to be a protocol error, but can instead be failed locally. Modifications: - Just fail the local promise in the above situation instead of sending a GO_AWAY Result: Less severe consequences if the server asynchronously sends headers after a RST_STREAM has been received. Fixes https://github.com/netty/netty/issues/6906. --- .../http2/DefaultHttp2ConnectionEncoder.java | 17 +++-- .../http2/Http2ConnectionRoundtripTest.java | 73 ++++++++++++++++++- 2 files changed, 83 insertions(+), 7 deletions(-) diff --git a/codec-http2/src/main/java/io/netty/handler/codec/http2/DefaultHttp2ConnectionEncoder.java b/codec-http2/src/main/java/io/netty/handler/codec/http2/DefaultHttp2ConnectionEncoder.java index bbe8b08c34..f0af13b394 100644 --- a/codec-http2/src/main/java/io/netty/handler/codec/http2/DefaultHttp2ConnectionEncoder.java +++ b/codec-http2/src/main/java/io/netty/handler/codec/http2/DefaultHttp2ConnectionEncoder.java @@ -126,8 +126,7 @@ public class DefaultHttp2ConnectionEncoder implements Http2ConnectionEncoder { // Allowed sending DATA frames in these states. break; default: - throw new IllegalStateException(String.format( - "Stream %d in unexpected state: %s", stream.id(), stream.state())); + throw new IllegalStateException("Stream " + stream.id() + " in unexpected state " + stream.state()); } } catch (Throwable e) { data.release(); @@ -153,7 +152,15 @@ public class DefaultHttp2ConnectionEncoder implements Http2ConnectionEncoder { try { Http2Stream stream = connection.stream(streamId); if (stream == null) { - stream = connection.local().createStream(streamId, endOfStream); + try { + stream = connection.local().createStream(streamId, endOfStream); + } catch (Http2Exception cause) { + if (connection.remote().mayHaveCreatedStream(streamId)) { + promise.tryFailure(new IllegalStateException("Stream no longer exists: " + streamId, cause)); + return promise; + } + throw cause; + } } else { switch (stream.state()) { case RESERVED_LOCAL: @@ -164,8 +171,8 @@ public class DefaultHttp2ConnectionEncoder implements Http2ConnectionEncoder { // Allowed sending headers in these states. break; default: - throw new IllegalStateException(String.format( - "Stream %d in unexpected state: %s", stream.id(), stream.state())); + throw new IllegalStateException("Stream " + stream.id() + " in unexpected state " + + stream.state()); } } diff --git a/codec-http2/src/test/java/io/netty/handler/codec/http2/Http2ConnectionRoundtripTest.java b/codec-http2/src/test/java/io/netty/handler/codec/http2/Http2ConnectionRoundtripTest.java index f04655d2a7..8fb1f8a26d 100644 --- a/codec-http2/src/test/java/io/netty/handler/codec/http2/Http2ConnectionRoundtripTest.java +++ b/codec-http2/src/test/java/io/netty/handler/codec/http2/Http2ConnectionRoundtripTest.java @@ -50,12 +50,16 @@ import java.util.Random; import java.util.concurrent.CountDownLatch; import java.util.concurrent.atomic.AtomicReference; +import static io.netty.handler.codec.http2.Http2CodecUtil.CONNECTION_STREAM_ID; +import static io.netty.handler.codec.http2.Http2CodecUtil.DEFAULT_PRIORITY_WEIGHT; import static io.netty.handler.codec.http2.Http2Error.PROTOCOL_ERROR; import static io.netty.handler.codec.http2.Http2TestUtil.randomString; import static io.netty.handler.codec.http2.Http2TestUtil.runInChannel; import static io.netty.util.CharsetUtil.UTF_8; -import static java.util.concurrent.TimeUnit.MILLISECONDS; import static java.util.concurrent.TimeUnit.SECONDS; +import static org.hamcrest.CoreMatchers.instanceOf; +import static org.hamcrest.CoreMatchers.not; +import static org.hamcrest.MatcherAssert.assertThat; import static org.junit.Assert.assertArrayEquals; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertFalse; @@ -67,9 +71,9 @@ import static org.mockito.Mockito.anyBoolean; import static org.mockito.Mockito.anyInt; import static org.mockito.Mockito.anyLong; import static org.mockito.Mockito.anyShort; -import static org.mockito.Mockito.eq; import static org.mockito.Mockito.doAnswer; import static org.mockito.Mockito.doThrow; +import static org.mockito.Mockito.eq; import static org.mockito.Mockito.never; import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; @@ -502,6 +506,71 @@ public class Http2ConnectionRoundtripTest { verify(clientListener, never()).onRstStreamRead(any(ChannelHandlerContext.class), anyInt(), anyLong()); } + @Test + public void headersWriteForPeerStreamWhichWasResetShouldNotGoAway() throws Exception { + bootstrapEnv(1, 1, 1, 0); + + final CountDownLatch serverGotRstLatch = new CountDownLatch(1); + final CountDownLatch serverWriteHeadersLatch = new CountDownLatch(1); + final AtomicReference serverWriteHeadersCauseRef = new AtomicReference(); + + final Http2Headers headers = dummyHeaders(); + final int streamId = 3; + runInChannel(clientChannel, new Http2Runnable() { + @Override + public void run() throws Http2Exception { + http2Client.encoder().writeHeaders(ctx(), streamId, headers, CONNECTION_STREAM_ID, + DEFAULT_PRIORITY_WEIGHT, false, 0, false, newPromise()); + http2Client.encoder().writeRstStream(ctx(), streamId, Http2Error.CANCEL.code(), newPromise()); + http2Client.flush(ctx()); + } + }); + + doAnswer(new Answer() { + @Override + public Void answer(InvocationOnMock invocationOnMock) throws Throwable { + if (streamId == (Integer) invocationOnMock.getArgument(1)) { + serverGotRstLatch.countDown(); + } + return null; + } + }).when(serverListener).onRstStreamRead(any(ChannelHandlerContext.class), eq(streamId), anyLong()); + + assertTrue(serverSettingsAckLatch.await(DEFAULT_AWAIT_TIMEOUT_SECONDS, SECONDS)); + assertTrue(serverGotRstLatch.await(DEFAULT_AWAIT_TIMEOUT_SECONDS, SECONDS)); + + verify(serverListener).onHeadersRead(any(ChannelHandlerContext.class), eq(streamId), eq(headers), anyInt(), + anyShort(), anyBoolean(), anyInt(), eq(false)); + + // Now have the server attempt to send a headers frame simulating some asynchronous work. + runInChannel(serverConnectedChannel, new Http2Runnable() { + @Override + public void run() throws Http2Exception { + http2Server.encoder().writeHeaders(serverCtx(), streamId, headers, 0, true, serverNewPromise()) + .addListener(new ChannelFutureListener() { + @Override + public void operationComplete(ChannelFuture future) throws Exception { + serverWriteHeadersCauseRef.set(future.cause()); + serverWriteHeadersLatch.countDown(); + } + }); + http2Server.flush(serverCtx()); + } + }); + + assertTrue(serverWriteHeadersLatch.await(DEFAULT_AWAIT_TIMEOUT_SECONDS, SECONDS)); + Throwable serverWriteHeadersCause = serverWriteHeadersCauseRef.get(); + assertNotNull(serverWriteHeadersCause); + assertThat(serverWriteHeadersCauseRef.get(), not(instanceOf(Http2Exception.class))); + + // Server should receive a RST_STREAM for stream 3. + verify(serverListener, never()).onGoAwayRead(any(ChannelHandlerContext.class), anyInt(), anyLong(), + any(ByteBuf.class)); + verify(clientListener, never()).onGoAwayRead(any(ChannelHandlerContext.class), anyInt(), anyLong(), + any(ByteBuf.class)); + verify(clientListener, never()).onRstStreamRead(any(ChannelHandlerContext.class), anyInt(), anyLong()); + } + @Test public void http2ExceptionInPipelineShouldCloseConnection() throws Exception { bootstrapEnv(1, 1, 2, 1);