Fix a bug where SpdyHeaderBlockZlibDecoder fails to decompress

- Forward-port 4c35b593c1, originally written by @wgallagher
- Fixes #2077
This commit is contained in:
Trustin Lee 2014-01-13 22:40:34 +09:00
parent 999b51b026
commit e1ab46f56a
7 changed files with 342 additions and 44 deletions

View File

@ -85,7 +85,7 @@ public class SpdyFrameEncoder extends MessageToByteEncoder<SpdyFrame> {
} else if (msg instanceof SpdySynStreamFrame) {
SpdySynStreamFrame spdySynStreamFrame = (SpdySynStreamFrame) msg;
ByteBuf data = headerBlockEncoder.encode(ctx, spdySynStreamFrame);
ByteBuf data = headerBlockEncoder.encode(spdySynStreamFrame);
try {
byte flags = spdySynStreamFrame.isLast() ? SPDY_FLAG_FIN : 0;
if (spdySynStreamFrame.isUnidirectional()) {
@ -109,7 +109,7 @@ public class SpdyFrameEncoder extends MessageToByteEncoder<SpdyFrame> {
} else if (msg instanceof SpdySynReplyFrame) {
SpdySynReplyFrame spdySynReplyFrame = (SpdySynReplyFrame) msg;
ByteBuf data = headerBlockEncoder.encode(ctx, spdySynReplyFrame);
ByteBuf data = headerBlockEncoder.encode(spdySynReplyFrame);
try {
byte flags = spdySynReplyFrame.isLast() ? SPDY_FLAG_FIN : 0;
int headerBlockLength = data.readableBytes();
@ -184,7 +184,7 @@ public class SpdyFrameEncoder extends MessageToByteEncoder<SpdyFrame> {
} else if (msg instanceof SpdyHeadersFrame) {
SpdyHeadersFrame spdyHeadersFrame = (SpdyHeadersFrame) msg;
ByteBuf data = headerBlockEncoder.encode(ctx, spdyHeadersFrame);
ByteBuf data = headerBlockEncoder.encode(spdyHeadersFrame);
try {
byte flags = spdyHeadersFrame.isLast() ? SPDY_FLAG_FIN : 0;
int headerBlockLength = data.readableBytes();

View File

@ -16,7 +16,6 @@
package io.netty.handler.codec.spdy;
import io.netty.buffer.ByteBuf;
import io.netty.channel.ChannelHandlerContext;
import io.netty.util.internal.PlatformDependent;
abstract class SpdyHeaderBlockEncoder {
@ -33,6 +32,6 @@ abstract class SpdyHeaderBlockEncoder {
}
}
abstract ByteBuf encode(ChannelHandlerContext ctx, SpdyHeadersFrame frame) throws Exception;
abstract ByteBuf encode(SpdyHeadersFrame frame) throws Exception;
abstract void end();
}

View File

@ -19,7 +19,6 @@ import com.jcraft.jzlib.Deflater;
import com.jcraft.jzlib.JZlib;
import io.netty.buffer.ByteBuf;
import io.netty.buffer.Unpooled;
import io.netty.channel.ChannelHandlerContext;
import io.netty.handler.codec.compression.CompressionException;
import static io.netty.handler.codec.spdy.SpdyCodecUtil.*;
@ -30,7 +29,7 @@ class SpdyHeaderBlockJZlibEncoder extends SpdyHeaderBlockRawEncoder {
private boolean finished;
public SpdyHeaderBlockJZlibEncoder(
SpdyHeaderBlockJZlibEncoder(
SpdyVersion version, int compressionLevel, int windowBits, int memLevel) {
super(version);
if (compressionLevel < 0 || compressionLevel > 9) {
@ -94,7 +93,7 @@ class SpdyHeaderBlockJZlibEncoder extends SpdyHeaderBlockRawEncoder {
}
@Override
public ByteBuf encode(ChannelHandlerContext ctx, SpdyHeadersFrame frame) throws Exception {
public ByteBuf encode(SpdyHeadersFrame frame) throws Exception {
if (frame == null) {
throw new IllegalArgumentException("frame");
}
@ -103,12 +102,12 @@ class SpdyHeaderBlockJZlibEncoder extends SpdyHeaderBlockRawEncoder {
return Unpooled.EMPTY_BUFFER;
}
ByteBuf decompressed = super.encode(ctx, frame);
ByteBuf decompressed = super.encode(frame);
if (decompressed.readableBytes() == 0) {
return Unpooled.EMPTY_BUFFER;
}
ByteBuf compressed = ctx.alloc().buffer();
ByteBuf compressed = decompressed.alloc().buffer();
setInput(decompressed);
encode(compressed);
return compressed;

View File

@ -17,7 +17,6 @@ package io.netty.handler.codec.spdy;
import io.netty.buffer.ByteBuf;
import io.netty.buffer.Unpooled;
import io.netty.channel.ChannelHandlerContext;
import java.util.Set;
@ -51,7 +50,7 @@ public class SpdyHeaderBlockRawEncoder extends SpdyHeaderBlockEncoder {
}
@Override
public ByteBuf encode(ChannelHandlerContext ctx, SpdyHeadersFrame frame) throws Exception {
public ByteBuf encode(SpdyHeadersFrame frame) throws Exception {
Set<String> names = frame.headers().names();
int numHeaders = names.size();
if (numHeaders == 0) {

View File

@ -23,57 +23,79 @@ import java.util.zip.Inflater;
import static io.netty.handler.codec.spdy.SpdyCodecUtil.*;
class SpdyHeaderBlockZlibDecoder extends SpdyHeaderBlockRawDecoder {
final class SpdyHeaderBlockZlibDecoder extends SpdyHeaderBlockRawDecoder {
private static final int DEFAULT_BUFFER_CAPACITY = 4096;
private final byte[] out = new byte[8192];
private final Inflater decompressor = new Inflater();
private ByteBuf decompressed;
public SpdyHeaderBlockZlibDecoder(SpdyVersion version, int maxHeaderSize) {
super(version, maxHeaderSize);
SpdyHeaderBlockZlibDecoder(SpdyVersion spdyVersion, int maxHeaderSize) {
super(spdyVersion, maxHeaderSize);
}
@Override
void decode(ByteBuf encoded, SpdyHeadersFrame frame) throws Exception {
setInput(encoded);
int len = setInput(encoded);
int numBytes;
do {
numBytes = decompress(frame);
} while (!decompressed.isReadable() && numBytes > 0);
} while (numBytes > 0);
if (decompressor.getRemaining() != 0) {
throw new SpdyProtocolException("client sent extra data beyond headers");
}
encoded.skipBytes(len);
}
private void setInput(ByteBuf compressed) {
byte[] in = new byte[compressed.readableBytes()];
compressed.readBytes(in);
decompressor.setInput(in);
private int setInput(ByteBuf compressed) {
int len = compressed.readableBytes();
if (compressed.hasArray()) {
decompressor.setInput(compressed.array(), compressed.arrayOffset() + compressed.readerIndex(), len);
} else {
byte[] in = new byte[len];
compressed.getBytes(compressed.readerIndex(), in);
decompressor.setInput(in, 0, in.length);
}
return len;
}
private int decompress(SpdyHeadersFrame frame) throws Exception {
if (decompressed == null) {
decompressed = Unpooled.buffer(8192);
}
ensureBuffer();
byte[] out = decompressed.array();
int off = decompressed.arrayOffset() + decompressed.writerIndex();
try {
int numBytes = decompressor.inflate(out);
int numBytes = decompressor.inflate(out, off, decompressed.writableBytes());
if (numBytes == 0 && decompressor.needsDictionary()) {
decompressor.setDictionary(SPDY_DICT);
numBytes = decompressor.inflate(out);
numBytes = decompressor.inflate(out, off, decompressed.writableBytes());
}
if (frame != null) {
decompressed.writeBytes(out, 0, numBytes);
decompressed.writerIndex(decompressed.writerIndex() + numBytes);
super.decode(decompressed, frame);
decompressed.discardReadBytes();
}
return numBytes;
} catch (DataFormatException e) {
throw new SpdyProtocolException(
"Received invalid header block", e);
throw new SpdyProtocolException("Received invalid header block", e);
}
}
private void ensureBuffer() {
if (decompressed == null) {
decompressed = Unpooled.buffer(DEFAULT_BUFFER_CAPACITY);
}
decompressed.ensureWritable(1);
}
@Override
public void reset() {
void reset() {
decompressed = null;
super.reset();
}

View File

@ -17,7 +17,6 @@ package io.netty.handler.codec.spdy;
import io.netty.buffer.ByteBuf;
import io.netty.buffer.Unpooled;
import io.netty.channel.ChannelHandlerContext;
import java.util.zip.Deflater;
@ -30,7 +29,7 @@ class SpdyHeaderBlockZlibEncoder extends SpdyHeaderBlockRawEncoder {
private boolean finished;
public SpdyHeaderBlockZlibEncoder(SpdyVersion version, int compressionLevel) {
SpdyHeaderBlockZlibEncoder(SpdyVersion version, int compressionLevel) {
super(version);
if (compressionLevel < 0 || compressionLevel > 9) {
throw new IllegalArgumentException(
@ -55,7 +54,7 @@ class SpdyHeaderBlockZlibEncoder extends SpdyHeaderBlockRawEncoder {
}
@Override
public ByteBuf encode(ChannelHandlerContext ctx, SpdyHeadersFrame frame) throws Exception {
public ByteBuf encode(SpdyHeadersFrame frame) throws Exception {
if (frame == null) {
throw new IllegalArgumentException("frame");
}
@ -64,12 +63,12 @@ class SpdyHeaderBlockZlibEncoder extends SpdyHeaderBlockRawEncoder {
return Unpooled.EMPTY_BUFFER;
}
ByteBuf decompressed = super.encode(ctx, frame);
ByteBuf decompressed = super.encode(frame);
if (decompressed.readableBytes() == 0) {
return Unpooled.EMPTY_BUFFER;
}
ByteBuf compressed = ctx.alloc().buffer();
ByteBuf compressed = decompressed.alloc().buffer();
setInput(decompressed);
encode(compressed);
return compressed;

View File

@ -17,10 +17,13 @@ package io.netty.handler.codec.spdy;
import io.netty.bootstrap.Bootstrap;
import io.netty.bootstrap.ServerBootstrap;
import io.netty.buffer.ByteBuf;
import io.netty.channel.Channel;
import io.netty.channel.ChannelFuture;
import io.netty.channel.ChannelHandlerAdapter;
import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.ChannelInitializer;
import io.netty.channel.ChannelPipeline;
import io.netty.channel.EventLoopGroup;
import io.netty.channel.nio.NioEventLoopGroup;
import io.netty.channel.socket.SocketChannel;
@ -103,6 +106,281 @@ public class SpdyFrameDecoderTest {
}
}
@Test
public void testLargeHeaderNameOnSynStreamRequest() throws Exception {
testLargeHeaderNameOnSynStreamRequest(SpdyVersion.SPDY_3);
testLargeHeaderNameOnSynStreamRequest(SpdyVersion.SPDY_3_1);
}
private static void testLargeHeaderNameOnSynStreamRequest(final SpdyVersion spdyVersion) throws Exception {
final int maxHeaderSize = 8192;
String expectedName = createString('h', 100);
String expectedValue = createString('v', 5000);
SpdyHeadersFrame frame = new DefaultSpdySynStreamFrame(1, 0, (byte) 0);
SpdyHeaders headers = frame.headers();
headers.add(expectedName, expectedValue);
final CaptureHandler captureHandler = new CaptureHandler();
final ServerBootstrap sb = new ServerBootstrap();
Bootstrap cb = new Bootstrap();
sb.group(new NioEventLoopGroup(1));
sb.channel(NioServerSocketChannel.class);
sb.childHandler(new ChannelInitializer<Channel>() {
@Override
protected void initChannel(Channel ch) throws Exception {
ChannelPipeline p = ch.pipeline();
p.addLast("decoder", new SpdyFrameDecoder(spdyVersion, 10000, maxHeaderSize));
p.addLast("sessionHandler", new SpdySessionHandler(spdyVersion, true));
p.addLast("handler", captureHandler);
}
});
cb.group(new NioEventLoopGroup(1));
cb.channel(NioSocketChannel.class);
cb.handler(new ChannelInitializer<Channel>() {
@Override
protected void initChannel(Channel ch) throws Exception {
ch.pipeline().addLast("encoder", new SpdyFrameEncoder(spdyVersion));
}
});
Channel sc = sb.bind(new InetSocketAddress(0)).sync().channel();
int port = ((InetSocketAddress) sc.localAddress()).getPort();
ChannelFuture ccf = cb.connect(new InetSocketAddress(NetUtil.LOCALHOST, port));
assertTrue(ccf.awaitUninterruptibly().isSuccess());
Channel cc = ccf.channel();
sendAndWaitForFrame(cc, frame, captureHandler);
assertNotNull("version " + spdyVersion.getVersion() + ", not null message",
captureHandler.message);
String message = "version " + spdyVersion.getVersion() + ", should be SpdyHeadersFrame, was " +
captureHandler.message.getClass();
assertTrue(message, captureHandler.message instanceof SpdyHeadersFrame);
SpdyHeadersFrame writtenFrame = (SpdyHeadersFrame) captureHandler.message;
assertFalse("should not be truncated", writtenFrame.isTruncated());
assertFalse("should not be invalid", writtenFrame.isInvalid());
String val = writtenFrame.headers().get(expectedName);
assertEquals(expectedValue, val);
sc.close().sync();
sb.group().shutdownGracefully();
cb.group().shutdownGracefully();
}
@Test
public void testZlibHeaders() throws Exception {
SpdyHeadersFrame frame = new DefaultSpdySynStreamFrame(1, 0, (byte) 0);
SpdyHeaders headers = frame.headers();
headers.add(createString('a', 100), createString('b', 100));
SpdyHeadersFrame actual = roundTrip(frame, 8192);
assertFalse("should not be truncated", actual.isTruncated());
assertTrue(equals(frame.headers(), actual.headers()));
headers.clear();
actual = roundTrip(frame, 8192);
assertFalse("should not be truncated", actual.isTruncated());
assertTrue(frame.headers().isEmpty());
assertTrue(equals(frame.headers(), actual.headers()));
headers.clear();
actual = roundTrip(frame, 4096);
assertFalse("should not be truncated", actual.isTruncated());
assertTrue(frame.headers().isEmpty());
assertTrue(equals(frame.headers(), actual.headers()));
headers.clear();
actual = roundTrip(frame, 128);
assertFalse("should not be truncated", actual.isTruncated());
assertTrue(frame.headers().isEmpty());
assertTrue(equals(frame.headers(), actual.headers()));
headers.clear();
headers.add(createString('c', 100), createString('d', 5000));
actual = roundTrip(frame, 8192);
assertFalse("should not be truncated", actual.isTruncated());
assertTrue(equals(frame.headers(), actual.headers()));
headers.clear();
headers.add(createString('e', 5000), createString('f', 100));
actual = roundTrip(frame, 8192);
assertFalse("should not be truncated", actual.isTruncated());
assertTrue(equals(frame.headers(), actual.headers()));
headers.clear();
headers.add(createString('g', 100), createString('h', 5000));
actual = roundTrip(frame, 8192);
assertFalse("should not be truncated", actual.isTruncated());
assertTrue(equals(frame.headers(), actual.headers()));
headers.clear();
headers.add(createString('i', 100), createString('j', 5000));
actual = roundTrip(frame, 4096);
assertTrue("should be truncated", actual.isTruncated());
assertTrue("headers should be empty", actual.headers().isEmpty());
headers.clear();
headers.add(createString('k', 5000), createString('l', 100));
actual = roundTrip(frame, 4096);
assertTrue("should be truncated", actual.isTruncated());
assertTrue("headers should be empty", actual.headers().isEmpty());
headers.clear();
headers.add(createString('m', 100), createString('n', 1000));
headers.add(createString('m', 100), createString('n', 1000));
headers.add(createString('m', 100), createString('n', 1000));
headers.add(createString('m', 100), createString('n', 1000));
headers.add(createString('m', 100), createString('n', 1000));
actual = roundTrip(frame, 8192);
assertFalse("should not be truncated", actual.isTruncated());
assertEquals(1, actual.headers().names().size());
assertEquals(5, actual.headers().getAll(createString('m', 100)).size());
headers.clear();
headers.add(createString('o', 1000), createString('p', 100));
headers.add(createString('o', 1000), createString('p', 100));
headers.add(createString('o', 1000), createString('p', 100));
headers.add(createString('o', 1000), createString('p', 100));
headers.add(createString('o', 1000), createString('p', 100));
actual = roundTrip(frame, 8192);
assertFalse("should not be truncated", actual.isTruncated());
assertEquals(1, actual.headers().names().size());
assertEquals(5, actual.headers().getAll(createString('o', 1000)).size());
headers.clear();
headers.add(createString('q', 100), createString('r', 1000));
headers.add(createString('q', 100), createString('r', 1000));
headers.add(createString('q', 100), createString('r', 1000));
headers.add(createString('q', 100), createString('r', 1000));
headers.add(createString('q', 100), createString('r', 1000));
actual = roundTrip(frame, 4096);
assertTrue("should be truncated", actual.isTruncated());
assertEquals(0, actual.headers().names().size());
headers.clear();
headers.add(createString('s', 1000), createString('t', 100));
headers.add(createString('s', 1000), createString('t', 100));
headers.add(createString('s', 1000), createString('t', 100));
headers.add(createString('s', 1000), createString('t', 100));
headers.add(createString('s', 1000), createString('t', 100));
actual = roundTrip(frame, 4096);
assertFalse("should be truncated", actual.isTruncated());
assertEquals(1, actual.headers().names().size());
assertEquals(5, actual.headers().getAll(createString('s', 1000)).size());
}
@Test
public void testZlibReuseEncoderDecoder() throws Exception {
SpdyHeadersFrame frame = new DefaultSpdySynStreamFrame(1, 0, (byte) 0);
SpdyHeaders headers = frame.headers();
SpdyHeaderBlockEncoder encoder = SpdyHeaderBlockEncoder.newInstance(SpdyVersion.SPDY_3_1, 6, 15, 8);
SpdyHeaderBlockDecoder decoder = SpdyHeaderBlockDecoder.newInstance(SpdyVersion.SPDY_3_1, 8192);
headers.add(createString('a', 100), createString('b', 100));
SpdyHeadersFrame actual = roundTrip(encoder, decoder, frame);
assertFalse("should not be truncated", actual.isTruncated());
assertTrue(equals(frame.headers(), actual.headers()));
encoder.end();
decoder.end();
decoder.reset();
headers.clear();
actual = roundTrip(frame, 8192);
assertFalse("should not be truncated", actual.isTruncated());
assertTrue(frame.headers().isEmpty());
assertTrue(equals(frame.headers(), actual.headers()));
encoder.end();
decoder.end();
decoder.reset();
headers.clear();
headers.add(createString('e', 5000), createString('f', 100));
actual = roundTrip(frame, 8192);
assertFalse("should not be truncated", actual.isTruncated());
assertTrue(equals(frame.headers(), actual.headers()));
encoder.end();
decoder.end();
decoder.reset();
headers.clear();
headers.add(createString('g', 100), createString('h', 5000));
actual = roundTrip(frame, 8192);
assertFalse("should not be truncated", actual.isTruncated());
assertTrue(equals(frame.headers(), actual.headers()));
encoder.end();
decoder.end();
decoder.reset();
headers.clear();
headers.add(createString('m', 100), createString('n', 1000));
headers.add(createString('m', 100), createString('n', 1000));
headers.add(createString('m', 100), createString('n', 1000));
headers.add(createString('m', 100), createString('n', 1000));
headers.add(createString('m', 100), createString('n', 1000));
actual = roundTrip(frame, 8192);
assertFalse("should not be truncated", actual.isTruncated());
assertEquals(1, actual.headers().names().size());
assertEquals(5, actual.headers().getAll(createString('m', 100)).size());
encoder.end();
decoder.end();
decoder.reset();
headers.clear();
headers.add(createString('o', 1000), createString('p', 100));
headers.add(createString('o', 1000), createString('p', 100));
headers.add(createString('o', 1000), createString('p', 100));
headers.add(createString('o', 1000), createString('p', 100));
headers.add(createString('o', 1000), createString('p', 100));
actual = roundTrip(frame, 8192);
assertFalse("should not be truncated", actual.isTruncated());
assertEquals(1, actual.headers().names().size());
assertEquals(5, actual.headers().getAll(createString('o', 1000)).size());
}
private static SpdyHeadersFrame roundTrip(SpdyHeadersFrame frame, int maxHeaderSize) throws Exception {
SpdyHeaderBlockEncoder encoder = SpdyHeaderBlockEncoder.newInstance(SpdyVersion.SPDY_3_1, 6, 15, 8);
SpdyHeaderBlockDecoder decoder = SpdyHeaderBlockDecoder.newInstance(SpdyVersion.SPDY_3_1, maxHeaderSize);
return roundTrip(encoder, decoder, frame);
}
private static SpdyHeadersFrame roundTrip(SpdyHeaderBlockEncoder encoder, SpdyHeaderBlockDecoder decoder,
SpdyHeadersFrame frame) throws Exception {
ByteBuf encoded = encoder.encode(frame);
SpdyHeadersFrame actual = new DefaultSpdySynStreamFrame(1, 0, (byte) 0);
decoder.decode(encoded, actual);
return actual;
}
private static boolean equals(SpdyHeaders h1, SpdyHeaders h2) {
if (!h1.names().equals(h2.names())) {
return false;
}
for (String name : h1.names()) {
if (!h1.getAll(name).equals(h2.getAll(name))) {
return false;
}
}
return true;
}
private static void sendAndWaitForFrame(Channel cc, SpdyFrame frame, CaptureHandler handler) {
cc.writeAndFlush(frame);
long theFuture = System.currentTimeMillis() + 3000;
@ -117,15 +395,17 @@ public class SpdyFrameDecoderTest {
private static void addHeader(SpdyHeadersFrame frame, int headerNameSize, int headerValueSize) {
frame.headers().add("k", "v");
StringBuilder headerName = new StringBuilder();
for (int i = 0; i < headerNameSize; i++) {
headerName.append('h');
String headerName = createString('h', headerNameSize);
String headerValue = createString('h', headerValueSize);
frame.headers().add(headerName, headerValue);
}
private static String createString(char c, int rep) {
StringBuilder sb = new StringBuilder();
for (int i = 0; i < rep; i++) {
sb.append(c);
}
StringBuilder headerValue = new StringBuilder();
for (int i = 0; i < headerValueSize; i++) {
headerValue.append('a');
}
frame.headers().add(headerName.toString(), headerValue.toString());
return sb.toString();
}
private static class CaptureHandler extends ChannelHandlerAdapter {