From 36f86305124e711712be5e587f783f904676e598 Mon Sep 17 00:00:00 2001 From: Trustin Lee Date: Sat, 9 Feb 2013 20:58:55 +0900 Subject: [PATCH] Make SnappyFramedDecoder more robust against corrupt frame --- .../compression/SnappyFramedDecoder.java | 173 ++++++++++-------- 1 file changed, 92 insertions(+), 81 deletions(-) diff --git a/codec/src/main/java/io/netty/handler/codec/compression/SnappyFramedDecoder.java b/codec/src/main/java/io/netty/handler/codec/compression/SnappyFramedDecoder.java index e6cbe54033..444df5d5da 100644 --- a/codec/src/main/java/io/netty/handler/codec/compression/SnappyFramedDecoder.java +++ b/codec/src/main/java/io/netty/handler/codec/compression/SnappyFramedDecoder.java @@ -46,6 +46,7 @@ public class SnappyFramedDecoder extends ByteToByteDecoder { private int chunkLength; private ChunkType chunkType; private boolean started; + private boolean corrupted; /** * Creates a new snappy-framed decoder with validation of checksums @@ -74,93 +75,103 @@ public class SnappyFramedDecoder extends ByteToByteDecoder { return; } - while (in.isReadable()) { - if (chunkLength == 0) { - if (in.readableBytes() < 3) { - // We need to be at least able to read the chunk type identifier (one byte), - // and the length of the chunk (2 bytes) in order to proceed + if (corrupted) { + in.skipBytes(in.readableBytes()); + return; + } + + try { + while (in.isReadable()) { + if (chunkLength == 0) { + if (in.readableBytes() < 3) { + // We need to be at least able to read the chunk type identifier (one byte), + // and the length of the chunk (2 bytes) in order to proceed + return; + } + + byte type = in.readByte(); + chunkType = mapChunkType(type); + chunkLength = in.readUnsignedByte() | in.readUnsignedByte() << 8; + + // The spec mandates that reserved unskippable chunks must immediately + // return an error, as we must assume that we cannot decode the stream + // correctly + if (chunkType == ChunkType.RESERVED_UNSKIPPABLE) { + throw new CompressionException("Found reserved unskippable chunk type: " + type); + } + } + + if (chunkLength == 0 || in.readableBytes() < chunkLength) { + // Wait until the entire chunk is available, as it will prevent us from + // having to buffer the data here instead return; } - byte type = in.readByte(); - chunkType = mapChunkType(type); - chunkLength = in.readUnsignedByte() | in.readUnsignedByte() << 8; + int checksum; - // The spec mandates that reserved unskippable chunks must immediately - // return an error, as we must assume that we cannot decode the stream - // correctly - if (chunkType == ChunkType.RESERVED_UNSKIPPABLE) { - throw new CompressionException("Found reserved unskippable chunk type: " + type); + switch(chunkType) { + case STREAM_IDENTIFIER: + if (chunkLength != SNAPPY.length) { + throw new CompressionException("Unexpected length of stream identifier: " + chunkLength); + } + + byte[] identifier = new byte[chunkLength]; + in.readBytes(identifier); + + if (!Arrays.equals(identifier, SNAPPY)) { + throw new CompressionException("Unexpected stream identifier contents. Mismatched snappy " + + "protocol version?"); + } + + started = true; + + break; + case RESERVED_SKIPPABLE: + if (!started) { + throw new CompressionException("Received RESERVED_SKIPPABLE tag before STREAM_IDENTIFIER"); + } + in.skipBytes(chunkLength); + break; + case UNCOMPRESSED_DATA: + if (!started) { + throw new CompressionException("Received UNCOMPRESSED_DATA tag before STREAM_IDENTIFIER"); + } + checksum = in.readUnsignedByte() + | in.readUnsignedByte() << 8 + | in.readUnsignedByte() << 16 + | in.readUnsignedByte() << 24; + if (validateChecksums) { + ByteBuf data = in.readBytes(chunkLength); + validateChecksum(data, checksum); + out.writeBytes(data); + } else { + in.readBytes(out, chunkLength); + } + break; + case COMPRESSED_DATA: + if (!started) { + throw new CompressionException("Received COMPRESSED_DATA tag before STREAM_IDENTIFIER"); + } + checksum = in.readUnsignedByte() + | in.readUnsignedByte() << 8 + | in.readUnsignedByte() << 16 + | in.readUnsignedByte() << 24; + if (validateChecksums) { + ByteBuf uncompressed = ctx.alloc().buffer(); + snappy.decode(in, uncompressed, chunkLength); + validateChecksum(uncompressed, checksum); + } else { + snappy.decode(in, out, chunkLength); + } + snappy.reset(); + break; } + + chunkLength = 0; } - - if (chunkLength == 0 || in.readableBytes() < chunkLength) { - // Wait until the entire chunk is available, as it will prevent us from - // having to buffer the data here instead - return; - } - - int checksum; - - switch(chunkType) { - case STREAM_IDENTIFIER: - if (chunkLength != SNAPPY.length) { - throw new CompressionException("Unexpected length of stream identifier: " + chunkLength); - } - - byte[] identifier = new byte[chunkLength]; - in.readBytes(identifier); - - if (!Arrays.equals(identifier, SNAPPY)) { - throw new CompressionException("Unexpected stream identifier contents. Mismatched snappy " + - "protocol version?"); - } - - started = true; - - break; - case RESERVED_SKIPPABLE: - if (!started) { - throw new CompressionException("Received RESERVED_SKIPPABLE tag before STREAM_IDENTIFIER"); - } - in.skipBytes(chunkLength); - break; - case UNCOMPRESSED_DATA: - if (!started) { - throw new CompressionException("Received UNCOMPRESSED_DATA tag before STREAM_IDENTIFIER"); - } - checksum = in.readUnsignedByte() - | (in.readUnsignedByte() << 8) - | (in.readUnsignedByte() << 16) - | (in.readUnsignedByte() << 24); - if (validateChecksums) { - ByteBuf data = in.readBytes(chunkLength); - validateChecksum(data, checksum); - out.writeBytes(data); - } else { - in.readBytes(out, chunkLength); - } - break; - case COMPRESSED_DATA: - if (!started) { - throw new CompressionException("Received COMPRESSED_DATA tag before STREAM_IDENTIFIER"); - } - checksum = in.readUnsignedByte() - | (in.readUnsignedByte() << 8) - | (in.readUnsignedByte() << 16) - | (in.readUnsignedByte() << 24); - if (validateChecksums) { - ByteBuf uncompressed = ctx.alloc().buffer(); - snappy.decode(in, uncompressed, chunkLength); - validateChecksum(uncompressed, checksum); - } else { - snappy.decode(in, out, chunkLength); - } - snappy.reset(); - break; - } - - chunkLength = 0; + } catch (Exception e) { + corrupted = true; + throw e; } }