Make SnappyFramedDecoder more robust against corrupt frame

This commit is contained in:
Trustin Lee 2013-02-09 20:58:55 +09:00
parent 2ac7983471
commit 36f8630512

View File

@ -46,6 +46,7 @@ public class SnappyFramedDecoder extends ByteToByteDecoder {
private int chunkLength;
private ChunkType chunkType;
private boolean started;
private boolean corrupted;
/**
* Creates a new snappy-framed decoder with validation of checksums
@ -74,93 +75,103 @@ public class SnappyFramedDecoder extends ByteToByteDecoder {
return;
}
while (in.isReadable()) {
if (chunkLength == 0) {
if (in.readableBytes() < 3) {
// We need to be at least able to read the chunk type identifier (one byte),
// and the length of the chunk (2 bytes) in order to proceed
if (corrupted) {
in.skipBytes(in.readableBytes());
return;
}
try {
while (in.isReadable()) {
if (chunkLength == 0) {
if (in.readableBytes() < 3) {
// We need to be at least able to read the chunk type identifier (one byte),
// and the length of the chunk (2 bytes) in order to proceed
return;
}
byte type = in.readByte();
chunkType = mapChunkType(type);
chunkLength = in.readUnsignedByte() | in.readUnsignedByte() << 8;
// The spec mandates that reserved unskippable chunks must immediately
// return an error, as we must assume that we cannot decode the stream
// correctly
if (chunkType == ChunkType.RESERVED_UNSKIPPABLE) {
throw new CompressionException("Found reserved unskippable chunk type: " + type);
}
}
if (chunkLength == 0 || in.readableBytes() < chunkLength) {
// Wait until the entire chunk is available, as it will prevent us from
// having to buffer the data here instead
return;
}
byte type = in.readByte();
chunkType = mapChunkType(type);
chunkLength = in.readUnsignedByte() | in.readUnsignedByte() << 8;
int checksum;
// The spec mandates that reserved unskippable chunks must immediately
// return an error, as we must assume that we cannot decode the stream
// correctly
if (chunkType == ChunkType.RESERVED_UNSKIPPABLE) {
throw new CompressionException("Found reserved unskippable chunk type: " + type);
switch(chunkType) {
case STREAM_IDENTIFIER:
if (chunkLength != SNAPPY.length) {
throw new CompressionException("Unexpected length of stream identifier: " + chunkLength);
}
byte[] identifier = new byte[chunkLength];
in.readBytes(identifier);
if (!Arrays.equals(identifier, SNAPPY)) {
throw new CompressionException("Unexpected stream identifier contents. Mismatched snappy " +
"protocol version?");
}
started = true;
break;
case RESERVED_SKIPPABLE:
if (!started) {
throw new CompressionException("Received RESERVED_SKIPPABLE tag before STREAM_IDENTIFIER");
}
in.skipBytes(chunkLength);
break;
case UNCOMPRESSED_DATA:
if (!started) {
throw new CompressionException("Received UNCOMPRESSED_DATA tag before STREAM_IDENTIFIER");
}
checksum = in.readUnsignedByte()
| in.readUnsignedByte() << 8
| in.readUnsignedByte() << 16
| in.readUnsignedByte() << 24;
if (validateChecksums) {
ByteBuf data = in.readBytes(chunkLength);
validateChecksum(data, checksum);
out.writeBytes(data);
} else {
in.readBytes(out, chunkLength);
}
break;
case COMPRESSED_DATA:
if (!started) {
throw new CompressionException("Received COMPRESSED_DATA tag before STREAM_IDENTIFIER");
}
checksum = in.readUnsignedByte()
| in.readUnsignedByte() << 8
| in.readUnsignedByte() << 16
| in.readUnsignedByte() << 24;
if (validateChecksums) {
ByteBuf uncompressed = ctx.alloc().buffer();
snappy.decode(in, uncompressed, chunkLength);
validateChecksum(uncompressed, checksum);
} else {
snappy.decode(in, out, chunkLength);
}
snappy.reset();
break;
}
chunkLength = 0;
}
if (chunkLength == 0 || in.readableBytes() < chunkLength) {
// Wait until the entire chunk is available, as it will prevent us from
// having to buffer the data here instead
return;
}
int checksum;
switch(chunkType) {
case STREAM_IDENTIFIER:
if (chunkLength != SNAPPY.length) {
throw new CompressionException("Unexpected length of stream identifier: " + chunkLength);
}
byte[] identifier = new byte[chunkLength];
in.readBytes(identifier);
if (!Arrays.equals(identifier, SNAPPY)) {
throw new CompressionException("Unexpected stream identifier contents. Mismatched snappy " +
"protocol version?");
}
started = true;
break;
case RESERVED_SKIPPABLE:
if (!started) {
throw new CompressionException("Received RESERVED_SKIPPABLE tag before STREAM_IDENTIFIER");
}
in.skipBytes(chunkLength);
break;
case UNCOMPRESSED_DATA:
if (!started) {
throw new CompressionException("Received UNCOMPRESSED_DATA tag before STREAM_IDENTIFIER");
}
checksum = in.readUnsignedByte()
| (in.readUnsignedByte() << 8)
| (in.readUnsignedByte() << 16)
| (in.readUnsignedByte() << 24);
if (validateChecksums) {
ByteBuf data = in.readBytes(chunkLength);
validateChecksum(data, checksum);
out.writeBytes(data);
} else {
in.readBytes(out, chunkLength);
}
break;
case COMPRESSED_DATA:
if (!started) {
throw new CompressionException("Received COMPRESSED_DATA tag before STREAM_IDENTIFIER");
}
checksum = in.readUnsignedByte()
| (in.readUnsignedByte() << 8)
| (in.readUnsignedByte() << 16)
| (in.readUnsignedByte() << 24);
if (validateChecksums) {
ByteBuf uncompressed = ctx.alloc().buffer();
snappy.decode(in, uncompressed, chunkLength);
validateChecksum(uncompressed, checksum);
} else {
snappy.decode(in, out, chunkLength);
}
snappy.reset();
break;
}
chunkLength = 0;
} catch (Exception e) {
corrupted = true;
throw e;
}
}