Move location of where oversized headers that don't exceed the go away limit is done

so that the check occurs after all headers have been read
This commit is contained in:
Christopher Exell 2017-01-24 11:44:55 -08:00 committed by Scott Mitchell
parent a416b79d86
commit d509365974
2 changed files with 44 additions and 3 deletions

View File

@ -295,10 +295,13 @@ public final class Decoder {
default:
throw new Error("should not reach here state: " + state);
}
}
if (headersLength > maxHeaderListSize) {
headerListSizeExceeded(streamId, maxHeaderListSize, true);
}
// we have read all of our headers, and not exceeded maxHeaderListSizeGoAway see if we have
// exceeded our actual maxHeaderListSize. This must be done here to prevent dynamic table
// corruption
if (headersLength > maxHeaderListSize) {
headerListSizeExceeded(streamId, maxHeaderListSize, true);
}
}

View File

@ -33,16 +33,20 @@ package io.netty.handler.codec.http2.internal.hpack;
import io.netty.buffer.ByteBuf;
import io.netty.buffer.Unpooled;
import io.netty.handler.codec.http2.DefaultHttp2Headers;
import io.netty.handler.codec.http2.Http2Exception;
import io.netty.handler.codec.http2.Http2Headers;
import org.junit.Before;
import org.junit.Test;
import static io.netty.handler.codec.http2.Http2HeadersEncoder.NEVER_SENSITIVE;
import static io.netty.handler.codec.http2.internal.hpack.Decoder.decodeULE128;
import static io.netty.util.AsciiString.EMPTY_STRING;
import static io.netty.util.AsciiString.of;
import static java.lang.Integer.MAX_VALUE;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertTrue;
import static org.junit.Assert.fail;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.reset;
import static org.mockito.Mockito.times;
@ -419,4 +423,38 @@ public class DecoderTest {
}
decode(sb.toString());
}
@Test
public void testDecodeLargerThanMaxHeaderListSizeButSmallerThanMaxHeaderListSizeUpdatesDynamicTable()
throws Http2Exception {
ByteBuf in = Unpooled.buffer(200);
try {
decoder.setMaxHeaderListSize(100, 200);
Encoder encoder = new Encoder(true);
// encode headers that are slightly larger than maxHeaderListSize
// but smaller than maxHeaderListSizeGoAway
Http2Headers toEncode = new DefaultHttp2Headers();
toEncode.add("test_1", "1");
toEncode.add("test_2", "2");
toEncode.add("long", String.format("%0100d", 0).replace('0', 'A'));
toEncode.add("test_3", "3");
encoder.encodeHeaders(1, in, toEncode, NEVER_SENSITIVE);
// decode the headers, we should get an exception, but
// the decoded headers object should contain all of the headers
Http2Headers decoded = new DefaultHttp2Headers();
try {
decoder.decode(1, in, decoded);
fail();
} catch (Http2Exception e) {
assertTrue(e instanceof Http2Exception.HeaderListSizeException);
}
assertEquals(4, decoded.size());
assertTrue(decoded.contains("test_3"));
} finally {
in.release();
}
}
}