Fix a bug where SpdyHeaderBlockZlibDecoder fails to decompress

- Fixes #2077
This commit is contained in:
wgallagher 2014-01-09 11:29:09 -05:00 committed by Trustin Lee
parent 2bc252f22b
commit 4c35b593c1
2 changed files with 285 additions and 12 deletions

View File

@ -25,9 +25,11 @@ import org.jboss.netty.buffer.ChannelBuffers;
final class SpdyHeaderBlockZlibDecoder extends SpdyHeaderBlockRawDecoder { final class SpdyHeaderBlockZlibDecoder extends SpdyHeaderBlockRawDecoder {
private static final int DEFAULT_BUFFER_CAPACITY = 4096;
private final Inflater decompressor = new Inflater(); private final Inflater decompressor = new Inflater();
private final ChannelBuffer decompressed = ChannelBuffers.buffer(4096); private ChannelBuffer decompressed;
public SpdyHeaderBlockZlibDecoder(SpdyVersion spdyVersion, int maxHeaderSize) { public SpdyHeaderBlockZlibDecoder(SpdyVersion spdyVersion, int maxHeaderSize) {
super(spdyVersion, maxHeaderSize); super(spdyVersion, maxHeaderSize);
@ -40,7 +42,7 @@ final class SpdyHeaderBlockZlibDecoder extends SpdyHeaderBlockRawDecoder {
int numBytes; int numBytes;
do { do {
numBytes = decompress(frame); numBytes = decompress(frame);
} while (!decompressed.readable() && numBytes > 0); } while (numBytes > 0);
if (decompressor.getRemaining() != 0) { if (decompressor.getRemaining() != 0) {
throw new SpdyProtocolException("client sent extra data beyond headers"); throw new SpdyProtocolException("client sent extra data beyond headers");
@ -64,6 +66,7 @@ final class SpdyHeaderBlockZlibDecoder extends SpdyHeaderBlockRawDecoder {
} }
private int decompress(SpdyHeadersFrame frame) throws Exception { private int decompress(SpdyHeadersFrame frame) throws Exception {
ensureBuffer();
byte[] out = decompressed.array(); byte[] out = decompressed.array();
int off = decompressed.arrayOffset() + decompressed.writerIndex(); int off = decompressed.arrayOffset() + decompressed.writerIndex();
try { try {
@ -84,15 +87,22 @@ final class SpdyHeaderBlockZlibDecoder extends SpdyHeaderBlockRawDecoder {
} }
} }
private void ensureBuffer() {
if (decompressed == null) {
decompressed = ChannelBuffers.dynamicBuffer(DEFAULT_BUFFER_CAPACITY);
}
decompressed.ensureWritableBytes(1);
}
@Override @Override
void reset() { void reset() {
decompressed.clear(); decompressed = null;
super.reset(); super.reset();
} }
@Override @Override
public void end() { public void end() {
decompressed.clear(); decompressed = null;
decompressor.end(); decompressor.end();
super.end(); super.end();
} }

View File

@ -17,6 +17,7 @@ package org.jboss.netty.handler.codec.spdy;
import org.jboss.netty.bootstrap.ClientBootstrap; import org.jboss.netty.bootstrap.ClientBootstrap;
import org.jboss.netty.bootstrap.ServerBootstrap; import org.jboss.netty.bootstrap.ServerBootstrap;
import org.jboss.netty.buffer.ChannelBuffer;
import org.jboss.netty.channel.Channel; import org.jboss.netty.channel.Channel;
import org.jboss.netty.channel.ChannelFactory; import org.jboss.netty.channel.ChannelFactory;
import org.jboss.netty.channel.ChannelFuture; import org.jboss.netty.channel.ChannelFuture;
@ -91,6 +92,266 @@ public class SpdyFrameDecoderTest {
} }
} }
@Test
public void testLargeHeaderNameOnSynStreamRequest() throws Exception {
testLargeHeaderNameOnSynStreamRequest(SpdyVersion.SPDY_3);
testLargeHeaderNameOnSynStreamRequest(SpdyVersion.SPDY_3_1);
}
private void testLargeHeaderNameOnSynStreamRequest(SpdyVersion spdyVersion) throws Exception {
int maxHeaderSize = 8192;
String expectedName = createString('h', 100);
String expectedValue = createString('v', 5000);
SpdyHeadersFrame frame = new DefaultSpdySynStreamFrame(1, 0, (byte) 0);
SpdyHeaders headers = frame.headers();
headers.add(expectedName, expectedValue);
CaptureHandler captureHandler = new CaptureHandler();
ServerBootstrap sb = new ServerBootstrap(
newServerSocketChannelFactory(Executors.newCachedThreadPool()));
ClientBootstrap cb = new ClientBootstrap(
newClientSocketChannelFactory(Executors.newCachedThreadPool()));
sb.getPipeline().addLast("decoder", new SpdyFrameDecoder(spdyVersion, 10000, maxHeaderSize));
sb.getPipeline().addLast("sessionHandler", new SpdySessionHandler(spdyVersion, true));
sb.getPipeline().addLast("handler", captureHandler);
cb.getPipeline().addLast("encoder", new SpdyFrameEncoder(spdyVersion));
Channel sc = sb.bind(new InetSocketAddress(0));
int port = ((InetSocketAddress) sc.getLocalAddress()).getPort();
ChannelFuture ccf = cb.connect(new InetSocketAddress(TestUtil.getLocalHost(), port));
assertTrue(ccf.awaitUninterruptibly().isSuccess());
Channel cc = ccf.getChannel();
sendAndWaitForFrame(cc, frame, captureHandler);
assertNotNull("version " + spdyVersion.getVersion() + ", not null message",
captureHandler.message);
String message = "version " + spdyVersion.getVersion() + ", should be SpdyHeadersFrame, was " +
captureHandler.message.getClass();
assertTrue(message, captureHandler.message instanceof SpdyHeadersFrame);
SpdyHeadersFrame writtenFrame = (SpdyHeadersFrame) captureHandler.message;
assertFalse("should not be truncated", writtenFrame.isTruncated());
assertFalse("should not be invalid", writtenFrame.isInvalid());
String val = writtenFrame.headers().get(expectedName);
assertEquals(expectedValue, val);
sc.close().awaitUninterruptibly();
cb.shutdown();
sb.shutdown();
cb.releaseExternalResources();
sb.releaseExternalResources();
}
@Test
public void testZlibHeaders() throws Exception {
SpdyHeadersFrame frame = new DefaultSpdySynStreamFrame(1, 0, (byte) 0);
SpdyHeaders headers = frame.headers();
headers.add(createString('a', 100), createString('b', 100));
SpdyHeadersFrame actual = roundTrip(frame, 8192);
assertFalse("should not be truncated", actual.isTruncated());
assertTrue(equals(frame.headers(), actual.headers()));
headers.clear();
actual = roundTrip(frame, 8192);
assertFalse("should not be truncated", actual.isTruncated());
assertTrue(frame.headers().isEmpty());
assertTrue(equals(frame.headers(), actual.headers()));
headers.clear();
actual = roundTrip(frame, 4096);
assertFalse("should not be truncated", actual.isTruncated());
assertTrue(frame.headers().isEmpty());
assertTrue(equals(frame.headers(), actual.headers()));
headers.clear();
actual = roundTrip(frame, 128);
assertFalse("should not be truncated", actual.isTruncated());
assertTrue(frame.headers().isEmpty());
assertTrue(equals(frame.headers(), actual.headers()));
headers.clear();
headers.add(createString('c', 100), createString('d', 5000));
actual = roundTrip(frame, 8192);
assertFalse("should not be truncated", actual.isTruncated());
assertTrue(equals(frame.headers(), actual.headers()));
headers.clear();
headers.add(createString('e', 5000), createString('f', 100));
actual = roundTrip(frame, 8192);
assertFalse("should not be truncated", actual.isTruncated());
assertTrue(equals(frame.headers(), actual.headers()));
headers.clear();
headers.add(createString('g', 100), createString('h', 5000));
actual = roundTrip(frame, 8192);
assertFalse("should not be truncated", actual.isTruncated());
assertTrue(equals(frame.headers(), actual.headers()));
headers.clear();
headers.add(createString('i', 100), createString('j', 5000));
actual = roundTrip(frame, 4096);
assertTrue("should be truncated", actual.isTruncated());
assertTrue("headers should be empty", actual.headers().isEmpty());
headers.clear();
headers.add(createString('k', 5000), createString('l', 100));
actual = roundTrip(frame, 4096);
assertTrue("should be truncated", actual.isTruncated());
assertTrue("headers should be empty", actual.headers().isEmpty());
headers.clear();
headers.add(createString('m', 100), createString('n', 1000));
headers.add(createString('m', 100), createString('n', 1000));
headers.add(createString('m', 100), createString('n', 1000));
headers.add(createString('m', 100), createString('n', 1000));
headers.add(createString('m', 100), createString('n', 1000));
actual = roundTrip(frame, 8192);
assertFalse("should not be truncated", actual.isTruncated());
assertEquals(1, actual.headers().names().size());
assertEquals(5, actual.headers().getAll(createString('m', 100)).size());
headers.clear();
headers.add(createString('o', 1000), createString('p', 100));
headers.add(createString('o', 1000), createString('p', 100));
headers.add(createString('o', 1000), createString('p', 100));
headers.add(createString('o', 1000), createString('p', 100));
headers.add(createString('o', 1000), createString('p', 100));
actual = roundTrip(frame, 8192);
assertFalse("should not be truncated", actual.isTruncated());
assertEquals(1, actual.headers().names().size());
assertEquals(5, actual.headers().getAll(createString('o', 1000)).size());
headers.clear();
headers.add(createString('q', 100), createString('r', 1000));
headers.add(createString('q', 100), createString('r', 1000));
headers.add(createString('q', 100), createString('r', 1000));
headers.add(createString('q', 100), createString('r', 1000));
headers.add(createString('q', 100), createString('r', 1000));
actual = roundTrip(frame, 4096);
assertTrue("should be truncated", actual.isTruncated());
assertEquals(0, actual.headers().names().size());
headers.clear();
headers.add(createString('s', 1000), createString('t', 100));
headers.add(createString('s', 1000), createString('t', 100));
headers.add(createString('s', 1000), createString('t', 100));
headers.add(createString('s', 1000), createString('t', 100));
headers.add(createString('s', 1000), createString('t', 100));
actual = roundTrip(frame, 4096);
assertFalse("should be truncated", actual.isTruncated());
assertEquals(1, actual.headers().names().size());
assertEquals(5, actual.headers().getAll(createString('s', 1000)).size());
}
@Test
public void testZlibReuseEncoderDecoder() throws Exception {
SpdyHeadersFrame frame = new DefaultSpdySynStreamFrame(1, 0, (byte) 0);
SpdyHeaders headers = frame.headers();
SpdyHeaderBlockEncoder encoder = SpdyHeaderBlockEncoder.newInstance(SpdyVersion.SPDY_3_1, 6, 15, 8);
SpdyHeaderBlockDecoder decoder = SpdyHeaderBlockDecoder.newInstance(SpdyVersion.SPDY_3_1, 8192);
headers.add(createString('a', 100), createString('b', 100));
SpdyHeadersFrame actual = roundTrip(encoder, decoder, frame);
assertFalse("should not be truncated", actual.isTruncated());
assertTrue(equals(frame.headers(), actual.headers()));
encoder.end();
decoder.end();
decoder.reset();
headers.clear();
actual = roundTrip(frame, 8192);
assertFalse("should not be truncated", actual.isTruncated());
assertTrue(frame.headers().isEmpty());
assertTrue(equals(frame.headers(), actual.headers()));
encoder.end();
decoder.end();
decoder.reset();
headers.clear();
headers.add(createString('e', 5000), createString('f', 100));
actual = roundTrip(frame, 8192);
assertFalse("should not be truncated", actual.isTruncated());
assertTrue(equals(frame.headers(), actual.headers()));
encoder.end();
decoder.end();
decoder.reset();
headers.clear();
headers.add(createString('g', 100), createString('h', 5000));
actual = roundTrip(frame, 8192);
assertFalse("should not be truncated", actual.isTruncated());
assertTrue(equals(frame.headers(), actual.headers()));
encoder.end();
decoder.end();
decoder.reset();
headers.clear();
headers.add(createString('m', 100), createString('n', 1000));
headers.add(createString('m', 100), createString('n', 1000));
headers.add(createString('m', 100), createString('n', 1000));
headers.add(createString('m', 100), createString('n', 1000));
headers.add(createString('m', 100), createString('n', 1000));
actual = roundTrip(frame, 8192);
assertFalse("should not be truncated", actual.isTruncated());
assertEquals(1, actual.headers().names().size());
assertEquals(5, actual.headers().getAll(createString('m', 100)).size());
encoder.end();
decoder.end();
decoder.reset();
headers.clear();
headers.add(createString('o', 1000), createString('p', 100));
headers.add(createString('o', 1000), createString('p', 100));
headers.add(createString('o', 1000), createString('p', 100));
headers.add(createString('o', 1000), createString('p', 100));
headers.add(createString('o', 1000), createString('p', 100));
actual = roundTrip(frame, 8192);
assertFalse("should not be truncated", actual.isTruncated());
assertEquals(1, actual.headers().names().size());
assertEquals(5, actual.headers().getAll(createString('o', 1000)).size());
}
private SpdyHeadersFrame roundTrip(SpdyHeadersFrame frame, int maxHeaderSize) throws Exception {
SpdyHeaderBlockEncoder encoder = SpdyHeaderBlockEncoder.newInstance(SpdyVersion.SPDY_3_1, 6, 15, 8);
SpdyHeaderBlockDecoder decoder = SpdyHeaderBlockDecoder.newInstance(SpdyVersion.SPDY_3_1, maxHeaderSize);
return roundTrip(encoder, decoder, frame);
}
private SpdyHeadersFrame roundTrip(SpdyHeaderBlockEncoder encoder, SpdyHeaderBlockDecoder decoder,
SpdyHeadersFrame frame) throws Exception {
ChannelBuffer encoded = encoder.encode(frame);
SpdyHeadersFrame actual = new DefaultSpdySynStreamFrame(1, 0, (byte) 0);
decoder.decode(encoded, actual);
return actual;
}
private static boolean equals(SpdyHeaders h1, SpdyHeaders h2) {
if (!h1.names().equals(h2.names())) return false;
for (String name : h1.names()) {
if (!h1.getAll(name).equals(h2.getAll(name))) {
return false;
}
}
return true;
}
private static void sendAndWaitForFrame(Channel cc, SpdyFrame frame, CaptureHandler handler) { private static void sendAndWaitForFrame(Channel cc, SpdyFrame frame, CaptureHandler handler) {
cc.write(frame); cc.write(frame);
long theFuture = System.currentTimeMillis() + 3000; long theFuture = System.currentTimeMillis() + 3000;
@ -105,15 +366,17 @@ public class SpdyFrameDecoderTest {
private static void addHeader(SpdyHeadersFrame frame, int headerNameSize, int headerValueSize) { private static void addHeader(SpdyHeadersFrame frame, int headerNameSize, int headerValueSize) {
frame.headers().add("k", "v"); frame.headers().add("k", "v");
StringBuilder headerName = new StringBuilder(); String headerName = createString('h', headerNameSize);
for (int i = 0; i < headerNameSize; i++) { String headerValue = createString('h', headerValueSize);
headerName.append('h'); frame.headers().add(headerName, headerValue);
} }
StringBuilder headerValue = new StringBuilder();
for (int i = 0; i < headerValueSize; i++) { private static String createString(char c, int rep) {
headerValue.append('a'); StringBuilder sb = new StringBuilder();
for (int i = 0; i < rep; i++) {
sb.append(c);
} }
frame.headers().add(headerName.toString(), headerValue.toString()); return sb.toString();
} }
protected ChannelFactory newClientSocketChannelFactory(Executor executor) { protected ChannelFactory newClientSocketChannelFactory(Executor executor) {