diff --git a/codec-http2/src/main/java/io/netty/handler/codec/http2/DefaultHttp2ConnectionDecoder.java b/codec-http2/src/main/java/io/netty/handler/codec/http2/DefaultHttp2ConnectionDecoder.java index 8a41acc069..10da347ef3 100644 --- a/codec-http2/src/main/java/io/netty/handler/codec/http2/DefaultHttp2ConnectionDecoder.java +++ b/codec-http2/src/main/java/io/netty/handler/codec/http2/DefaultHttp2ConnectionDecoder.java @@ -214,8 +214,8 @@ public class DefaultHttp2ConnectionDecoder implements Http2ConnectionDecoder { void onGoAwayRead0(ChannelHandlerContext ctx, int lastStreamId, long errorCode, ByteBuf debugData) throws Http2Exception { - connection.goAwayReceived(lastStreamId, errorCode, debugData); listener.onGoAwayRead(ctx, lastStreamId, errorCode, debugData); + connection.goAwayReceived(lastStreamId, errorCode, debugData); } void onUnknownFrame0(ChannelHandlerContext ctx, byte frameType, int streamId, Http2Flags flags, diff --git a/codec-http2/src/test/java/io/netty/handler/codec/http2/Http2ConnectionRoundtripTest.java b/codec-http2/src/test/java/io/netty/handler/codec/http2/Http2ConnectionRoundtripTest.java index 4ef4dc6306..f9dee00179 100644 --- a/codec-http2/src/test/java/io/netty/handler/codec/http2/Http2ConnectionRoundtripTest.java +++ b/codec-http2/src/test/java/io/netty/handler/codec/http2/Http2ConnectionRoundtripTest.java @@ -52,6 +52,8 @@ import java.io.ByteArrayOutputStream; import java.util.Random; import java.util.concurrent.CountDownLatch; import java.util.concurrent.ExecutionException; +import java.util.concurrent.atomic.AtomicBoolean; +import java.util.concurrent.atomic.AtomicInteger; import java.util.concurrent.atomic.AtomicReference; import static io.netty.buffer.Unpooled.EMPTY_BUFFER; @@ -945,47 +947,33 @@ public class Http2ConnectionRoundtripTest { } @Test - public void createStreamSynchronouslyAfterGoAwayReceivedShouldFailLocally() throws Exception { + public void listenerIsNotifiedOfGoawayBeforeStreamsAreRemovedFromTheConnection() throws Exception { bootstrapEnv(1, 1, 2, 1, 1); - final CountDownLatch clientGoAwayLatch = new CountDownLatch(1); - doAnswer(new Answer() { - @Override - public Void answer(InvocationOnMock invocationOnMock) throws Throwable { - clientGoAwayLatch.countDown(); - return null; - } - }).when(clientListener).onGoAwayRead(any(ChannelHandlerContext.class), anyInt(), anyLong(), any(ByteBuf.class)); - // We want both sides to do graceful shutdown during the test. setClientGracefulShutdownTime(10000); setServerGracefulShutdownTime(10000); - final Http2Headers headers = dummyHeaders(); - final AtomicReference clientWriteAfterGoAwayFutureRef = new AtomicReference(); - final CountDownLatch clientWriteAfterGoAwayLatch = new CountDownLatch(1); + final AtomicReference clientStream3State = new AtomicReference(); + final CountDownLatch clientGoAwayLatch = new CountDownLatch(1); doAnswer(new Answer() { @Override public Void answer(InvocationOnMock invocationOnMock) throws Throwable { - ChannelFuture f = http2Client.encoder().writeHeaders(ctx(), 5, headers, 0, (short) 16, false, 0, - true, newPromise()); - clientWriteAfterGoAwayFutureRef.set(f); - f.addListener(new ChannelFutureListener() { - @Override - public void operationComplete(ChannelFuture future) throws Exception { - clientWriteAfterGoAwayLatch.countDown(); - } - }); - http2Client.flush(ctx()); + clientStream3State.set(http2Client.connection().stream(3).state()); + clientGoAwayLatch.countDown(); return null; } }).when(clientListener).onGoAwayRead(any(ChannelHandlerContext.class), anyInt(), anyLong(), any(ByteBuf.class)); + // Create a single stream by sending a HEADERS frame to the server. + final Http2Headers headers = dummyHeaders(); runInChannel(clientChannel, new Http2Runnable() { @Override public void run() throws Http2Exception { + http2Client.encoder().writeHeaders(ctx(), 1, headers, 0, (short) 16, false, 0, + false, newPromise()); http2Client.encoder().writeHeaders(ctx(), 3, headers, 0, (short) 16, false, 0, - true, newPromise()); + false, newPromise()); http2Client.flush(ctx()); } }); @@ -998,24 +986,38 @@ public class Http2ConnectionRoundtripTest { runInChannel(serverChannel, new Http2Runnable() { @Override public void run() throws Http2Exception { - http2Server.encoder().writeGoAway(serverCtx(), 3, NO_ERROR.code(), EMPTY_BUFFER, serverNewPromise()); + http2Server.encoder().writeGoAway(serverCtx(), 1, NO_ERROR.code(), EMPTY_BUFFER, serverNewPromise()); http2Server.flush(serverCtx()); } }); - // Wait for the client's write operation to complete. - assertTrue(clientWriteAfterGoAwayLatch.await(DEFAULT_AWAIT_TIMEOUT_SECONDS, SECONDS)); + // wait for the client to receive the GO_AWAY. + assertTrue(clientGoAwayLatch.await(DEFAULT_AWAIT_TIMEOUT_SECONDS, SECONDS)); + verify(clientListener).onGoAwayRead(any(ChannelHandlerContext.class), eq(1), eq(NO_ERROR.code()), + any(ByteBuf.class)); + assertEquals(Http2Stream.State.OPEN, clientStream3State.get()); - ChannelFuture clientWriteAfterGoAwayFuture = clientWriteAfterGoAwayFutureRef.get(); - assertNotNull(clientWriteAfterGoAwayFuture); - Throwable clientCause = clientWriteAfterGoAwayFuture.cause(); - assertThat(clientCause, is(instanceOf(Http2Exception.StreamException.class))); - assertEquals(Http2Error.REFUSED_STREAM.code(), ((Http2Exception.StreamException) clientCause).error().code()); + // Make sure that stream 3 has been closed which is true if it's gone. + final CountDownLatch probeStreamCount = new CountDownLatch(1); + final AtomicBoolean stream3Exists = new AtomicBoolean(); + final AtomicInteger streamCount = new AtomicInteger(); + runInChannel(this.clientChannel, new Http2Runnable() { + @Override + public void run() throws Http2Exception { + stream3Exists.set(http2Client.connection().stream(3) != null); + streamCount.set(http2Client.connection().numActiveStreams()); + probeStreamCount.countDown(); + } + }); + // The stream should be closed right after + assertTrue(probeStreamCount.await(DEFAULT_AWAIT_TIMEOUT_SECONDS, SECONDS)); + assertEquals(1, streamCount.get()); + assertFalse(stream3Exists.get()); // Wait for the server to receive a GO_AWAY, but this is expected to timeout! assertFalse(goAwayLatch.await(1, SECONDS)); verify(serverListener, never()).onGoAwayRead(any(ChannelHandlerContext.class), anyInt(), anyLong(), - any(ByteBuf.class)); + any(ByteBuf.class)); // Shutdown shouldn't wait for the server to close streams setClientGracefulShutdownTime(0);