DefaultHttp2ConnectionEncoder#writeHeaders shouldn't send GO_AWAY if stream is closed
Motivation: DefaultHttp2ConnectionEncoder#writeHeaders attempts to find a stream object, and if one doesn't exist it tries to create one. However in the event that the local endpoint has received a RST_STREAM frame before writing the response headers we attempt to create a stream. Since this stream ID is for the incorrect endpoint we then generate a GO_AWAY for what appears to be a protocol error, but can instead be failed locally. Modifications: - Just fail the local promise in the above situation instead of sending a GO_AWAY Result: Less severe consequences if the server asynchronously sends headers after a RST_STREAM has been received. Fixes https://github.com/netty/netty/issues/6906.
This commit is contained in:
parent
07a641900c
commit
bc46a99eaa
@ -126,8 +126,7 @@ public class DefaultHttp2ConnectionEncoder implements Http2ConnectionEncoder {
|
||||
// Allowed sending DATA frames in these states.
|
||||
break;
|
||||
default:
|
||||
throw new IllegalStateException(String.format(
|
||||
"Stream %d in unexpected state: %s", stream.id(), stream.state()));
|
||||
throw new IllegalStateException("Stream " + stream.id() + " in unexpected state " + stream.state());
|
||||
}
|
||||
} catch (Throwable e) {
|
||||
data.release();
|
||||
@ -153,7 +152,15 @@ public class DefaultHttp2ConnectionEncoder implements Http2ConnectionEncoder {
|
||||
try {
|
||||
Http2Stream stream = connection.stream(streamId);
|
||||
if (stream == null) {
|
||||
stream = connection.local().createStream(streamId, endOfStream);
|
||||
try {
|
||||
stream = connection.local().createStream(streamId, endOfStream);
|
||||
} catch (Http2Exception cause) {
|
||||
if (connection.remote().mayHaveCreatedStream(streamId)) {
|
||||
promise.tryFailure(new IllegalStateException("Stream no longer exists: " + streamId, cause));
|
||||
return promise;
|
||||
}
|
||||
throw cause;
|
||||
}
|
||||
} else {
|
||||
switch (stream.state()) {
|
||||
case RESERVED_LOCAL:
|
||||
@ -164,8 +171,8 @@ public class DefaultHttp2ConnectionEncoder implements Http2ConnectionEncoder {
|
||||
// Allowed sending headers in these states.
|
||||
break;
|
||||
default:
|
||||
throw new IllegalStateException(String.format(
|
||||
"Stream %d in unexpected state: %s", stream.id(), stream.state()));
|
||||
throw new IllegalStateException("Stream " + stream.id() + " in unexpected state " +
|
||||
stream.state());
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -50,12 +50,16 @@ import java.util.Random;
|
||||
import java.util.concurrent.CountDownLatch;
|
||||
import java.util.concurrent.atomic.AtomicReference;
|
||||
|
||||
import static io.netty.handler.codec.http2.Http2CodecUtil.CONNECTION_STREAM_ID;
|
||||
import static io.netty.handler.codec.http2.Http2CodecUtil.DEFAULT_PRIORITY_WEIGHT;
|
||||
import static io.netty.handler.codec.http2.Http2Error.PROTOCOL_ERROR;
|
||||
import static io.netty.handler.codec.http2.Http2TestUtil.randomString;
|
||||
import static io.netty.handler.codec.http2.Http2TestUtil.runInChannel;
|
||||
import static io.netty.util.CharsetUtil.UTF_8;
|
||||
import static java.util.concurrent.TimeUnit.MILLISECONDS;
|
||||
import static java.util.concurrent.TimeUnit.SECONDS;
|
||||
import static org.hamcrest.CoreMatchers.instanceOf;
|
||||
import static org.hamcrest.CoreMatchers.not;
|
||||
import static org.hamcrest.MatcherAssert.assertThat;
|
||||
import static org.junit.Assert.assertArrayEquals;
|
||||
import static org.junit.Assert.assertEquals;
|
||||
import static org.junit.Assert.assertFalse;
|
||||
@ -67,9 +71,9 @@ import static org.mockito.Mockito.anyBoolean;
|
||||
import static org.mockito.Mockito.anyInt;
|
||||
import static org.mockito.Mockito.anyLong;
|
||||
import static org.mockito.Mockito.anyShort;
|
||||
import static org.mockito.Mockito.eq;
|
||||
import static org.mockito.Mockito.doAnswer;
|
||||
import static org.mockito.Mockito.doThrow;
|
||||
import static org.mockito.Mockito.eq;
|
||||
import static org.mockito.Mockito.never;
|
||||
import static org.mockito.Mockito.times;
|
||||
import static org.mockito.Mockito.verify;
|
||||
@ -502,6 +506,71 @@ public class Http2ConnectionRoundtripTest {
|
||||
verify(clientListener, never()).onRstStreamRead(any(ChannelHandlerContext.class), anyInt(), anyLong());
|
||||
}
|
||||
|
||||
@Test
|
||||
public void headersWriteForPeerStreamWhichWasResetShouldNotGoAway() throws Exception {
|
||||
bootstrapEnv(1, 1, 1, 0);
|
||||
|
||||
final CountDownLatch serverGotRstLatch = new CountDownLatch(1);
|
||||
final CountDownLatch serverWriteHeadersLatch = new CountDownLatch(1);
|
||||
final AtomicReference<Throwable> serverWriteHeadersCauseRef = new AtomicReference<Throwable>();
|
||||
|
||||
final Http2Headers headers = dummyHeaders();
|
||||
final int streamId = 3;
|
||||
runInChannel(clientChannel, new Http2Runnable() {
|
||||
@Override
|
||||
public void run() throws Http2Exception {
|
||||
http2Client.encoder().writeHeaders(ctx(), streamId, headers, CONNECTION_STREAM_ID,
|
||||
DEFAULT_PRIORITY_WEIGHT, false, 0, false, newPromise());
|
||||
http2Client.encoder().writeRstStream(ctx(), streamId, Http2Error.CANCEL.code(), newPromise());
|
||||
http2Client.flush(ctx());
|
||||
}
|
||||
});
|
||||
|
||||
doAnswer(new Answer<Void>() {
|
||||
@Override
|
||||
public Void answer(InvocationOnMock invocationOnMock) throws Throwable {
|
||||
if (streamId == (Integer) invocationOnMock.getArgument(1)) {
|
||||
serverGotRstLatch.countDown();
|
||||
}
|
||||
return null;
|
||||
}
|
||||
}).when(serverListener).onRstStreamRead(any(ChannelHandlerContext.class), eq(streamId), anyLong());
|
||||
|
||||
assertTrue(serverSettingsAckLatch.await(DEFAULT_AWAIT_TIMEOUT_SECONDS, SECONDS));
|
||||
assertTrue(serverGotRstLatch.await(DEFAULT_AWAIT_TIMEOUT_SECONDS, SECONDS));
|
||||
|
||||
verify(serverListener).onHeadersRead(any(ChannelHandlerContext.class), eq(streamId), eq(headers), anyInt(),
|
||||
anyShort(), anyBoolean(), anyInt(), eq(false));
|
||||
|
||||
// Now have the server attempt to send a headers frame simulating some asynchronous work.
|
||||
runInChannel(serverConnectedChannel, new Http2Runnable() {
|
||||
@Override
|
||||
public void run() throws Http2Exception {
|
||||
http2Server.encoder().writeHeaders(serverCtx(), streamId, headers, 0, true, serverNewPromise())
|
||||
.addListener(new ChannelFutureListener() {
|
||||
@Override
|
||||
public void operationComplete(ChannelFuture future) throws Exception {
|
||||
serverWriteHeadersCauseRef.set(future.cause());
|
||||
serverWriteHeadersLatch.countDown();
|
||||
}
|
||||
});
|
||||
http2Server.flush(serverCtx());
|
||||
}
|
||||
});
|
||||
|
||||
assertTrue(serverWriteHeadersLatch.await(DEFAULT_AWAIT_TIMEOUT_SECONDS, SECONDS));
|
||||
Throwable serverWriteHeadersCause = serverWriteHeadersCauseRef.get();
|
||||
assertNotNull(serverWriteHeadersCause);
|
||||
assertThat(serverWriteHeadersCauseRef.get(), not(instanceOf(Http2Exception.class)));
|
||||
|
||||
// Server should receive a RST_STREAM for stream 3.
|
||||
verify(serverListener, never()).onGoAwayRead(any(ChannelHandlerContext.class), anyInt(), anyLong(),
|
||||
any(ByteBuf.class));
|
||||
verify(clientListener, never()).onGoAwayRead(any(ChannelHandlerContext.class), anyInt(), anyLong(),
|
||||
any(ByteBuf.class));
|
||||
verify(clientListener, never()).onRstStreamRead(any(ChannelHandlerContext.class), anyInt(), anyLong());
|
||||
}
|
||||
|
||||
@Test
|
||||
public void http2ExceptionInPipelineShouldCloseConnection() throws Exception {
|
||||
bootstrapEnv(1, 1, 2, 1);
|
||||
|
Loading…
Reference in New Issue
Block a user