From 8304069a30ebc0b2f464895ae5e1098119232ec3 Mon Sep 17 00:00:00 2001 From: Norman Maurer Date: Wed, 5 Feb 2020 14:41:51 +0100 Subject: [PATCH] =?UTF-8?q?Add=20SslClientHelloHandler=20which=20allows=20?= =?UTF-8?q?to=20do=20something=20based=20on=20the=20S=E2=80=A6=20(#9827)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Motivation: Sometimes it is useful to do something depending on the Ssl ClientHello (like for example select a SslContext to use). At the moment we only allow to hook into the SNI extension but this is not enough. Modifications: Add SslClientHelloHandler which allows to hook into ClientHello messages. This class is now also the super class of AbstractSniHandler Result: More flexible processing of SSL handshakes --- .../netty/handler/ssl/AbstractSniHandler.java | 254 ++--------------- .../handler/ssl/SslClientHelloHandler.java | 264 ++++++++++++++++++ pom.xml | 2 +- 3 files changed, 283 insertions(+), 237 deletions(-) create mode 100644 handler/src/main/java/io/netty/handler/ssl/SslClientHelloHandler.java diff --git a/handler/src/main/java/io/netty/handler/ssl/AbstractSniHandler.java b/handler/src/main/java/io/netty/handler/ssl/AbstractSniHandler.java index e9325144a6..46f3a35e51 100644 --- a/handler/src/main/java/io/netty/handler/ssl/AbstractSniHandler.java +++ b/handler/src/main/java/io/netty/handler/ssl/AbstractSniHandler.java @@ -16,18 +16,11 @@ package io.netty.handler.ssl; import io.netty.buffer.ByteBuf; -import io.netty.buffer.ByteBufUtil; import io.netty.channel.ChannelHandlerContext; -import io.netty.channel.ChannelPromise; -import io.netty.handler.codec.ByteToMessageDecoder; -import io.netty.handler.codec.DecoderException; + import io.netty.util.CharsetUtil; import io.netty.util.concurrent.Future; -import io.netty.util.concurrent.FutureListener; -import io.netty.util.internal.logging.InternalLogger; -import io.netty.util.internal.logging.InternalLoggerFactory; -import java.net.SocketAddress; import java.util.Locale; /** @@ -37,143 +30,9 @@ import java.util.Locale; * The client will send host name in the handshake data so server could decide * which certificate to choose for the host name.

*/ -public abstract class AbstractSniHandler extends ByteToMessageDecoder { +public abstract class AbstractSniHandler extends SslClientHelloHandler { - private static final InternalLogger logger = - InternalLoggerFactory.getInstance(AbstractSniHandler.class); - - private boolean handshakeFailed; - private boolean suppressRead; - private boolean readPending; - private ByteBuf handshakeBuffer; - - @Override - protected void decode(ChannelHandlerContext ctx, ByteBuf in) throws Exception { - if (!suppressRead && !handshakeFailed) { - try { - int readerIndex = in.readerIndex(); - int readableBytes = in.readableBytes(); - int handshakeLength = -1; - - // Check if we have enough data to determine the record type and length. - while (readableBytes >= SslUtils.SSL_RECORD_HEADER_LENGTH) { - final int contentType = in.getUnsignedByte(readerIndex); - switch (contentType) { - case SslUtils.SSL_CONTENT_TYPE_CHANGE_CIPHER_SPEC: - // fall-through - case SslUtils.SSL_CONTENT_TYPE_ALERT: - final int len = SslUtils.getEncryptedPacketLength(in, readerIndex); - - // Not an SSL/TLS packet - if (len == SslUtils.NOT_ENCRYPTED) { - handshakeFailed = true; - NotSslRecordException e = new NotSslRecordException( - "not an SSL/TLS record: " + ByteBufUtil.hexDump(in)); - in.skipBytes(in.readableBytes()); - ctx.fireUserEventTriggered(new SniCompletionEvent(e)); - SslUtils.handleHandshakeFailure(ctx, e, true); - throw e; - } - if (len == SslUtils.NOT_ENOUGH_DATA) { - // Not enough data - return; - } - // SNI can't be present in an ALERT or CHANGE_CIPHER_SPEC record, so we'll fall back and - // assume no SNI is present. Let's let the actual TLS implementation sort this out. - // Just select the default SslContext - select(ctx, null); - return; - case SslUtils.SSL_CONTENT_TYPE_HANDSHAKE: - final int majorVersion = in.getUnsignedByte(readerIndex + 1); - // SSLv3 or TLS - if (majorVersion == 3) { - int packetLength = in.getUnsignedShort(readerIndex + 3) + - SslUtils.SSL_RECORD_HEADER_LENGTH; - - if (readableBytes < packetLength) { - // client hello incomplete; try again to decode once more data is ready. - return; - } else if (packetLength == SslUtils.SSL_RECORD_HEADER_LENGTH) { - select(ctx, null); - return; - } - - final int endOffset = readerIndex + packetLength; - - // Let's check if we already parsed the handshake length or not. - if (handshakeLength == -1) { - if (readerIndex + 4 > endOffset) { - // Need more data to read HandshakeType and handshakeLength (4 bytes) - return; - } - - final int handshakeType = in.getUnsignedByte(readerIndex + - SslUtils.SSL_RECORD_HEADER_LENGTH); - - // Check if this is a clientHello(1) - // See https://tools.ietf.org/html/rfc5246#section-7.4 - if (handshakeType != 1) { - select(ctx, null); - return; - } - - // Read the length of the handshake as it may arrive in fragments - // See https://tools.ietf.org/html/rfc5246#section-7.4 - handshakeLength = in.getUnsignedMedium(readerIndex + - SslUtils.SSL_RECORD_HEADER_LENGTH + 1); - - // Consume handshakeType and handshakeLength (this sums up as 4 bytes) - readerIndex += 4; - packetLength -= 4; - - if (handshakeLength + 4 + SslUtils.SSL_RECORD_HEADER_LENGTH <= packetLength) { - // We have everything we need in one packet. - // Skip the record header - readerIndex += SslUtils.SSL_RECORD_HEADER_LENGTH; - select(ctx, extractSniHostname(in, readerIndex, readerIndex + handshakeLength)); - return; - } else { - if (handshakeBuffer == null) { - handshakeBuffer = ctx.alloc().buffer(handshakeLength); - } else { - // Clear the buffer so we can aggregate into it again. - handshakeBuffer.clear(); - } - } - } - - // Combine the encapsulated data in one buffer but not include the SSL_RECORD_HEADER - handshakeBuffer.writeBytes(in, readerIndex + SslUtils.SSL_RECORD_HEADER_LENGTH, - packetLength - SslUtils.SSL_RECORD_HEADER_LENGTH); - readerIndex += packetLength; - readableBytes -= packetLength; - if (handshakeLength <= handshakeBuffer.readableBytes()) { - select(ctx, extractSniHostname(handshakeBuffer, 0, handshakeLength)); - return; - } - break; - } - // fall-through - default: - // not tls, ssl or application data, do not try sni - select(ctx, null); - return; - } - } - } catch (NotSslRecordException e) { - // Just rethrow as in this case we also closed the channel and this is consistent with SslHandler. - throw e; - } catch (Exception e) { - // unexpected encoding, ignore sni and use default - if (logger.isDebugEnabled()) { - logger.debug("Unexpected client hello packet: " + ByteBufUtil.hexDump(in), e); - } - select(ctx, null); - } - } - } - - private static String extractSniHostname(ByteBuf in, int offset, int endOffset) { + private static String extractSniHostname(ByteBuf in) { // See https://tools.ietf.org/html/rfc5246#section-7.4.1.2 // // Decode the ssl client hello packet. @@ -194,6 +53,8 @@ public abstract class AbstractSniHandler extends ByteToMessageDecoder { // // We have to skip bytes until SessionID (which sum to 34 bytes in this case). + int offset = in.readerIndex(); + int endOffset = in.writerIndex(); offset += 34; if (endOffset - offset >= 6) { @@ -257,57 +118,19 @@ public abstract class AbstractSniHandler extends ByteToMessageDecoder { return null; } - private void releaseHandshakeBuffer() { - if (handshakeBuffer != null) { - handshakeBuffer.release(); - handshakeBuffer = null; - } - } + private String hostname; - private void select(final ChannelHandlerContext ctx, final String hostname) throws Exception { - releaseHandshakeBuffer(); + @Override + protected Future lookup(ChannelHandlerContext ctx, ByteBuf clientHello) throws Exception { + hostname = clientHello == null ? null : extractSniHostname(clientHello); - Future future = lookup(ctx, hostname); - if (future.isDone()) { - fireSniCompletionEvent(ctx, hostname, future); - onLookupComplete(ctx, hostname, future); - } else { - suppressRead = true; - future.addListener((FutureListener) future1 -> { - suppressRead = false; - try { - fireSniCompletionEvent(ctx, hostname, future1); - onLookupComplete(ctx, hostname, future1); - } catch (DecoderException err) { - ctx.fireExceptionCaught(err); - } catch (Exception cause) { - ctx.fireExceptionCaught(new DecoderException(cause)); - } catch (Throwable cause) { - ctx.fireExceptionCaught(cause); - } finally { - if (readPending) { - readPending = false; - ctx.read(); - } - } - }); - } + return lookup(ctx, hostname); } @Override - protected void handlerRemoved0(ChannelHandlerContext ctx) throws Exception { - releaseHandshakeBuffer(); - - super.handlerRemoved0(ctx); - } - - private void fireSniCompletionEvent(ChannelHandlerContext ctx, String hostname, Future future) { - Throwable cause = future.cause(); - if (cause == null) { - ctx.fireUserEventTriggered(new SniCompletionEvent(hostname)); - } else { - ctx.fireUserEventTriggered(new SniCompletionEvent(hostname, cause)); - } + protected void onLookupComplete(ChannelHandlerContext ctx, Future future) throws Exception { + fireSniCompletionEvent(ctx, hostname, future); + onLookupComplete(ctx, hostname, future); } /** @@ -326,53 +149,12 @@ public abstract class AbstractSniHandler extends ByteToMessageDecoder { protected abstract void onLookupComplete(ChannelHandlerContext ctx, String hostname, Future future) throws Exception; - @Override - public void read(ChannelHandlerContext ctx) throws Exception { - if (suppressRead) { - readPending = true; + private void fireSniCompletionEvent(ChannelHandlerContext ctx, String hostname, Future future) { + Throwable cause = future.cause(); + if (cause == null) { + ctx.fireUserEventTriggered(new SniCompletionEvent(hostname)); } else { - ctx.read(); + ctx.fireUserEventTriggered(new SniCompletionEvent(hostname, cause)); } } - - @Override - public void bind(ChannelHandlerContext ctx, SocketAddress localAddress, ChannelPromise promise) throws Exception { - ctx.bind(localAddress, promise); - } - - @Override - public void connect(ChannelHandlerContext ctx, SocketAddress remoteAddress, SocketAddress localAddress, - ChannelPromise promise) throws Exception { - ctx.connect(remoteAddress, localAddress, promise); - } - - @Override - public void disconnect(ChannelHandlerContext ctx, ChannelPromise promise) throws Exception { - ctx.disconnect(promise); - } - - @Override - public void close(ChannelHandlerContext ctx, ChannelPromise promise) throws Exception { - ctx.close(promise); - } - - @Override - public void register(ChannelHandlerContext ctx, ChannelPromise promise) throws Exception { - ctx.register(promise); - } - - @Override - public void deregister(ChannelHandlerContext ctx, ChannelPromise promise) throws Exception { - ctx.deregister(promise); - } - - @Override - public void write(ChannelHandlerContext ctx, Object msg, ChannelPromise promise) throws Exception { - ctx.write(msg, promise); - } - - @Override - public void flush(ChannelHandlerContext ctx) throws Exception { - ctx.flush(); - } } diff --git a/handler/src/main/java/io/netty/handler/ssl/SslClientHelloHandler.java b/handler/src/main/java/io/netty/handler/ssl/SslClientHelloHandler.java new file mode 100644 index 0000000000..7ff65edb05 --- /dev/null +++ b/handler/src/main/java/io/netty/handler/ssl/SslClientHelloHandler.java @@ -0,0 +1,264 @@ +/* + * Copyright 2017 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.ssl; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.ByteBufUtil; +import io.netty.channel.ChannelHandlerContext; +import io.netty.handler.codec.ByteToMessageDecoder; +import io.netty.handler.codec.DecoderException; +import io.netty.util.concurrent.Future; +import io.netty.util.concurrent.FutureListener; +import io.netty.util.internal.PlatformDependent; +import io.netty.util.internal.logging.InternalLogger; +import io.netty.util.internal.logging.InternalLoggerFactory; + +/** + * {@link ByteToMessageDecoder} which allows to be notified once a full {@code ClientHello} was received. + */ +public abstract class SslClientHelloHandler extends ByteToMessageDecoder { + + private static final InternalLogger logger = + InternalLoggerFactory.getInstance(SslClientHelloHandler.class); + + private boolean handshakeFailed; + private boolean suppressRead; + private boolean readPending; + private ByteBuf handshakeBuffer; + + @Override + protected void decode(ChannelHandlerContext ctx, ByteBuf in) throws Exception { + if (!suppressRead && !handshakeFailed) { + try { + int readerIndex = in.readerIndex(); + int readableBytes = in.readableBytes(); + int handshakeLength = -1; + + // Check if we have enough data to determine the record type and length. + while (readableBytes >= SslUtils.SSL_RECORD_HEADER_LENGTH) { + final int contentType = in.getUnsignedByte(readerIndex); + switch (contentType) { + case SslUtils.SSL_CONTENT_TYPE_CHANGE_CIPHER_SPEC: + // fall-through + case SslUtils.SSL_CONTENT_TYPE_ALERT: + final int len = SslUtils.getEncryptedPacketLength(in, readerIndex); + + // Not an SSL/TLS packet + if (len == SslUtils.NOT_ENCRYPTED) { + handshakeFailed = true; + NotSslRecordException e = new NotSslRecordException( + "not an SSL/TLS record: " + ByteBufUtil.hexDump(in)); + in.skipBytes(in.readableBytes()); + ctx.fireUserEventTriggered(new SniCompletionEvent(e)); + SslUtils.handleHandshakeFailure(ctx, e, true); + throw e; + } + if (len == SslUtils.NOT_ENOUGH_DATA) { + // Not enough data + return; + } + // No ClientHello + select(ctx, null); + return; + case SslUtils.SSL_CONTENT_TYPE_HANDSHAKE: + final int majorVersion = in.getUnsignedByte(readerIndex + 1); + // SSLv3 or TLS + if (majorVersion == 3) { + int packetLength = in.getUnsignedShort(readerIndex + 3) + + SslUtils.SSL_RECORD_HEADER_LENGTH; + + if (readableBytes < packetLength) { + // client hello incomplete; try again to decode once more data is ready. + return; + } else if (packetLength == SslUtils.SSL_RECORD_HEADER_LENGTH) { + select(ctx, null); + return; + } + + final int endOffset = readerIndex + packetLength; + + // Let's check if we already parsed the handshake length or not. + if (handshakeLength == -1) { + if (readerIndex + 4 > endOffset) { + // Need more data to read HandshakeType and handshakeLength (4 bytes) + return; + } + + final int handshakeType = in.getUnsignedByte(readerIndex + + SslUtils.SSL_RECORD_HEADER_LENGTH); + + // Check if this is a clientHello(1) + // See https://tools.ietf.org/html/rfc5246#section-7.4 + if (handshakeType != 1) { + select(ctx, null); + return; + } + + // Read the length of the handshake as it may arrive in fragments + // See https://tools.ietf.org/html/rfc5246#section-7.4 + handshakeLength = in.getUnsignedMedium(readerIndex + + SslUtils.SSL_RECORD_HEADER_LENGTH + 1); + + // Consume handshakeType and handshakeLength (this sums up as 4 bytes) + readerIndex += 4; + packetLength -= 4; + + if (handshakeLength + 4 + SslUtils.SSL_RECORD_HEADER_LENGTH <= packetLength) { + // We have everything we need in one packet. + // Skip the record header + readerIndex += SslUtils.SSL_RECORD_HEADER_LENGTH; + select(ctx, in.retainedSlice(readerIndex, handshakeLength)); + return; + } else { + if (handshakeBuffer == null) { + handshakeBuffer = ctx.alloc().buffer(handshakeLength); + } else { + // Clear the buffer so we can aggregate into it again. + handshakeBuffer.clear(); + } + } + } + + // Combine the encapsulated data in one buffer but not include the SSL_RECORD_HEADER + handshakeBuffer.writeBytes(in, readerIndex + SslUtils.SSL_RECORD_HEADER_LENGTH, + packetLength - SslUtils.SSL_RECORD_HEADER_LENGTH); + readerIndex += packetLength; + readableBytes -= packetLength; + if (handshakeLength <= handshakeBuffer.readableBytes()) { + ByteBuf clientHello = handshakeBuffer.setIndex(0, handshakeLength); + handshakeBuffer = null; + + select(ctx, clientHello); + return; + } + break; + } + // fall-through + default: + // not tls, ssl or application data + select(ctx, null); + return; + } + } + } catch (NotSslRecordException e) { + // Just rethrow as in this case we also closed the channel and this is consistent with SslHandler. + throw e; + } catch (Exception e) { + // unexpected encoding, ignore sni and use default + if (logger.isDebugEnabled()) { + logger.debug("Unexpected client hello packet: " + ByteBufUtil.hexDump(in), e); + } + select(ctx, null); + } + } + } + + private void releaseHandshakeBuffer() { + releaseIfNotNull(handshakeBuffer); + handshakeBuffer = null; + } + + private static void releaseIfNotNull(ByteBuf buffer) { + if (buffer != null) { + buffer.release(); + } + } + + private void select(final ChannelHandlerContext ctx, final ByteBuf clientHello) { + final Future future; + try { + future = lookup(ctx, clientHello); + if (future.isDone()) { + releaseIfNotNull(clientHello); + onLookupComplete(ctx, future); + } else { + suppressRead = true; + future.addListener((FutureListener) f -> { + releaseIfNotNull(clientHello); + try { + suppressRead = false; + try { + onLookupComplete(ctx, f); + } catch (DecoderException err) { + ctx.fireExceptionCaught(err); + } catch (Exception cause) { + ctx.fireExceptionCaught(new DecoderException(cause)); + } catch (Throwable cause) { + ctx.fireExceptionCaught(cause); + } + } finally { + if (readPending) { + readPending = false; + ctx.read(); + } + } + }); + } + } catch (Throwable cause) { + releaseIfNotNull(clientHello); + PlatformDependent.throwException(cause); + } + } + + @Override + protected void handlerRemoved0(ChannelHandlerContext ctx) throws Exception { + releaseHandshakeBuffer(); + + super.handlerRemoved0(ctx); + } + + /** + * Kicks off a lookup for the given {@code ClientHello} and returns a {@link Future} which in turn will + * notify the {@link #onLookupComplete(ChannelHandlerContext, Future)} on completion. + * + * See https://tools.ietf.org/html/rfc5246#section-7.4.1.2 + * + *
+     * struct {
+     *    ProtocolVersion client_version;
+     *    Random random;
+     *    SessionID session_id;
+     *    CipherSuite cipher_suites<2..2^16-2>;
+     *    CompressionMethod compression_methods<1..2^8-1>;
+     *    select (extensions_present) {
+     *        case false:
+     *            struct {};
+     *        case true:
+     *            Extension extensions<0..2^16-1>;
+     *    };
+     * } ClientHello;
+     * 
+ * + * @see #onLookupComplete(ChannelHandlerContext, Future) + */ + protected abstract Future lookup(ChannelHandlerContext ctx, ByteBuf clientHello) throws Exception; + + /** + * Called upon completion of the {@link #lookup(ChannelHandlerContext, ByteBuf)} {@link Future}. + * + * @see #lookup(ChannelHandlerContext, ByteBuf) + */ + protected abstract void onLookupComplete(ChannelHandlerContext ctx, Future future) throws Exception; + + @Override + public void read(ChannelHandlerContext ctx) throws Exception { + if (suppressRead) { + readPending = true; + } else { + ctx.read(); + } + } +} diff --git a/pom.xml b/pom.xml index 33c65324c0..1869aac793 100644 --- a/pom.xml +++ b/pom.xml @@ -710,7 +710,7 @@ com.github.siom79.japicmp japicmp-maven-plugin - 0.13.1 + 0.14.3 true