diff --git a/codec-http/src/test/java/io/netty/handler/codec/http/HttpObjectAggregatorTest.java b/codec-http/src/test/java/io/netty/handler/codec/http/HttpObjectAggregatorTest.java index 7dc0ac5a51..503c93339f 100644 --- a/codec-http/src/test/java/io/netty/handler/codec/http/HttpObjectAggregatorTest.java +++ b/codec-http/src/test/java/io/netty/handler/codec/http/HttpObjectAggregatorTest.java @@ -23,7 +23,10 @@ import io.netty.channel.embedded.EmbeddedChannel; import io.netty.handler.codec.DecoderResult; import io.netty.handler.codec.DecoderResultProvider; import io.netty.handler.codec.TooLongFrameException; +import io.netty.util.AsciiString; import io.netty.util.CharsetUtil; +import io.netty.util.ReferenceCountUtil; + import org.junit.Test; import org.mockito.Mockito; @@ -40,6 +43,7 @@ import static org.junit.Assert.assertNull; import static org.junit.Assert.assertThat; import static org.junit.Assert.assertTrue; import static org.junit.Assert.fail; +import static org.junit.Assert.assertSame; public class HttpObjectAggregatorTest { @@ -517,4 +521,123 @@ public class HttpObjectAggregatorTest { aggregatedRep.release(); replacedRep.release(); } + + @Test + public void testSelectiveRequestAggregation() { + HttpObjectAggregator myPostAggregator = new HttpObjectAggregator(1024 * 1024) { + @Override + protected boolean isStartMessage(HttpObject msg) throws Exception { + if (msg instanceof HttpRequest) { + HttpRequest request = (HttpRequest) msg; + HttpMethod method = request.method(); + + if (method.equals(HttpMethod.POST)) { + return true; + } + } + + return false; + } + }; + + EmbeddedChannel channel = new EmbeddedChannel(myPostAggregator); + + try { + // Aggregate: POST + HttpRequest request1 = new DefaultHttpRequest(HttpVersion.HTTP_1_1, HttpMethod.POST, "/"); + HttpContent content1 = new DefaultHttpContent(Unpooled.copiedBuffer("Hello, World!", CharsetUtil.UTF_8)); + request1.headers().set(HttpHeaderNames.CONTENT_TYPE, HttpHeaderValues.TEXT_PLAIN); + + assertTrue(channel.writeInbound(request1, content1, LastHttpContent.EMPTY_LAST_CONTENT)); + + // Getting an aggregated response out + Object msg1 = channel.readInbound(); + try { + assertTrue(msg1 instanceof FullHttpRequest); + } finally { + ReferenceCountUtil.release(msg1); + } + + // Don't aggregate: non-POST + HttpRequest request2 = new DefaultHttpRequest(HttpVersion.HTTP_1_1, HttpMethod.PUT, "/"); + HttpContent content2 = new DefaultHttpContent(Unpooled.copiedBuffer("Hello, World!", CharsetUtil.UTF_8)); + request2.headers().set(HttpHeaderNames.CONTENT_TYPE, HttpHeaderValues.TEXT_PLAIN); + + try { + assertTrue(channel.writeInbound(request2, content2, LastHttpContent.EMPTY_LAST_CONTENT)); + + // Getting the same response objects out + assertSame(request2, channel.readInbound()); + assertSame(content2, channel.readInbound()); + assertSame(LastHttpContent.EMPTY_LAST_CONTENT, channel.readInbound()); + } finally { + ReferenceCountUtil.release(request2); + ReferenceCountUtil.release(content2); + } + + assertFalse(channel.finish()); + } finally { + channel.close(); + } + } + + @Test + public void testSelectiveResponseAggregation() { + HttpObjectAggregator myTextAggregator = new HttpObjectAggregator(1024 * 1024) { + @Override + protected boolean isStartMessage(HttpObject msg) throws Exception { + if (msg instanceof HttpResponse) { + HttpResponse response = (HttpResponse) msg; + HttpHeaders headers = response.headers(); + + String contentType = headers.get(HttpHeaderNames.CONTENT_TYPE); + if (AsciiString.contentEqualsIgnoreCase(contentType, HttpHeaderValues.TEXT_PLAIN)) { + return true; + } + } + + return false; + } + }; + + EmbeddedChannel channel = new EmbeddedChannel(myTextAggregator); + + try { + // Aggregate: text/plain + HttpResponse response1 = new DefaultHttpResponse(HttpVersion.HTTP_1_1, HttpResponseStatus.OK); + HttpContent content1 = new DefaultHttpContent(Unpooled.copiedBuffer("Hello, World!", CharsetUtil.UTF_8)); + response1.headers().set(HttpHeaderNames.CONTENT_TYPE, HttpHeaderValues.TEXT_PLAIN); + + assertTrue(channel.writeInbound(response1, content1, LastHttpContent.EMPTY_LAST_CONTENT)); + + // Getting an aggregated response out + Object msg1 = channel.readInbound(); + try { + assertTrue(msg1 instanceof FullHttpResponse); + } finally { + ReferenceCountUtil.release(msg1); + } + + // Don't aggregate: application/json + HttpResponse response2 = new DefaultHttpResponse(HttpVersion.HTTP_1_1, HttpResponseStatus.OK); + HttpContent content2 = new DefaultHttpContent(Unpooled.copiedBuffer("{key: 'value'}", CharsetUtil.UTF_8)); + response2.headers().set(HttpHeaderNames.CONTENT_TYPE, HttpHeaderValues.APPLICATION_JSON); + + try { + assertTrue(channel.writeInbound(response2, content2, LastHttpContent.EMPTY_LAST_CONTENT)); + + // Getting the same response objects out + assertSame(response2, channel.readInbound()); + assertSame(content2, channel.readInbound()); + assertSame(LastHttpContent.EMPTY_LAST_CONTENT, channel.readInbound()); + } finally { + ReferenceCountUtil.release(response2); + ReferenceCountUtil.release(content2); + } + + assertFalse(channel.finish()); + } finally { + channel.close(); + } + } } diff --git a/codec/src/main/java/io/netty/handler/codec/MessageAggregator.java b/codec/src/main/java/io/netty/handler/codec/MessageAggregator.java index 9f145f0b6e..046dc1a4d0 100644 --- a/codec/src/main/java/io/netty/handler/codec/MessageAggregator.java +++ b/codec/src/main/java/io/netty/handler/codec/MessageAggregator.java @@ -62,6 +62,8 @@ public abstract class MessageAggregator out) throws Exception { + assert aggregating; + if (isStartMessage(msg)) { handlingOversizedMessage = false; if (currentMessage != null) { @@ -242,7 +259,7 @@ public abstract class MessageAggregator