Don't loop over TLS records for SNI (#7479)

Motivation:

The AbstractSniHandler previously was willing to tolerate up to three
non-handshake records before a ClientHello that contained an SNI
extension field. This is, so far as I can tell, completely
unnecessary: no TLS implementation will be sending alerts or change
cipher spec messages before ClientHello.

Given that it was not possible to determine why this loop is in
the code to begin with, it's probably just best to remove it.

Modifications:

Remove the for loop.

Result:

The AbstractSniHandler will more rapidly determine whether it should
pass the records on to the default SSL handler.

Co-authored-by: Norman Maurer <norman_maurer@apple.com>
This commit is contained in:
Cory Benfield 2019-07-01 05:22:55 -04:00 committed by Norman Maurer
parent 4596f9e139
commit 14154074f2

View File

@ -42,9 +42,6 @@ import java.util.Locale;
*/ */
public abstract class AbstractSniHandler<T> extends ByteToMessageDecoder implements ChannelOutboundHandler { public abstract class AbstractSniHandler<T> extends ByteToMessageDecoder implements ChannelOutboundHandler {
// Maximal number of ssl records to inspect before fallback to the default SslContext.
private static final int MAX_SSL_RECORDS = 4;
private static final InternalLogger logger = private static final InternalLogger logger =
InternalLoggerFactory.getInstance(AbstractSniHandler.class); InternalLoggerFactory.getInstance(AbstractSniHandler.class);
@ -55,83 +52,75 @@ public abstract class AbstractSniHandler<T> extends ByteToMessageDecoder impleme
@Override @Override
protected void decode(ChannelHandlerContext ctx, ByteBuf in, List<Object> out) throws Exception { protected void decode(ChannelHandlerContext ctx, ByteBuf in, List<Object> out) throws Exception {
if (!suppressRead && !handshakeFailed) { if (!suppressRead && !handshakeFailed) {
final int writerIndex = in.writerIndex();
try { try {
loop: final int readerIndex = in.readerIndex();
for (int i = 0; i < MAX_SSL_RECORDS; i++) { final int readableBytes = in.readableBytes();
final int readerIndex = in.readerIndex(); if (readableBytes < SslUtils.SSL_RECORD_HEADER_LENGTH) {
final int readableBytes = writerIndex - readerIndex; // Not enough data to determine the record type and length.
if (readableBytes < SslUtils.SSL_RECORD_HEADER_LENGTH) { return;
// Not enough data to determine the record type and length. }
return;
}
final int command = in.getUnsignedByte(readerIndex); final int command = in.getUnsignedByte(readerIndex);
switch (command) {
case SslUtils.SSL_CONTENT_TYPE_CHANGE_CIPHER_SPEC:
// fall-through
case SslUtils.SSL_CONTENT_TYPE_ALERT:
final int len = SslUtils.getEncryptedPacketLength(in, readerIndex);
// tls, but not handshake command // Not an SSL/TLS packet
switch (command) { if (len == SslUtils.NOT_ENCRYPTED) {
case SslUtils.SSL_CONTENT_TYPE_CHANGE_CIPHER_SPEC: handshakeFailed = true;
case SslUtils.SSL_CONTENT_TYPE_ALERT: NotSslRecordException e = new NotSslRecordException(
final int len = SslUtils.getEncryptedPacketLength(in, readerIndex); "not an SSL/TLS record: " + ByteBufUtil.hexDump(in));
in.skipBytes(in.readableBytes());
ctx.fireUserEventTriggered(new SniCompletionEvent(e));
SslUtils.handleHandshakeFailure(ctx, e, true);
throw e;
}
if (len == SslUtils.NOT_ENOUGH_DATA) {
// Not enough data
return;
}
// SNI can't be present in an ALERT or CHANGE_CIPHER_SPEC record, so we'll fall back and assume
// no SNI is present. Let's let the actual TLS implementation sort this out.
break;
// Not an SSL/TLS packet case SslUtils.SSL_CONTENT_TYPE_HANDSHAKE:
if (len == SslUtils.NOT_ENCRYPTED) { final int majorVersion = in.getUnsignedByte(readerIndex + 1);
handshakeFailed = true; // SSLv3 or TLS
NotSslRecordException e = new NotSslRecordException( if (majorVersion == 3) {
"not an SSL/TLS record: " + ByteBufUtil.hexDump(in)); final int packetLength = in.getUnsignedShort(readerIndex + 3) +
in.skipBytes(in.readableBytes()); SslUtils.SSL_RECORD_HEADER_LENGTH;
ctx.fireUserEventTriggered(new SniCompletionEvent(e));
SslUtils.handleHandshakeFailure(ctx, e, true); if (readableBytes < packetLength) {
throw e; // client hello incomplete; try again to decode once more data is ready.
}
if (len == SslUtils.NOT_ENOUGH_DATA ||
writerIndex - readerIndex - SslUtils.SSL_RECORD_HEADER_LENGTH < len) {
// Not enough data
return; return;
} }
// increase readerIndex and try again.
in.skipBytes(len);
continue;
case SslUtils.SSL_CONTENT_TYPE_HANDSHAKE:
final int majorVersion = in.getUnsignedByte(readerIndex + 1);
// SSLv3 or TLS // See https://tools.ietf.org/html/rfc5246#section-7.4.1.2
if (majorVersion == 3) { //
final int packetLength = in.getUnsignedShort(readerIndex + 3) + // Decode the ssl client hello packet.
SslUtils.SSL_RECORD_HEADER_LENGTH; // We have to skip bytes until SessionID (which sum to 43 bytes).
//
// struct {
// ProtocolVersion client_version;
// Random random;
// SessionID session_id;
// CipherSuite cipher_suites<2..2^16-2>;
// CompressionMethod compression_methods<1..2^8-1>;
// select (extensions_present) {
// case false:
// struct {};
// case true:
// Extension extensions<0..2^16-1>;
// };
// } ClientHello;
//
if (readableBytes < packetLength) { final int endOffset = readerIndex + packetLength;
// client hello incomplete; try again to decode once more data is ready. int offset = readerIndex + 43;
return;
}
// See https://tools.ietf.org/html/rfc5246#section-7.4.1.2
//
// Decode the ssl client hello packet.
// We have to skip bytes until SessionID (which sum to 43 bytes).
//
// struct {
// ProtocolVersion client_version;
// Random random;
// SessionID session_id;
// CipherSuite cipher_suites<2..2^16-2>;
// CompressionMethod compression_methods<1..2^8-1>;
// select (extensions_present) {
// case false:
// struct {};
// case true:
// Extension extensions<0..2^16-1>;
// };
// } ClientHello;
//
final int endOffset = readerIndex + packetLength;
int offset = readerIndex + 43;
if (endOffset - offset < 6) {
break loop;
}
if (endOffset - offset >= 6) {
final int sessionIdLength = in.getUnsignedByte(offset); final int sessionIdLength = in.getUnsignedByte(offset);
offset += sessionIdLength + 1; offset += sessionIdLength + 1;
@ -145,68 +134,61 @@ public abstract class AbstractSniHandler<T> extends ByteToMessageDecoder impleme
offset += 2; offset += 2;
final int extensionsLimit = offset + extensionsLength; final int extensionsLimit = offset + extensionsLength;
if (extensionsLimit > endOffset) { // Extensions should never exceed the record boundary.
// Extensions should never exceed the record boundary. if (extensionsLimit <= endOffset) {
break loop; while (extensionsLimit - offset >= 4) {
} final int extensionType = in.getUnsignedShort(offset);
for (;;) {
if (extensionsLimit - offset < 4) {
break loop;
}
final int extensionType = in.getUnsignedShort(offset);
offset += 2;
final int extensionLength = in.getUnsignedShort(offset);
offset += 2;
if (extensionsLimit - offset < extensionLength) {
break loop;
}
// SNI
// See https://tools.ietf.org/html/rfc6066#page-6
if (extensionType == 0) {
offset += 2; offset += 2;
if (extensionsLimit - offset < 3) {
break loop; final int extensionLength = in.getUnsignedShort(offset);
offset += 2;
if (extensionsLimit - offset < extensionLength) {
break;
} }
final int serverNameType = in.getUnsignedByte(offset); // SNI
offset++; // See https://tools.ietf.org/html/rfc6066#page-6
if (extensionType == 0) {
if (serverNameType == 0) {
final int serverNameLength = in.getUnsignedShort(offset);
offset += 2; offset += 2;
if (extensionsLimit - offset < 3) {
if (extensionsLimit - offset < serverNameLength) { break;
break loop;
} }
final String hostname = in.toString(offset, serverNameLength, final int serverNameType = in.getUnsignedByte(offset);
CharsetUtil.US_ASCII); offset++;
try { if (serverNameType == 0) {
select(ctx, hostname.toLowerCase(Locale.US)); final int serverNameLength = in.getUnsignedShort(offset);
} catch (Throwable t) { offset += 2;
PlatformDependent.throwException(t);
if (extensionsLimit - offset < serverNameLength) {
break;
}
final String hostname =
in.toString(offset, serverNameLength, CharsetUtil.US_ASCII);
try {
select(ctx, hostname.toLowerCase(Locale.US));
} catch (Throwable t) {
PlatformDependent.throwException(t);
}
return;
} else {
// invalid enum value
break;
} }
return;
} else {
// invalid enum value
break loop;
} }
}
offset += extensionLength; offset += extensionLength;
}
} }
} }
// Fall-through }
default: break;
//not tls, ssl or application data, do not try sni default:
break loop; //not tls, ssl or application data, do not try sni
} break;
} }
} catch (NotSslRecordException e) { } catch (NotSslRecordException e) {
// Just rethrow as in this case we also closed the channel and this is consistent with SslHandler. // Just rethrow as in this case we also closed the channel and this is consistent with SslHandler.