From f7b3caeddc5bb1da75aaafa4a66dec88ed585d69 Mon Sep 17 00:00:00 2001 From: Scott Mitchell Date: Fri, 3 Feb 2017 17:54:13 -0800 Subject: [PATCH] OpenSslEngine option to wrap/unwrap multiple packets per call Motivation: The JDK SSLEngine documentation says that a call to wrap/unwrap "will attempt to consume one complete SSL/TLS network packet" [1]. This limitation can result in thrashing in the pipeline to decode and encode data that may be spread amongst multiple SSL/TLS network packets. ReferenceCountedOpenSslEngine also does not correct account for the overhead introduced by each individual SSL_write call if there are multiple ByteBuffers passed to the wrap() method. Modifications: - OpenSslEngine and SslHandler supports a mode to not comply with the limitation to only deal with a single SSL/TLS network packet per call - ReferenceCountedOpenSslEngine correctly accounts for the overhead of each call to SSL_write - SslHandler shouldn't cache maxPacketBufferSize as aggressively because this value may change before/after the handshake. Result: OpenSslEngine and SslHanadler can handle multiple SSL/TLS network packet per call. [1] https://docs.oracle.com/javase/7/docs/api/javax/net/ssl/SSLEngine.html --- .../io/netty/handler/ssl/OpenSslContext.java | 4 +- .../io/netty/handler/ssl/OpenSslEngine.java | 5 +- .../ssl/ReferenceCountedOpenSslContext.java | 16 +- .../ssl/ReferenceCountedOpenSslEngine.java | 293 +++++++++----- .../java/io/netty/handler/ssl/SslContext.java | 35 +- .../java/io/netty/handler/ssl/SslHandler.java | 216 +++++----- .../netty/handler/ssl/OpenSslEngineTest.java | 368 +++++++++++++++++- .../io/netty/handler/ssl/SniHandlerTest.java | 52 ++- .../io/netty/handler/ssl/SslHandlerTest.java | 10 +- 9 files changed, 748 insertions(+), 251 deletions(-) diff --git a/handler/src/main/java/io/netty/handler/ssl/OpenSslContext.java b/handler/src/main/java/io/netty/handler/ssl/OpenSslContext.java index c4ca6b5197..f18c0643fc 100644 --- a/handler/src/main/java/io/netty/handler/ssl/OpenSslContext.java +++ b/handler/src/main/java/io/netty/handler/ssl/OpenSslContext.java @@ -45,8 +45,8 @@ public abstract class OpenSslContext extends ReferenceCountedOpenSslContext { } @Override - final SSLEngine newEngine0(ByteBufAllocator alloc, String peerHost, int peerPort) { - return new OpenSslEngine(this, alloc, peerHost, peerPort); + final SSLEngine newEngine0(ByteBufAllocator alloc, String peerHost, int peerPort, boolean jdkCompatibilityMode) { + return new OpenSslEngine(this, alloc, peerHost, peerPort, jdkCompatibilityMode); } @Override diff --git a/handler/src/main/java/io/netty/handler/ssl/OpenSslEngine.java b/handler/src/main/java/io/netty/handler/ssl/OpenSslEngine.java index cbc7ee411c..a700dabf39 100644 --- a/handler/src/main/java/io/netty/handler/ssl/OpenSslEngine.java +++ b/handler/src/main/java/io/netty/handler/ssl/OpenSslEngine.java @@ -27,8 +27,9 @@ import javax.net.ssl.SSLEngine; * and manually release the native memory see {@link ReferenceCountedOpenSslEngine}. */ public final class OpenSslEngine extends ReferenceCountedOpenSslEngine { - OpenSslEngine(OpenSslContext context, ByteBufAllocator alloc, String peerHost, int peerPort) { - super(context, alloc, peerHost, peerPort, false); + OpenSslEngine(OpenSslContext context, ByteBufAllocator alloc, String peerHost, int peerPort, + boolean jdkCompatibilityMode) { + super(context, alloc, peerHost, peerPort, jdkCompatibilityMode, false); } @Override diff --git a/handler/src/main/java/io/netty/handler/ssl/ReferenceCountedOpenSslContext.java b/handler/src/main/java/io/netty/handler/ssl/ReferenceCountedOpenSslContext.java index ee049ab19d..35839c9d06 100644 --- a/handler/src/main/java/io/netty/handler/ssl/ReferenceCountedOpenSslContext.java +++ b/handler/src/main/java/io/netty/handler/ssl/ReferenceCountedOpenSslContext.java @@ -409,11 +409,21 @@ public abstract class ReferenceCountedOpenSslContext extends SslContext implemen @Override public final SSLEngine newEngine(ByteBufAllocator alloc, String peerHost, int peerPort) { - return newEngine0(alloc, peerHost, peerPort); + return newEngine0(alloc, peerHost, peerPort, true); } - SSLEngine newEngine0(ByteBufAllocator alloc, String peerHost, int peerPort) { - return new ReferenceCountedOpenSslEngine(this, alloc, peerHost, peerPort, true); + @Override + final SslHandler newHandler(ByteBufAllocator alloc, boolean startTls) { + return new SslHandler(newEngine0(alloc, null, -1, false), startTls, false); + } + + @Override + final SslHandler newHandler(ByteBufAllocator alloc, String peerHost, int peerPort, boolean startTls) { + return new SslHandler(newEngine0(alloc, peerHost, peerPort, false), startTls, false); + } + + SSLEngine newEngine0(ByteBufAllocator alloc, String peerHost, int peerPort, boolean jdkCompatibilityMode) { + return new ReferenceCountedOpenSslEngine(this, alloc, peerHost, peerPort, jdkCompatibilityMode, true); } abstract OpenSslKeyMaterialManager keyMaterialManager(); diff --git a/handler/src/main/java/io/netty/handler/ssl/ReferenceCountedOpenSslEngine.java b/handler/src/main/java/io/netty/handler/ssl/ReferenceCountedOpenSslEngine.java index 27460c7149..5f877b8201 100644 --- a/handler/src/main/java/io/netty/handler/ssl/ReferenceCountedOpenSslEngine.java +++ b/handler/src/main/java/io/netty/handler/ssl/ReferenceCountedOpenSslEngine.java @@ -43,8 +43,8 @@ import java.util.HashMap; import java.util.List; import java.util.Map; import java.util.concurrent.atomic.AtomicIntegerFieldUpdater; - import java.util.concurrent.locks.Lock; + import javax.net.ssl.SSLEngine; import javax.net.ssl.SSLEngineResult; import javax.net.ssl.SSLException; @@ -58,10 +58,10 @@ import javax.net.ssl.SSLSessionContext; import javax.security.cert.X509Certificate; import static io.netty.handler.ssl.OpenSsl.memoryAddress; -import static io.netty.handler.ssl.SslUtils.SSL_RECORD_HEADER_LENGTH; import static io.netty.util.internal.EmptyArrays.EMPTY_CERTIFICATES; import static io.netty.util.internal.EmptyArrays.EMPTY_JAVAX_X509_CERTIFICATES; import static io.netty.util.internal.ObjectUtil.checkNotNull; +import static java.lang.Integer.MAX_VALUE; import static java.lang.Math.min; import static javax.net.ssl.SSLEngineResult.HandshakeStatus.FINISHED; import static javax.net.ssl.SSLEngineResult.HandshakeStatus.NEED_UNWRAP; @@ -98,25 +98,10 @@ public class ReferenceCountedOpenSslEngine extends SSLEngine implements Referenc */ private static final int DEFAULT_HOSTNAME_VALIDATION_FLAGS = 0; - static final int MAX_PLAINTEXT_LENGTH = 16 * 1024; // 2^14 - /** - * This is the maximum overhead when encrypting plaintext as defined by - * rfc5264, - * rfc5289 and openssl implementation itself. - * - * Please note that we use a padding of 16 here as openssl uses PKC#5 which uses 16 bytes while the spec itself - * allow up to 255 bytes. 16 bytes is the max for PKC#5 (which handles it the same way as PKC#7) as we use a block - * size of 16. See rfc5652#section-6.3. - * - * TLS Header (5) + 16 (IV) + 48 (MAC) + 1 (Padding_length field) + 15 (Padding) + 1 (ContentType) + - * 2 (ProtocolVersion) + 2 (Length) - * - * TODO: We may need to review this calculation once TLS 1.3 becomes available. + * Depends upon tcnative ... only use if tcnative is available! */ - static final int MAX_TLS_RECORD_OVERHEAD_LENGTH = SSL_RECORD_HEADER_LENGTH + 16 + 48 + 1 + 15 + 1 + 2 + 2; - - static final int MAX_ENCRYPTED_PACKET_LENGTH = MAX_PLAINTEXT_LENGTH + MAX_TLS_RECORD_OVERHEAD_LENGTH; + static final int MAX_PLAINTEXT_LENGTH = SSL.SSL_MAX_PLAINTEXT_LENGTH; private static final AtomicIntegerFieldUpdater DESTROYED_UPDATER = AtomicIntegerFieldUpdater.newUpdater(ReferenceCountedOpenSslEngine.class, "destroyed"); @@ -146,7 +131,6 @@ public class ReferenceCountedOpenSslEngine extends SSLEngine implements Referenc * Started via {@link #beginHandshake()}. */ STARTED_EXPLICITLY, - /** * Handshake is finished. */ @@ -198,6 +182,7 @@ public class ReferenceCountedOpenSslEngine extends SSLEngine implements Referenc private boolean isInboundDone; private boolean outboundClosed; + private final boolean jdkCompatibilityMode; private final boolean clientMode; private final ByteBufAllocator alloc; private final OpenSslEngineMap engineMap; @@ -209,6 +194,8 @@ public class ReferenceCountedOpenSslEngine extends SSLEngine implements Referenc private final ByteBuffer[] singleDstBuffer = new ByteBuffer[1]; private final OpenSslKeyMaterialManager keyMaterialManager; private final boolean enableOcsp; + private int maxWrapOverhead; + private int maxWrapBufferSize; // This is package-private as we set it from OpenSslContext if an exception is thrown during // the verification step. @@ -220,10 +207,14 @@ public class ReferenceCountedOpenSslEngine extends SSLEngine implements Referenc * @param alloc The allocator to use. * @param peerHost The peer host name. * @param peerPort The peer port. + * @param jdkCompatibilityMode {@code true} to behave like described in + * https://docs.oracle.com/javase/7/docs/api/javax/net/ssl/SSLEngine.html. + * {@code false} allows for partial and/or multiple packets to be process in a single + * wrap or unwrap call. * @param leakDetection {@code true} to enable leak detection of this object. */ ReferenceCountedOpenSslEngine(ReferenceCountedOpenSslContext context, ByteBufAllocator alloc, String peerHost, - int peerPort, boolean leakDetection) { + int peerPort, boolean jdkCompatibilityMode, boolean leakDetection) { super(peerHost, peerPort); OpenSsl.ensureAvailability(); leak = leakDetection ? leakDetector.track(this) : null; @@ -236,7 +227,7 @@ public class ReferenceCountedOpenSslEngine extends SSLEngine implements Referenc localCerts = context.keyCertChain; keyMaterialManager = context.keyMaterialManager(); enableOcsp = context.enableOcsp; - + this.jdkCompatibilityMode = jdkCompatibilityMode; Lock readerLock = context.ctxLock.readLock(); readerLock.lock(); try { @@ -244,6 +235,7 @@ public class ReferenceCountedOpenSslEngine extends SSLEngine implements Referenc } finally { readerLock.unlock(); } + ssl = SSL.newSSL(context.ctx, !context.isClient()); try { networkBIO = SSL.bioNewByteBuffer(ssl, context.getBioNonApplicationBufferSize()); @@ -264,6 +256,13 @@ public class ReferenceCountedOpenSslEngine extends SSLEngine implements Referenc if (enableOcsp) { SSL.enableOcsp(ssl); } + + if (!jdkCompatibilityMode) { + SSL.setMode(ssl, SSL.getMode(ssl) | SSL.SSL_MODE_ENABLE_PARTIAL_WRITE); + } + + // setMode may impact the overhead. + calculateMaxWrapOverhead(); } catch (Throwable cause) { SSL.freeSSL(ssl); PlatformDependent.throwException(cause); @@ -461,7 +460,7 @@ public class ReferenceCountedOpenSslEngine extends SSLEngine implements Referenc } } else { final int limit = dst.limit(); - final int len = min(MAX_ENCRYPTED_PACKET_LENGTH, limit - pos); + final int len = min(maxEncryptedPacketLength0(), limit - pos); final ByteBuf buf = alloc.directBuffer(len); try { sslRead = SSL.readFromSSL(ssl, memoryAddress(buf), len); @@ -478,6 +477,65 @@ public class ReferenceCountedOpenSslEngine extends SSLEngine implements Referenc return sslRead; } + /** + * Visible only for testing! + */ + final synchronized int maxWrapOverhead() { + return maxWrapOverhead; + } + + /** + * Visible only for testing! + */ + final synchronized int maxEncryptedPacketLength() { + return maxEncryptedPacketLength0(); + } + + /** + * This method is intentionally not synchronized, only use if you know you are in the EventLoop + * thread and visibility on {@link #maxWrapOverhead} is achieved via other synchronized blocks. + */ + final int maxEncryptedPacketLength0() { + return maxWrapOverhead + MAX_PLAINTEXT_LENGTH; + } + + /** + * This method is intentionally not synchronized, only use if you know you are in the EventLoop + * thread and visibility on {@link #maxWrapBufferSize} and {@link #maxWrapOverhead} is achieved + * via other synchronized blocks. + */ + final int calculateMaxLengthForWrap(int plaintextLength, int numComponents) { + return (int) min(maxWrapBufferSize, plaintextLength + (long) maxWrapOverhead * numComponents); + } + + final synchronized int sslPending() { + return sslPending0(); + } + + /** + * It is assumed this method is called in a synchronized block (or the constructor)! + */ + private void calculateMaxWrapOverhead() { + maxWrapOverhead = SSL.getMaxWrapOverhead(ssl); + + // maxWrapBufferSize must be set after maxWrapOverhead because there is a dependency on this value. + // If jdkCompatibility mode is off we allow enough space to encrypt 16 buffers at a time. This could be + // configurable in the future if necessary. + maxWrapBufferSize = jdkCompatibilityMode ? maxEncryptedPacketLength0() : maxEncryptedPacketLength0() << 4; + } + + private int sslPending0() { + // OpenSSL has a limitation where if you call SSL_pending before the handshake is complete OpenSSL will throw a + // "called a function you should not call" error. Using the TLS_method instead of SSLv23_method may solve this + // issue but this API is only available in 1.1.0+ [1]. + // [1] https://www.openssl.org/docs/man1.1.0/ssl/SSL_CTX_new.html + return handshakeState != HandshakeState.FINISHED ? 0 : SSL.sslPending(ssl); + } + + private boolean isBytesAvailableEnoughForWrap(int bytesAvailable, int plaintextLength, int numComponents) { + return bytesAvailable - (long) maxWrapOverhead * numComponents >= plaintextLength; + } + @Override public final SSLEngineResult wrap( final ByteBuffer[] srcs, int offset, final int length, final ByteBuffer dst) throws SSLException { @@ -602,32 +660,34 @@ public class ReferenceCountedOpenSslEngine extends SSLEngine implements Referenc } } - int srcsLen = 0; final int endOffset = offset + length; - for (int i = offset; i < endOffset; ++i) { - final ByteBuffer src = srcs[i]; - if (src == null) { - throw new IllegalArgumentException("srcs[" + i + "] is null"); - } - if (srcsLen == MAX_PLAINTEXT_LENGTH) { - continue; + if (jdkCompatibilityMode) { + int srcsLen = 0; + for (int i = offset; i < endOffset; ++i) { + final ByteBuffer src = srcs[i]; + if (src == null) { + throw new IllegalArgumentException("srcs[" + i + "] is null"); + } + if (srcsLen == MAX_PLAINTEXT_LENGTH) { + continue; + } + + srcsLen += src.remaining(); + if (srcsLen > MAX_PLAINTEXT_LENGTH || srcsLen < 0) { + // If srcLen > MAX_PLAINTEXT_LENGTH or secLen < 0 just set it to MAX_PLAINTEXT_LENGTH. + // This also help us to guard against overflow. + // We not break out here as we still need to check for null entries in srcs[]. + srcsLen = MAX_PLAINTEXT_LENGTH; + } } - srcsLen += src.remaining(); - if (srcsLen > MAX_PLAINTEXT_LENGTH || srcsLen < 0) { - // If srcLen > MAX_PLAINTEXT_LENGTH or secLen < 0 just set it to MAX_PLAINTEXT_LENGTH. - // This also help us to guard against overflow. - // We not break out here as we still need to check for null entries in srcs[]. - srcsLen = MAX_PLAINTEXT_LENGTH; + // jdkCompatibilityMode will only produce a single TLS packet, and we don't aggregate src buffers, + // so we always fix the number of buffers to 1 when checking if the dst buffer is large enough. + if (!isBytesAvailableEnoughForWrap(dst.remaining(), srcsLen, 1)) { + return new SSLEngineResult(BUFFER_OVERFLOW, getHandshakeStatus(), 0, 0); } } - // we will only produce a single TLS packet, and we don't aggregate src buffers, - // so we always fix the number of buffers to 1 when checking if the dst buffer is large enough. - if (dst.remaining() < calculateOutNetBufSize(srcsLen, 1)) { - return new SSLEngineResult(BUFFER_OVERFLOW, getHandshakeStatus(), 0, 0); - } - // There was no pending data in the network BIO -- encrypt any application data int bytesConsumed = 0; // Flush any data that may have been written implicitly by OpenSSL in case a shutdown/alert occurs. @@ -639,8 +699,23 @@ public class ReferenceCountedOpenSslEngine extends SSLEngine implements Referenc continue; } - // Write plaintext application data to the SSL engine - int bytesWritten = writePlaintextData(src, min(remaining, MAX_PLAINTEXT_LENGTH - bytesConsumed)); + final int bytesWritten; + if (jdkCompatibilityMode) { + // Write plaintext application data to the SSL engine. We don't have to worry about checking + // if there is enough space if jdkCompatibilityMode because we only wrap at most + // MAX_PLAINTEXT_LENGTH and we loop over the input before hand and check if there is space. + bytesWritten = writePlaintextData(src, min(remaining, MAX_PLAINTEXT_LENGTH - bytesConsumed)); + } else { + // OpenSSL's SSL_write keeps state between calls. We should make sure the amount we attempt to + // write is guaranteed to succeed so we don't have to worry about keeping state consistent + // between calls. + final int availableCapacityForWrap = dst.remaining() - bytesProduced - maxWrapOverhead; + if (availableCapacityForWrap <= 0) { + return new SSLEngineResult(BUFFER_OVERFLOW, getHandshakeStatus(), bytesConsumed, + bytesProduced); + } + bytesWritten = writePlaintextData(src, min(remaining, availableCapacityForWrap)); + } if (bytesWritten > 0) { bytesConsumed += bytesWritten; @@ -650,7 +725,9 @@ public class ReferenceCountedOpenSslEngine extends SSLEngine implements Referenc bytesProduced += bioLengthBefore - pendingNow; bioLengthBefore = pendingNow; - return newResultMayFinishHandshake(status, bytesConsumed, bytesProduced); + if (jdkCompatibilityMode || bytesProduced == dst.remaining()) { + return newResultMayFinishHandshake(status, bytesConsumed, bytesProduced); + } } else { int sslError = SSL.getError(ssl, bytesWritten); if (sslError == SSL.SSL_ERROR_ZERO_RETURN) { @@ -686,10 +763,10 @@ public class ReferenceCountedOpenSslEngine extends SSLEngine implements Referenc // "When using a buffering BIO, like a BIO pair, data must be written into or retrieved // out of the BIO before being able to continue." // - // So we attempt to drain the BIO buffer below, but if there is no data this condition - // is undefined and we assume their is a fatal error with the openssl engine and close. + // In practice this means the destination buffer doesn't have enough space for OpenSSL + // to write encrypted data to. This is an OVERFLOW condition. // [1] https://www.openssl.org/docs/manmaster/ssl/SSL_write.html - return newResult(NEED_WRAP, bytesConsumed, bytesProduced); + return newResult(BUFFER_OVERFLOW, status, bytesConsumed, bytesProduced); } else { // Everything else is considered as error throw shutdownWithError("SSL_write"); @@ -835,26 +912,39 @@ public class ReferenceCountedOpenSslEngine extends SSLEngine implements Referenc } } - if (len < SSL_RECORD_HEADER_LENGTH) { + int sslPending = sslPending0(); + int packetLength; + // The JDK implies that only a single SSL packet should be processed per unwrap call [1]. If we are in + // JDK compatibility mode then we should honor this, but if not we just wrap as much as possible. If there + // are multiple records or partial records this may reduce thrashing events through the pipeline. + // [1] https://docs.oracle.com/javase/7/docs/api/javax/net/ssl/SSLEngine.html + if (jdkCompatibilityMode) { + if (len < SslUtils.SSL_RECORD_HEADER_LENGTH) { + return newResultMayFinishHandshake(BUFFER_UNDERFLOW, status, 0, 0); + } + + packetLength = SslUtils.getEncryptedPacketLength(srcs, srcsOffset); + if (packetLength == SslUtils.NOT_ENCRYPTED) { + throw new NotSslRecordException("not an SSL/TLS record"); + } + + if (packetLength - SslUtils.SSL_RECORD_HEADER_LENGTH > capacity) { + // No enough space in the destination buffer so signal the caller + // that the buffer needs to be increased. + return newResultMayFinishHandshake(BUFFER_OVERFLOW, status, 0, 0); + } + + if (len < packetLength) { + // We either have no enough data to read the packet length at all or not enough for reading + // the whole packet. + return newResultMayFinishHandshake(BUFFER_UNDERFLOW, status, 0, 0); + } + } else if (len == 0 && sslPending <= 0) { return newResultMayFinishHandshake(BUFFER_UNDERFLOW, status, 0, 0); - } - - int packetLength = SslUtils.getEncryptedPacketLength(srcs, srcsOffset); - - if (packetLength == SslUtils.NOT_ENCRYPTED) { - throw new NotSslRecordException("not an SSL/TLS record"); - } - - if (packetLength - SSL_RECORD_HEADER_LENGTH > capacity) { - // No enough space in the destination buffer so signal the caller - // that the buffer needs to be increased. + } else if (capacity == 0) { return newResultMayFinishHandshake(BUFFER_OVERFLOW, status, 0, 0); - } - - if (len < packetLength) { - // We either have no enough data to read the packet length at all or not enough for reading - // the whole packet. - return newResultMayFinishHandshake(BUFFER_UNDERFLOW, status, 0, 0); + } else { + packetLength = (int) min(MAX_VALUE, len); } // This must always be the case when we reached here as if not we returned BUFFER_UNDERFLOW. @@ -867,24 +957,38 @@ public class ReferenceCountedOpenSslEngine extends SSLEngine implements Referenc int bytesProduced = 0; int bytesConsumed = 0; try { - for (; srcsOffset < srcsEndOffset; ++srcsOffset) { + srcLoop: + for (;;) { ByteBuffer src = srcs[srcsOffset]; int remaining = src.remaining(); + final ByteBuf bioWriteCopyBuf; + int pendingEncryptedBytes; if (remaining == 0) { - // We must skip empty buffers as BIO_write will return 0 if asked to write something - // with length 0. - continue; + if (sslPending <= 0) { + // We must skip empty buffers as BIO_write will return 0 if asked to write something + // with length 0. + if (++srcsOffset >= srcsEndOffset) { + break; + } + continue; + } else { + bioWriteCopyBuf = null; + pendingEncryptedBytes = SSL.bioLengthByteBuffer(networkBIO); + } + } else { + // Write more encrypted data into the BIO. Ensure we only read one packet at a time as + // stated in the SSLEngine javadocs. + pendingEncryptedBytes = min(packetLength, remaining); + bioWriteCopyBuf = writeEncryptedData(src, pendingEncryptedBytes); } - // Write more encrypted data into the BIO. Ensure we only read one packet at a time as - // stated in the SSLEngine javadocs. - int pendingEncryptedBytes = min(packetLength, remaining); - ByteBuf bioWriteCopyBuf = writeEncryptedData(src, pendingEncryptedBytes); try { - readLoop: - for (; dstsOffset < dstsEndOffset; ++dstsOffset) { + for (;;) { ByteBuffer dst = dsts[dstsOffset]; if (!dst.hasRemaining()) { // No space left in the destination buffer, skip it. + if (++dstsOffset >= dstsEndOffset) { + break srcLoop; + } continue; } @@ -903,22 +1007,27 @@ public class ReferenceCountedOpenSslEngine extends SSLEngine implements Referenc bytesProduced += bytesRead; if (!dst.hasRemaining()) { + sslPending = sslPending0(); // Move to the next dst buffer as this one is full. - continue; - } - if (packetLength == 0) { + if (++dstsOffset >= dstsEndOffset) { + return sslPending > 0 ? + newResult(BUFFER_OVERFLOW, status, bytesConsumed, bytesProduced) : + newResultMayFinishHandshake(isInboundDone() ? CLOSED : OK, status, + bytesConsumed, bytesProduced); + } + } else if (packetLength == 0) { // We read everything return now. - return newResultMayFinishHandshake(isInboundDone() ? CLOSED : OK, status, - bytesConsumed, bytesProduced); + break srcLoop; + } else if (jdkCompatibilityMode) { + // try to write again to the BIO. stop reading from it by break out of the readLoop. + break; } - // try to write again to the BIO. stop reading from it by break out of the readLoop. - break; } else { int sslError = SSL.getError(ssl, bytesRead); if (sslError == SSL.SSL_ERROR_WANT_READ || sslError == SSL.SSL_ERROR_WANT_WRITE) { // break to the outer loop as we want to read more data which means we need to // write more to the BIO. - break readLoop; + break; } else if (sslError == SSL.SSL_ERROR_ZERO_RETURN) { // This means the connection was shutdown correctly, close inbound and outbound if (!receivedShutdown) { @@ -933,8 +1042,7 @@ public class ReferenceCountedOpenSslEngine extends SSLEngine implements Referenc } } - // Either we have no more dst buffers to put the data, or no more data to generate; we are done. - if (dstsOffset >= dstsEndOffset || packetLength == 0) { + if (++srcsOffset >= srcsEndOffset) { break; } } finally { @@ -1054,7 +1162,6 @@ public class ReferenceCountedOpenSslEngine extends SSLEngine implements Referenc public final Runnable getDelegatedTask() { // Currently, we do not delegate SSL computation tasks // TODO: in the future, possibly create tasks to do encrypt / decrypt async - return null; } @@ -1329,6 +1436,7 @@ public class ReferenceCountedOpenSslEngine extends SSLEngine implements Referenc // for renegotiation. handshakeState = HandshakeState.STARTED_EXPLICITLY; // Next time this method is invoked by the user, + calculateMaxWrapOverhead(); // we should raise an exception. break; case STARTED_EXPLICITLY: @@ -1375,6 +1483,7 @@ public class ReferenceCountedOpenSslEngine extends SSLEngine implements Referenc case NOT_STARTED: handshakeState = HandshakeState.STARTED_EXPLICITLY; handshake(); + calculateMaxWrapOverhead(); break; default: throw new Error(); @@ -1667,11 +1776,6 @@ public class ReferenceCountedOpenSslEngine extends SSLEngine implements Referenc return destroyed != 0; } - static int calculateOutNetBufSize(int pendingBytes, int numComponents) { - return (int) min(MAX_ENCRYPTED_PACKET_LENGTH, - pendingBytes + (long) MAX_TLS_RECORD_OVERHEAD_LENGTH * numComponents); - } - final boolean checkSniHostnameMatch(String hostname) { return Java8SslUtils.checkSniHostnameMatch(matchers, hostname); } @@ -1819,6 +1923,7 @@ public class ReferenceCountedOpenSslEngine extends SSLEngine implements Referenc initPeerCerts(); selectApplicationProtocol(); + calculateMaxWrapOverhead(); handshakeState = HandshakeState.FINISHED; } else { @@ -2026,7 +2131,7 @@ public class ReferenceCountedOpenSslEngine extends SSLEngine implements Referenc @Override public int getPacketBufferSize() { - return MAX_ENCRYPTED_PACKET_LENGTH; + return maxEncryptedPacketLength(); } @Override diff --git a/handler/src/main/java/io/netty/handler/ssl/SslContext.java b/handler/src/main/java/io/netty/handler/ssl/SslContext.java index 4998d0d446..e3eb973a42 100644 --- a/handler/src/main/java/io/netty/handler/ssl/SslContext.java +++ b/handler/src/main/java/io/netty/handler/ssl/SslContext.java @@ -882,11 +882,11 @@ public abstract class SslContext { * Creates a new {@link SslHandler}. *

If {@link SslProvider#OPENSSL_REFCNT} is used then the returned {@link SslHandler} will release the engine * that is wrapped. If the returned {@link SslHandler} is not inserted into a pipeline then you may leak native - * memory!

+ * memory! *

Beware: the underlying generated {@link SSLEngine} won't have * hostname verification enabled by default. * If you create {@link SslHandler} for the client side and want proper security, we advice that you configure - * the {@link SSLEngine} (see {@link javax.net.ssl.SSLParameters#setEndpointIdentificationAlgorithm(String)}):

+ * the {@link SSLEngine} (see {@link javax.net.ssl.SSLParameters#setEndpointIdentificationAlgorithm(String)}): *
      * SSLEngine sslEngine = sslHandler.engine();
      * SSLParameters sslParameters = sslEngine.getSSLParameters();
@@ -894,12 +894,22 @@ public abstract class SslContext {
      * sslParameters.setEndpointIdentificationAlgorithm("HTTPS");
      * sslEngine.setSSLParameters(sslParameters);
      * 
- * + *

+ * The underlying {@link SSLEngine} may not follow the restrictions imposed by the + * SSLEngine javadocs which + * limits wrap/unwrap to operate on a single SSL/TLS packet. * @param alloc If supported by the SSLEngine then the SSLEngine will use this to allocate ByteBuf objects. - * * @return a new {@link SslHandler} */ public final SslHandler newHandler(ByteBufAllocator alloc) { + return newHandler(alloc, startTls); + } + + /** + * Create a new SslHandler. + * @see #newHandler(io.netty.buffer.ByteBufAllocator) + */ + SslHandler newHandler(ByteBufAllocator alloc, boolean startTls) { return new SslHandler(newEngine(alloc), startTls); } @@ -907,11 +917,11 @@ public abstract class SslContext { * Creates a new {@link SslHandler} with advisory peer information. *

If {@link SslProvider#OPENSSL_REFCNT} is used then the returned {@link SslHandler} will release the engine * that is wrapped. If the returned {@link SslHandler} is not inserted into a pipeline then you may leak native - * memory!

+ * memory! *

Beware: the underlying generated {@link SSLEngine} won't have * hostname verification enabled by default. * If you create {@link SslHandler} for the client side and want proper security, we advice that you configure - * the {@link SSLEngine} (see {@link javax.net.ssl.SSLParameters#setEndpointIdentificationAlgorithm(String)}):

+ * the {@link SSLEngine} (see {@link javax.net.ssl.SSLParameters#setEndpointIdentificationAlgorithm(String)}): *
      * SSLEngine sslEngine = sslHandler.engine();
      * SSLParameters sslParameters = sslEngine.getSSLParameters();
@@ -919,7 +929,10 @@ public abstract class SslContext {
      * sslParameters.setEndpointIdentificationAlgorithm("HTTPS");
      * sslEngine.setSSLParameters(sslParameters);
      * 
- * + *

+ * The underlying {@link SSLEngine} may not follow the restrictions imposed by the + * SSLEngine javadocs which + * limits wrap/unwrap to operate on a single SSL/TLS packet. * @param alloc If supported by the SSLEngine then the SSLEngine will use this to allocate ByteBuf objects. * @param peerHost the non-authoritative name of the host * @param peerPort the non-authoritative port @@ -927,6 +940,14 @@ public abstract class SslContext { * @return a new {@link SslHandler} */ public final SslHandler newHandler(ByteBufAllocator alloc, String peerHost, int peerPort) { + return newHandler(alloc, peerHost, peerPort, startTls); + } + + /** + * Create a new SslHandler. + * @see #newHandler(io.netty.buffer.ByteBufAllocator, String, int, boolean) + */ + SslHandler newHandler(ByteBufAllocator alloc, String peerHost, int peerPort, boolean startTls) { return new SslHandler(newEngine(alloc, peerHost, peerPort), startTls); } diff --git a/handler/src/main/java/io/netty/handler/ssl/SslHandler.java b/handler/src/main/java/io/netty/handler/ssl/SslHandler.java index c05496466c..6c29e5d4e4 100644 --- a/handler/src/main/java/io/netty/handler/ssl/SslHandler.java +++ b/handler/src/main/java/io/netty/handler/ssl/SslHandler.java @@ -46,12 +46,6 @@ import io.netty.util.internal.ThrowableUtil; import io.netty.util.internal.logging.InternalLogger; import io.netty.util.internal.logging.InternalLoggerFactory; -import javax.net.ssl.SSLEngine; -import javax.net.ssl.SSLEngineResult; -import javax.net.ssl.SSLEngineResult.HandshakeStatus; -import javax.net.ssl.SSLEngineResult.Status; -import javax.net.ssl.SSLException; -import javax.net.ssl.SSLSession; import java.io.IOException; import java.net.SocketAddress; import java.nio.ByteBuffer; @@ -65,6 +59,12 @@ import java.util.concurrent.Executor; import java.util.concurrent.ScheduledFuture; import java.util.concurrent.TimeUnit; import java.util.regex.Pattern; +import javax.net.ssl.SSLEngine; +import javax.net.ssl.SSLEngineResult; +import javax.net.ssl.SSLEngineResult.HandshakeStatus; +import javax.net.ssl.SSLEngineResult.Status; +import javax.net.ssl.SSLException; +import javax.net.ssl.SSLSession; import static io.netty.handler.ssl.SslUtils.getEncryptedPacketLength; @@ -210,9 +210,21 @@ public class SslHandler extends ByteToMessageDecoder implements ChannelOutboundH return result; } + @Override + int getPacketBufferSize(SslHandler handler) { + return ((ReferenceCountedOpenSslEngine) handler.engine).maxEncryptedPacketLength0(); + } + @Override int calculateWrapBufferCapacity(SslHandler handler, int pendingBytes, int numComponents) { - return ReferenceCountedOpenSslEngine.calculateOutNetBufSize(pendingBytes, numComponents); + return ((ReferenceCountedOpenSslEngine) handler.engine).calculateMaxLengthForWrap(pendingBytes, + numComponents); + } + + @Override + int calculatePendingData(SslHandler handler, int guess) { + int sslPending = ((ReferenceCountedOpenSslEngine) handler.engine).sslPending(); + return sslPending > 0 ? sslPending : guess; } }, CONSCRYPT(true, COMPOSITE_CUMULATOR) { @@ -246,6 +258,11 @@ public class SslHandler extends ByteToMessageDecoder implements ChannelOutboundH int calculateWrapBufferCapacity(SslHandler handler, int pendingBytes, int numComponents) { return ((ConscryptAlpnSslEngine) handler.engine).calculateOutNetBufSize(pendingBytes, numComponents); } + + @Override + int calculatePendingData(SslHandler handler, int guess) { + return guess; + } }, JDK(false, MERGE_CUMULATOR) { @Override @@ -260,18 +277,18 @@ public class SslHandler extends ByteToMessageDecoder implements ChannelOutboundH @Override int calculateWrapBufferCapacity(SslHandler handler, int pendingBytes, int numComponents) { - return handler.maxPacketBufferSize; + return handler.engine.getSession().getPacketBufferSize(); + } + + @Override + int calculatePendingData(SslHandler handler, int guess) { + return guess; } }; static SslEngineType forEngine(SSLEngine engine) { - if (engine instanceof ReferenceCountedOpenSslEngine) { - return TCNATIVE; - } - if (engine instanceof ConscryptAlpnSslEngine) { - return CONSCRYPT; - } - return JDK; + return engine instanceof ReferenceCountedOpenSslEngine ? TCNATIVE : + engine instanceof ConscryptAlpnSslEngine ? CONSCRYPT : JDK; } SslEngineType(boolean wantsDirectBuffer, Cumulator cumulator) { @@ -279,11 +296,17 @@ public class SslHandler extends ByteToMessageDecoder implements ChannelOutboundH this.cumulator = cumulator; } + int getPacketBufferSize(SslHandler handler) { + return handler.engine.getSession().getPacketBufferSize(); + } + abstract SSLEngineResult unwrap(SslHandler handler, ByteBuf in, int readerIndex, int len, ByteBuf out) throws SSLException; abstract int calculateWrapBufferCapacity(SslHandler handler, int pendingBytes, int numComponents); + abstract int calculatePendingData(SslHandler handler, int guess); + // BEGIN Platform-dependent flags /** @@ -307,8 +330,8 @@ public class SslHandler extends ByteToMessageDecoder implements ChannelOutboundH private volatile ChannelHandlerContext ctx; private final SSLEngine engine; private final SslEngineType engineType; - private final int maxPacketBufferSize; private final Executor delegatedTaskExecutor; + private final boolean jdkCompatibilityMode; /** * Used if {@link SSLEngine#wrap(ByteBuffer[], ByteBuffer)} and {@link SSLEngine#unwrap(ByteBuffer, ByteBuffer[])} @@ -380,6 +403,14 @@ public class SslHandler extends ByteToMessageDecoder implements ChannelOutboundH */ @Deprecated public SslHandler(SSLEngine engine, boolean startTls, Executor delegatedTaskExecutor) { + this(engine, startTls, true, delegatedTaskExecutor); + } + + SslHandler(SSLEngine engine, boolean startTls, boolean jdkCompatibilityMode) { + this(engine, startTls, jdkCompatibilityMode, ImmediateExecutor.INSTANCE); + } + + SslHandler(SSLEngine engine, boolean startTls, boolean jdkCompatibilityMode, Executor delegatedTaskExecutor) { if (engine == null) { throw new NullPointerException("engine"); } @@ -390,7 +421,7 @@ public class SslHandler extends ByteToMessageDecoder implements ChannelOutboundH engineType = SslEngineType.forEngine(engine); this.delegatedTaskExecutor = delegatedTaskExecutor; this.startTls = startTls; - maxPacketBufferSize = engine.getSession().getPacketBufferSize(); + this.jdkCompatibilityMode = jdkCompatibilityMode; setCumulator(engineType.cumulator); } @@ -878,7 +909,7 @@ public class SslHandler extends ByteToMessageDecoder implements ChannelOutboundH switch (result.getStatus()) { case BUFFER_OVERFLOW: - out.ensureWritable(maxPacketBufferSize); + out.ensureWritable(engine.getSession().getPacketBufferSize()); break; default: return result; @@ -1014,100 +1045,82 @@ public class SslHandler extends ByteToMessageDecoder implements ChannelOutboundH return getEncryptedPacketLength(buffer, buffer.readerIndex()) != SslUtils.NOT_ENCRYPTED; } - @Override - protected void decode(ChannelHandlerContext ctx, ByteBuf in, List out) throws SSLException { - final int startOffset = in.readerIndex(); - final int endOffset = in.writerIndex(); - int offset = startOffset; - int totalLength = 0; - + private void decodeJdkCompatible(ChannelHandlerContext ctx, ByteBuf in) throws NotSslRecordException { + int packetLength = this.packetLength; // If we calculated the length of the current SSL record before, use that information. if (packetLength > 0) { - if (endOffset - startOffset < packetLength) { + if (in.readableBytes() < packetLength) { return; - } else { - offset += packetLength; - totalLength = packetLength; - packetLength = 0; } - } - - boolean nonSslRecord = false; - - while (totalLength < ReferenceCountedOpenSslEngine.MAX_ENCRYPTED_PACKET_LENGTH) { - final int readableBytes = endOffset - offset; + } else { + // Get the packet length and wait until we get a packets worth of data to unwrap. + final int readableBytes = in.readableBytes(); if (readableBytes < SslUtils.SSL_RECORD_HEADER_LENGTH) { - break; + return; } - - final int packetLength = getEncryptedPacketLength(in, offset); + packetLength = getEncryptedPacketLength(in, in.readerIndex()); if (packetLength == SslUtils.NOT_ENCRYPTED) { - nonSslRecord = true; - break; + // Not an SSL/TLS packet + NotSslRecordException e = new NotSslRecordException( + "not an SSL/TLS record: " + ByteBufUtil.hexDump(in)); + in.skipBytes(in.readableBytes()); + + // First fail the handshake promise as we may need to have access to the SSLEngine which may + // be released because the user will remove the SslHandler in an exceptionCaught(...) implementation. + setHandshakeFailure(ctx, e); + + throw e; } - assert packetLength > 0; - if (packetLength > readableBytes) { // wait until the whole packet can be read this.packetLength = packetLength; - break; - } - - int newTotalLength = totalLength + packetLength; - if (newTotalLength > ReferenceCountedOpenSslEngine.MAX_ENCRYPTED_PACKET_LENGTH) { - // Don't read too much. - break; - } - - // We have a whole packet. - // Increment the offset to handle the next packet. - offset += packetLength; - totalLength = newTotalLength; - } - - if (totalLength > 0) { - // The buffer contains one or more full SSL records. - // Slice out the whole packet so unwrap will only be called with complete packets. - // Also directly reset the packetLength. This is needed as unwrap(..) may trigger - // decode(...) again via: - // 1) unwrap(..) is called - // 2) wrap(...) is called from within unwrap(...) - // 3) wrap(...) calls unwrapLater(...) - // 4) unwrapLater(...) calls decode(...) - // - // See https://github.com/netty/netty/issues/1534 - - in.skipBytes(totalLength); - - try { - firedChannelRead = unwrap(ctx, in, startOffset, totalLength) || firedChannelRead; - } catch (Throwable cause) { - try { - // We need to flush one time as there may be an alert that we should send to the remote peer because - // of the SSLException reported here. - wrapAndFlush(ctx); - } catch (SSLException ex) { - logger.debug("SSLException during trying to call SSLEngine.wrap(...)" + - " because of an previous SSLException, ignoring...", ex); - } finally { - setHandshakeFailure(ctx, cause); - } - PlatformDependent.throwException(cause); + return; } } - if (nonSslRecord) { - // Not an SSL/TLS packet - NotSslRecordException e = new NotSslRecordException( - "not an SSL/TLS record: " + ByteBufUtil.hexDump(in)); - in.skipBytes(in.readableBytes()); + // Reset the state of this class so we can get the length of the next packet. We assume the entire packet will + // be consumed by the SSLEngine. + this.packetLength = 0; + try { + int bytesConsumed = unwrap(ctx, in, in.readerIndex(), packetLength); + assert bytesConsumed == packetLength || engine.isInboundDone() : + "we feed the SSLEngine a packets worth of data: " + packetLength + " but it only consumed: " + + bytesConsumed; + in.skipBytes(bytesConsumed); + } catch (Throwable cause) { + handleUnwrapThrowable(ctx, cause); + } + } - // First fail the handshake promise as we may need to have access to the SSLEngine which may - // be released because the user will remove the SslHandler in an exceptionCaught(...) implementation. - setHandshakeFailure(ctx, e); + private void decodeNonJdkCompatible(ChannelHandlerContext ctx, ByteBuf in) { + try { + in.skipBytes(unwrap(ctx, in, in.readerIndex(), in.readableBytes())); + } catch (Throwable cause) { + handleUnwrapThrowable(ctx, cause); + } + } - throw e; + private void handleUnwrapThrowable(ChannelHandlerContext ctx, Throwable cause) { + try { + // We need to flush one time as there may be an alert that we should send to the remote peer because + // of the SSLException reported here. + wrapAndFlush(ctx); + } catch (SSLException ex) { + logger.debug("SSLException during trying to call SSLEngine.wrap(...)" + + " because of an previous SSLException, ignoring...", ex); + } finally { + setHandshakeFailure(ctx, cause); + } + PlatformDependent.throwException(cause); + } + + @Override + protected void decode(ChannelHandlerContext ctx, ByteBuf in, List out) throws SSLException { + if (jdkCompatibilityMode) { + decodeJdkCompatible(ctx, in); + } else { + decodeNonJdkCompatible(ctx, in); } } @@ -1148,10 +1161,9 @@ public class SslHandler extends ByteToMessageDecoder implements ChannelOutboundH /** * Unwraps inbound SSL records. */ - private boolean unwrap( + private int unwrap( ChannelHandlerContext ctx, ByteBuf packet, int offset, int length) throws SSLException { - - boolean decoded = false; + final int originalLength = length; boolean wrapLater = false; boolean notifyClosure = false; ByteBuf decodeOut = allocate(ctx, length); @@ -1174,7 +1186,7 @@ public class SslHandler extends ByteToMessageDecoder implements ChannelOutboundH int readableBytes = decodeOut.readableBytes(); int bufferSize = engine.getSession().getApplicationBufferSize() - readableBytes; if (readableBytes > 0) { - decoded = true; + firedChannelRead = true; ctx.fireChannelRead(decodeOut); // This buffer was handled, null it out. @@ -1194,7 +1206,7 @@ public class SslHandler extends ByteToMessageDecoder implements ChannelOutboundH // Allocate a new buffer which can hold all the rest data and loop again. // TODO: We may want to reconsider how we calculate the length here as we may // have more then one ssl message to decode. - decodeOut = allocate(ctx, bufferSize); + decodeOut = allocate(ctx, engineType.calculatePendingData(this, bufferSize)); continue; case CLOSED: // notify about the CLOSED state of the SSLEngine. See #137 @@ -1268,7 +1280,7 @@ public class SslHandler extends ByteToMessageDecoder implements ChannelOutboundH } finally { if (decodeOut != null) { if (decodeOut.isReadable()) { - decoded = true; + firedChannelRead = true; ctx.fireChannelRead(decodeOut); } else { @@ -1276,7 +1288,7 @@ public class SslHandler extends ByteToMessageDecoder implements ChannelOutboundH } } } - return decoded; + return originalLength - length; } private static ByteBuffer toByteBuffer(ByteBuf out, int index, int len) { diff --git a/handler/src/test/java/io/netty/handler/ssl/OpenSslEngineTest.java b/handler/src/test/java/io/netty/handler/ssl/OpenSslEngineTest.java index 5939b66e35..156d69b589 100644 --- a/handler/src/test/java/io/netty/handler/ssl/OpenSslEngineTest.java +++ b/handler/src/test/java/io/netty/handler/ssl/OpenSslEngineTest.java @@ -43,8 +43,6 @@ import javax.net.ssl.SSLException; import javax.net.ssl.SSLParameters; import static io.netty.handler.ssl.OpenSslTestUtils.checkShouldUseKeyManagerFactory; -import static io.netty.handler.ssl.ReferenceCountedOpenSslEngine.MAX_ENCRYPTED_PACKET_LENGTH; -import static io.netty.handler.ssl.ReferenceCountedOpenSslEngine.MAX_TLS_RECORD_OVERHEAD_LENGTH; import static io.netty.handler.ssl.ReferenceCountedOpenSslEngine.MAX_PLAINTEXT_LENGTH; import static io.netty.internal.tcnative.SSL.SSL_CVERIFY_IGNORED; import static java.lang.Integer.MAX_VALUE; @@ -218,9 +216,9 @@ public class OpenSslEngineTest extends SSLEngineTest { ByteBuffer src = allocateBuffer(srcLen); ByteBuffer dstTooSmall = allocateBuffer( - src.capacity() + MAX_TLS_RECORD_OVERHEAD_LENGTH - 1); + src.capacity() + ((ReferenceCountedOpenSslEngine) clientEngine).maxWrapOverhead() - 1); ByteBuffer dst = allocateBuffer( - src.capacity() + MAX_TLS_RECORD_OVERHEAD_LENGTH); + src.capacity() + ((ReferenceCountedOpenSslEngine) clientEngine).maxWrapOverhead()); // Check that we fail to wrap if the dst buffers capacity is not at least // src.capacity() + ReferenceCountedOpenSslEngine.MAX_TLS_RECORD_OVERHEAD_LENGTH @@ -267,7 +265,7 @@ public class OpenSslEngineTest extends SSLEngineTest { ByteBuffer src2 = src.duplicate(); ByteBuffer dst = allocateBuffer(src.capacity() - + MAX_TLS_RECORD_OVERHEAD_LENGTH); + + ((ReferenceCountedOpenSslEngine) clientEngine).maxWrapOverhead()); SSLEngineResult result = clientEngine.wrap(new ByteBuffer[] { src, src2 }, dst); assertEquals(SSLEngineResult.Status.BUFFER_OVERFLOW, result.getStatus()); @@ -311,8 +309,8 @@ public class OpenSslEngineTest extends SSLEngineTest { } ByteBuffer[] srcs = srcList.toArray(new ByteBuffer[srcList.size()]); - - ByteBuffer dst = allocateBuffer(MAX_ENCRYPTED_PACKET_LENGTH - 1); + ByteBuffer dst = allocateBuffer( + ((ReferenceCountedOpenSslEngine) clientEngine).maxEncryptedPacketLength() - 1); SSLEngineResult result = clientEngine.wrap(srcs, dst); assertEquals(SSLEngineResult.Status.BUFFER_OVERFLOW, result.getStatus()); @@ -330,21 +328,34 @@ public class OpenSslEngineTest extends SSLEngineTest { } @Test - public void testCalculateOutNetBufSizeOverflow() { - assertEquals(MAX_ENCRYPTED_PACKET_LENGTH, - ReferenceCountedOpenSslEngine.calculateOutNetBufSize(MAX_VALUE, 1)); + public void testCalculateOutNetBufSizeOverflow() throws SSLException { + clientSslCtx = SslContextBuilder.forClient() + .trustManager(InsecureTrustManagerFactory.INSTANCE) + .sslProvider(sslClientProvider()) + .build(); + SSLEngine clientEngine = null; + try { + clientEngine = clientSslCtx.newEngine(UnpooledByteBufAllocator.DEFAULT); + int value = ((ReferenceCountedOpenSslEngine) clientEngine).calculateMaxLengthForWrap(MAX_VALUE, 1); + assertTrue("unexpected value: " + value, value > 0); + } finally { + cleanupClientSslEngine(clientEngine); + } } @Test - public void testCalculateOutNetBufSize0() { - assertEquals(MAX_TLS_RECORD_OVERHEAD_LENGTH, - ReferenceCountedOpenSslEngine.calculateOutNetBufSize(0, 1)); - } - - @Test - public void testCalculateOutNetBufSizeMaxEncryptedPacketLength() { - assertEquals(MAX_ENCRYPTED_PACKET_LENGTH, - ReferenceCountedOpenSslEngine.calculateOutNetBufSize(MAX_ENCRYPTED_PACKET_LENGTH + 1, 2)); + public void testCalculateOutNetBufSize0() throws SSLException { + clientSslCtx = SslContextBuilder.forClient() + .trustManager(InsecureTrustManagerFactory.INSTANCE) + .sslProvider(sslClientProvider()) + .build(); + SSLEngine clientEngine = null; + try { + clientEngine = clientSslCtx.newEngine(UnpooledByteBufAllocator.DEFAULT); + assertTrue(((ReferenceCountedOpenSslEngine) clientEngine).calculateMaxLengthForWrap(0, 1) > 0); + } finally { + cleanupClientSslEngine(clientEngine); + } } @Override @@ -533,6 +544,323 @@ public class OpenSslEngineTest extends SSLEngineTest { testWrapWithDifferentSizes(OpenSsl.PROTOCOL_SSL_V3, "ECDHE-RSA-RC4-SHA"); } + @Test + public void testMultipleRecordsInOneBufferWithNonZeroPositionJDKCompatabilityModeOff() throws Exception { + SelfSignedCertificate cert = new SelfSignedCertificate(); + + clientSslCtx = SslContextBuilder + .forClient() + .trustManager(cert.cert()) + .sslProvider(sslClientProvider()) + .build(); + SSLEngine client = clientSslCtx.newHandler(UnpooledByteBufAllocator.DEFAULT).engine(); + + serverSslCtx = SslContextBuilder + .forServer(cert.certificate(), cert.privateKey()) + .sslProvider(sslServerProvider()) + .build(); + SSLEngine server = serverSslCtx.newHandler(UnpooledByteBufAllocator.DEFAULT).engine(); + + try { + // Choose buffer size small enough that we can put multiple buffers into one buffer and pass it into the + // unwrap call without exceed MAX_ENCRYPTED_PACKET_LENGTH. + final int plainClientOutLen = 1024; + ByteBuffer plainClientOut = allocateBuffer(plainClientOutLen); + ByteBuffer plainServerOut = allocateBuffer(server.getSession().getApplicationBufferSize()); + + ByteBuffer encClientToServer = allocateBuffer(client.getSession().getPacketBufferSize()); + + int positionOffset = 1; + // We need to be able to hold 2 records + positionOffset + ByteBuffer combinedEncClientToServer = allocateBuffer( + encClientToServer.capacity() * 2 + positionOffset); + combinedEncClientToServer.position(positionOffset); + + handshake(client, server); + + plainClientOut.limit(plainClientOut.capacity()); + SSLEngineResult result = client.wrap(plainClientOut, encClientToServer); + assertEquals(plainClientOut.capacity(), result.bytesConsumed()); + assertTrue(result.bytesProduced() > 0); + + encClientToServer.flip(); + + // Copy the first record into the combined buffer + combinedEncClientToServer.put(encClientToServer); + + plainClientOut.clear(); + encClientToServer.clear(); + + result = client.wrap(plainClientOut, encClientToServer); + assertEquals(plainClientOut.capacity(), result.bytesConsumed()); + assertTrue(result.bytesProduced() > 0); + + encClientToServer.flip(); + + // Copy the first record into the combined buffer + combinedEncClientToServer.put(encClientToServer); + + encClientToServer.clear(); + + combinedEncClientToServer.flip(); + combinedEncClientToServer.position(positionOffset); + + // Make sure the limit takes positionOffset into account to the content we are looking at is correct. + combinedEncClientToServer.limit( + combinedEncClientToServer.limit() - positionOffset); + final int combinedEncClientToServerLen = combinedEncClientToServer.remaining(); + + result = server.unwrap(combinedEncClientToServer, plainServerOut); + assertEquals(0, combinedEncClientToServer.remaining()); + assertEquals(combinedEncClientToServerLen, result.bytesConsumed()); + assertEquals(plainClientOutLen, result.bytesProduced()); + } finally { + cert.delete(); + cleanupClientSslEngine(client); + cleanupServerSslEngine(server); + } + } + + @Test + public void testInputTooBigAndFillsUpBuffersJDKCompatabilityModeOff() throws Exception { + SelfSignedCertificate cert = new SelfSignedCertificate(); + + clientSslCtx = SslContextBuilder + .forClient() + .trustManager(cert.cert()) + .sslProvider(sslClientProvider()) + .build(); + SSLEngine client = clientSslCtx.newHandler(UnpooledByteBufAllocator.DEFAULT).engine(); + + serverSslCtx = SslContextBuilder + .forServer(cert.certificate(), cert.privateKey()) + .sslProvider(sslServerProvider()) + .build(); + SSLEngine server = serverSslCtx.newHandler(UnpooledByteBufAllocator.DEFAULT).engine(); + + try { + ByteBuffer plainClient = allocateBuffer(MAX_PLAINTEXT_LENGTH + 100); + ByteBuffer plainClient2 = allocateBuffer(512); + ByteBuffer plainClientTotal = allocateBuffer(plainClient.capacity() + plainClient2.capacity()); + plainClientTotal.put(plainClient); + plainClientTotal.put(plainClient2); + plainClient.clear(); + plainClient2.clear(); + plainClientTotal.flip(); + + // The capacity is designed to trigger an overflow condition. + ByteBuffer encClientToServerTooSmall = allocateBuffer(MAX_PLAINTEXT_LENGTH + 28); + ByteBuffer encClientToServer = allocateBuffer(client.getSession().getApplicationBufferSize()); + ByteBuffer encClientToServerTotal = allocateBuffer(client.getSession().getApplicationBufferSize() << 1); + ByteBuffer plainServer = allocateBuffer(server.getSession().getApplicationBufferSize() << 1); + + handshake(client, server); + + int plainClientRemaining = plainClient.remaining(); + int encClientToServerTooSmallRemaining = encClientToServerTooSmall.remaining(); + SSLEngineResult result = client.wrap(plainClient, encClientToServerTooSmall); + assertEquals(SSLEngineResult.Status.OK, result.getStatus()); + assertEquals(plainClientRemaining - plainClient.remaining(), result.bytesConsumed()); + assertEquals(encClientToServerTooSmallRemaining - encClientToServerTooSmall.remaining(), + result.bytesProduced()); + + result = client.wrap(plainClient, encClientToServerTooSmall); + assertEquals(SSLEngineResult.Status.BUFFER_OVERFLOW, result.getStatus()); + assertEquals(0, result.bytesConsumed()); + assertEquals(0, result.bytesProduced()); + + plainClientRemaining = plainClient.remaining(); + int encClientToServerRemaining = encClientToServer.remaining(); + result = client.wrap(plainClient, encClientToServer); + assertEquals(SSLEngineResult.Status.OK, result.getStatus()); + assertEquals(plainClientRemaining, result.bytesConsumed()); + assertEquals(encClientToServerRemaining - encClientToServer.remaining(), result.bytesProduced()); + assertEquals(0, plainClient.remaining()); + + final int plainClient2Remaining = plainClient2.remaining(); + encClientToServerRemaining = encClientToServer.remaining(); + result = client.wrap(plainClient2, encClientToServer); + assertEquals(SSLEngineResult.Status.OK, result.getStatus()); + assertEquals(plainClient2Remaining, result.bytesConsumed()); + assertEquals(encClientToServerRemaining - encClientToServer.remaining(), result.bytesProduced()); + + // Concatenate the too small buffer + encClientToServerTooSmall.flip(); + encClientToServer.flip(); + encClientToServerTotal.put(encClientToServerTooSmall); + encClientToServerTotal.put(encClientToServer); + encClientToServerTotal.flip(); + + // Unwrap in a single call. + final int encClientToServerTotalRemaining = encClientToServerTotal.remaining(); + result = server.unwrap(encClientToServerTotal, plainServer); + assertEquals(SSLEngineResult.Status.OK, result.getStatus()); + assertEquals(encClientToServerTotalRemaining, result.bytesConsumed()); + plainServer.flip(); + assertEquals(plainClientTotal, plainServer); + } finally { + cert.delete(); + cleanupClientSslEngine(client); + cleanupServerSslEngine(server); + } + } + + @Test + public void testPartialPacketUnwrapJDKCompatabilityModeOff() throws Exception { + SelfSignedCertificate cert = new SelfSignedCertificate(); + + clientSslCtx = SslContextBuilder + .forClient() + .trustManager(cert.cert()) + .sslProvider(sslClientProvider()) + .build(); + SSLEngine client = clientSslCtx.newHandler(UnpooledByteBufAllocator.DEFAULT).engine(); + + serverSslCtx = SslContextBuilder + .forServer(cert.certificate(), cert.privateKey()) + .sslProvider(sslServerProvider()) + .build(); + SSLEngine server = serverSslCtx.newHandler(UnpooledByteBufAllocator.DEFAULT).engine(); + + try { + ByteBuffer plainClient = allocateBuffer(1024); + ByteBuffer plainClient2 = allocateBuffer(512); + ByteBuffer plainClientTotal = allocateBuffer(plainClient.capacity() + plainClient2.capacity()); + plainClientTotal.put(plainClient); + plainClientTotal.put(plainClient2); + plainClient.clear(); + plainClient2.clear(); + plainClientTotal.flip(); + + ByteBuffer encClientToServer = allocateBuffer(client.getSession().getPacketBufferSize()); + ByteBuffer plainServer = allocateBuffer(server.getSession().getApplicationBufferSize()); + + handshake(client, server); + + SSLEngineResult result = client.wrap(plainClient, encClientToServer); + assertEquals(SSLEngineResult.Status.OK, result.getStatus()); + assertEquals(result.bytesConsumed(), plainClient.capacity()); + final int encClientLen = result.bytesProduced(); + + result = client.wrap(plainClient2, encClientToServer); + assertEquals(SSLEngineResult.Status.OK, result.getStatus()); + assertEquals(result.bytesConsumed(), plainClient2.capacity()); + final int encClientLen2 = result.bytesProduced(); + + // Flip so we can read it. + encClientToServer.flip(); + + // Consume a partial TLS packet. + ByteBuffer encClientFirstHalf = encClientToServer.duplicate(); + encClientFirstHalf.limit(encClientLen / 2); + result = server.unwrap(encClientFirstHalf, plainServer); + assertEquals(SSLEngineResult.Status.OK, result.getStatus()); + assertEquals(result.bytesConsumed(), encClientLen / 2); + encClientToServer.position(result.bytesConsumed()); + + // We now have half of the first packet and the whole second packet, so lets decode all but the last byte. + ByteBuffer encClientAllButLastByte = encClientToServer.duplicate(); + final int encClientAllButLastByteLen = encClientAllButLastByte.remaining() - 1; + encClientAllButLastByte.limit(encClientAllButLastByte.limit() - 1); + result = server.unwrap(encClientAllButLastByte, plainServer); + assertEquals(SSLEngineResult.Status.OK, result.getStatus()); + assertEquals(result.bytesConsumed(), encClientAllButLastByteLen); + encClientToServer.position(encClientToServer.position() + result.bytesConsumed()); + + // Read the last byte and verify the original content has been decrypted. + result = server.unwrap(encClientToServer, plainServer); + assertEquals(SSLEngineResult.Status.OK, result.getStatus()); + assertEquals(result.bytesConsumed(), 1); + plainServer.flip(); + assertEquals(plainClientTotal, plainServer); + } finally { + cert.delete(); + cleanupClientSslEngine(client); + cleanupServerSslEngine(server); + } + } + + @Test + public void testBufferUnderFlowAvoidedIfJDKCompatabilityModeOff() throws Exception { + SelfSignedCertificate cert = new SelfSignedCertificate(); + + clientSslCtx = SslContextBuilder + .forClient() + .trustManager(cert.cert()) + .sslProvider(sslClientProvider()) + .build(); + SSLEngine client = clientSslCtx.newHandler(UnpooledByteBufAllocator.DEFAULT).engine(); + + serverSslCtx = SslContextBuilder + .forServer(cert.certificate(), cert.privateKey()) + .sslProvider(sslServerProvider()) + .build(); + SSLEngine server = serverSslCtx.newHandler(UnpooledByteBufAllocator.DEFAULT).engine(); + + try { + ByteBuffer plainClient = allocateBuffer(1024); + plainClient.limit(plainClient.capacity()); + + ByteBuffer encClientToServer = allocateBuffer(client.getSession().getPacketBufferSize()); + ByteBuffer plainServer = allocateBuffer(server.getSession().getApplicationBufferSize()); + + handshake(client, server); + + SSLEngineResult result = client.wrap(plainClient, encClientToServer); + assertEquals(SSLEngineResult.Status.OK, result.getStatus()); + assertEquals(result.bytesConsumed(), plainClient.capacity()); + + // Flip so we can read it. + encClientToServer.flip(); + int remaining = encClientToServer.remaining(); + + // We limit the buffer so we have less then the header to read, this should result in an BUFFER_UNDERFLOW. + encClientToServer.limit(SslUtils.SSL_RECORD_HEADER_LENGTH - 1); + result = server.unwrap(encClientToServer, plainServer); + assertEquals(SSLEngineResult.Status.OK, result.getStatus()); + assertEquals(SslUtils.SSL_RECORD_HEADER_LENGTH - 1, result.bytesConsumed()); + assertEquals(0, result.bytesProduced()); + remaining -= result.bytesConsumed(); + + // We limit the buffer so we can read the header but not the rest, this should result in an + // BUFFER_UNDERFLOW. + encClientToServer.limit(SslUtils.SSL_RECORD_HEADER_LENGTH); + result = server.unwrap(encClientToServer, plainServer); + assertEquals(SSLEngineResult.Status.OK, result.getStatus()); + assertEquals(1, result.bytesConsumed()); + assertEquals(0, result.bytesProduced()); + remaining -= result.bytesConsumed(); + + // We limit the buffer so we can read the header and partly the rest, this should result in an + // BUFFER_UNDERFLOW. + encClientToServer.limit( + SslUtils.SSL_RECORD_HEADER_LENGTH + remaining - 1 - SslUtils.SSL_RECORD_HEADER_LENGTH); + result = server.unwrap(encClientToServer, plainServer); + assertEquals(SSLEngineResult.Status.OK, result.getStatus()); + assertEquals(encClientToServer.limit() - SslUtils.SSL_RECORD_HEADER_LENGTH, result.bytesConsumed()); + assertEquals(0, result.bytesProduced()); + remaining -= result.bytesConsumed(); + + // Reset limit so we can read the full record. + encClientToServer.limit(remaining); + assertEquals(0, encClientToServer.remaining()); + result = server.unwrap(encClientToServer, plainServer); + assertEquals(SSLEngineResult.Status.BUFFER_UNDERFLOW, result.getStatus()); + assertEquals(0, result.bytesConsumed()); + assertEquals(0, result.bytesProduced()); + + encClientToServer.position(0); + result = server.unwrap(encClientToServer, plainServer); + assertEquals(SSLEngineResult.Status.OK, result.getStatus()); + assertEquals(remaining, result.bytesConsumed()); + assertEquals(0, result.bytesProduced()); + } finally { + cert.delete(); + cleanupClientSslEngine(client); + cleanupServerSslEngine(server); + } + } + private void testWrapWithDifferentSizes(String protocol, String cipher) throws Exception { assumeTrue(OpenSsl.SUPPORTED_PROTOCOLS_SET.contains(protocol)); if (!OpenSsl.isCipherSuiteAvailable(cipher)) { @@ -573,7 +901,7 @@ public class OpenSslEngineTest extends SSLEngineTest { private void testWrapDstBigEnough(SSLEngine engine, int srcLen) throws SSLException { ByteBuffer src = allocateBuffer(srcLen); - ByteBuffer dst = allocateBuffer(srcLen + MAX_TLS_RECORD_OVERHEAD_LENGTH); + ByteBuffer dst = allocateBuffer(srcLen + ((ReferenceCountedOpenSslEngine) engine).maxWrapOverhead()); SSLEngineResult result = engine.wrap(src, dst); assertEquals(SSLEngineResult.Status.OK, result.getStatus()); diff --git a/handler/src/test/java/io/netty/handler/ssl/SniHandlerTest.java b/handler/src/test/java/io/netty/handler/ssl/SniHandlerTest.java index 07c87c694d..521cb79123 100644 --- a/handler/src/test/java/io/netty/handler/ssl/SniHandlerTest.java +++ b/handler/src/test/java/io/netty/handler/ssl/SniHandlerTest.java @@ -16,26 +16,9 @@ package io.netty.handler.ssl; -import static org.hamcrest.CoreMatchers.is; -import static org.hamcrest.CoreMatchers.nullValue; -import static org.junit.Assert.assertEquals; -import static org.junit.Assert.assertThat; -import static org.junit.Assert.assertTrue; -import static org.junit.Assume.assumeTrue; - -import java.io.File; -import java.net.InetSocketAddress; -import java.util.ArrayList; -import java.util.List; -import java.util.concurrent.CountDownLatch; -import java.util.concurrent.TimeUnit; - -import javax.net.ssl.SSLEngine; - -import org.junit.Test; - import io.netty.bootstrap.Bootstrap; import io.netty.bootstrap.ServerBootstrap; +import io.netty.buffer.ByteBuf; import io.netty.buffer.ByteBufAllocator; import io.netty.buffer.Unpooled; import io.netty.channel.Channel; @@ -63,9 +46,27 @@ import io.netty.util.ReferenceCounted; import io.netty.util.concurrent.Promise; import io.netty.util.internal.ObjectUtil; import io.netty.util.internal.StringUtil; +import org.junit.Test; import org.junit.runner.RunWith; import org.junit.runners.Parameterized; +import java.io.File; +import java.net.InetSocketAddress; +import java.util.ArrayList; +import java.util.List; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.TimeUnit; + +import javax.net.ssl.SSLEngine; + +import static org.hamcrest.CoreMatchers.is; +import static org.hamcrest.CoreMatchers.nullValue; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertThat; +import static org.junit.Assert.assertTrue; +import static org.junit.Assume.assumeTrue; + @RunWith(Parameterized.class) public class SniHandlerTest { @@ -251,14 +252,27 @@ public class SniHandlerTest { // invalid byte[] message = {22, 3, 1, 0, 0}; - try { // Push the handshake message. ch.writeInbound(Unpooled.wrappedBuffer(message)); + // TODO(scott): This should fail becasue the engine should reject zero length records during handshake. + // See https://github.com/netty/netty/issues/6348. + // fail(); } catch (Exception e) { // expected } + ch.close(); + + // When the channel is closed the SslHandler will write an empty buffer to the channel. + ByteBuf buf = ch.readOutbound(); + // TODO(scott): if the engine is shutdown correctly then this buffer shouldn't be null! + // See https://github.com/netty/netty/issues/6348. + if (buf != null) { + assertFalse(buf.isReadable()); + buf.release(); + } + assertThat(ch.finish(), is(false)); assertThat(handler.hostname(), nullValue()); assertThat(handler.sslContext(), is(nettyContext)); diff --git a/handler/src/test/java/io/netty/handler/ssl/SslHandlerTest.java b/handler/src/test/java/io/netty/handler/ssl/SslHandlerTest.java index 5ef43de248..7cfd050c21 100644 --- a/handler/src/test/java/io/netty/handler/ssl/SslHandlerTest.java +++ b/handler/src/test/java/io/netty/handler/ssl/SslHandlerTest.java @@ -388,8 +388,14 @@ public class SslHandlerTest { SslHandler handler = new SslHandler(SSLContext.getDefault().createSSLEngine()); EmbeddedChannel ch = new EmbeddedChannel(handler); - // Closing the Channel will also produce a close_notify so it is expected to return true. - assertTrue(ch.finishAndReleaseAll()); + ch.close(); + + // When the channel is closed the SslHandler will write an empty buffer to the channel. + ByteBuf buf = ch.readOutbound(); + assertFalse(buf.isReadable()); + buf.release(); + + assertFalse(ch.finishAndReleaseAll()); assertTrue(handler.handshakeFuture().cause() instanceof ClosedChannelException); assertTrue(handler.sslCloseFuture().cause() instanceof ClosedChannelException);