diff --git a/handler/src/main/java/io/netty/handler/ssl/SniHandler.java b/handler/src/main/java/io/netty/handler/ssl/SniHandler.java index 77043fc427..1a9cd1ca96 100644 --- a/handler/src/main/java/io/netty/handler/ssl/SniHandler.java +++ b/handler/src/main/java/io/netty/handler/ssl/SniHandler.java @@ -82,18 +82,25 @@ public class SniHandler extends ByteToMessageDecoder { @Override protected void decode(ChannelHandlerContext ctx, ByteBuf in, List out) throws Exception { - if (!handshakeFailed && in.readableBytes() >= SslUtils.SSL_RECORD_HEADER_LENGTH) { - int writerIndex = in.writerIndex(); - int readerIndex = in.readerIndex(); + if (!handshakeFailed) { + final int writerIndex = in.writerIndex(); try { - loop: for (int i = 0; i < MAX_SSL_RECORDS; i++) { - int command = in.getUnsignedByte(readerIndex); + loop: + for (int i = 0; i < MAX_SSL_RECORDS; i++) { + final int readerIndex = in.readerIndex(); + final int readableBytes = writerIndex - readerIndex; + if (readableBytes < SslUtils.SSL_RECORD_HEADER_LENGTH) { + // Not enough data to determine the record type and length. + return; + } + + final int command = in.getUnsignedByte(readerIndex); // tls, but not handshake command switch (command) { case SslUtils.SSL_CONTENT_TYPE_CHANGE_CIPHER_SPEC: case SslUtils.SSL_CONTENT_TYPE_ALERT: - int len = SslUtils.getEncryptedPacketLength(in, readerIndex); + final int len = SslUtils.getEncryptedPacketLength(in, readerIndex); // Not an SSL/TLS packet if (len == -1) { @@ -111,20 +118,21 @@ public class SniHandler extends ByteToMessageDecoder { return; } // increase readerIndex and try again. - readerIndex += len; + in.skipBytes(len); continue; case SslUtils.SSL_CONTENT_TYPE_HANDSHAKE: - int majorVersion = in.getUnsignedByte(readerIndex + 1); + final int majorVersion = in.getUnsignedByte(readerIndex + 1); // SSLv3 or TLS if (majorVersion == 3) { - int packetLength = in.getUnsignedShort(readerIndex + 3) - + SslUtils.SSL_RECORD_HEADER_LENGTH; + final int packetLength = in.getUnsignedShort(readerIndex + 3) + + SslUtils.SSL_RECORD_HEADER_LENGTH; - if (in.readableBytes() < packetLength) { - // client hello incomplete try again to decode once more data is ready. + if (readableBytes < packetLength) { + // client hello incomplete; try again to decode once more data is ready. return; } + // See https://tools.ietf.org/html/rfc5246#section-7.4.1.2 // // Decode the ssl client hello packet. @@ -144,38 +152,71 @@ public class SniHandler extends ByteToMessageDecoder { // }; // } ClientHello; // + + final int endOffset = readerIndex + packetLength; int offset = readerIndex + 43; - int sessionIdLength = in.getUnsignedByte(offset); + if (endOffset - offset < 6) { + break loop; + } + + final int sessionIdLength = in.getUnsignedByte(offset); offset += sessionIdLength + 1; - int cipherSuitesLength = in.getUnsignedShort(offset); + final int cipherSuitesLength = in.getUnsignedShort(offset); offset += cipherSuitesLength + 2; - int compressionMethodLength = in.getUnsignedByte(offset); + final int compressionMethodLength = in.getUnsignedByte(offset); offset += compressionMethodLength + 1; - int extensionsLength = in.getUnsignedShort(offset); + final int extensionsLength = in.getUnsignedShort(offset); offset += 2; - int extensionsLimit = offset + extensionsLength; + final int extensionsLimit = offset + extensionsLength; - while (offset < extensionsLimit) { - int extensionType = in.getUnsignedShort(offset); + if (extensionsLimit > endOffset) { + // Extensions should never exceed the record boundary. + break loop; + } + + for (;;) { + if (extensionsLimit - offset < 4) { + break loop; + } + + final int extensionType = in.getUnsignedShort(offset); offset += 2; - int extensionLength = in.getUnsignedShort(offset); + final int extensionLength = in.getUnsignedShort(offset); offset += 2; + if (extensionsLimit - offset < extensionLength) { + break loop; + } + // SNI // See https://tools.ietf.org/html/rfc6066#page-6 if (extensionType == 0) { - int serverNameType = in.getUnsignedByte(offset + 2); + offset += 2; + if (extensionsLimit - offset < 3) { + break loop; + } + + final int serverNameType = in.getUnsignedByte(offset); + offset++; + if (serverNameType == 0) { - int serverNameLength = in.getUnsignedShort(offset + 3); - String hostname = in.toString(offset + 5, serverNameLength, - CharsetUtil.UTF_8); + final int serverNameLength = in.getUnsignedShort(offset); + offset += 2; + + if (extensionsLimit - offset < serverNameLength) { + break loop; + } + + final String hostname = in.toString(offset, serverNameLength, + CharsetUtil.UTF_8); + select(ctx, IDN.toASCII(hostname, - IDN.ALLOW_UNASSIGNED).toLowerCase(Locale.US)); + IDN.ALLOW_UNASSIGNED).toLowerCase(Locale.US)); return; } else { // invalid enum value