HTTP/2 limit header accumulated size

Motivation:
The Http2FrameListener requires that the Http2FrameReader accumulate ByteBuf objects. There is currently no way to limit the amount of bytes that is accumulated.

Motiviation:
- DefaultHttp2FrameReader supports maxHeaderSize which will fail fast as soon as the maximum size is exceeded.
- DefaultHttp2HeadersDecoder will respect the return value of endHeaderBlock() and fail if the max size is exceeded.

Result:
Frames which carry header data now respect a maximum number of bytes that can be accumulated.
This commit is contained in:
Scott Mitchell 2015-06-17 17:33:09 -07:00
parent ba84a596e2
commit ecacd11b06
7 changed files with 109 additions and 18 deletions

View File

@ -14,20 +14,21 @@
*/ */
package io.netty.handler.codec.http2; package io.netty.handler.codec.http2;
import static io.netty.handler.codec.http2.Http2CodecUtil.SETTINGS_INITIAL_WINDOW_SIZE;
import static io.netty.handler.codec.http2.Http2CodecUtil.DEFAULT_MAX_FRAME_SIZE; import static io.netty.handler.codec.http2.Http2CodecUtil.DEFAULT_MAX_FRAME_SIZE;
import static io.netty.handler.codec.http2.Http2CodecUtil.FRAME_HEADER_LENGTH; import static io.netty.handler.codec.http2.Http2CodecUtil.FRAME_HEADER_LENGTH;
import static io.netty.handler.codec.http2.Http2CodecUtil.INT_FIELD_LENGTH; import static io.netty.handler.codec.http2.Http2CodecUtil.INT_FIELD_LENGTH;
import static io.netty.handler.codec.http2.Http2CodecUtil.PRIORITY_ENTRY_LENGTH; import static io.netty.handler.codec.http2.Http2CodecUtil.PRIORITY_ENTRY_LENGTH;
import static io.netty.handler.codec.http2.Http2CodecUtil.SETTINGS_INITIAL_WINDOW_SIZE;
import static io.netty.handler.codec.http2.Http2CodecUtil.SETTINGS_MAX_FRAME_SIZE; import static io.netty.handler.codec.http2.Http2CodecUtil.SETTINGS_MAX_FRAME_SIZE;
import static io.netty.handler.codec.http2.Http2CodecUtil.SETTING_ENTRY_LENGTH; import static io.netty.handler.codec.http2.Http2CodecUtil.SETTING_ENTRY_LENGTH;
import static io.netty.handler.codec.http2.Http2CodecUtil.isMaxFrameSizeValid;
import static io.netty.handler.codec.http2.Http2CodecUtil.readUnsignedInt;
import static io.netty.handler.codec.http2.Http2Error.ENHANCE_YOUR_CALM;
import static io.netty.handler.codec.http2.Http2Error.FLOW_CONTROL_ERROR; import static io.netty.handler.codec.http2.Http2Error.FLOW_CONTROL_ERROR;
import static io.netty.handler.codec.http2.Http2Error.FRAME_SIZE_ERROR; import static io.netty.handler.codec.http2.Http2Error.FRAME_SIZE_ERROR;
import static io.netty.handler.codec.http2.Http2Error.PROTOCOL_ERROR; import static io.netty.handler.codec.http2.Http2Error.PROTOCOL_ERROR;
import static io.netty.handler.codec.http2.Http2CodecUtil.isMaxFrameSizeValid;
import static io.netty.handler.codec.http2.Http2CodecUtil.readUnsignedInt;
import static io.netty.handler.codec.http2.Http2Exception.streamError;
import static io.netty.handler.codec.http2.Http2Exception.connectionError; import static io.netty.handler.codec.http2.Http2Exception.connectionError;
import static io.netty.handler.codec.http2.Http2Exception.streamError;
import static io.netty.handler.codec.http2.Http2FrameTypes.CONTINUATION; import static io.netty.handler.codec.http2.Http2FrameTypes.CONTINUATION;
import static io.netty.handler.codec.http2.Http2FrameTypes.DATA; import static io.netty.handler.codec.http2.Http2FrameTypes.DATA;
import static io.netty.handler.codec.http2.Http2FrameTypes.GO_AWAY; import static io.netty.handler.codec.http2.Http2FrameTypes.GO_AWAY;
@ -421,9 +422,8 @@ public class DefaultHttp2FrameReader implements Http2FrameReader, Http2FrameSize
final HeadersBlockBuilder hdrBlockBuilder = headersBlockBuilder(); final HeadersBlockBuilder hdrBlockBuilder = headersBlockBuilder();
hdrBlockBuilder.addFragment(fragment, ctx.alloc(), endOfHeaders); hdrBlockBuilder.addFragment(fragment, ctx.alloc(), endOfHeaders);
if (endOfHeaders) { if (endOfHeaders) {
listener.onHeadersRead(ctx, headersStreamId, hdrBlockBuilder.headers(), listener.onHeadersRead(ctx, headersStreamId, hdrBlockBuilder.headers(), streamDependency,
streamDependency, weight, exclusive, padding, headersFlags.endOfStream()); weight, exclusive, padding, headersFlags.endOfStream());
close();
} }
} }
}; };
@ -449,7 +449,6 @@ public class DefaultHttp2FrameReader implements Http2FrameReader, Http2FrameSize
if (endOfHeaders) { if (endOfHeaders) {
listener.onHeadersRead(ctx, headersStreamId, hdrBlockBuilder.headers(), padding, listener.onHeadersRead(ctx, headersStreamId, hdrBlockBuilder.headers(), padding,
headersFlags.endOfStream()); headersFlags.endOfStream());
close();
} }
} }
}; };
@ -519,10 +518,8 @@ public class DefaultHttp2FrameReader implements Http2FrameReader, Http2FrameSize
Http2FrameListener listener) throws Http2Exception { Http2FrameListener listener) throws Http2Exception {
headersBlockBuilder().addFragment(fragment, ctx.alloc(), endOfHeaders); headersBlockBuilder().addFragment(fragment, ctx.alloc(), endOfHeaders);
if (endOfHeaders) { if (endOfHeaders) {
Http2Headers headers = headersBlockBuilder().headers(); listener.onPushPromiseRead(ctx, pushPromiseStreamId, promisedStreamId,
listener.onPushPromiseRead(ctx, pushPromiseStreamId, promisedStreamId, headers, headersBlockBuilder().headers(), padding);
padding);
close();
} }
} }
}; };
@ -626,6 +623,16 @@ public class DefaultHttp2FrameReader implements Http2FrameReader, Http2FrameSize
protected class HeadersBlockBuilder { protected class HeadersBlockBuilder {
private ByteBuf headerBlock; private ByteBuf headerBlock;
/**
* The local header size maximum has been exceeded while accumulating bytes.
* @throws Http2Exception A connection error indicating too much data has been received.
*/
private void headerSizeExceeded() throws Http2Exception {
close();
throw connectionError(ENHANCE_YOUR_CALM, "Header size exceeded max allowed size (%d)",
headersDecoder.configuration().maxHeaderSize());
}
/** /**
* Adds a fragment to the block. * Adds a fragment to the block.
* *
@ -635,8 +642,11 @@ public class DefaultHttp2FrameReader implements Http2FrameReader, Http2FrameSize
* This is used for an optimization for when the first fragment is the full * This is used for an optimization for when the first fragment is the full
* block. In that case, the buffer is used directly without copying. * block. In that case, the buffer is used directly without copying.
*/ */
final void addFragment(ByteBuf fragment, ByteBufAllocator alloc, boolean endOfHeaders) { final void addFragment(ByteBuf fragment, ByteBufAllocator alloc, boolean endOfHeaders) throws Http2Exception {
if (headerBlock == null) { if (headerBlock == null) {
if (fragment.readableBytes() > headersDecoder.configuration().maxHeaderSize()) {
headerSizeExceeded();
}
if (endOfHeaders) { if (endOfHeaders) {
// Optimization - don't bother copying, just use the buffer as-is. Need // Optimization - don't bother copying, just use the buffer as-is. Need
// to retain since we release when the header block is built. // to retain since we release when the header block is built.
@ -647,6 +657,10 @@ public class DefaultHttp2FrameReader implements Http2FrameReader, Http2FrameSize
} }
return; return;
} }
if (headersDecoder.configuration().maxHeaderSize() - fragment.readableBytes() <
headerBlock.readableBytes()) {
headerSizeExceeded();
}
if (headerBlock.isWritable(fragment.readableBytes())) { if (headerBlock.isWritable(fragment.readableBytes())) {
// The buffer can hold the requested bytes, just write it directly. // The buffer can hold the requested bytes, just write it directly.
headerBlock.writeBytes(fragment); headerBlock.writeBytes(fragment);

View File

@ -18,6 +18,7 @@ package io.netty.handler.codec.http2;
import static io.netty.handler.codec.http2.Http2CodecUtil.DEFAULT_HEADER_TABLE_SIZE; import static io.netty.handler.codec.http2.Http2CodecUtil.DEFAULT_HEADER_TABLE_SIZE;
import static io.netty.handler.codec.http2.Http2CodecUtil.DEFAULT_MAX_HEADER_SIZE; import static io.netty.handler.codec.http2.Http2CodecUtil.DEFAULT_MAX_HEADER_SIZE;
import static io.netty.handler.codec.http2.Http2Error.COMPRESSION_ERROR; import static io.netty.handler.codec.http2.Http2Error.COMPRESSION_ERROR;
import static io.netty.handler.codec.http2.Http2Error.ENHANCE_YOUR_CALM;
import static io.netty.handler.codec.http2.Http2Error.INTERNAL_ERROR; import static io.netty.handler.codec.http2.Http2Error.INTERNAL_ERROR;
import static io.netty.handler.codec.http2.Http2Error.PROTOCOL_ERROR; import static io.netty.handler.codec.http2.Http2Error.PROTOCOL_ERROR;
import static io.netty.handler.codec.http2.Http2Exception.connectionError; import static io.netty.handler.codec.http2.Http2Exception.connectionError;
@ -32,6 +33,7 @@ import com.twitter.hpack.Decoder;
import com.twitter.hpack.HeaderListener; import com.twitter.hpack.HeaderListener;
public class DefaultHttp2HeadersDecoder implements Http2HeadersDecoder, Http2HeadersDecoder.Configuration { public class DefaultHttp2HeadersDecoder implements Http2HeadersDecoder, Http2HeadersDecoder.Configuration {
private final int maxHeaderSize;
private final Decoder decoder; private final Decoder decoder;
private final Http2HeaderTable headerTable; private final Http2HeaderTable headerTable;
@ -40,8 +42,12 @@ public class DefaultHttp2HeadersDecoder implements Http2HeadersDecoder, Http2Hea
} }
public DefaultHttp2HeadersDecoder(int maxHeaderSize, int maxHeaderTableSize) { public DefaultHttp2HeadersDecoder(int maxHeaderSize, int maxHeaderTableSize) {
if (maxHeaderSize <= 0) {
throw new IllegalArgumentException("maxHeaderSize must be positive: " + maxHeaderSize);
}
decoder = new Decoder(maxHeaderSize, maxHeaderTableSize); decoder = new Decoder(maxHeaderSize, maxHeaderTableSize);
headerTable = new Http2HeaderTableDecoder(); headerTable = new Http2HeaderTableDecoder();
this.maxHeaderSize = maxHeaderSize;
} }
@Override @Override
@ -49,11 +55,24 @@ public class DefaultHttp2HeadersDecoder implements Http2HeadersDecoder, Http2Hea
return headerTable; return headerTable;
} }
@Override
public int maxHeaderSize() {
return maxHeaderSize;
}
@Override @Override
public Configuration configuration() { public Configuration configuration() {
return this; return this;
} }
/**
* Respond to headers block resulting in the maximum header size being exceeded.
* @throws Http2Exception If we can not recover from the truncation.
*/
protected void maxHeaderSizeExceeded() throws Http2Exception {
throw connectionError(ENHANCE_YOUR_CALM, "Header size exceeded max allowed bytes (%d)", maxHeaderSize);
}
@Override @Override
public Http2Headers decodeHeaders(ByteBuf headerBlock) throws Http2Exception { public Http2Headers decodeHeaders(ByteBuf headerBlock) throws Http2Exception {
InputStream in = new ByteBufInputStream(headerBlock); InputStream in = new ByteBufInputStream(headerBlock);
@ -67,9 +86,8 @@ public class DefaultHttp2HeadersDecoder implements Http2HeadersDecoder, Http2Hea
}; };
decoder.decode(in, listener); decoder.decode(in, listener);
boolean truncated = decoder.endHeaderBlock(); if (decoder.endHeaderBlock()) {
if (truncated) { maxHeaderSizeExceeded();
// TODO: what's the right thing to do here?
} }
if (headers.size() > headerTable.maxHeaderListSize()) { if (headers.size() > headerTable.maxHeaderListSize()) {

View File

@ -29,6 +29,11 @@ public interface Http2HeadersDecoder {
* Access the Http2HeaderTable for this {@link Http2HeadersDecoder} * Access the Http2HeaderTable for this {@link Http2HeadersDecoder}
*/ */
Http2HeaderTable headerTable(); Http2HeaderTable headerTable();
/**
* Get the maximum number of bytes that is allowed before truncation occurs.
*/
int maxHeaderSize();
} }
/** /**

View File

@ -15,15 +15,20 @@
package io.netty.handler.codec.http2; package io.netty.handler.codec.http2;
import static io.netty.handler.codec.http2.Http2CodecUtil.DEFAULT_MAX_HEADER_SIZE;
import static io.netty.handler.codec.http2.Http2CodecUtil.MAX_UNSIGNED_INT; import static io.netty.handler.codec.http2.Http2CodecUtil.MAX_UNSIGNED_INT;
import static io.netty.handler.codec.http2.Http2TestUtil.as; import static io.netty.handler.codec.http2.Http2TestUtil.as;
import static io.netty.handler.codec.http2.Http2TestUtil.randomString; import static io.netty.handler.codec.http2.Http2TestUtil.randomString;
import static org.junit.Assert.fail;
import static org.mockito.Matchers.any; import static org.mockito.Matchers.any;
import static org.mockito.Matchers.anyBoolean;
import static org.mockito.Matchers.anyInt;
import static org.mockito.Matchers.anyShort;
import static org.mockito.Matchers.eq; import static org.mockito.Matchers.eq;
import static org.mockito.Mockito.doAnswer; import static org.mockito.Mockito.doAnswer;
import static org.mockito.Mockito.never;
import static org.mockito.Mockito.verify; import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when; import static org.mockito.Mockito.when;
import io.netty.buffer.ByteBuf; import io.netty.buffer.ByteBuf;
import io.netty.buffer.ByteBufAllocator; import io.netty.buffer.ByteBufAllocator;
import io.netty.buffer.Unpooled; import io.netty.buffer.Unpooled;
@ -33,8 +38,10 @@ import io.netty.channel.ChannelFuture;
import io.netty.channel.ChannelFutureListener; import io.netty.channel.ChannelFutureListener;
import io.netty.channel.ChannelHandlerContext; import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.ChannelPromise; import io.netty.channel.ChannelPromise;
import io.netty.util.ByteString;
import io.netty.util.CharsetUtil; import io.netty.util.CharsetUtil;
import io.netty.util.concurrent.EventExecutor; import io.netty.util.concurrent.EventExecutor;
import org.junit.After; import org.junit.After;
import org.junit.Before; import org.junit.Before;
import org.junit.Test; import org.junit.Test;
@ -291,6 +298,20 @@ public class DefaultHttp2FrameIOTest {
eq(true)); eq(true));
} }
@Test
public void headersThatAreTooBigShouldFail() throws Exception {
Http2Headers headers = headersOfSize(DEFAULT_MAX_HEADER_SIZE + 1);
writer.writeHeaders(ctx, 1, headers, 2, (short) 3, true, 0xFF, true, promise);
try {
reader.readFrame(ctx, buffer, listener);
fail();
} catch (Http2Exception e) {
verify(listener, never()).onHeadersRead(any(ChannelHandlerContext.class), anyInt(),
any(Http2Headers.class), anyInt(), anyShort(), anyBoolean(), anyInt(),
anyBoolean());
}
}
@Test @Test
public void emptypushPromiseShouldRoundtrip() throws Exception { public void emptypushPromiseShouldRoundtrip() throws Exception {
Http2Headers headers = EmptyHttp2Headers.INSTANCE; Http2Headers headers = EmptyHttp2Headers.INSTANCE;
@ -357,4 +378,13 @@ public class DefaultHttp2FrameIOTest {
} }
return headers; return headers;
} }
private Http2Headers headersOfSize(final int minSize) {
final ByteString singleByte = new ByteString(new byte[]{0});
DefaultHttp2Headers headers = new DefaultHttp2Headers();
for (int size = 0; size < minSize; size += 2) {
headers.add(singleByte, singleByte);
}
return headers;
}
} }

View File

@ -15,11 +15,13 @@
package io.netty.handler.codec.http2; package io.netty.handler.codec.http2;
import static io.netty.handler.codec.http2.Http2CodecUtil.DEFAULT_MAX_HEADER_SIZE;
import static io.netty.handler.codec.http2.Http2CodecUtil.MAX_HEADER_TABLE_SIZE; import static io.netty.handler.codec.http2.Http2CodecUtil.MAX_HEADER_TABLE_SIZE;
import static io.netty.handler.codec.http2.Http2TestUtil.as; import static io.netty.handler.codec.http2.Http2TestUtil.as;
import static io.netty.handler.codec.http2.Http2TestUtil.randomBytes; import static io.netty.handler.codec.http2.Http2TestUtil.randomBytes;
import static io.netty.util.CharsetUtil.UTF_8; import static io.netty.util.CharsetUtil.UTF_8;
import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertEquals;
import static org.junit.Assert.fail;
import io.netty.buffer.ByteBuf; import io.netty.buffer.ByteBuf;
import io.netty.buffer.Unpooled; import io.netty.buffer.Unpooled;
@ -55,6 +57,17 @@ public class DefaultHttp2HeadersDecoderTest {
} }
} }
@Test(expected = Http2Exception.class)
public void testExceedHeaderSize() throws Exception {
ByteBuf buf = encode(randomBytes(DEFAULT_MAX_HEADER_SIZE), randomBytes(1));
try {
decoder.decodeHeaders(buf);
fail();
} finally {
buf.release();
}
}
private static byte[] b(String string) { private static byte[] b(String string) {
return string.getBytes(UTF_8); return string.getBytes(UTF_8);
} }

View File

@ -1189,6 +1189,7 @@ public class DefaultHttp2RemoteFlowControllerTest {
final Http2RemoteFlowController.FlowControlled flowControlled = mockedFlowControlledThatThrowsOnWrite(); final Http2RemoteFlowController.FlowControlled flowControlled = mockedFlowControlledThatThrowsOnWrite();
final Http2Stream stream = stream(STREAM_A); final Http2Stream stream = stream(STREAM_A);
doAnswer(new Answer<Void>() { doAnswer(new Answer<Void>() {
@Override
public Void answer(InvocationOnMock invocationOnMock) { public Void answer(InvocationOnMock invocationOnMock) {
stream.closeLocalSide(); stream.closeLocalSide();
return null; return null;
@ -1212,6 +1213,7 @@ public class DefaultHttp2RemoteFlowControllerTest {
final Http2RemoteFlowController.FlowControlled flowControlled = mockedFlowControlledThatThrowsOnWrite(); final Http2RemoteFlowController.FlowControlled flowControlled = mockedFlowControlledThatThrowsOnWrite();
final Http2Stream stream = stream(STREAM_A); final Http2Stream stream = stream(STREAM_A);
doAnswer(new Answer<Void>() { doAnswer(new Answer<Void>() {
@Override
public Void answer(InvocationOnMock invocationOnMock) { public Void answer(InvocationOnMock invocationOnMock) {
throw new RuntimeException("error failed"); throw new RuntimeException("error failed");
} }
@ -1257,6 +1259,7 @@ public class DefaultHttp2RemoteFlowControllerTest {
final Http2Stream stream = stream(STREAM_A); final Http2Stream stream = stream(STREAM_A);
doAnswer(new Answer<Void>() { doAnswer(new Answer<Void>() {
@Override
public Void answer(InvocationOnMock invocationOnMock) { public Void answer(InvocationOnMock invocationOnMock) {
throw new RuntimeException("writeComplete failed"); throw new RuntimeException("writeComplete failed");
} }
@ -1286,6 +1289,7 @@ public class DefaultHttp2RemoteFlowControllerTest {
when(flowControlled.size()).thenReturn(100); when(flowControlled.size()).thenReturn(100);
doThrow(new RuntimeException("write failed")).when(flowControlled).write(anyInt()); doThrow(new RuntimeException("write failed")).when(flowControlled).write(anyInt());
doAnswer(new Answer<Void>() { doAnswer(new Answer<Void>() {
@Override
public Void answer(InvocationOnMock invocationOnMock) { public Void answer(InvocationOnMock invocationOnMock) {
stream.close(); stream.close();
return null; return null;

View File

@ -70,7 +70,14 @@ final class Http2TestUtil {
* Returns a byte array filled with random data. * Returns a byte array filled with random data.
*/ */
public static byte[] randomBytes() { public static byte[] randomBytes() {
byte[] data = new byte[100]; return randomBytes(100);
}
/**
* Returns a byte array filled with random data.
*/
public static byte[] randomBytes(int size) {
byte[] data = new byte[size];
new Random().nextBytes(data); new Random().nextBytes(data);
return data; return data;
} }