Fix a bug where HttpObjectAggregator fails to send a '100 Continue' response

- Fixes #1742
This commit is contained in:
Trustin Lee 2013-12-16 21:44:44 +09:00
parent f7a3881536
commit 3444c06654
2 changed files with 59 additions and 13 deletions

View File

@ -15,22 +15,22 @@
*/ */
package io.netty.handler.codec.http; package io.netty.handler.codec.http;
import static io.netty.handler.codec.http.HttpHeaders.is100ContinueExpected;
import static io.netty.handler.codec.http.HttpHeaders.removeTransferEncodingChunked;
import io.netty.buffer.ByteBuf;
import io.netty.buffer.CompositeByteBuf; import io.netty.buffer.CompositeByteBuf;
import io.netty.buffer.Unpooled; import io.netty.buffer.Unpooled;
import io.netty.channel.ChannelFuture;
import io.netty.channel.ChannelFutureListener;
import io.netty.channel.ChannelHandler; import io.netty.channel.ChannelHandler;
import io.netty.channel.ChannelHandlerContext; import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.ChannelPipeline; import io.netty.channel.ChannelPipeline;
import io.netty.handler.codec.DecoderResult; import io.netty.handler.codec.DecoderResult;
import io.netty.handler.codec.MessageToMessageDecoder; import io.netty.handler.codec.MessageToMessageDecoder;
import io.netty.handler.codec.TooLongFrameException; import io.netty.handler.codec.TooLongFrameException;
import io.netty.util.CharsetUtil;
import io.netty.util.ReferenceCountUtil; import io.netty.util.ReferenceCountUtil;
import java.util.List; import java.util.List;
import static io.netty.handler.codec.http.HttpHeaders.*;
/** /**
* A {@link ChannelHandler} that aggregates an {@link HttpMessage} * A {@link ChannelHandler} that aggregates an {@link HttpMessage}
* and its following {@link HttpContent}s into a single {@link HttpMessage} with * and its following {@link HttpContent}s into a single {@link HttpMessage} with
@ -49,8 +49,8 @@ import java.util.List;
*/ */
public class HttpObjectAggregator extends MessageToMessageDecoder<HttpObject> { public class HttpObjectAggregator extends MessageToMessageDecoder<HttpObject> {
public static final int DEFAULT_MAX_COMPOSITEBUFFER_COMPONENTS = 1024; public static final int DEFAULT_MAX_COMPOSITEBUFFER_COMPONENTS = 1024;
private static final ByteBuf CONTINUE = Unpooled.unreleasableBuffer(Unpooled.copiedBuffer( private static final FullHttpResponse CONTINUE =
"HTTP/1.1 100 Continue\r\n\r\n", CharsetUtil.US_ASCII)); new DefaultFullHttpResponse(HttpVersion.HTTP_1_1, HttpResponseStatus.CONTINUE, Unpooled.EMPTY_BUFFER);
private final int maxContentLength; private final int maxContentLength;
private FullHttpMessage currentMessage; private FullHttpMessage currentMessage;
@ -109,7 +109,7 @@ public class HttpObjectAggregator extends MessageToMessageDecoder<HttpObject> {
} }
@Override @Override
protected void decode(ChannelHandlerContext ctx, HttpObject msg, List<Object> out) throws Exception { protected void decode(final ChannelHandlerContext ctx, HttpObject msg, List<Object> out) throws Exception {
FullHttpMessage currentMessage = this.currentMessage; FullHttpMessage currentMessage = this.currentMessage;
if (msg instanceof HttpMessage) { if (msg instanceof HttpMessage) {
@ -124,7 +124,14 @@ public class HttpObjectAggregator extends MessageToMessageDecoder<HttpObject> {
// No need to notify the upstream handlers - just log. // No need to notify the upstream handlers - just log.
// If decoding a response, just throw an exception. // If decoding a response, just throw an exception.
if (is100ContinueExpected(m)) { if (is100ContinueExpected(m)) {
ctx.writeAndFlush(CONTINUE.duplicate()); ctx.writeAndFlush(CONTINUE).addListener(new ChannelFutureListener() {
@Override
public void operationComplete(ChannelFuture future) throws Exception {
if (!future.isSuccess()) {
ctx.fireExceptionCaught(future.cause());
}
}
});
} }
if (!m.getDecoderResult().isSuccess()) { if (!m.getDecoderResult().isSuccess()) {

View File

@ -19,9 +19,12 @@ import io.netty.buffer.ByteBuf;
import io.netty.buffer.Unpooled; import io.netty.buffer.Unpooled;
import io.netty.channel.embedded.EmbeddedChannel; import io.netty.channel.embedded.EmbeddedChannel;
import io.netty.util.CharsetUtil; import io.netty.util.CharsetUtil;
import org.junit.Assert;
import org.junit.Test; import org.junit.Test;
import static io.netty.handler.codec.http.HttpHeaders.Names.*;
import static org.hamcrest.CoreMatchers.*;
import static org.junit.Assert.*;
public class HttpServerCodecTest { public class HttpServerCodecTest {
/** /**
@ -45,7 +48,7 @@ public class HttpServerCodecTest {
decoderEmbedder.finish(); decoderEmbedder.finish();
HttpMessage httpMessage = (HttpMessage) decoderEmbedder.readInbound(); HttpMessage httpMessage = (HttpMessage) decoderEmbedder.readInbound();
Assert.assertNotNull(httpMessage); assertNotNull(httpMessage);
boolean empty = true; boolean empty = true;
int totalBytesPolled = 0; int totalBytesPolled = 0;
@ -56,11 +59,47 @@ public class HttpServerCodecTest {
} }
empty = false; empty = false;
totalBytesPolled += httpChunk.content().readableBytes(); totalBytesPolled += httpChunk.content().readableBytes();
Assert.assertFalse(httpChunk instanceof LastHttpContent); assertFalse(httpChunk instanceof LastHttpContent);
httpChunk.release(); httpChunk.release();
} }
Assert.assertFalse(empty); assertFalse(empty);
Assert.assertEquals(offeredContentLength, totalBytesPolled); assertEquals(offeredContentLength, totalBytesPolled);
}
@Test
public void test100Continue() throws Exception {
EmbeddedChannel ch = new EmbeddedChannel(new HttpServerCodec(), new HttpObjectAggregator(1024));
// Send the request headers.
ch.writeInbound(Unpooled.copiedBuffer(
"PUT /upload-large HTTP/1.1\r\n" +
"Expect: 100-continue\r\n" +
"Content-Length: 1\r\n\r\n", CharsetUtil.UTF_8));
// Ensure the aggregator generates nothing.
assertThat(ch.readInbound(), is(nullValue()));
// Ensure the aggregator writes a 100 Continue response.
ByteBuf continueResponse = (ByteBuf) ch.readOutbound();
assertThat(continueResponse.toString(CharsetUtil.UTF_8), is("HTTP/1.1 100 Continue\r\n\r\n"));
// But nothing more.
assertThat(ch.readOutbound(), is(nullValue()));
// Send the content of the request.
ch.writeInbound(Unpooled.wrappedBuffer(new byte[] { 42 }));
// Ensure the aggregator generates a full request.
FullHttpRequest req = (FullHttpRequest) ch.readInbound();
assertThat(req.headers().get(CONTENT_LENGTH), is("1"));
assertThat(req.content().readableBytes(), is(1));
assertThat(req.content().readByte(), is((byte) 42));
req.release();
// But nothing more.
assertThat(ch.readInbound(), is(nullValue()));
ch.finish();
} }
private static ByteBuf prepareDataChunk(int size) { private static ByteBuf prepareDataChunk(int size) {