Expose a callback in HttpObjectAggregator to handle oversized messages

- Related: #2211
This commit is contained in:
Chris Mowforth 2014-02-07 16:59:51 +00:00 committed by Trustin Lee
parent dbb2198839
commit 91376263d7
2 changed files with 156 additions and 32 deletions

View File

@ -24,7 +24,7 @@ import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.ChannelPipeline;
import io.netty.handler.codec.DecoderResult;
import io.netty.handler.codec.MessageToMessageDecoder;
import io.netty.handler.codec.TooLongFrameException;
import io.netty.util.ReferenceCountUtil;
import java.util.List;
@ -48,12 +48,16 @@ import static io.netty.handler.codec.http.HttpHeaders.*;
*/
public class HttpObjectAggregator extends MessageToMessageDecoder<HttpObject> {
public static final int DEFAULT_MAX_COMPOSITEBUFFER_COMPONENTS = 1024;
private static final FullHttpResponse CONTINUE =
new DefaultFullHttpResponse(HttpVersion.HTTP_1_1, HttpResponseStatus.CONTINUE, Unpooled.EMPTY_BUFFER);
private static final FullHttpResponse CONTINUE = new DefaultFullHttpResponse(
HttpVersion.HTTP_1_1, HttpResponseStatus.CONTINUE, Unpooled.EMPTY_BUFFER);
private static final FullHttpResponse TOO_LARGE = new DefaultFullHttpResponse(
HttpVersion.HTTP_1_1, HttpResponseStatus.REQUEST_ENTITY_TOO_LARGE, Unpooled.EMPTY_BUFFER);
private final int maxContentLength;
private FullHttpMessage currentMessage;
private boolean tooLongFrameFound;
private boolean is100Continue;
private boolean isKeepAlive;
private int maxCumulationBufferComponents = DEFAULT_MAX_COMPOSITEBUFFER_COMPONENTS;
private ChannelHandlerContext ctx;
@ -64,7 +68,8 @@ public class HttpObjectAggregator extends MessageToMessageDecoder<HttpObject> {
* @param maxContentLength
* the maximum length of the aggregated content.
* If the length of the aggregated content exceeds this value,
* a {@link TooLongFrameException} will be raised.
* {@link #messageTooLong(io.netty.channel.ChannelHandlerContext, HttpObject)}
* will be called.
*/
public HttpObjectAggregator(int maxContentLength) {
if (maxContentLength <= 0) {
@ -117,12 +122,20 @@ public class HttpObjectAggregator extends MessageToMessageDecoder<HttpObject> {
HttpMessage m = (HttpMessage) msg;
is100Continue = is100ContinueExpected(m);
isKeepAlive = isKeepAlive(m);
// if content length is set, preemptively close if it's too large
if (isContentLengthSet(m)) {
if (getContentLength(m) > maxContentLength) {
// handle oversized message
handleMessageTooLong(ctx, m);
return;
}
}
// Handle the 'Expect: 100-continue' header if necessary.
// TODO: Respond with 413 Request Entity Too Large
// and discard the traffic or close the connection.
// No need to notify the upstream handlers - just log.
// If decoding a response, just throw an exception.
if (is100ContinueExpected(m)) {
if (is100Continue) {
ctx.writeAndFlush(CONTINUE).addListener(new ChannelFutureListener() {
@Override
public void operationComplete(ChannelFuture future) throws Exception {
@ -171,13 +184,9 @@ public class HttpObjectAggregator extends MessageToMessageDecoder<HttpObject> {
CompositeByteBuf content = (CompositeByteBuf) currentMessage.content();
if (content.readableBytes() > maxContentLength - chunk.content().readableBytes()) {
tooLongFrameFound = true;
// release current message to prevent leaks
currentMessage.release();
this.currentMessage = null;
throw new TooLongFrameException("HTTP content length exceeded " + maxContentLength + " bytes.");
// handle oversized message
handleMessageTooLong(ctx, currentMessage);
return;
}
// Append the content of the chunk
@ -217,6 +226,44 @@ public class HttpObjectAggregator extends MessageToMessageDecoder<HttpObject> {
}
}
private void handleMessageTooLong(ChannelHandlerContext ctx, HttpObject msg) throws Exception {
tooLongFrameFound = true;
currentMessage = null;
messageTooLong(ctx, msg);
}
/**
* Invoked when an incoming request exceeds the maximum content length.
*
* The default behavior returns a {@code 413} HTTP response.
*
* Sub-classes may override this method to change behavior. <em>Note:</em>
* you are responsible for releasing {@code msg} if you override the
* default behavior.
*
* @param ctx the {@link ChannelHandlerContext}
* @param msg the accumulated HTTP message up to this point
*/
protected void messageTooLong(ChannelHandlerContext ctx, HttpObject msg) throws Exception {
// release current message to prevent leaks
ReferenceCountUtil.release(msg);
final ChannelHandlerContext context = ctx;
// send back a 413 and close the connection
ChannelFuture future = ctx.writeAndFlush(TOO_LARGE).addListener(new ChannelFutureListener() {
@Override
public void operationComplete(ChannelFuture future) throws Exception {
if (!future.isSuccess()) {
context.fireExceptionCaught(future.cause());
}
}
});
if (!is100Continue || !isKeepAlive) {
future.addListener(ChannelFutureListener.CLOSE);
}
}
@Override
public void channelInactive(ChannelHandlerContext ctx) throws Exception {
// release current message if it is not null as it may be a left-over

View File

@ -20,11 +20,11 @@ import io.netty.buffer.CompositeByteBuf;
import io.netty.buffer.Unpooled;
import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.embedded.EmbeddedChannel;
import io.netty.handler.codec.TooLongFrameException;
import io.netty.util.CharsetUtil;
import org.easymock.EasyMock;
import org.junit.Test;
import java.nio.channels.ClosedChannelException;
import java.util.List;
import static io.netty.util.ReferenceCountUtil.*;
@ -105,38 +105,115 @@ public class HttpObjectAggregatorTest {
}
@Test
public void testTooLongFrameException() {
public void testOversizedMessage() {
HttpObjectAggregator aggr = new HttpObjectAggregator(4);
EmbeddedChannel embedder = new EmbeddedChannel(aggr);
HttpRequest message = new DefaultHttpRequest(HttpVersion.HTTP_1_1,
HttpMethod.GET, "http://localhost");
HttpContent chunk1 = releaseLater(new DefaultHttpContent(Unpooled.copiedBuffer("test", CharsetUtil.US_ASCII)));
HttpContent chunk2 = releaseLater(new DefaultHttpContent(Unpooled.copiedBuffer("test2", CharsetUtil.US_ASCII)));
HttpContent chunk3 = releaseLater(new DefaultHttpContent(Unpooled.copiedBuffer("test3", CharsetUtil.US_ASCII)));
HttpContent chunk4 = LastHttpContent.EMPTY_LAST_CONTENT;
HttpContent chunk3 = LastHttpContent.EMPTY_LAST_CONTENT;
assertFalse(embedder.writeInbound(message));
assertFalse(embedder.writeInbound(chunk1.copy()));
assertFalse(embedder.writeInbound(chunk2.copy()));
FullHttpResponse response = embedder.readOutbound();
assertEquals(HttpResponseStatus.REQUEST_ENTITY_TOO_LARGE, response.getStatus());
assertFalse(embedder.isOpen());
try {
embedder.writeInbound(chunk2.copy());
assertFalse(embedder.writeInbound(chunk3.copy()));
fail();
} catch (TooLongFrameException e) {
// expected
} catch (Exception e) {
assertTrue(e instanceof ClosedChannelException);
}
assertFalse(embedder.writeInbound(chunk3.copy()));
assertFalse(embedder.writeInbound(chunk4.copy()));
assertFalse(embedder.finish());
}
@Test
public void testOversizedMessageWithoutKeepAlive() {
// send a HTTP/1.0 request with no keep-alive header
HttpRequest message = new DefaultHttpRequest(HttpVersion.HTTP_1_0,
HttpMethod.GET, "http://localhost");
HttpHeaders.setContentLength(message, 5);
checkOversizedMessage(message);
}
@Test
public void testOversizedMessageWithContentLength() {
HttpRequest message = new DefaultHttpRequest(HttpVersion.HTTP_1_1,
HttpMethod.GET, "http://localhost");
HttpHeaders.setContentLength(message, 5);
checkOversizedMessage(message);
}
@Test
public void testOversizedMessageWith100Continue() {
HttpObjectAggregator aggr = new HttpObjectAggregator(8);
EmbeddedChannel embedder = new EmbeddedChannel(aggr);
// send an oversized request with 100 continue
HttpRequest message = new DefaultHttpRequest(HttpVersion.HTTP_1_1,
HttpMethod.GET, "http://localhost");
HttpHeaders.set100ContinueExpected(message);
HttpHeaders.setContentLength(message, 16);
HttpContent chunk1 = releaseLater(new DefaultHttpContent(Unpooled.copiedBuffer("some", CharsetUtil.US_ASCII)));
HttpContent chunk2 = releaseLater(new DefaultHttpContent(Unpooled.copiedBuffer("test", CharsetUtil.US_ASCII)));
HttpContent chunk3 = LastHttpContent.EMPTY_LAST_CONTENT;
assertFalse(embedder.writeInbound(message));
assertFalse(embedder.writeInbound(chunk1.copy()));
try {
embedder.writeInbound(chunk2.copy());
embedder.writeInbound(chunk1);
fail();
} catch (TooLongFrameException e) {
// expected
} catch (AssertionError e) {
assertTrue(embedder.isWritable());
}
assertFalse(embedder.writeInbound(chunk3.copy()));
assertFalse(embedder.writeInbound(chunk4.copy()));
embedder.finish();
HttpResponse response = embedder.readOutbound();
assertEquals(HttpResponseStatus.REQUEST_ENTITY_TOO_LARGE, response.getStatus());
assertTrue(embedder.isOpen());
// now send a valid request
HttpRequest message2 = new DefaultHttpRequest(HttpVersion.HTTP_1_1,
HttpMethod.GET, "http://localhost");
assertFalse(embedder.writeInbound(message2));
assertFalse(embedder.writeInbound(chunk2));
assertTrue(embedder.writeInbound(chunk3));
FullHttpRequest aggratedMessage = embedder.readInbound();
assertNotNull(aggratedMessage);
assertEquals(chunk2.content().readableBytes() + chunk3.content().readableBytes(),
HttpHeaders.getContentLength(aggratedMessage));
assertFalse(embedder.finish());
}
private static void checkOversizedMessage(HttpMessage message) {
HttpObjectAggregator aggr = new HttpObjectAggregator(4);
EmbeddedChannel embedder = new EmbeddedChannel(aggr);
assertFalse(embedder.writeInbound(message));
HttpResponse response = embedder.readOutbound();
assertEquals(HttpResponseStatus.REQUEST_ENTITY_TOO_LARGE, response.getStatus());
assertFalse(embedder.isOpen());
HttpContent chunk1 = releaseLater(new DefaultHttpContent(Unpooled.copiedBuffer("test", CharsetUtil.US_ASCII)));
try {
embedder.writeInbound(chunk1.copy());
fail();
} catch (Exception e) {
assertTrue(e instanceof ClosedChannelException);
}
assertFalse(embedder.finish());
}
@Test(expected = IllegalArgumentException.class)