From 29c471ec52a2d42fb395619cf20d76d924174b6e Mon Sep 17 00:00:00 2001 From: Norman Maurer Date: Fri, 29 Nov 2019 09:17:43 +0100 Subject: [PATCH] Correctly handle fragmented Handshake message when trying to detect SNI (#9806) Motivation: At the moment our AbstractSniHandler makes the assemption that Handshake messages are not fragmented. This is incorrect as it is completely valid to split these across multiple TLSPlaintext records. Thanks to @sskrobotov for bringing this to my attentation and to @Lukasa for the help. Modifications: - Adjust logic in AbstractSniHandler to handle fragmentation - Add unit tests Result: Correctly handle fragmented Handshake message in AbstractSniHandler (and so SniHandler). --- .../netty/handler/ssl/AbstractSniHandler.java | 336 +++++++++++------- .../io/netty/handler/ssl/SniHandlerTest.java | 80 ++++- 2 files changed, 281 insertions(+), 135 deletions(-) diff --git a/handler/src/main/java/io/netty/handler/ssl/AbstractSniHandler.java b/handler/src/main/java/io/netty/handler/ssl/AbstractSniHandler.java index bb5423daac..b06389ce38 100644 --- a/handler/src/main/java/io/netty/handler/ssl/AbstractSniHandler.java +++ b/handler/src/main/java/io/netty/handler/ssl/AbstractSniHandler.java @@ -24,7 +24,6 @@ import io.netty.handler.codec.DecoderException; import io.netty.util.CharsetUtil; import io.netty.util.concurrent.Future; import io.netty.util.concurrent.FutureListener; -import io.netty.util.internal.PlatformDependent; import io.netty.util.internal.logging.InternalLogger; import io.netty.util.internal.logging.InternalLoggerFactory; @@ -47,147 +46,119 @@ public abstract class AbstractSniHandler extends ByteToMessageDecoder { private boolean handshakeFailed; private boolean suppressRead; private boolean readPending; + private ByteBuf handshakeBuffer; @Override protected void decode(ChannelHandlerContext ctx, ByteBuf in, List out) throws Exception { if (!suppressRead && !handshakeFailed) { try { - final int readerIndex = in.readerIndex(); - final int readableBytes = in.readableBytes(); - if (readableBytes < SslUtils.SSL_RECORD_HEADER_LENGTH) { - // Not enough data to determine the record type and length. - return; - } + int readerIndex = in.readerIndex(); + int readableBytes = in.readableBytes(); + int handshakeLength = -1; - final int command = in.getUnsignedByte(readerIndex); - switch (command) { - case SslUtils.SSL_CONTENT_TYPE_CHANGE_CIPHER_SPEC: - // fall-through - case SslUtils.SSL_CONTENT_TYPE_ALERT: - final int len = SslUtils.getEncryptedPacketLength(in, readerIndex); + // Check if we have enough data to determine the record type and length. + while (readableBytes >= SslUtils.SSL_RECORD_HEADER_LENGTH) { + final int contentType = in.getUnsignedByte(readerIndex); + switch (contentType) { + case SslUtils.SSL_CONTENT_TYPE_CHANGE_CIPHER_SPEC: + // fall-through + case SslUtils.SSL_CONTENT_TYPE_ALERT: + final int len = SslUtils.getEncryptedPacketLength(in, readerIndex); - // Not an SSL/TLS packet - if (len == SslUtils.NOT_ENCRYPTED) { - handshakeFailed = true; - NotSslRecordException e = new NotSslRecordException( - "not an SSL/TLS record: " + ByteBufUtil.hexDump(in)); - in.skipBytes(in.readableBytes()); - ctx.fireUserEventTriggered(new SniCompletionEvent(e)); - SslUtils.handleHandshakeFailure(ctx, e, true); - throw e; - } - if (len == SslUtils.NOT_ENOUGH_DATA) { - // Not enough data - return; - } - // SNI can't be present in an ALERT or CHANGE_CIPHER_SPEC record, so we'll fall back and assume - // no SNI is present. Let's let the actual TLS implementation sort this out. - break; - - case SslUtils.SSL_CONTENT_TYPE_HANDSHAKE: - final int majorVersion = in.getUnsignedByte(readerIndex + 1); - // SSLv3 or TLS - if (majorVersion == 3) { - final int packetLength = in.getUnsignedShort(readerIndex + 3) + - SslUtils.SSL_RECORD_HEADER_LENGTH; - - if (readableBytes < packetLength) { - // client hello incomplete; try again to decode once more data is ready. + // Not an SSL/TLS packet + if (len == SslUtils.NOT_ENCRYPTED) { + handshakeFailed = true; + NotSslRecordException e = new NotSslRecordException( + "not an SSL/TLS record: " + ByteBufUtil.hexDump(in)); + in.skipBytes(in.readableBytes()); + ctx.fireUserEventTriggered(new SniCompletionEvent(e)); + SslUtils.handleHandshakeFailure(ctx, e, true); + throw e; + } + if (len == SslUtils.NOT_ENOUGH_DATA) { + // Not enough data return; } + // SNI can't be present in an ALERT or CHANGE_CIPHER_SPEC record, so we'll fall back and + // assume no SNI is present. Let's let the actual TLS implementation sort this out. + // Just select the default SslContext + select(ctx, null); + return; + case SslUtils.SSL_CONTENT_TYPE_HANDSHAKE: + final int majorVersion = in.getUnsignedByte(readerIndex + 1); + // SSLv3 or TLS + if (majorVersion == 3) { + int packetLength = in.getUnsignedShort(readerIndex + 3) + + SslUtils.SSL_RECORD_HEADER_LENGTH; - // See https://tools.ietf.org/html/rfc5246#section-7.4.1.2 - // - // Decode the ssl client hello packet. - // We have to skip bytes until SessionID (which sum to 43 bytes). - // - // struct { - // ProtocolVersion client_version; - // Random random; - // SessionID session_id; - // CipherSuite cipher_suites<2..2^16-2>; - // CompressionMethod compression_methods<1..2^8-1>; - // select (extensions_present) { - // case false: - // struct {}; - // case true: - // Extension extensions<0..2^16-1>; - // }; - // } ClientHello; - // + if (readableBytes < packetLength) { + // client hello incomplete; try again to decode once more data is ready. + return; + } else if (packetLength == SslUtils.SSL_RECORD_HEADER_LENGTH) { + select(ctx, null); + return; + } - final int endOffset = readerIndex + packetLength; - int offset = readerIndex + 43; + final int endOffset = readerIndex + packetLength; - if (endOffset - offset >= 6) { - final int sessionIdLength = in.getUnsignedByte(offset); - offset += sessionIdLength + 1; + // Let's check if we already parsed the handshake length or not. + if (handshakeLength == -1) { + if (readerIndex + 4 > endOffset) { + // Need more data to read HandshakeType and handshakeLength (4 bytes) + return; + } - final int cipherSuitesLength = in.getUnsignedShort(offset); - offset += cipherSuitesLength + 2; + final int handshakeType = in.getUnsignedByte(readerIndex + + SslUtils.SSL_RECORD_HEADER_LENGTH); - final int compressionMethodLength = in.getUnsignedByte(offset); - offset += compressionMethodLength + 1; + // Check if this is a clientHello(1) + // See https://tools.ietf.org/html/rfc5246#section-7.4 + if (handshakeType != 1) { + select(ctx, null); + return; + } - final int extensionsLength = in.getUnsignedShort(offset); - offset += 2; - final int extensionsLimit = offset + extensionsLength; + // Read the length of the handshake as it may arrive in fragments + // See https://tools.ietf.org/html/rfc5246#section-7.4 + handshakeLength = in.getUnsignedMedium(readerIndex + + SslUtils.SSL_RECORD_HEADER_LENGTH + 1); - // Extensions should never exceed the record boundary. - if (extensionsLimit <= endOffset) { - while (extensionsLimit - offset >= 4) { - final int extensionType = in.getUnsignedShort(offset); - offset += 2; + // Consume handshakeType and handshakeLength (this sums up as 4 bytes) + readerIndex += 4; + packetLength -= 4; - final int extensionLength = in.getUnsignedShort(offset); - offset += 2; - - if (extensionsLimit - offset < extensionLength) { - break; + if (handshakeLength + 4 + SslUtils.SSL_RECORD_HEADER_LENGTH <= packetLength) { + // We have everything we need in one packet. + // Skip the record header + readerIndex += SslUtils.SSL_RECORD_HEADER_LENGTH; + select(ctx, extractSniHostname(in, readerIndex, readerIndex + handshakeLength)); + return; + } else { + if (handshakeBuffer == null) { + handshakeBuffer = ctx.alloc().buffer(handshakeLength); + } else { + // Clear the buffer so we can aggregate into it again. + handshakeBuffer.clear(); } - - // SNI - // See https://tools.ietf.org/html/rfc6066#page-6 - if (extensionType == 0) { - offset += 2; - if (extensionsLimit - offset < 3) { - break; - } - - final int serverNameType = in.getUnsignedByte(offset); - offset++; - - if (serverNameType == 0) { - final int serverNameLength = in.getUnsignedShort(offset); - offset += 2; - - if (extensionsLimit - offset < serverNameLength) { - break; - } - - final String hostname = - in.toString(offset, serverNameLength, CharsetUtil.US_ASCII); - try { - select(ctx, hostname.toLowerCase(Locale.US)); - } catch (Throwable t) { - PlatformDependent.throwException(t); - } - return; - } else { - // invalid enum value - break; - } - } - - offset += extensionLength; } } + + // Combine the encapsulated data in one buffer but not include the SSL_RECORD_HEADER + handshakeBuffer.writeBytes(in, readerIndex + SslUtils.SSL_RECORD_HEADER_LENGTH, + packetLength - SslUtils.SSL_RECORD_HEADER_LENGTH); + readerIndex += packetLength; + readableBytes -= packetLength; + if (handshakeLength <= handshakeBuffer.readableBytes()) { + select(ctx, extractSniHostname(handshakeBuffer, 0, handshakeLength)); + return; + } } - } - break; - default: - //not tls, ssl or application data, do not try sni - break; + break; + default: + // not tls, ssl or application data, do not try sni + select(ctx, null); + return; + } } } catch (NotSslRecordException e) { // Just rethrow as in this case we also closed the channel and this is consistent with SslHandler. @@ -197,13 +168,105 @@ public abstract class AbstractSniHandler extends ByteToMessageDecoder { if (logger.isDebugEnabled()) { logger.debug("Unexpected client hello packet: " + ByteBufUtil.hexDump(in), e); } + select(ctx, null); } - // Just select the default SslContext - select(ctx, null); + } + } + + private static String extractSniHostname(ByteBuf in, int offset, int endOffset) { + // See https://tools.ietf.org/html/rfc5246#section-7.4.1.2 + // + // Decode the ssl client hello packet. + // + // struct { + // ProtocolVersion client_version; + // Random random; + // SessionID session_id; + // CipherSuite cipher_suites<2..2^16-2>; + // CompressionMethod compression_methods<1..2^8-1>; + // select (extensions_present) { + // case false: + // struct {}; + // case true: + // Extension extensions<0..2^16-1>; + // }; + // } ClientHello; + // + + // We have to skip bytes until SessionID (which sum to 34 bytes in this case). + offset += 34; + + if (endOffset - offset >= 6) { + final int sessionIdLength = in.getUnsignedByte(offset); + offset += sessionIdLength + 1; + + final int cipherSuitesLength = in.getUnsignedShort(offset); + offset += cipherSuitesLength + 2; + + final int compressionMethodLength = in.getUnsignedByte(offset); + offset += compressionMethodLength + 1; + + final int extensionsLength = in.getUnsignedShort(offset); + offset += 2; + final int extensionsLimit = offset + extensionsLength; + + // Extensions should never exceed the record boundary. + if (extensionsLimit <= endOffset) { + while (extensionsLimit - offset >= 4) { + final int extensionType = in.getUnsignedShort(offset); + offset += 2; + + final int extensionLength = in.getUnsignedShort(offset); + offset += 2; + + if (extensionsLimit - offset < extensionLength) { + break; + } + + // SNI + // See https://tools.ietf.org/html/rfc6066#page-6 + if (extensionType == 0) { + offset += 2; + if (extensionsLimit - offset < 3) { + break; + } + + final int serverNameType = in.getUnsignedByte(offset); + offset++; + + if (serverNameType == 0) { + final int serverNameLength = in.getUnsignedShort(offset); + offset += 2; + + if (extensionsLimit - offset < serverNameLength) { + break; + } + + final String hostname = in.toString(offset, serverNameLength, CharsetUtil.US_ASCII); + return hostname.toLowerCase(Locale.US); + } else { + // invalid enum value + break; + } + } + + offset += extensionLength; + } + } + } + return null; + } + + private void releaseHandshakeBuffer() { + if (handshakeBuffer != null) { + handshakeBuffer.release(); + handshakeBuffer = null; } } private void select(final ChannelHandlerContext ctx, final String hostname) throws Exception { + releaseHandshakeBuffer(); + Future future = lookup(ctx, hostname); if (future.isDone()) { fireSniCompletionEvent(ctx, hostname, future); @@ -211,18 +274,16 @@ public abstract class AbstractSniHandler extends ByteToMessageDecoder { } else { suppressRead = true; future.addListener((FutureListener) future1 -> { + suppressRead = false; try { - suppressRead = false; - try { - fireSniCompletionEvent(ctx, hostname, future1); - onLookupComplete(ctx, hostname, future1); - } catch (DecoderException err) { - ctx.fireExceptionCaught(err); - } catch (Exception cause) { - ctx.fireExceptionCaught(new DecoderException(cause)); - } catch (Throwable cause) { - ctx.fireExceptionCaught(cause); - } + fireSniCompletionEvent(ctx, hostname, future1); + onLookupComplete(ctx, hostname, future1); + } catch (DecoderException err) { + ctx.fireExceptionCaught(err); + } catch (Exception cause) { + ctx.fireExceptionCaught(new DecoderException(cause)); + } catch (Throwable cause) { + ctx.fireExceptionCaught(cause); } finally { if (readPending) { readPending = false; @@ -233,6 +294,13 @@ public abstract class AbstractSniHandler extends ByteToMessageDecoder { } } + @Override + protected void handlerRemoved0(ChannelHandlerContext ctx) throws Exception { + releaseHandshakeBuffer(); + + super.handlerRemoved0(ctx); + } + private void fireSniCompletionEvent(ChannelHandlerContext ctx, String hostname, Future future) { Throwable cause = future.cause(); if (cause == null) { diff --git a/handler/src/test/java/io/netty/handler/ssl/SniHandlerTest.java b/handler/src/test/java/io/netty/handler/ssl/SniHandlerTest.java index c8d74c31b3..0a752e45d0 100644 --- a/handler/src/test/java/io/netty/handler/ssl/SniHandlerTest.java +++ b/handler/src/test/java/io/netty/handler/ssl/SniHandlerTest.java @@ -38,7 +38,9 @@ import java.util.concurrent.atomic.AtomicBoolean; import java.util.concurrent.atomic.AtomicReference; import javax.net.ssl.SSLEngine; +import javax.net.ssl.SSLException; +import io.netty.util.concurrent.Future; import org.junit.Test; import org.junit.runner.RunWith; import org.junit.runners.Parameterized; @@ -76,6 +78,7 @@ import io.netty.util.concurrent.Promise; import io.netty.util.internal.ResourcesUtil; import io.netty.util.internal.StringUtil; +import org.mockito.Mockito; @RunWith(Parameterized.class) public class SniHandlerTest { @@ -317,7 +320,7 @@ public class SniHandlerTest { try { // Push the handshake message. ch.writeInbound(Unpooled.wrappedBuffer(message)); - // TODO(scott): This should fail becasue the engine should reject zero length records during handshake. + // TODO(scott): This should fail because the engine should reject zero length records during handshake. // See https://github.com/netty/netty/issues/6348. // fail(); } catch (Exception e) { @@ -575,4 +578,79 @@ public class SniHandlerTest { ReferenceCountUtil.release(ctx); } } + + @Test + public void testNonFragmented() throws Exception { + testWithFragmentSize(Integer.MAX_VALUE); + } + @Test + public void testFragmented() throws Exception { + testWithFragmentSize(50); + } + + private void testWithFragmentSize(final int maxFragmentSize) throws Exception { + final String sni = "netty.io"; + SelfSignedCertificate cert = new SelfSignedCertificate(); + final SslContext context = SslContextBuilder.forServer(cert.key(), cert.cert()) + .sslProvider(provider) + .build(); + try { + @SuppressWarnings("unchecked") final EmbeddedChannel server = new EmbeddedChannel( + new SniHandler(Mockito.mock(DomainNameMapping.class)) { + @Override + protected Future lookup(final ChannelHandlerContext ctx, final String hostname) { + assertEquals(sni, hostname); + return ctx.executor().newSucceededFuture(context); + } + }); + + final List buffers = clientHelloInMultipleFragments(provider, sni, maxFragmentSize); + for (ByteBuf buffer : buffers) { + server.writeInbound(buffer); + } + assertTrue(server.finishAndReleaseAll()); + } finally { + releaseAll(context); + cert.delete(); + } + } + + private static List clientHelloInMultipleFragments( + SslProvider provider, String hostname, int maxTlsPlaintextSize) throws SSLException { + final EmbeddedChannel client = new EmbeddedChannel(); + final SslContext ctx = SslContextBuilder.forClient() + .sslProvider(provider) + .trustManager(InsecureTrustManagerFactory.INSTANCE) + .build(); + try { + final SslHandler sslHandler = ctx.newHandler(client.alloc(), hostname, -1); + client.pipeline().addLast(sslHandler); + final ByteBuf clientHello = client.readOutbound(); + List buffers = split(clientHello, maxTlsPlaintextSize); + assertTrue(client.finishAndReleaseAll()); + return buffers; + } finally { + releaseAll(ctx); + } + } + + private static List split(ByteBuf clientHello, int maxSize) { + final int type = clientHello.readUnsignedByte(); + final int version = clientHello.readUnsignedShort(); + final int length = clientHello.readUnsignedShort(); + assertEquals(length, clientHello.readableBytes()); + + final List result = new ArrayList(); + while (clientHello.readableBytes() > 0) { + final int toRead = Math.min(maxSize, clientHello.readableBytes()); + final ByteBuf bb = clientHello.alloc().buffer(SslUtils.SSL_RECORD_HEADER_LENGTH + toRead); + bb.writeByte(type); + bb.writeShort(version); + bb.writeShort(toRead); + bb.writeBytes(clientHello, toRead); + result.add(bb); + } + clientHello.release(); + return result; + } }