Support conversion of HttpMessage and HttpContent to HTTP/2 Frames

Motivation:

HttpToHttp2ConnectionHandler only converts FullHttpMessage to HTTP/2 Frames. This does not support other use cases such as adding a HttpContentCompressor to the pipeline, which writes HttpMessage and HttpContent.

Additionally HttpToHttp2ConnectionHandler ignores converting and sending HTTP trailing headers, which is a bug as the HTTP/2 spec states that they should be sent.

Modifications:

Update HttpToHttp2ConnectionHandler to support converting HttpMessage and HttpContent to HTTP/2 Frames.
Additionally, include an extra call to writeHeaders if the message includes trailing headers

Result:

One can now write HttpMessage and HttpContent (http chunking) down the pipeline and they will be converted to HTTP/2 Frames.  If any trailing headers exist, they will be converted and sent as well.
This commit is contained in:
Brendt Lucas 2015-07-11 12:09:45 +01:00 committed by Norman Maurer
parent 22cea8a2cc
commit 3ec02bf746
3 changed files with 200 additions and 33 deletions

View File

@ -15,11 +15,16 @@
package io.netty.handler.codec.http2;
import io.netty.buffer.ByteBuf;
import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.ChannelPromise;
import io.netty.handler.codec.http.FullHttpMessage;
import io.netty.handler.codec.http.HttpContent;
import io.netty.handler.codec.http.HttpHeaders;
import io.netty.handler.codec.http.HttpMessage;
import io.netty.handler.codec.http.LastHttpContent;
import io.netty.handler.codec.http2.Http2CodecUtil.SimpleChannelPromiseAggregator;
import io.netty.util.ReferenceCountUtil;
/**
* Translates HTTP/1.x object writes into HTTP/2 frames.
@ -27,6 +32,9 @@ import io.netty.handler.codec.http2.Http2CodecUtil.SimpleChannelPromiseAggregato
* See {@link InboundHttp2ToHttpAdapter} to get translation from HTTP/2 frames to HTTP/1.x objects.
*/
public class HttpToHttp2ConnectionHandler extends Http2ConnectionHandler {
private int currentStreamId;
public HttpToHttp2ConnectionHandler(boolean server, Http2FrameListener listener) {
super(server, listener);
}
@ -57,45 +65,64 @@ public class HttpToHttp2ConnectionHandler extends Http2ConnectionHandler {
}
/**
* Handles conversion of a {@link FullHttpMessage} to HTTP/2 frames.
* Handles conversion of {@link HttpMessage} and {@link HttpContent} to HTTP/2 frames.
*/
@Override
public void write(ChannelHandlerContext ctx, Object msg, ChannelPromise promise) {
if (msg instanceof FullHttpMessage) {
FullHttpMessage httpMsg = (FullHttpMessage) msg;
boolean hasData = httpMsg.content().isReadable();
boolean httpMsgNeedRelease = true;
SimpleChannelPromiseAggregator promiseAggregator = null;
if (!(msg instanceof HttpMessage || msg instanceof HttpContent)) {
ctx.write(msg, promise);
return;
}
boolean release = true;
SimpleChannelPromiseAggregator promiseAggregator =
new SimpleChannelPromiseAggregator(promise, ctx.channel(), ctx.executor());
try {
Http2ConnectionEncoder encoder = encoder();
boolean endStream = false;
if (msg instanceof HttpMessage) {
final HttpMessage httpMsg = (HttpMessage) msg;
// Provide the user the opportunity to specify the streamId
int streamId = getStreamId(httpMsg.headers());
currentStreamId = getStreamId(httpMsg.headers());
// Convert and write the headers.
Http2Headers http2Headers = HttpUtil.toHttp2Headers(httpMsg);
Http2ConnectionEncoder encoder = encoder();
endStream = msg instanceof FullHttpMessage && !((FullHttpMessage) msg).content().isReadable();
encoder.writeHeaders(ctx, currentStreamId, http2Headers, 0, endStream, promiseAggregator.newPromise());
}
if (!endStream && msg instanceof HttpContent) {
boolean isLastContent = false;
Http2Headers trailers = EmptyHttp2Headers.INSTANCE;
if (msg instanceof LastHttpContent) {
isLastContent = true;
// Convert any trailing headers.
final LastHttpContent lastContent = (LastHttpContent) msg;
trailers = HttpUtil.toHttp2Headers(lastContent.trailingHeaders());
}
// Write the data
final ByteBuf content = ((HttpContent) msg).content();
endStream = isLastContent && trailers.isEmpty();
release = false;
encoder.writeData(ctx, currentStreamId, content, 0, endStream, promiseAggregator.newPromise());
if (!trailers.isEmpty()) {
// Write trailing headers.
encoder.writeHeaders(ctx, currentStreamId, trailers, 0, true, promiseAggregator.newPromise());
}
}
if (hasData) {
promiseAggregator = new SimpleChannelPromiseAggregator(promise, ctx.channel(), ctx.executor());
encoder.writeHeaders(ctx, streamId, http2Headers, 0, false, promiseAggregator.newPromise());
httpMsgNeedRelease = false;
encoder.writeData(ctx, streamId, httpMsg.content(), 0, true, promiseAggregator.newPromise());
promiseAggregator.doneAllocatingPromises();
} else {
encoder.writeHeaders(ctx, streamId, http2Headers, 0, true, promise);
}
} catch (Throwable t) {
if (promiseAggregator == null) {
promise.tryFailure(t);
} else {
promiseAggregator.setFailure(t);
}
} finally {
if (httpMsgNeedRelease) {
httpMsg.release();
}
}
} else {
ctx.write(msg, promise);
if (release) {
ReferenceCountUtil.release(msg);
}
}
}
}

View File

@ -29,6 +29,7 @@ import io.netty.handler.codec.http.HttpHeaderNames;
import io.netty.handler.codec.http.HttpHeaderUtil;
import io.netty.handler.codec.http.HttpHeaderValues;
import io.netty.handler.codec.http.HttpHeaders;
import io.netty.handler.codec.http.HttpMessage;
import io.netty.handler.codec.http.HttpMethod;
import io.netty.handler.codec.http.HttpRequest;
import io.netty.handler.codec.http.HttpResponse;
@ -265,7 +266,7 @@ public final class HttpUtil {
/**
* Converts the given HTTP/1.x headers into HTTP/2 headers.
*/
public static Http2Headers toHttp2Headers(FullHttpMessage in) throws Exception {
public static Http2Headers toHttp2Headers(HttpMessage in) throws Exception {
final Http2Headers out = new DefaultHttp2Headers();
HttpHeaders inHeaders = in.headers();
if (in instanceof HttpRequest) {
@ -304,6 +305,16 @@ public final class HttpUtil {
}
// Add the HTTP headers which have not been consumed above
return out.add(toHttp2Headers(inHeaders));
}
public static Http2Headers toHttp2Headers(HttpHeaders inHeaders) throws Exception {
if (inHeaders.isEmpty()) {
return EmptyHttp2Headers.INSTANCE;
}
final Http2Headers out = new DefaultHttp2Headers();
inHeaders.forEachEntry(new EntryVisitor() {
@Override
public boolean visit(Entry<CharSequence, CharSequence> entry) throws Exception {

View File

@ -45,9 +45,14 @@ import io.netty.channel.nio.NioEventLoopGroup;
import io.netty.channel.socket.nio.NioServerSocketChannel;
import io.netty.channel.socket.nio.NioSocketChannel;
import io.netty.handler.codec.http.DefaultFullHttpRequest;
import io.netty.handler.codec.http.DefaultHttpContent;
import io.netty.handler.codec.http.DefaultHttpRequest;
import io.netty.handler.codec.http.DefaultLastHttpContent;
import io.netty.handler.codec.http.FullHttpRequest;
import io.netty.handler.codec.http.HttpHeaderNames;
import io.netty.handler.codec.http.HttpHeaders;
import io.netty.handler.codec.http.HttpRequest;
import io.netty.handler.codec.http.LastHttpContent;
import io.netty.handler.codec.http2.Http2TestUtil.FrameCountDown;
import io.netty.util.NetUtil;
import io.netty.util.concurrent.Future;
@ -84,6 +89,7 @@ public class HttpToHttp2ConnectionHandlerTest {
private Channel clientChannel;
private CountDownLatch requestLatch;
private CountDownLatch serverSettingsAckLatch;
private CountDownLatch trailersLatch;
private FrameCountDown serverFrameCountDown;
@Before
@ -104,7 +110,7 @@ public class HttpToHttp2ConnectionHandlerTest {
@Test
public void testJustHeadersRequest() throws Exception {
bootstrapEnv(2, 1);
bootstrapEnv(2, 1, 0);
final FullHttpRequest request = new DefaultFullHttpRequest(HTTP_1_1, GET, "/example");
final HttpHeaders httpHeaders = request.headers();
httpHeaders.setInt(HttpUtil.ExtensionHeaderNames.STREAM_ID.text(), 5);
@ -146,7 +152,7 @@ public class HttpToHttp2ConnectionHandlerTest {
}
}).when(serverListener).onDataRead(any(ChannelHandlerContext.class), eq(3),
any(ByteBuf.class), eq(0), eq(true));
bootstrapEnv(3, 1);
bootstrapEnv(3, 1, 0);
final FullHttpRequest request = new DefaultFullHttpRequest(HTTP_1_1, POST, "/example",
Unpooled.copiedBuffer(text, UTF_8));
final HttpHeaders httpHeaders = request.headers();
@ -175,9 +181,127 @@ public class HttpToHttp2ConnectionHandlerTest {
assertEquals(text, receivedBuffers.get(0));
}
private void bootstrapEnv(int requestCountDown, int serverSettingsAckCount) throws Exception {
@Test
public void testRequestWithBodyAndTrailingHeaders() throws Exception {
final String text = "foooooogoooo";
final List<String> receivedBuffers = Collections.synchronizedList(new ArrayList<String>());
doAnswer(new Answer<Void>() {
@Override
public Void answer(InvocationOnMock in) throws Throwable {
receivedBuffers.add(((ByteBuf) in.getArguments()[2]).toString(UTF_8));
return null;
}
}).when(serverListener).onDataRead(any(ChannelHandlerContext.class), eq(3),
any(ByteBuf.class), eq(0), eq(false));
bootstrapEnv(4, 1, 1);
final FullHttpRequest request = new DefaultFullHttpRequest(HTTP_1_1, POST, "/example",
Unpooled.copiedBuffer(text, UTF_8));
final HttpHeaders httpHeaders = request.headers();
httpHeaders.set(HttpHeaderNames.HOST, "http://your_user-name123@www.example.org:5555/example");
httpHeaders.add("foo", "goo");
httpHeaders.add("foo", "goo2");
httpHeaders.add("foo2", "goo2");
final Http2Headers http2Headers =
new DefaultHttp2Headers().method(as("POST")).path(as("/example"))
.authority(as("www.example.org:5555")).scheme(as("http"))
.add(as("foo"), as("goo")).add(as("foo"), as("goo2"))
.add(as("foo2"), as("goo2"));
request.trailingHeaders().add("trailing", "bar");
final Http2Headers http2TrailingHeaders = new DefaultHttp2Headers().add(as("trailing"), as("bar"));
ChannelPromise writePromise = newPromise();
ChannelFuture writeFuture = clientChannel.writeAndFlush(request, writePromise);
assertTrue(writePromise.awaitUninterruptibly(WAIT_TIME_SECONDS, SECONDS));
assertTrue(writePromise.isSuccess());
assertTrue(writeFuture.awaitUninterruptibly(WAIT_TIME_SECONDS, SECONDS));
assertTrue(writeFuture.isSuccess());
awaitRequests();
verify(serverListener).onHeadersRead(any(ChannelHandlerContext.class), eq(3), eq(http2Headers), eq(0),
anyShort(), anyBoolean(), eq(0), eq(false));
verify(serverListener).onDataRead(any(ChannelHandlerContext.class), eq(3), any(ByteBuf.class), eq(0),
eq(false));
verify(serverListener).onHeadersRead(any(ChannelHandlerContext.class), eq(3), eq(http2TrailingHeaders), eq(0),
anyShort(), anyBoolean(), eq(0), eq(true));
assertEquals(1, receivedBuffers.size());
assertEquals(text, receivedBuffers.get(0));
}
@Test
public void testChunkedRequestWithBodyAndTrailingHeaders() throws Exception {
final String text = "foooooo";
final String text2 = "goooo";
final List<String> receivedBuffers = Collections.synchronizedList(new ArrayList<String>());
doAnswer(new Answer<Void>() {
@Override
public Void answer(InvocationOnMock in) throws Throwable {
receivedBuffers.add(((ByteBuf) in.getArguments()[2]).toString(UTF_8));
return null;
}
}).when(serverListener).onDataRead(any(ChannelHandlerContext.class), eq(3),
any(ByteBuf.class), eq(0), eq(false));
bootstrapEnv(4, 1, 1);
final HttpRequest request = new DefaultHttpRequest(HTTP_1_1, POST, "/example");
final HttpHeaders httpHeaders = request.headers();
httpHeaders.set(HttpHeaderNames.HOST, "http://your_user-name123@www.example.org:5555/example");
httpHeaders.add(HttpHeaderNames.TRANSFER_ENCODING, "chunked");
httpHeaders.add("foo", "goo");
httpHeaders.add("foo", "goo2");
httpHeaders.add("foo2", "goo2");
final Http2Headers http2Headers =
new DefaultHttp2Headers().method(as("POST")).path(as("/example"))
.authority(as("www.example.org:5555")).scheme(as("http"))
.add(as("foo"), as("goo")).add(as("foo"), as("goo2"))
.add(as("foo2"), as("goo2"));
final DefaultHttpContent httpContent = new DefaultHttpContent(Unpooled.copiedBuffer(text, UTF_8));
final LastHttpContent lastHttpContent = new DefaultLastHttpContent(Unpooled.copiedBuffer(text2, UTF_8));
lastHttpContent.trailingHeaders().add("trailing", "bar");
final Http2Headers http2TrailingHeaders = new DefaultHttp2Headers().add(as("trailing"), as("bar"));
ChannelPromise writePromise = newPromise();
ChannelFuture writeFuture = clientChannel.write(request, writePromise);
ChannelPromise contentPromise = newPromise();
ChannelFuture contentFuture = clientChannel.write(httpContent, contentPromise);
ChannelPromise lastContentPromise = newPromise();
ChannelFuture lastContentFuture = clientChannel.write(lastHttpContent, lastContentPromise);
clientChannel.flush();
assertTrue(writePromise.awaitUninterruptibly(WAIT_TIME_SECONDS, SECONDS));
assertTrue(writePromise.isSuccess());
assertTrue(writeFuture.awaitUninterruptibly(WAIT_TIME_SECONDS, SECONDS));
assertTrue(writeFuture.isSuccess());
assertTrue(contentPromise.awaitUninterruptibly(WAIT_TIME_SECONDS, SECONDS));
assertTrue(contentPromise.isSuccess());
assertTrue(contentFuture.awaitUninterruptibly(WAIT_TIME_SECONDS, SECONDS));
assertTrue(contentFuture.isSuccess());
assertTrue(lastContentPromise.awaitUninterruptibly(WAIT_TIME_SECONDS, SECONDS));
assertTrue(lastContentPromise.isSuccess());
assertTrue(lastContentFuture.awaitUninterruptibly(WAIT_TIME_SECONDS, SECONDS));
assertTrue(lastContentFuture.isSuccess());
awaitRequests();
verify(serverListener).onHeadersRead(any(ChannelHandlerContext.class), eq(3), eq(http2Headers), eq(0),
anyShort(), anyBoolean(), eq(0), eq(false));
verify(serverListener).onDataRead(any(ChannelHandlerContext.class), eq(3), any(ByteBuf.class), eq(0),
eq(false));
verify(serverListener).onHeadersRead(any(ChannelHandlerContext.class), eq(3), eq(http2TrailingHeaders), eq(0),
anyShort(), anyBoolean(), eq(0), eq(true));
assertEquals(1, receivedBuffers.size());
assertEquals(text + text2, receivedBuffers.get(0));
}
private void bootstrapEnv(int requestCountDown, int serverSettingsAckCount, int trailersCount) throws Exception {
requestLatch = new CountDownLatch(requestCountDown);
serverSettingsAckLatch = new CountDownLatch(serverSettingsAckCount);
trailersLatch = trailersCount == 0 ? null : new CountDownLatch(trailersCount);
sb = new ServerBootstrap();
cb = new Bootstrap();
@ -188,7 +312,8 @@ public class HttpToHttp2ConnectionHandlerTest {
@Override
protected void initChannel(Channel ch) throws Exception {
ChannelPipeline p = ch.pipeline();
serverFrameCountDown = new FrameCountDown(serverListener, serverSettingsAckLatch, requestLatch);
serverFrameCountDown =
new FrameCountDown(serverListener, serverSettingsAckLatch, requestLatch, null, trailersLatch);
p.addLast(new HttpToHttp2ConnectionHandler(true, serverFrameCountDown));
}
});
@ -213,6 +338,10 @@ public class HttpToHttp2ConnectionHandlerTest {
private void awaitRequests() throws Exception {
assertTrue(requestLatch.await(WAIT_TIME_SECONDS, SECONDS));
if (trailersLatch != null) {
assertTrue(trailersLatch.await(WAIT_TIME_SECONDS, SECONDS));
}
assertTrue(serverSettingsAckLatch.await(WAIT_TIME_SECONDS, SECONDS));
}
private ChannelHandlerContext ctx() {