diff --git a/codec-http2/src/main/java/io/netty/handler/codec/http2/HttpToHttp2ConnectionHandler.java b/codec-http2/src/main/java/io/netty/handler/codec/http2/HttpToHttp2ConnectionHandler.java index ca3676617d..6bb45045dc 100644 --- a/codec-http2/src/main/java/io/netty/handler/codec/http2/HttpToHttp2ConnectionHandler.java +++ b/codec-http2/src/main/java/io/netty/handler/codec/http2/HttpToHttp2ConnectionHandler.java @@ -15,11 +15,16 @@ package io.netty.handler.codec.http2; +import io.netty.buffer.ByteBuf; import io.netty.channel.ChannelHandlerContext; import io.netty.channel.ChannelPromise; import io.netty.handler.codec.http.FullHttpMessage; +import io.netty.handler.codec.http.HttpContent; import io.netty.handler.codec.http.HttpHeaders; +import io.netty.handler.codec.http.HttpMessage; +import io.netty.handler.codec.http.LastHttpContent; import io.netty.handler.codec.http2.Http2CodecUtil.SimpleChannelPromiseAggregator; +import io.netty.util.ReferenceCountUtil; /** * Translates HTTP/1.x object writes into HTTP/2 frames. @@ -27,6 +32,9 @@ import io.netty.handler.codec.http2.Http2CodecUtil.SimpleChannelPromiseAggregato * See {@link InboundHttp2ToHttpAdapter} to get translation from HTTP/2 frames to HTTP/1.x objects. */ public class HttpToHttp2ConnectionHandler extends Http2ConnectionHandler { + + private int currentStreamId; + public HttpToHttp2ConnectionHandler(boolean server, Http2FrameListener listener) { super(server, listener); } @@ -57,45 +65,64 @@ public class HttpToHttp2ConnectionHandler extends Http2ConnectionHandler { } /** - * Handles conversion of a {@link FullHttpMessage} to HTTP/2 frames. + * Handles conversion of {@link HttpMessage} and {@link HttpContent} to HTTP/2 frames. */ @Override public void write(ChannelHandlerContext ctx, Object msg, ChannelPromise promise) { - if (msg instanceof FullHttpMessage) { - FullHttpMessage httpMsg = (FullHttpMessage) msg; - boolean hasData = httpMsg.content().isReadable(); - boolean httpMsgNeedRelease = true; - SimpleChannelPromiseAggregator promiseAggregator = null; - try { + + if (!(msg instanceof HttpMessage || msg instanceof HttpContent)) { + ctx.write(msg, promise); + return; + } + + boolean release = true; + SimpleChannelPromiseAggregator promiseAggregator = + new SimpleChannelPromiseAggregator(promise, ctx.channel(), ctx.executor()); + try { + Http2ConnectionEncoder encoder = encoder(); + boolean endStream = false; + if (msg instanceof HttpMessage) { + final HttpMessage httpMsg = (HttpMessage) msg; + // Provide the user the opportunity to specify the streamId - int streamId = getStreamId(httpMsg.headers()); + currentStreamId = getStreamId(httpMsg.headers()); // Convert and write the headers. Http2Headers http2Headers = HttpUtil.toHttp2Headers(httpMsg); - Http2ConnectionEncoder encoder = encoder(); + endStream = msg instanceof FullHttpMessage && !((FullHttpMessage) msg).content().isReadable(); + encoder.writeHeaders(ctx, currentStreamId, http2Headers, 0, endStream, promiseAggregator.newPromise()); + } - if (hasData) { - promiseAggregator = new SimpleChannelPromiseAggregator(promise, ctx.channel(), ctx.executor()); - encoder.writeHeaders(ctx, streamId, http2Headers, 0, false, promiseAggregator.newPromise()); - httpMsgNeedRelease = false; - encoder.writeData(ctx, streamId, httpMsg.content(), 0, true, promiseAggregator.newPromise()); - promiseAggregator.doneAllocatingPromises(); - } else { - encoder.writeHeaders(ctx, streamId, http2Headers, 0, true, promise); + if (!endStream && msg instanceof HttpContent) { + boolean isLastContent = false; + Http2Headers trailers = EmptyHttp2Headers.INSTANCE; + if (msg instanceof LastHttpContent) { + isLastContent = true; + + // Convert any trailing headers. + final LastHttpContent lastContent = (LastHttpContent) msg; + trailers = HttpUtil.toHttp2Headers(lastContent.trailingHeaders()); } - } catch (Throwable t) { - if (promiseAggregator == null) { - promise.tryFailure(t); - } else { - promiseAggregator.setFailure(t); - } - } finally { - if (httpMsgNeedRelease) { - httpMsg.release(); + + // Write the data + final ByteBuf content = ((HttpContent) msg).content(); + endStream = isLastContent && trailers.isEmpty(); + release = false; + encoder.writeData(ctx, currentStreamId, content, 0, endStream, promiseAggregator.newPromise()); + + if (!trailers.isEmpty()) { + // Write trailing headers. + encoder.writeHeaders(ctx, currentStreamId, trailers, 0, true, promiseAggregator.newPromise()); } } - } else { - ctx.write(msg, promise); + + promiseAggregator.doneAllocatingPromises(); + } catch (Throwable t) { + promiseAggregator.setFailure(t); + } finally { + if (release) { + ReferenceCountUtil.release(msg); + } } } } diff --git a/codec-http2/src/main/java/io/netty/handler/codec/http2/HttpUtil.java b/codec-http2/src/main/java/io/netty/handler/codec/http2/HttpUtil.java index b312c59bfc..da65fcc460 100644 --- a/codec-http2/src/main/java/io/netty/handler/codec/http2/HttpUtil.java +++ b/codec-http2/src/main/java/io/netty/handler/codec/http2/HttpUtil.java @@ -29,6 +29,7 @@ import io.netty.handler.codec.http.HttpHeaderNames; import io.netty.handler.codec.http.HttpHeaderUtil; import io.netty.handler.codec.http.HttpHeaderValues; import io.netty.handler.codec.http.HttpHeaders; +import io.netty.handler.codec.http.HttpMessage; import io.netty.handler.codec.http.HttpMethod; import io.netty.handler.codec.http.HttpRequest; import io.netty.handler.codec.http.HttpResponse; @@ -265,7 +266,7 @@ public final class HttpUtil { /** * Converts the given HTTP/1.x headers into HTTP/2 headers. */ - public static Http2Headers toHttp2Headers(FullHttpMessage in) throws Exception { + public static Http2Headers toHttp2Headers(HttpMessage in) throws Exception { final Http2Headers out = new DefaultHttp2Headers(); HttpHeaders inHeaders = in.headers(); if (in instanceof HttpRequest) { @@ -304,6 +305,16 @@ public final class HttpUtil { } // Add the HTTP headers which have not been consumed above + return out.add(toHttp2Headers(inHeaders)); + } + + public static Http2Headers toHttp2Headers(HttpHeaders inHeaders) throws Exception { + if (inHeaders.isEmpty()) { + return EmptyHttp2Headers.INSTANCE; + } + + final Http2Headers out = new DefaultHttp2Headers(); + inHeaders.forEachEntry(new EntryVisitor() { @Override public boolean visit(Entry entry) throws Exception { diff --git a/codec-http2/src/test/java/io/netty/handler/codec/http2/HttpToHttp2ConnectionHandlerTest.java b/codec-http2/src/test/java/io/netty/handler/codec/http2/HttpToHttp2ConnectionHandlerTest.java index 04d5f7a743..9960de21e7 100644 --- a/codec-http2/src/test/java/io/netty/handler/codec/http2/HttpToHttp2ConnectionHandlerTest.java +++ b/codec-http2/src/test/java/io/netty/handler/codec/http2/HttpToHttp2ConnectionHandlerTest.java @@ -45,9 +45,14 @@ import io.netty.channel.nio.NioEventLoopGroup; import io.netty.channel.socket.nio.NioServerSocketChannel; import io.netty.channel.socket.nio.NioSocketChannel; import io.netty.handler.codec.http.DefaultFullHttpRequest; +import io.netty.handler.codec.http.DefaultHttpContent; +import io.netty.handler.codec.http.DefaultHttpRequest; +import io.netty.handler.codec.http.DefaultLastHttpContent; import io.netty.handler.codec.http.FullHttpRequest; import io.netty.handler.codec.http.HttpHeaderNames; import io.netty.handler.codec.http.HttpHeaders; +import io.netty.handler.codec.http.HttpRequest; +import io.netty.handler.codec.http.LastHttpContent; import io.netty.handler.codec.http2.Http2TestUtil.FrameCountDown; import io.netty.util.NetUtil; import io.netty.util.concurrent.Future; @@ -84,6 +89,7 @@ public class HttpToHttp2ConnectionHandlerTest { private Channel clientChannel; private CountDownLatch requestLatch; private CountDownLatch serverSettingsAckLatch; + private CountDownLatch trailersLatch; private FrameCountDown serverFrameCountDown; @Before @@ -104,7 +110,7 @@ public class HttpToHttp2ConnectionHandlerTest { @Test public void testJustHeadersRequest() throws Exception { - bootstrapEnv(2, 1); + bootstrapEnv(2, 1, 0); final FullHttpRequest request = new DefaultFullHttpRequest(HTTP_1_1, GET, "/example"); final HttpHeaders httpHeaders = request.headers(); httpHeaders.setInt(HttpUtil.ExtensionHeaderNames.STREAM_ID.text(), 5); @@ -146,7 +152,7 @@ public class HttpToHttp2ConnectionHandlerTest { } }).when(serverListener).onDataRead(any(ChannelHandlerContext.class), eq(3), any(ByteBuf.class), eq(0), eq(true)); - bootstrapEnv(3, 1); + bootstrapEnv(3, 1, 0); final FullHttpRequest request = new DefaultFullHttpRequest(HTTP_1_1, POST, "/example", Unpooled.copiedBuffer(text, UTF_8)); final HttpHeaders httpHeaders = request.headers(); @@ -175,9 +181,127 @@ public class HttpToHttp2ConnectionHandlerTest { assertEquals(text, receivedBuffers.get(0)); } - private void bootstrapEnv(int requestCountDown, int serverSettingsAckCount) throws Exception { + @Test + public void testRequestWithBodyAndTrailingHeaders() throws Exception { + final String text = "foooooogoooo"; + final List receivedBuffers = Collections.synchronizedList(new ArrayList()); + doAnswer(new Answer() { + @Override + public Void answer(InvocationOnMock in) throws Throwable { + receivedBuffers.add(((ByteBuf) in.getArguments()[2]).toString(UTF_8)); + return null; + } + }).when(serverListener).onDataRead(any(ChannelHandlerContext.class), eq(3), + any(ByteBuf.class), eq(0), eq(false)); + bootstrapEnv(4, 1, 1); + final FullHttpRequest request = new DefaultFullHttpRequest(HTTP_1_1, POST, "/example", + Unpooled.copiedBuffer(text, UTF_8)); + final HttpHeaders httpHeaders = request.headers(); + httpHeaders.set(HttpHeaderNames.HOST, "http://your_user-name123@www.example.org:5555/example"); + httpHeaders.add("foo", "goo"); + httpHeaders.add("foo", "goo2"); + httpHeaders.add("foo2", "goo2"); + final Http2Headers http2Headers = + new DefaultHttp2Headers().method(as("POST")).path(as("/example")) + .authority(as("www.example.org:5555")).scheme(as("http")) + .add(as("foo"), as("goo")).add(as("foo"), as("goo2")) + .add(as("foo2"), as("goo2")); + + request.trailingHeaders().add("trailing", "bar"); + + final Http2Headers http2TrailingHeaders = new DefaultHttp2Headers().add(as("trailing"), as("bar")); + + ChannelPromise writePromise = newPromise(); + ChannelFuture writeFuture = clientChannel.writeAndFlush(request, writePromise); + + assertTrue(writePromise.awaitUninterruptibly(WAIT_TIME_SECONDS, SECONDS)); + assertTrue(writePromise.isSuccess()); + assertTrue(writeFuture.awaitUninterruptibly(WAIT_TIME_SECONDS, SECONDS)); + assertTrue(writeFuture.isSuccess()); + awaitRequests(); + verify(serverListener).onHeadersRead(any(ChannelHandlerContext.class), eq(3), eq(http2Headers), eq(0), + anyShort(), anyBoolean(), eq(0), eq(false)); + verify(serverListener).onDataRead(any(ChannelHandlerContext.class), eq(3), any(ByteBuf.class), eq(0), + eq(false)); + verify(serverListener).onHeadersRead(any(ChannelHandlerContext.class), eq(3), eq(http2TrailingHeaders), eq(0), + anyShort(), anyBoolean(), eq(0), eq(true)); + assertEquals(1, receivedBuffers.size()); + assertEquals(text, receivedBuffers.get(0)); + } + + @Test + public void testChunkedRequestWithBodyAndTrailingHeaders() throws Exception { + final String text = "foooooo"; + final String text2 = "goooo"; + final List receivedBuffers = Collections.synchronizedList(new ArrayList()); + doAnswer(new Answer() { + @Override + public Void answer(InvocationOnMock in) throws Throwable { + receivedBuffers.add(((ByteBuf) in.getArguments()[2]).toString(UTF_8)); + return null; + } + }).when(serverListener).onDataRead(any(ChannelHandlerContext.class), eq(3), + any(ByteBuf.class), eq(0), eq(false)); + bootstrapEnv(4, 1, 1); + final HttpRequest request = new DefaultHttpRequest(HTTP_1_1, POST, "/example"); + final HttpHeaders httpHeaders = request.headers(); + httpHeaders.set(HttpHeaderNames.HOST, "http://your_user-name123@www.example.org:5555/example"); + httpHeaders.add(HttpHeaderNames.TRANSFER_ENCODING, "chunked"); + httpHeaders.add("foo", "goo"); + httpHeaders.add("foo", "goo2"); + httpHeaders.add("foo2", "goo2"); + final Http2Headers http2Headers = + new DefaultHttp2Headers().method(as("POST")).path(as("/example")) + .authority(as("www.example.org:5555")).scheme(as("http")) + .add(as("foo"), as("goo")).add(as("foo"), as("goo2")) + .add(as("foo2"), as("goo2")); + + final DefaultHttpContent httpContent = new DefaultHttpContent(Unpooled.copiedBuffer(text, UTF_8)); + final LastHttpContent lastHttpContent = new DefaultLastHttpContent(Unpooled.copiedBuffer(text2, UTF_8)); + + lastHttpContent.trailingHeaders().add("trailing", "bar"); + + final Http2Headers http2TrailingHeaders = new DefaultHttp2Headers().add(as("trailing"), as("bar")); + + ChannelPromise writePromise = newPromise(); + ChannelFuture writeFuture = clientChannel.write(request, writePromise); + ChannelPromise contentPromise = newPromise(); + ChannelFuture contentFuture = clientChannel.write(httpContent, contentPromise); + ChannelPromise lastContentPromise = newPromise(); + ChannelFuture lastContentFuture = clientChannel.write(lastHttpContent, lastContentPromise); + + clientChannel.flush(); + + assertTrue(writePromise.awaitUninterruptibly(WAIT_TIME_SECONDS, SECONDS)); + assertTrue(writePromise.isSuccess()); + assertTrue(writeFuture.awaitUninterruptibly(WAIT_TIME_SECONDS, SECONDS)); + assertTrue(writeFuture.isSuccess()); + + assertTrue(contentPromise.awaitUninterruptibly(WAIT_TIME_SECONDS, SECONDS)); + assertTrue(contentPromise.isSuccess()); + assertTrue(contentFuture.awaitUninterruptibly(WAIT_TIME_SECONDS, SECONDS)); + assertTrue(contentFuture.isSuccess()); + + assertTrue(lastContentPromise.awaitUninterruptibly(WAIT_TIME_SECONDS, SECONDS)); + assertTrue(lastContentPromise.isSuccess()); + assertTrue(lastContentFuture.awaitUninterruptibly(WAIT_TIME_SECONDS, SECONDS)); + assertTrue(lastContentFuture.isSuccess()); + + awaitRequests(); + verify(serverListener).onHeadersRead(any(ChannelHandlerContext.class), eq(3), eq(http2Headers), eq(0), + anyShort(), anyBoolean(), eq(0), eq(false)); + verify(serverListener).onDataRead(any(ChannelHandlerContext.class), eq(3), any(ByteBuf.class), eq(0), + eq(false)); + verify(serverListener).onHeadersRead(any(ChannelHandlerContext.class), eq(3), eq(http2TrailingHeaders), eq(0), + anyShort(), anyBoolean(), eq(0), eq(true)); + assertEquals(1, receivedBuffers.size()); + assertEquals(text + text2, receivedBuffers.get(0)); + } + + private void bootstrapEnv(int requestCountDown, int serverSettingsAckCount, int trailersCount) throws Exception { requestLatch = new CountDownLatch(requestCountDown); serverSettingsAckLatch = new CountDownLatch(serverSettingsAckCount); + trailersLatch = trailersCount == 0 ? null : new CountDownLatch(trailersCount); sb = new ServerBootstrap(); cb = new Bootstrap(); @@ -188,7 +312,8 @@ public class HttpToHttp2ConnectionHandlerTest { @Override protected void initChannel(Channel ch) throws Exception { ChannelPipeline p = ch.pipeline(); - serverFrameCountDown = new FrameCountDown(serverListener, serverSettingsAckLatch, requestLatch); + serverFrameCountDown = + new FrameCountDown(serverListener, serverSettingsAckLatch, requestLatch, null, trailersLatch); p.addLast(new HttpToHttp2ConnectionHandler(true, serverFrameCountDown)); } }); @@ -213,6 +338,10 @@ public class HttpToHttp2ConnectionHandlerTest { private void awaitRequests() throws Exception { assertTrue(requestLatch.await(WAIT_TIME_SECONDS, SECONDS)); + if (trailersLatch != null) { + assertTrue(trailersLatch.await(WAIT_TIME_SECONDS, SECONDS)); + } + assertTrue(serverSettingsAckLatch.await(WAIT_TIME_SECONDS, SECONDS)); } private ChannelHandlerContext ctx() {