Correctly guard against multiple RST frames for the same stream (#9811)

Motivation:

Http2ConnectionHandler tries to guard against sending multiple RST frames for the same stream. Unfortunally the code is not 100 % correct as it only updates the state after it calls write. This may lead to the situation of have an extra RST frame slip through if the second write for the RST frame is done from a listener that is attached to the promise.

Modifications:

- Update state before calling write
- Add unit test

Result:

Only ever send one RST frame per stream
This commit is contained in:
Norman Maurer 2019-11-27 06:54:19 +01:00
parent 154e40bcea
commit 621af5ce27
2 changed files with 55 additions and 5 deletions

View File

@ -777,6 +777,13 @@ public class Http2ConnectionHandler extends ByteToMessageDecoder implements Http
// Don't write a RST_STREAM frame if we have already written one. // Don't write a RST_STREAM frame if we have already written one.
return promise.setSuccess(); return promise.setSuccess();
} }
// Synchronously set the resetSent flag to prevent any subsequent calls
// from resulting in multiple reset frames being sent.
//
// This needs to be done before we notify the promise as the promise may have a listener attached that
// call resetStream(...) again.
stream.resetSent();
final ChannelFuture future; final ChannelFuture future;
// If the remote peer is not aware of the steam, then we are not allowed to send a RST_STREAM // If the remote peer is not aware of the steam, then we are not allowed to send a RST_STREAM
// https://tools.ietf.org/html/rfc7540#section-6.4. // https://tools.ietf.org/html/rfc7540#section-6.4.
@ -786,11 +793,6 @@ public class Http2ConnectionHandler extends ByteToMessageDecoder implements Http
} else { } else {
future = frameWriter().writeRstStream(ctx, stream.id(), errorCode, promise); future = frameWriter().writeRstStream(ctx, stream.id(), errorCode, promise);
} }
// Synchronously set the resetSent flag to prevent any subsequent calls
// from resulting in multiple reset frames being sent.
stream.resetSent();
if (future.isDone()) { if (future.isDone()) {
processRstStreamWriteResult(ctx, stream, future); processRstStreamWriteResult(ctx, stream, future);
} else { } else {

View File

@ -41,6 +41,7 @@ import org.junit.Test;
import org.mockito.ArgumentCaptor; import org.mockito.ArgumentCaptor;
import org.mockito.ArgumentMatchers; import org.mockito.ArgumentMatchers;
import org.mockito.Mock; import org.mockito.Mock;
import org.mockito.Mockito;
import org.mockito.MockitoAnnotations; import org.mockito.MockitoAnnotations;
import org.mockito.invocation.InvocationOnMock; import org.mockito.invocation.InvocationOnMock;
import org.mockito.stubbing.Answer; import org.mockito.stubbing.Answer;
@ -48,9 +49,11 @@ import org.mockito.stubbing.Answer;
import java.util.concurrent.CountDownLatch; import java.util.concurrent.CountDownLatch;
import java.util.concurrent.TimeUnit; import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicBoolean; import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicReference;
import static io.netty.buffer.Unpooled.copiedBuffer; import static io.netty.buffer.Unpooled.copiedBuffer;
import static io.netty.handler.codec.http2.Http2CodecUtil.connectionPrefaceBuf; import static io.netty.handler.codec.http2.Http2CodecUtil.connectionPrefaceBuf;
import static io.netty.handler.codec.http2.Http2Error.CANCEL;
import static io.netty.handler.codec.http2.Http2Error.PROTOCOL_ERROR; import static io.netty.handler.codec.http2.Http2Error.PROTOCOL_ERROR;
import static io.netty.handler.codec.http2.Http2Error.STREAM_CLOSED; import static io.netty.handler.codec.http2.Http2Error.STREAM_CLOSED;
import static io.netty.handler.codec.http2.Http2Stream.State.CLOSED; import static io.netty.handler.codec.http2.Http2Stream.State.CLOSED;
@ -736,6 +739,51 @@ public class Http2ConnectionHandlerTest {
verify(executor, never()).schedule(any(Runnable.class), anyLong(), any(TimeUnit.class)); verify(executor, never()).schedule(any(Runnable.class), anyLong(), any(TimeUnit.class));
} }
@Test
public void writeMultipleRstFramesForSameStream() throws Exception {
handler = newHandler();
when(stream.id()).thenReturn(STREAM_ID);
final AtomicBoolean resetSent = new AtomicBoolean();
when(stream.resetSent()).then(new Answer<Http2Stream>() {
@Override
public Http2Stream answer(InvocationOnMock invocationOnMock) {
resetSent.set(true);
return stream;
}
});
when(stream.isResetSent()).then(new Answer<Boolean>() {
@Override
public Boolean answer(InvocationOnMock invocationOnMock) {
return resetSent.get();
}
});
when(frameWriter.writeRstStream(eq(ctx), eq(STREAM_ID), anyLong(), any(ChannelPromise.class)))
.then(new Answer<ChannelFuture>() {
@Override
public ChannelFuture answer(InvocationOnMock invocationOnMock) throws Throwable {
ChannelPromise promise = invocationOnMock.getArgument(3);
return promise.setSuccess();
}
});
ChannelPromise promise =
new DefaultChannelPromise(channel, ImmediateEventExecutor.INSTANCE);
final ChannelPromise promise2 =
new DefaultChannelPromise(channel, ImmediateEventExecutor.INSTANCE);
promise.addListener(new ChannelFutureListener() {
@Override
public void operationComplete(ChannelFuture future) {
handler.resetStream(ctx, STREAM_ID, STREAM_CLOSED.code(), promise2);
}
});
handler.resetStream(ctx, STREAM_ID, CANCEL.code(), promise);
verify(frameWriter).writeRstStream(eq(ctx), eq(STREAM_ID), anyLong(), any(ChannelPromise.class));
assertTrue(promise.isSuccess());
assertTrue(promise2.isSuccess());
}
private void writeRstStreamUsingVoidPromise(int streamId) throws Exception { private void writeRstStreamUsingVoidPromise(int streamId) throws Exception {
handler = newHandler(); handler = newHandler();
final Throwable cause = new RuntimeException("fake exception"); final Throwable cause = new RuntimeException("fake exception");