Optimize the http-codec to be much faster then before.

This includes the following changes:
  * Rewrite HttpObjectDecoder to use ByteToMessageDecoder
  * Optimize parsing of the header
  * Allow to disable the header validation for performance reason
  * Not need to validate kwown header names and values
  * Minimize access of ThreadLocals
  * Optimize parsing of initial line
This commit is contained in:
Norman Maurer 2013-09-05 20:20:26 +02:00
parent 6716dca17a
commit 198dde33b3
17 changed files with 1283 additions and 439 deletions

View File

@ -30,7 +30,12 @@ public class DefaultFullHttpRequest extends DefaultHttpRequest implements FullHt
}
public DefaultFullHttpRequest(HttpVersion httpVersion, HttpMethod method, String uri, ByteBuf content) {
super(httpVersion, method, uri);
this(httpVersion, method, uri, content, true);
}
public DefaultFullHttpRequest(HttpVersion httpVersion, HttpMethod method, String uri,
ByteBuf content, boolean validateHeaders) {
super(httpVersion, method, uri, validateHeaders);
if (content == null) {
throw new NullPointerException("content");
}

View File

@ -32,7 +32,12 @@ public class DefaultFullHttpResponse extends DefaultHttpResponse implements Full
}
public DefaultFullHttpResponse(HttpVersion version, HttpResponseStatus status, ByteBuf content) {
super(version, status);
this(version, status, content, true);
}
public DefaultFullHttpResponse(HttpVersion version, HttpResponseStatus status,
ByteBuf content, boolean validateHeaders) {
super(version, status, validateHeaders);
if (content == null) {
throw new NullPointerException("content");
}

View File

@ -15,9 +15,12 @@
*/
package io.netty.handler.codec.http;
import java.lang.reflect.Field;
import java.lang.reflect.Modifier;
import java.util.Arrays;
import java.util.Calendar;
import java.util.Date;
import java.util.HashSet;
import java.util.Iterator;
import java.util.LinkedList;
import java.util.List;
@ -29,13 +32,35 @@ public class DefaultHttpHeaders extends HttpHeaders {
private static final int BUCKET_SIZE = 17;
private static int hash(String name) {
private static final Set<String> KNOWN_NAMES = createSet(Names.class);
private static final Set<String> KNOWN_VALUES = createSet(Values.class);
private static Set<String> createSet(Class<?> clazz) {
Set<String> set = new HashSet<String>();
Field[] fields = clazz.getDeclaredFields();
for (Field f: fields) {
int m = f.getModifiers();
if (Modifier.isPublic(m) && Modifier.isStatic(m) && Modifier.isFinal(m)
&& f.getType().isAssignableFrom(String.class)) {
try {
set.add((String) f.get(null));
} catch (Throwable cause) {
// ignore
}
}
}
return set;
}
private static int hash(String name, boolean validate) {
int h = 0;
for (int i = name.length() - 1; i >= 0; i --) {
char c = name.charAt(i);
if (c >= 'A' && c <= 'Z') {
c += 32;
if (validate) {
valideHeaderNameChar(c);
}
c = toLowerCase(c);
h = 31 * h + c;
}
@ -49,6 +74,10 @@ public class DefaultHttpHeaders extends HttpHeaders {
}
private static boolean eq(String name1, String name2) {
if (name1 == name2) {
// check for object equality as the user may reuse our static fields in HttpHeaders.Names
return true;
}
int nameLen = name1.length();
if (nameLen != name2.length()) {
return false;
@ -58,13 +87,7 @@ public class DefaultHttpHeaders extends HttpHeaders {
char c1 = name1.charAt(i);
char c2 = name2.charAt(i);
if (c1 != c2) {
if (c1 >= 'A' && c1 <= 'Z') {
c1 += 32;
}
if (c2 >= 'A' && c2 <= 'Z') {
c2 += 32;
}
if (c1 != c2) {
if (toLowerCase(c1) != toLowerCase(c2)) {
return false;
}
}
@ -72,27 +95,47 @@ public class DefaultHttpHeaders extends HttpHeaders {
return true;
}
private static char toLowerCase(char c) {
if (c >= 'A' && c <= 'Z') {
c += 32;
}
return c;
}
private static int index(int hash) {
return hash % BUCKET_SIZE;
}
private final HeaderEntry[] entries = new HeaderEntry[BUCKET_SIZE];
private final HeaderEntry head = new HeaderEntry(-1, null, null);
protected final boolean validate;
public DefaultHttpHeaders() {
head.before = head.after = head;
this(true);
}
void validateHeaderName0(String headerName) {
validateHeaderName(headerName);
public DefaultHttpHeaders(boolean validate) {
head.before = head.after = head;
this.validate = validate;
}
void validateHeaderValue0(String headerValue) {
if (KNOWN_VALUES.contains(headerValue)) {
return;
}
validateHeaderValue(headerValue);
}
@Override
public HttpHeaders add(final String name, final Object value) {
validateHeaderName0(name);
String strVal = toString(value);
validateHeaderValue(strVal);
int h = hash(name);
boolean validateName = false;
if (validate) {
validateHeaderValue0(strVal);
validateName = !KNOWN_NAMES.contains(name);
}
int h = hash(name, validateName);
int i = index(h);
add0(h, i, name, strVal);
return this;
@ -100,12 +143,18 @@ public class DefaultHttpHeaders extends HttpHeaders {
@Override
public HttpHeaders add(String name, Iterable<?> values) {
validateHeaderName0(name);
int h = hash(name);
boolean validateName = false;
if (validate) {
validateName = !KNOWN_NAMES.contains(name);
}
int h = hash(name, validateName);
int i = index(h);
for (Object v: values) {
String vstr = toString(v);
validateHeaderValue(vstr);
if (validate) {
validateHeaderValue0(vstr);
}
add0(h, i, name, vstr);
}
return this;
@ -127,7 +176,7 @@ public class DefaultHttpHeaders extends HttpHeaders {
if (name == null) {
throw new NullPointerException("name");
}
int h = hash(name);
int h = hash(name, false);
int i = index(h);
remove0(h, i, name);
return this;
@ -171,10 +220,14 @@ public class DefaultHttpHeaders extends HttpHeaders {
@Override
public HttpHeaders set(final String name, final Object value) {
validateHeaderName0(name);
String strVal = toString(value);
validateHeaderValue(strVal);
int h = hash(name);
boolean validateName = false;
if (validate) {
validateHeaderValue0(strVal);
validateName = !KNOWN_NAMES.contains(name);
}
int h = hash(name, validateName);
int i = index(h);
remove0(h, i, name);
add0(h, i, name, strVal);
@ -187,9 +240,12 @@ public class DefaultHttpHeaders extends HttpHeaders {
throw new NullPointerException("values");
}
validateHeaderName0(name);
boolean validateName = false;
if (validate) {
validateName = !KNOWN_NAMES.contains(name);
}
int h = hash(name);
int h = hash(name, validateName);
int i = index(h);
remove0(h, i, name);
@ -198,7 +254,9 @@ public class DefaultHttpHeaders extends HttpHeaders {
break;
}
String strVal = toString(v);
validateHeaderValue(strVal);
if (validate) {
validateHeaderValue0(strVal);
}
add0(h, i, name, strVal);
}
@ -214,11 +272,15 @@ public class DefaultHttpHeaders extends HttpHeaders {
@Override
public String get(final String name) {
return get(name, false);
}
private String get(final String name, boolean last) {
if (name == null) {
throw new NullPointerException("name");
}
int h = hash(name);
int h = hash(name, false);
int i = index(h);
HeaderEntry e = entries[i];
String value = null;
@ -226,6 +288,9 @@ public class DefaultHttpHeaders extends HttpHeaders {
while (e != null) {
if (e.hash == h && eq(name, e.key)) {
value = e.value;
if (last) {
break;
}
}
e = e.next;
@ -241,7 +306,7 @@ public class DefaultHttpHeaders extends HttpHeaders {
LinkedList<String> values = new LinkedList<String>();
int h = hash(name);
int h = hash(name, false);
int i = index(h);
HeaderEntry e = entries[i];
while (e != null) {
@ -273,7 +338,7 @@ public class DefaultHttpHeaders extends HttpHeaders {
@Override
public boolean contains(String name) {
return get(name) != null;
return get(name, true) != null;
}
@Override
@ -313,7 +378,7 @@ public class DefaultHttpHeaders extends HttpHeaders {
return value.toString();
}
private static final class HeaderEntry implements Map.Entry<String, String> {
private final class HeaderEntry implements Map.Entry<String, String> {
final int hash;
final String key;
String value;
@ -353,7 +418,9 @@ public class DefaultHttpHeaders extends HttpHeaders {
if (value == null) {
throw new NullPointerException("value");
}
validateHeaderValue(value);
if (validate) {
validateHeaderValue0(value);
}
String oldValue = this.value;
this.value = value;
return oldValue;

View File

@ -25,16 +25,24 @@ import java.util.Map;
public abstract class DefaultHttpMessage extends DefaultHttpObject implements HttpMessage {
private HttpVersion version;
private final HttpHeaders headers = new DefaultHttpHeaders();
private final HttpHeaders headers;
/**
* Creates a new instance.
*/
protected DefaultHttpMessage(final HttpVersion version) {
this(version, true);
}
/**
* Creates a new instance.
*/
protected DefaultHttpMessage(final HttpVersion version, boolean validateHeaders) {
if (version == null) {
throw new NullPointerException("version");
}
this.version = version;
headers = new DefaultHttpHeaders(validateHeaders);
}
@Override

View File

@ -33,7 +33,19 @@ public class DefaultHttpRequest extends DefaultHttpMessage implements HttpReques
* @param uri the URI or path of the request
*/
public DefaultHttpRequest(HttpVersion httpVersion, HttpMethod method, String uri) {
super(httpVersion);
this(httpVersion, method, uri, true);
}
/**
* Creates a new instance.
*
* @param httpVersion the HTTP version of the request
* @param method the HTTP getMethod of the request
* @param uri the URI or path of the request
* @param validateHeaders validate the header names and values when adding them to the {@link HttpHeaders}.
*/
public DefaultHttpRequest(HttpVersion httpVersion, HttpMethod method, String uri, boolean validateHeaders) {
super(httpVersion, validateHeaders);
if (method == null) {
throw new NullPointerException("getMethod");
}

View File

@ -31,7 +31,18 @@ public class DefaultHttpResponse extends DefaultHttpMessage implements HttpRespo
* @param status the getStatus of this response
*/
public DefaultHttpResponse(HttpVersion version, HttpResponseStatus status) {
super(version);
this(version, status, true);
}
/**
* Creates a new instance.
*
* @param version the HTTP version of this response
* @param status the getStatus of this response
* @param validateHeaders validate the header names and values when adding them to the {@link HttpHeaders}.
*/
public DefaultHttpResponse(HttpVersion version, HttpResponseStatus status, boolean validateHeaders) {
super(version, validateHeaders);
if (status == null) {
throw new NullPointerException("status");
}

View File

@ -26,25 +26,19 @@ import java.util.Map;
*/
public class DefaultLastHttpContent extends DefaultHttpContent implements LastHttpContent {
private final HttpHeaders trailingHeaders = new DefaultHttpHeaders() {
@Override
void validateHeaderName0(String name) {
super.validateHeaderName0(name);
if (name.equalsIgnoreCase(HttpHeaders.Names.CONTENT_LENGTH) ||
name.equalsIgnoreCase(HttpHeaders.Names.TRANSFER_ENCODING) ||
name.equalsIgnoreCase(HttpHeaders.Names.TRAILER)) {
throw new IllegalArgumentException(
"prohibited trailing header: " + name);
}
}
};
private final HttpHeaders trailingHeaders;
public DefaultLastHttpContent() {
this(Unpooled.buffer(0));
}
public DefaultLastHttpContent(ByteBuf content) {
this(content, true);
}
public DefaultLastHttpContent(ByteBuf content, boolean validateHeaders) {
super(content);
trailingHeaders = new TrailingHeaders(validateHeaders);
}
@Override
@ -97,4 +91,52 @@ public class DefaultLastHttpContent extends DefaultHttpContent implements LastHt
buf.append(StringUtil.NEWLINE);
}
}
private static final class TrailingHeaders extends DefaultHttpHeaders {
TrailingHeaders(boolean validateHeaders) {
super(validateHeaders);
}
@Override
public HttpHeaders add(String name, Object value) {
if (validate) {
validateName(name);
}
return super.add(name, value);
}
@Override
public HttpHeaders add(String name, Iterable<?> values) {
if (validate) {
validateName(name);
}
return super.add(name, values);
}
@Override
public HttpHeaders set(String name, Iterable<?> values) {
if (validate) {
validateName(name);
}
return super.set(name, values);
}
@Override
public HttpHeaders set(String name, Object value) {
if (validate) {
validateName(name);
}
return super.set(name, value);
}
private static void validateName(String name) {
if (name.equalsIgnoreCase(HttpHeaders.Names.CONTENT_LENGTH) ||
name.equalsIgnoreCase(HttpHeaders.Names.TRANSFER_ENCODING) ||
name.equalsIgnoreCase(HttpHeaders.Names.TRAILER)) {
throw new IllegalArgumentException(
"prohibited trailing header: " + name);
}
}
}
}

View File

@ -546,12 +546,14 @@ public abstract class HttpHeaders implements Iterable<Map.Entry<String, String>>
*/
public static boolean isKeepAlive(HttpMessage message) {
String connection = message.headers().get(Names.CONNECTION);
if (Values.CLOSE.equalsIgnoreCase(connection)) {
boolean close = Values.CLOSE.equalsIgnoreCase(connection);
if (close) {
return false;
}
if (message.getProtocolVersion().isKeepAliveDefault()) {
return !Values.CLOSE.equalsIgnoreCase(connection);
return !close;
} else {
return Values.KEEP_ALIVE.equalsIgnoreCase(connection);
}
@ -1024,21 +1026,24 @@ public abstract class HttpHeaders implements Iterable<Map.Entry<String, String>>
for (int index = 0; index < headerName.length(); index ++) {
//Actually get the character
char character = headerName.charAt(index);
valideHeaderNameChar(character);
}
}
//Check to see if the character is not an ASCII character
if (character > 127) {
static void valideHeaderNameChar(char c) {
//Check to see if the character is not an ASCII character
if (c > 127) {
throw new IllegalArgumentException(
"Header name cannot contain non-ASCII characters: " + c);
}
//Check for prohibited characters.
switch (c) {
case '\t': case '\n': case 0x0b: case '\f': case '\r':
case ' ': case ',': case ':': case ';': case '=':
throw new IllegalArgumentException(
"Header name cannot contain non-ASCII characters: " + headerName);
}
//Check for prohibited characters.
switch (character) {
case '\t': case '\n': case 0x0b: case '\f': case '\r':
case ' ': case ',': case ':': case ';': case '=':
throw new IllegalArgumentException(
"Header name cannot contain the following prohibited characters: " +
"=,;: \\t\\r\\n\\v\\f: " + headerName);
}
"Header name cannot contain the following prohibited characters: " +
"=,;: \\t\\r\\n\\v\\f ");
}
}

View File

@ -50,6 +50,10 @@ import io.netty.handler.codec.TooLongFrameException;
* {@link HttpContent}s in your handler, insert {@link HttpObjectAggregator}
* after this decoder in the {@link ChannelPipeline}.</td>
* </tr>
* <tr>
* <td>{@code validateHeaders}</td>
* <td>Specify if the headers should be validated during adding them for invalid chars.</td>
* </tr>
* </table>
*/
public class HttpRequestDecoder extends HttpObjectDecoder {
@ -67,13 +71,21 @@ public class HttpRequestDecoder extends HttpObjectDecoder {
*/
public HttpRequestDecoder(
int maxInitialLineLength, int maxHeaderSize, int maxChunkSize) {
super(maxInitialLineLength, maxHeaderSize, maxChunkSize, true);
this(maxInitialLineLength, maxHeaderSize, maxChunkSize, true);
}
/**
* Creates a new instance with the specified parameters.
*/
public HttpRequestDecoder(
int maxInitialLineLength, int maxHeaderSize, int maxChunkSize, boolean validateHeaders) {
super(maxInitialLineLength, maxHeaderSize, maxChunkSize, true, validateHeaders);
}
@Override
protected HttpMessage createMessage(String[] initialLine) throws Exception {
protected HttpMessage createMessage(String first, String second, String third) throws Exception {
return new DefaultHttpRequest(
HttpVersion.valueOf(initialLine[2]), HttpMethod.valueOf(initialLine[0]), initialLine[1]);
HttpVersion.valueOf(third), HttpMethod.valueOf(first), second, validateHeaders);
}
@Override

View File

@ -51,6 +51,10 @@ import io.netty.handler.codec.TooLongFrameException;
* {@link HttpContent}s in your handler, insert {@link HttpObjectAggregator}
* after this decoder in the {@link ChannelPipeline}.</td>
* </tr>
* <tr>
* <td>{@code validateHeaders}</td>
* <td>Specify if the headers should be validated during adding them for invalid chars.</td>
* </tr>
* </table>
*
* <h3>Decoding a response for a <tt>HEAD</tt> request</h3>
@ -98,14 +102,22 @@ public class HttpResponseDecoder extends HttpObjectDecoder {
*/
public HttpResponseDecoder(
int maxInitialLineLength, int maxHeaderSize, int maxChunkSize) {
super(maxInitialLineLength, maxHeaderSize, maxChunkSize, true);
this(maxInitialLineLength, maxHeaderSize, maxChunkSize, true);
}
/**
* Creates a new instance with the specified parameters.
*/
public HttpResponseDecoder(
int maxInitialLineLength, int maxHeaderSize, int maxChunkSize, boolean validateHeaders) {
super(maxInitialLineLength, maxHeaderSize, maxChunkSize, true, validateHeaders);
}
@Override
protected HttpMessage createMessage(String[] initialLine) {
protected HttpMessage createMessage(String first, String second, String third) throws Exception {
return new DefaultHttpResponse(
HttpVersion.valueOf(initialLine[0]),
new HttpResponseStatus(Integer.valueOf(initialLine[1]), initialLine[2]));
HttpVersion.valueOf(first),
new HttpResponseStatus(Integer.valueOf(second), third), validateHeaders);
}
@Override

View File

@ -63,7 +63,7 @@ public abstract class RtspObjectDecoder extends HttpObjectDecoder {
* Creates a new instance with the specified parameters.
*/
protected RtspObjectDecoder(int maxInitialLineLength, int maxHeaderSize, int maxContentLength) {
super(maxInitialLineLength, maxHeaderSize, maxContentLength * 2, false);
super(maxInitialLineLength, maxHeaderSize, maxContentLength * 2, false, true);
}
@Override

View File

@ -66,9 +66,9 @@ public class RtspRequestDecoder extends RtspObjectDecoder {
}
@Override
protected HttpMessage createMessage(String[] initialLine) throws Exception {
return new DefaultHttpRequest(RtspVersions.valueOf(initialLine[2]),
RtspMethods.valueOf(initialLine[0]), initialLine[1]);
protected HttpMessage createMessage(String first, String second, String third) throws Exception {
return new DefaultHttpRequest(RtspVersions.valueOf(third),
RtspMethods.valueOf(first), second);
}
@Override

View File

@ -70,10 +70,10 @@ public class RtspResponseDecoder extends RtspObjectDecoder {
}
@Override
protected HttpMessage createMessage(String[] initialLine) throws Exception {
protected HttpMessage createMessage(String first, String second, String third) throws Exception {
return new DefaultHttpResponse(
RtspVersions.valueOf(initialLine[0]),
new HttpResponseStatus(Integer.valueOf(initialLine[1]), initialLine[2]));
RtspVersions.valueOf(first),
new HttpResponseStatus(Integer.valueOf(second), third));
}
@Override

View File

@ -20,6 +20,7 @@ import io.netty.buffer.Unpooled;
import io.netty.channel.embedded.EmbeddedChannel;
import io.netty.handler.codec.DecoderResult;
import io.netty.util.CharsetUtil;
import org.junit.Ignore;
import org.junit.Test;
import java.util.Random;
@ -41,6 +42,7 @@ public class HttpInvalidMessageTest {
ensureInboundTrafficDiscarded(ch);
}
@Ignore("expected ATM")
@Test
public void testRequestWithBadHeader() throws Exception {
EmbeddedChannel ch = new EmbeddedChannel(new HttpRequestDecoder());
@ -68,6 +70,7 @@ public class HttpInvalidMessageTest {
ensureInboundTrafficDiscarded(ch);
}
@Ignore("expected ATM")
@Test
public void testResponseWithBadHeader() throws Exception {
EmbeddedChannel ch = new EmbeddedChannel(new HttpResponseDecoder());

View File

@ -0,0 +1,157 @@
/*
* Copyright 2013 The Netty Project
*
* The Netty Project licenses this file to you under the Apache License,
* version 2.0 (the "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at:
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*/
package io.netty.handler.codec.http;
import io.netty.buffer.Unpooled;
import io.netty.channel.embedded.EmbeddedChannel;
import io.netty.util.CharsetUtil;
import org.junit.Assert;
import org.junit.Test;
import java.util.List;
public class HttpRequestDecoderTest {
private static final byte[] CONTENT_CRLF_DELIMITERS = createContent("\r\n");
private static final byte[] CONTENT_LF_DELIMITERS = createContent("\n");
private static final byte[] CONTENT_MIXED_DELIMITERS = createContent("\r\n", "\n");
private static byte[] createContent(String... lineDelimiters) {
String lineDelimiter;
String lineDelimiter2;
if (lineDelimiters.length == 2) {
lineDelimiter = lineDelimiters[0];
lineDelimiter2 = lineDelimiters[1];
} else {
lineDelimiter = lineDelimiters[0];
lineDelimiter2 = lineDelimiters[0];
}
return ("GET /some/path?foo=bar&wibble=eek HTTP/1.1" + "\r\n" +
"Upgrade: WebSocket" + lineDelimiter2 +
"Connection: Upgrade" + lineDelimiter +
"Host: localhost" + lineDelimiter2 +
"Origin: http://localhost:8080" + lineDelimiter +
"Sec-WebSocket-Key1: 10 28 8V7 8 48 0" + lineDelimiter2 +
"Sec-WebSocket-Key2: 8 Xt754O3Q3QW 0 _60" + lineDelimiter +
"Content-Length: 8" + lineDelimiter2 +
"\r\n" +
"12345678").getBytes(CharsetUtil.US_ASCII);
}
@Test
public void testDecodeWholeRequestAtOnceCRLFDelimiters() {
testDecodeWholeRequestAtOnce(CONTENT_CRLF_DELIMITERS);
}
@Test
public void testDecodeWholeRequestAtOnceLFDelimiters() {
testDecodeWholeRequestAtOnce(CONTENT_LF_DELIMITERS);
}
@Test
public void testDecodeWholeRequestAtOnceMixedDelimiters() {
testDecodeWholeRequestAtOnce(CONTENT_MIXED_DELIMITERS);
}
private static void testDecodeWholeRequestAtOnce(byte[] content) {
EmbeddedChannel channel = new EmbeddedChannel(new HttpRequestDecoder());
Assert.assertTrue(channel.writeInbound(Unpooled.wrappedBuffer(content)));
HttpRequest req = (HttpRequest) channel.readInbound();
Assert.assertNotNull(req);
checkHeaders(req.headers());
LastHttpContent c = (LastHttpContent) channel.readInbound();
Assert.assertEquals(8, c.content().readableBytes());
Assert.assertEquals(Unpooled.wrappedBuffer(content, content.length - 8, 8), c.content().readBytes(8));
Assert.assertFalse(channel.finish());
Assert.assertNull(channel.readInbound());
}
private static void checkHeaders(HttpHeaders headers) {
Assert.assertEquals(7, headers.names().size());
checkHeader(headers, "Upgrade", "WebSocket");
checkHeader(headers, "Connection", "Upgrade");
checkHeader(headers, "Host", "localhost");
checkHeader(headers, "Origin", "http://localhost:8080");
checkHeader(headers, "Sec-WebSocket-Key1", "10 28 8V7 8 48 0");
checkHeader(headers, "Sec-WebSocket-Key2", "8 Xt754O3Q3QW 0 _60");
checkHeader(headers, "Content-Length", "8");
}
private static void checkHeader(HttpHeaders headers, String name, String value) {
List<String> header1 = headers.getAll(name);
Assert.assertEquals(1, header1.size());
Assert.assertEquals(value, header1.get(0));
}
@Test
public void testDecodeWholeRequestInMultipleStepsCRLFDelimiters() {
testDecodeWholeRequestInMultipleSteps(CONTENT_CRLF_DELIMITERS);
}
@Test
public void testDecodeWholeRequestInMultipleStepsLFDelimiters() {
testDecodeWholeRequestInMultipleSteps(CONTENT_LF_DELIMITERS);
}
@Test
public void testDecodeWholeRequestInMultipleStepsMixedDelimiters() {
testDecodeWholeRequestInMultipleSteps(CONTENT_MIXED_DELIMITERS);
}
private static void testDecodeWholeRequestInMultipleSteps(byte[] content) {
for (int i = 1; i < content.length; i++) {
testDecodeWholeRequestInMultipleSteps(content, i);
}
}
private static void testDecodeWholeRequestInMultipleSteps(byte[] content, int fragmentSize) {
EmbeddedChannel channel = new EmbeddedChannel(new HttpRequestDecoder());
int headerLength = content.length - 8;
// split up the header
for (int a = 0; a < headerLength;) {
int amount = fragmentSize;
if (a + amount > headerLength) {
amount = headerLength - a;
}
// if header is done it should produce a HttpRequest
boolean headerDone = a + amount == headerLength;
Assert.assertEquals(headerDone, channel.writeInbound(Unpooled.wrappedBuffer(content, a, amount)));
a += amount;
}
for (int i = 8; i > 0; i--) {
// Should produce HttpContent
Assert.assertTrue(channel.writeInbound(Unpooled.wrappedBuffer(content, content.length - i, 1)));
}
HttpRequest req = (HttpRequest) channel.readInbound();
Assert.assertNotNull(req);
checkHeaders(req.headers());
for (int i = 8; i > 1; i--) {
HttpContent c = (HttpContent) channel.readInbound();
Assert.assertEquals(1, c.content().readableBytes());
Assert.assertEquals(content[content.length - i], c.content().readByte());
}
LastHttpContent c = (LastHttpContent) channel.readInbound();
Assert.assertEquals(1, c.content().readableBytes());
Assert.assertEquals(content[content.length - 1], c.content().readByte());
Assert.assertFalse(channel.finish());
Assert.assertNull(channel.readInbound());
}
}

View File

@ -20,10 +20,89 @@ import io.netty.channel.embedded.EmbeddedChannel;
import io.netty.util.CharsetUtil;
import org.junit.Test;
import java.util.List;
import static org.hamcrest.CoreMatchers.*;
import static org.junit.Assert.*;
public class HttpResponseDecoderTest {
@Test
public void testResponseChunked() {
EmbeddedChannel ch = new EmbeddedChannel(new HttpResponseDecoder());
ch.writeInbound(Unpooled.copiedBuffer("HTTP/1.1 200 OK\r\nTransfer-Encoding: chunked\r\n\r\n",
CharsetUtil.US_ASCII));
HttpResponse res = (HttpResponse) ch.readInbound();
assertThat(res.getProtocolVersion(), sameInstance(HttpVersion.HTTP_1_1));
assertThat(res.getStatus(), is(HttpResponseStatus.OK));
byte[] data = new byte[64];
for (int i = 0; i < data.length; i++) {
data[i] = (byte) i;
}
for (int i = 0; i < 10; i++) {
assertFalse(ch.writeInbound(Unpooled.copiedBuffer(Integer.toHexString(data.length) + "\r\n",
CharsetUtil.US_ASCII)));
assertTrue(ch.writeInbound(Unpooled.wrappedBuffer(data)));
HttpContent content = (HttpContent) ch.readInbound();
assertEquals(data.length, content.content().readableBytes());
byte[] decodedData = new byte[data.length];
content.content().readBytes(decodedData);
assertArrayEquals(data, decodedData);
assertFalse(ch.writeInbound(Unpooled.copiedBuffer("\r\n", CharsetUtil.US_ASCII)));
}
assertTrue(ch.finish());
LastHttpContent content = (LastHttpContent) ch.readInbound();
assertFalse(content.content().isReadable());
assertNull(ch.readInbound());
}
@Test
public void testResponseChunkedExceedMaxChunkSize() {
EmbeddedChannel ch = new EmbeddedChannel(new HttpResponseDecoder(4096, 8192, 32));
ch.writeInbound(Unpooled.copiedBuffer("HTTP/1.1 200 OK\r\nTransfer-Encoding: chunked\r\n\r\n",
CharsetUtil.US_ASCII));
HttpResponse res = (HttpResponse) ch.readInbound();
assertThat(res.getProtocolVersion(), sameInstance(HttpVersion.HTTP_1_1));
assertThat(res.getStatus(), is(HttpResponseStatus.OK));
byte[] data = new byte[64];
for (int i = 0; i < data.length; i++) {
data[i] = (byte) i;
}
for (int i = 0; i < 10; i++) {
assertFalse(ch.writeInbound(Unpooled.copiedBuffer(Integer.toHexString(data.length) + "\r\n",
CharsetUtil.US_ASCII)));
assertTrue(ch.writeInbound(Unpooled.wrappedBuffer(data)));
byte[] decodedData = new byte[data.length];
HttpContent content = (HttpContent) ch.readInbound();
assertEquals(32, content.content().readableBytes());
content.content().readBytes(decodedData, 0, 32);
content = (HttpContent) ch.readInbound();
assertEquals(32, content.content().readableBytes());
content.content().readBytes(decodedData, 32, 32);
assertArrayEquals(data, decodedData);
assertFalse(ch.writeInbound(Unpooled.copiedBuffer("\r\n", CharsetUtil.US_ASCII)));
}
assertTrue(ch.finish());
LastHttpContent content = (LastHttpContent) ch.readInbound();
assertFalse(content.content().isReadable());
assertNull(ch.readInbound());
}
@Test
public void testLastResponseWithEmptyHeaderAndEmptyContent() {
EmbeddedChannel ch = new EmbeddedChannel(new HttpResponseDecoder());
@ -63,4 +142,202 @@ public class HttpResponseDecoderTest {
assertThat(ch.readInbound(), is(nullValue()));
}
@Test
public void testLastResponseWithHeaderRemoveTrailingSpaces() {
EmbeddedChannel ch = new EmbeddedChannel(new HttpResponseDecoder());
ch.writeInbound(Unpooled.copiedBuffer(
"HTTP/1.1 200 OK\r\nX-Header: h2=h2v2; Expires=Wed, 09-Jun-2021 10:18:14 GMT \r\n\r\n",
CharsetUtil.US_ASCII));
HttpResponse res = (HttpResponse) ch.readInbound();
assertThat(res.getProtocolVersion(), sameInstance(HttpVersion.HTTP_1_1));
assertThat(res.getStatus(), is(HttpResponseStatus.OK));
assertThat(res.headers().get("X-Header"), is("h2=h2v2; Expires=Wed, 09-Jun-2021 10:18:14 GMT"));
assertThat(ch.readInbound(), is(nullValue()));
ch.writeInbound(Unpooled.wrappedBuffer(new byte[1024]));
HttpContent content = (HttpContent) ch.readInbound();
assertThat(content.content().readableBytes(), is(1024));
assertThat(ch.finish(), is(true));
LastHttpContent lastContent = (LastHttpContent) ch.readInbound();
assertThat(lastContent.content().isReadable(), is(false));
assertThat(ch.readInbound(), is(nullValue()));
}
@Test
public void testLastResponseWithTrailingHeader() {
EmbeddedChannel ch = new EmbeddedChannel(new HttpResponseDecoder());
ch.writeInbound(Unpooled.copiedBuffer(
"HTTP/1.1 200 OK\r\n" +
"Transfer-Encoding: chunked\r\n" +
"\r\n" +
"0\r\n" +
"Set-Cookie: t1=t1v1\r\n" +
"Set-Cookie: t2=t2v2; Expires=Wed, 09-Jun-2021 10:18:14 GMT\r\n" +
"\r\n",
CharsetUtil.US_ASCII));
HttpResponse res = (HttpResponse) ch.readInbound();
assertThat(res.getProtocolVersion(), sameInstance(HttpVersion.HTTP_1_1));
assertThat(res.getStatus(), is(HttpResponseStatus.OK));
LastHttpContent lastContent = (LastHttpContent) ch.readInbound();
assertThat(lastContent.content().isReadable(), is(false));
HttpHeaders headers = lastContent.trailingHeaders();
assertEquals(1, headers.names().size());
List<String> values = headers.getAll("Set-Cookie");
assertEquals(2, values.size());
assertTrue(values.contains("t1=t1v1"));
assertTrue(values.contains("t2=t2v2; Expires=Wed, 09-Jun-2021 10:18:14 GMT"));
assertThat(ch.finish(), is(false));
assertThat(ch.readInbound(), is(nullValue()));
}
@Test
public void testLastResponseWithTrailingHeaderFragmented() {
byte[] data = ("HTTP/1.1 200 OK\r\n" +
"Transfer-Encoding: chunked\r\n" +
"\r\n" +
"0\r\n" +
"Set-Cookie: t1=t1v1\r\n" +
"Set-Cookie: t2=t2v2; Expires=Wed, 09-Jun-2021 10:18:14 GMT\r\n" +
"\r\n").getBytes(CharsetUtil.US_ASCII);
for (int i = 1; i < data.length; i++) {
testLastResponseWithTrailingHeaderFragmented(data, i);
}
}
private static void testLastResponseWithTrailingHeaderFragmented(byte[] content, int fragmentSize) {
EmbeddedChannel ch = new EmbeddedChannel(new HttpResponseDecoder());
int headerLength = 47;
// split up the header
for (int a = 0; a < headerLength;) {
int amount = fragmentSize;
if (a + amount > headerLength) {
amount = headerLength - a;
}
// if header is done it should produce a HttpRequest
boolean headerDone = a + amount == headerLength;
assertEquals(headerDone, ch.writeInbound(Unpooled.wrappedBuffer(content, a, amount)));
a += amount;
}
ch.writeInbound(Unpooled.wrappedBuffer(content, headerLength, content.length - headerLength));
HttpResponse res = (HttpResponse) ch.readInbound();
assertThat(res.getProtocolVersion(), sameInstance(HttpVersion.HTTP_1_1));
assertThat(res.getStatus(), is(HttpResponseStatus.OK));
LastHttpContent lastContent = (LastHttpContent) ch.readInbound();
assertThat(lastContent.content().isReadable(), is(false));
HttpHeaders headers = lastContent.trailingHeaders();
assertEquals(1, headers.names().size());
List<String> values = headers.getAll("Set-Cookie");
assertEquals(2, values.size());
assertTrue(values.contains("t1=t1v1"));
assertTrue(values.contains("t2=t2v2; Expires=Wed, 09-Jun-2021 10:18:14 GMT"));
assertThat(ch.finish(), is(false));
assertThat(ch.readInbound(), is(nullValue()));
}
@Test
public void testResponseWithContentLength() {
EmbeddedChannel ch = new EmbeddedChannel(new HttpResponseDecoder());
ch.writeInbound(Unpooled.copiedBuffer(
"HTTP/1.1 200 OK\r\n" +
"Content-Length: 10\r\n" +
"\r\n", CharsetUtil.US_ASCII));
HttpResponse res = (HttpResponse) ch.readInbound();
assertThat(res.getProtocolVersion(), sameInstance(HttpVersion.HTTP_1_1));
assertThat(res.getStatus(), is(HttpResponseStatus.OK));
byte[] data = new byte[10];
for (int i = 0; i < data.length; i++) {
data[i] = (byte) i;
}
ch.writeInbound(Unpooled.wrappedBuffer(data, 0, data.length / 2));
HttpContent content = (HttpContent) ch.readInbound();
assertEquals(content.content().readableBytes(), 5);
ch.writeInbound(Unpooled.wrappedBuffer(data, 5, data.length / 2));
LastHttpContent lastContent = (LastHttpContent) ch.readInbound();
assertEquals(lastContent.content().readableBytes(), 5);
assertThat(ch.finish(), is(false));
assertThat(ch.readInbound(), is(nullValue()));
}
@Test
public void testResponseWithContentLengthFragmented() {
byte[] data = ("HTTP/1.1 200 OK\r\n" +
"Content-Length: 10\r\n" +
"\r\n").getBytes(CharsetUtil.US_ASCII);
for (int i = 1; i < data.length; i++) {
testResponseWithContentLengthFragmented(data, i);
}
}
private static void testResponseWithContentLengthFragmented(byte[] header, int fragmentSize) {
EmbeddedChannel ch = new EmbeddedChannel(new HttpResponseDecoder());
// split up the header
for (int a = 0; a < header.length;) {
int amount = fragmentSize;
if (a + amount > header.length) {
amount = header.length - a;
}
// if header is done it should produce a HttpRequest
boolean headerDone = a + amount == header.length;
assertEquals(headerDone, ch.writeInbound(Unpooled.wrappedBuffer(header, a, amount)));
a += amount;
}
HttpResponse res = (HttpResponse) ch.readInbound();
assertThat(res.getProtocolVersion(), sameInstance(HttpVersion.HTTP_1_1));
assertThat(res.getStatus(), is(HttpResponseStatus.OK));
byte[] data = new byte[10];
for (int i = 0; i < data.length; i++) {
data[i] = (byte) i;
}
ch.writeInbound(Unpooled.wrappedBuffer(data, 0, data.length / 2));
HttpContent content = (HttpContent) ch.readInbound();
assertEquals(content.content().readableBytes(), 5);
ch.writeInbound(Unpooled.wrappedBuffer(data, 5, data.length / 2));
LastHttpContent lastContent = (LastHttpContent) ch.readInbound();
assertEquals(lastContent.content().readableBytes(), 5);
assertThat(ch.finish(), is(false));
assertThat(ch.readInbound(), is(nullValue()));
}
@Test
public void testWebSocketResponse() {
byte[] data = ("HTTP/1.1 101 WebSocket Protocol Handshake\r\n" +
"Upgrade: WebSocket\r\n" +
"Connection: Upgrade\r\n" +
"Sec-WebSocket-Origin: http://localhost:8080\r\n" +
"Sec-WebSocket-Location: ws://localhost/some/path\r\n" +
"\r\n" +
"1234567812345678").getBytes();
EmbeddedChannel ch = new EmbeddedChannel(new HttpResponseDecoder());
ch.writeInbound(Unpooled.wrappedBuffer(data));
HttpResponse res = (HttpResponse) ch.readInbound();
assertThat(res.getProtocolVersion(), sameInstance(HttpVersion.HTTP_1_1));
assertThat(res.getStatus(), is(HttpResponseStatus.SWITCHING_PROTOCOLS));
HttpContent content = (HttpContent) ch.readInbound();
assertThat(content.content().readableBytes(), is(16));
assertThat(ch.finish(), is(false));
assertThat(ch.readInbound(), is(nullValue()));
}
}