Correctly handle non handshake commands when using SniHandler

Motivation:

As we can only handle handshake commands to parse SNI we should try to skip alert and change cipher spec commands a few times before we fallback to use a default SslContext.

Modifications:

- Use default SslContext if no application data command was received
- Use default SslContext if after 4 commands we not received a handshake command
- Simplify code
- Eliminate multiple volatile fields
- Rename SslConstants to SslUtils
- Share code between SslHandler and SniHandler by moving stuff to SslUtils

Result:

Correct handling of non handshake commands and cleaner code.
This commit is contained in:
Norman Maurer 2016-01-08 11:47:51 +01:00
parent 951bacb0ca
commit 27118ec53a
4 changed files with 273 additions and 224 deletions

View File

@ -37,14 +37,19 @@ import java.util.Locale;
*/ */
public class SniHandler extends ByteToMessageDecoder { public class SniHandler extends ByteToMessageDecoder {
// Maximal number of ssl records to inspect before fallback to the default SslContext.
private static final int MAX_SSL_RECORDS = 4;
private static final InternalLogger logger = private static final InternalLogger logger =
InternalLoggerFactory.getInstance(SniHandler.class); InternalLoggerFactory.getInstance(SniHandler.class);
private static final Selection EMPTY_SELECTION = new Selection(null, null);
private final DomainNameMapping<SslContext> mapping; private final DomainNameMapping<SslContext> mapping;
private boolean handshaken; private boolean handshakeFailed;
private volatile String hostname;
private volatile SslContext selectedContext; private volatile Selection selection = EMPTY_SELECTION;
/** /**
* Create a SNI detection handler with configured {@link SslContext} * Create a SNI detection handler with configured {@link SslContext}
@ -59,127 +64,159 @@ public class SniHandler extends ByteToMessageDecoder {
} }
this.mapping = (DomainNameMapping<SslContext>) mapping; this.mapping = (DomainNameMapping<SslContext>) mapping;
handshaken = false;
} }
/** /**
* @return the selected hostname * @return the selected hostname
*/ */
public String hostname() { public String hostname() {
return hostname; return selection.hostname;
} }
/** /**
* @return the selected sslcontext * @return the selected sslcontext
*/ */
public SslContext sslContext() { public SslContext sslContext() {
return selectedContext; return selection.context;
} }
@Override @Override
protected void decode(ChannelHandlerContext ctx, ByteBuf in, List<Object> out) throws Exception { protected void decode(ChannelHandlerContext ctx, ByteBuf in, List<Object> out) throws Exception {
if (!handshaken && in.readableBytes() >= 5) { if (!handshakeFailed && in.readableBytes() >= SslUtils.SSL_RECORD_HEADER_LENGTH) {
String hostname = sniHostNameFromHandshakeInfo(in); int writerIndex = in.writerIndex();
if (hostname != null) { int readerIndex = in.readerIndex();
hostname = IDN.toASCII(hostname, IDN.ALLOW_UNASSIGNED).toLowerCase(Locale.US); try {
loop: for (int i = 0; i < MAX_SSL_RECORDS; i++) {
int command = in.getUnsignedByte(readerIndex);
// tls, but not handshake command
switch (command) {
case SslUtils.SSL_CONTENT_TYPE_CHANGE_CIPHER_SPEC:
case SslUtils.SSL_CONTENT_TYPE_ALERT:
int len = SslUtils.getEncryptedPacketLength(in, readerIndex);
// Not an SSL/TLS packet
if (len == -1) {
handshakeFailed = true;
NotSslRecordException e = new NotSslRecordException(
"not an SSL/TLS record: " + ByteBufUtil.hexDump(in));
in.skipBytes(in.readableBytes());
ctx.fireExceptionCaught(e);
SslUtils.notifyHandshakeFailure(ctx, e);
return;
}
if (writerIndex - readerIndex - SslUtils.SSL_RECORD_HEADER_LENGTH < len) {
// Not enough data
return;
}
// increase readerIndex and try again.
readerIndex += len;
continue;
case SslUtils.SSL_CONTENT_TYPE_HANDSHAKE:
int majorVersion = in.getUnsignedByte(readerIndex + 1);
// SSLv3 or TLS
if (majorVersion == 3) {
int packetLength = in.getUnsignedShort(readerIndex + 3)
+ SslUtils.SSL_RECORD_HEADER_LENGTH;
if (in.readableBytes() < packetLength) {
// client hello incomplete try again to decode once more data is ready.
return;
}
// See https://tools.ietf.org/html/rfc5246#section-7.4.1.2
//
// Decode the ssl client hello packet.
// We have to skip bytes until SessionID (which sum to 43 bytes).
//
// struct {
// ProtocolVersion client_version;
// Random random;
// SessionID session_id;
// CipherSuite cipher_suites<2..2^16-2>;
// CompressionMethod compression_methods<1..2^8-1>;
// select (extensions_present) {
// case false:
// struct {};
// case true:
// Extension extensions<0..2^16-1>;
// };
// } ClientHello;
//
int offset = readerIndex + 43;
int sessionIdLength = in.getUnsignedByte(offset);
offset += sessionIdLength + 1;
int cipherSuitesLength = in.getUnsignedShort(offset);
offset += cipherSuitesLength + 2;
int compressionMethodLength = in.getUnsignedByte(offset);
offset += compressionMethodLength + 1;
int extensionsLength = in.getUnsignedShort(offset);
offset += 2;
int extensionsLimit = offset + extensionsLength;
while (offset < extensionsLimit) {
int extensionType = in.getUnsignedShort(offset);
offset += 2;
int extensionLength = in.getUnsignedShort(offset);
offset += 2;
// SNI
// See https://tools.ietf.org/html/rfc6066#page-6
if (extensionType == 0) {
int serverNameType = in.getUnsignedByte(offset + 2);
if (serverNameType == 0) {
int serverNameLength = in.getUnsignedShort(offset + 3);
String hostname = in.toString(offset + 5, serverNameLength,
CharsetUtil.UTF_8);
select(ctx, IDN.toASCII(hostname,
IDN.ALLOW_UNASSIGNED).toLowerCase(Locale.US));
return;
} else {
// invalid enum value
break loop;
}
}
offset += extensionLength;
}
}
// Fall-through
default:
//not tls, ssl or application data, do not try sni
break loop;
}
}
} catch (Throwable e) {
// unexpected encoding, ignore sni and use default
if (logger.isDebugEnabled()) {
logger.debug("Unexpected client hello packet: " + ByteBufUtil.hexDump(in), e);
}
} }
this.hostname = hostname; // Just select the default SslContext
select(ctx, null);
// the mapping will return default context when this.hostname is null
selectedContext = mapping.map(hostname);
}
if (handshaken) {
SslHandler sslHandler = selectedContext.newHandler(ctx.alloc());
ctx.pipeline().replace(this, SslHandler.class.getName(), sslHandler);
} }
} }
private String sniHostNameFromHandshakeInfo(ByteBuf in) { private void select(ChannelHandlerContext ctx, String hostname) {
int readerIndex = in.readerIndex(); SslContext selectedContext = mapping.map(hostname);
try { selection = new Selection(selectedContext, hostname);
int command = in.getUnsignedByte(readerIndex); SslHandler sslHandler = selectedContext.newHandler(ctx.alloc());
ctx.pipeline().replace(this, SslHandler.class.getName(), sslHandler);
}
// tls, but not handshake command private static final class Selection {
switch (command) { final SslContext context;
case SslConstants.SSL_CONTENT_TYPE_CHANGE_CIPHER_SPEC: final String hostname;
case SslConstants.SSL_CONTENT_TYPE_ALERT:
case SslConstants.SSL_CONTENT_TYPE_APPLICATION_DATA:
return null;
case SslConstants.SSL_CONTENT_TYPE_HANDSHAKE:
break;
default:
//not tls or sslv3, do not try sni
handshaken = true;
return null;
}
int majorVersion = in.getUnsignedByte(readerIndex + 1); Selection(SslContext context, String hostname) {
this.context = context;
// SSLv3 or TLS this.hostname = hostname;
if (majorVersion == 3) {
int packetLength = in.getUnsignedShort(readerIndex + 3) + 5;
if (in.readableBytes() >= packetLength) {
// decode the ssl client hello packet
// we have to skip some var-length fields
int offset = readerIndex + 43;
int sessionIdLength = in.getUnsignedByte(offset);
offset += sessionIdLength + 1;
int cipherSuitesLength = in.getUnsignedShort(offset);
offset += cipherSuitesLength + 2;
int compressionMethodLength = in.getUnsignedByte(offset);
offset += compressionMethodLength + 1;
int extensionsLength = in.getUnsignedShort(offset);
offset += 2;
int extensionsLimit = offset + extensionsLength;
while (offset < extensionsLimit) {
int extensionType = in.getUnsignedShort(offset);
offset += 2;
int extensionLength = in.getUnsignedShort(offset);
offset += 2;
// SNI
if (extensionType == 0) {
handshaken = true;
int serverNameType = in.getUnsignedByte(offset + 2);
if (serverNameType == 0) {
int serverNameLength = in.getUnsignedShort(offset + 3);
return in.toString(offset + 5, serverNameLength,
CharsetUtil.UTF_8);
} else {
// invalid enum value
return null;
}
}
offset += extensionLength;
}
handshaken = true;
return null;
} else {
// client hello incomplete
return null;
}
} else {
handshaken = true;
return null;
}
} catch (Throwable e) {
// unexpected encoding, ignore sni and use default
if (logger.isDebugEnabled()) {
logger.debug("Unexpected client hello packet: " + ByteBufUtil.hexDump(in), e);
}
handshaken = true;
return null;
} }
} }
} }

View File

@ -1,45 +0,0 @@
/*
* Copyright 2014 The Netty Project
*
* The Netty Project licenses this file to you under the Apache License,
* version 2.0 (the "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at:
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*/
package io.netty.handler.ssl;
/**
* Constants for SSL packets.
*/
final class SslConstants {
/**
* change cipher spec
*/
public static final int SSL_CONTENT_TYPE_CHANGE_CIPHER_SPEC = 20;
/**
* alert
*/
public static final int SSL_CONTENT_TYPE_ALERT = 21;
/**
* handshake
*/
public static final int SSL_CONTENT_TYPE_HANDSHAKE = 22;
/**
* application data
*/
public static final int SSL_CONTENT_TYPE_APPLICATION_DATA = 23;
private SslConstants() {
}
}

View File

@ -66,6 +66,8 @@ import java.util.concurrent.ScheduledFuture;
import java.util.concurrent.TimeUnit; import java.util.concurrent.TimeUnit;
import java.util.regex.Pattern; import java.util.regex.Pattern;
import static io.netty.handler.ssl.SslUtils.getEncryptedPacketLength;
/** /**
* Adds <a href="http://en.wikipedia.org/wiki/Transport_Layer_Security">SSL * Adds <a href="http://en.wikipedia.org/wiki/Transport_Layer_Security">SSL
* &middot; TLS</a> and StartTLS support to a {@link Channel}. Please refer * &middot; TLS</a> and StartTLS support to a {@link Channel}. Please refer
@ -813,84 +815,13 @@ public class SslHandler extends ByteToMessageDecoder implements ChannelOutboundH
* Is thrown if the given {@link ByteBuf} has not at least 5 bytes to read. * Is thrown if the given {@link ByteBuf} has not at least 5 bytes to read.
*/ */
public static boolean isEncrypted(ByteBuf buffer) { public static boolean isEncrypted(ByteBuf buffer) {
if (buffer.readableBytes() < 5) { if (buffer.readableBytes() < SslUtils.SSL_RECORD_HEADER_LENGTH) {
throw new IllegalArgumentException("buffer must have at least 5 readable bytes"); throw new IllegalArgumentException(
"buffer must have at least " + SslUtils.SSL_RECORD_HEADER_LENGTH + " readable bytes");
} }
return getEncryptedPacketLength(buffer, buffer.readerIndex()) != -1; return getEncryptedPacketLength(buffer, buffer.readerIndex()) != -1;
} }
/**
* Return how much bytes can be read out of the encrypted data. Be aware that this method will not increase
* the readerIndex of the given {@link ByteBuf}.
*
* @param buffer
* The {@link ByteBuf} to read from. Be aware that it must have at least 5 bytes to read,
* otherwise it will throw an {@link IllegalArgumentException}.
* @return length
* The length of the encrypted packet that is included in the buffer. This will
* return {@code -1} if the given {@link ByteBuf} is not encrypted at all.
* @throws IllegalArgumentException
* Is thrown if the given {@link ByteBuf} has not at least 5 bytes to read.
*/
private static int getEncryptedPacketLength(ByteBuf buffer, int offset) {
int packetLength = 0;
// SSLv3 or TLS - Check ContentType
boolean tls;
switch (buffer.getUnsignedByte(offset)) {
case 20: // change_cipher_spec
case 21: // alert
case 22: // handshake
case 23: // application_data
tls = true;
break;
default:
// SSLv2 or bad data
tls = false;
}
if (tls) {
// SSLv3 or TLS - Check ProtocolVersion
int majorVersion = buffer.getUnsignedByte(offset + 1);
if (majorVersion == 3) {
// SSLv3 or TLS
packetLength = buffer.getUnsignedShort(offset + 3) + 5;
if (packetLength <= 5) {
// Neither SSLv3 or TLSv1 (i.e. SSLv2 or bad data)
tls = false;
}
} else {
// Neither SSLv3 or TLSv1 (i.e. SSLv2 or bad data)
tls = false;
}
}
if (!tls) {
// SSLv2 or bad data - Check the version
boolean sslv2 = true;
int headerLength = (buffer.getUnsignedByte(offset) & 0x80) != 0 ? 2 : 3;
int majorVersion = buffer.getUnsignedByte(offset + headerLength + 1);
if (majorVersion == 2 || majorVersion == 3) {
// SSLv2
if (headerLength == 2) {
packetLength = (buffer.getShort(offset) & 0x7FFF) + 2;
} else {
packetLength = (buffer.getShort(offset) & 0x3FFF) + 3;
}
if (packetLength <= headerLength) {
sslv2 = false;
}
} else {
sslv2 = false;
}
if (!sslv2) {
return -1;
}
}
return packetLength;
}
@Override @Override
protected void decode(ChannelHandlerContext ctx, ByteBuf in, List<Object> out) throws SSLException { protected void decode(ChannelHandlerContext ctx, ByteBuf in, List<Object> out) throws SSLException {
final int startOffset = in.readerIndex(); final int startOffset = in.readerIndex();
@ -913,7 +844,7 @@ public class SslHandler extends ByteToMessageDecoder implements ChannelOutboundH
while (totalLength < OpenSslEngine.MAX_ENCRYPTED_PACKET_LENGTH) { while (totalLength < OpenSslEngine.MAX_ENCRYPTED_PACKET_LENGTH) {
final int readableBytes = endOffset - offset; final int readableBytes = endOffset - offset;
if (readableBytes < 5) { if (readableBytes < SslUtils.SSL_RECORD_HEADER_LENGTH) {
break; break;
} }
@ -1299,8 +1230,7 @@ public class SslHandler extends ByteToMessageDecoder implements ChannelOutboundH
private void notifyHandshakeFailure(Throwable cause) { private void notifyHandshakeFailure(Throwable cause) {
if (handshakePromise.tryFailure(cause)) { if (handshakePromise.tryFailure(cause)) {
ctx.fireUserEventTriggered(new SslHandshakeCompletionEvent(cause)); SslUtils.notifyHandshakeFailure(ctx, cause);
ctx.close();
} }
} }

View File

@ -0,0 +1,127 @@
/*
* Copyright 2014 The Netty Project
*
* The Netty Project licenses this file to you under the Apache License,
* version 2.0 (the "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at:
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*/
package io.netty.handler.ssl;
import io.netty.buffer.ByteBuf;
import io.netty.channel.ChannelHandlerContext;
/**
* Constants for SSL packets.
*/
final class SslUtils {
/**
* change cipher spec
*/
public static final int SSL_CONTENT_TYPE_CHANGE_CIPHER_SPEC = 20;
/**
* alert
*/
public static final int SSL_CONTENT_TYPE_ALERT = 21;
/**
* handshake
*/
public static final int SSL_CONTENT_TYPE_HANDSHAKE = 22;
/**
* application data
*/
public static final int SSL_CONTENT_TYPE_APPLICATION_DATA = 23;
/**
* the length of the ssl record header (in bytes)
*/
public static final int SSL_RECORD_HEADER_LENGTH = 5;
/**
* Return how much bytes can be read out of the encrypted data. Be aware that this method will not increase
* the readerIndex of the given {@link ByteBuf}.
*
* @param buffer
* The {@link ByteBuf} to read from. Be aware that it must have at least
* {@link #SSL_RECORD_HEADER_LENGTH} bytes to read,
* otherwise it will throw an {@link IllegalArgumentException}.
* @return length
* The length of the encrypted packet that is included in the buffer. This will
* return {@code -1} if the given {@link ByteBuf} is not encrypted at all.
* @throws IllegalArgumentException
* Is thrown if the given {@link ByteBuf} has not at least {@link #SSL_RECORD_HEADER_LENGTH}
* bytes to read.
*/
static int getEncryptedPacketLength(ByteBuf buffer, int offset) {
int packetLength = 0;
// SSLv3 or TLS - Check ContentType
boolean tls;
switch (buffer.getUnsignedByte(offset)) {
case SSL_CONTENT_TYPE_CHANGE_CIPHER_SPEC:
case SSL_CONTENT_TYPE_ALERT:
case SSL_CONTENT_TYPE_HANDSHAKE:
case SSL_CONTENT_TYPE_APPLICATION_DATA:
tls = true;
break;
default:
// SSLv2 or bad data
tls = false;
}
if (tls) {
// SSLv3 or TLS - Check ProtocolVersion
int majorVersion = buffer.getUnsignedByte(offset + 1);
if (majorVersion == 3) {
// SSLv3 or TLS
packetLength = buffer.getUnsignedShort(offset + 3) + SSL_RECORD_HEADER_LENGTH;
if (packetLength <= SSL_RECORD_HEADER_LENGTH) {
// Neither SSLv3 or TLSv1 (i.e. SSLv2 or bad data)
tls = false;
}
} else {
// Neither SSLv3 or TLSv1 (i.e. SSLv2 or bad data)
tls = false;
}
}
if (!tls) {
// SSLv2 or bad data - Check the version
int headerLength = (buffer.getUnsignedByte(offset) & 0x80) != 0 ? 2 : 3;
int majorVersion = buffer.getUnsignedByte(offset + headerLength + 1);
if (majorVersion == 2 || majorVersion == 3) {
// SSLv2
if (headerLength == 2) {
packetLength = (buffer.getShort(offset) & 0x7FFF) + 2;
} else {
packetLength = (buffer.getShort(offset) & 0x3FFF) + 3;
}
if (packetLength <= headerLength) {
return -1;
}
} else {
return -1;
}
}
return packetLength;
}
static void notifyHandshakeFailure(ChannelHandlerContext ctx, Throwable cause) {
ctx.fireUserEventTriggered(new SslHandshakeCompletionEvent(cause));
ctx.close();
}
private SslUtils() {
}
}