Do not throw IndexOutOfBoundsException on an invalid SSL record

Motivation:

When an SSL record contains an invalid extension data, SniHandler
currently throws an IndexOutOfBoundsException, which is not optimal.

Modifications:

- Do strict index range checks

Result:

No more unnecessary instantiation of exceptions and their stack traces
This commit is contained in:
Trustin Lee 2016-01-29 14:02:21 +09:00 committed by Norman Maurer
parent 3616d9e814
commit c3e5604f59

View File

@ -108,19 +108,25 @@ public class SniHandler extends ByteToMessageDecoder implements ChannelOutboundH
@Override
protected void decode(ChannelHandlerContext ctx, ByteBuf in, List<Object> out) throws Exception {
if (!suppressRead && !handshakeFailed && in.readableBytes() >= SslUtils.SSL_RECORD_HEADER_LENGTH) {
int writerIndex = in.writerIndex();
int readerIndex = in.readerIndex();
if (!suppressRead && !handshakeFailed) {
final int writerIndex = in.writerIndex();
try {
loop:
for (int i = 0; i < MAX_SSL_RECORDS; i++) {
int command = in.getUnsignedByte(readerIndex);
final int readerIndex = in.readerIndex();
final int readableBytes = writerIndex - readerIndex;
if (readableBytes < SslUtils.SSL_RECORD_HEADER_LENGTH) {
// Not enough data to determine the record type and length.
return;
}
final int command = in.getUnsignedByte(readerIndex);
// tls, but not handshake command
switch (command) {
case SslUtils.SSL_CONTENT_TYPE_CHANGE_CIPHER_SPEC:
case SslUtils.SSL_CONTENT_TYPE_ALERT:
int len = SslUtils.getEncryptedPacketLength(in, readerIndex);
final int len = SslUtils.getEncryptedPacketLength(in, readerIndex);
// Not an SSL/TLS packet
if (len == -1) {
@ -138,20 +144,21 @@ public class SniHandler extends ByteToMessageDecoder implements ChannelOutboundH
return;
}
// increase readerIndex and try again.
readerIndex += len;
in.skipBytes(len);
continue;
case SslUtils.SSL_CONTENT_TYPE_HANDSHAKE:
int majorVersion = in.getUnsignedByte(readerIndex + 1);
final int majorVersion = in.getUnsignedByte(readerIndex + 1);
// SSLv3 or TLS
if (majorVersion == 3) {
int packetLength = in.getUnsignedShort(readerIndex + 3)
+ SslUtils.SSL_RECORD_HEADER_LENGTH;
final int packetLength = in.getUnsignedShort(readerIndex + 3) +
SslUtils.SSL_RECORD_HEADER_LENGTH;
if (in.readableBytes() < packetLength) {
// client hello incomplete try again to decode once more data is ready.
if (readableBytes < packetLength) {
// client hello incomplete; try again to decode once more data is ready.
return;
}
// See https://tools.ietf.org/html/rfc5246#section-7.4.1.2
//
// Decode the ssl client hello packet.
@ -171,39 +178,71 @@ public class SniHandler extends ByteToMessageDecoder implements ChannelOutboundH
// };
// } ClientHello;
//
final int endOffset = readerIndex + packetLength;
int offset = readerIndex + 43;
int sessionIdLength = in.getUnsignedByte(offset);
if (endOffset - offset < 6) {
break loop;
}
final int sessionIdLength = in.getUnsignedByte(offset);
offset += sessionIdLength + 1;
int cipherSuitesLength = in.getUnsignedShort(offset);
final int cipherSuitesLength = in.getUnsignedShort(offset);
offset += cipherSuitesLength + 2;
int compressionMethodLength = in.getUnsignedByte(offset);
final int compressionMethodLength = in.getUnsignedByte(offset);
offset += compressionMethodLength + 1;
int extensionsLength = in.getUnsignedShort(offset);
final int extensionsLength = in.getUnsignedShort(offset);
offset += 2;
int extensionsLimit = offset + extensionsLength;
final int extensionsLimit = offset + extensionsLength;
while (offset < extensionsLimit) {
int extensionType = in.getUnsignedShort(offset);
if (extensionsLimit > endOffset) {
// Extensions should never exceed the record boundary.
break loop;
}
for (;;) {
if (extensionsLimit - offset < 4) {
break loop;
}
final int extensionType = in.getUnsignedShort(offset);
offset += 2;
int extensionLength = in.getUnsignedShort(offset);
final int extensionLength = in.getUnsignedShort(offset);
offset += 2;
if (extensionsLimit - offset < extensionLength) {
break loop;
}
// SNI
// See https://tools.ietf.org/html/rfc6066#page-6
if (extensionType == 0) {
int serverNameType = in.getUnsignedByte(offset + 2);
offset += 2;
if (extensionsLimit - offset < 3) {
break loop;
}
final int serverNameType = in.getUnsignedByte(offset);
offset++;
if (serverNameType == 0) {
int serverNameLength = in.getUnsignedShort(offset + 3);
String hostname = in.toString(offset + 5, serverNameLength,
CharsetUtil.UTF_8);
final int serverNameLength = in.getUnsignedShort(offset);
offset += 2;
if (extensionsLimit - offset < serverNameLength) {
break loop;
}
final String hostname = in.toString(offset, serverNameLength,
CharsetUtil.UTF_8);
select(ctx, IDN.toASCII(hostname,
IDN.ALLOW_UNASSIGNED).toLowerCase(Locale.US));
IDN.ALLOW_UNASSIGNED).toLowerCase(Locale.US));
return;
} else {
// invalid enum value