Make Http2StreamFrameToHttpObjectCodec truly @Sharable (#8482)

Motivation:
The `Http2StreamFrameToHttpObjectCodec` is marked `@Sharable` but mutates
an internal `HttpScheme` field every time it is added to a pipeline.

Modifications:
Instead of storing the `HttpScheme` in the handler we store it as an
attribute on the parent channel.

Result:
Fixes #8480.
This commit is contained in:
Bryce Anderson 2018-11-09 10:23:53 -07:00 committed by Norman Maurer
parent e766469e87
commit a140e6dcad
2 changed files with 94 additions and 11 deletions

View File

@ -41,6 +41,8 @@ import io.netty.handler.codec.http.HttpUtil;
import io.netty.handler.codec.http.HttpVersion;
import io.netty.handler.codec.http.LastHttpContent;
import io.netty.handler.ssl.SslHandler;
import io.netty.util.Attribute;
import io.netty.util.AttributeKey;
import io.netty.util.internal.UnstableApi;
import java.util.List;
@ -57,16 +59,17 @@ import java.util.List;
@UnstableApi
@Sharable
public class Http2StreamFrameToHttpObjectCodec extends MessageToMessageCodec<Http2StreamFrame, HttpObject> {
private static final AttributeKey<HttpScheme> SCHEME_ATTR_KEY =
AttributeKey.valueOf(HttpScheme.class, "STREAMFRAMECODEC_SCHEME");
private final boolean isServer;
private final boolean validateHeaders;
private HttpScheme scheme;
public Http2StreamFrameToHttpObjectCodec(final boolean isServer,
final boolean validateHeaders) {
this.isServer = isServer;
this.validateHeaders = validateHeaders;
scheme = HttpScheme.HTTP;
}
public Http2StreamFrameToHttpObjectCodec(final boolean isServer) {
@ -154,7 +157,7 @@ public class Http2StreamFrameToHttpObjectCodec extends MessageToMessageCodec<Htt
final HttpResponse res = (HttpResponse) obj;
if (res.status().equals(HttpResponseStatus.CONTINUE)) {
if (res instanceof FullHttpResponse) {
final Http2Headers headers = toHttp2Headers(res);
final Http2Headers headers = toHttp2Headers(ctx, res);
out.add(new DefaultHttp2HeadersFrame(headers, false));
return;
} else {
@ -165,7 +168,7 @@ public class Http2StreamFrameToHttpObjectCodec extends MessageToMessageCodec<Htt
}
if (obj instanceof HttpMessage) {
Http2Headers headers = toHttp2Headers((HttpMessage) obj);
Http2Headers headers = toHttp2Headers(ctx, (HttpMessage) obj);
boolean noMoreFrames = false;
if (obj instanceof FullHttpMessage) {
FullHttpMessage full = (FullHttpMessage) obj;
@ -184,11 +187,11 @@ public class Http2StreamFrameToHttpObjectCodec extends MessageToMessageCodec<Htt
}
}
private Http2Headers toHttp2Headers(final HttpMessage msg) {
private Http2Headers toHttp2Headers(final ChannelHandlerContext ctx, final HttpMessage msg) {
if (msg instanceof HttpRequest) {
msg.headers().set(
HttpConversionUtil.ExtensionHeaderNames.SCHEME.text(),
scheme.name());
connectionScheme(ctx));
}
return HttpConversionUtil.toHttp2Headers(msg, validateHeaders);
@ -213,17 +216,35 @@ public class Http2StreamFrameToHttpObjectCodec extends MessageToMessageCodec<Htt
public void handlerAdded(final ChannelHandlerContext ctx) throws Exception {
super.handlerAdded(ctx);
// this handler is typically used on an Http2StreamChannel. at this
// this handler is typically used on an Http2StreamChannel. At this
// stage, ssl handshake should've been established. checking for the
// presence of SslHandler in the parent's channel pipeline to
// determine the HTTP scheme should suffice, even for the case where
// SniHandler is used.
scheme = isSsl(ctx) ? HttpScheme.HTTPS : HttpScheme.HTTP;
final Attribute<HttpScheme> schemeAttribute = connectionSchemeAttribute(ctx);
if (schemeAttribute.get() == null) {
final HttpScheme scheme = isSsl(ctx) ? HttpScheme.HTTPS : HttpScheme.HTTP;
schemeAttribute.set(scheme);
}
}
protected boolean isSsl(final ChannelHandlerContext ctx) {
final Channel ch = ctx.channel();
final Channel connChannel = (ch instanceof Http2StreamChannel) ? ch.parent() : ch;
final Channel connChannel = connectionChannel(ctx);
return null != connChannel.pipeline().get(SslHandler.class);
}
private static HttpScheme connectionScheme(ChannelHandlerContext ctx) {
final HttpScheme scheme = connectionSchemeAttribute(ctx).get();
return scheme == null ? HttpScheme.HTTP : scheme;
}
private static Attribute<HttpScheme> connectionSchemeAttribute(ChannelHandlerContext ctx) {
final Channel ch = connectionChannel(ctx);
return ch.attr(SCHEME_ATTR_KEY);
}
private static Channel connectionChannel(ChannelHandlerContext ctx) {
final Channel ch = ctx.channel();
return ch instanceof Http2StreamChannel ? ch.parent() : ch;
}
}

View File

@ -19,6 +19,7 @@ package io.netty.handler.codec.http2;
import io.netty.buffer.ByteBuf;
import io.netty.buffer.ByteBufAllocator;
import io.netty.buffer.Unpooled;
import io.netty.channel.ChannelHandler;
import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.ChannelOutboundHandlerAdapter;
import io.netty.channel.ChannelPromise;
@ -871,4 +872,65 @@ public class Http2StreamFrameToHttpObjectCodecTest {
frame.release();
}
}
@Test
public void testIsSharableBetweenChannels() throws Exception {
final Queue<Http2StreamFrame> frames = new ConcurrentLinkedQueue<Http2StreamFrame>();
final ChannelHandler sharedHandler = new Http2StreamFrameToHttpObjectCodec(false);
final SslContext ctx = SslContextBuilder.forClient().sslProvider(SslProvider.JDK).build();
EmbeddedChannel tlsCh = new EmbeddedChannel(ctx.newHandler(ByteBufAllocator.DEFAULT),
new ChannelOutboundHandlerAdapter() {
@Override
public void write(ChannelHandlerContext ctx, Object msg, ChannelPromise promise) {
if (msg instanceof Http2StreamFrame) {
frames.add((Http2StreamFrame) msg);
promise.setSuccess();
} else {
ctx.write(msg, promise);
}
}
}, sharedHandler);
EmbeddedChannel plaintextCh = new EmbeddedChannel(
new ChannelOutboundHandlerAdapter() {
@Override
public void write(ChannelHandlerContext ctx, Object msg, ChannelPromise promise) {
if (msg instanceof Http2StreamFrame) {
frames.add((Http2StreamFrame) msg);
promise.setSuccess();
} else {
ctx.write(msg, promise);
}
}
}, sharedHandler);
FullHttpRequest req = new DefaultFullHttpRequest(
HttpVersion.HTTP_1_1, HttpMethod.GET, "/hello/world");
assertTrue(tlsCh.writeOutbound(req));
assertTrue(tlsCh.finishAndReleaseAll());
Http2HeadersFrame headersFrame = (Http2HeadersFrame) frames.poll();
Http2Headers headers = headersFrame.headers();
assertThat(headers.scheme().toString(), is("https"));
assertThat(headers.method().toString(), is("GET"));
assertThat(headers.path().toString(), is("/hello/world"));
assertTrue(headersFrame.isEndStream());
assertNull(frames.poll());
// Run the plaintext channel
req = new DefaultFullHttpRequest(HttpVersion.HTTP_1_1, HttpMethod.GET, "/hello/world");
assertFalse(plaintextCh.writeOutbound(req));
assertFalse(plaintextCh.finishAndReleaseAll());
headersFrame = (Http2HeadersFrame) frames.poll();
headers = headersFrame.headers();
assertThat(headers.scheme().toString(), is("http"));
assertThat(headers.method().toString(), is("GET"));
assertThat(headers.path().toString(), is("/hello/world"));
assertTrue(headersFrame.isEndStream());
assertNull(frames.poll());
}
}