Fix SnappyFramedEncoder/Decoder / Fix Snappy preamble encoding / Add test for #1002

- The new test still fails due to a bug in Snappy.encode/decode()
This commit is contained in:
Trustin Lee 2013-02-09 23:39:33 +09:00
parent 36f8630512
commit 319b7fa69a
4 changed files with 148 additions and 97 deletions

View File

@ -41,12 +41,13 @@ public class Snappy {
public void encode(ByteBuf in, ByteBuf out, int length) { public void encode(ByteBuf in, ByteBuf out, int length) {
// Write the preamble length to the output buffer // Write the preamble length to the output buffer
int bytesToEncode = 1 + bitsToEncode(length - 1) / 7; for (int i = 0;; i ++) {
for (int i = 0; i < bytesToEncode; i++) { int b = length >>> i * 7;
if (i == bytesToEncode - 1) { if ((b & 0xFFFFFF80) != 0) {
out.writeByte(length >> i * 7); out.writeByte(b & 0x7f | 0x80);
} else { } else {
out.writeByte(0x80 | length >> i * 7); out.writeByte(b);
break;
} }
} }
@ -301,10 +302,10 @@ public class Snappy {
private static int readPreamble(ByteBuf in) { private static int readPreamble(ByteBuf in) {
int length = 0; int length = 0;
int byteIndex = 0; int byteIndex = 0;
while (in.readableBytes() > 0) { while (in.isReadable()) {
int current = in.readByte() & 0x0ff; int current = in.readUnsignedByte();
length += current << byteIndex++ * 7; length |= (current & 0x7f) << byteIndex++ * 7;
if ((current & 0x80) != 0x80) { if ((current & 0x80) == 0) {
return length; return length;
} }
@ -334,18 +335,18 @@ public class Snappy {
break; break;
case 61: case 61:
length = in.readUnsignedByte() length = in.readUnsignedByte()
| (in.readUnsignedByte() << 8); | in.readUnsignedByte() << 8;
break; break;
case 62: case 62:
length = in.readUnsignedByte() length = in.readUnsignedByte()
| (in.readUnsignedByte() << 8) | in.readUnsignedByte() << 8
| (in.readUnsignedByte() << 16); | in.readUnsignedByte() << 16;
break; break;
case 64: case 64:
length = in.readUnsignedByte() length = in.readUnsignedByte()
| (in.readUnsignedByte() << 8) | in.readUnsignedByte() << 8
| (in.readUnsignedByte() << 16) | in.readUnsignedByte() << 16
| (in.readUnsignedByte() << 24); | in.readUnsignedByte() << 24;
break; break;
default: default:
length = tag >> 2 & 0x3F; length = tag >> 2 & 0x3F;

View File

@ -43,8 +43,6 @@ public class SnappyFramedDecoder extends ByteToByteDecoder {
private final Snappy snappy = new Snappy(); private final Snappy snappy = new Snappy();
private final boolean validateChecksums; private final boolean validateChecksums;
private int chunkLength;
private ChunkType chunkType;
private boolean started; private boolean started;
private boolean corrupted; private boolean corrupted;
@ -71,52 +69,38 @@ public class SnappyFramedDecoder extends ByteToByteDecoder {
@Override @Override
protected void decode(ChannelHandlerContext ctx, ByteBuf in, ByteBuf out) throws Exception { protected void decode(ChannelHandlerContext ctx, ByteBuf in, ByteBuf out) throws Exception {
if (!in.isReadable()) {
return;
}
if (corrupted) { if (corrupted) {
in.skipBytes(in.readableBytes()); in.skipBytes(in.readableBytes());
return; return;
} }
try { try {
while (in.isReadable()) { int idx = in.readerIndex();
if (chunkLength == 0) { final int inSize = in.writerIndex() - idx;
if (in.readableBytes() < 3) { if (inSize < 4) {
// We need to be at least able to read the chunk type identifier (one byte), // We need to be at least able to read the chunk type identifier (one byte),
// and the length of the chunk (2 bytes) in order to proceed // and the length of the chunk (3 bytes) in order to proceed
return; return;
} }
byte type = in.readByte(); final int chunkTypeVal = in.getUnsignedByte(idx);
chunkType = mapChunkType(type); final ChunkType chunkType = mapChunkType((byte) chunkTypeVal);
chunkLength = in.readUnsignedByte() | in.readUnsignedByte() << 8; final int chunkLength = in.getUnsignedByte(idx + 1)
| in.getUnsignedByte(idx + 2) << 8
| in.getUnsignedByte(idx + 3) << 16;
// The spec mandates that reserved unskippable chunks must immediately switch (chunkType) {
// return an error, as we must assume that we cannot decode the stream
// correctly
if (chunkType == ChunkType.RESERVED_UNSKIPPABLE) {
throw new CompressionException("Found reserved unskippable chunk type: " + type);
}
}
if (chunkLength == 0 || in.readableBytes() < chunkLength) {
// Wait until the entire chunk is available, as it will prevent us from
// having to buffer the data here instead
return;
}
int checksum;
switch(chunkType) {
case STREAM_IDENTIFIER: case STREAM_IDENTIFIER:
if (chunkLength != SNAPPY.length) { if (chunkLength != SNAPPY.length) {
throw new CompressionException("Unexpected length of stream identifier: " + chunkLength); throw new CompressionException("Unexpected length of stream identifier: " + chunkLength);
} }
if (inSize < 4 + SNAPPY.length) {
break;
}
byte[] identifier = new byte[chunkLength]; byte[] identifier = new byte[chunkLength];
in.readBytes(identifier); in.skipBytes(4).readBytes(identifier);
if (!Arrays.equals(identifier, SNAPPY)) { if (!Arrays.equals(identifier, SNAPPY)) {
throw new CompressionException("Unexpected stream identifier contents. Mismatched snappy " + throw new CompressionException("Unexpected stream identifier contents. Mismatched snappy " +
@ -124,35 +108,62 @@ public class SnappyFramedDecoder extends ByteToByteDecoder {
} }
started = true; started = true;
break; break;
case RESERVED_SKIPPABLE: case RESERVED_SKIPPABLE:
if (!started) { if (!started) {
throw new CompressionException("Received RESERVED_SKIPPABLE tag before STREAM_IDENTIFIER"); throw new CompressionException("Received RESERVED_SKIPPABLE tag before STREAM_IDENTIFIER");
} }
in.skipBytes(chunkLength);
if (inSize < 4 + chunkLength) {
// TODO: Don't keep skippable bytes
return;
}
in.skipBytes(4 + chunkLength);
break; break;
case RESERVED_UNSKIPPABLE:
// The spec mandates that reserved unskippable chunks must immediately
// return an error, as we must assume that we cannot decode the stream
// correctly
throw new CompressionException(
"Found reserved unskippable chunk type: 0x" + Integer.toHexString(chunkTypeVal));
case UNCOMPRESSED_DATA: case UNCOMPRESSED_DATA:
if (!started) { if (!started) {
throw new CompressionException("Received UNCOMPRESSED_DATA tag before STREAM_IDENTIFIER"); throw new CompressionException("Received UNCOMPRESSED_DATA tag before STREAM_IDENTIFIER");
} }
checksum = in.readUnsignedByte() if (chunkLength > 65536 + 4) {
throw new CompressionException("Received UNCOMPRESSED_DATA larger than 65540 bytes");
}
if (inSize < 4 + chunkLength) {
return;
}
in.skipBytes(4);
if (validateChecksums) {
int checksum = in.readUnsignedByte()
| in.readUnsignedByte() << 8 | in.readUnsignedByte() << 8
| in.readUnsignedByte() << 16 | in.readUnsignedByte() << 16
| in.readUnsignedByte() << 24; | in.readUnsignedByte() << 24;
if (validateChecksums) { ByteBuf data = in.readSlice(chunkLength - 4);
ByteBuf data = in.readBytes(chunkLength);
validateChecksum(data, checksum); validateChecksum(data, checksum);
out.writeBytes(data); out.writeBytes(data);
} else { } else {
in.readBytes(out, chunkLength); in.skipBytes(4);
in.readBytes(out, chunkLength - 4);
} }
break; break;
case COMPRESSED_DATA: case COMPRESSED_DATA:
if (!started) { if (!started) {
throw new CompressionException("Received COMPRESSED_DATA tag before STREAM_IDENTIFIER"); throw new CompressionException("Received COMPRESSED_DATA tag before STREAM_IDENTIFIER");
} }
checksum = in.readUnsignedByte()
if (inSize < 4 + chunkLength) {
return;
}
in.skipBytes(4);
int checksum = in.readUnsignedByte()
| in.readUnsignedByte() << 8 | in.readUnsignedByte() << 8
| in.readUnsignedByte() << 16 | in.readUnsignedByte() << 16
| in.readUnsignedByte() << 24; | in.readUnsignedByte() << 24;
@ -166,9 +177,6 @@ public class SnappyFramedDecoder extends ByteToByteDecoder {
snappy.reset(); snappy.reset();
break; break;
} }
chunkLength = 0;
}
} catch (Exception e) { } catch (Exception e) {
corrupted = true; corrupted = true;
throw e; throw e;

View File

@ -15,12 +15,12 @@
*/ */
package io.netty.handler.codec.compression; package io.netty.handler.codec.compression;
import static io.netty.handler.codec.compression.SnappyChecksumUtil.calculateChecksum;
import io.netty.buffer.ByteBuf; import io.netty.buffer.ByteBuf;
import io.netty.channel.ChannelHandlerContext; import io.netty.channel.ChannelHandlerContext;
import io.netty.handler.codec.ByteToByteEncoder; import io.netty.handler.codec.ByteToByteEncoder;
import static io.netty.handler.codec.compression.SnappyChecksumUtil.*;
/** /**
* Compresses a {@link ByteBuf} using the Snappy framing format. * Compresses a {@link ByteBuf} using the Snappy framing format.
* *
@ -39,7 +39,7 @@ public class SnappyFramedEncoder extends ByteToByteEncoder {
* type 0xff, a length field of 0x6, and 'sNaPpY' in ASCII. * type 0xff, a length field of 0x6, and 'sNaPpY' in ASCII.
*/ */
private static final byte[] STREAM_START = { private static final byte[] STREAM_START = {
-0x80, 0x06, 0x00, 0x73, 0x4e, 0x61, 0x50, 0x70, 0x59 -0x80, 0x06, 0x00, 0x00, 0x73, 0x4e, 0x61, 0x50, 0x70, 0x59
}; };
private final Snappy snappy = new Snappy(); private final Snappy snappy = new Snappy();
@ -56,29 +56,43 @@ public class SnappyFramedEncoder extends ByteToByteEncoder {
out.writeBytes(STREAM_START); out.writeBytes(STREAM_START);
} }
final int chunkLength = in.readableBytes(); int dataLength = in.readableBytes();
if (chunkLength > MIN_COMPRESSIBLE_LENGTH) { if (dataLength > MIN_COMPRESSIBLE_LENGTH) {
// If we have lots of available data, break it up into smaller chunks for (;;) {
int numberOfChunks = 1 + chunkLength / Short.MAX_VALUE; final int lengthIdx = out.writerIndex() + 1;
for (int i = 0; i < numberOfChunks; i++) { out.writeInt(0);
int subChunkLength = Math.min(Short.MAX_VALUE, chunkLength); if (dataLength > 65536) {
out.writeByte(0); ByteBuf slice = in.readSlice(65536);
writeChunkLength(out, subChunkLength);
ByteBuf slice = in.slice(in.readerIndex(), subChunkLength);
calculateAndWriteChecksum(slice, out); calculateAndWriteChecksum(slice, out);
snappy.encode(slice, out, 65536);
snappy.encode(slice, out, subChunkLength); setChunkLength(out, lengthIdx);
dataLength -= 65536;
in.readerIndex(slice.readerIndex()); } else {
ByteBuf slice = in.readSlice(dataLength);
calculateAndWriteChecksum(slice, out);
snappy.encode(slice, out, dataLength);
setChunkLength(out, lengthIdx);
break;
}
} }
} else { } else {
out.writeByte(1); out.writeByte(1);
writeChunkLength(out, chunkLength); writeChunkLength(out, dataLength + 4);
calculateAndWriteChecksum(in, out); calculateAndWriteChecksum(in, out);
out.writeBytes(in, chunkLength); out.writeBytes(in, dataLength);
} }
} }
private static void setChunkLength(ByteBuf out, int lengthIdx) {
int chunkLength = out.writerIndex() - lengthIdx - 3;
if (chunkLength >>> 24 != 0) {
throw new CompressionException("compressed data too large: " + chunkLength);
}
out.setByte(lengthIdx, chunkLength & 0xff);
out.setByte(lengthIdx + 1, chunkLength >>> 8 & 0xff);
out.setByte(lengthIdx + 2, chunkLength >>> 16 & 0xff);
}
/** /**
* Writes the 2-byte chunk length to the output buffer. * Writes the 2-byte chunk length to the output buffer.
* *
@ -86,8 +100,8 @@ public class SnappyFramedEncoder extends ByteToByteEncoder {
* @param chunkLength The length to write * @param chunkLength The length to write
*/ */
private static void writeChunkLength(ByteBuf out, int chunkLength) { private static void writeChunkLength(ByteBuf out, int chunkLength) {
out.writeByte(chunkLength & 0x0ff); out.writeByte(chunkLength & 0xff);
out.writeByte(chunkLength >> 8 & 0x0ff); out.writeByte(chunkLength >>> 8 & 0xff);
} }
/** /**

View File

@ -19,10 +19,11 @@ import io.netty.buffer.ByteBuf;
import io.netty.buffer.Unpooled; import io.netty.buffer.Unpooled;
import io.netty.channel.embedded.EmbeddedByteChannel; import io.netty.channel.embedded.EmbeddedByteChannel;
import io.netty.util.CharsetUtil; import io.netty.util.CharsetUtil;
import org.junit.Ignore;
import org.junit.Assert;
import org.junit.Test; import org.junit.Test;
import static org.junit.Assert.*;
public class SnappyIntegrationTest { public class SnappyIntegrationTest {
private final EmbeddedByteChannel channel = new EmbeddedByteChannel(new SnappyFramedEncoder(), private final EmbeddedByteChannel channel = new EmbeddedByteChannel(new SnappyFramedEncoder(),
new SnappyFramedDecoder()); new SnappyFramedDecoder());
@ -30,15 +31,42 @@ public class SnappyIntegrationTest {
@Test @Test
public void testEncoderDecoderIdentity() throws Exception { public void testEncoderDecoderIdentity() throws Exception {
ByteBuf in = Unpooled.copiedBuffer( ByteBuf in = Unpooled.copiedBuffer(
"Netty has been designed carefully with the experiences " + "Netty has been designed carefully with the experiences earned from the implementation of a lot of " +
"earned from the implementation of a lot of protocols " + "protocols such as FTP, SMTP, HTTP, and various binary and text-based legacy protocols",
"such as FTP, SMTP, HTTP, and various binary and " + CharsetUtil.US_ASCII);
"text-based legacy protocols", CharsetUtil.US_ASCII
);
channel.writeOutbound(in);
channel.writeOutbound(in.copy());
channel.writeInbound(channel.readOutbound()); channel.writeInbound(channel.readOutbound());
assertEquals(in, channel.readInbound());
}
Assert.assertEquals(in, channel.readInbound()); @Test
@Ignore // FIXME: Make it pass.
public void testEncoderDecoderIdentity2() throws Exception {
// Data from https://github.com/netty/netty/issues/1002
ByteBuf in = Unpooled.wrappedBuffer(new byte[] {
11, 0, 0, 0, 0, 0, 16, 65, 96, 119, -22, 79, -43, 76, -75, -93,
11, 104, 96, -99, 126, -98, 27, -36, 40, 117, -65, -3, -57, -83, -58, 7,
114, -14, 68, -122, 124, 88, 118, 54, 45, -26, 117, 13, -45, -9, 60, -73,
-53, -44, 53, 68, -77, -71, 109, 43, -38, 59, 100, -12, -87, 44, -106, 123,
-107, 38, 13, -117, -23, -49, 29, 21, 26, 66, 1, -1, -1, -1, -1, -1,
-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,
-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,
-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,
-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, 66, 0, -104, -49,
16, -120, 22, 8, -52, -54, -102, -52, -119, -124, -92, -71, 101, -120, -52, -48,
45, -26, -24, 26, 41, -13, 36, 64, -47, 15, -124, -7, -16, 91, 96, 0,
-93, -42, 101, 20, -74, 39, -124, 35, 43, -49, -21, -92, -20, -41, 79, 41,
110, -105, 42, -96, 90, -9, -100, -22, -62, 91, 2, 35, 113, 117, -71, 66,
1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,
-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,
-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,
-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,
-1, -1
});
channel.writeOutbound(in.copy());
channel.writeInbound(channel.readOutbound());
assertEquals(in, channel.readInbound());
} }
} }