diff --git a/codec-http2/src/main/java/io/netty/handler/codec/http2/HpackDecoder.java b/codec-http2/src/main/java/io/netty/handler/codec/http2/HpackDecoder.java index 078c079891..9c680e6a99 100644 --- a/codec-http2/src/main/java/io/netty/handler/codec/http2/HpackDecoder.java +++ b/codec-http2/src/main/java/io/netty/handler/codec/http2/HpackDecoder.java @@ -44,6 +44,7 @@ import static io.netty.handler.codec.http2.Http2CodecUtil.headerListSizeExceeded import static io.netty.handler.codec.http2.Http2Error.COMPRESSION_ERROR; import static io.netty.handler.codec.http2.Http2Error.PROTOCOL_ERROR; import static io.netty.handler.codec.http2.Http2Exception.connectionError; +import static io.netty.handler.codec.http2.Http2Exception.streamError; import static io.netty.handler.codec.http2.Http2Headers.PseudoHeaderName.getPseudoHeader; import static io.netty.handler.codec.http2.Http2Headers.PseudoHeaderName.hasPseudoHeaderFormat; import static io.netty.util.AsciiString.EMPTY_STRING; @@ -119,24 +120,21 @@ final class HpackDecoder { * This method assumes the entire header block is contained in {@code in}. */ public void decode(int streamId, ByteBuf in, Http2Headers headers, boolean validateHeaders) throws Http2Exception { - Http2HeadersSink sink = new Http2HeadersSink(headers, maxHeaderListSize); - decode(in, sink, validateHeaders); + Http2HeadersSink sink = new Http2HeadersSink(streamId, headers, maxHeaderListSize, validateHeaders); + decode(in, sink); - // we have read all of our headers. See if we have exceeded our maxHeaderListSize. We must - // delay throwing until this point to prevent dynamic table corruption - if (sink.exceededMaxLength()) { - headerListSizeExceeded(streamId, maxHeaderListSize, true); - } + // Now that we've read all of our headers we can perform the validation steps. We must + // delay throwing until this point to prevent dynamic table corruption. + sink.finish(); } - private void decode(ByteBuf in, Sink sink, boolean validateHeaders) throws Http2Exception { + private void decode(ByteBuf in, Sink sink) throws Http2Exception { int index = 0; int nameLength = 0; int valueLength = 0; byte state = READ_HEADER_REPRESENTATION; boolean huffmanEncoded = false; CharSequence name = null; - HeaderType headerType = null; IndexType indexType = IndexType.NONE; while (in.isReadable()) { switch (state) { @@ -157,7 +155,6 @@ final class HpackDecoder { break; default: HpackHeaderField indexedHeader = getIndexedHeader(index); - headerType = validate(indexedHeader.name, headerType, validateHeaders); sink.appendToHeaderList(indexedHeader.name, indexedHeader.value); } } else if ((b & 0x40) == 0x40) { @@ -174,7 +171,6 @@ final class HpackDecoder { default: // Index was stored as the prefix name = readName(index); - headerType = validate(name, headerType, validateHeaders); nameLength = name.length(); state = READ_LITERAL_HEADER_VALUE_LENGTH_PREFIX; } @@ -201,7 +197,6 @@ final class HpackDecoder { default: // Index was stored as the prefix name = readName(index); - headerType = validate(name, headerType, validateHeaders); nameLength = name.length(); state = READ_LITERAL_HEADER_VALUE_LENGTH_PREFIX; } @@ -215,7 +210,6 @@ final class HpackDecoder { case READ_INDEXED_HEADER: HpackHeaderField indexedHeader = getIndexedHeader(decodeULE128(in, index)); - headerType = validate(indexedHeader.name, headerType, validateHeaders); sink.appendToHeaderList(indexedHeader.name, indexedHeader.value); state = READ_HEADER_REPRESENTATION; break; @@ -223,7 +217,6 @@ final class HpackDecoder { case READ_INDEXED_HEADER_NAME: // Header Name matches an entry in the Header Table name = readName(decodeULE128(in, index)); - headerType = validate(name, headerType, validateHeaders); nameLength = name.length(); state = READ_LITERAL_HEADER_VALUE_LENGTH_PREFIX; break; @@ -254,7 +247,6 @@ final class HpackDecoder { } name = readStringLiteral(in, nameLength, huffmanEncoded); - headerType = validate(name, headerType, validateHeaders); state = READ_LITERAL_HEADER_VALUE_LENGTH_PREFIX; break; @@ -268,7 +260,6 @@ final class HpackDecoder { state = READ_LITERAL_HEADER_VALUE_LENGTH; break; case 0: - headerType = validate(name, headerType, validateHeaders); insertHeader(sink, name, EMPTY_STRING, indexType); state = READ_HEADER_REPRESENTATION; break; @@ -293,7 +284,6 @@ final class HpackDecoder { } CharSequence value = readStringLiteral(in, valueLength, huffmanEncoded); - headerType = validate(name, headerType, validateHeaders); insertHeader(sink, name, value, indexType); state = READ_HEADER_REPRESENTATION; break; @@ -327,7 +317,7 @@ final class HpackDecoder { } /** - * @deprecated use {@link #setmaxHeaderListSize(long)}; {@code maxHeaderListSizeGoAway} is + * @deprecated use {@link #setMaxHeaderListSize(long)}; {@code maxHeaderListSizeGoAway} is * ignored */ @Deprecated @@ -385,26 +375,23 @@ final class HpackDecoder { hpackDynamicTable.setCapacity(dynamicTableSize); } - private HeaderType validate(CharSequence name, HeaderType previousHeaderType, - final boolean validateHeaders) throws Http2Exception { - if (!validateHeaders) { - return null; - } - + private static HeaderType validate(int streamId, CharSequence name, + HeaderType previousHeaderType) throws Http2Exception { if (hasPseudoHeaderFormat(name)) { if (previousHeaderType == HeaderType.REGULAR_HEADER) { - throw connectionError(PROTOCOL_ERROR, "Pseudo-header field '%s' found after regular header.", name); + throw streamError(streamId, PROTOCOL_ERROR, + "Pseudo-header field '%s' found after regular header.", name); } final Http2Headers.PseudoHeaderName pseudoHeader = getPseudoHeader(name); if (pseudoHeader == null) { - throw connectionError(PROTOCOL_ERROR, "Invalid HTTP/2 pseudo-header '%s' encountered.", name); + throw streamError(streamId, PROTOCOL_ERROR, "Invalid HTTP/2 pseudo-header '%s' encountered.", name); } final HeaderType currentHeaderType = pseudoHeader.isRequestOnly() ? HeaderType.REQUEST_PSEUDO_HEADER : HeaderType.RESPONSE_PSEUDO_HEADER; if (previousHeaderType != null && currentHeaderType != previousHeaderType) { - throw connectionError(PROTOCOL_ERROR, "Mix of request and response pseudo-headers."); + throw streamError(streamId, PROTOCOL_ERROR, "Mix of request and response pseudo-headers."); } return currentHeaderType; @@ -435,8 +422,7 @@ final class HpackDecoder { throw INDEX_HEADER_ILLEGAL_INDEX_VALUE; } - private void insertHeader(Sink sink, CharSequence name, CharSequence value, - IndexType indexType) throws Http2Exception { + private void insertHeader(Sink sink, CharSequence name, CharSequence value, IndexType indexType) { sink.appendToHeaderList(name, value); switch (indexType) { @@ -529,32 +515,55 @@ final class HpackDecoder { private interface Sink { void appendToHeaderList(CharSequence name, CharSequence value); + void finish() throws Http2Exception; } private static final class Http2HeadersSink implements Sink { private final Http2Headers headers; private final long maxHeaderListSize; + private final int streamId; + private final boolean validate; private long headersLength; private boolean exceededMaxLength; + private HeaderType previousType; + private Http2Exception validationException; - public Http2HeadersSink(Http2Headers headers, long maxHeaderListSize) { + public Http2HeadersSink(int streamId, Http2Headers headers, long maxHeaderListSize, boolean validate) { this.headers = headers; this.maxHeaderListSize = maxHeaderListSize; + this.streamId = streamId; + this.validate = validate; + } + + @Override + public void finish() throws Http2Exception { + if (exceededMaxLength) { + headerListSizeExceeded(streamId, maxHeaderListSize, true); + } else if (validationException != null) { + throw validationException; + } } @Override public void appendToHeaderList(CharSequence name, CharSequence value) { headersLength += HpackHeaderField.sizeOf(name, value); - if (headersLength > maxHeaderListSize) { - exceededMaxLength = true; - } - if (!exceededMaxLength) { - headers.add(name, value); - } - } + exceededMaxLength |= headersLength > maxHeaderListSize; - public boolean exceededMaxLength() { - return exceededMaxLength; + if (exceededMaxLength || validationException != null) { + // We don't store the header since we've already failed validation requirements. + return; + } + + if (validate) { + try { + previousType = HpackDecoder.validate(streamId, name, previousType); + } catch (Http2Exception ex) { + validationException = ex; + return; + } + } + + headers.add(name, value); } } } diff --git a/codec-http2/src/test/java/io/netty/handler/codec/http2/HpackDecoderTest.java b/codec-http2/src/test/java/io/netty/handler/codec/http2/HpackDecoderTest.java index ba0117553e..994fef6f37 100644 --- a/codec-http2/src/test/java/io/netty/handler/codec/http2/HpackDecoderTest.java +++ b/codec-http2/src/test/java/io/netty/handler/codec/http2/HpackDecoderTest.java @@ -544,7 +544,7 @@ public class HpackDecoderTest { Http2Headers decoded = new DefaultHttp2Headers(); - expectedException.expect(Http2Exception.class); + expectedException.expect(Http2Exception.StreamException.class); hpackDecoder.decode(1, in, decoded, true); } finally { in.release(); @@ -588,7 +588,7 @@ public class HpackDecoderTest { Http2Headers decoded = new DefaultHttp2Headers(); - expectedException.expect(Http2Exception.class); + expectedException.expect(Http2Exception.StreamException.class); hpackDecoder.decode(1, in, decoded, true); } finally { in.release(); @@ -608,7 +608,7 @@ public class HpackDecoderTest { Http2Headers decoded = new DefaultHttp2Headers(); - expectedException.expect(Http2Exception.class); + expectedException.expect(Http2Exception.StreamException.class); hpackDecoder.decode(1, in, decoded, true); } finally { in.release(); @@ -628,10 +628,47 @@ public class HpackDecoderTest { Http2Headers decoded = new DefaultHttp2Headers(); - expectedException.expect(Http2Exception.class); + expectedException.expect(Http2Exception.StreamException.class); hpackDecoder.decode(1, in, decoded, true); } finally { in.release(); } } + + @Test + public void failedValidationDoesntCorruptHpack() throws Exception { + ByteBuf in1 = Unpooled.buffer(200); + ByteBuf in2 = Unpooled.buffer(200); + try { + HpackEncoder hpackEncoder = new HpackEncoder(true); + + Http2Headers toEncode = new DefaultHttp2Headers(); + toEncode.add(":method", "GET"); + toEncode.add(":status", "200"); + toEncode.add("foo", "bar"); + hpackEncoder.encodeHeaders(1, in1, toEncode, NEVER_SENSITIVE); + + Http2Headers decoded = new DefaultHttp2Headers(); + + try { + hpackDecoder.decode(1, in1, decoded, true); + fail("Should have thrown a StreamException"); + } catch (Http2Exception.StreamException expected) { + assertEquals(1, expected.streamId()); + } + + // Do it again, this time without validation, to make sure the HPACK state is still sane. + decoded.clear(); + hpackEncoder.encodeHeaders(1, in2, toEncode, NEVER_SENSITIVE); + hpackDecoder.decode(1, in2, decoded, false); + + assertEquals(3, decoded.size()); + assertEquals("GET", decoded.method().toString()); + assertEquals("200", decoded.status().toString()); + assertEquals("bar", decoded.get("foo").toString()); + } finally { + in1.release(); + in2.release(); + } + } }