Http/2 RST_STREAM frame robustness

Motivation:
The Http2ConnectionHandler writeRstStream method allows RST_STREAM frames to be sent when we do not know about the stream and after a RST_STREAM frame has already been sent.  This may lead to sending frames when we should not according to the HTTP/2 spec. There is also the potential to notify the closeListener multiple times if the closeStream method is called multiple times.

Modifications:
- Prevent RST_STREAM from being sent if we don't know about the stream, or if we already sent the RST_STREAM.
- Prevent the closeListener from being notified multiple times.

Result:
More robust writeRstStream logic in boundary conditions.
This commit is contained in:
Scott Mitchell 2015-03-14 12:58:30 -07:00
parent 4ff551baa2
commit bd66dca418
2 changed files with 91 additions and 11 deletions

View File

@ -15,7 +15,6 @@
package io.netty.handler.codec.http2; package io.netty.handler.codec.http2;
import static io.netty.handler.codec.http2.Http2CodecUtil.HTTP_UPGRADE_STREAM_ID; import static io.netty.handler.codec.http2.Http2CodecUtil.HTTP_UPGRADE_STREAM_ID;
import static io.netty.handler.codec.http2.Http2CodecUtil.PING_FRAME_PAYLOAD_LENGTH;
import static io.netty.handler.codec.http2.Http2CodecUtil.connectionPrefaceBuf; import static io.netty.handler.codec.http2.Http2CodecUtil.connectionPrefaceBuf;
import static io.netty.handler.codec.http2.Http2CodecUtil.getEmbeddedHttp2Exception; import static io.netty.handler.codec.http2.Http2CodecUtil.getEmbeddedHttp2Exception;
import static io.netty.handler.codec.http2.Http2Error.INTERNAL_ERROR; import static io.netty.handler.codec.http2.Http2Error.INTERNAL_ERROR;
@ -24,7 +23,6 @@ import static io.netty.handler.codec.http2.Http2Error.PROTOCOL_ERROR;
import static io.netty.handler.codec.http2.Http2Exception.connectionError; import static io.netty.handler.codec.http2.Http2Exception.connectionError;
import static io.netty.handler.codec.http2.Http2Exception.isStreamError; import static io.netty.handler.codec.http2.Http2Exception.isStreamError;
import static io.netty.util.internal.ObjectUtil.checkNotNull; import static io.netty.util.internal.ObjectUtil.checkNotNull;
import io.netty.buffer.ByteBuf; import io.netty.buffer.ByteBuf;
import io.netty.channel.ChannelFuture; import io.netty.channel.ChannelFuture;
import io.netty.channel.ChannelFutureListener; import io.netty.channel.ChannelFutureListener;
@ -279,6 +277,10 @@ public class Http2ConnectionHandler extends ByteToMessageDecoder implements Http
// If this connection is closing and there are no longer any // If this connection is closing and there are no longer any
// active streams, close after the current operation completes. // active streams, close after the current operation completes.
if (closeListener != null && connection().numActiveStreams() == 0) { if (closeListener != null && connection().numActiveStreams() == 0) {
ChannelFutureListener closeListener = Http2ConnectionHandler.this.closeListener;
// This method could be called multiple times
// and we don't want to notify the closeListener multiple times
Http2ConnectionHandler.this.closeListener = null;
closeListener.operationComplete(future); closeListener.operationComplete(future);
} }
} }
@ -339,16 +341,32 @@ public class Http2ConnectionHandler extends ByteToMessageDecoder implements Http
* Writes a {@code RST_STREAM} frame to the remote endpoint and updates the connection state appropriately. * Writes a {@code RST_STREAM} frame to the remote endpoint and updates the connection state appropriately.
*/ */
@Override @Override
public ChannelFuture writeRstStream(ChannelHandlerContext ctx, int streamId, long errorCode, public ChannelFuture writeRstStream(final ChannelHandlerContext ctx, int streamId, long errorCode,
ChannelPromise promise) { final ChannelPromise promise) {
Http2Stream stream = connection().stream(streamId); final Http2Stream stream = connection().stream(streamId);
if (stream == null || stream.isResetSent()) {
// Don't write a RST_STREAM frame if we are not aware of the stream, or if we have already written one.
return promise.setSuccess();
}
ChannelFuture future = frameWriter().writeRstStream(ctx, streamId, errorCode, promise); ChannelFuture future = frameWriter().writeRstStream(ctx, streamId, errorCode, promise);
ctx.flush(); ctx.flush();
if (stream != null) { // Synchronously set the resetSent flag to prevent any subsequent calls
// from resulting in multiple reset frames being sent.
stream.resetSent(); stream.resetSent();
future.addListener(new ChannelFutureListener() {
@Override
public void operationComplete(ChannelFuture future) throws Exception {
if (future.isSuccess()) {
closeStream(stream, promise); closeStream(stream, promise);
} else {
// The connection will be closed and so no need to change the resetSent flag to false.
onConnectionError(ctx, future.cause(), null);
} }
}
});
return future; return future;
} }

View File

@ -18,23 +18,32 @@ package io.netty.handler.codec.http2;
import static io.netty.buffer.Unpooled.copiedBuffer; import static io.netty.buffer.Unpooled.copiedBuffer;
import static io.netty.handler.codec.http2.Http2CodecUtil.connectionPrefaceBuf; import static io.netty.handler.codec.http2.Http2CodecUtil.connectionPrefaceBuf;
import static io.netty.handler.codec.http2.Http2Error.PROTOCOL_ERROR; import static io.netty.handler.codec.http2.Http2Error.PROTOCOL_ERROR;
import static io.netty.handler.codec.http2.Http2Error.STREAM_CLOSED;
import static io.netty.handler.codec.http2.Http2Stream.State.CLOSED;
import static io.netty.util.CharsetUtil.UTF_8; import static io.netty.util.CharsetUtil.UTF_8;
import static org.junit.Assert.assertNull;
import static org.junit.Assert.assertTrue;
import static org.mockito.Matchers.any; import static org.mockito.Matchers.any;
import static org.mockito.Matchers.anyBoolean; import static org.mockito.Matchers.anyBoolean;
import static org.mockito.Matchers.anyInt; import static org.mockito.Matchers.anyInt;
import static org.mockito.Matchers.anyLong;
import static org.mockito.Matchers.eq; import static org.mockito.Matchers.eq;
import static org.mockito.Mockito.doAnswer; import static org.mockito.Mockito.doAnswer;
import static org.mockito.Mockito.never; import static org.mockito.Mockito.never;
import static org.mockito.Mockito.times;
import static org.mockito.Mockito.verify; import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when; import static org.mockito.Mockito.when;
import io.netty.buffer.ByteBuf; import io.netty.buffer.ByteBuf;
import io.netty.buffer.UnpooledByteBufAllocator; import io.netty.buffer.UnpooledByteBufAllocator;
import io.netty.channel.Channel; import io.netty.channel.Channel;
import io.netty.channel.ChannelFuture; import io.netty.channel.ChannelFuture;
import io.netty.channel.ChannelFutureListener;
import io.netty.channel.ChannelHandlerContext; import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.ChannelPromise; import io.netty.channel.ChannelPromise;
import io.netty.channel.DefaultChannelPromise; import io.netty.channel.DefaultChannelPromise;
import io.netty.util.concurrent.GenericFutureListener;
import java.util.Arrays;
import java.util.Collections; import java.util.Collections;
import java.util.List; import java.util.List;
@ -53,8 +62,10 @@ import org.mockito.stubbing.Answer;
*/ */
public class Http2ConnectionHandlerTest { public class Http2ConnectionHandlerTest {
private static final int STREAM_ID = 1; private static final int STREAM_ID = 1;
private static final int NON_EXISTANT_STREAM_ID = 13;
private Http2ConnectionHandler handler; private Http2ConnectionHandler handler;
private ChannelPromise promise;
@Mock @Mock
private Http2Connection connection; private Http2Connection connection;
@ -71,8 +82,6 @@ public class Http2ConnectionHandlerTest {
@Mock @Mock
private Channel channel; private Channel channel;
private ChannelPromise promise;
@Mock @Mock
private ChannelFuture future; private ChannelFuture future;
@ -110,6 +119,8 @@ public class Http2ConnectionHandlerTest {
when(connection.remote()).thenReturn(remote); when(connection.remote()).thenReturn(remote);
when(connection.local()).thenReturn(local); when(connection.local()).thenReturn(local);
when(connection.activeStreams()).thenReturn(Collections.singletonList(stream)); when(connection.activeStreams()).thenReturn(Collections.singletonList(stream));
when(connection.stream(NON_EXISTANT_STREAM_ID)).thenReturn(null);
when(connection.stream(STREAM_ID)).thenReturn(stream);
when(stream.open(anyBoolean())).thenReturn(stream); when(stream.open(anyBoolean())).thenReturn(stream);
when(encoder.writeSettings(eq(ctx), any(Http2Settings.class), eq(promise))).thenReturn(future); when(encoder.writeSettings(eq(ctx), any(Http2Settings.class), eq(promise))).thenReturn(future);
when(ctx.alloc()).thenReturn(UnpooledByteBufAllocator.DEFAULT); when(ctx.alloc()).thenReturn(UnpooledByteBufAllocator.DEFAULT);
@ -199,4 +210,55 @@ public class Http2ConnectionHandlerTest {
verify(encoder).close(); verify(encoder).close();
verify(decoder).close(); verify(decoder).close();
} }
@Test
public void writeRstOnNonExistantStreamShouldSucceed() throws Exception {
handler = newHandler();
handler.writeRstStream(ctx, NON_EXISTANT_STREAM_ID, STREAM_CLOSED.code(), promise);
verify(frameWriter, never())
.writeRstStream(any(ChannelHandlerContext.class), anyInt(), anyLong(), any(ChannelPromise.class));
assertTrue(promise.isDone());
assertTrue(promise.isSuccess());
assertNull(promise.cause());
}
@Test
public void writeRstOnClosedStreamShouldSucceed() throws Exception {
handler = newHandler();
when(frameWriter.writeRstStream(eq(ctx), eq(STREAM_ID),
anyLong(), any(ChannelPromise.class))).thenReturn(future);
when(stream.state()).thenReturn(CLOSED);
// The stream is "closed" but is still known about by the connection (connection().stream(..)
// will return the stream). We should still write a RST_STREAM frame in this scenario.
handler.writeRstStream(ctx, STREAM_ID, STREAM_CLOSED.code(), promise);
verify(frameWriter).writeRstStream(eq(ctx), eq(STREAM_ID), anyLong(), any(ChannelPromise.class));
}
@SuppressWarnings("unchecked")
@Test
public void closeListenerShouldBeNotifiedOnlyOneTime() throws Exception {
handler = newHandler();
when(connection.activeStreams()).thenReturn(Arrays.asList(stream));
when(connection.numActiveStreams()).thenReturn(1);
when(future.isDone()).thenReturn(true);
when(future.isSuccess()).thenReturn(true);
doAnswer(new Answer<ChannelFuture>() {
@Override
public ChannelFuture answer(InvocationOnMock invocation) throws Throwable {
Object[] args = invocation.getArguments();
ChannelFutureListener listener = (ChannelFutureListener) args[0];
// Simulate that all streams have become inactive by the time the future completes
when(connection.activeStreams()).thenReturn(Collections.<Http2Stream>emptyList());
when(connection.numActiveStreams()).thenReturn(0);
// Simulate the future being completed
listener.operationComplete(future);
return future;
}
}).when(future).addListener(any(GenericFutureListener.class));
handler.close(ctx, promise);
handler.closeStream(stream, future);
// Simulate another stream close call being made after the context should already be closed
handler.closeStream(stream, future);
verify(ctx, times(1)).close(any(ChannelPromise.class));
}
} }