Correctly handle fragmented Handshake message when trying to detect SNI (#9806)

Motivation:

At the moment our AbstractSniHandler makes the assemption that Handshake messages are not fragmented. This is incorrect as it is completely valid to split these across multiple TLSPlaintext records.

Thanks to @sskrobotov for bringing this to my attentation and to @Lukasa for the help.

Modifications:

- Adjust logic in AbstractSniHandler to handle fragmentation
- Add unit tests

Result:

Correctly handle fragmented Handshake message in AbstractSniHandler (and so SniHandler).
This commit is contained in:
Norman Maurer 2019-11-29 09:17:43 +01:00
parent 585ed4d08f
commit 29c471ec52
2 changed files with 281 additions and 135 deletions

View File

@ -24,7 +24,6 @@ import io.netty.handler.codec.DecoderException;
import io.netty.util.CharsetUtil; import io.netty.util.CharsetUtil;
import io.netty.util.concurrent.Future; import io.netty.util.concurrent.Future;
import io.netty.util.concurrent.FutureListener; import io.netty.util.concurrent.FutureListener;
import io.netty.util.internal.PlatformDependent;
import io.netty.util.internal.logging.InternalLogger; import io.netty.util.internal.logging.InternalLogger;
import io.netty.util.internal.logging.InternalLoggerFactory; import io.netty.util.internal.logging.InternalLoggerFactory;
@ -47,147 +46,119 @@ public abstract class AbstractSniHandler<T> extends ByteToMessageDecoder {
private boolean handshakeFailed; private boolean handshakeFailed;
private boolean suppressRead; private boolean suppressRead;
private boolean readPending; private boolean readPending;
private ByteBuf handshakeBuffer;
@Override @Override
protected void decode(ChannelHandlerContext ctx, ByteBuf in, List<Object> out) throws Exception { protected void decode(ChannelHandlerContext ctx, ByteBuf in, List<Object> out) throws Exception {
if (!suppressRead && !handshakeFailed) { if (!suppressRead && !handshakeFailed) {
try { try {
final int readerIndex = in.readerIndex(); int readerIndex = in.readerIndex();
final int readableBytes = in.readableBytes(); int readableBytes = in.readableBytes();
if (readableBytes < SslUtils.SSL_RECORD_HEADER_LENGTH) { int handshakeLength = -1;
// Not enough data to determine the record type and length.
return;
}
final int command = in.getUnsignedByte(readerIndex); // Check if we have enough data to determine the record type and length.
switch (command) { while (readableBytes >= SslUtils.SSL_RECORD_HEADER_LENGTH) {
case SslUtils.SSL_CONTENT_TYPE_CHANGE_CIPHER_SPEC: final int contentType = in.getUnsignedByte(readerIndex);
// fall-through switch (contentType) {
case SslUtils.SSL_CONTENT_TYPE_ALERT: case SslUtils.SSL_CONTENT_TYPE_CHANGE_CIPHER_SPEC:
final int len = SslUtils.getEncryptedPacketLength(in, readerIndex); // fall-through
case SslUtils.SSL_CONTENT_TYPE_ALERT:
final int len = SslUtils.getEncryptedPacketLength(in, readerIndex);
// Not an SSL/TLS packet // Not an SSL/TLS packet
if (len == SslUtils.NOT_ENCRYPTED) { if (len == SslUtils.NOT_ENCRYPTED) {
handshakeFailed = true; handshakeFailed = true;
NotSslRecordException e = new NotSslRecordException( NotSslRecordException e = new NotSslRecordException(
"not an SSL/TLS record: " + ByteBufUtil.hexDump(in)); "not an SSL/TLS record: " + ByteBufUtil.hexDump(in));
in.skipBytes(in.readableBytes()); in.skipBytes(in.readableBytes());
ctx.fireUserEventTriggered(new SniCompletionEvent(e)); ctx.fireUserEventTriggered(new SniCompletionEvent(e));
SslUtils.handleHandshakeFailure(ctx, e, true); SslUtils.handleHandshakeFailure(ctx, e, true);
throw e; throw e;
} }
if (len == SslUtils.NOT_ENOUGH_DATA) { if (len == SslUtils.NOT_ENOUGH_DATA) {
// Not enough data // Not enough data
return;
}
// SNI can't be present in an ALERT or CHANGE_CIPHER_SPEC record, so we'll fall back and assume
// no SNI is present. Let's let the actual TLS implementation sort this out.
break;
case SslUtils.SSL_CONTENT_TYPE_HANDSHAKE:
final int majorVersion = in.getUnsignedByte(readerIndex + 1);
// SSLv3 or TLS
if (majorVersion == 3) {
final int packetLength = in.getUnsignedShort(readerIndex + 3) +
SslUtils.SSL_RECORD_HEADER_LENGTH;
if (readableBytes < packetLength) {
// client hello incomplete; try again to decode once more data is ready.
return; return;
} }
// SNI can't be present in an ALERT or CHANGE_CIPHER_SPEC record, so we'll fall back and
// assume no SNI is present. Let's let the actual TLS implementation sort this out.
// Just select the default SslContext
select(ctx, null);
return;
case SslUtils.SSL_CONTENT_TYPE_HANDSHAKE:
final int majorVersion = in.getUnsignedByte(readerIndex + 1);
// SSLv3 or TLS
if (majorVersion == 3) {
int packetLength = in.getUnsignedShort(readerIndex + 3) +
SslUtils.SSL_RECORD_HEADER_LENGTH;
// See https://tools.ietf.org/html/rfc5246#section-7.4.1.2 if (readableBytes < packetLength) {
// // client hello incomplete; try again to decode once more data is ready.
// Decode the ssl client hello packet. return;
// We have to skip bytes until SessionID (which sum to 43 bytes). } else if (packetLength == SslUtils.SSL_RECORD_HEADER_LENGTH) {
// select(ctx, null);
// struct { return;
// ProtocolVersion client_version; }
// Random random;
// SessionID session_id;
// CipherSuite cipher_suites<2..2^16-2>;
// CompressionMethod compression_methods<1..2^8-1>;
// select (extensions_present) {
// case false:
// struct {};
// case true:
// Extension extensions<0..2^16-1>;
// };
// } ClientHello;
//
final int endOffset = readerIndex + packetLength; final int endOffset = readerIndex + packetLength;
int offset = readerIndex + 43;
if (endOffset - offset >= 6) { // Let's check if we already parsed the handshake length or not.
final int sessionIdLength = in.getUnsignedByte(offset); if (handshakeLength == -1) {
offset += sessionIdLength + 1; if (readerIndex + 4 > endOffset) {
// Need more data to read HandshakeType and handshakeLength (4 bytes)
return;
}
final int cipherSuitesLength = in.getUnsignedShort(offset); final int handshakeType = in.getUnsignedByte(readerIndex +
offset += cipherSuitesLength + 2; SslUtils.SSL_RECORD_HEADER_LENGTH);
final int compressionMethodLength = in.getUnsignedByte(offset); // Check if this is a clientHello(1)
offset += compressionMethodLength + 1; // See https://tools.ietf.org/html/rfc5246#section-7.4
if (handshakeType != 1) {
select(ctx, null);
return;
}
final int extensionsLength = in.getUnsignedShort(offset); // Read the length of the handshake as it may arrive in fragments
offset += 2; // See https://tools.ietf.org/html/rfc5246#section-7.4
final int extensionsLimit = offset + extensionsLength; handshakeLength = in.getUnsignedMedium(readerIndex +
SslUtils.SSL_RECORD_HEADER_LENGTH + 1);
// Extensions should never exceed the record boundary. // Consume handshakeType and handshakeLength (this sums up as 4 bytes)
if (extensionsLimit <= endOffset) { readerIndex += 4;
while (extensionsLimit - offset >= 4) { packetLength -= 4;
final int extensionType = in.getUnsignedShort(offset);
offset += 2;
final int extensionLength = in.getUnsignedShort(offset); if (handshakeLength + 4 + SslUtils.SSL_RECORD_HEADER_LENGTH <= packetLength) {
offset += 2; // We have everything we need in one packet.
// Skip the record header
if (extensionsLimit - offset < extensionLength) { readerIndex += SslUtils.SSL_RECORD_HEADER_LENGTH;
break; select(ctx, extractSniHostname(in, readerIndex, readerIndex + handshakeLength));
return;
} else {
if (handshakeBuffer == null) {
handshakeBuffer = ctx.alloc().buffer(handshakeLength);
} else {
// Clear the buffer so we can aggregate into it again.
handshakeBuffer.clear();
} }
// SNI
// See https://tools.ietf.org/html/rfc6066#page-6
if (extensionType == 0) {
offset += 2;
if (extensionsLimit - offset < 3) {
break;
}
final int serverNameType = in.getUnsignedByte(offset);
offset++;
if (serverNameType == 0) {
final int serverNameLength = in.getUnsignedShort(offset);
offset += 2;
if (extensionsLimit - offset < serverNameLength) {
break;
}
final String hostname =
in.toString(offset, serverNameLength, CharsetUtil.US_ASCII);
try {
select(ctx, hostname.toLowerCase(Locale.US));
} catch (Throwable t) {
PlatformDependent.throwException(t);
}
return;
} else {
// invalid enum value
break;
}
}
offset += extensionLength;
} }
} }
// Combine the encapsulated data in one buffer but not include the SSL_RECORD_HEADER
handshakeBuffer.writeBytes(in, readerIndex + SslUtils.SSL_RECORD_HEADER_LENGTH,
packetLength - SslUtils.SSL_RECORD_HEADER_LENGTH);
readerIndex += packetLength;
readableBytes -= packetLength;
if (handshakeLength <= handshakeBuffer.readableBytes()) {
select(ctx, extractSniHostname(handshakeBuffer, 0, handshakeLength));
return;
}
} }
} break;
break; default:
default: // not tls, ssl or application data, do not try sni
//not tls, ssl or application data, do not try sni select(ctx, null);
break; return;
}
} }
} catch (NotSslRecordException e) { } catch (NotSslRecordException e) {
// Just rethrow as in this case we also closed the channel and this is consistent with SslHandler. // Just rethrow as in this case we also closed the channel and this is consistent with SslHandler.
@ -197,13 +168,105 @@ public abstract class AbstractSniHandler<T> extends ByteToMessageDecoder {
if (logger.isDebugEnabled()) { if (logger.isDebugEnabled()) {
logger.debug("Unexpected client hello packet: " + ByteBufUtil.hexDump(in), e); logger.debug("Unexpected client hello packet: " + ByteBufUtil.hexDump(in), e);
} }
select(ctx, null);
} }
// Just select the default SslContext }
select(ctx, null); }
private static String extractSniHostname(ByteBuf in, int offset, int endOffset) {
// See https://tools.ietf.org/html/rfc5246#section-7.4.1.2
//
// Decode the ssl client hello packet.
//
// struct {
// ProtocolVersion client_version;
// Random random;
// SessionID session_id;
// CipherSuite cipher_suites<2..2^16-2>;
// CompressionMethod compression_methods<1..2^8-1>;
// select (extensions_present) {
// case false:
// struct {};
// case true:
// Extension extensions<0..2^16-1>;
// };
// } ClientHello;
//
// We have to skip bytes until SessionID (which sum to 34 bytes in this case).
offset += 34;
if (endOffset - offset >= 6) {
final int sessionIdLength = in.getUnsignedByte(offset);
offset += sessionIdLength + 1;
final int cipherSuitesLength = in.getUnsignedShort(offset);
offset += cipherSuitesLength + 2;
final int compressionMethodLength = in.getUnsignedByte(offset);
offset += compressionMethodLength + 1;
final int extensionsLength = in.getUnsignedShort(offset);
offset += 2;
final int extensionsLimit = offset + extensionsLength;
// Extensions should never exceed the record boundary.
if (extensionsLimit <= endOffset) {
while (extensionsLimit - offset >= 4) {
final int extensionType = in.getUnsignedShort(offset);
offset += 2;
final int extensionLength = in.getUnsignedShort(offset);
offset += 2;
if (extensionsLimit - offset < extensionLength) {
break;
}
// SNI
// See https://tools.ietf.org/html/rfc6066#page-6
if (extensionType == 0) {
offset += 2;
if (extensionsLimit - offset < 3) {
break;
}
final int serverNameType = in.getUnsignedByte(offset);
offset++;
if (serverNameType == 0) {
final int serverNameLength = in.getUnsignedShort(offset);
offset += 2;
if (extensionsLimit - offset < serverNameLength) {
break;
}
final String hostname = in.toString(offset, serverNameLength, CharsetUtil.US_ASCII);
return hostname.toLowerCase(Locale.US);
} else {
// invalid enum value
break;
}
}
offset += extensionLength;
}
}
}
return null;
}
private void releaseHandshakeBuffer() {
if (handshakeBuffer != null) {
handshakeBuffer.release();
handshakeBuffer = null;
} }
} }
private void select(final ChannelHandlerContext ctx, final String hostname) throws Exception { private void select(final ChannelHandlerContext ctx, final String hostname) throws Exception {
releaseHandshakeBuffer();
Future<T> future = lookup(ctx, hostname); Future<T> future = lookup(ctx, hostname);
if (future.isDone()) { if (future.isDone()) {
fireSniCompletionEvent(ctx, hostname, future); fireSniCompletionEvent(ctx, hostname, future);
@ -211,18 +274,16 @@ public abstract class AbstractSniHandler<T> extends ByteToMessageDecoder {
} else { } else {
suppressRead = true; suppressRead = true;
future.addListener((FutureListener<T>) future1 -> { future.addListener((FutureListener<T>) future1 -> {
suppressRead = false;
try { try {
suppressRead = false; fireSniCompletionEvent(ctx, hostname, future1);
try { onLookupComplete(ctx, hostname, future1);
fireSniCompletionEvent(ctx, hostname, future1); } catch (DecoderException err) {
onLookupComplete(ctx, hostname, future1); ctx.fireExceptionCaught(err);
} catch (DecoderException err) { } catch (Exception cause) {
ctx.fireExceptionCaught(err); ctx.fireExceptionCaught(new DecoderException(cause));
} catch (Exception cause) { } catch (Throwable cause) {
ctx.fireExceptionCaught(new DecoderException(cause)); ctx.fireExceptionCaught(cause);
} catch (Throwable cause) {
ctx.fireExceptionCaught(cause);
}
} finally { } finally {
if (readPending) { if (readPending) {
readPending = false; readPending = false;
@ -233,6 +294,13 @@ public abstract class AbstractSniHandler<T> extends ByteToMessageDecoder {
} }
} }
@Override
protected void handlerRemoved0(ChannelHandlerContext ctx) throws Exception {
releaseHandshakeBuffer();
super.handlerRemoved0(ctx);
}
private void fireSniCompletionEvent(ChannelHandlerContext ctx, String hostname, Future<T> future) { private void fireSniCompletionEvent(ChannelHandlerContext ctx, String hostname, Future<T> future) {
Throwable cause = future.cause(); Throwable cause = future.cause();
if (cause == null) { if (cause == null) {

View File

@ -38,7 +38,9 @@ import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicReference; import java.util.concurrent.atomic.AtomicReference;
import javax.net.ssl.SSLEngine; import javax.net.ssl.SSLEngine;
import javax.net.ssl.SSLException;
import io.netty.util.concurrent.Future;
import org.junit.Test; import org.junit.Test;
import org.junit.runner.RunWith; import org.junit.runner.RunWith;
import org.junit.runners.Parameterized; import org.junit.runners.Parameterized;
@ -76,6 +78,7 @@ import io.netty.util.concurrent.Promise;
import io.netty.util.internal.ResourcesUtil; import io.netty.util.internal.ResourcesUtil;
import io.netty.util.internal.StringUtil; import io.netty.util.internal.StringUtil;
import org.mockito.Mockito;
@RunWith(Parameterized.class) @RunWith(Parameterized.class)
public class SniHandlerTest { public class SniHandlerTest {
@ -317,7 +320,7 @@ public class SniHandlerTest {
try { try {
// Push the handshake message. // Push the handshake message.
ch.writeInbound(Unpooled.wrappedBuffer(message)); ch.writeInbound(Unpooled.wrappedBuffer(message));
// TODO(scott): This should fail becasue the engine should reject zero length records during handshake. // TODO(scott): This should fail because the engine should reject zero length records during handshake.
// See https://github.com/netty/netty/issues/6348. // See https://github.com/netty/netty/issues/6348.
// fail(); // fail();
} catch (Exception e) { } catch (Exception e) {
@ -575,4 +578,79 @@ public class SniHandlerTest {
ReferenceCountUtil.release(ctx); ReferenceCountUtil.release(ctx);
} }
} }
@Test
public void testNonFragmented() throws Exception {
testWithFragmentSize(Integer.MAX_VALUE);
}
@Test
public void testFragmented() throws Exception {
testWithFragmentSize(50);
}
private void testWithFragmentSize(final int maxFragmentSize) throws Exception {
final String sni = "netty.io";
SelfSignedCertificate cert = new SelfSignedCertificate();
final SslContext context = SslContextBuilder.forServer(cert.key(), cert.cert())
.sslProvider(provider)
.build();
try {
@SuppressWarnings("unchecked") final EmbeddedChannel server = new EmbeddedChannel(
new SniHandler(Mockito.mock(DomainNameMapping.class)) {
@Override
protected Future<SslContext> lookup(final ChannelHandlerContext ctx, final String hostname) {
assertEquals(sni, hostname);
return ctx.executor().newSucceededFuture(context);
}
});
final List<ByteBuf> buffers = clientHelloInMultipleFragments(provider, sni, maxFragmentSize);
for (ByteBuf buffer : buffers) {
server.writeInbound(buffer);
}
assertTrue(server.finishAndReleaseAll());
} finally {
releaseAll(context);
cert.delete();
}
}
private static List<ByteBuf> clientHelloInMultipleFragments(
SslProvider provider, String hostname, int maxTlsPlaintextSize) throws SSLException {
final EmbeddedChannel client = new EmbeddedChannel();
final SslContext ctx = SslContextBuilder.forClient()
.sslProvider(provider)
.trustManager(InsecureTrustManagerFactory.INSTANCE)
.build();
try {
final SslHandler sslHandler = ctx.newHandler(client.alloc(), hostname, -1);
client.pipeline().addLast(sslHandler);
final ByteBuf clientHello = client.readOutbound();
List<ByteBuf> buffers = split(clientHello, maxTlsPlaintextSize);
assertTrue(client.finishAndReleaseAll());
return buffers;
} finally {
releaseAll(ctx);
}
}
private static List<ByteBuf> split(ByteBuf clientHello, int maxSize) {
final int type = clientHello.readUnsignedByte();
final int version = clientHello.readUnsignedShort();
final int length = clientHello.readUnsignedShort();
assertEquals(length, clientHello.readableBytes());
final List<ByteBuf> result = new ArrayList<ByteBuf>();
while (clientHello.readableBytes() > 0) {
final int toRead = Math.min(maxSize, clientHello.readableBytes());
final ByteBuf bb = clientHello.alloc().buffer(SslUtils.SSL_RECORD_HEADER_LENGTH + toRead);
bb.writeByte(type);
bb.writeShort(version);
bb.writeShort(toRead);
bb.writeBytes(clientHello, toRead);
result.add(bb);
}
clientHello.release();
return result;
}
} }