Support conversion of HttpMessage and HttpContent to HTTP/2 Frames
Motivation: HttpToHttp2ConnectionHandler only converts FullHttpMessage to HTTP/2 Frames. This does not support other use cases such as adding a HttpContentCompressor to the pipeline, which writes HttpMessage and HttpContent. Additionally HttpToHttp2ConnectionHandler ignores converting and sending HTTP trailing headers, which is a bug as the HTTP/2 spec states that they should be sent. Modifications: Update HttpToHttp2ConnectionHandler to support converting HttpMessage and HttpContent to HTTP/2 Frames. Additionally, include an extra call to writeHeaders if the message includes trailing headers Result: One can now write HttpMessage and HttpContent (http chunking) down the pipeline and they will be converted to HTTP/2 Frames. If any trailing headers exist, they will be converted and sent as well.
This commit is contained in:
parent
1a5dac175e
commit
57d28dd421
@ -15,11 +15,16 @@
|
||||
|
||||
package io.netty.handler.codec.http2;
|
||||
|
||||
import io.netty.buffer.ByteBuf;
|
||||
import io.netty.channel.ChannelHandlerContext;
|
||||
import io.netty.channel.ChannelPromise;
|
||||
import io.netty.handler.codec.http.FullHttpMessage;
|
||||
import io.netty.handler.codec.http.HttpContent;
|
||||
import io.netty.handler.codec.http.HttpHeaders;
|
||||
import io.netty.handler.codec.http.HttpMessage;
|
||||
import io.netty.handler.codec.http.LastHttpContent;
|
||||
import io.netty.handler.codec.http2.Http2CodecUtil.SimpleChannelPromiseAggregator;
|
||||
import io.netty.util.ReferenceCountUtil;
|
||||
|
||||
/**
|
||||
* Translates HTTP/1.x object writes into HTTP/2 frames.
|
||||
@ -27,6 +32,9 @@ import io.netty.handler.codec.http2.Http2CodecUtil.SimpleChannelPromiseAggregato
|
||||
* See {@link InboundHttp2ToHttpAdapter} to get translation from HTTP/2 frames to HTTP/1.x objects.
|
||||
*/
|
||||
public class HttpToHttp2ConnectionHandler extends Http2ConnectionHandler {
|
||||
|
||||
private int currentStreamId;
|
||||
|
||||
public HttpToHttp2ConnectionHandler(boolean server, Http2FrameListener listener) {
|
||||
super(server, listener);
|
||||
}
|
||||
@ -57,45 +65,64 @@ public class HttpToHttp2ConnectionHandler extends Http2ConnectionHandler {
|
||||
}
|
||||
|
||||
/**
|
||||
* Handles conversion of a {@link FullHttpMessage} to HTTP/2 frames.
|
||||
* Handles conversion of {@link HttpMessage} and {@link HttpContent} to HTTP/2 frames.
|
||||
*/
|
||||
@Override
|
||||
public void write(ChannelHandlerContext ctx, Object msg, ChannelPromise promise) {
|
||||
if (msg instanceof FullHttpMessage) {
|
||||
FullHttpMessage httpMsg = (FullHttpMessage) msg;
|
||||
boolean hasData = httpMsg.content().isReadable();
|
||||
boolean httpMsgNeedRelease = true;
|
||||
SimpleChannelPromiseAggregator promiseAggregator = null;
|
||||
try {
|
||||
|
||||
if (!(msg instanceof HttpMessage || msg instanceof HttpContent)) {
|
||||
ctx.write(msg, promise);
|
||||
return;
|
||||
}
|
||||
|
||||
boolean release = true;
|
||||
SimpleChannelPromiseAggregator promiseAggregator =
|
||||
new SimpleChannelPromiseAggregator(promise, ctx.channel(), ctx.executor());
|
||||
try {
|
||||
Http2ConnectionEncoder encoder = encoder();
|
||||
boolean endStream = false;
|
||||
if (msg instanceof HttpMessage) {
|
||||
final HttpMessage httpMsg = (HttpMessage) msg;
|
||||
|
||||
// Provide the user the opportunity to specify the streamId
|
||||
int streamId = getStreamId(httpMsg.headers());
|
||||
currentStreamId = getStreamId(httpMsg.headers());
|
||||
|
||||
// Convert and write the headers.
|
||||
Http2Headers http2Headers = HttpUtil.toHttp2Headers(httpMsg);
|
||||
Http2ConnectionEncoder encoder = encoder();
|
||||
endStream = msg instanceof FullHttpMessage && !((FullHttpMessage) msg).content().isReadable();
|
||||
encoder.writeHeaders(ctx, currentStreamId, http2Headers, 0, endStream, promiseAggregator.newPromise());
|
||||
}
|
||||
|
||||
if (hasData) {
|
||||
promiseAggregator = new SimpleChannelPromiseAggregator(promise, ctx.channel(), ctx.executor());
|
||||
encoder.writeHeaders(ctx, streamId, http2Headers, 0, false, promiseAggregator.newPromise());
|
||||
httpMsgNeedRelease = false;
|
||||
encoder.writeData(ctx, streamId, httpMsg.content(), 0, true, promiseAggregator.newPromise());
|
||||
promiseAggregator.doneAllocatingPromises();
|
||||
} else {
|
||||
encoder.writeHeaders(ctx, streamId, http2Headers, 0, true, promise);
|
||||
if (!endStream && msg instanceof HttpContent) {
|
||||
boolean isLastContent = false;
|
||||
Http2Headers trailers = EmptyHttp2Headers.INSTANCE;
|
||||
if (msg instanceof LastHttpContent) {
|
||||
isLastContent = true;
|
||||
|
||||
// Convert any trailing headers.
|
||||
final LastHttpContent lastContent = (LastHttpContent) msg;
|
||||
trailers = HttpUtil.toHttp2Headers(lastContent.trailingHeaders());
|
||||
}
|
||||
} catch (Throwable t) {
|
||||
if (promiseAggregator == null) {
|
||||
promise.tryFailure(t);
|
||||
} else {
|
||||
promiseAggregator.setFailure(t);
|
||||
}
|
||||
} finally {
|
||||
if (httpMsgNeedRelease) {
|
||||
httpMsg.release();
|
||||
|
||||
// Write the data
|
||||
final ByteBuf content = ((HttpContent) msg).content();
|
||||
endStream = isLastContent && trailers.isEmpty();
|
||||
release = false;
|
||||
encoder.writeData(ctx, currentStreamId, content, 0, endStream, promiseAggregator.newPromise());
|
||||
|
||||
if (!trailers.isEmpty()) {
|
||||
// Write trailing headers.
|
||||
encoder.writeHeaders(ctx, currentStreamId, trailers, 0, true, promiseAggregator.newPromise());
|
||||
}
|
||||
}
|
||||
} else {
|
||||
ctx.write(msg, promise);
|
||||
|
||||
promiseAggregator.doneAllocatingPromises();
|
||||
} catch (Throwable t) {
|
||||
promiseAggregator.setFailure(t);
|
||||
} finally {
|
||||
if (release) {
|
||||
ReferenceCountUtil.release(msg);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -29,6 +29,7 @@ import io.netty.handler.codec.http.HttpHeaderNames;
|
||||
import io.netty.handler.codec.http.HttpHeaderUtil;
|
||||
import io.netty.handler.codec.http.HttpHeaderValues;
|
||||
import io.netty.handler.codec.http.HttpHeaders;
|
||||
import io.netty.handler.codec.http.HttpMessage;
|
||||
import io.netty.handler.codec.http.HttpMethod;
|
||||
import io.netty.handler.codec.http.HttpRequest;
|
||||
import io.netty.handler.codec.http.HttpResponse;
|
||||
@ -265,7 +266,7 @@ public final class HttpUtil {
|
||||
/**
|
||||
* Converts the given HTTP/1.x headers into HTTP/2 headers.
|
||||
*/
|
||||
public static Http2Headers toHttp2Headers(FullHttpMessage in) throws Exception {
|
||||
public static Http2Headers toHttp2Headers(HttpMessage in) throws Exception {
|
||||
final Http2Headers out = new DefaultHttp2Headers();
|
||||
HttpHeaders inHeaders = in.headers();
|
||||
if (in instanceof HttpRequest) {
|
||||
@ -304,6 +305,16 @@ public final class HttpUtil {
|
||||
}
|
||||
|
||||
// Add the HTTP headers which have not been consumed above
|
||||
return out.add(toHttp2Headers(inHeaders));
|
||||
}
|
||||
|
||||
public static Http2Headers toHttp2Headers(HttpHeaders inHeaders) throws Exception {
|
||||
if (inHeaders.isEmpty()) {
|
||||
return EmptyHttp2Headers.INSTANCE;
|
||||
}
|
||||
|
||||
final Http2Headers out = new DefaultHttp2Headers();
|
||||
|
||||
inHeaders.forEachEntry(new EntryVisitor() {
|
||||
@Override
|
||||
public boolean visit(Entry<CharSequence, CharSequence> entry) throws Exception {
|
||||
|
@ -45,9 +45,14 @@ import io.netty.channel.nio.NioEventLoopGroup;
|
||||
import io.netty.channel.socket.nio.NioServerSocketChannel;
|
||||
import io.netty.channel.socket.nio.NioSocketChannel;
|
||||
import io.netty.handler.codec.http.DefaultFullHttpRequest;
|
||||
import io.netty.handler.codec.http.DefaultHttpContent;
|
||||
import io.netty.handler.codec.http.DefaultHttpRequest;
|
||||
import io.netty.handler.codec.http.DefaultLastHttpContent;
|
||||
import io.netty.handler.codec.http.FullHttpRequest;
|
||||
import io.netty.handler.codec.http.HttpHeaderNames;
|
||||
import io.netty.handler.codec.http.HttpHeaders;
|
||||
import io.netty.handler.codec.http.HttpRequest;
|
||||
import io.netty.handler.codec.http.LastHttpContent;
|
||||
import io.netty.handler.codec.http2.Http2TestUtil.FrameCountDown;
|
||||
import io.netty.util.NetUtil;
|
||||
import io.netty.util.concurrent.Future;
|
||||
@ -84,6 +89,7 @@ public class HttpToHttp2ConnectionHandlerTest {
|
||||
private Channel clientChannel;
|
||||
private CountDownLatch requestLatch;
|
||||
private CountDownLatch serverSettingsAckLatch;
|
||||
private CountDownLatch trailersLatch;
|
||||
private FrameCountDown serverFrameCountDown;
|
||||
|
||||
@Before
|
||||
@ -104,7 +110,7 @@ public class HttpToHttp2ConnectionHandlerTest {
|
||||
|
||||
@Test
|
||||
public void testJustHeadersRequest() throws Exception {
|
||||
bootstrapEnv(2, 1);
|
||||
bootstrapEnv(2, 1, 0);
|
||||
final FullHttpRequest request = new DefaultFullHttpRequest(HTTP_1_1, GET, "/example");
|
||||
final HttpHeaders httpHeaders = request.headers();
|
||||
httpHeaders.setInt(HttpUtil.ExtensionHeaderNames.STREAM_ID.text(), 5);
|
||||
@ -146,7 +152,7 @@ public class HttpToHttp2ConnectionHandlerTest {
|
||||
}
|
||||
}).when(serverListener).onDataRead(any(ChannelHandlerContext.class), eq(3),
|
||||
any(ByteBuf.class), eq(0), eq(true));
|
||||
bootstrapEnv(3, 1);
|
||||
bootstrapEnv(3, 1, 0);
|
||||
final FullHttpRequest request = new DefaultFullHttpRequest(HTTP_1_1, POST, "/example",
|
||||
Unpooled.copiedBuffer(text, UTF_8));
|
||||
final HttpHeaders httpHeaders = request.headers();
|
||||
@ -175,9 +181,127 @@ public class HttpToHttp2ConnectionHandlerTest {
|
||||
assertEquals(text, receivedBuffers.get(0));
|
||||
}
|
||||
|
||||
private void bootstrapEnv(int requestCountDown, int serverSettingsAckCount) throws Exception {
|
||||
@Test
|
||||
public void testRequestWithBodyAndTrailingHeaders() throws Exception {
|
||||
final String text = "foooooogoooo";
|
||||
final List<String> receivedBuffers = Collections.synchronizedList(new ArrayList<String>());
|
||||
doAnswer(new Answer<Void>() {
|
||||
@Override
|
||||
public Void answer(InvocationOnMock in) throws Throwable {
|
||||
receivedBuffers.add(((ByteBuf) in.getArguments()[2]).toString(UTF_8));
|
||||
return null;
|
||||
}
|
||||
}).when(serverListener).onDataRead(any(ChannelHandlerContext.class), eq(3),
|
||||
any(ByteBuf.class), eq(0), eq(false));
|
||||
bootstrapEnv(4, 1, 1);
|
||||
final FullHttpRequest request = new DefaultFullHttpRequest(HTTP_1_1, POST, "/example",
|
||||
Unpooled.copiedBuffer(text, UTF_8));
|
||||
final HttpHeaders httpHeaders = request.headers();
|
||||
httpHeaders.set(HttpHeaderNames.HOST, "http://your_user-name123@www.example.org:5555/example");
|
||||
httpHeaders.add("foo", "goo");
|
||||
httpHeaders.add("foo", "goo2");
|
||||
httpHeaders.add("foo2", "goo2");
|
||||
final Http2Headers http2Headers =
|
||||
new DefaultHttp2Headers().method(as("POST")).path(as("/example"))
|
||||
.authority(as("www.example.org:5555")).scheme(as("http"))
|
||||
.add(as("foo"), as("goo")).add(as("foo"), as("goo2"))
|
||||
.add(as("foo2"), as("goo2"));
|
||||
|
||||
request.trailingHeaders().add("trailing", "bar");
|
||||
|
||||
final Http2Headers http2TrailingHeaders = new DefaultHttp2Headers().add(as("trailing"), as("bar"));
|
||||
|
||||
ChannelPromise writePromise = newPromise();
|
||||
ChannelFuture writeFuture = clientChannel.writeAndFlush(request, writePromise);
|
||||
|
||||
assertTrue(writePromise.awaitUninterruptibly(WAIT_TIME_SECONDS, SECONDS));
|
||||
assertTrue(writePromise.isSuccess());
|
||||
assertTrue(writeFuture.awaitUninterruptibly(WAIT_TIME_SECONDS, SECONDS));
|
||||
assertTrue(writeFuture.isSuccess());
|
||||
awaitRequests();
|
||||
verify(serverListener).onHeadersRead(any(ChannelHandlerContext.class), eq(3), eq(http2Headers), eq(0),
|
||||
anyShort(), anyBoolean(), eq(0), eq(false));
|
||||
verify(serverListener).onDataRead(any(ChannelHandlerContext.class), eq(3), any(ByteBuf.class), eq(0),
|
||||
eq(false));
|
||||
verify(serverListener).onHeadersRead(any(ChannelHandlerContext.class), eq(3), eq(http2TrailingHeaders), eq(0),
|
||||
anyShort(), anyBoolean(), eq(0), eq(true));
|
||||
assertEquals(1, receivedBuffers.size());
|
||||
assertEquals(text, receivedBuffers.get(0));
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testChunkedRequestWithBodyAndTrailingHeaders() throws Exception {
|
||||
final String text = "foooooo";
|
||||
final String text2 = "goooo";
|
||||
final List<String> receivedBuffers = Collections.synchronizedList(new ArrayList<String>());
|
||||
doAnswer(new Answer<Void>() {
|
||||
@Override
|
||||
public Void answer(InvocationOnMock in) throws Throwable {
|
||||
receivedBuffers.add(((ByteBuf) in.getArguments()[2]).toString(UTF_8));
|
||||
return null;
|
||||
}
|
||||
}).when(serverListener).onDataRead(any(ChannelHandlerContext.class), eq(3),
|
||||
any(ByteBuf.class), eq(0), eq(false));
|
||||
bootstrapEnv(4, 1, 1);
|
||||
final HttpRequest request = new DefaultHttpRequest(HTTP_1_1, POST, "/example");
|
||||
final HttpHeaders httpHeaders = request.headers();
|
||||
httpHeaders.set(HttpHeaderNames.HOST, "http://your_user-name123@www.example.org:5555/example");
|
||||
httpHeaders.add(HttpHeaderNames.TRANSFER_ENCODING, "chunked");
|
||||
httpHeaders.add("foo", "goo");
|
||||
httpHeaders.add("foo", "goo2");
|
||||
httpHeaders.add("foo2", "goo2");
|
||||
final Http2Headers http2Headers =
|
||||
new DefaultHttp2Headers().method(as("POST")).path(as("/example"))
|
||||
.authority(as("www.example.org:5555")).scheme(as("http"))
|
||||
.add(as("foo"), as("goo")).add(as("foo"), as("goo2"))
|
||||
.add(as("foo2"), as("goo2"));
|
||||
|
||||
final DefaultHttpContent httpContent = new DefaultHttpContent(Unpooled.copiedBuffer(text, UTF_8));
|
||||
final LastHttpContent lastHttpContent = new DefaultLastHttpContent(Unpooled.copiedBuffer(text2, UTF_8));
|
||||
|
||||
lastHttpContent.trailingHeaders().add("trailing", "bar");
|
||||
|
||||
final Http2Headers http2TrailingHeaders = new DefaultHttp2Headers().add(as("trailing"), as("bar"));
|
||||
|
||||
ChannelPromise writePromise = newPromise();
|
||||
ChannelFuture writeFuture = clientChannel.write(request, writePromise);
|
||||
ChannelPromise contentPromise = newPromise();
|
||||
ChannelFuture contentFuture = clientChannel.write(httpContent, contentPromise);
|
||||
ChannelPromise lastContentPromise = newPromise();
|
||||
ChannelFuture lastContentFuture = clientChannel.write(lastHttpContent, lastContentPromise);
|
||||
|
||||
clientChannel.flush();
|
||||
|
||||
assertTrue(writePromise.awaitUninterruptibly(WAIT_TIME_SECONDS, SECONDS));
|
||||
assertTrue(writePromise.isSuccess());
|
||||
assertTrue(writeFuture.awaitUninterruptibly(WAIT_TIME_SECONDS, SECONDS));
|
||||
assertTrue(writeFuture.isSuccess());
|
||||
|
||||
assertTrue(contentPromise.awaitUninterruptibly(WAIT_TIME_SECONDS, SECONDS));
|
||||
assertTrue(contentPromise.isSuccess());
|
||||
assertTrue(contentFuture.awaitUninterruptibly(WAIT_TIME_SECONDS, SECONDS));
|
||||
assertTrue(contentFuture.isSuccess());
|
||||
|
||||
assertTrue(lastContentPromise.awaitUninterruptibly(WAIT_TIME_SECONDS, SECONDS));
|
||||
assertTrue(lastContentPromise.isSuccess());
|
||||
assertTrue(lastContentFuture.awaitUninterruptibly(WAIT_TIME_SECONDS, SECONDS));
|
||||
assertTrue(lastContentFuture.isSuccess());
|
||||
|
||||
awaitRequests();
|
||||
verify(serverListener).onHeadersRead(any(ChannelHandlerContext.class), eq(3), eq(http2Headers), eq(0),
|
||||
anyShort(), anyBoolean(), eq(0), eq(false));
|
||||
verify(serverListener).onDataRead(any(ChannelHandlerContext.class), eq(3), any(ByteBuf.class), eq(0),
|
||||
eq(false));
|
||||
verify(serverListener).onHeadersRead(any(ChannelHandlerContext.class), eq(3), eq(http2TrailingHeaders), eq(0),
|
||||
anyShort(), anyBoolean(), eq(0), eq(true));
|
||||
assertEquals(1, receivedBuffers.size());
|
||||
assertEquals(text + text2, receivedBuffers.get(0));
|
||||
}
|
||||
|
||||
private void bootstrapEnv(int requestCountDown, int serverSettingsAckCount, int trailersCount) throws Exception {
|
||||
requestLatch = new CountDownLatch(requestCountDown);
|
||||
serverSettingsAckLatch = new CountDownLatch(serverSettingsAckCount);
|
||||
trailersLatch = trailersCount == 0 ? null : new CountDownLatch(trailersCount);
|
||||
|
||||
sb = new ServerBootstrap();
|
||||
cb = new Bootstrap();
|
||||
@ -188,7 +312,8 @@ public class HttpToHttp2ConnectionHandlerTest {
|
||||
@Override
|
||||
protected void initChannel(Channel ch) throws Exception {
|
||||
ChannelPipeline p = ch.pipeline();
|
||||
serverFrameCountDown = new FrameCountDown(serverListener, serverSettingsAckLatch, requestLatch);
|
||||
serverFrameCountDown =
|
||||
new FrameCountDown(serverListener, serverSettingsAckLatch, requestLatch, null, trailersLatch);
|
||||
p.addLast(new HttpToHttp2ConnectionHandler(true, serverFrameCountDown));
|
||||
}
|
||||
});
|
||||
@ -213,6 +338,10 @@ public class HttpToHttp2ConnectionHandlerTest {
|
||||
|
||||
private void awaitRequests() throws Exception {
|
||||
assertTrue(requestLatch.await(WAIT_TIME_SECONDS, SECONDS));
|
||||
if (trailersLatch != null) {
|
||||
assertTrue(trailersLatch.await(WAIT_TIME_SECONDS, SECONDS));
|
||||
}
|
||||
assertTrue(serverSettingsAckLatch.await(WAIT_TIME_SECONDS, SECONDS));
|
||||
}
|
||||
|
||||
private ChannelHandlerContext ctx() {
|
||||
|
Loading…
Reference in New Issue
Block a user