From ed10513238c271a2c9b372b702dec4943b64f0d5 Mon Sep 17 00:00:00 2001 From: igariev Date: Wed, 14 Jan 2015 23:27:37 -0800 Subject: [PATCH] Fixed several issues with HttpContentDecoder Motivation: HttpContentDecoder had the following issues: - For chunked content, the decoder set invalid "Content-Length" header with length of the first decoded chunk. - Decoding of FullHttpRequests put both the original conent and decoded content into output. As result, using HttpObjectAggregator before the decoder lead to errors. - Requests with "Expect: 100-continue" header were not acknowleged: the decoder didn't pass the header message down the handler's chain until content is received. If client expected "100 Continue" response, deadlock happened. Modification: - Invalid "Content-Length" header is removed; handlers down the chain can either rely on LastHttpContent message or ask HttpObjectAggregator to add the header. - FullHttpRequest is split into HttpRequest and HttpContent (decoded) parts. - Header (HttpRequest) part of request is sent down the chain as soon as it's received. Result: The issues are fixed, unittest is added. --- .../codec/http/HttpContentDecoder.java | 127 +++-- .../codec/http/HttpContentDecoderTest.java | 510 ++++++++++++++++++ 2 files changed, 571 insertions(+), 66 deletions(-) create mode 100644 codec-http/src/test/java/io/netty/handler/codec/http/HttpContentDecoderTest.java diff --git a/codec-http/src/main/java/io/netty/handler/codec/http/HttpContentDecoder.java b/codec-http/src/main/java/io/netty/handler/codec/http/HttpContentDecoder.java index dce58ac15d..3d60b6b530 100644 --- a/codec-http/src/main/java/io/netty/handler/codec/http/HttpContentDecoder.java +++ b/codec-http/src/main/java/io/netty/handler/codec/http/HttpContentDecoder.java @@ -18,6 +18,7 @@ package io.netty.handler.codec.http; import io.netty.buffer.ByteBuf; import io.netty.channel.ChannelHandlerContext; import io.netty.channel.embedded.EmbeddedChannel; +import io.netty.handler.codec.CodecException; import io.netty.handler.codec.MessageToMessageDecoder; import io.netty.util.ReferenceCountUtil; @@ -47,8 +48,6 @@ public abstract class HttpContentDecoder extends MessageToMessageDecoder req = channel.inboundMessages(); + assertTrue(req.size() >= 1); + Object o = req.peek(); + assertThat(o, is(instanceOf(HttpRequest.class))); + HttpRequest r = (HttpRequest) o; + String v = r.headers().get(HttpHeaderNames.CONTENT_LENGTH); + Long value = v == null ? null : Long.parseLong(v); + assertTrue(value == null || value.longValue() == HELLO_WORLD.length()); + + assertHasInboundMessages(channel, true); + assertHasOutboundMessages(channel, false); + assertFalse(channel.finish()); + } + + @Test + public void testRequestContentLength2() { + // case 2: if HttpObjectAggregator is down the chain, then correct Content-Length header must be set + + // force content to be in more than one chunk (5 bytes/chunk) + HttpRequestDecoder decoder = new HttpRequestDecoder(4096, 4096, 5); + HttpContentDecoder decompressor = new HttpContentDecompressor(); + HttpObjectAggregator aggregator = new HttpObjectAggregator(1024); + EmbeddedChannel channel = new EmbeddedChannel(decoder, decompressor, aggregator); + String headers = "POST / HTTP/1.1\r\n" + + "Content-Length: " + GZ_HELLO_WORLD.length + "\r\n" + + "Content-Encoding: gzip\r\n" + + "\r\n"; + ByteBuf buf = Unpooled.copiedBuffer(headers.getBytes(CharsetUtil.US_ASCII), GZ_HELLO_WORLD); + assertTrue(channel.writeInbound(buf)); + + Object o = channel.readInbound(); + assertThat(o, is(instanceOf(FullHttpRequest.class))); + FullHttpRequest r = (FullHttpRequest) o; + String v = r.headers().get(HttpHeaderNames.CONTENT_LENGTH); + Long value = v == null ? null : Long.parseLong(v); + + r.release(); + assertNotNull(value); + assertEquals(HELLO_WORLD.length(), value.longValue()); + + assertHasInboundMessages(channel, false); + assertHasOutboundMessages(channel, false); + assertFalse(channel.finish()); + } + + @Test + public void testResponseContentLength1() { + // case 1: test that ContentDecompressor either sets the correct Content-Length header + // or removes it completely (handlers down the chain must rely on LastHttpContent object) + + // force content to be in more than one chunk (5 bytes/chunk) + HttpResponseDecoder decoder = new HttpResponseDecoder(4096, 4096, 5); + HttpContentDecoder decompressor = new HttpContentDecompressor(); + EmbeddedChannel channel = new EmbeddedChannel(decoder, decompressor); + String headers = "HTTP/1.1 200 OK\r\n" + + "Content-Length: " + GZ_HELLO_WORLD.length + "\r\n" + + "Content-Encoding: gzip\r\n" + + "\r\n"; + ByteBuf buf = Unpooled.copiedBuffer(headers.getBytes(CharsetUtil.US_ASCII), GZ_HELLO_WORLD); + assertTrue(channel.writeInbound(buf)); + + Queue resp = channel.inboundMessages(); + assertTrue(resp.size() >= 1); + Object o = resp.peek(); + assertThat(o, is(instanceOf(HttpResponse.class))); + HttpResponse r = (HttpResponse) o; + String v = r.headers().get(HttpHeaderNames.CONTENT_LENGTH); + Long value = v == null ? null : Long.parseLong(v); + assertTrue(value == null || value.longValue() == HELLO_WORLD.length()); + + assertHasInboundMessages(channel, true); + assertHasOutboundMessages(channel, false); + assertFalse(channel.finish()); + } + + @Test + public void testResponseContentLength2() { + // case 2: if HttpObjectAggregator is down the chain, then correct Content-Length header must be set + + // force content to be in more than one chunk (5 bytes/chunk) + HttpResponseDecoder decoder = new HttpResponseDecoder(4096, 4096, 5); + HttpContentDecoder decompressor = new HttpContentDecompressor(); + HttpObjectAggregator aggregator = new HttpObjectAggregator(1024); + EmbeddedChannel channel = new EmbeddedChannel(decoder, decompressor, aggregator); + String headers = "HTTP/1.1 200 OK\r\n" + + "Content-Length: " + GZ_HELLO_WORLD.length + "\r\n" + + "Content-Encoding: gzip\r\n" + + "\r\n"; + ByteBuf buf = Unpooled.copiedBuffer(headers.getBytes(CharsetUtil.US_ASCII), GZ_HELLO_WORLD); + assertTrue(channel.writeInbound(buf)); + + Object o = channel.readInbound(); + assertThat(o, is(instanceOf(FullHttpResponse.class))); + FullHttpResponse r = (FullHttpResponse) o; + String v = r.headers().get(HttpHeaderNames.CONTENT_LENGTH); + Long value = v == null ? null : Long.parseLong(v); + assertNotNull(value); + assertEquals(HELLO_WORLD.length(), value.longValue()); + r.release(); + + assertHasInboundMessages(channel, false); + assertHasOutboundMessages(channel, false); + assertFalse(channel.finish()); + } + + @Test + public void testFullHttpRequest() { + // test that ContentDecoder can be used after the ObjectAggregator + HttpRequestDecoder decoder = new HttpRequestDecoder(4096, 4096, 5); + HttpObjectAggregator aggregator = new HttpObjectAggregator(1024); + HttpContentDecoder decompressor = new HttpContentDecompressor(); + EmbeddedChannel channel = new EmbeddedChannel(decoder, aggregator, decompressor); + String headers = "POST / HTTP/1.1\r\n" + + "Content-Length: " + GZ_HELLO_WORLD.length + "\r\n" + + "Content-Encoding: gzip\r\n" + + "\r\n"; + assertTrue(channel.writeInbound(Unpooled.copiedBuffer(headers.getBytes(), GZ_HELLO_WORLD))); + + Queue req = channel.inboundMessages(); + assertTrue(req.size() > 1); + int contentLength = 0; + for (Object o : req) { + if (o instanceof HttpContent) { + assertTrue(((HttpContent) o).refCnt() > 0); + ByteBuf b = ((HttpContent) o).content(); + contentLength += b.readableBytes(); + } + } + + int readCount = 0; + byte[] receivedContent = new byte[contentLength]; + for (Object o : req) { + if (o instanceof HttpContent) { + ByteBuf b = ((HttpContent) o).content(); + int readableBytes = b.readableBytes(); + b.readBytes(receivedContent, readCount, readableBytes); + readCount += readableBytes; + } + } + + assertEquals(HELLO_WORLD, new String(receivedContent, CharsetUtil.US_ASCII)); + + assertHasInboundMessages(channel, true); + assertHasOutboundMessages(channel, false); + assertFalse(channel.finish()); + } + + @Test + public void testFullHttpResponse() { + // test that ContentDecoder can be used after the ObjectAggregator + HttpResponseDecoder decoder = new HttpResponseDecoder(4096, 4096, 5); + HttpObjectAggregator aggregator = new HttpObjectAggregator(1024); + HttpContentDecoder decompressor = new HttpContentDecompressor(); + EmbeddedChannel channel = new EmbeddedChannel(decoder, aggregator, decompressor); + String headers = "HTTP/1.1 200 OK\r\n" + + "Content-Length: " + GZ_HELLO_WORLD.length + "\r\n" + + "Content-Encoding: gzip\r\n" + + "\r\n"; + assertTrue(channel.writeInbound(Unpooled.copiedBuffer(headers.getBytes(), GZ_HELLO_WORLD))); + + Queue resp = channel.inboundMessages(); + assertTrue(resp.size() > 1); + int contentLength = 0; + for (Object o : resp) { + if (o instanceof HttpContent) { + assertTrue(((HttpContent) o).refCnt() > 0); + ByteBuf b = ((HttpContent) o).content(); + contentLength += b.readableBytes(); + } + } + + int readCount = 0; + byte[] receivedContent = new byte[contentLength]; + for (Object o : resp) { + if (o instanceof HttpContent) { + ByteBuf b = ((HttpContent) o).content(); + int readableBytes = b.readableBytes(); + b.readBytes(receivedContent, readCount, readableBytes); + readCount += readableBytes; + } + } + + assertEquals(HELLO_WORLD, new String(receivedContent, CharsetUtil.US_ASCII)); + + assertHasInboundMessages(channel, true); + assertHasOutboundMessages(channel, false); + assertFalse(channel.finish()); + } + + private byte[] gzDecompress(byte[] input) { + ZlibDecoder decoder = ZlibCodecFactory.newZlibDecoder(ZlibWrapper.GZIP); + EmbeddedChannel channel = new EmbeddedChannel(decoder); + assertTrue(channel.writeInbound(Unpooled.wrappedBuffer(input))); + assertTrue(channel.finish()); // close the channel to indicate end-of-data + + int outputSize = 0; + ByteBuf o; + List inbound = new ArrayList(); + while ((o = channel.readInbound()) != null) { + inbound.add(o); + outputSize += o.readableBytes(); + } + + byte[] output = new byte[outputSize]; + int readCount = 0; + for (ByteBuf b : inbound) { + int readableBytes = b.readableBytes(); + b.readBytes(output, readCount, readableBytes); + b.release(); + readCount += readableBytes; + } + assertTrue(channel.inboundMessages().isEmpty() && channel.outboundMessages().isEmpty()); + return output; + } + + private byte[] gzCompress(byte[] input) { + ZlibEncoder encoder = ZlibCodecFactory.newZlibEncoder(ZlibWrapper.GZIP); + EmbeddedChannel channel = new EmbeddedChannel(encoder); + assertTrue(channel.writeOutbound(Unpooled.wrappedBuffer(input))); + assertTrue(channel.finish()); // close the channel to indicate end-of-data + + int outputSize = 0; + ByteBuf o; + List outbound = new ArrayList(); + while ((o = channel.readOutbound()) != null) { + outbound.add(o); + outputSize += o.readableBytes(); + } + + byte[] output = new byte[outputSize]; + int readCount = 0; + for (ByteBuf b : outbound) { + int readableBytes = b.readableBytes(); + b.readBytes(output, readCount, readableBytes); + b.release(); + readCount += readableBytes; + } + assertTrue(channel.inboundMessages().isEmpty() && channel.outboundMessages().isEmpty()); + return output; + } + + private void assertHasInboundMessages(EmbeddedChannel channel, boolean hasMessages) { + Object o; + if (hasMessages) { + while (true) { + o = channel.readInbound(); + assertNotNull(o); + ReferenceCountUtil.release(o); + if (o instanceof LastHttpContent) { + break; + } + } + } else { + o = channel.readInbound(); + assertNull(o); + } + } + + private void assertHasOutboundMessages(EmbeddedChannel channel, boolean hasMessages) { + Object o; + if (hasMessages) { + while (true) { + o = channel.readOutbound(); + assertNotNull(o); + ReferenceCountUtil.release(o); + if (o instanceof LastHttpContent) { + break; + } + } + } else { + o = channel.readOutbound(); + assertNull(o); + } + } +}