Server returns status 431 on header size errors

Motivation:

Currently clients attempting to send headers that are too large recieve
a RST frame. This makes it harder than needed for implementations on top
of netty to handle this in a graceful way.

Modifications:

When the Decoder throws a StreamError of type FRAME_SIZE_ERROR, the
Http2ConnectionHandler will now attempt to send an Http2Header with
status 431 and endOfStream=true

Result:

Implementations now do not have to subclass parts of netty to handle
431s
This commit is contained in:
Nikolaj Hald Nielsen 2016-11-23 11:32:41 -08:00 committed by Scott Mitchell
parent 89e93968ac
commit cd458f10bc
10 changed files with 261 additions and 17 deletions

View File

@ -374,6 +374,7 @@ public class DefaultHttp2Connection implements Http2Connection {
private DefaultStream parent;
private IntObjectMap<DefaultStream> children = IntCollections.emptyMap();
private boolean resetSent;
private boolean headersSent;
DefaultStream(int id, State state) {
this.id = id;
@ -401,6 +402,17 @@ public class DefaultHttp2Connection implements Http2Connection {
return this;
}
@Override
public Http2Stream headersSent() {
headersSent = true;
return this;
}
@Override
public boolean isHeadersSent() {
return headersSent;
}
@Override
public final <V> V setProperty(PropertyKey key, V value) {
return properties.add(verifyKey(key), value);
@ -804,6 +816,16 @@ public class DefaultHttp2Connection implements Http2Connection {
public Http2Stream closeRemoteSide() {
throw new UnsupportedOperationException();
}
@Override
public Http2Stream headersSent() {
throw new UnsupportedOperationException();
}
@Override
public boolean isHeadersSent() {
throw new UnsupportedOperationException();
}
}
/**

View File

@ -185,13 +185,19 @@ public class DefaultHttp2ConnectionEncoder implements Http2ConnectionEncoder {
};
promise = promise.unvoid().addListener(closeStreamLocalListener);
}
return frameWriter.writeHeaders(ctx, streamId, headers, streamDependency, weight,
exclusive, padding, endOfStream, promise);
ChannelFuture future = frameWriter.writeHeaders(ctx, streamId, headers, streamDependency,
weight, exclusive, padding, endOfStream, promise);
// Synchronously set the headersSent flag to ensure that we do not subsequently write
// other headers containing pseudo-header fields.
stream.headersSent();
return future;
} else {
// Pass headers to the flow-controller so it can maintain their sequence relative to DATA frames.
flowController.addFlowControlled(stream,
new FlowControlledHeaders(stream, headers, streamDependency, weight, exclusive, padding,
endOfStream, promise));
stream.headersSent();
return promise;
}
} catch (Http2NoMoreStreamIdsException e) {

View File

@ -673,7 +673,7 @@ public class DefaultHttp2FrameReader implements Http2FrameReader, Http2FrameSize
*/
private void headerSizeExceeded() throws Http2Exception {
close();
headerListSizeExceeded(streamId, headersDecoder.configuration().headerTable().maxHeaderListSize());
headerListSizeExceeded(streamId, headersDecoder.configuration().headerTable().maxHeaderListSize(), true);
}
/**

View File

@ -30,7 +30,7 @@ import io.netty.util.internal.UnstableApi;
import static io.netty.buffer.Unpooled.directBuffer;
import static io.netty.buffer.Unpooled.unreleasableBuffer;
import static io.netty.handler.codec.http2.Http2Error.PROTOCOL_ERROR;
import static io.netty.handler.codec.http2.Http2Exception.streamError;
import static io.netty.handler.codec.http2.Http2Exception.headerListSizeError;
import static io.netty.util.CharsetUtil.UTF_8;
import static java.lang.Math.max;
import static java.lang.Math.min;
@ -223,8 +223,10 @@ public final class Http2CodecUtil {
return max(0, min(state.pendingBytes(), state.windowSize()));
}
public static void headerListSizeExceeded(int streamId, long maxHeaderListSize) throws Http2Exception {
throw streamError(streamId, PROTOCOL_ERROR, "Header size exceeded max allowed size (%d)", maxHeaderListSize);
public static void headerListSizeExceeded(int streamId, long maxHeaderListSize,
boolean onDecode) throws Http2Exception {
throw headerListSizeError(streamId, PROTOCOL_ERROR, onDecode, "Header size exceeded max " +
"allowed size (%d)", maxHeaderListSize);
}
static void writeFrameHeaderInternal(ByteBuf out, int payloadLength, byte type,

View File

@ -22,6 +22,7 @@ import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.ChannelOutboundHandler;
import io.netty.channel.ChannelPromise;
import io.netty.handler.codec.ByteToMessageDecoder;
import io.netty.handler.codec.http.HttpResponseStatus;
import io.netty.handler.codec.http2.Http2Exception.CompositeStreamException;
import io.netty.handler.codec.http2.Http2Exception.StreamException;
import io.netty.util.concurrent.ScheduledFuture;
@ -65,6 +66,9 @@ public class Http2ConnectionHandler extends ByteToMessageDecoder implements Http
private static final InternalLogger logger = InternalLoggerFactory.getInstance(Http2ConnectionHandler.class);
private static final Http2Headers headersTooLarge = new DefaultHttp2Headers().status(
HttpResponseStatus.REQUEST_HEADER_FIELDS_TOO_LARGE.codeAsText());
private final Http2ConnectionDecoder decoder;
private final Http2ConnectionEncoder encoder;
private final Http2Settings initialSettings;
@ -600,7 +604,51 @@ public class Http2ConnectionHandler extends ByteToMessageDecoder implements Http
*/
protected void onStreamError(ChannelHandlerContext ctx, @SuppressWarnings("unused") Throwable cause,
StreamException http2Ex) {
resetStream(ctx, http2Ex.streamId(), http2Ex.error().code(), ctx.newPromise());
final int streamId = http2Ex.streamId();
Http2Stream stream = connection().stream(streamId);
//if this is caused by reading headers that are too large, send a header with status 431
if (http2Ex instanceof Http2Exception.HeaderListSizeException &&
((Http2Exception.HeaderListSizeException) http2Ex).duringDecode() &&
connection().isServer()) {
// NOTE We have to check to make sure that a stream exists before we send our reply.
// We likely always create the stream below as the stream isn't created until the
// header block is completely processed.
// The case of a streamId referring to a stream which was already closed is handled
// by createStream and will land us in the catch block below
if (stream == null) {
try {
stream = encoder.connection().remote().createStream(streamId, true);
} catch (Http2Exception e) {
resetUnknownStream(ctx, streamId, http2Ex.error().code(), ctx.newPromise());
return;
}
}
// ensure that we have not already sent headers on this stream
if (stream != null && !stream.isHeadersSent()) {
handleServerHeaderDecodeSizeError(ctx, stream);
}
}
if (stream == null) {
resetUnknownStream(ctx, streamId, http2Ex.error().code(), ctx.newPromise());
} else {
resetStream(ctx, stream, http2Ex.error().code(), ctx.newPromise());
}
}
/**
* Notifies client that this server has received headers that are larger than what it is
* willing to accept. Override to change behavior.
*
* @param ctx the channel context
* @param stream the Http2Stream on which the header was received
*/
protected void handleServerHeaderDecodeSizeError(ChannelHandlerContext ctx, Http2Stream stream) {
encoder().writeHeaders(ctx, stream.id(), headersTooLarge, 0, true, ctx.newPromise());
}
protected Http2FrameWriter frameWriter() {
@ -631,12 +679,17 @@ public class Http2ConnectionHandler extends ByteToMessageDecoder implements Http
@Override
public ChannelFuture resetStream(final ChannelHandlerContext ctx, int streamId, long errorCode,
ChannelPromise promise) {
promise = promise.unvoid();
final Http2Stream stream = connection().stream(streamId);
if (stream == null) {
return resetUnknownStream(ctx, streamId, errorCode, promise);
return resetUnknownStream(ctx, streamId, errorCode, promise.unvoid());
}
return resetStream(ctx, stream, errorCode, promise);
}
private ChannelFuture resetStream(final ChannelHandlerContext ctx, final Http2Stream stream,
long errorCode, ChannelPromise promise) {
promise = promise.unvoid();
if (stream.isResetSent()) {
// Don't write a RST_STREAM frame if we have already written one.
return promise.setSuccess();
@ -646,7 +699,7 @@ public class Http2ConnectionHandler extends ByteToMessageDecoder implements Http
// We cannot write RST_STREAM frames on IDLE streams https://tools.ietf.org/html/rfc7540#section-6.4.
future = promise.setSuccess();
} else {
future = frameWriter().writeRstStream(ctx, streamId, errorCode, promise);
future = frameWriter().writeRstStream(ctx, stream.id(), errorCode, promise);
}
// Synchronously set the resetSent flag to prevent any subsequent calls

View File

@ -149,6 +149,28 @@ public class Http2Exception extends Exception {
new StreamException(id, error, String.format(fmt, args), cause);
}
/**
* A specific stream error resulting from failing to decode headers that exceeds the max header size list.
* If the {@code id} is not {@link Http2CodecUtil#CONNECTION_STREAM_ID} then a
* {@link Http2Exception.StreamException} will be returned. Otherwise the error is considered a
* connection error and a {@link Http2Exception} is returned.
* @param id The stream id for which the error is isolated to.
* @param error The type of error as defined by the HTTP/2 specification.
* @param onDecode Whether this error was caught while decoding headers
* @param fmt String with the content and format for the additional debug data.
* @param args Objects which fit into the format defined by {@code fmt}.
* @return If the {@code id} is not
* {@link Http2CodecUtil#CONNECTION_STREAM_ID} then a {@link HeaderListSizeException}
* will be returned. Otherwise the error is considered a connection error and a {@link Http2Exception} is
* returned.
*/
public static Http2Exception headerListSizeError(int id, Http2Error error, boolean onDecode,
String fmt, Object... args) {
return CONNECTION_STREAM_ID == id ?
Http2Exception.connectionError(error, fmt, args) :
new HeaderListSizeException(id, error, String.format(fmt, args), onDecode);
}
/**
* Check if an exception is isolated to a single stream or the entire connection.
* @param e The exception to check.
@ -210,7 +232,7 @@ public class Http2Exception extends Exception {
/**
* Represents an exception that can be isolated to a single stream (as opposed to the entire connection).
*/
public static final class StreamException extends Http2Exception {
public static class StreamException extends Http2Exception {
private static final long serialVersionUID = 602472544416984384L;
private final int streamId;
@ -229,6 +251,19 @@ public class Http2Exception extends Exception {
}
}
public static final class HeaderListSizeException extends StreamException {
private final boolean decode;
HeaderListSizeException(int streamId, Http2Error error, String message, boolean decode) {
super(streamId, error, message);
this.decode = decode;
}
public boolean duringDecode() {
return decode;
}
}
/**
* Provides the ability to handle multiple stream exceptions with one throw statement.
*/

View File

@ -183,4 +183,14 @@ public interface Http2Stream {
* @return The stream before iteration stopped or {@code null} if iteration went past the end.
*/
Http2Stream forEachChild(Http2StreamVisitor visitor) throws Http2Exception;
/**
* Indicates that headers has been sent to the remote on this stream.
*/
Http2Stream headersSent();
/**
* Indicates whether or not headers was sent to the remote endpoint.
*/
boolean isHeadersSent();
}

View File

@ -207,7 +207,7 @@ public final class Decoder {
state = READ_LITERAL_HEADER_NAME_LENGTH;
} else {
if (index > maxHeaderListSize - headersLength) {
headerListSizeExceeded(streamId, maxHeaderListSize);
headerListSizeExceeded(streamId, maxHeaderListSize, true);
}
nameLength = index;
state = READ_LITERAL_HEADER_NAME;
@ -219,7 +219,7 @@ public final class Decoder {
nameLength = decodeULE128(in, index);
if (nameLength > maxHeaderListSize - headersLength) {
headerListSizeExceeded(streamId, maxHeaderListSize);
headerListSizeExceeded(streamId, maxHeaderListSize, true);
}
state = READ_LITERAL_HEADER_NAME;
break;
@ -251,7 +251,7 @@ public final class Decoder {
default:
// Check new header size against max header size
if ((long) index + nameLength > maxHeaderListSize - headersLength) {
headerListSizeExceeded(streamId, maxHeaderListSize);
headerListSizeExceeded(streamId, maxHeaderListSize, true);
}
valueLength = index;
state = READ_LITERAL_HEADER_VALUE;
@ -265,7 +265,7 @@ public final class Decoder {
// Check new header size against max header size
if ((long) valueLength + nameLength > maxHeaderListSize - headersLength) {
headerListSizeExceeded(streamId, maxHeaderListSize);
headerListSizeExceeded(streamId, maxHeaderListSize, true);
}
state = READ_LITERAL_HEADER_VALUE;
break;
@ -403,7 +403,7 @@ public final class Decoder {
long headersLength) throws Http2Exception {
headersLength += name.length() + value.length();
if (headersLength > maxHeaderListSize) {
headerListSizeExceeded(streamId, maxHeaderListSize);
headerListSizeExceeded(streamId, maxHeaderListSize, true);
}
headers.add(name, value);
return headersLength;

View File

@ -124,7 +124,7 @@ public final class Encoder {
// overflow.
headerSize += currHeaderSize;
if (headerSize > maxHeaderListSize) {
headerListSizeExceeded(streamId, maxHeaderListSize);
headerListSizeExceeded(streamId, maxHeaderListSize, false);
}
encodeHeader(out, name, value, sensitivityDetector.isSensitive(name, value), currHeaderSize);
}

View File

@ -25,6 +25,7 @@ import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.ChannelPipeline;
import io.netty.channel.ChannelPromise;
import io.netty.channel.DefaultChannelPromise;
import io.netty.handler.codec.http.HttpResponseStatus;
import io.netty.handler.codec.http2.Http2CodecUtil.SimpleChannelPromiseAggregator;
import io.netty.util.ReferenceCountUtil;
import io.netty.util.concurrent.EventExecutor;
@ -47,6 +48,7 @@ import static io.netty.buffer.Unpooled.copiedBuffer;
import static io.netty.handler.codec.http2.Http2CodecUtil.connectionPrefaceBuf;
import static io.netty.handler.codec.http2.Http2Error.PROTOCOL_ERROR;
import static io.netty.handler.codec.http2.Http2Error.STREAM_CLOSED;
import static io.netty.handler.codec.http2.Http2Exception.connectionError;
import static io.netty.handler.codec.http2.Http2Stream.State.CLOSED;
import static io.netty.handler.codec.http2.Http2Stream.State.IDLE;
import static io.netty.handler.codec.http2.Http2TestUtil.newVoidPromise;
@ -306,6 +308,118 @@ public class Http2ConnectionHandlerTest {
captor.getValue().release();
}
@Test
public void serverShouldSend431OnHeaderSizeErrorWhenDecodingInitialHeaders() throws Exception {
int padding = 0;
handler = newHandler();
Http2Exception e = new Http2Exception.HeaderListSizeException(STREAM_ID, Http2Error.PROTOCOL_ERROR,
"Header size exceeded max allowed size 8196", true);
when(stream.id()).thenReturn(STREAM_ID);
when(connection.isServer()).thenReturn(true);
when(stream.isHeadersSent()).thenReturn(false);
when(remote.lastStreamCreated()).thenReturn(STREAM_ID);
when(frameWriter.writeRstStream(eq(ctx), eq(STREAM_ID),
eq(Http2Error.PROTOCOL_ERROR.code()), eq(promise))).thenReturn(future);
handler.exceptionCaught(ctx, e);
ArgumentCaptor<Http2Headers> captor = ArgumentCaptor.forClass(Http2Headers.class);
verify(encoder).writeHeaders(eq(ctx), eq(STREAM_ID),
captor.capture(), eq(padding), eq(true), eq(promise));
Http2Headers headers = captor.getValue();
assertEquals(HttpResponseStatus.REQUEST_HEADER_FIELDS_TOO_LARGE.codeAsText(), headers.status());
verify(frameWriter).writeRstStream(ctx, STREAM_ID, Http2Error.PROTOCOL_ERROR.code(), promise);
}
@Test
public void serverShouldNeverSend431HeaderSizeErrorWhenEncoding() throws Exception {
int padding = 0;
handler = newHandler();
Http2Exception e = new Http2Exception.HeaderListSizeException(STREAM_ID, Http2Error.PROTOCOL_ERROR,
"Header size exceeded max allowed size 8196", false);
when(stream.id()).thenReturn(STREAM_ID);
when(connection.isServer()).thenReturn(true);
when(stream.isHeadersSent()).thenReturn(false);
when(remote.lastStreamCreated()).thenReturn(STREAM_ID);
when(frameWriter.writeRstStream(eq(ctx), eq(STREAM_ID),
eq(Http2Error.PROTOCOL_ERROR.code()), eq(promise))).thenReturn(future);
handler.exceptionCaught(ctx, e);
verify(encoder, never()).writeHeaders(eq(ctx), eq(STREAM_ID),
any(Http2Headers.class), eq(padding), eq(true), eq(promise));
verify(frameWriter).writeRstStream(ctx, STREAM_ID, Http2Error.PROTOCOL_ERROR.code(), promise);
}
@Test
public void clientShouldNeverSend431WhenHeadersAreTooLarge() throws Exception {
int padding = 0;
handler = newHandler();
Http2Exception e = new Http2Exception.HeaderListSizeException(STREAM_ID, Http2Error.PROTOCOL_ERROR,
"Header size exceeded max allowed size 8196", true);
when(stream.id()).thenReturn(STREAM_ID);
when(connection.isServer()).thenReturn(false);
when(stream.isHeadersSent()).thenReturn(false);
when(remote.lastStreamCreated()).thenReturn(STREAM_ID);
when(frameWriter.writeRstStream(eq(ctx), eq(STREAM_ID),
eq(Http2Error.PROTOCOL_ERROR.code()), eq(promise))).thenReturn(future);
handler.exceptionCaught(ctx, e);
verify(encoder, never()).writeHeaders(eq(ctx), eq(STREAM_ID),
any(Http2Headers.class), eq(padding), eq(true), eq(promise));
verify(frameWriter).writeRstStream(ctx, STREAM_ID, Http2Error.PROTOCOL_ERROR.code(), promise);
}
@Test
public void serverShouldNeverSend431IfHeadersAlreadySent() throws Exception {
int padding = 0;
handler = newHandler();
Http2Exception e = new Http2Exception.HeaderListSizeException(STREAM_ID, Http2Error.PROTOCOL_ERROR,
"Header size exceeded max allowed size 8196", true);
when(stream.id()).thenReturn(STREAM_ID);
when(connection.isServer()).thenReturn(true);
when(stream.isHeadersSent()).thenReturn(true);
when(remote.lastStreamCreated()).thenReturn(STREAM_ID);
when(frameWriter.writeRstStream(eq(ctx), eq(STREAM_ID),
eq(Http2Error.PROTOCOL_ERROR.code()), eq(promise))).thenReturn(future);
handler.exceptionCaught(ctx, e);
verify(encoder, never()).writeHeaders(eq(ctx), eq(STREAM_ID),
any(Http2Headers.class), eq(padding), eq(true), eq(promise));
verify(frameWriter).writeRstStream(ctx, STREAM_ID, Http2Error.PROTOCOL_ERROR.code(), promise);
}
@Test
public void serverShouldCreateStreamIfNeededBeforeSending431() throws Exception {
int padding = 0;
handler = newHandler();
Http2Exception e = new Http2Exception.HeaderListSizeException(STREAM_ID, Http2Error.PROTOCOL_ERROR,
"Header size exceeded max allowed size 8196", true);
when(connection.stream(STREAM_ID)).thenReturn(null);
when(remote.createStream(STREAM_ID, true)).thenReturn(stream);
when(stream.id()).thenReturn(STREAM_ID);
when(connection.isServer()).thenReturn(true);
when(stream.isHeadersSent()).thenReturn(false);
when(remote.lastStreamCreated()).thenReturn(STREAM_ID);
when(frameWriter.writeRstStream(eq(ctx), eq(STREAM_ID),
eq(Http2Error.PROTOCOL_ERROR.code()), eq(promise))).thenReturn(future);
handler.exceptionCaught(ctx, e);
verify(remote).createStream(STREAM_ID, true);
verify(encoder).writeHeaders(eq(ctx), eq(STREAM_ID),
any(Http2Headers.class), eq(padding), eq(true), eq(promise));
verify(frameWriter).writeRstStream(ctx, STREAM_ID, Http2Error.PROTOCOL_ERROR.code(), promise);
}
@Test
public void encoderAndDecoderAreClosedOnChannelInactive() throws Exception {
handler = newHandler();
@ -328,6 +442,7 @@ public class Http2ConnectionHandlerTest {
@Test
public void writeRstOnClosedStreamShouldSucceed() throws Exception {
handler = newHandler();
when(stream.id()).thenReturn(STREAM_ID);
when(frameWriter.writeRstStream(eq(ctx), eq(STREAM_ID),
anyLong(), any(ChannelPromise.class))).thenReturn(future);
when(stream.state()).thenReturn(CLOSED);
@ -499,6 +614,7 @@ public class Http2ConnectionHandlerTest {
private void writeRstStreamUsingVoidPromise(int streamId) throws Exception {
handler = newHandler();
final Throwable cause = new RuntimeException("fake exception");
when(stream.id()).thenReturn(STREAM_ID);
when(frameWriter.writeRstStream(eq(ctx), eq(streamId), anyLong(), any(ChannelPromise.class)))
.then(new Answer<ChannelFuture>() {
@Override