diff --git a/handler/src/main/java/io/netty/handler/ssl/JdkSslContext.java b/handler/src/main/java/io/netty/handler/ssl/JdkSslContext.java index e031bd5441..54ee2befbf 100644 --- a/handler/src/main/java/io/netty/handler/ssl/JdkSslContext.java +++ b/handler/src/main/java/io/netty/handler/ssl/JdkSslContext.java @@ -16,9 +16,7 @@ package io.netty.handler.ssl; -import io.netty.buffer.ByteBuf; import io.netty.buffer.ByteBufAllocator; -import io.netty.buffer.ByteBufInputStream; import io.netty.util.internal.EmptyArrays; import io.netty.util.internal.logging.InternalLogger; import io.netty.util.internal.logging.InternalLoggerFactory; @@ -28,8 +26,6 @@ import javax.net.ssl.KeyManagerFactory; import javax.net.ssl.SSLContext; import javax.net.ssl.SSLEngine; import javax.net.ssl.SSLSessionContext; -import javax.net.ssl.TrustManagerFactory; -import javax.security.auth.x500.X500Principal; import java.io.File; import java.io.IOException; import java.security.InvalidAlgorithmParameterException; @@ -40,8 +36,6 @@ import java.security.NoSuchAlgorithmException; import java.security.Security; import java.security.UnrecoverableKeyException; import java.security.cert.CertificateException; -import java.security.cert.CertificateFactory; -import java.security.cert.X509Certificate; import java.security.spec.InvalidKeySpecException; import java.util.ArrayList; import java.util.Arrays; @@ -325,39 +319,4 @@ public abstract class JdkSslContext extends SslContext { return kmf; } - - /** - * Build a {@link TrustManagerFactory} from a certificate chain file. - * @param certChainFile The certificate file to build from. - * @param trustManagerFactory The existing {@link TrustManagerFactory} that will be used if not {@code null}. - * @return A {@link TrustManagerFactory} which contains the certificates in {@code certChainFile} - */ - protected static TrustManagerFactory buildTrustManagerFactory(File certChainFile, - TrustManagerFactory trustManagerFactory) - throws NoSuchAlgorithmException, CertificateException, KeyStoreException, IOException { - KeyStore ks = KeyStore.getInstance("JKS"); - ks.load(null, null); - CertificateFactory cf = CertificateFactory.getInstance("X.509"); - - ByteBuf[] certs = PemReader.readCertificates(certChainFile); - try { - for (ByteBuf buf: certs) { - X509Certificate cert = (X509Certificate) cf.generateCertificate(new ByteBufInputStream(buf)); - X500Principal principal = cert.getSubjectX500Principal(); - ks.setCertificateEntry(principal.getName("RFC2253"), cert); - } - } finally { - for (ByteBuf buf: certs) { - buf.release(); - } - } - - // Set up trust manager factory to use our key store. - if (trustManagerFactory == null) { - trustManagerFactory = TrustManagerFactory.getInstance(TrustManagerFactory.getDefaultAlgorithm()); - } - trustManagerFactory.init(ks); - - return trustManagerFactory; - } } diff --git a/handler/src/main/java/io/netty/handler/ssl/OpenSslClientContext.java b/handler/src/main/java/io/netty/handler/ssl/OpenSslClientContext.java index d88e289b3e..85f15d103b 100644 --- a/handler/src/main/java/io/netty/handler/ssl/OpenSslClientContext.java +++ b/handler/src/main/java/io/netty/handler/ssl/OpenSslClientContext.java @@ -20,6 +20,7 @@ import io.netty.buffer.ByteBufInputStream; import org.apache.tomcat.jni.SSL; import org.apache.tomcat.jni.SSLContext; +import javax.net.ssl.KeyManagerFactory; import javax.net.ssl.SSLException; import javax.net.ssl.TrustManager; import javax.net.ssl.TrustManagerFactory; @@ -39,13 +40,12 @@ import java.security.cert.X509Certificate; */ public final class OpenSslClientContext extends OpenSslContext { private final OpenSslSessionContext sessionContext; - private final OpenSslEngineMap engineMap; /** * Creates a new instance. */ public OpenSslClientContext() throws SSLException { - this(null, null, null, IdentityCipherSuiteFilter.INSTANCE, null, 0, 0); + this(null, null, null, null, null, null, null, IdentityCipherSuiteFilter.INSTANCE, null, 0, 0); } /** @@ -79,7 +79,8 @@ public final class OpenSslClientContext extends OpenSslContext { * {@code null} to use the default. */ public OpenSslClientContext(File certChainFile, TrustManagerFactory trustManagerFactory) throws SSLException { - this(certChainFile, trustManagerFactory, null, IdentityCipherSuiteFilter.INSTANCE, null, 0, 0); + this(certChainFile, trustManagerFactory, null, null, null, null, null, + IdentityCipherSuiteFilter.INSTANCE, null, 0, 0); } /** @@ -104,11 +105,14 @@ public final class OpenSslClientContext extends OpenSslContext { public OpenSslClientContext(File certChainFile, TrustManagerFactory trustManagerFactory, Iterable ciphers, ApplicationProtocolConfig apn, long sessionCacheSize, long sessionTimeout) throws SSLException { - this(certChainFile, trustManagerFactory, ciphers, IdentityCipherSuiteFilter.INSTANCE, + this(certChainFile, trustManagerFactory, null, null, null, null, ciphers, IdentityCipherSuiteFilter.INSTANCE, apn, sessionCacheSize, sessionTimeout); } /** + * @deprecated use {@link #OpenSslClientContext(File, TrustManagerFactory, File, File, String, + * KeyManagerFactory, Iterable, CipherSuiteFilter, ApplicationProtocolConfig,long, long)} + * * Creates a new instance. * * @param certChainFile an X.509 certificate chain file in PEM format @@ -124,29 +128,98 @@ public final class OpenSslClientContext extends OpenSslContext { * @param sessionTimeout the timeout for the cached SSL session objects, in seconds. * {@code 0} to use the default value. */ + @Deprecated public OpenSslClientContext(File certChainFile, TrustManagerFactory trustManagerFactory, Iterable ciphers, + CipherSuiteFilter cipherFilter, ApplicationProtocolConfig apn, + long sessionCacheSize, long sessionTimeout) throws SSLException { + this(certChainFile, trustManagerFactory, null, null, null, null, + ciphers, cipherFilter, apn, sessionCacheSize, sessionTimeout); + } + + /** + * Creates a new instance. + * @param trustCertChainFile an X.509 certificate chain file in PEM format. + * {@code null} to use the system default + * @param trustManagerFactory the {@link TrustManagerFactory} that provides the {@link TrustManager}s + * that verifies the certificates sent from servers. + * {@code null} to use the default or the results of parsing {@code trustCertChainFile} + * @param keyCertChainFile an X.509 certificate chain file in PEM format. + * This provides the public key for mutual authentication. + * {@code null} to use the system default + * @param keyFile a PKCS#8 private key file in PEM format. + * This provides the private key for mutual authentication. + * {@code null} for no mutual authentication. + * @param keyPassword the password of the {@code keyFile}. + * {@code null} if it's not password-protected. + * Ignored if {@code keyFile} is {@code null}. + * @param keyManagerFactory the {@link KeyManagerFactory} that provides the {@link javax.net.ssl.KeyManager}s + * that is used to encrypt data being sent to servers. + * {@code null} to use the default or the results of parsing + * {@code keyCertChainFile} and {@code keyFile}. + * @param ciphers the cipher suites to enable, in the order of preference. + * {@code null} to use the default cipher suites. + * @param cipherFilter a filter to apply over the supplied list of ciphers + * @param apn Application Protocol Negotiator object. + * @param sessionCacheSize the size of the cache used for storing SSL session objects. + * {@code 0} to use the default value. + * @param sessionTimeout the timeout for the cached SSL session objects, in seconds. + * {@code 0} to use the default value. + */ + public OpenSslClientContext(File trustCertChainFile, TrustManagerFactory trustManagerFactory, + File keyCertChainFile, File keyFile, String keyPassword, + KeyManagerFactory keyManagerFactory, Iterable ciphers, CipherSuiteFilter cipherFilter, ApplicationProtocolConfig apn, long sessionCacheSize, long sessionTimeout) throws SSLException { super(ciphers, cipherFilter, apn, sessionCacheSize, sessionTimeout, SSL.SSL_MODE_CLIENT); boolean success = false; try { - if (certChainFile != null && !certChainFile.isFile()) { - throw new IllegalArgumentException("certChainFile is not a file: " + certChainFile); + if (trustCertChainFile != null && !trustCertChainFile.isFile()) { + throw new IllegalArgumentException("trustCertChainFile is not a file: " + trustCertChainFile); } + if (keyCertChainFile != null && !keyCertChainFile.isFile()) { + throw new IllegalArgumentException("keyCertChainFile is not a file: " + keyCertChainFile); + } + + if (keyFile != null && !keyFile.isFile()) { + throw new IllegalArgumentException("keyFile is not a file: " + keyFile); + } + if (keyFile == null && keyCertChainFile != null || keyFile != null && keyCertChainFile == null) { + throw new IllegalArgumentException( + "Either both keyCertChainFile and keyFile needs to be null or none of them"); + } synchronized (OpenSslContext.class) { - if (certChainFile != null) { + if (trustCertChainFile != null) { /* Load the certificate chain. We must skip the first cert when server mode */ - if (!SSLContext.setCertificateChainFile(ctx, certChainFile.getPath(), true)) { + if (!SSLContext.setCertificateChainFile(ctx, trustCertChainFile.getPath(), true)) { long error = SSL.getLastErrorNumber(); if (OpenSsl.isError(error)) { throw new SSLException( "failed to set certificate chain: " - + certChainFile + " (" + SSL.getErrorString(error) + ')'); + + trustCertChainFile + " (" + SSL.getErrorString(error) + ')'); } } } + if (keyCertChainFile != null && keyFile != null) { + /* Load the certificate file and private key. */ + try { + if (!SSLContext.setCertificate( + ctx, keyCertChainFile.getPath(), keyFile.getPath(), keyPassword, SSL.SSL_AIDX_RSA)) { + long error = SSL.getLastErrorNumber(); + if (OpenSsl.isError(error)) { + throw new SSLException("failed to set certificate: " + + keyCertChainFile + " and " + keyFile + + " (" + SSL.getErrorString(error) + ')'); + } + } + } catch (SSLException e) { + throw e; + } catch (Exception e) { + throw new SSLException("failed to set certificate: " + keyCertChainFile + " and " + keyFile, e); + } + } + SSLContext.setVerify(ctx, SSL.SSL_VERIFY_NONE, VERIFY_DEPTH); try { @@ -155,25 +228,24 @@ public final class OpenSslClientContext extends OpenSslContext { trustManagerFactory = TrustManagerFactory.getInstance( TrustManagerFactory.getDefaultAlgorithm()); } - initTrustManagerFactory(certChainFile, trustManagerFactory); + initTrustManagerFactory(trustCertChainFile, trustManagerFactory); final X509TrustManager manager = chooseTrustManager(trustManagerFactory.getTrustManagers()); - engineMap = newEngineMap(manager); - // Use this to prevent an error when running on java < 7 if (useExtendedTrustManager(manager)) { final X509ExtendedTrustManager extendedManager = (X509ExtendedTrustManager) manager; SSLContext.setCertVerifyCallback(ctx, new AbstractCertificateVerifier() { @Override - void verify(long ssl, X509Certificate[] peerCerts, String auth) throws Exception { - OpenSslEngine engine = engineMap.remove(ssl); + void verify(OpenSslEngine engine, X509Certificate[] peerCerts, String auth) + throws Exception { extendedManager.checkServerTrusted(peerCerts, auth, engine); } }); } else { SSLContext.setCertVerifyCallback(ctx, new AbstractCertificateVerifier() { @Override - void verify(long ssl, X509Certificate[] peerCerts, String auth) throws Exception { + void verify(OpenSslEngine engine, X509Certificate[] peerCerts, String auth) + throws Exception { manager.checkServerTrusted(peerCerts, auth); } }); @@ -218,11 +290,6 @@ public final class OpenSslClientContext extends OpenSslContext { return sessionContext; } - @Override - OpenSslEngineMap engineMap() { - return engineMap; - } - // No cache is currently supported for client side mode. private static final class OpenSslClientSessionContext extends OpenSslSessionContext { private OpenSslClientSessionContext(long context) { diff --git a/handler/src/main/java/io/netty/handler/ssl/OpenSslContext.java b/handler/src/main/java/io/netty/handler/ssl/OpenSslContext.java index 06f9d510cc..d642e22386 100644 --- a/handler/src/main/java/io/netty/handler/ssl/OpenSslContext.java +++ b/handler/src/main/java/io/netty/handler/ssl/OpenSslContext.java @@ -26,6 +26,7 @@ import org.apache.tomcat.jni.SSLContext; import javax.net.ssl.SSLEngine; import javax.net.ssl.SSLException; +import javax.net.ssl.SSLHandshakeException; import javax.net.ssl.TrustManager; import javax.net.ssl.X509ExtendedTrustManager; import javax.net.ssl.X509TrustManager; @@ -56,6 +57,8 @@ public abstract class OpenSslContext extends SslContext { private final List unmodifiableCiphers; private final long sessionCacheSize; private final long sessionTimeout; + private final OpenSslEngineMap engineMap = new DefaultOpenSslEngineMap(); + private final OpenSslApplicationProtocolNegotiator apn; /** The OpenSSL SSL_CTX object */ protected final long ctx; @@ -277,14 +280,11 @@ public abstract class OpenSslContext extends SslContext { */ @Override public final SSLEngine newEngine(ByteBufAllocator alloc) { - OpenSslEngineMap engineMap = engineMap(); final OpenSslEngine engine = new OpenSslEngine(ctx, alloc, isClient(), sessionContext(), apn, engineMap); engineMap.add(engine); return engine; } - abstract OpenSslEngineMap engineMap(); - /** * Returns the {@code SSL_CTX} object of this context. */ @@ -392,31 +392,28 @@ public abstract class OpenSslContext extends SslContext { } } - static OpenSslEngineMap newEngineMap(X509TrustManager trustManager) { - if (useExtendedTrustManager(trustManager)) { - return new DefaultOpenSslEngineMap(); - } - return OpenSslEngineMap.EMPTY; - } - static boolean useExtendedTrustManager(X509TrustManager trustManager) { return PlatformDependent.javaVersion() >= 7 && trustManager instanceof X509ExtendedTrustManager; } - abstract static class AbstractCertificateVerifier implements CertificateVerifier { + abstract class AbstractCertificateVerifier implements CertificateVerifier { @Override public final boolean verify(long ssl, byte[][] chain, String auth) { X509Certificate[] peerCerts = certificates(chain); + final OpenSslEngine engine = engineMap.remove(ssl); try { - verify(ssl, peerCerts, auth); + verify(engine, peerCerts, auth); return true; - } catch (Exception e) { - logger.debug("verification of certificate failed", e); + } catch (Throwable cause) { + logger.debug("verification of certificate failed", cause); + SSLHandshakeException e = new SSLHandshakeException("General OpenSslEngine problem"); + e.initCause(cause); + engine.handshakeException = e; } return false; } - abstract void verify(long ssl, X509Certificate[] peerCerts, String auth) throws Exception; + abstract void verify(OpenSslEngine engine, X509Certificate[] peerCerts, String auth) throws Exception; } private static final class DefaultOpenSslEngineMap implements OpenSslEngineMap { diff --git a/handler/src/main/java/io/netty/handler/ssl/OpenSslEngine.java b/handler/src/main/java/io/netty/handler/ssl/OpenSslEngine.java index ded15e9386..94dc55a9a1 100644 --- a/handler/src/main/java/io/netty/handler/ssl/OpenSslEngine.java +++ b/handler/src/main/java/io/netty/handler/ssl/OpenSslEngine.java @@ -29,6 +29,7 @@ import org.apache.tomcat.jni.SSL; import javax.net.ssl.SSLEngine; import javax.net.ssl.SSLEngineResult; import javax.net.ssl.SSLException; +import javax.net.ssl.SSLHandshakeException; import javax.net.ssl.SSLPeerUnverifiedException; import javax.net.ssl.SSLSession; import javax.net.ssl.SSLSessionBindingEvent; @@ -160,6 +161,10 @@ public final class OpenSslEngine extends SSLEngine { private final OpenSslApplicationProtocolNegotiator apn; private final SSLSession session = new OpenSslSession(); + // This is package-private as we set it from OpenSslContext if an exception is thrown during + // the verification step. + SSLHandshakeException handshakeException; + /** * Creates a new instance * @@ -489,6 +494,22 @@ public final class OpenSslEngine extends SSLEngine { return new SSLEngineResult(getEngineStatus(), handshakeStatus0(), bytesConsumed, bytesProduced); } + private SSLException newSSLException(String msg) { + if (!handshakeFinished) { + return new SSLHandshakeException(msg); + } + return new SSLException(msg); + } + + private void checkPendingHandshakeException() throws SSLHandshakeException { + if (handshakeException != null) { + SSLHandshakeException exception = handshakeException; + handshakeException = null; + shutdown(); + throw exception; + } + } + public synchronized SSLEngineResult unwrap( final ByteBuffer[] srcs, int srcsOffset, final int srcsLength, final ByteBuffer[] dsts, final int dstsOffset, final int dstsLength) throws SSLException { @@ -608,7 +629,9 @@ public final class OpenSslEngine extends SSLEngine { // There was an internal error -- shutdown shutdown(); - throw new SSLException(err); + throw newSSLException(err); + } else { + checkPendingHandshakeException(); } } } else { @@ -954,8 +977,9 @@ public final class OpenSslEngine extends SSLEngine { // There was an internal error -- shutdown shutdown(); - throw new SSLException(err); + throw newSSLException(err); } + checkPendingHandshakeException(); } else { // if SSL_do_handshake returns > 0 it means the handshake was finished. This means we can update // handshakeFinished directly and so eliminate uncessary calls to SSL.isInInit(...) @@ -1037,6 +1061,7 @@ public final class OpenSslEngine extends SSLEngine { if (status == FINISHED) { handshakeFinished(); } + checkPendingHandshakeException(); return status; } diff --git a/handler/src/main/java/io/netty/handler/ssl/OpenSslServerContext.java b/handler/src/main/java/io/netty/handler/ssl/OpenSslServerContext.java index 207c0f0355..83ee505eaa 100644 --- a/handler/src/main/java/io/netty/handler/ssl/OpenSslServerContext.java +++ b/handler/src/main/java/io/netty/handler/ssl/OpenSslServerContext.java @@ -19,7 +19,10 @@ import io.netty.util.internal.EmptyArrays; import org.apache.tomcat.jni.SSL; import org.apache.tomcat.jni.SSLContext; +import javax.net.ssl.KeyManager; +import javax.net.ssl.KeyManagerFactory; import javax.net.ssl.SSLException; +import javax.net.ssl.TrustManager; import javax.net.ssl.TrustManagerFactory; import javax.net.ssl.X509ExtendedTrustManager; import javax.net.ssl.X509TrustManager; @@ -34,7 +37,6 @@ import static io.netty.util.internal.ObjectUtil.*; */ public final class OpenSslServerContext extends OpenSslContext { private final OpenSslServerSessionContext sessionContext; - private final OpenSslEngineMap engineMap; /** * Creates a new instance. @@ -55,8 +57,8 @@ public final class OpenSslServerContext extends OpenSslContext { * {@code null} if it's not password-protected. */ public OpenSslServerContext(File certChainFile, File keyFile, String keyPassword) throws SSLException { - this(certChainFile, keyFile, keyPassword, null, null, IdentityCipherSuiteFilter.INSTANCE, - NONE_PROTOCOL_NEGOTIATOR, 0, 0); + this(certChainFile, keyFile, keyPassword, null, IdentityCipherSuiteFilter.INSTANCE, + ApplicationProtocolConfig.DISABLED, 0, 0); } /** @@ -82,8 +84,8 @@ public final class OpenSslServerContext extends OpenSslContext { File certChainFile, File keyFile, String keyPassword, Iterable ciphers, ApplicationProtocolConfig apn, long sessionCacheSize, long sessionTimeout) throws SSLException { - this(certChainFile, keyFile, keyPassword, null, ciphers, - toNegotiator(apn), sessionCacheSize, sessionTimeout); + this(certChainFile, keyFile, keyPassword, ciphers, IdentityCipherSuiteFilter.INSTANCE, + apn, sessionCacheSize, sessionTimeout); } /** @@ -128,9 +130,8 @@ public final class OpenSslServerContext extends OpenSslContext { * {@code 0} to use the default value. * @param sessionTimeout the timeout for the cached SSL session objects, in seconds. * {@code 0} to use the default value. - * @deprecated use {@link #OpenSslServerContext( - * File, File, String, TrustManagerFactory, Iterable, - * CipherSuiteFilter, ApplicationProtocolConfig, long, long)} + * @deprecated use {@link #OpenSslServerContext(File, TrustManagerFactory, File, File, String, KeyManagerFactory, + * Iterable, CipherSuiteFilter, ApplicationProtocolConfig, long, long)} */ @Deprecated public OpenSslServerContext( @@ -155,17 +156,16 @@ public final class OpenSslServerContext extends OpenSslContext { * {@code 0} to use the default value. * @param sessionTimeout the timeout for the cached SSL session objects, in seconds. * {@code 0} to use the default value. - * @deprecated use {@link #OpenSslServerContext( - * File, File, String, TrustManagerFactory, Iterable, - * CipherSuiteFilter, OpenSslApplicationProtocolNegotiator, long, long)} + * @deprecated use {@link #OpenSslServerContext(File, TrustManagerFactory, File, File, String, KeyManagerFactory, + * Iterable, CipherSuiteFilter, ApplicationProtocolConfig, long, long)} */ @Deprecated public OpenSslServerContext( File certChainFile, File keyFile, String keyPassword, TrustManagerFactory trustManagerFactory, Iterable ciphers, OpenSslApplicationProtocolNegotiator apn, long sessionCacheSize, long sessionTimeout) throws SSLException { - this(certChainFile, keyFile, keyPassword, trustManagerFactory, ciphers, - IdentityCipherSuiteFilter.INSTANCE, apn, sessionCacheSize, sessionTimeout); + this(null, trustManagerFactory, certChainFile, keyFile, keyPassword, null, + ciphers, null, apn, sessionCacheSize, sessionTimeout); } /** @@ -188,8 +188,44 @@ public final class OpenSslServerContext extends OpenSslContext { File certChainFile, File keyFile, String keyPassword, Iterable ciphers, CipherSuiteFilter cipherFilter, ApplicationProtocolConfig apn, long sessionCacheSize, long sessionTimeout) throws SSLException { - this(certChainFile, keyFile, keyPassword, null, ciphers, cipherFilter, - toNegotiator(apn), sessionCacheSize, sessionTimeout); + this(null, null, certChainFile, keyFile, keyPassword, null, + ciphers, cipherFilter, apn, sessionCacheSize, sessionTimeout); + } + + /** + * Creates a new instance. + * + * @param trustCertChainFile an X.509 certificate chain file in PEM format. + * This provides the certificate chains used for mutual authentication. + * {@code null} to use the system default + * @param trustManagerFactory the {@link TrustManagerFactory} that provides the {@link TrustManager}s + * that verifies the certificates sent from clients. + * {@code null} to use the default or the results of parsing {@code trustCertChainFile}. + * @param keyCertChainFile an X.509 certificate chain file in PEM format + * @param keyFile a PKCS#8 private key file in PEM format + * @param keyPassword the password of the {@code keyFile}. + * {@code null} if it's not password-protected. + * @param keyManagerFactory the {@link KeyManagerFactory} that provides the {@link KeyManager}s + * that is used to encrypt data being sent to clients. + * {@code null} to use the default or the results of parsing + * {@code keyCertChainFile} and {@code keyFile}. + * @param ciphers the cipher suites to enable, in the order of preference. + * {@code null} to use the default cipher suites. + * @param cipherFilter a filter to apply over the supplied list of ciphers + * Only required if {@code provider} is {@link SslProvider#JDK} + * @param config Provides a means to configure parameters related to application protocol negotiation. + * @param sessionCacheSize the size of the cache used for storing SSL session objects. + * {@code 0} to use the default value. + * @param sessionTimeout the timeout for the cached SSL session objects, in seconds. + * {@code 0} to use the default value. + */ + public OpenSslServerContext( + File trustCertChainFile, TrustManagerFactory trustManagerFactory, + File keyCertChainFile, File keyFile, String keyPassword, KeyManagerFactory keyManagerFactory, + Iterable ciphers, CipherSuiteFilter cipherFilter, ApplicationProtocolConfig config, + long sessionCacheSize, long sessionTimeout) throws SSLException { + this(trustCertChainFile, trustManagerFactory, keyCertChainFile, keyFile, keyPassword, keyManagerFactory, + ciphers, cipherFilter, toNegotiator(config), sessionCacheSize, sessionTimeout); } /** @@ -208,12 +244,13 @@ public final class OpenSslServerContext extends OpenSslContext { * @param sessionTimeout the timeout for the cached SSL session objects, in seconds. * {@code 0} to use the default value. */ - public OpenSslServerContext( - File certChainFile, File keyFile, String keyPassword, TrustManagerFactory trustManagerFactory, - Iterable ciphers, CipherSuiteFilter cipherFilter, ApplicationProtocolConfig config, - long sessionCacheSize, long sessionTimeout) throws SSLException { - this(certChainFile, keyFile, keyPassword, trustManagerFactory, ciphers, cipherFilter, - toNegotiator(config), sessionCacheSize, sessionTimeout); + @Deprecated + public OpenSslServerContext(File certChainFile, File keyFile, String keyPassword, + TrustManagerFactory trustManagerFactory, Iterable ciphers, + CipherSuiteFilter cipherFilter, ApplicationProtocolConfig config, + long sessionCacheSize, long sessionTimeout) throws SSLException { + this(null, trustManagerFactory, certChainFile, keyFile, keyPassword, null, ciphers, cipherFilter, + toNegotiator(config), sessionCacheSize, sessionTimeout); } /** @@ -231,21 +268,61 @@ public final class OpenSslServerContext extends OpenSslContext { * {@code 0} to use the default value. * @param sessionTimeout the timeout for the cached SSL session objects, in seconds. * {@code 0} to use the default value. + * @deprecated use {@link #OpenSslServerContext(File, TrustManagerFactory, File, File, String, KeyManagerFactory, + * Iterable, CipherSuiteFilter, OpenSslApplicationProtocolNegotiator, long, long)} */ + @Deprecated public OpenSslServerContext( File certChainFile, File keyFile, String keyPassword, TrustManagerFactory trustManagerFactory, Iterable ciphers, CipherSuiteFilter cipherFilter, OpenSslApplicationProtocolNegotiator apn, long sessionCacheSize, long sessionTimeout) throws SSLException { + this(null, trustManagerFactory, certChainFile, keyFile, keyPassword, null, ciphers, cipherFilter, + apn, sessionCacheSize, sessionTimeout); + } + + /** + * Creates a new instance. + * + * + * @param trustCertChainFile an X.509 certificate chain file in PEM format. + * This provides the certificate chains used for mutual authentication. + * {@code null} to use the system default + * @param trustManagerFactory the {@link TrustManagerFactory} that provides the {@link TrustManager}s + * that verifies the certificates sent from clients. + * {@code null} to use the default or the results of parsing {@code trustCertChainFile}. + * @param keyCertChainFile an X.509 certificate chain file in PEM format + * @param keyFile a PKCS#8 private key file in PEM format + * @param keyPassword the password of the {@code keyFile}. + * {@code null} if it's not password-protected. + * @param keyManagerFactory the {@link KeyManagerFactory} that provides the {@link KeyManager}s + * that is used to encrypt data being sent to clients. + * {@code null} to use the default or the results of parsing + * {@code keyCertChainFile} and {@code keyFile}. + * @param ciphers the cipher suites to enable, in the order of preference. + * {@code null} to use the default cipher suites. + * @param cipherFilter a filter to apply over the supplied list of ciphers + * Only required if {@code provider} is {@link SslProvider#JDK} + * @param apn Application Protocol Negotiator object + * @param sessionCacheSize the size of the cache used for storing SSL session objects. + * {@code 0} to use the default value. + * @param sessionTimeout the timeout for the cached SSL session objects, in seconds. + * {@code 0} to use the default value. + */ + public OpenSslServerContext( + File trustCertChainFile, TrustManagerFactory trustManagerFactory, + File keyCertChainFile, File keyFile, String keyPassword, KeyManagerFactory keyManagerFactory, + Iterable ciphers, CipherSuiteFilter cipherFilter, OpenSslApplicationProtocolNegotiator apn, + long sessionCacheSize, long sessionTimeout) throws SSLException { super(ciphers, cipherFilter, apn, sessionCacheSize, sessionTimeout, SSL.SSL_MODE_SERVER); OpenSsl.ensureAvailability(); - checkNotNull(certChainFile, "certChainFile"); - if (!certChainFile.isFile()) { - throw new IllegalArgumentException("certChainFile is not a file: " + certChainFile); + checkNotNull(keyCertChainFile, "keyCertChainFile"); + if (!keyCertChainFile.isFile()) { + throw new IllegalArgumentException("keyCertChainFile is not a file: " + keyCertChainFile); } checkNotNull(keyFile, "keyFile"); if (!keyFile.isFile()) { - throw new IllegalArgumentException("keyPath is not a file: " + keyFile); + throw new IllegalArgumentException("keyFile is not a file: " + keyFile); } if (keyPassword == null) { keyPassword = ""; @@ -259,59 +336,64 @@ public final class OpenSslServerContext extends OpenSslContext { SSLContext.setVerify(ctx, SSL.SSL_CVERIFY_NONE, VERIFY_DEPTH); /* Load the certificate chain. We must skip the first cert when server mode */ - if (!SSLContext.setCertificateChainFile(ctx, certChainFile.getPath(), true)) { + if (!SSLContext.setCertificateChainFile(ctx, keyCertChainFile.getPath(), true)) { long error = SSL.getLastErrorNumber(); if (OpenSsl.isError(error)) { String err = SSL.getErrorString(error); throw new SSLException( - "failed to set certificate chain: " + certChainFile + " (" + err + ')'); + "failed to set certificate chain: " + keyCertChainFile + " (" + err + ')'); } } /* Load the certificate file and private key. */ try { if (!SSLContext.setCertificate( - ctx, certChainFile.getPath(), keyFile.getPath(), keyPassword, SSL.SSL_AIDX_RSA)) { + ctx, keyCertChainFile.getPath(), keyFile.getPath(), keyPassword, SSL.SSL_AIDX_RSA)) { long error = SSL.getLastErrorNumber(); if (OpenSsl.isError(error)) { String err = SSL.getErrorString(error); throw new SSLException("failed to set certificate: " + - certChainFile + " and " + keyFile + " (" + err + ')'); + keyCertChainFile + " and " + keyFile + " (" + err + ')'); } } } catch (SSLException e) { throw e; } catch (Exception e) { - throw new SSLException("failed to set certificate: " + certChainFile + " and " + keyFile, e); + throw new SSLException("failed to set certificate: " + keyCertChainFile + " and " + keyFile, e); } try { - char[] keyPasswordChars = keyPassword == null ? EmptyArrays.EMPTY_CHARS : keyPassword.toCharArray(); - - KeyStore ks = buildKeyStore(certChainFile, keyFile, keyPasswordChars); if (trustManagerFactory == null) { // Mimic the way SSLContext.getInstance(KeyManager[], null, null) works trustManagerFactory = TrustManagerFactory.getInstance( TrustManagerFactory.getDefaultAlgorithm()); } - trustManagerFactory.init(ks); + if (trustCertChainFile != null) { + trustManagerFactory = buildTrustManagerFactory(trustCertChainFile, trustManagerFactory); + } else { + char[] keyPasswordChars = + keyPassword == null ? EmptyArrays.EMPTY_CHARS : keyPassword.toCharArray(); + + KeyStore ks = buildKeyStore(keyCertChainFile, keyFile, keyPasswordChars); + trustManagerFactory.init(ks); + } final X509TrustManager manager = chooseTrustManager(trustManagerFactory.getTrustManagers()); - engineMap = newEngineMap(manager); // Use this to prevent an error when running on java < 7 if (useExtendedTrustManager(manager)) { final X509ExtendedTrustManager extendedManager = (X509ExtendedTrustManager) manager; SSLContext.setCertVerifyCallback(ctx, new AbstractCertificateVerifier() { @Override - void verify(long ssl, X509Certificate[] peerCerts, String auth) throws Exception { - OpenSslEngine engine = engineMap.remove(ssl); + void verify(OpenSslEngine engine, X509Certificate[] peerCerts, String auth) + throws Exception { extendedManager.checkClientTrusted(peerCerts, auth, engine); } }); } else { SSLContext.setCertVerifyCallback(ctx, new AbstractCertificateVerifier() { @Override - void verify(long ssl, X509Certificate[] peerCerts, String auth) throws Exception { + void verify(OpenSslEngine engine, X509Certificate[] peerCerts, String auth) + throws Exception { manager.checkClientTrusted(peerCerts, auth); } }); @@ -333,9 +415,4 @@ public final class OpenSslServerContext extends OpenSslContext { public OpenSslServerSessionContext sessionContext() { return sessionContext; } - - @Override - OpenSslEngineMap engineMap() { - return engineMap; - } } diff --git a/handler/src/main/java/io/netty/handler/ssl/SslContext.java b/handler/src/main/java/io/netty/handler/ssl/SslContext.java index 69bd2caedd..05f55274ef 100644 --- a/handler/src/main/java/io/netty/handler/ssl/SslContext.java +++ b/handler/src/main/java/io/netty/handler/ssl/SslContext.java @@ -39,6 +39,7 @@ import javax.net.ssl.SSLException; import javax.net.ssl.SSLSessionContext; import javax.net.ssl.TrustManager; import javax.net.ssl.TrustManagerFactory; +import javax.security.auth.x500.X500Principal; import java.io.File; import java.io.IOException; import java.security.InvalidAlgorithmParameterException; @@ -52,6 +53,7 @@ import java.security.PrivateKey; import java.security.cert.Certificate; import java.security.cert.CertificateException; import java.security.cert.CertificateFactory; +import java.security.cert.X509Certificate; import java.security.spec.InvalidKeySpecException; import java.security.spec.PKCS8EncodedKeySpec; import java.util.ArrayList; @@ -399,8 +401,8 @@ public abstract class SslContext { keyManagerFactory, ciphers, cipherFilter, apn, sessionCacheSize, sessionTimeout); case OPENSSL: return new OpenSslServerContext( - keyCertChainFile, keyFile, keyPassword, trustManagerFactory, - ciphers, cipherFilter, apn, sessionCacheSize, sessionTimeout); + trustCertChainFile, trustManagerFactory, keyCertChainFile, keyFile, keyPassword, + keyManagerFactory, ciphers, cipherFilter, apn, sessionCacheSize, sessionTimeout); default: throw new Error(provider.toString()); } @@ -729,8 +731,8 @@ public abstract class SslContext { keyManagerFactory, ciphers, cipherFilter, apn, sessionCacheSize, sessionTimeout); case OPENSSL: return new OpenSslClientContext( - trustCertChainFile, trustManagerFactory, ciphers, cipherFilter, apn, - sessionCacheSize, sessionTimeout); + trustCertChainFile, trustManagerFactory, keyCertChainFile, keyFile, keyPassword, + keyManagerFactory, ciphers, cipherFilter, apn, sessionCacheSize, sessionTimeout); } // Should never happen!! throw new Error(); @@ -925,4 +927,39 @@ public abstract class SslContext { ks.setKeyEntry("key", key, keyPasswordChars, certChain.toArray(new Certificate[certChain.size()])); return ks; } + + /** + * Build a {@link TrustManagerFactory} from a certificate chain file. + * @param certChainFile The certificate file to build from. + * @param trustManagerFactory The existing {@link TrustManagerFactory} that will be used if not {@code null}. + * @return A {@link TrustManagerFactory} which contains the certificates in {@code certChainFile} + */ + protected static TrustManagerFactory buildTrustManagerFactory(File certChainFile, + TrustManagerFactory trustManagerFactory) + throws NoSuchAlgorithmException, CertificateException, KeyStoreException, IOException { + KeyStore ks = KeyStore.getInstance("JKS"); + ks.load(null, null); + CertificateFactory cf = CertificateFactory.getInstance("X.509"); + + ByteBuf[] certs = PemReader.readCertificates(certChainFile); + try { + for (ByteBuf buf: certs) { + X509Certificate cert = (X509Certificate) cf.generateCertificate(new ByteBufInputStream(buf)); + X500Principal principal = cert.getSubjectX500Principal(); + ks.setCertificateEntry(principal.getName("RFC2253"), cert); + } + } finally { + for (ByteBuf buf: certs) { + buf.release(); + } + } + + // Set up trust manager factory to use our key store. + if (trustManagerFactory == null) { + trustManagerFactory = TrustManagerFactory.getInstance(TrustManagerFactory.getDefaultAlgorithm()); + } + trustManagerFactory.init(ks); + + return trustManagerFactory; + } } diff --git a/handler/src/test/java/io/netty/handler/ssl/JdkSslEngineTest.java b/handler/src/test/java/io/netty/handler/ssl/JdkSslEngineTest.java index 4df0fc00a8..9482f2b445 100644 --- a/handler/src/test/java/io/netty/handler/ssl/JdkSslEngineTest.java +++ b/handler/src/test/java/io/netty/handler/ssl/JdkSslEngineTest.java @@ -15,23 +15,17 @@ */ package io.netty.handler.ssl; -import static org.junit.Assert.assertEquals; -import static org.junit.Assert.assertNotNull; import static org.junit.Assert.assertNull; import static org.junit.Assert.assertTrue; import static org.junit.Assume.assumeNoException; -import static org.mockito.Mockito.verify; import io.netty.bootstrap.Bootstrap; import io.netty.bootstrap.ServerBootstrap; -import io.netty.buffer.ByteBuf; -import io.netty.buffer.Unpooled; import io.netty.channel.Channel; import io.netty.channel.ChannelFuture; import io.netty.channel.ChannelHandlerAdapter; import io.netty.channel.ChannelHandlerContext; import io.netty.channel.ChannelInitializer; import io.netty.channel.ChannelPipeline; -import io.netty.channel.SimpleChannelInboundHandler; import io.netty.channel.nio.NioEventLoopGroup; import io.netty.channel.socket.nio.NioServerSocketChannel; import io.netty.channel.socket.nio.NioSocketChannel; @@ -40,94 +34,24 @@ import io.netty.handler.ssl.JdkApplicationProtocolNegotiator.ProtocolSelectorFac import io.netty.handler.ssl.util.InsecureTrustManagerFactory; import io.netty.handler.ssl.util.SelfSignedCertificate; import io.netty.util.NetUtil; -import io.netty.util.concurrent.Future; -import java.io.File; import java.net.InetSocketAddress; import java.security.cert.CertificateException; import java.util.List; import java.util.Set; -import java.util.concurrent.CountDownLatch; import java.util.concurrent.TimeUnit; import javax.net.ssl.SSLEngine; import javax.net.ssl.SSLException; import javax.net.ssl.SSLHandshakeException; -import org.junit.After; -import org.junit.Before; import org.junit.Test; -import org.mockito.ArgumentCaptor; -import org.mockito.Mock; -import org.mockito.MockitoAnnotations; -public class JdkSslEngineTest { +public class JdkSslEngineTest extends SSLEngineTest { private static final String PREFERRED_APPLICATION_LEVEL_PROTOCOL = "my-protocol-http2"; private static final String FALLBACK_APPLICATION_LEVEL_PROTOCOL = "my-protocol-http1_1"; private static final String APPLICATION_LEVEL_PROTOCOL_NOT_COMPATIBLE = "my-protocol-FOO"; - @Mock - private MessageReciever serverReceiver; - @Mock - private MessageReciever clientReceiver; - - private Throwable serverException; - private Throwable clientException; - private SslContext serverSslCtx; - private SslContext clientSslCtx; - private ServerBootstrap sb; - private Bootstrap cb; - private Channel serverChannel; - private Channel serverConnectedChannel; - private Channel clientChannel; - private CountDownLatch serverLatch; - private CountDownLatch clientLatch; - - private interface MessageReciever { - void messageReceived(ByteBuf msg); - } - - private final class MessageDelegatorChannelHandler extends SimpleChannelInboundHandler { - private final MessageReciever receiver; - private final CountDownLatch latch; - - public MessageDelegatorChannelHandler(MessageReciever receiver, CountDownLatch latch) { - super(false); - this.receiver = receiver; - this.latch = latch; - } - - @Override - protected void channelRead0(ChannelHandlerContext ctx, ByteBuf msg) throws Exception { - receiver.messageReceived(msg); - latch.countDown(); - } - } - - @Before - public void setup() { - MockitoAnnotations.initMocks(this); - serverLatch = new CountDownLatch(1); - clientLatch = new CountDownLatch(1); - } - - @After - public void tearDown() throws InterruptedException { - if (serverChannel != null) { - serverChannel.close().sync(); - Future serverGroup = sb.group().shutdownGracefully(0, 0, TimeUnit.MILLISECONDS); - Future serverChildGroup = sb.childGroup().shutdownGracefully(0, 0, TimeUnit.MILLISECONDS); - Future clientGroup = cb.group().shutdownGracefully(0, 0, TimeUnit.MILLISECONDS); - serverGroup.sync(); - serverChildGroup.sync(); - clientGroup.sync(); - } - clientChannel = null; - serverChannel = null; - serverConnectedChannel = null; - serverException = null; - } - @Test public void testNpn() throws Exception { try { @@ -344,57 +268,6 @@ public class JdkSslEngineTest { } } - @Test - public void testMutualAuthSameCerts() throws Exception { - mySetupMutualAuth(new File(getClass().getResource("test_unencrypted.pem").getFile()), - new File(getClass().getResource("test.crt").getFile()), - null); - runTest(null); - } - - @Test - public void testMutualAuthDiffCerts() throws Exception { - File serverKeyFile = new File(getClass().getResource("test_encrypted.pem").getFile()); - File serverCrtFile = new File(getClass().getResource("test.crt").getFile()); - String serverKeyPassword = "12345"; - File clientKeyFile = new File(getClass().getResource("test2_encrypted.pem").getFile()); - File clientCrtFile = new File(getClass().getResource("test2.crt").getFile()); - String clientKeyPassword = "12345"; - mySetupMutualAuth(clientCrtFile, serverKeyFile, serverCrtFile, serverKeyPassword, - serverCrtFile, clientKeyFile, clientCrtFile, clientKeyPassword); - runTest(null); - } - - @Test - public void testMutualAuthDiffCertsServerFailure() throws Exception { - File serverKeyFile = new File(getClass().getResource("test_encrypted.pem").getFile()); - File serverCrtFile = new File(getClass().getResource("test.crt").getFile()); - String serverKeyPassword = "12345"; - File clientKeyFile = new File(getClass().getResource("test2_encrypted.pem").getFile()); - File clientCrtFile = new File(getClass().getResource("test2.crt").getFile()); - String clientKeyPassword = "12345"; - // Client trusts server but server only trusts itself - mySetupMutualAuth(serverCrtFile, serverKeyFile, serverCrtFile, serverKeyPassword, - serverCrtFile, clientKeyFile, clientCrtFile, clientKeyPassword); - assertTrue(serverLatch.await(2, TimeUnit.SECONDS)); - assertTrue(serverException instanceof SSLHandshakeException); - } - - @Test - public void testMutualAuthDiffCertsClientFailure() throws Exception { - File serverKeyFile = new File(getClass().getResource("test_unencrypted.pem").getFile()); - File serverCrtFile = new File(getClass().getResource("test.crt").getFile()); - String serverKeyPassword = null; - File clientKeyFile = new File(getClass().getResource("test2_unencrypted.pem").getFile()); - File clientCrtFile = new File(getClass().getResource("test2.crt").getFile()); - String clientKeyPassword = null; - // Server trusts client but client only trusts itself - mySetupMutualAuth(clientCrtFile, serverKeyFile, serverCrtFile, serverKeyPassword, - clientCrtFile, clientKeyFile, clientCrtFile, clientKeyPassword); - assertTrue(clientLatch.await(2, TimeUnit.SECONDS)); - assertTrue(clientException instanceof SSLHandshakeException); - } - private void mySetup(JdkApplicationProtocolNegotiator apn) throws InterruptedException, SSLException, CertificateException { mySetup(apn, apn); @@ -465,132 +338,12 @@ public class JdkSslEngineTest { clientChannel = ccf.channel(); } - private void mySetupMutualAuth(File keyFile, File crtFile, String keyPassword) - throws SSLException, CertificateException, InterruptedException { - mySetupMutualAuth(crtFile, keyFile, crtFile, keyPassword, crtFile, keyFile, crtFile, keyPassword); - } - - private void mySetupMutualAuth( - File servertTrustCrtFile, File serverKeyFile, File serverCrtFile, String serverKeyPassword, - File clientTrustCrtFile, File clientKeyFile, File clientCrtFile, String clientKeyPassword) - throws InterruptedException, SSLException, CertificateException { - serverSslCtx = new JdkSslServerContext(servertTrustCrtFile, null, - serverCrtFile, serverKeyFile, serverKeyPassword, null, - null, IdentityCipherSuiteFilter.INSTANCE, (ApplicationProtocolConfig) null, 0, 0); - clientSslCtx = new JdkSslClientContext(clientTrustCrtFile, null, - clientCrtFile, clientKeyFile, clientKeyPassword, null, - null, IdentityCipherSuiteFilter.INSTANCE, (ApplicationProtocolConfig) null, 0, 0); - - serverConnectedChannel = null; - sb = new ServerBootstrap(); - cb = new Bootstrap(); - - sb.group(new NioEventLoopGroup(), new NioEventLoopGroup()); - sb.channel(NioServerSocketChannel.class); - sb.childHandler(new ChannelInitializer() { - @Override - protected void initChannel(Channel ch) throws Exception { - ChannelPipeline p = ch.pipeline(); - SSLEngine engine = serverSslCtx.newEngine(ch.alloc()); - engine.setUseClientMode(false); - engine.setNeedClientAuth(true); - p.addLast(new SslHandler(engine)); - p.addLast(new MessageDelegatorChannelHandler(serverReceiver, serverLatch)); - p.addLast(new ChannelHandlerAdapter() { - @Override - public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) throws Exception { - if (cause.getCause() instanceof SSLHandshakeException) { - serverException = cause.getCause(); - serverLatch.countDown(); - } else { - ctx.fireExceptionCaught(cause); - } - } - }); - serverConnectedChannel = ch; - } - }); - - cb.group(new NioEventLoopGroup()); - cb.channel(NioSocketChannel.class); - cb.handler(new ChannelInitializer() { - @Override - protected void initChannel(Channel ch) throws Exception { - ChannelPipeline p = ch.pipeline(); - p.addLast(clientSslCtx.newHandler(ch.alloc())); - p.addLast(new MessageDelegatorChannelHandler(clientReceiver, clientLatch)); - p.addLast(new ChannelHandlerAdapter() { - @Override - public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) throws Exception { - if (cause.getCause() instanceof SSLHandshakeException) { - clientException = cause.getCause(); - clientLatch.countDown(); - } else { - ctx.fireExceptionCaught(cause); - } - } - }); - } - }); - - serverChannel = sb.bind(new InetSocketAddress(0)).sync().channel(); - int port = ((InetSocketAddress) serverChannel.localAddress()).getPort(); - - ChannelFuture ccf = cb.connect(new InetSocketAddress(NetUtil.LOCALHOST, port)); - assertTrue(ccf.awaitUninterruptibly().isSuccess()); - clientChannel = ccf.channel(); - } - private void runTest() throws Exception { runTest(PREFERRED_APPLICATION_LEVEL_PROTOCOL); } - private void runTest(String expectedApplicationProtocol) throws Exception { - final ByteBuf clientMessage = Unpooled.copiedBuffer("I am a client".getBytes()); - final ByteBuf serverMessage = Unpooled.copiedBuffer("I am a server".getBytes()); - try { - writeAndVerifyReceived(clientMessage.retain(), clientChannel, serverLatch, serverReceiver); - writeAndVerifyReceived(serverMessage.retain(), serverConnectedChannel, clientLatch, clientReceiver); - if (expectedApplicationProtocol != null) { - verifyApplicationLevelProtocol(clientChannel, expectedApplicationProtocol); - verifyApplicationLevelProtocol(serverConnectedChannel, expectedApplicationProtocol); - } - } finally { - clientMessage.release(); - serverMessage.release(); - } - } - - private void verifyApplicationLevelProtocol(Channel channel, String expectedApplicationProtocol) { - SslHandler handler = channel.pipeline().get(SslHandler.class); - assertNotNull(handler); - String[] protocol = handler.engine().getSession().getProtocol().split(":"); - assertNotNull(protocol); - if (expectedApplicationProtocol != null && !expectedApplicationProtocol.isEmpty()) { - assertTrue("protocol.length must be greater than 1 but is " + protocol.length, protocol.length > 1); - assertEquals(expectedApplicationProtocol, protocol[1]); - } else { - assertEquals(1, protocol.length); - } - } - - private static void writeAndVerifyReceived(ByteBuf message, Channel sendChannel, CountDownLatch receiverLatch, - MessageReciever receiver) throws Exception { - List dataCapture = null; - try { - sendChannel.writeAndFlush(message); - receiverLatch.await(5, TimeUnit.SECONDS); - message.resetReaderIndex(); - ArgumentCaptor captor = ArgumentCaptor.forClass(ByteBuf.class); - verify(receiver).messageReceived(captor.capture()); - dataCapture = captor.getAllValues(); - assertEquals(message, dataCapture.get(0)); - } finally { - if (dataCapture != null) { - for (ByteBuf data : dataCapture) { - data.release(); - } - } - } + @Override + protected SslProvider sslProvider() { + return SslProvider.JDK; } } diff --git a/handler/src/test/java/io/netty/handler/ssl/OpenSslEngineTest.java b/handler/src/test/java/io/netty/handler/ssl/OpenSslEngineTest.java new file mode 100644 index 0000000000..0d6a76ca59 --- /dev/null +++ b/handler/src/test/java/io/netty/handler/ssl/OpenSslEngineTest.java @@ -0,0 +1,23 @@ +/* + * Copyright 2015 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.ssl; + +public class OpenSslEngineTest extends SSLEngineTest { + @Override + protected SslProvider sslProvider() { + return SslProvider.OPENSSL; + } +} diff --git a/handler/src/test/java/io/netty/handler/ssl/SSLEngineTest.java b/handler/src/test/java/io/netty/handler/ssl/SSLEngineTest.java new file mode 100644 index 0000000000..490f3db0c0 --- /dev/null +++ b/handler/src/test/java/io/netty/handler/ssl/SSLEngineTest.java @@ -0,0 +1,296 @@ +/* + * Copyright 2015 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.ssl; + +import io.netty.bootstrap.Bootstrap; +import io.netty.bootstrap.ServerBootstrap; +import io.netty.buffer.ByteBuf; +import io.netty.buffer.Unpooled; +import io.netty.channel.Channel; +import io.netty.channel.ChannelFuture; +import io.netty.channel.ChannelHandlerAdapter; +import io.netty.channel.ChannelHandlerContext; +import io.netty.channel.ChannelInitializer; +import io.netty.channel.ChannelPipeline; +import io.netty.channel.SimpleChannelInboundHandler; +import io.netty.channel.nio.NioEventLoopGroup; +import io.netty.channel.socket.nio.NioServerSocketChannel; +import io.netty.channel.socket.nio.NioSocketChannel; +import io.netty.util.NetUtil; +import io.netty.util.concurrent.Future; +import org.junit.After; +import org.junit.Before; +import org.junit.Test; +import org.mockito.ArgumentCaptor; +import org.mockito.Mock; +import org.mockito.MockitoAnnotations; + +import javax.net.ssl.SSLEngine; +import javax.net.ssl.SSLException; +import javax.net.ssl.SSLHandshakeException; +import java.io.File; +import java.net.InetSocketAddress; +import java.util.List; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.TimeUnit; + +import static org.junit.Assert.*; +import static org.mockito.Mockito.verify; + +public abstract class SSLEngineTest { + + @Mock + protected MessageReciever serverReceiver; + @Mock + protected MessageReciever clientReceiver; + + protected Throwable serverException; + protected Throwable clientException; + protected SslContext serverSslCtx; + protected SslContext clientSslCtx; + protected ServerBootstrap sb; + protected Bootstrap cb; + protected Channel serverChannel; + protected Channel serverConnectedChannel; + protected Channel clientChannel; + protected CountDownLatch serverLatch; + protected CountDownLatch clientLatch; + + interface MessageReciever { + void messageReceived(ByteBuf msg); + } + + protected static final class MessageDelegatorChannelHandler extends SimpleChannelInboundHandler { + private final MessageReciever receiver; + private final CountDownLatch latch; + + public MessageDelegatorChannelHandler(MessageReciever receiver, CountDownLatch latch) { + super(false); + this.receiver = receiver; + this.latch = latch; + } + + @Override + protected void channelRead0(ChannelHandlerContext ctx, ByteBuf msg) throws Exception { + receiver.messageReceived(msg); + latch.countDown(); + } + } + + @Before + public void setup() { + MockitoAnnotations.initMocks(this); + serverLatch = new CountDownLatch(1); + clientLatch = new CountDownLatch(1); + } + + @After + public void tearDown() throws InterruptedException { + if (serverChannel != null) { + serverChannel.close().sync(); + Future serverGroup = sb.group().shutdownGracefully(0, 0, TimeUnit.MILLISECONDS); + Future serverChildGroup = sb.childGroup().shutdownGracefully(0, 0, TimeUnit.MILLISECONDS); + Future clientGroup = cb.group().shutdownGracefully(0, 0, TimeUnit.MILLISECONDS); + serverGroup.sync(); + serverChildGroup.sync(); + clientGroup.sync(); + } + clientChannel = null; + serverChannel = null; + serverConnectedChannel = null; + serverException = null; + } + + @Test + public void testMutualAuthSameCerts() throws Exception { + mySetupMutualAuth(new File(getClass().getResource("test_unencrypted.pem").getFile()), + new File(getClass().getResource("test.crt").getFile()), + null); + runTest(null); + } + + @Test + public void testMutualAuthDiffCerts() throws Exception { + File serverKeyFile = new File(getClass().getResource("test_encrypted.pem").getFile()); + File serverCrtFile = new File(getClass().getResource("test.crt").getFile()); + String serverKeyPassword = "12345"; + File clientKeyFile = new File(getClass().getResource("test2_encrypted.pem").getFile()); + File clientCrtFile = new File(getClass().getResource("test2.crt").getFile()); + String clientKeyPassword = "12345"; + mySetupMutualAuth(clientCrtFile, serverKeyFile, serverCrtFile, serverKeyPassword, + serverCrtFile, clientKeyFile, clientCrtFile, clientKeyPassword); + runTest(null); + } + + @Test + public void testMutualAuthDiffCertsServerFailure() throws Exception { + File serverKeyFile = new File(getClass().getResource("test_encrypted.pem").getFile()); + File serverCrtFile = new File(getClass().getResource("test.crt").getFile()); + String serverKeyPassword = "12345"; + File clientKeyFile = new File(getClass().getResource("test2_encrypted.pem").getFile()); + File clientCrtFile = new File(getClass().getResource("test2.crt").getFile()); + String clientKeyPassword = "12345"; + // Client trusts server but server only trusts itself + mySetupMutualAuth(serverCrtFile, serverKeyFile, serverCrtFile, serverKeyPassword, + serverCrtFile, clientKeyFile, clientCrtFile, clientKeyPassword); + assertTrue(serverLatch.await(2, TimeUnit.SECONDS)); + assertTrue(serverException instanceof SSLHandshakeException); + } + + @Test + public void testMutualAuthDiffCertsClientFailure() throws Exception { + File serverKeyFile = new File(getClass().getResource("test_unencrypted.pem").getFile()); + File serverCrtFile = new File(getClass().getResource("test.crt").getFile()); + String serverKeyPassword = null; + File clientKeyFile = new File(getClass().getResource("test2_unencrypted.pem").getFile()); + File clientCrtFile = new File(getClass().getResource("test2.crt").getFile()); + String clientKeyPassword = null; + // Server trusts client but client only trusts itself + mySetupMutualAuth(clientCrtFile, serverKeyFile, serverCrtFile, serverKeyPassword, + clientCrtFile, clientKeyFile, clientCrtFile, clientKeyPassword); + assertTrue(clientLatch.await(2, TimeUnit.SECONDS)); + assertTrue(clientException instanceof SSLHandshakeException); + } + + private void mySetupMutualAuth(File keyFile, File crtFile, String keyPassword) + throws SSLException, InterruptedException { + mySetupMutualAuth(crtFile, keyFile, crtFile, keyPassword, crtFile, keyFile, crtFile, keyPassword); + } + + private void mySetupMutualAuth( + File servertTrustCrtFile, File serverKeyFile, File serverCrtFile, String serverKeyPassword, + File clientTrustCrtFile, File clientKeyFile, File clientCrtFile, String clientKeyPassword) + throws InterruptedException, SSLException { + serverSslCtx = SslContext.newServerContext(sslProvider(), servertTrustCrtFile, null, + serverCrtFile, serverKeyFile, serverKeyPassword, null, + null, IdentityCipherSuiteFilter.INSTANCE, null, 0, 0); + clientSslCtx = SslContext.newClientContext(sslProvider(), clientTrustCrtFile, null, + clientCrtFile, clientKeyFile, clientKeyPassword, null, + null, IdentityCipherSuiteFilter.INSTANCE, + null, 0, 0); + + serverConnectedChannel = null; + sb = new ServerBootstrap(); + cb = new Bootstrap(); + + sb.group(new NioEventLoopGroup(), new NioEventLoopGroup()); + sb.channel(NioServerSocketChannel.class); + sb.childHandler(new ChannelInitializer() { + @Override + protected void initChannel(Channel ch) throws Exception { + ChannelPipeline p = ch.pipeline(); + SSLEngine engine = serverSslCtx.newEngine(ch.alloc()); + engine.setUseClientMode(false); + engine.setNeedClientAuth(true); + p.addLast(new SslHandler(engine)); + p.addLast(new MessageDelegatorChannelHandler(serverReceiver, serverLatch)); + p.addLast(new ChannelHandlerAdapter() { + @Override + public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) throws Exception { + if (cause.getCause() instanceof SSLHandshakeException) { + serverException = cause.getCause(); + serverLatch.countDown(); + } else { + ctx.fireExceptionCaught(cause); + } + } + }); + serverConnectedChannel = ch; + } + }); + + cb.group(new NioEventLoopGroup()); + cb.channel(NioSocketChannel.class); + cb.handler(new ChannelInitializer() { + @Override + protected void initChannel(Channel ch) throws Exception { + ChannelPipeline p = ch.pipeline(); + p.addLast(clientSslCtx.newHandler(ch.alloc())); + p.addLast(new MessageDelegatorChannelHandler(clientReceiver, clientLatch)); + p.addLast(new ChannelHandlerAdapter() { + @Override + public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) throws Exception { + cause.printStackTrace(); + if (cause.getCause() instanceof SSLHandshakeException) { + clientException = cause.getCause(); + clientLatch.countDown(); + } else { + ctx.fireExceptionCaught(cause); + } + } + }); + } + }); + + serverChannel = sb.bind(new InetSocketAddress(0)).sync().channel(); + int port = ((InetSocketAddress) serverChannel.localAddress()).getPort(); + + ChannelFuture ccf = cb.connect(new InetSocketAddress(NetUtil.LOCALHOST, port)); + assertTrue(ccf.awaitUninterruptibly().isSuccess()); + clientChannel = ccf.channel(); + } + + protected void runTest(String expectedApplicationProtocol) throws Exception { + final ByteBuf clientMessage = Unpooled.copiedBuffer("I am a client".getBytes()); + final ByteBuf serverMessage = Unpooled.copiedBuffer("I am a server".getBytes()); + try { + writeAndVerifyReceived(clientMessage.retain(), clientChannel, serverLatch, serverReceiver); + writeAndVerifyReceived(serverMessage.retain(), serverConnectedChannel, clientLatch, clientReceiver); + if (expectedApplicationProtocol != null) { + verifyApplicationLevelProtocol(clientChannel, expectedApplicationProtocol); + verifyApplicationLevelProtocol(serverConnectedChannel, expectedApplicationProtocol); + } + } finally { + clientMessage.release(); + serverMessage.release(); + } + } + + private static void verifyApplicationLevelProtocol(Channel channel, String expectedApplicationProtocol) { + SslHandler handler = channel.pipeline().get(SslHandler.class); + assertNotNull(handler); + String[] protocol = handler.engine().getSession().getProtocol().split(":"); + assertNotNull(protocol); + if (expectedApplicationProtocol != null && !expectedApplicationProtocol.isEmpty()) { + assertTrue("protocol.length must be greater than 1 but is " + protocol.length, protocol.length > 1); + assertEquals(expectedApplicationProtocol, protocol[1]); + } else { + assertEquals(1, protocol.length); + } + } + + private static void writeAndVerifyReceived(ByteBuf message, Channel sendChannel, CountDownLatch receiverLatch, + MessageReciever receiver) throws Exception { + List dataCapture = null; + try { + sendChannel.writeAndFlush(message); + receiverLatch.await(5, TimeUnit.SECONDS); + message.resetReaderIndex(); + ArgumentCaptor captor = ArgumentCaptor.forClass(ByteBuf.class); + verify(receiver).messageReceived(captor.capture()); + dataCapture = captor.getAllValues(); + assertEquals(message, dataCapture.get(0)); + } finally { + if (dataCapture != null) { + for (ByteBuf data : dataCapture) { + data.release(); + } + } + } + } + + protected abstract SslProvider sslProvider(); +}