Overall clean-up of HttpObjectAggregator / Handle oversized response differently

- Related: #2211
This commit is contained in:
Trustin Lee 2014-02-20 11:35:23 -08:00
parent 91376263d7
commit fcc41a62bd
2 changed files with 97 additions and 78 deletions

View File

@ -24,7 +24,10 @@ import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.ChannelPipeline; import io.netty.channel.ChannelPipeline;
import io.netty.handler.codec.DecoderResult; import io.netty.handler.codec.DecoderResult;
import io.netty.handler.codec.MessageToMessageDecoder; import io.netty.handler.codec.MessageToMessageDecoder;
import io.netty.handler.codec.TooLongFrameException;
import io.netty.util.ReferenceCountUtil; import io.netty.util.ReferenceCountUtil;
import io.netty.util.internal.logging.InternalLogger;
import io.netty.util.internal.logging.InternalLoggerFactory;
import java.util.List; import java.util.List;
@ -47,17 +50,27 @@ import static io.netty.handler.codec.http.HttpHeaders.*;
* </pre> * </pre>
*/ */
public class HttpObjectAggregator extends MessageToMessageDecoder<HttpObject> { public class HttpObjectAggregator extends MessageToMessageDecoder<HttpObject> {
public static final int DEFAULT_MAX_COMPOSITEBUFFER_COMPONENTS = 1024;
/**
* @deprecated Will be removed in the next minor version bump.
*/
@Deprecated
public static final int DEFAULT_MAX_COMPOSITEBUFFER_COMPONENTS = 1024; // TODO: Make it private in the next bump.
private static final InternalLogger logger = InternalLoggerFactory.getInstance(HttpObjectAggregator.class);
private static final FullHttpResponse CONTINUE = new DefaultFullHttpResponse( private static final FullHttpResponse CONTINUE = new DefaultFullHttpResponse(
HttpVersion.HTTP_1_1, HttpResponseStatus.CONTINUE, Unpooled.EMPTY_BUFFER); HttpVersion.HTTP_1_1, HttpResponseStatus.CONTINUE, Unpooled.EMPTY_BUFFER);
private static final FullHttpResponse TOO_LARGE = new DefaultFullHttpResponse( private static final FullHttpResponse TOO_LARGE = new DefaultFullHttpResponse(
HttpVersion.HTTP_1_1, HttpResponseStatus.REQUEST_ENTITY_TOO_LARGE, Unpooled.EMPTY_BUFFER); HttpVersion.HTTP_1_1, HttpResponseStatus.REQUEST_ENTITY_TOO_LARGE, Unpooled.EMPTY_BUFFER);
static {
TOO_LARGE.headers().set(Names.CONTENT_LENGTH, 0);
}
private final int maxContentLength; private final int maxContentLength;
private FullHttpMessage currentMessage; private FullHttpMessage currentMessage;
private boolean tooLongFrameFound; private boolean handlingOversizedMessage;
private boolean is100Continue;
private boolean isKeepAlive;
private int maxCumulationBufferComponents = DEFAULT_MAX_COMPOSITEBUFFER_COMPONENTS; private int maxCumulationBufferComponents = DEFAULT_MAX_COMPOSITEBUFFER_COMPONENTS;
private ChannelHandlerContext ctx; private ChannelHandlerContext ctx;
@ -68,7 +81,7 @@ public class HttpObjectAggregator extends MessageToMessageDecoder<HttpObject> {
* @param maxContentLength * @param maxContentLength
* the maximum length of the aggregated content. * the maximum length of the aggregated content.
* If the length of the aggregated content exceeds this value, * If the length of the aggregated content exceeds this value,
* {@link #messageTooLong(io.netty.channel.ChannelHandlerContext, HttpObject)} * {@link #handleOversizedMessage(ChannelHandlerContext, HttpMessage)}
* will be called. * will be called.
*/ */
public HttpObjectAggregator(int maxContentLength) { public HttpObjectAggregator(int maxContentLength) {
@ -84,7 +97,7 @@ public class HttpObjectAggregator extends MessageToMessageDecoder<HttpObject> {
* Returns the maximum number of components in the cumulation buffer. If the number of * Returns the maximum number of components in the cumulation buffer. If the number of
* the components in the cumulation buffer exceeds this value, the components of the * the components in the cumulation buffer exceeds this value, the components of the
* cumulation buffer are consolidated into a single component, involving memory copies. * cumulation buffer are consolidated into a single component, involving memory copies.
* The default value of this property is {@link #DEFAULT_MAX_COMPOSITEBUFFER_COMPONENTS}. * The default value of this property is {@value #DEFAULT_MAX_COMPOSITEBUFFER_COMPONENTS}.
*/ */
public final int getMaxCumulationBufferComponents() { public final int getMaxCumulationBufferComponents() {
return maxCumulationBufferComponents; return maxCumulationBufferComponents;
@ -94,7 +107,7 @@ public class HttpObjectAggregator extends MessageToMessageDecoder<HttpObject> {
* Sets the maximum number of components in the cumulation buffer. If the number of * Sets the maximum number of components in the cumulation buffer. If the number of
* the components in the cumulation buffer exceeds this value, the components of the * the components in the cumulation buffer exceeds this value, the components of the
* cumulation buffer are consolidated into a single component, involving memory copies. * cumulation buffer are consolidated into a single component, involving memory copies.
* The default value of this property is {@link #DEFAULT_MAX_COMPOSITEBUFFER_COMPONENTS} * The default value of this property is {@value #DEFAULT_MAX_COMPOSITEBUFFER_COMPONENTS}
* and its minimum allowed value is {@code 2}. * and its minimum allowed value is {@code 2}.
*/ */
public final void setMaxCumulationBufferComponents(int maxCumulationBufferComponents) { public final void setMaxCumulationBufferComponents(int maxCumulationBufferComponents) {
@ -117,25 +130,22 @@ public class HttpObjectAggregator extends MessageToMessageDecoder<HttpObject> {
FullHttpMessage currentMessage = this.currentMessage; FullHttpMessage currentMessage = this.currentMessage;
if (msg instanceof HttpMessage) { if (msg instanceof HttpMessage) {
tooLongFrameFound = false; handlingOversizedMessage = false;
assert currentMessage == null; assert currentMessage == null;
HttpMessage m = (HttpMessage) msg; HttpMessage m = (HttpMessage) msg;
is100Continue = is100ContinueExpected(m);
isKeepAlive = isKeepAlive(m);
// if content length is set, preemptively close if it's too large // if content length is set, preemptively close if it's too large
if (isContentLengthSet(m)) { if (isContentLengthSet(m)) {
if (getContentLength(m) > maxContentLength) { if (getContentLength(m) > maxContentLength) {
// handle oversized message // handle oversized message
handleMessageTooLong(ctx, m); invokeHandleOversizedMessage(ctx, m);
return; return;
} }
} }
// Handle the 'Expect: 100-continue' header if necessary. // Handle the 'Expect: 100-continue' header if necessary.
if (is100Continue) { if (is100ContinueExpected(m)) {
ctx.writeAndFlush(CONTINUE).addListener(new ChannelFutureListener() { ctx.writeAndFlush(CONTINUE).addListener(new ChannelFutureListener() {
@Override @Override
public void operationComplete(ChannelFuture future) throws Exception { public void operationComplete(ChannelFuture future) throws Exception {
@ -170,7 +180,7 @@ public class HttpObjectAggregator extends MessageToMessageDecoder<HttpObject> {
// A streamed message - initialize the cumulative buffer, and wait for incoming chunks. // A streamed message - initialize the cumulative buffer, and wait for incoming chunks.
removeTransferEncodingChunked(currentMessage); removeTransferEncodingChunked(currentMessage);
} else if (msg instanceof HttpContent) { } else if (msg instanceof HttpContent) {
if (tooLongFrameFound) { if (handlingOversizedMessage) {
if (msg instanceof LastHttpContent) { if (msg instanceof LastHttpContent) {
this.currentMessage = null; this.currentMessage = null;
} }
@ -185,7 +195,7 @@ public class HttpObjectAggregator extends MessageToMessageDecoder<HttpObject> {
if (content.readableBytes() > maxContentLength - chunk.content().readableBytes()) { if (content.readableBytes() > maxContentLength - chunk.content().readableBytes()) {
// handle oversized message // handle oversized message
handleMessageTooLong(ctx, currentMessage); invokeHandleOversizedMessage(ctx, currentMessage);
return; return;
} }
@ -226,41 +236,56 @@ public class HttpObjectAggregator extends MessageToMessageDecoder<HttpObject> {
} }
} }
private void handleMessageTooLong(ChannelHandlerContext ctx, HttpObject msg) throws Exception { private void invokeHandleOversizedMessage(ChannelHandlerContext ctx, HttpMessage msg) throws Exception {
tooLongFrameFound = true; handlingOversizedMessage = true;
currentMessage = null; currentMessage = null;
messageTooLong(ctx, msg); try {
handleOversizedMessage(ctx, msg);
} finally {
ReferenceCountUtil.release(msg);
}
} }
/** /**
* Invoked when an incoming request exceeds the maximum content length. * Invoked when an incoming request exceeds the maximum content length.
* *
* The default behavior returns a {@code 413} HTTP response. * The default behavior is:
* * <ul>
* Sub-classes may override this method to change behavior. <em>Note:</em> * <li>Oversized request: Send a {@link HttpResponseStatus#REQUEST_ENTITY_TOO_LARGE} and close the connection
* you are responsible for releasing {@code msg} if you override the * if keep-alive is not enabled.</li>
* default behavior. * <li>Oversized response: Close the connection and raise {@link TooLongFrameException}.</li>
* </ul>
* Sub-classes may override this method to change the default behavior. The specified {@code msg} is released
* once this method returns.
* *
* @param ctx the {@link ChannelHandlerContext} * @param ctx the {@link ChannelHandlerContext}
* @param msg the accumulated HTTP message up to this point * @param msg the accumulated HTTP message up to this point
*/ */
protected void messageTooLong(ChannelHandlerContext ctx, HttpObject msg) throws Exception { @SuppressWarnings("UnusedParameters")
// release current message to prevent leaks protected void handleOversizedMessage(final ChannelHandlerContext ctx, HttpMessage msg) throws Exception {
ReferenceCountUtil.release(msg); if (msg instanceof HttpRequest) {
// send back a 413 and close the connection
final ChannelHandlerContext context = ctx; ChannelFuture future = ctx.writeAndFlush(TOO_LARGE).addListener(new ChannelFutureListener() {
// send back a 413 and close the connection @Override
ChannelFuture future = ctx.writeAndFlush(TOO_LARGE).addListener(new ChannelFutureListener() { public void operationComplete(ChannelFuture future) throws Exception {
@Override if (!future.isSuccess()) {
public void operationComplete(ChannelFuture future) throws Exception { logger.debug("Failed to send a 413 Request Entity Too Large.", future.cause());
if (!future.isSuccess()) { ctx.close();
context.fireExceptionCaught(future.cause()); }
} }
} });
});
if (!is100Continue || !isKeepAlive) { // If the client started to send data already, close because it's impossible to recover.
future.addListener(ChannelFutureListener.CLOSE); // If 'Expect: 100-continue' is missing, close becuase it's impossible to recover.
// If keep-alive is off, no need to leave the connection open.
if (msg instanceof FullHttpMessage || !is100ContinueExpected(msg) || !isKeepAlive(msg)) {
future.addListener(ChannelFutureListener.CLOSE);
}
} else if (msg instanceof HttpResponse) {
ctx.close();
throw new TooLongFrameException("Response entity too large: " + msg);
} else {
throw new IllegalStateException();
} }
} }

View File

@ -20,6 +20,7 @@ import io.netty.buffer.CompositeByteBuf;
import io.netty.buffer.Unpooled; import io.netty.buffer.Unpooled;
import io.netty.channel.ChannelHandlerContext; import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.embedded.EmbeddedChannel; import io.netty.channel.embedded.EmbeddedChannel;
import io.netty.handler.codec.http.HttpHeaders.Names;
import io.netty.util.CharsetUtil; import io.netty.util.CharsetUtil;
import org.easymock.EasyMock; import org.easymock.EasyMock;
import org.junit.Test; import org.junit.Test;
@ -38,8 +39,7 @@ public class HttpObjectAggregatorTest {
HttpObjectAggregator aggr = new HttpObjectAggregator(1024 * 1024); HttpObjectAggregator aggr = new HttpObjectAggregator(1024 * 1024);
EmbeddedChannel embedder = new EmbeddedChannel(aggr); EmbeddedChannel embedder = new EmbeddedChannel(aggr);
HttpRequest message = new DefaultHttpRequest(HttpVersion.HTTP_1_1, HttpRequest message = new DefaultHttpRequest(HttpVersion.HTTP_1_1, HttpMethod.GET, "http://localhost");
HttpMethod.GET, "http://localhost");
HttpHeaders.setHeader(message, "X-Test", true); HttpHeaders.setHeader(message, "X-Test", true);
HttpContent chunk1 = new DefaultHttpContent(Unpooled.copiedBuffer("test", CharsetUtil.US_ASCII)); HttpContent chunk1 = new DefaultHttpContent(Unpooled.copiedBuffer("test", CharsetUtil.US_ASCII));
HttpContent chunk2 = new DefaultHttpContent(Unpooled.copiedBuffer("test2", CharsetUtil.US_ASCII)); HttpContent chunk2 = new DefaultHttpContent(Unpooled.copiedBuffer("test2", CharsetUtil.US_ASCII));
@ -77,8 +77,7 @@ public class HttpObjectAggregatorTest {
public void testAggregateWithTrailer() { public void testAggregateWithTrailer() {
HttpObjectAggregator aggr = new HttpObjectAggregator(1024 * 1024); HttpObjectAggregator aggr = new HttpObjectAggregator(1024 * 1024);
EmbeddedChannel embedder = new EmbeddedChannel(aggr); EmbeddedChannel embedder = new EmbeddedChannel(aggr);
HttpRequest message = new DefaultHttpRequest(HttpVersion.HTTP_1_1, HttpRequest message = new DefaultHttpRequest(HttpVersion.HTTP_1_1, HttpMethod.GET, "http://localhost");
HttpMethod.GET, "http://localhost");
HttpHeaders.setHeader(message, "X-Test", true); HttpHeaders.setHeader(message, "X-Test", true);
HttpHeaders.setTransferEncodingChunked(message); HttpHeaders.setTransferEncodingChunked(message);
HttpContent chunk1 = new DefaultHttpContent(Unpooled.copiedBuffer("test", CharsetUtil.US_ASCII)); HttpContent chunk1 = new DefaultHttpContent(Unpooled.copiedBuffer("test", CharsetUtil.US_ASCII));
@ -105,11 +104,9 @@ public class HttpObjectAggregatorTest {
} }
@Test @Test
public void testOversizedMessage() { public void testOversizedRequest() {
HttpObjectAggregator aggr = new HttpObjectAggregator(4); EmbeddedChannel embedder = new EmbeddedChannel(new HttpObjectAggregator(4));
EmbeddedChannel embedder = new EmbeddedChannel(aggr); HttpRequest message = new DefaultHttpRequest(HttpVersion.HTTP_1_1, HttpMethod.GET, "http://localhost");
HttpRequest message = new DefaultHttpRequest(HttpVersion.HTTP_1_1,
HttpMethod.GET, "http://localhost");
HttpContent chunk1 = releaseLater(new DefaultHttpContent(Unpooled.copiedBuffer("test", CharsetUtil.US_ASCII))); HttpContent chunk1 = releaseLater(new DefaultHttpContent(Unpooled.copiedBuffer("test", CharsetUtil.US_ASCII)));
HttpContent chunk2 = releaseLater(new DefaultHttpContent(Unpooled.copiedBuffer("test2", CharsetUtil.US_ASCII))); HttpContent chunk2 = releaseLater(new DefaultHttpContent(Unpooled.copiedBuffer("test2", CharsetUtil.US_ASCII)));
HttpContent chunk3 = LastHttpContent.EMPTY_LAST_CONTENT; HttpContent chunk3 = LastHttpContent.EMPTY_LAST_CONTENT;
@ -120,6 +117,7 @@ public class HttpObjectAggregatorTest {
FullHttpResponse response = embedder.readOutbound(); FullHttpResponse response = embedder.readOutbound();
assertEquals(HttpResponseStatus.REQUEST_ENTITY_TOO_LARGE, response.getStatus()); assertEquals(HttpResponseStatus.REQUEST_ENTITY_TOO_LARGE, response.getStatus());
assertEquals("0", response.headers().get(Names.CONTENT_LENGTH));
assertFalse(embedder.isOpen()); assertFalse(embedder.isOpen());
try { try {
@ -133,32 +131,26 @@ public class HttpObjectAggregatorTest {
} }
@Test @Test
public void testOversizedMessageWithoutKeepAlive() { public void testOversizedRequestWithoutKeepAlive() {
// send a HTTP/1.0 request with no keep-alive header // send a HTTP/1.0 request with no keep-alive header
HttpRequest message = new DefaultHttpRequest(HttpVersion.HTTP_1_0, HttpRequest message = new DefaultHttpRequest(HttpVersion.HTTP_1_0, HttpMethod.GET, "http://localhost");
HttpMethod.GET, "http://localhost");
HttpHeaders.setContentLength(message, 5); HttpHeaders.setContentLength(message, 5);
checkOversizedMessage(message); checkOversizedMessage(message);
} }
@Test @Test
public void testOversizedMessageWithContentLength() { public void testOversizedRequestWithContentLength() {
HttpRequest message = new DefaultHttpRequest(HttpVersion.HTTP_1_1, HttpRequest message = new DefaultHttpRequest(HttpVersion.HTTP_1_1, HttpMethod.GET, "http://localhost");
HttpMethod.GET, "http://localhost");
HttpHeaders.setContentLength(message, 5); HttpHeaders.setContentLength(message, 5);
checkOversizedMessage(message); checkOversizedMessage(message);
} }
@Test @Test
public void testOversizedMessageWith100Continue() { public void testOversizedMessageWith100Continue() {
HttpObjectAggregator aggr = new HttpObjectAggregator(8); EmbeddedChannel embedder = new EmbeddedChannel(new HttpObjectAggregator(8));
EmbeddedChannel embedder = new EmbeddedChannel(aggr);
// send an oversized request with 100 continue // send an oversized request with 100 continue
HttpRequest message = new DefaultHttpRequest(HttpVersion.HTTP_1_1, HttpRequest message = new DefaultHttpRequest(HttpVersion.HTTP_1_1, HttpMethod.GET, "http://localhost");
HttpMethod.GET, "http://localhost");
HttpHeaders.set100ContinueExpected(message); HttpHeaders.set100ContinueExpected(message);
HttpHeaders.setContentLength(message, 16); HttpHeaders.setContentLength(message, 16);
@ -166,43 +158,46 @@ public class HttpObjectAggregatorTest {
HttpContent chunk2 = releaseLater(new DefaultHttpContent(Unpooled.copiedBuffer("test", CharsetUtil.US_ASCII))); HttpContent chunk2 = releaseLater(new DefaultHttpContent(Unpooled.copiedBuffer("test", CharsetUtil.US_ASCII)));
HttpContent chunk3 = LastHttpContent.EMPTY_LAST_CONTENT; HttpContent chunk3 = LastHttpContent.EMPTY_LAST_CONTENT;
// Send a request with 100-continue + large Content-Length header value.
assertFalse(embedder.writeInbound(message)); assertFalse(embedder.writeInbound(message));
try { // The agregator should respond with '413 Request Entity Too Large.'
embedder.writeInbound(chunk1); FullHttpResponse response = embedder.readOutbound();
fail();
} catch (AssertionError e) {
assertTrue(embedder.isWritable());
}
HttpResponse response = embedder.readOutbound();
assertEquals(HttpResponseStatus.REQUEST_ENTITY_TOO_LARGE, response.getStatus()); assertEquals(HttpResponseStatus.REQUEST_ENTITY_TOO_LARGE, response.getStatus());
assertEquals("0", response.headers().get(Names.CONTENT_LENGTH));
// An ill-behaving client could continue to send data without a respect, and such data should be discarded.
assertFalse(embedder.writeInbound(chunk1));
// The aggregator should not close the connection because keep-alive is on.
assertTrue(embedder.isOpen()); assertTrue(embedder.isOpen());
// now send a valid request // Now send a valid request.
HttpRequest message2 = new DefaultHttpRequest(HttpVersion.HTTP_1_1, HttpRequest message2 = new DefaultHttpRequest(HttpVersion.HTTP_1_1, HttpMethod.GET, "http://localhost");
HttpMethod.GET, "http://localhost");
assertFalse(embedder.writeInbound(message2)); assertFalse(embedder.writeInbound(message2));
assertFalse(embedder.writeInbound(chunk2)); assertFalse(embedder.writeInbound(chunk2));
assertTrue(embedder.writeInbound(chunk3)); assertTrue(embedder.writeInbound(chunk3));
FullHttpRequest aggratedMessage = embedder.readInbound(); FullHttpRequest fullMsg = embedder.readInbound();
assertNotNull(aggratedMessage); assertNotNull(fullMsg);
assertEquals(chunk2.content().readableBytes() + chunk3.content().readableBytes(), assertEquals(
HttpHeaders.getContentLength(aggratedMessage)); chunk2.content().readableBytes() + chunk3.content().readableBytes(),
HttpHeaders.getContentLength(fullMsg));
assertEquals(HttpHeaders.getContentLength(fullMsg), fullMsg.content().readableBytes());
assertFalse(embedder.finish()); assertFalse(embedder.finish());
} }
private static void checkOversizedMessage(HttpMessage message) { private static void checkOversizedMessage(HttpMessage message) {
HttpObjectAggregator aggr = new HttpObjectAggregator(4); EmbeddedChannel embedder = new EmbeddedChannel(new HttpObjectAggregator(4));
EmbeddedChannel embedder = new EmbeddedChannel(aggr);
assertFalse(embedder.writeInbound(message)); assertFalse(embedder.writeInbound(message));
HttpResponse response = embedder.readOutbound(); HttpResponse response = embedder.readOutbound();
assertEquals(HttpResponseStatus.REQUEST_ENTITY_TOO_LARGE, response.getStatus()); assertEquals(HttpResponseStatus.REQUEST_ENTITY_TOO_LARGE, response.getStatus());
assertEquals("0", response.headers().get(Names.CONTENT_LENGTH));
assertFalse(embedder.isOpen()); assertFalse(embedder.isOpen());
HttpContent chunk1 = releaseLater(new DefaultHttpContent(Unpooled.copiedBuffer("test", CharsetUtil.US_ASCII))); HttpContent chunk1 = releaseLater(new DefaultHttpContent(Unpooled.copiedBuffer("test", CharsetUtil.US_ASCII)));
@ -241,8 +236,7 @@ public class HttpObjectAggregatorTest {
HttpObjectAggregator aggr = new HttpObjectAggregator(1024 * 1024); HttpObjectAggregator aggr = new HttpObjectAggregator(1024 * 1024);
EmbeddedChannel embedder = new EmbeddedChannel(aggr); EmbeddedChannel embedder = new EmbeddedChannel(aggr);
HttpRequest message = new DefaultHttpRequest(HttpVersion.HTTP_1_1, HttpRequest message = new DefaultHttpRequest(HttpVersion.HTTP_1_1, HttpMethod.GET, "http://localhost");
HttpMethod.GET, "http://localhost");
HttpHeaders.setHeader(message, "X-Test", true); HttpHeaders.setHeader(message, "X-Test", true);
HttpHeaders.setHeader(message, "Transfer-Encoding", "Chunked"); HttpHeaders.setHeader(message, "Transfer-Encoding", "Chunked");
HttpContent chunk1 = new DefaultHttpContent(Unpooled.copiedBuffer("test", CharsetUtil.US_ASCII)); HttpContent chunk1 = new DefaultHttpContent(Unpooled.copiedBuffer("test", CharsetUtil.US_ASCII));