diff --git a/src/main/java/org/jboss/netty/handler/codec/spdy/SpdyHeaderBlockZlibDecoder.java b/src/main/java/org/jboss/netty/handler/codec/spdy/SpdyHeaderBlockZlibDecoder.java index 5da002f8c3..0b5449e3c9 100644 --- a/src/main/java/org/jboss/netty/handler/codec/spdy/SpdyHeaderBlockZlibDecoder.java +++ b/src/main/java/org/jboss/netty/handler/codec/spdy/SpdyHeaderBlockZlibDecoder.java @@ -25,9 +25,11 @@ import org.jboss.netty.buffer.ChannelBuffers; final class SpdyHeaderBlockZlibDecoder extends SpdyHeaderBlockRawDecoder { + private static final int DEFAULT_BUFFER_CAPACITY = 4096; + private final Inflater decompressor = new Inflater(); - private final ChannelBuffer decompressed = ChannelBuffers.buffer(4096); + private ChannelBuffer decompressed; public SpdyHeaderBlockZlibDecoder(SpdyVersion spdyVersion, int maxHeaderSize) { super(spdyVersion, maxHeaderSize); @@ -40,7 +42,7 @@ final class SpdyHeaderBlockZlibDecoder extends SpdyHeaderBlockRawDecoder { int numBytes; do { numBytes = decompress(frame); - } while (!decompressed.readable() && numBytes > 0); + } while (numBytes > 0); if (decompressor.getRemaining() != 0) { throw new SpdyProtocolException("client sent extra data beyond headers"); @@ -64,6 +66,7 @@ final class SpdyHeaderBlockZlibDecoder extends SpdyHeaderBlockRawDecoder { } private int decompress(SpdyHeadersFrame frame) throws Exception { + ensureBuffer(); byte[] out = decompressed.array(); int off = decompressed.arrayOffset() + decompressed.writerIndex(); try { @@ -84,15 +87,22 @@ final class SpdyHeaderBlockZlibDecoder extends SpdyHeaderBlockRawDecoder { } } + private void ensureBuffer() { + if (decompressed == null) { + decompressed = ChannelBuffers.dynamicBuffer(DEFAULT_BUFFER_CAPACITY); + } + decompressed.ensureWritableBytes(1); + } + @Override void reset() { - decompressed.clear(); + decompressed = null; super.reset(); } @Override public void end() { - decompressed.clear(); + decompressed = null; decompressor.end(); super.end(); } diff --git a/src/test/java/org/jboss/netty/handler/codec/spdy/SpdyFrameDecoderTest.java b/src/test/java/org/jboss/netty/handler/codec/spdy/SpdyFrameDecoderTest.java index bf31e51706..b6b6803929 100644 --- a/src/test/java/org/jboss/netty/handler/codec/spdy/SpdyFrameDecoderTest.java +++ b/src/test/java/org/jboss/netty/handler/codec/spdy/SpdyFrameDecoderTest.java @@ -17,6 +17,7 @@ package org.jboss.netty.handler.codec.spdy; import org.jboss.netty.bootstrap.ClientBootstrap; import org.jboss.netty.bootstrap.ServerBootstrap; +import org.jboss.netty.buffer.ChannelBuffer; import org.jboss.netty.channel.Channel; import org.jboss.netty.channel.ChannelFactory; import org.jboss.netty.channel.ChannelFuture; @@ -91,6 +92,266 @@ public class SpdyFrameDecoderTest { } } + @Test + public void testLargeHeaderNameOnSynStreamRequest() throws Exception { + testLargeHeaderNameOnSynStreamRequest(SpdyVersion.SPDY_3); + testLargeHeaderNameOnSynStreamRequest(SpdyVersion.SPDY_3_1); + } + + private void testLargeHeaderNameOnSynStreamRequest(SpdyVersion spdyVersion) throws Exception { + int maxHeaderSize = 8192; + + String expectedName = createString('h', 100); + String expectedValue = createString('v', 5000); + + SpdyHeadersFrame frame = new DefaultSpdySynStreamFrame(1, 0, (byte) 0); + SpdyHeaders headers = frame.headers(); + headers.add(expectedName, expectedValue); + + CaptureHandler captureHandler = new CaptureHandler(); + ServerBootstrap sb = new ServerBootstrap( + newServerSocketChannelFactory(Executors.newCachedThreadPool())); + ClientBootstrap cb = new ClientBootstrap( + newClientSocketChannelFactory(Executors.newCachedThreadPool())); + + sb.getPipeline().addLast("decoder", new SpdyFrameDecoder(spdyVersion, 10000, maxHeaderSize)); + sb.getPipeline().addLast("sessionHandler", new SpdySessionHandler(spdyVersion, true)); + sb.getPipeline().addLast("handler", captureHandler); + + cb.getPipeline().addLast("encoder", new SpdyFrameEncoder(spdyVersion)); + + Channel sc = sb.bind(new InetSocketAddress(0)); + int port = ((InetSocketAddress) sc.getLocalAddress()).getPort(); + + ChannelFuture ccf = cb.connect(new InetSocketAddress(TestUtil.getLocalHost(), port)); + assertTrue(ccf.awaitUninterruptibly().isSuccess()); + Channel cc = ccf.getChannel(); + + sendAndWaitForFrame(cc, frame, captureHandler); + + assertNotNull("version " + spdyVersion.getVersion() + ", not null message", + captureHandler.message); + String message = "version " + spdyVersion.getVersion() + ", should be SpdyHeadersFrame, was " + + captureHandler.message.getClass(); + assertTrue(message, captureHandler.message instanceof SpdyHeadersFrame); + SpdyHeadersFrame writtenFrame = (SpdyHeadersFrame) captureHandler.message; + + assertFalse("should not be truncated", writtenFrame.isTruncated()); + assertFalse("should not be invalid", writtenFrame.isInvalid()); + + String val = writtenFrame.headers().get(expectedName); + assertEquals(expectedValue, val); + + sc.close().awaitUninterruptibly(); + cb.shutdown(); + sb.shutdown(); + cb.releaseExternalResources(); + sb.releaseExternalResources(); + } + + @Test + public void testZlibHeaders() throws Exception { + + SpdyHeadersFrame frame = new DefaultSpdySynStreamFrame(1, 0, (byte) 0); + SpdyHeaders headers = frame.headers(); + + headers.add(createString('a', 100), createString('b', 100)); + SpdyHeadersFrame actual = roundTrip(frame, 8192); + assertFalse("should not be truncated", actual.isTruncated()); + assertTrue(equals(frame.headers(), actual.headers())); + + headers.clear(); + actual = roundTrip(frame, 8192); + assertFalse("should not be truncated", actual.isTruncated()); + assertTrue(frame.headers().isEmpty()); + assertTrue(equals(frame.headers(), actual.headers())); + + headers.clear(); + actual = roundTrip(frame, 4096); + assertFalse("should not be truncated", actual.isTruncated()); + assertTrue(frame.headers().isEmpty()); + assertTrue(equals(frame.headers(), actual.headers())); + + headers.clear(); + actual = roundTrip(frame, 128); + assertFalse("should not be truncated", actual.isTruncated()); + assertTrue(frame.headers().isEmpty()); + assertTrue(equals(frame.headers(), actual.headers())); + + headers.clear(); + headers.add(createString('c', 100), createString('d', 5000)); + actual = roundTrip(frame, 8192); + assertFalse("should not be truncated", actual.isTruncated()); + assertTrue(equals(frame.headers(), actual.headers())); + + headers.clear(); + headers.add(createString('e', 5000), createString('f', 100)); + actual = roundTrip(frame, 8192); + assertFalse("should not be truncated", actual.isTruncated()); + assertTrue(equals(frame.headers(), actual.headers())); + + headers.clear(); + headers.add(createString('g', 100), createString('h', 5000)); + actual = roundTrip(frame, 8192); + assertFalse("should not be truncated", actual.isTruncated()); + assertTrue(equals(frame.headers(), actual.headers())); + + headers.clear(); + headers.add(createString('i', 100), createString('j', 5000)); + actual = roundTrip(frame, 4096); + assertTrue("should be truncated", actual.isTruncated()); + assertTrue("headers should be empty", actual.headers().isEmpty()); + + headers.clear(); + headers.add(createString('k', 5000), createString('l', 100)); + actual = roundTrip(frame, 4096); + assertTrue("should be truncated", actual.isTruncated()); + assertTrue("headers should be empty", actual.headers().isEmpty()); + + headers.clear(); + headers.add(createString('m', 100), createString('n', 1000)); + headers.add(createString('m', 100), createString('n', 1000)); + headers.add(createString('m', 100), createString('n', 1000)); + headers.add(createString('m', 100), createString('n', 1000)); + headers.add(createString('m', 100), createString('n', 1000)); + actual = roundTrip(frame, 8192); + assertFalse("should not be truncated", actual.isTruncated()); + assertEquals(1, actual.headers().names().size()); + assertEquals(5, actual.headers().getAll(createString('m', 100)).size()); + + headers.clear(); + headers.add(createString('o', 1000), createString('p', 100)); + headers.add(createString('o', 1000), createString('p', 100)); + headers.add(createString('o', 1000), createString('p', 100)); + headers.add(createString('o', 1000), createString('p', 100)); + headers.add(createString('o', 1000), createString('p', 100)); + actual = roundTrip(frame, 8192); + assertFalse("should not be truncated", actual.isTruncated()); + assertEquals(1, actual.headers().names().size()); + assertEquals(5, actual.headers().getAll(createString('o', 1000)).size()); + + headers.clear(); + headers.add(createString('q', 100), createString('r', 1000)); + headers.add(createString('q', 100), createString('r', 1000)); + headers.add(createString('q', 100), createString('r', 1000)); + headers.add(createString('q', 100), createString('r', 1000)); + headers.add(createString('q', 100), createString('r', 1000)); + actual = roundTrip(frame, 4096); + assertTrue("should be truncated", actual.isTruncated()); + assertEquals(0, actual.headers().names().size()); + + headers.clear(); + headers.add(createString('s', 1000), createString('t', 100)); + headers.add(createString('s', 1000), createString('t', 100)); + headers.add(createString('s', 1000), createString('t', 100)); + headers.add(createString('s', 1000), createString('t', 100)); + headers.add(createString('s', 1000), createString('t', 100)); + actual = roundTrip(frame, 4096); + assertFalse("should be truncated", actual.isTruncated()); + assertEquals(1, actual.headers().names().size()); + assertEquals(5, actual.headers().getAll(createString('s', 1000)).size()); + } + + @Test + public void testZlibReuseEncoderDecoder() throws Exception { + SpdyHeadersFrame frame = new DefaultSpdySynStreamFrame(1, 0, (byte) 0); + SpdyHeaders headers = frame.headers(); + + SpdyHeaderBlockEncoder encoder = SpdyHeaderBlockEncoder.newInstance(SpdyVersion.SPDY_3_1, 6, 15, 8); + SpdyHeaderBlockDecoder decoder = SpdyHeaderBlockDecoder.newInstance(SpdyVersion.SPDY_3_1, 8192); + + headers.add(createString('a', 100), createString('b', 100)); + SpdyHeadersFrame actual = roundTrip(encoder, decoder, frame); + assertFalse("should not be truncated", actual.isTruncated()); + assertTrue(equals(frame.headers(), actual.headers())); + + encoder.end(); + decoder.end(); + decoder.reset(); + + headers.clear(); + actual = roundTrip(frame, 8192); + assertFalse("should not be truncated", actual.isTruncated()); + assertTrue(frame.headers().isEmpty()); + assertTrue(equals(frame.headers(), actual.headers())); + + encoder.end(); + decoder.end(); + decoder.reset(); + + headers.clear(); + headers.add(createString('e', 5000), createString('f', 100)); + actual = roundTrip(frame, 8192); + assertFalse("should not be truncated", actual.isTruncated()); + assertTrue(equals(frame.headers(), actual.headers())); + + encoder.end(); + decoder.end(); + decoder.reset(); + + headers.clear(); + headers.add(createString('g', 100), createString('h', 5000)); + actual = roundTrip(frame, 8192); + assertFalse("should not be truncated", actual.isTruncated()); + assertTrue(equals(frame.headers(), actual.headers())); + + encoder.end(); + decoder.end(); + decoder.reset(); + + headers.clear(); + headers.add(createString('m', 100), createString('n', 1000)); + headers.add(createString('m', 100), createString('n', 1000)); + headers.add(createString('m', 100), createString('n', 1000)); + headers.add(createString('m', 100), createString('n', 1000)); + headers.add(createString('m', 100), createString('n', 1000)); + actual = roundTrip(frame, 8192); + assertFalse("should not be truncated", actual.isTruncated()); + assertEquals(1, actual.headers().names().size()); + assertEquals(5, actual.headers().getAll(createString('m', 100)).size()); + + encoder.end(); + decoder.end(); + decoder.reset(); + + headers.clear(); + headers.add(createString('o', 1000), createString('p', 100)); + headers.add(createString('o', 1000), createString('p', 100)); + headers.add(createString('o', 1000), createString('p', 100)); + headers.add(createString('o', 1000), createString('p', 100)); + headers.add(createString('o', 1000), createString('p', 100)); + actual = roundTrip(frame, 8192); + assertFalse("should not be truncated", actual.isTruncated()); + assertEquals(1, actual.headers().names().size()); + assertEquals(5, actual.headers().getAll(createString('o', 1000)).size()); + } + + private SpdyHeadersFrame roundTrip(SpdyHeadersFrame frame, int maxHeaderSize) throws Exception { + SpdyHeaderBlockEncoder encoder = SpdyHeaderBlockEncoder.newInstance(SpdyVersion.SPDY_3_1, 6, 15, 8); + SpdyHeaderBlockDecoder decoder = SpdyHeaderBlockDecoder.newInstance(SpdyVersion.SPDY_3_1, maxHeaderSize); + return roundTrip(encoder, decoder, frame); + } + + private SpdyHeadersFrame roundTrip(SpdyHeaderBlockEncoder encoder, SpdyHeaderBlockDecoder decoder, + SpdyHeadersFrame frame) throws Exception { + ChannelBuffer encoded = encoder.encode(frame); + + SpdyHeadersFrame actual = new DefaultSpdySynStreamFrame(1, 0, (byte) 0); + + decoder.decode(encoded, actual); + return actual; + } + + private static boolean equals(SpdyHeaders h1, SpdyHeaders h2) { + if (!h1.names().equals(h2.names())) return false; + for (String name : h1.names()) { + if (!h1.getAll(name).equals(h2.getAll(name))) { + return false; + } + } + return true; + } + private static void sendAndWaitForFrame(Channel cc, SpdyFrame frame, CaptureHandler handler) { cc.write(frame); long theFuture = System.currentTimeMillis() + 3000; @@ -105,15 +366,17 @@ public class SpdyFrameDecoderTest { private static void addHeader(SpdyHeadersFrame frame, int headerNameSize, int headerValueSize) { frame.headers().add("k", "v"); - StringBuilder headerName = new StringBuilder(); - for (int i = 0; i < headerNameSize; i++) { - headerName.append('h'); + String headerName = createString('h', headerNameSize); + String headerValue = createString('h', headerValueSize); + frame.headers().add(headerName, headerValue); + } + + private static String createString(char c, int rep) { + StringBuilder sb = new StringBuilder(); + for (int i = 0; i < rep; i++) { + sb.append(c); } - StringBuilder headerValue = new StringBuilder(); - for (int i = 0; i < headerValueSize; i++) { - headerValue.append('a'); - } - frame.headers().add(headerName.toString(), headerValue.toString()); + return sb.toString(); } protected ChannelFactory newClientSocketChannelFactory(Executor executor) {