HpackDecoder treats invalid pseudo-headers as stream level errors

Motivation:

The HTTP/2 spec dictates that invalid pseudo-headers should cause the
request/response to be treated as malformed (8.1.2.1), and the recourse
for that is to treat the situation as a stream error of type
PROTOCOL_ERROR (8.1.2.6). However, we're treating them as a connection
error with the connection being immediately torn down and the HPACK
state potentially being corrupted.

Modifications:

The HpackDecoder now throws a StreamException for validation failures
and throwing is deffered until the end of of the decode phase to ensure
that the HPACK state isn't corrupted by returning early.

Result:

Behavior more closely aligned with the HTTP/2 spec.

Fixes #8043.
This commit is contained in:
Bryce Anderson 2018-06-22 10:07:08 -06:00 committed by Norman Maurer
parent 5e42e758be
commit 8f01259833
2 changed files with 89 additions and 43 deletions

View File

@ -44,6 +44,7 @@ import static io.netty.handler.codec.http2.Http2CodecUtil.headerListSizeExceeded
import static io.netty.handler.codec.http2.Http2Error.COMPRESSION_ERROR; import static io.netty.handler.codec.http2.Http2Error.COMPRESSION_ERROR;
import static io.netty.handler.codec.http2.Http2Error.PROTOCOL_ERROR; import static io.netty.handler.codec.http2.Http2Error.PROTOCOL_ERROR;
import static io.netty.handler.codec.http2.Http2Exception.connectionError; import static io.netty.handler.codec.http2.Http2Exception.connectionError;
import static io.netty.handler.codec.http2.Http2Exception.streamError;
import static io.netty.handler.codec.http2.Http2Headers.PseudoHeaderName.getPseudoHeader; import static io.netty.handler.codec.http2.Http2Headers.PseudoHeaderName.getPseudoHeader;
import static io.netty.handler.codec.http2.Http2Headers.PseudoHeaderName.hasPseudoHeaderFormat; import static io.netty.handler.codec.http2.Http2Headers.PseudoHeaderName.hasPseudoHeaderFormat;
import static io.netty.util.AsciiString.EMPTY_STRING; import static io.netty.util.AsciiString.EMPTY_STRING;
@ -119,24 +120,21 @@ final class HpackDecoder {
* This method assumes the entire header block is contained in {@code in}. * This method assumes the entire header block is contained in {@code in}.
*/ */
public void decode(int streamId, ByteBuf in, Http2Headers headers, boolean validateHeaders) throws Http2Exception { public void decode(int streamId, ByteBuf in, Http2Headers headers, boolean validateHeaders) throws Http2Exception {
Http2HeadersSink sink = new Http2HeadersSink(headers, maxHeaderListSize); Http2HeadersSink sink = new Http2HeadersSink(streamId, headers, maxHeaderListSize, validateHeaders);
decode(in, sink, validateHeaders); decode(in, sink);
// we have read all of our headers. See if we have exceeded our maxHeaderListSize. We must // Now that we've read all of our headers we can perform the validation steps. We must
// delay throwing until this point to prevent dynamic table corruption // delay throwing until this point to prevent dynamic table corruption.
if (sink.exceededMaxLength()) { sink.finish();
headerListSizeExceeded(streamId, maxHeaderListSize, true);
}
} }
private void decode(ByteBuf in, Sink sink, boolean validateHeaders) throws Http2Exception { private void decode(ByteBuf in, Sink sink) throws Http2Exception {
int index = 0; int index = 0;
int nameLength = 0; int nameLength = 0;
int valueLength = 0; int valueLength = 0;
byte state = READ_HEADER_REPRESENTATION; byte state = READ_HEADER_REPRESENTATION;
boolean huffmanEncoded = false; boolean huffmanEncoded = false;
CharSequence name = null; CharSequence name = null;
HeaderType headerType = null;
IndexType indexType = IndexType.NONE; IndexType indexType = IndexType.NONE;
while (in.isReadable()) { while (in.isReadable()) {
switch (state) { switch (state) {
@ -157,7 +155,6 @@ final class HpackDecoder {
break; break;
default: default:
HpackHeaderField indexedHeader = getIndexedHeader(index); HpackHeaderField indexedHeader = getIndexedHeader(index);
headerType = validate(indexedHeader.name, headerType, validateHeaders);
sink.appendToHeaderList(indexedHeader.name, indexedHeader.value); sink.appendToHeaderList(indexedHeader.name, indexedHeader.value);
} }
} else if ((b & 0x40) == 0x40) { } else if ((b & 0x40) == 0x40) {
@ -174,7 +171,6 @@ final class HpackDecoder {
default: default:
// Index was stored as the prefix // Index was stored as the prefix
name = readName(index); name = readName(index);
headerType = validate(name, headerType, validateHeaders);
nameLength = name.length(); nameLength = name.length();
state = READ_LITERAL_HEADER_VALUE_LENGTH_PREFIX; state = READ_LITERAL_HEADER_VALUE_LENGTH_PREFIX;
} }
@ -201,7 +197,6 @@ final class HpackDecoder {
default: default:
// Index was stored as the prefix // Index was stored as the prefix
name = readName(index); name = readName(index);
headerType = validate(name, headerType, validateHeaders);
nameLength = name.length(); nameLength = name.length();
state = READ_LITERAL_HEADER_VALUE_LENGTH_PREFIX; state = READ_LITERAL_HEADER_VALUE_LENGTH_PREFIX;
} }
@ -215,7 +210,6 @@ final class HpackDecoder {
case READ_INDEXED_HEADER: case READ_INDEXED_HEADER:
HpackHeaderField indexedHeader = getIndexedHeader(decodeULE128(in, index)); HpackHeaderField indexedHeader = getIndexedHeader(decodeULE128(in, index));
headerType = validate(indexedHeader.name, headerType, validateHeaders);
sink.appendToHeaderList(indexedHeader.name, indexedHeader.value); sink.appendToHeaderList(indexedHeader.name, indexedHeader.value);
state = READ_HEADER_REPRESENTATION; state = READ_HEADER_REPRESENTATION;
break; break;
@ -223,7 +217,6 @@ final class HpackDecoder {
case READ_INDEXED_HEADER_NAME: case READ_INDEXED_HEADER_NAME:
// Header Name matches an entry in the Header Table // Header Name matches an entry in the Header Table
name = readName(decodeULE128(in, index)); name = readName(decodeULE128(in, index));
headerType = validate(name, headerType, validateHeaders);
nameLength = name.length(); nameLength = name.length();
state = READ_LITERAL_HEADER_VALUE_LENGTH_PREFIX; state = READ_LITERAL_HEADER_VALUE_LENGTH_PREFIX;
break; break;
@ -254,7 +247,6 @@ final class HpackDecoder {
} }
name = readStringLiteral(in, nameLength, huffmanEncoded); name = readStringLiteral(in, nameLength, huffmanEncoded);
headerType = validate(name, headerType, validateHeaders);
state = READ_LITERAL_HEADER_VALUE_LENGTH_PREFIX; state = READ_LITERAL_HEADER_VALUE_LENGTH_PREFIX;
break; break;
@ -268,7 +260,6 @@ final class HpackDecoder {
state = READ_LITERAL_HEADER_VALUE_LENGTH; state = READ_LITERAL_HEADER_VALUE_LENGTH;
break; break;
case 0: case 0:
headerType = validate(name, headerType, validateHeaders);
insertHeader(sink, name, EMPTY_STRING, indexType); insertHeader(sink, name, EMPTY_STRING, indexType);
state = READ_HEADER_REPRESENTATION; state = READ_HEADER_REPRESENTATION;
break; break;
@ -293,7 +284,6 @@ final class HpackDecoder {
} }
CharSequence value = readStringLiteral(in, valueLength, huffmanEncoded); CharSequence value = readStringLiteral(in, valueLength, huffmanEncoded);
headerType = validate(name, headerType, validateHeaders);
insertHeader(sink, name, value, indexType); insertHeader(sink, name, value, indexType);
state = READ_HEADER_REPRESENTATION; state = READ_HEADER_REPRESENTATION;
break; break;
@ -327,7 +317,7 @@ final class HpackDecoder {
} }
/** /**
* @deprecated use {@link #setmaxHeaderListSize(long)}; {@code maxHeaderListSizeGoAway} is * @deprecated use {@link #setMaxHeaderListSize(long)}; {@code maxHeaderListSizeGoAway} is
* ignored * ignored
*/ */
@Deprecated @Deprecated
@ -385,26 +375,23 @@ final class HpackDecoder {
hpackDynamicTable.setCapacity(dynamicTableSize); hpackDynamicTable.setCapacity(dynamicTableSize);
} }
private HeaderType validate(CharSequence name, HeaderType previousHeaderType, private static HeaderType validate(int streamId, CharSequence name,
final boolean validateHeaders) throws Http2Exception { HeaderType previousHeaderType) throws Http2Exception {
if (!validateHeaders) {
return null;
}
if (hasPseudoHeaderFormat(name)) { if (hasPseudoHeaderFormat(name)) {
if (previousHeaderType == HeaderType.REGULAR_HEADER) { if (previousHeaderType == HeaderType.REGULAR_HEADER) {
throw connectionError(PROTOCOL_ERROR, "Pseudo-header field '%s' found after regular header.", name); throw streamError(streamId, PROTOCOL_ERROR,
"Pseudo-header field '%s' found after regular header.", name);
} }
final Http2Headers.PseudoHeaderName pseudoHeader = getPseudoHeader(name); final Http2Headers.PseudoHeaderName pseudoHeader = getPseudoHeader(name);
if (pseudoHeader == null) { if (pseudoHeader == null) {
throw connectionError(PROTOCOL_ERROR, "Invalid HTTP/2 pseudo-header '%s' encountered.", name); throw streamError(streamId, PROTOCOL_ERROR, "Invalid HTTP/2 pseudo-header '%s' encountered.", name);
} }
final HeaderType currentHeaderType = pseudoHeader.isRequestOnly() ? final HeaderType currentHeaderType = pseudoHeader.isRequestOnly() ?
HeaderType.REQUEST_PSEUDO_HEADER : HeaderType.RESPONSE_PSEUDO_HEADER; HeaderType.REQUEST_PSEUDO_HEADER : HeaderType.RESPONSE_PSEUDO_HEADER;
if (previousHeaderType != null && currentHeaderType != previousHeaderType) { if (previousHeaderType != null && currentHeaderType != previousHeaderType) {
throw connectionError(PROTOCOL_ERROR, "Mix of request and response pseudo-headers."); throw streamError(streamId, PROTOCOL_ERROR, "Mix of request and response pseudo-headers.");
} }
return currentHeaderType; return currentHeaderType;
@ -435,8 +422,7 @@ final class HpackDecoder {
throw INDEX_HEADER_ILLEGAL_INDEX_VALUE; throw INDEX_HEADER_ILLEGAL_INDEX_VALUE;
} }
private void insertHeader(Sink sink, CharSequence name, CharSequence value, private void insertHeader(Sink sink, CharSequence name, CharSequence value, IndexType indexType) {
IndexType indexType) throws Http2Exception {
sink.appendToHeaderList(name, value); sink.appendToHeaderList(name, value);
switch (indexType) { switch (indexType) {
@ -529,32 +515,55 @@ final class HpackDecoder {
private interface Sink { private interface Sink {
void appendToHeaderList(CharSequence name, CharSequence value); void appendToHeaderList(CharSequence name, CharSequence value);
void finish() throws Http2Exception;
} }
private static final class Http2HeadersSink implements Sink { private static final class Http2HeadersSink implements Sink {
private final Http2Headers headers; private final Http2Headers headers;
private final long maxHeaderListSize; private final long maxHeaderListSize;
private final int streamId;
private final boolean validate;
private long headersLength; private long headersLength;
private boolean exceededMaxLength; private boolean exceededMaxLength;
private HeaderType previousType;
private Http2Exception validationException;
public Http2HeadersSink(Http2Headers headers, long maxHeaderListSize) { public Http2HeadersSink(int streamId, Http2Headers headers, long maxHeaderListSize, boolean validate) {
this.headers = headers; this.headers = headers;
this.maxHeaderListSize = maxHeaderListSize; this.maxHeaderListSize = maxHeaderListSize;
this.streamId = streamId;
this.validate = validate;
}
@Override
public void finish() throws Http2Exception {
if (exceededMaxLength) {
headerListSizeExceeded(streamId, maxHeaderListSize, true);
} else if (validationException != null) {
throw validationException;
}
} }
@Override @Override
public void appendToHeaderList(CharSequence name, CharSequence value) { public void appendToHeaderList(CharSequence name, CharSequence value) {
headersLength += HpackHeaderField.sizeOf(name, value); headersLength += HpackHeaderField.sizeOf(name, value);
if (headersLength > maxHeaderListSize) { exceededMaxLength |= headersLength > maxHeaderListSize;
exceededMaxLength = true;
}
if (!exceededMaxLength) {
headers.add(name, value);
}
}
public boolean exceededMaxLength() { if (exceededMaxLength || validationException != null) {
return exceededMaxLength; // We don't store the header since we've already failed validation requirements.
return;
}
if (validate) {
try {
previousType = HpackDecoder.validate(streamId, name, previousType);
} catch (Http2Exception ex) {
validationException = ex;
return;
}
}
headers.add(name, value);
} }
} }
} }

View File

@ -544,7 +544,7 @@ public class HpackDecoderTest {
Http2Headers decoded = new DefaultHttp2Headers(); Http2Headers decoded = new DefaultHttp2Headers();
expectedException.expect(Http2Exception.class); expectedException.expect(Http2Exception.StreamException.class);
hpackDecoder.decode(1, in, decoded, true); hpackDecoder.decode(1, in, decoded, true);
} finally { } finally {
in.release(); in.release();
@ -588,7 +588,7 @@ public class HpackDecoderTest {
Http2Headers decoded = new DefaultHttp2Headers(); Http2Headers decoded = new DefaultHttp2Headers();
expectedException.expect(Http2Exception.class); expectedException.expect(Http2Exception.StreamException.class);
hpackDecoder.decode(1, in, decoded, true); hpackDecoder.decode(1, in, decoded, true);
} finally { } finally {
in.release(); in.release();
@ -608,7 +608,7 @@ public class HpackDecoderTest {
Http2Headers decoded = new DefaultHttp2Headers(); Http2Headers decoded = new DefaultHttp2Headers();
expectedException.expect(Http2Exception.class); expectedException.expect(Http2Exception.StreamException.class);
hpackDecoder.decode(1, in, decoded, true); hpackDecoder.decode(1, in, decoded, true);
} finally { } finally {
in.release(); in.release();
@ -628,10 +628,47 @@ public class HpackDecoderTest {
Http2Headers decoded = new DefaultHttp2Headers(); Http2Headers decoded = new DefaultHttp2Headers();
expectedException.expect(Http2Exception.class); expectedException.expect(Http2Exception.StreamException.class);
hpackDecoder.decode(1, in, decoded, true); hpackDecoder.decode(1, in, decoded, true);
} finally { } finally {
in.release(); in.release();
} }
} }
@Test
public void failedValidationDoesntCorruptHpack() throws Exception {
ByteBuf in1 = Unpooled.buffer(200);
ByteBuf in2 = Unpooled.buffer(200);
try {
HpackEncoder hpackEncoder = new HpackEncoder(true);
Http2Headers toEncode = new DefaultHttp2Headers();
toEncode.add(":method", "GET");
toEncode.add(":status", "200");
toEncode.add("foo", "bar");
hpackEncoder.encodeHeaders(1, in1, toEncode, NEVER_SENSITIVE);
Http2Headers decoded = new DefaultHttp2Headers();
try {
hpackDecoder.decode(1, in1, decoded, true);
fail("Should have thrown a StreamException");
} catch (Http2Exception.StreamException expected) {
assertEquals(1, expected.streamId());
}
// Do it again, this time without validation, to make sure the HPACK state is still sane.
decoded.clear();
hpackEncoder.encodeHeaders(1, in2, toEncode, NEVER_SENSITIVE);
hpackDecoder.decode(1, in2, decoded, false);
assertEquals(3, decoded.size());
assertEquals("GET", decoded.method().toString());
assertEquals("200", decoded.status().toString());
assertEquals("bar", decoded.get("foo").toString());
} finally {
in1.release();
in2.release();
}
}
} }