Ensure the same ByteBufAllocator is used in the EmbeddedChannel when compress / decompress. Related to [#5294]

Motivation:

The user may specify to use a different allocator then the default. In this case we need to ensure it is shared when creating the EmbeddedChannel inside of a ChannelHandler

Modifications:

Use the config of the "original" Channel in the EmbeddedChannel and so share the same allocator etc.

Result:

Same type of buffers are used.
This commit is contained in:
Norman Maurer 2016-05-22 20:02:38 +02:00
parent 4c186c4c41
commit 844976a0a2
6 changed files with 75 additions and 22 deletions

View File

@ -15,6 +15,7 @@
*/ */
package io.netty.handler.codec.http; package io.netty.handler.codec.http;
import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.embedded.EmbeddedChannel; import io.netty.channel.embedded.EmbeddedChannel;
import io.netty.handler.codec.compression.ZlibCodecFactory; import io.netty.handler.codec.compression.ZlibCodecFactory;
import io.netty.handler.codec.compression.ZlibWrapper; import io.netty.handler.codec.compression.ZlibWrapper;
@ -32,6 +33,7 @@ public class HttpContentCompressor extends HttpContentEncoder {
private final int compressionLevel; private final int compressionLevel;
private final int windowBits; private final int windowBits;
private final int memLevel; private final int memLevel;
private ChannelHandlerContext ctx;
/** /**
* Creates a new handler with the default compression level (<tt>6</tt>), * Creates a new handler with the default compression level (<tt>6</tt>),
@ -92,6 +94,11 @@ public class HttpContentCompressor extends HttpContentEncoder {
this.memLevel = memLevel; this.memLevel = memLevel;
} }
@Override
public void handlerAdded(ChannelHandlerContext ctx) throws Exception {
this.ctx = ctx;
}
@Override @Override
protected Result beginEncode(HttpResponse headers, String acceptEncoding) throws Exception { protected Result beginEncode(HttpResponse headers, String acceptEncoding) throws Exception {
String contentEncoding = headers.headers().get(HttpHeaderNames.CONTENT_ENCODING); String contentEncoding = headers.headers().get(HttpHeaderNames.CONTENT_ENCODING);
@ -119,7 +126,8 @@ public class HttpContentCompressor extends HttpContentEncoder {
return new Result( return new Result(
targetContentEncoding, targetContentEncoding,
new EmbeddedChannel(ZlibCodecFactory.newZlibEncoder( new EmbeddedChannel(ctx.channel().id(), ctx.channel().metadata().hasDisconnect(),
ctx.channel().config(), ZlibCodecFactory.newZlibEncoder(
wrapper, compressionLevel, windowBits, memLevel))); wrapper, compressionLevel, windowBits, memLevel)));
} }

View File

@ -47,6 +47,7 @@ public abstract class HttpContentDecoder extends MessageToMessageDecoder<HttpObj
static final String IDENTITY = HttpHeaderValues.IDENTITY.toString(); static final String IDENTITY = HttpHeaderValues.IDENTITY.toString();
protected ChannelHandlerContext ctx;
private EmbeddedChannel decoder; private EmbeddedChannel decoder;
private boolean continueResponse; private boolean continueResponse;
@ -199,6 +200,12 @@ public abstract class HttpContentDecoder extends MessageToMessageDecoder<HttpObj
super.channelInactive(ctx); super.channelInactive(ctx);
} }
@Override
public void handlerAdded(ChannelHandlerContext ctx) throws Exception {
this.ctx = ctx;
super.handlerAdded(ctx);
}
private void cleanup() { private void cleanup() {
if (decoder != null) { if (decoder != null) {
// Clean-up the previous decoder if not cleaned up correctly. // Clean-up the previous decoder if not cleaned up correctly.

View File

@ -19,6 +19,7 @@ import static io.netty.handler.codec.http.HttpHeaderValues.DEFLATE;
import static io.netty.handler.codec.http.HttpHeaderValues.GZIP; import static io.netty.handler.codec.http.HttpHeaderValues.GZIP;
import static io.netty.handler.codec.http.HttpHeaderValues.X_DEFLATE; import static io.netty.handler.codec.http.HttpHeaderValues.X_DEFLATE;
import static io.netty.handler.codec.http.HttpHeaderValues.X_GZIP; import static io.netty.handler.codec.http.HttpHeaderValues.X_GZIP;
import io.netty.channel.embedded.EmbeddedChannel; import io.netty.channel.embedded.EmbeddedChannel;
import io.netty.handler.codec.compression.ZlibCodecFactory; import io.netty.handler.codec.compression.ZlibCodecFactory;
import io.netty.handler.codec.compression.ZlibWrapper; import io.netty.handler.codec.compression.ZlibWrapper;
@ -53,13 +54,15 @@ public class HttpContentDecompressor extends HttpContentDecoder {
protected EmbeddedChannel newContentDecoder(String contentEncoding) throws Exception { protected EmbeddedChannel newContentDecoder(String contentEncoding) throws Exception {
if (GZIP.contentEqualsIgnoreCase(contentEncoding) || if (GZIP.contentEqualsIgnoreCase(contentEncoding) ||
X_GZIP.contentEqualsIgnoreCase(contentEncoding)) { X_GZIP.contentEqualsIgnoreCase(contentEncoding)) {
return new EmbeddedChannel(ZlibCodecFactory.newZlibDecoder(ZlibWrapper.GZIP)); return new EmbeddedChannel(ctx.channel().id(), ctx.channel().metadata().hasDisconnect(),
ctx.channel().config(), ZlibCodecFactory.newZlibDecoder(ZlibWrapper.GZIP));
} }
if (DEFLATE.contentEqualsIgnoreCase(contentEncoding) || if (DEFLATE.contentEqualsIgnoreCase(contentEncoding) ||
X_DEFLATE.contentEqualsIgnoreCase(contentEncoding)) { X_DEFLATE.contentEqualsIgnoreCase(contentEncoding)) {
final ZlibWrapper wrapper = strict ? ZlibWrapper.ZLIB : ZlibWrapper.ZLIB_OR_NONE; final ZlibWrapper wrapper = strict ? ZlibWrapper.ZLIB : ZlibWrapper.ZLIB_OR_NONE;
// To be strict, 'deflate' means ZLIB, but some servers were not implemented correctly. // To be strict, 'deflate' means ZLIB, but some servers were not implemented correctly.
return new EmbeddedChannel(ZlibCodecFactory.newZlibDecoder(wrapper)); return new EmbeddedChannel(ctx.channel().id(), ctx.channel().metadata().hasDisconnect(),
ctx.channel().config(), ZlibCodecFactory.newZlibDecoder(wrapper));
} }
// 'identity' or unsupported // 'identity' or unsupported

View File

@ -143,7 +143,7 @@ public class CompressorHttp2ConnectionEncoder extends DecoratingHttp2ConnectionE
boolean endStream, ChannelPromise promise) { boolean endStream, ChannelPromise promise) {
try { try {
// Determine if compression is required and sanitize the headers. // Determine if compression is required and sanitize the headers.
EmbeddedChannel compressor = newCompressor(headers, endStream); EmbeddedChannel compressor = newCompressor(ctx, headers, endStream);
// Write the headers and create the stream object. // Write the headers and create the stream object.
ChannelFuture future = super.writeHeaders(ctx, streamId, headers, padding, endStream, promise); ChannelFuture future = super.writeHeaders(ctx, streamId, headers, padding, endStream, promise);
@ -164,7 +164,7 @@ public class CompressorHttp2ConnectionEncoder extends DecoratingHttp2ConnectionE
final boolean endOfStream, final ChannelPromise promise) { final boolean endOfStream, final ChannelPromise promise) {
try { try {
// Determine if compression is required and sanitize the headers. // Determine if compression is required and sanitize the headers.
EmbeddedChannel compressor = newCompressor(headers, endOfStream); EmbeddedChannel compressor = newCompressor(ctx, headers, endOfStream);
// Write the headers and create the stream object. // Write the headers and create the stream object.
ChannelFuture future = super.writeHeaders(ctx, streamId, headers, streamDependency, weight, exclusive, ChannelFuture future = super.writeHeaders(ctx, streamId, headers, streamDependency, weight, exclusive,
@ -184,17 +184,19 @@ public class CompressorHttp2ConnectionEncoder extends DecoratingHttp2ConnectionE
* Returns a new {@link EmbeddedChannel} that encodes the HTTP2 message content encoded in the specified * Returns a new {@link EmbeddedChannel} that encodes the HTTP2 message content encoded in the specified
* {@code contentEncoding}. * {@code contentEncoding}.
* *
* @param ctx the context.
* @param contentEncoding the value of the {@code content-encoding} header * @param contentEncoding the value of the {@code content-encoding} header
* @return a new {@link ByteToMessageDecoder} if the specified encoding is supported. {@code null} otherwise * @return a new {@link ByteToMessageDecoder} if the specified encoding is supported. {@code null} otherwise
* (alternatively, you can throw a {@link Http2Exception} to block unknown encoding). * (alternatively, you can throw a {@link Http2Exception} to block unknown encoding).
* @throws Http2Exception If the specified encoding is not not supported and warrants an exception * @throws Http2Exception If the specified encoding is not not supported and warrants an exception
*/ */
protected EmbeddedChannel newContentCompressor(CharSequence contentEncoding) throws Http2Exception { protected EmbeddedChannel newContentCompressor(ChannelHandlerContext ctx, CharSequence contentEncoding)
throws Http2Exception {
if (GZIP.contentEqualsIgnoreCase(contentEncoding) || X_GZIP.contentEqualsIgnoreCase(contentEncoding)) { if (GZIP.contentEqualsIgnoreCase(contentEncoding) || X_GZIP.contentEqualsIgnoreCase(contentEncoding)) {
return newCompressionChannel(ZlibWrapper.GZIP); return newCompressionChannel(ctx, ZlibWrapper.GZIP);
} }
if (DEFLATE.contentEqualsIgnoreCase(contentEncoding) || X_DEFLATE.contentEqualsIgnoreCase(contentEncoding)) { if (DEFLATE.contentEqualsIgnoreCase(contentEncoding) || X_DEFLATE.contentEqualsIgnoreCase(contentEncoding)) {
return newCompressionChannel(ZlibWrapper.ZLIB); return newCompressionChannel(ctx, ZlibWrapper.ZLIB);
} }
// 'identity' or unsupported // 'identity' or unsupported
return null; return null;
@ -214,10 +216,12 @@ public class CompressorHttp2ConnectionEncoder extends DecoratingHttp2ConnectionE
/** /**
* Generate a new instance of an {@link EmbeddedChannel} capable of compressing data * Generate a new instance of an {@link EmbeddedChannel} capable of compressing data
* @param ctx the context.
* @param wrapper Defines what type of encoder should be used * @param wrapper Defines what type of encoder should be used
*/ */
private EmbeddedChannel newCompressionChannel(ZlibWrapper wrapper) { private EmbeddedChannel newCompressionChannel(final ChannelHandlerContext ctx, ZlibWrapper wrapper) {
return new EmbeddedChannel(ZlibCodecFactory.newZlibEncoder(wrapper, compressionLevel, windowBits, return new EmbeddedChannel(ctx.channel().id(), ctx.channel().metadata().hasDisconnect(),
ctx.channel().config(), ZlibCodecFactory.newZlibEncoder(wrapper, compressionLevel, windowBits,
memLevel)); memLevel));
} }
@ -225,12 +229,14 @@ public class CompressorHttp2ConnectionEncoder extends DecoratingHttp2ConnectionE
* Checks if a new compressor object is needed for the stream identified by {@code streamId}. This method will * Checks if a new compressor object is needed for the stream identified by {@code streamId}. This method will
* modify the {@code content-encoding} header contained in {@code headers}. * modify the {@code content-encoding} header contained in {@code headers}.
* *
* @param ctx the context.
* @param headers Object representing headers which are to be written * @param headers Object representing headers which are to be written
* @param endOfStream Indicates if the stream has ended * @param endOfStream Indicates if the stream has ended
* @return The channel used to compress data. * @return The channel used to compress data.
* @throws Http2Exception if any problems occur during initialization. * @throws Http2Exception if any problems occur during initialization.
*/ */
private EmbeddedChannel newCompressor(Http2Headers headers, boolean endOfStream) throws Http2Exception { private EmbeddedChannel newCompressor(ChannelHandlerContext ctx, Http2Headers headers, boolean endOfStream)
throws Http2Exception {
if (endOfStream) { if (endOfStream) {
return null; return null;
} }
@ -239,7 +245,7 @@ public class CompressorHttp2ConnectionEncoder extends DecoratingHttp2ConnectionE
if (encoding == null) { if (encoding == null) {
encoding = IDENTITY; encoding = IDENTITY;
} }
final EmbeddedChannel compressor = newContentCompressor(encoding); final EmbeddedChannel compressor = newContentCompressor(ctx, encoding);
if (compressor != null) { if (compressor != null) {
CharSequence targetContentEncoding = getTargetContentEncoding(encoding); CharSequence targetContentEncoding = getTargetContentEncoding(encoding);
if (IDENTITY.contentEqualsIgnoreCase(targetContentEncoding)) { if (IDENTITY.contentEqualsIgnoreCase(targetContentEncoding)) {

View File

@ -142,14 +142,14 @@ public class DelegatingDecompressorFrameListener extends Http2FrameListenerDecor
@Override @Override
public void onHeadersRead(ChannelHandlerContext ctx, int streamId, Http2Headers headers, int padding, public void onHeadersRead(ChannelHandlerContext ctx, int streamId, Http2Headers headers, int padding,
boolean endStream) throws Http2Exception { boolean endStream) throws Http2Exception {
initDecompressor(streamId, headers, endStream); initDecompressor(ctx, streamId, headers, endStream);
listener.onHeadersRead(ctx, streamId, headers, padding, endStream); listener.onHeadersRead(ctx, streamId, headers, padding, endStream);
} }
@Override @Override
public void onHeadersRead(ChannelHandlerContext ctx, int streamId, Http2Headers headers, int streamDependency, public void onHeadersRead(ChannelHandlerContext ctx, int streamId, Http2Headers headers, int streamDependency,
short weight, boolean exclusive, int padding, boolean endStream) throws Http2Exception { short weight, boolean exclusive, int padding, boolean endStream) throws Http2Exception {
initDecompressor(streamId, headers, endStream); initDecompressor(ctx, streamId, headers, endStream);
listener.onHeadersRead(ctx, streamId, headers, streamDependency, weight, exclusive, padding, endStream); listener.onHeadersRead(ctx, streamId, headers, streamDependency, weight, exclusive, padding, endStream);
} }
@ -162,14 +162,17 @@ public class DelegatingDecompressorFrameListener extends Http2FrameListenerDecor
* (alternatively, you can throw a {@link Http2Exception} to block unknown encoding). * (alternatively, you can throw a {@link Http2Exception} to block unknown encoding).
* @throws Http2Exception If the specified encoding is not not supported and warrants an exception * @throws Http2Exception If the specified encoding is not not supported and warrants an exception
*/ */
protected EmbeddedChannel newContentDecompressor(CharSequence contentEncoding) throws Http2Exception { protected EmbeddedChannel newContentDecompressor(final ChannelHandlerContext ctx, CharSequence contentEncoding)
throws Http2Exception {
if (GZIP.contentEqualsIgnoreCase(contentEncoding) || X_GZIP.contentEqualsIgnoreCase(contentEncoding)) { if (GZIP.contentEqualsIgnoreCase(contentEncoding) || X_GZIP.contentEqualsIgnoreCase(contentEncoding)) {
return new EmbeddedChannel(ZlibCodecFactory.newZlibDecoder(ZlibWrapper.GZIP)); return new EmbeddedChannel(ctx.channel().id(), ctx.channel().metadata().hasDisconnect(),
ctx.channel().config(), ZlibCodecFactory.newZlibDecoder(ZlibWrapper.GZIP));
} }
if (DEFLATE.contentEqualsIgnoreCase(contentEncoding) || X_DEFLATE.contentEqualsIgnoreCase(contentEncoding)) { if (DEFLATE.contentEqualsIgnoreCase(contentEncoding) || X_DEFLATE.contentEqualsIgnoreCase(contentEncoding)) {
final ZlibWrapper wrapper = strict ? ZlibWrapper.ZLIB : ZlibWrapper.ZLIB_OR_NONE; final ZlibWrapper wrapper = strict ? ZlibWrapper.ZLIB : ZlibWrapper.ZLIB_OR_NONE;
// To be strict, 'deflate' means ZLIB, but some servers were not implemented correctly. // To be strict, 'deflate' means ZLIB, but some servers were not implemented correctly.
return new EmbeddedChannel(ZlibCodecFactory.newZlibDecoder(wrapper)); return new EmbeddedChannel(ctx.channel().id(), ctx.channel().metadata().hasDisconnect(),
ctx.channel().config(), ZlibCodecFactory.newZlibDecoder(wrapper));
} }
// 'identity' or unsupported // 'identity' or unsupported
return null; return null;
@ -192,12 +195,14 @@ public class DelegatingDecompressorFrameListener extends Http2FrameListenerDecor
* Checks if a new decompressor object is needed for the stream identified by {@code streamId}. * Checks if a new decompressor object is needed for the stream identified by {@code streamId}.
* This method will modify the {@code content-encoding} header contained in {@code headers}. * This method will modify the {@code content-encoding} header contained in {@code headers}.
* *
* @param ctx The context
* @param streamId The identifier for the headers inside {@code headers} * @param streamId The identifier for the headers inside {@code headers}
* @param headers Object representing headers which have been read * @param headers Object representing headers which have been read
* @param endOfStream Indicates if the stream has ended * @param endOfStream Indicates if the stream has ended
* @throws Http2Exception If the {@code content-encoding} is not supported * @throws Http2Exception If the {@code content-encoding} is not supported
*/ */
private void initDecompressor(int streamId, Http2Headers headers, boolean endOfStream) throws Http2Exception { private void initDecompressor(ChannelHandlerContext ctx, int streamId, Http2Headers headers, boolean endOfStream)
throws Http2Exception {
final Http2Stream stream = connection.stream(streamId); final Http2Stream stream = connection.stream(streamId);
if (stream == null) { if (stream == null) {
return; return;
@ -210,7 +215,7 @@ public class DelegatingDecompressorFrameListener extends Http2FrameListenerDecor
if (contentEncoding == null) { if (contentEncoding == null) {
contentEncoding = IDENTITY; contentEncoding = IDENTITY;
} }
final EmbeddedChannel channel = newContentDecompressor(contentEncoding); final EmbeddedChannel channel = newContentDecompressor(ctx, contentEncoding);
if (channel != null) { if (channel != null) {
decompressor = new Http2Decompressor(channel); decompressor = new Http2Decompressor(channel);
stream.setProperty(propertyKey, decompressor); stream.setProperty(propertyKey, decompressor);

View File

@ -132,11 +132,35 @@ public class EmbeddedChannel extends AbstractChannel {
*/ */
public EmbeddedChannel(ChannelId channelId, boolean hasDisconnect, final ChannelHandler... handlers) { public EmbeddedChannel(ChannelId channelId, boolean hasDisconnect, final ChannelHandler... handlers) {
super(null, channelId); super(null, channelId);
metadata = metadata(hasDisconnect);
ObjectUtil.checkNotNull(handlers, "handlers");
metadata = hasDisconnect ? METADATA_DISCONNECT : METADATA_NO_DISCONNECT;
config = new DefaultChannelConfig(this); config = new DefaultChannelConfig(this);
setup(handlers);
}
/**
* Create a new instance with the channel ID set to the given ID and the pipeline
* initialized with the specified handlers.
*
* @param channelId the {@link ChannelId} that will be used to identify this channel
* @param hasDisconnect {@code false} if this {@link Channel} will delegate {@link #disconnect()}
* to {@link #close()}, {@link false} otherwise.
* @param config the {@link ChannelConfig} which will be returned by {@link #config()}.
* @param handlers the {@link ChannelHandler}s which will be add in the {@link ChannelPipeline}
*/
public EmbeddedChannel(ChannelId channelId, boolean hasDisconnect, final ChannelConfig config,
final ChannelHandler... handlers) {
super(null, channelId);
metadata = metadata(hasDisconnect);
this.config = config;
setup(handlers);
}
private static ChannelMetadata metadata(boolean hasDisconnect) {
return hasDisconnect ? METADATA_DISCONNECT : METADATA_NO_DISCONNECT;
}
private void setup(final ChannelHandler... handlers) {
ObjectUtil.checkNotNull(handlers, "handlers");
ChannelPipeline p = pipeline(); ChannelPipeline p = pipeline();
p.addLast(new ChannelInitializer<Channel>() { p.addLast(new ChannelInitializer<Channel>() {
@Override @Override