From 8304069a30ebc0b2f464895ae5e1098119232ec3 Mon Sep 17 00:00:00 2001
From: Norman Maurer
Date: Wed, 5 Feb 2020 14:41:51 +0100
Subject: [PATCH] =?UTF-8?q?Add=20SslClientHelloHandler=20which=20allows=20?=
=?UTF-8?q?to=20do=20something=20based=20on=20the=20S=E2=80=A6=20(#9827)?=
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit
Motivation:
Sometimes it is useful to do something depending on the Ssl ClientHello (like for example select a SslContext to use). At the moment we only allow to hook into the SNI extension but this is not enough.
Modifications:
Add SslClientHelloHandler which allows to hook into ClientHello messages. This class is now also the super class of AbstractSniHandler
Result:
More flexible processing of SSL handshakes
---
.../netty/handler/ssl/AbstractSniHandler.java | 254 ++---------------
.../handler/ssl/SslClientHelloHandler.java | 264 ++++++++++++++++++
pom.xml | 2 +-
3 files changed, 283 insertions(+), 237 deletions(-)
create mode 100644 handler/src/main/java/io/netty/handler/ssl/SslClientHelloHandler.java
diff --git a/handler/src/main/java/io/netty/handler/ssl/AbstractSniHandler.java b/handler/src/main/java/io/netty/handler/ssl/AbstractSniHandler.java
index e9325144a6..46f3a35e51 100644
--- a/handler/src/main/java/io/netty/handler/ssl/AbstractSniHandler.java
+++ b/handler/src/main/java/io/netty/handler/ssl/AbstractSniHandler.java
@@ -16,18 +16,11 @@
package io.netty.handler.ssl;
import io.netty.buffer.ByteBuf;
-import io.netty.buffer.ByteBufUtil;
import io.netty.channel.ChannelHandlerContext;
-import io.netty.channel.ChannelPromise;
-import io.netty.handler.codec.ByteToMessageDecoder;
-import io.netty.handler.codec.DecoderException;
+
import io.netty.util.CharsetUtil;
import io.netty.util.concurrent.Future;
-import io.netty.util.concurrent.FutureListener;
-import io.netty.util.internal.logging.InternalLogger;
-import io.netty.util.internal.logging.InternalLoggerFactory;
-import java.net.SocketAddress;
import java.util.Locale;
/**
@@ -37,143 +30,9 @@ import java.util.Locale;
* The client will send host name in the handshake data so server could decide
* which certificate to choose for the host name.
*/
-public abstract class AbstractSniHandler extends ByteToMessageDecoder {
+public abstract class AbstractSniHandler extends SslClientHelloHandler {
- private static final InternalLogger logger =
- InternalLoggerFactory.getInstance(AbstractSniHandler.class);
-
- private boolean handshakeFailed;
- private boolean suppressRead;
- private boolean readPending;
- private ByteBuf handshakeBuffer;
-
- @Override
- protected void decode(ChannelHandlerContext ctx, ByteBuf in) throws Exception {
- if (!suppressRead && !handshakeFailed) {
- try {
- int readerIndex = in.readerIndex();
- int readableBytes = in.readableBytes();
- int handshakeLength = -1;
-
- // Check if we have enough data to determine the record type and length.
- while (readableBytes >= SslUtils.SSL_RECORD_HEADER_LENGTH) {
- final int contentType = in.getUnsignedByte(readerIndex);
- switch (contentType) {
- case SslUtils.SSL_CONTENT_TYPE_CHANGE_CIPHER_SPEC:
- // fall-through
- case SslUtils.SSL_CONTENT_TYPE_ALERT:
- final int len = SslUtils.getEncryptedPacketLength(in, readerIndex);
-
- // Not an SSL/TLS packet
- if (len == SslUtils.NOT_ENCRYPTED) {
- handshakeFailed = true;
- NotSslRecordException e = new NotSslRecordException(
- "not an SSL/TLS record: " + ByteBufUtil.hexDump(in));
- in.skipBytes(in.readableBytes());
- ctx.fireUserEventTriggered(new SniCompletionEvent(e));
- SslUtils.handleHandshakeFailure(ctx, e, true);
- throw e;
- }
- if (len == SslUtils.NOT_ENOUGH_DATA) {
- // Not enough data
- return;
- }
- // SNI can't be present in an ALERT or CHANGE_CIPHER_SPEC record, so we'll fall back and
- // assume no SNI is present. Let's let the actual TLS implementation sort this out.
- // Just select the default SslContext
- select(ctx, null);
- return;
- case SslUtils.SSL_CONTENT_TYPE_HANDSHAKE:
- final int majorVersion = in.getUnsignedByte(readerIndex + 1);
- // SSLv3 or TLS
- if (majorVersion == 3) {
- int packetLength = in.getUnsignedShort(readerIndex + 3) +
- SslUtils.SSL_RECORD_HEADER_LENGTH;
-
- if (readableBytes < packetLength) {
- // client hello incomplete; try again to decode once more data is ready.
- return;
- } else if (packetLength == SslUtils.SSL_RECORD_HEADER_LENGTH) {
- select(ctx, null);
- return;
- }
-
- final int endOffset = readerIndex + packetLength;
-
- // Let's check if we already parsed the handshake length or not.
- if (handshakeLength == -1) {
- if (readerIndex + 4 > endOffset) {
- // Need more data to read HandshakeType and handshakeLength (4 bytes)
- return;
- }
-
- final int handshakeType = in.getUnsignedByte(readerIndex +
- SslUtils.SSL_RECORD_HEADER_LENGTH);
-
- // Check if this is a clientHello(1)
- // See https://tools.ietf.org/html/rfc5246#section-7.4
- if (handshakeType != 1) {
- select(ctx, null);
- return;
- }
-
- // Read the length of the handshake as it may arrive in fragments
- // See https://tools.ietf.org/html/rfc5246#section-7.4
- handshakeLength = in.getUnsignedMedium(readerIndex +
- SslUtils.SSL_RECORD_HEADER_LENGTH + 1);
-
- // Consume handshakeType and handshakeLength (this sums up as 4 bytes)
- readerIndex += 4;
- packetLength -= 4;
-
- if (handshakeLength + 4 + SslUtils.SSL_RECORD_HEADER_LENGTH <= packetLength) {
- // We have everything we need in one packet.
- // Skip the record header
- readerIndex += SslUtils.SSL_RECORD_HEADER_LENGTH;
- select(ctx, extractSniHostname(in, readerIndex, readerIndex + handshakeLength));
- return;
- } else {
- if (handshakeBuffer == null) {
- handshakeBuffer = ctx.alloc().buffer(handshakeLength);
- } else {
- // Clear the buffer so we can aggregate into it again.
- handshakeBuffer.clear();
- }
- }
- }
-
- // Combine the encapsulated data in one buffer but not include the SSL_RECORD_HEADER
- handshakeBuffer.writeBytes(in, readerIndex + SslUtils.SSL_RECORD_HEADER_LENGTH,
- packetLength - SslUtils.SSL_RECORD_HEADER_LENGTH);
- readerIndex += packetLength;
- readableBytes -= packetLength;
- if (handshakeLength <= handshakeBuffer.readableBytes()) {
- select(ctx, extractSniHostname(handshakeBuffer, 0, handshakeLength));
- return;
- }
- break;
- }
- // fall-through
- default:
- // not tls, ssl or application data, do not try sni
- select(ctx, null);
- return;
- }
- }
- } catch (NotSslRecordException e) {
- // Just rethrow as in this case we also closed the channel and this is consistent with SslHandler.
- throw e;
- } catch (Exception e) {
- // unexpected encoding, ignore sni and use default
- if (logger.isDebugEnabled()) {
- logger.debug("Unexpected client hello packet: " + ByteBufUtil.hexDump(in), e);
- }
- select(ctx, null);
- }
- }
- }
-
- private static String extractSniHostname(ByteBuf in, int offset, int endOffset) {
+ private static String extractSniHostname(ByteBuf in) {
// See https://tools.ietf.org/html/rfc5246#section-7.4.1.2
//
// Decode the ssl client hello packet.
@@ -194,6 +53,8 @@ public abstract class AbstractSniHandler extends ByteToMessageDecoder {
//
// We have to skip bytes until SessionID (which sum to 34 bytes in this case).
+ int offset = in.readerIndex();
+ int endOffset = in.writerIndex();
offset += 34;
if (endOffset - offset >= 6) {
@@ -257,57 +118,19 @@ public abstract class AbstractSniHandler extends ByteToMessageDecoder {
return null;
}
- private void releaseHandshakeBuffer() {
- if (handshakeBuffer != null) {
- handshakeBuffer.release();
- handshakeBuffer = null;
- }
- }
+ private String hostname;
- private void select(final ChannelHandlerContext ctx, final String hostname) throws Exception {
- releaseHandshakeBuffer();
+ @Override
+ protected Future lookup(ChannelHandlerContext ctx, ByteBuf clientHello) throws Exception {
+ hostname = clientHello == null ? null : extractSniHostname(clientHello);
- Future future = lookup(ctx, hostname);
- if (future.isDone()) {
- fireSniCompletionEvent(ctx, hostname, future);
- onLookupComplete(ctx, hostname, future);
- } else {
- suppressRead = true;
- future.addListener((FutureListener) future1 -> {
- suppressRead = false;
- try {
- fireSniCompletionEvent(ctx, hostname, future1);
- onLookupComplete(ctx, hostname, future1);
- } catch (DecoderException err) {
- ctx.fireExceptionCaught(err);
- } catch (Exception cause) {
- ctx.fireExceptionCaught(new DecoderException(cause));
- } catch (Throwable cause) {
- ctx.fireExceptionCaught(cause);
- } finally {
- if (readPending) {
- readPending = false;
- ctx.read();
- }
- }
- });
- }
+ return lookup(ctx, hostname);
}
@Override
- protected void handlerRemoved0(ChannelHandlerContext ctx) throws Exception {
- releaseHandshakeBuffer();
-
- super.handlerRemoved0(ctx);
- }
-
- private void fireSniCompletionEvent(ChannelHandlerContext ctx, String hostname, Future future) {
- Throwable cause = future.cause();
- if (cause == null) {
- ctx.fireUserEventTriggered(new SniCompletionEvent(hostname));
- } else {
- ctx.fireUserEventTriggered(new SniCompletionEvent(hostname, cause));
- }
+ protected void onLookupComplete(ChannelHandlerContext ctx, Future future) throws Exception {
+ fireSniCompletionEvent(ctx, hostname, future);
+ onLookupComplete(ctx, hostname, future);
}
/**
@@ -326,53 +149,12 @@ public abstract class AbstractSniHandler extends ByteToMessageDecoder {
protected abstract void onLookupComplete(ChannelHandlerContext ctx,
String hostname, Future future) throws Exception;
- @Override
- public void read(ChannelHandlerContext ctx) throws Exception {
- if (suppressRead) {
- readPending = true;
+ private void fireSniCompletionEvent(ChannelHandlerContext ctx, String hostname, Future future) {
+ Throwable cause = future.cause();
+ if (cause == null) {
+ ctx.fireUserEventTriggered(new SniCompletionEvent(hostname));
} else {
- ctx.read();
+ ctx.fireUserEventTriggered(new SniCompletionEvent(hostname, cause));
}
}
-
- @Override
- public void bind(ChannelHandlerContext ctx, SocketAddress localAddress, ChannelPromise promise) throws Exception {
- ctx.bind(localAddress, promise);
- }
-
- @Override
- public void connect(ChannelHandlerContext ctx, SocketAddress remoteAddress, SocketAddress localAddress,
- ChannelPromise promise) throws Exception {
- ctx.connect(remoteAddress, localAddress, promise);
- }
-
- @Override
- public void disconnect(ChannelHandlerContext ctx, ChannelPromise promise) throws Exception {
- ctx.disconnect(promise);
- }
-
- @Override
- public void close(ChannelHandlerContext ctx, ChannelPromise promise) throws Exception {
- ctx.close(promise);
- }
-
- @Override
- public void register(ChannelHandlerContext ctx, ChannelPromise promise) throws Exception {
- ctx.register(promise);
- }
-
- @Override
- public void deregister(ChannelHandlerContext ctx, ChannelPromise promise) throws Exception {
- ctx.deregister(promise);
- }
-
- @Override
- public void write(ChannelHandlerContext ctx, Object msg, ChannelPromise promise) throws Exception {
- ctx.write(msg, promise);
- }
-
- @Override
- public void flush(ChannelHandlerContext ctx) throws Exception {
- ctx.flush();
- }
}
diff --git a/handler/src/main/java/io/netty/handler/ssl/SslClientHelloHandler.java b/handler/src/main/java/io/netty/handler/ssl/SslClientHelloHandler.java
new file mode 100644
index 0000000000..7ff65edb05
--- /dev/null
+++ b/handler/src/main/java/io/netty/handler/ssl/SslClientHelloHandler.java
@@ -0,0 +1,264 @@
+/*
+ * Copyright 2017 The Netty Project
+ *
+ * The Netty Project licenses this file to you under the Apache License,
+ * version 2.0 (the "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at:
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
+ * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
+ * License for the specific language governing permissions and limitations
+ * under the License.
+ */
+package io.netty.handler.ssl;
+
+import io.netty.buffer.ByteBuf;
+import io.netty.buffer.ByteBufUtil;
+import io.netty.channel.ChannelHandlerContext;
+import io.netty.handler.codec.ByteToMessageDecoder;
+import io.netty.handler.codec.DecoderException;
+import io.netty.util.concurrent.Future;
+import io.netty.util.concurrent.FutureListener;
+import io.netty.util.internal.PlatformDependent;
+import io.netty.util.internal.logging.InternalLogger;
+import io.netty.util.internal.logging.InternalLoggerFactory;
+
+/**
+ * {@link ByteToMessageDecoder} which allows to be notified once a full {@code ClientHello} was received.
+ */
+public abstract class SslClientHelloHandler extends ByteToMessageDecoder {
+
+ private static final InternalLogger logger =
+ InternalLoggerFactory.getInstance(SslClientHelloHandler.class);
+
+ private boolean handshakeFailed;
+ private boolean suppressRead;
+ private boolean readPending;
+ private ByteBuf handshakeBuffer;
+
+ @Override
+ protected void decode(ChannelHandlerContext ctx, ByteBuf in) throws Exception {
+ if (!suppressRead && !handshakeFailed) {
+ try {
+ int readerIndex = in.readerIndex();
+ int readableBytes = in.readableBytes();
+ int handshakeLength = -1;
+
+ // Check if we have enough data to determine the record type and length.
+ while (readableBytes >= SslUtils.SSL_RECORD_HEADER_LENGTH) {
+ final int contentType = in.getUnsignedByte(readerIndex);
+ switch (contentType) {
+ case SslUtils.SSL_CONTENT_TYPE_CHANGE_CIPHER_SPEC:
+ // fall-through
+ case SslUtils.SSL_CONTENT_TYPE_ALERT:
+ final int len = SslUtils.getEncryptedPacketLength(in, readerIndex);
+
+ // Not an SSL/TLS packet
+ if (len == SslUtils.NOT_ENCRYPTED) {
+ handshakeFailed = true;
+ NotSslRecordException e = new NotSslRecordException(
+ "not an SSL/TLS record: " + ByteBufUtil.hexDump(in));
+ in.skipBytes(in.readableBytes());
+ ctx.fireUserEventTriggered(new SniCompletionEvent(e));
+ SslUtils.handleHandshakeFailure(ctx, e, true);
+ throw e;
+ }
+ if (len == SslUtils.NOT_ENOUGH_DATA) {
+ // Not enough data
+ return;
+ }
+ // No ClientHello
+ select(ctx, null);
+ return;
+ case SslUtils.SSL_CONTENT_TYPE_HANDSHAKE:
+ final int majorVersion = in.getUnsignedByte(readerIndex + 1);
+ // SSLv3 or TLS
+ if (majorVersion == 3) {
+ int packetLength = in.getUnsignedShort(readerIndex + 3) +
+ SslUtils.SSL_RECORD_HEADER_LENGTH;
+
+ if (readableBytes < packetLength) {
+ // client hello incomplete; try again to decode once more data is ready.
+ return;
+ } else if (packetLength == SslUtils.SSL_RECORD_HEADER_LENGTH) {
+ select(ctx, null);
+ return;
+ }
+
+ final int endOffset = readerIndex + packetLength;
+
+ // Let's check if we already parsed the handshake length or not.
+ if (handshakeLength == -1) {
+ if (readerIndex + 4 > endOffset) {
+ // Need more data to read HandshakeType and handshakeLength (4 bytes)
+ return;
+ }
+
+ final int handshakeType = in.getUnsignedByte(readerIndex +
+ SslUtils.SSL_RECORD_HEADER_LENGTH);
+
+ // Check if this is a clientHello(1)
+ // See https://tools.ietf.org/html/rfc5246#section-7.4
+ if (handshakeType != 1) {
+ select(ctx, null);
+ return;
+ }
+
+ // Read the length of the handshake as it may arrive in fragments
+ // See https://tools.ietf.org/html/rfc5246#section-7.4
+ handshakeLength = in.getUnsignedMedium(readerIndex +
+ SslUtils.SSL_RECORD_HEADER_LENGTH + 1);
+
+ // Consume handshakeType and handshakeLength (this sums up as 4 bytes)
+ readerIndex += 4;
+ packetLength -= 4;
+
+ if (handshakeLength + 4 + SslUtils.SSL_RECORD_HEADER_LENGTH <= packetLength) {
+ // We have everything we need in one packet.
+ // Skip the record header
+ readerIndex += SslUtils.SSL_RECORD_HEADER_LENGTH;
+ select(ctx, in.retainedSlice(readerIndex, handshakeLength));
+ return;
+ } else {
+ if (handshakeBuffer == null) {
+ handshakeBuffer = ctx.alloc().buffer(handshakeLength);
+ } else {
+ // Clear the buffer so we can aggregate into it again.
+ handshakeBuffer.clear();
+ }
+ }
+ }
+
+ // Combine the encapsulated data in one buffer but not include the SSL_RECORD_HEADER
+ handshakeBuffer.writeBytes(in, readerIndex + SslUtils.SSL_RECORD_HEADER_LENGTH,
+ packetLength - SslUtils.SSL_RECORD_HEADER_LENGTH);
+ readerIndex += packetLength;
+ readableBytes -= packetLength;
+ if (handshakeLength <= handshakeBuffer.readableBytes()) {
+ ByteBuf clientHello = handshakeBuffer.setIndex(0, handshakeLength);
+ handshakeBuffer = null;
+
+ select(ctx, clientHello);
+ return;
+ }
+ break;
+ }
+ // fall-through
+ default:
+ // not tls, ssl or application data
+ select(ctx, null);
+ return;
+ }
+ }
+ } catch (NotSslRecordException e) {
+ // Just rethrow as in this case we also closed the channel and this is consistent with SslHandler.
+ throw e;
+ } catch (Exception e) {
+ // unexpected encoding, ignore sni and use default
+ if (logger.isDebugEnabled()) {
+ logger.debug("Unexpected client hello packet: " + ByteBufUtil.hexDump(in), e);
+ }
+ select(ctx, null);
+ }
+ }
+ }
+
+ private void releaseHandshakeBuffer() {
+ releaseIfNotNull(handshakeBuffer);
+ handshakeBuffer = null;
+ }
+
+ private static void releaseIfNotNull(ByteBuf buffer) {
+ if (buffer != null) {
+ buffer.release();
+ }
+ }
+
+ private void select(final ChannelHandlerContext ctx, final ByteBuf clientHello) {
+ final Future future;
+ try {
+ future = lookup(ctx, clientHello);
+ if (future.isDone()) {
+ releaseIfNotNull(clientHello);
+ onLookupComplete(ctx, future);
+ } else {
+ suppressRead = true;
+ future.addListener((FutureListener) f -> {
+ releaseIfNotNull(clientHello);
+ try {
+ suppressRead = false;
+ try {
+ onLookupComplete(ctx, f);
+ } catch (DecoderException err) {
+ ctx.fireExceptionCaught(err);
+ } catch (Exception cause) {
+ ctx.fireExceptionCaught(new DecoderException(cause));
+ } catch (Throwable cause) {
+ ctx.fireExceptionCaught(cause);
+ }
+ } finally {
+ if (readPending) {
+ readPending = false;
+ ctx.read();
+ }
+ }
+ });
+ }
+ } catch (Throwable cause) {
+ releaseIfNotNull(clientHello);
+ PlatformDependent.throwException(cause);
+ }
+ }
+
+ @Override
+ protected void handlerRemoved0(ChannelHandlerContext ctx) throws Exception {
+ releaseHandshakeBuffer();
+
+ super.handlerRemoved0(ctx);
+ }
+
+ /**
+ * Kicks off a lookup for the given {@code ClientHello} and returns a {@link Future} which in turn will
+ * notify the {@link #onLookupComplete(ChannelHandlerContext, Future)} on completion.
+ *
+ * See https://tools.ietf.org/html/rfc5246#section-7.4.1.2
+ *
+ *
+ * struct {
+ * ProtocolVersion client_version;
+ * Random random;
+ * SessionID session_id;
+ * CipherSuite cipher_suites<2..2^16-2>;
+ * CompressionMethod compression_methods<1..2^8-1>;
+ * select (extensions_present) {
+ * case false:
+ * struct {};
+ * case true:
+ * Extension extensions<0..2^16-1>;
+ * };
+ * } ClientHello;
+ *
+ *
+ * @see #onLookupComplete(ChannelHandlerContext, Future)
+ */
+ protected abstract Future lookup(ChannelHandlerContext ctx, ByteBuf clientHello) throws Exception;
+
+ /**
+ * Called upon completion of the {@link #lookup(ChannelHandlerContext, ByteBuf)} {@link Future}.
+ *
+ * @see #lookup(ChannelHandlerContext, ByteBuf)
+ */
+ protected abstract void onLookupComplete(ChannelHandlerContext ctx, Future future) throws Exception;
+
+ @Override
+ public void read(ChannelHandlerContext ctx) throws Exception {
+ if (suppressRead) {
+ readPending = true;
+ } else {
+ ctx.read();
+ }
+ }
+}
diff --git a/pom.xml b/pom.xml
index 33c65324c0..1869aac793 100644
--- a/pom.xml
+++ b/pom.xml
@@ -710,7 +710,7 @@
com.github.siom79.japicmp
japicmp-maven-plugin
- 0.13.1
+ 0.14.3
true