diff --git a/codec-http2/src/main/java/io/netty/handler/codec/http2/Http2ConnectionHandler.java b/codec-http2/src/main/java/io/netty/handler/codec/http2/Http2ConnectionHandler.java index 505335fb05..96aa2a70c3 100644 --- a/codec-http2/src/main/java/io/netty/handler/codec/http2/Http2ConnectionHandler.java +++ b/codec-http2/src/main/java/io/netty/handler/codec/http2/Http2ConnectionHandler.java @@ -784,6 +784,13 @@ public class Http2ConnectionHandler extends ByteToMessageDecoder implements Http // Don't write a RST_STREAM frame if we have already written one. return promise.setSuccess(); } + // Synchronously set the resetSent flag to prevent any subsequent calls + // from resulting in multiple reset frames being sent. + // + // This needs to be done before we notify the promise as the promise may have a listener attached that + // call resetStream(...) again. + stream.resetSent(); + final ChannelFuture future; // If the remote peer is not aware of the steam, then we are not allowed to send a RST_STREAM // https://tools.ietf.org/html/rfc7540#section-6.4. @@ -793,11 +800,6 @@ public class Http2ConnectionHandler extends ByteToMessageDecoder implements Http } else { future = frameWriter().writeRstStream(ctx, stream.id(), errorCode, promise); } - - // Synchronously set the resetSent flag to prevent any subsequent calls - // from resulting in multiple reset frames being sent. - stream.resetSent(); - if (future.isDone()) { processRstStreamWriteResult(ctx, stream, future); } else { diff --git a/codec-http2/src/test/java/io/netty/handler/codec/http2/Http2ConnectionHandlerTest.java b/codec-http2/src/test/java/io/netty/handler/codec/http2/Http2ConnectionHandlerTest.java index 1d055c6ed0..0143edc194 100644 --- a/codec-http2/src/test/java/io/netty/handler/codec/http2/Http2ConnectionHandlerTest.java +++ b/codec-http2/src/test/java/io/netty/handler/codec/http2/Http2ConnectionHandlerTest.java @@ -41,6 +41,7 @@ import org.junit.Test; import org.mockito.ArgumentCaptor; import org.mockito.ArgumentMatchers; import org.mockito.Mock; +import org.mockito.Mockito; import org.mockito.MockitoAnnotations; import org.mockito.invocation.InvocationOnMock; import org.mockito.stubbing.Answer; @@ -49,9 +50,11 @@ import java.util.List; import java.util.concurrent.CountDownLatch; import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicBoolean; +import java.util.concurrent.atomic.AtomicReference; import static io.netty.buffer.Unpooled.copiedBuffer; import static io.netty.handler.codec.http2.Http2CodecUtil.connectionPrefaceBuf; +import static io.netty.handler.codec.http2.Http2Error.CANCEL; import static io.netty.handler.codec.http2.Http2Error.PROTOCOL_ERROR; import static io.netty.handler.codec.http2.Http2Error.STREAM_CLOSED; import static io.netty.handler.codec.http2.Http2Stream.State.CLOSED; @@ -768,6 +771,51 @@ public class Http2ConnectionHandlerTest { verify(executor, never()).schedule(any(Runnable.class), anyLong(), any(TimeUnit.class)); } + @Test + public void writeMultipleRstFramesForSameStream() throws Exception { + handler = newHandler(); + when(stream.id()).thenReturn(STREAM_ID); + + final AtomicBoolean resetSent = new AtomicBoolean(); + when(stream.resetSent()).then(new Answer() { + @Override + public Http2Stream answer(InvocationOnMock invocationOnMock) { + resetSent.set(true); + return stream; + } + }); + when(stream.isResetSent()).then(new Answer() { + @Override + public Boolean answer(InvocationOnMock invocationOnMock) { + return resetSent.get(); + } + }); + when(frameWriter.writeRstStream(eq(ctx), eq(STREAM_ID), anyLong(), any(ChannelPromise.class))) + .then(new Answer() { + @Override + public ChannelFuture answer(InvocationOnMock invocationOnMock) throws Throwable { + ChannelPromise promise = invocationOnMock.getArgument(3); + return promise.setSuccess(); + } + }); + + ChannelPromise promise = + new DefaultChannelPromise(channel, ImmediateEventExecutor.INSTANCE); + final ChannelPromise promise2 = + new DefaultChannelPromise(channel, ImmediateEventExecutor.INSTANCE); + promise.addListener(new ChannelFutureListener() { + @Override + public void operationComplete(ChannelFuture future) { + handler.resetStream(ctx, STREAM_ID, STREAM_CLOSED.code(), promise2); + } + }); + + handler.resetStream(ctx, STREAM_ID, CANCEL.code(), promise); + verify(frameWriter).writeRstStream(eq(ctx), eq(STREAM_ID), anyLong(), any(ChannelPromise.class)); + assertTrue(promise.isSuccess()); + assertTrue(promise2.isSuccess()); + } + private void writeRstStreamUsingVoidPromise(int streamId) throws Exception { handler = newHandler(); final Throwable cause = new RuntimeException("fake exception");