diff --git a/codec-http/src/main/java/io/netty/handler/codec/http/HttpObjectAggregator.java b/codec-http/src/main/java/io/netty/handler/codec/http/HttpObjectAggregator.java index cf4559da86..696e8c1de4 100644 --- a/codec-http/src/main/java/io/netty/handler/codec/http/HttpObjectAggregator.java +++ b/codec-http/src/main/java/io/netty/handler/codec/http/HttpObjectAggregator.java @@ -15,7 +15,9 @@ */ package io.netty.handler.codec.http; +import io.netty.buffer.ByteBuf; import io.netty.buffer.CompositeByteBuf; +import io.netty.buffer.DefaultByteBufHolder; import io.netty.buffer.Unpooled; import io.netty.channel.ChannelFuture; import io.netty.channel.ChannelFutureListener; @@ -55,7 +57,7 @@ public class HttpObjectAggregator extends MessageToMessageDecoder { new DefaultFullHttpResponse(HttpVersion.HTTP_1_1, HttpResponseStatus.CONTINUE, Unpooled.EMPTY_BUFFER); private final int maxContentLength; - private FullHttpMessage currentMessage; + private AggregatedFullHttpMessage currentMessage; private boolean tooLongFrameFound; private int maxCumulationBufferComponents = DEFAULT_MAX_COMPOSITEBUFFER_COMPONENTS; @@ -112,7 +114,7 @@ public class HttpObjectAggregator extends MessageToMessageDecoder { @Override protected void decode(final ChannelHandlerContext ctx, HttpObject msg, List out) throws Exception { - FullHttpMessage currentMessage = this.currentMessage; + AggregatedFullHttpMessage currentMessage = this.currentMessage; if (msg instanceof HttpMessage) { tooLongFrameFound = false; @@ -144,19 +146,17 @@ public class HttpObjectAggregator extends MessageToMessageDecoder { } if (msg instanceof HttpRequest) { HttpRequest header = (HttpRequest) msg; - this.currentMessage = currentMessage = new DefaultFullHttpRequest(header.getProtocolVersion(), - header.getMethod(), header.getUri(), Unpooled.compositeBuffer(maxCumulationBufferComponents)); + this.currentMessage = currentMessage = new AggregatedFullHttpRequest( + header, ctx.alloc().compositeBuffer(maxCumulationBufferComponents), null); } else if (msg instanceof HttpResponse) { HttpResponse header = (HttpResponse) msg; - this.currentMessage = currentMessage = new DefaultFullHttpResponse( - header.getProtocolVersion(), header.getStatus(), - Unpooled.compositeBuffer(maxCumulationBufferComponents)); + this.currentMessage = currentMessage = new AggregatedFullHttpResponse( + header, + Unpooled.compositeBuffer(maxCumulationBufferComponents), null); } else { throw new Error(); } - currentMessage.headers().set(m.headers()); - // A streamed message - initialize the cumulative buffer, and wait for incoming chunks. removeTransferEncodingChunked(currentMessage); } else if (msg instanceof HttpContent) { @@ -207,7 +207,9 @@ public class HttpObjectAggregator extends MessageToMessageDecoder { // Merge trailing headers into the message. if (chunk instanceof LastHttpContent) { LastHttpContent trailer = (LastHttpContent) chunk; - currentMessage.headers().add(trailer.trailingHeaders()); + currentMessage.setTrailingHeaders(trailer.trailingHeaders()); + } else { + currentMessage.setTrailingHeaders(new DefaultHttpHeaders()); } // Set the 'Content-Length' header. @@ -257,19 +259,197 @@ public class HttpObjectAggregator extends MessageToMessageDecoder { FullHttpMessage fullMsg; if (msg instanceof HttpRequest) { - HttpRequest req = (HttpRequest) msg; - fullMsg = new DefaultFullHttpRequest( - req.getProtocolVersion(), req.getMethod(), req.getUri(), Unpooled.EMPTY_BUFFER, false); - fullMsg.setDecoderResult(req.getDecoderResult()); + fullMsg = new AggregatedFullHttpRequest( + (HttpRequest) msg, Unpooled.EMPTY_BUFFER, new DefaultHttpHeaders()); } else if (msg instanceof HttpResponse) { - HttpResponse res = (HttpResponse) msg; - fullMsg = new DefaultFullHttpResponse( - res.getProtocolVersion(), res.getStatus(), Unpooled.EMPTY_BUFFER, false); - fullMsg.setDecoderResult(res.getDecoderResult()); + fullMsg = new AggregatedFullHttpResponse( + (HttpResponse) msg, Unpooled.EMPTY_BUFFER, new DefaultHttpHeaders()); } else { throw new IllegalStateException(); } return fullMsg; } + + private abstract static class AggregatedFullHttpMessage extends DefaultByteBufHolder implements FullHttpMessage { + protected final HttpMessage message; + private HttpHeaders trailingHeaders; + + private AggregatedFullHttpMessage(HttpMessage message, ByteBuf content, HttpHeaders trailingHeaders) { + super(content); + this.message = message; + this.trailingHeaders = trailingHeaders; + } + @Override + public HttpHeaders trailingHeaders() { + return trailingHeaders; + } + + public void setTrailingHeaders(HttpHeaders trailingHeaders) { + this.trailingHeaders = trailingHeaders; + } + + @Override + public HttpVersion getProtocolVersion() { + return message.getProtocolVersion(); + } + + @Override + public FullHttpMessage setProtocolVersion(HttpVersion version) { + message.setProtocolVersion(version); + return this; + } + + @Override + public HttpHeaders headers() { + return message.headers(); + } + + @Override + public DecoderResult getDecoderResult() { + return message.getDecoderResult(); + } + + @Override + public void setDecoderResult(DecoderResult result) { + message.setDecoderResult(result); + } + + @Override + public FullHttpMessage retain(int increment) { + super.retain(increment); + return this; + } + + @Override + public FullHttpMessage retain() { + super.retain(); + return this; + } + + @Override + public abstract FullHttpMessage copy(); + + @Override + public abstract FullHttpMessage duplicate(); + } + + private static final class AggregatedFullHttpRequest extends AggregatedFullHttpMessage implements FullHttpRequest { + + private AggregatedFullHttpRequest(HttpRequest request, ByteBuf content, HttpHeaders trailingHeaders) { + super(request, content, trailingHeaders); + } + + @Override + public FullHttpRequest copy() { + DefaultFullHttpRequest copy = new DefaultFullHttpRequest( + getProtocolVersion(), getMethod(), getUri(), content().copy()); + copy.headers().set(headers()); + copy.trailingHeaders().set(trailingHeaders()); + return copy; + } + + @Override + public FullHttpRequest duplicate() { + DefaultFullHttpRequest duplicate = new DefaultFullHttpRequest( + getProtocolVersion(), getMethod(), getUri(), content().duplicate()); + duplicate.headers().set(headers()); + duplicate.trailingHeaders().set(trailingHeaders()); + return duplicate; + } + + @Override + public FullHttpRequest retain(int increment) { + super.retain(increment); + return this; + } + + @Override + public FullHttpRequest retain() { + super.retain(); + return this; + } + + @Override + public FullHttpRequest setMethod(HttpMethod method) { + ((HttpRequest) message).setMethod(method); + return this; + } + + @Override + public FullHttpRequest setUri(String uri) { + ((HttpRequest) message).setUri(uri); + return this; + } + + @Override + public HttpMethod getMethod() { + return ((HttpRequest) message).getMethod(); + } + + @Override + public String getUri() { + return ((HttpRequest) message).getUri(); + } + + @Override + public FullHttpRequest setProtocolVersion(HttpVersion version) { + super.setProtocolVersion(version); + return this; + } + } + + private static final class AggregatedFullHttpResponse extends AggregatedFullHttpMessage + implements FullHttpResponse { + private AggregatedFullHttpResponse(HttpResponse message, ByteBuf content, HttpHeaders trailingHeaders) { + super(message, content, trailingHeaders); + } + + @Override + public FullHttpResponse copy() { + DefaultFullHttpResponse copy = new DefaultFullHttpResponse( + getProtocolVersion(), getStatus(), content().copy()); + copy.headers().set(headers()); + copy.trailingHeaders().set(trailingHeaders()); + return copy; + } + + @Override + public FullHttpResponse duplicate() { + DefaultFullHttpResponse duplicate = new DefaultFullHttpResponse(getProtocolVersion(), getStatus(), + content().duplicate()); + duplicate.headers().set(headers()); + duplicate.trailingHeaders().set(trailingHeaders()); + return duplicate; + } + + @Override + public FullHttpResponse setStatus(HttpResponseStatus status) { + ((HttpResponse) message).setStatus(status); + return this; + } + + @Override + public HttpResponseStatus getStatus() { + return ((HttpResponse) message).getStatus(); + } + + @Override + public FullHttpResponse setProtocolVersion(HttpVersion version) { + super.setProtocolVersion(version); + return this; + } + + @Override + public FullHttpResponse retain(int increment) { + super.retain(increment); + return this; + } + + @Override + public FullHttpResponse retain() { + super.retain(); + return this; + } + } } diff --git a/codec-http/src/test/java/io/netty/handler/codec/http/HttpObjectAggregatorTest.java b/codec-http/src/test/java/io/netty/handler/codec/http/HttpObjectAggregatorTest.java index 614d50d195..4e096c7392 100644 --- a/codec-http/src/test/java/io/netty/handler/codec/http/HttpObjectAggregatorTest.java +++ b/codec-http/src/test/java/io/netty/handler/codec/http/HttpObjectAggregatorTest.java @@ -51,7 +51,7 @@ public class HttpObjectAggregatorTest { // this should trigger a channelRead event so return true assertTrue(embedder.writeInbound(chunk3)); assertTrue(embedder.finish()); - DefaultFullHttpRequest aggratedMessage = (DefaultFullHttpRequest) embedder.readInbound(); + FullHttpRequest aggratedMessage = (FullHttpRequest) embedder.readInbound(); assertNotNull(aggratedMessage); assertEquals(chunk1.content().readableBytes() + chunk2.content().readableBytes(), @@ -93,13 +93,13 @@ public class HttpObjectAggregatorTest { // this should trigger a channelRead event so return true assertTrue(embedder.writeInbound(trailer)); assertTrue(embedder.finish()); - DefaultFullHttpRequest aggratedMessage = (DefaultFullHttpRequest) embedder.readInbound(); + FullHttpRequest aggratedMessage = (FullHttpRequest) embedder.readInbound(); assertNotNull(aggratedMessage); assertEquals(chunk1.content().readableBytes() + chunk2.content().readableBytes(), HttpHeaders.getContentLength(aggratedMessage)); assertEquals(aggratedMessage.headers().get("X-Test"), Boolean.TRUE.toString()); - assertEquals(aggratedMessage.headers().get("X-Trailer"), Boolean.TRUE.toString()); + assertEquals(aggratedMessage.trailingHeaders().get("X-Trailer"), Boolean.TRUE.toString()); checkContentBuffer(aggratedMessage); assertNull(embedder.readInbound()); }