HttpContentDecompressor should change decompressed requests to chunked encoding. Fixes issue #5428

`HttpContentDecoder` was removing `Content-Length` header but not adding a `Transfer-Encoding` header which goes against the HTTP spec.

Added `Transfer-Encoding` header with value `chunked` when `Content-Length` is removed.
Modified existing unit test to also check for this condition.

Compliance with HTTP spec.
This commit is contained in:
Nitesh Kant 2016-06-19 20:37:07 -07:00 committed by Norman Maurer
parent 9602535b7d
commit ee0897a1d9
2 changed files with 40 additions and 41 deletions

View File

@ -99,6 +99,7 @@ public abstract class HttpContentDecoder extends MessageToMessageDecoder<HttpObj
// If buffering is not an issue, add HttpObjectAggregator down the chain, it will set the header.
// Otherwise, rely on LastHttpContent message.
headers.remove(HttpHeaderNames.CONTENT_LENGTH);
headers.set(HttpHeaderNames.TRANSFER_ENCODING, HttpHeaderValues.CHUNKED);
// set new content encoding,
CharSequence targetContentEncoding = getTargetContentEncoding(contentEncoding);

View File

@ -299,9 +299,12 @@ public class HttpContentDecoderTest {
Object o = resp.peek();
assertThat(o, is(instanceOf(HttpResponse.class)));
HttpResponse r = (HttpResponse) o;
String v = r.headers().get(HttpHeaderNames.CONTENT_LENGTH);
Long value = v == null ? null : Long.parseLong(v);
assertTrue(value == null || value.longValue() == HELLO_WORLD.length());
assertFalse("Content-Length header not removed.", r.headers().contains(HttpHeaderNames.CONTENT_LENGTH));
String transferEncoding = r.headers().get(HttpHeaderNames.TRANSFER_ENCODING);
assertNotNull("Content-length as well as transfer-encoding not set.", transferEncoding);
assertEquals("Unexpected transfer-encoding value.", HttpHeaderValues.CHUNKED.toString(), transferEncoding);
assertHasInboundMessages(channel, true);
assertHasOutboundMessages(channel, false);
@ -354,24 +357,9 @@ public class HttpContentDecoderTest {
Queue<Object> req = channel.inboundMessages();
assertTrue(req.size() > 1);
int contentLength = 0;
for (Object o : req) {
if (o instanceof HttpContent) {
assertTrue(((HttpContent) o).refCnt() > 0);
ByteBuf b = ((HttpContent) o).content();
contentLength += b.readableBytes();
}
}
contentLength = calculateContentLength(req, contentLength);
int readCount = 0;
byte[] receivedContent = new byte[contentLength];
for (Object o : req) {
if (o instanceof HttpContent) {
ByteBuf b = ((HttpContent) o).content();
int readableBytes = b.readableBytes();
b.readBytes(receivedContent, readCount, readableBytes);
readCount += readableBytes;
}
}
byte[] receivedContent = readContent(req, contentLength);
assertEquals(HELLO_WORLD, new String(receivedContent, CharsetUtil.US_ASCII));
@ -396,24 +384,9 @@ public class HttpContentDecoderTest {
Queue<Object> resp = channel.inboundMessages();
assertTrue(resp.size() > 1);
int contentLength = 0;
for (Object o : resp) {
if (o instanceof HttpContent) {
assertTrue(((HttpContent) o).refCnt() > 0);
ByteBuf b = ((HttpContent) o).content();
contentLength += b.readableBytes();
}
}
contentLength = calculateContentLength(resp, contentLength);
int readCount = 0;
byte[] receivedContent = new byte[contentLength];
for (Object o : resp) {
if (o instanceof HttpContent) {
ByteBuf b = ((HttpContent) o).content();
int readableBytes = b.readableBytes();
b.readBytes(receivedContent, readCount, readableBytes);
readCount += readableBytes;
}
}
byte[] receivedContent = readContent(resp, contentLength);
assertEquals(HELLO_WORLD, new String(receivedContent, CharsetUtil.US_ASCII));
@ -422,7 +395,7 @@ public class HttpContentDecoderTest {
assertFalse(channel.finish());
}
private byte[] gzDecompress(byte[] input) {
private static byte[] gzDecompress(byte[] input) {
ZlibDecoder decoder = ZlibCodecFactory.newZlibDecoder(ZlibWrapper.GZIP);
EmbeddedChannel channel = new EmbeddedChannel(decoder);
assertTrue(channel.writeInbound(Unpooled.wrappedBuffer(input)));
@ -448,7 +421,32 @@ public class HttpContentDecoderTest {
return output;
}
private byte[] gzCompress(byte[] input) {
private static byte[] readContent(Queue<Object> req, int contentLength) {
byte[] receivedContent = new byte[contentLength];
int readCount = 0;
for (Object o : req) {
if (o instanceof HttpContent) {
ByteBuf b = ((HttpContent) o).content();
int readableBytes = b.readableBytes();
b.readBytes(receivedContent, readCount, readableBytes);
readCount += readableBytes;
}
}
return receivedContent;
}
private static int calculateContentLength(Queue<Object> req, int contentLength) {
for (Object o : req) {
if (o instanceof HttpContent) {
assertTrue(((HttpContent) o).refCnt() > 0);
ByteBuf b = ((HttpContent) o).content();
contentLength += b.readableBytes();
}
}
return contentLength;
}
private static byte[] gzCompress(byte[] input) {
ZlibEncoder encoder = ZlibCodecFactory.newZlibEncoder(ZlibWrapper.GZIP);
EmbeddedChannel channel = new EmbeddedChannel(encoder);
assertTrue(channel.writeOutbound(Unpooled.wrappedBuffer(input)));
@ -474,7 +472,7 @@ public class HttpContentDecoderTest {
return output;
}
private void assertHasInboundMessages(EmbeddedChannel channel, boolean hasMessages) {
private static void assertHasInboundMessages(EmbeddedChannel channel, boolean hasMessages) {
Object o;
if (hasMessages) {
while (true) {
@ -491,7 +489,7 @@ public class HttpContentDecoderTest {
}
}
private void assertHasOutboundMessages(EmbeddedChannel channel, boolean hasMessages) {
private static void assertHasOutboundMessages(EmbeddedChannel channel, boolean hasMessages) {
Object o;
if (hasMessages) {
while (true) {