Add support for mutual auth when using OpenSslEngine.

Motivation:

Currently mutual auth is not supported when using OpenSslEngine.

Modification:

- Add support to OpenSslClientContext
- Correctly throw SSLHandshakeException when an error during handshake is detected

Result:

Mutual auth can be used with OpenSslEngine
This commit is contained in:
Norman Maurer 2015-04-29 21:37:50 +02:00
parent 0196dfa144
commit e33ef7a80e
9 changed files with 610 additions and 376 deletions

View File

@ -16,9 +16,7 @@
package io.netty.handler.ssl;
import io.netty.buffer.ByteBuf;
import io.netty.buffer.ByteBufAllocator;
import io.netty.buffer.ByteBufInputStream;
import io.netty.util.internal.EmptyArrays;
import io.netty.util.internal.logging.InternalLogger;
import io.netty.util.internal.logging.InternalLoggerFactory;
@ -28,8 +26,6 @@ import javax.net.ssl.KeyManagerFactory;
import javax.net.ssl.SSLContext;
import javax.net.ssl.SSLEngine;
import javax.net.ssl.SSLSessionContext;
import javax.net.ssl.TrustManagerFactory;
import javax.security.auth.x500.X500Principal;
import java.io.File;
import java.io.IOException;
import java.security.InvalidAlgorithmParameterException;
@ -40,8 +36,6 @@ import java.security.NoSuchAlgorithmException;
import java.security.Security;
import java.security.UnrecoverableKeyException;
import java.security.cert.CertificateException;
import java.security.cert.CertificateFactory;
import java.security.cert.X509Certificate;
import java.security.spec.InvalidKeySpecException;
import java.util.ArrayList;
import java.util.Arrays;
@ -325,39 +319,4 @@ public abstract class JdkSslContext extends SslContext {
return kmf;
}
/**
* Build a {@link TrustManagerFactory} from a certificate chain file.
* @param certChainFile The certificate file to build from.
* @param trustManagerFactory The existing {@link TrustManagerFactory} that will be used if not {@code null}.
* @return A {@link TrustManagerFactory} which contains the certificates in {@code certChainFile}
*/
protected static TrustManagerFactory buildTrustManagerFactory(File certChainFile,
TrustManagerFactory trustManagerFactory)
throws NoSuchAlgorithmException, CertificateException, KeyStoreException, IOException {
KeyStore ks = KeyStore.getInstance("JKS");
ks.load(null, null);
CertificateFactory cf = CertificateFactory.getInstance("X.509");
ByteBuf[] certs = PemReader.readCertificates(certChainFile);
try {
for (ByteBuf buf: certs) {
X509Certificate cert = (X509Certificate) cf.generateCertificate(new ByteBufInputStream(buf));
X500Principal principal = cert.getSubjectX500Principal();
ks.setCertificateEntry(principal.getName("RFC2253"), cert);
}
} finally {
for (ByteBuf buf: certs) {
buf.release();
}
}
// Set up trust manager factory to use our key store.
if (trustManagerFactory == null) {
trustManagerFactory = TrustManagerFactory.getInstance(TrustManagerFactory.getDefaultAlgorithm());
}
trustManagerFactory.init(ks);
return trustManagerFactory;
}
}

View File

@ -20,6 +20,7 @@ import io.netty.buffer.ByteBufInputStream;
import org.apache.tomcat.jni.SSL;
import org.apache.tomcat.jni.SSLContext;
import javax.net.ssl.KeyManagerFactory;
import javax.net.ssl.SSLException;
import javax.net.ssl.TrustManager;
import javax.net.ssl.TrustManagerFactory;
@ -39,13 +40,12 @@ import java.security.cert.X509Certificate;
*/
public final class OpenSslClientContext extends OpenSslContext {
private final OpenSslSessionContext sessionContext;
private final OpenSslEngineMap engineMap;
/**
* Creates a new instance.
*/
public OpenSslClientContext() throws SSLException {
this(null, null, null, IdentityCipherSuiteFilter.INSTANCE, null, 0, 0);
this(null, null, null, null, null, null, null, IdentityCipherSuiteFilter.INSTANCE, null, 0, 0);
}
/**
@ -79,7 +79,8 @@ public final class OpenSslClientContext extends OpenSslContext {
* {@code null} to use the default.
*/
public OpenSslClientContext(File certChainFile, TrustManagerFactory trustManagerFactory) throws SSLException {
this(certChainFile, trustManagerFactory, null, IdentityCipherSuiteFilter.INSTANCE, null, 0, 0);
this(certChainFile, trustManagerFactory, null, null, null, null, null,
IdentityCipherSuiteFilter.INSTANCE, null, 0, 0);
}
/**
@ -104,11 +105,14 @@ public final class OpenSslClientContext extends OpenSslContext {
public OpenSslClientContext(File certChainFile, TrustManagerFactory trustManagerFactory, Iterable<String> ciphers,
ApplicationProtocolConfig apn, long sessionCacheSize, long sessionTimeout)
throws SSLException {
this(certChainFile, trustManagerFactory, ciphers, IdentityCipherSuiteFilter.INSTANCE,
this(certChainFile, trustManagerFactory, null, null, null, null, ciphers, IdentityCipherSuiteFilter.INSTANCE,
apn, sessionCacheSize, sessionTimeout);
}
/**
* @deprecated use {@link #OpenSslClientContext(File, TrustManagerFactory, File, File, String,
* KeyManagerFactory, Iterable, CipherSuiteFilter, ApplicationProtocolConfig,long, long)}
*
* Creates a new instance.
*
* @param certChainFile an X.509 certificate chain file in PEM format
@ -124,29 +128,98 @@ public final class OpenSslClientContext extends OpenSslContext {
* @param sessionTimeout the timeout for the cached SSL session objects, in seconds.
* {@code 0} to use the default value.
*/
@Deprecated
public OpenSslClientContext(File certChainFile, TrustManagerFactory trustManagerFactory, Iterable<String> ciphers,
CipherSuiteFilter cipherFilter, ApplicationProtocolConfig apn,
long sessionCacheSize, long sessionTimeout) throws SSLException {
this(certChainFile, trustManagerFactory, null, null, null, null,
ciphers, cipherFilter, apn, sessionCacheSize, sessionTimeout);
}
/**
* Creates a new instance.
* @param trustCertChainFile an X.509 certificate chain file in PEM format.
* {@code null} to use the system default
* @param trustManagerFactory the {@link TrustManagerFactory} that provides the {@link TrustManager}s
* that verifies the certificates sent from servers.
* {@code null} to use the default or the results of parsing {@code trustCertChainFile}
* @param keyCertChainFile an X.509 certificate chain file in PEM format.
* This provides the public key for mutual authentication.
* {@code null} to use the system default
* @param keyFile a PKCS#8 private key file in PEM format.
* This provides the private key for mutual authentication.
* {@code null} for no mutual authentication.
* @param keyPassword the password of the {@code keyFile}.
* {@code null} if it's not password-protected.
* Ignored if {@code keyFile} is {@code null}.
* @param keyManagerFactory the {@link KeyManagerFactory} that provides the {@link javax.net.ssl.KeyManager}s
* that is used to encrypt data being sent to servers.
* {@code null} to use the default or the results of parsing
* {@code keyCertChainFile} and {@code keyFile}.
* @param ciphers the cipher suites to enable, in the order of preference.
* {@code null} to use the default cipher suites.
* @param cipherFilter a filter to apply over the supplied list of ciphers
* @param apn Application Protocol Negotiator object.
* @param sessionCacheSize the size of the cache used for storing SSL session objects.
* {@code 0} to use the default value.
* @param sessionTimeout the timeout for the cached SSL session objects, in seconds.
* {@code 0} to use the default value.
*/
public OpenSslClientContext(File trustCertChainFile, TrustManagerFactory trustManagerFactory,
File keyCertChainFile, File keyFile, String keyPassword,
KeyManagerFactory keyManagerFactory, Iterable<String> ciphers,
CipherSuiteFilter cipherFilter, ApplicationProtocolConfig apn,
long sessionCacheSize, long sessionTimeout)
throws SSLException {
super(ciphers, cipherFilter, apn, sessionCacheSize, sessionTimeout, SSL.SSL_MODE_CLIENT);
boolean success = false;
try {
if (certChainFile != null && !certChainFile.isFile()) {
throw new IllegalArgumentException("certChainFile is not a file: " + certChainFile);
if (trustCertChainFile != null && !trustCertChainFile.isFile()) {
throw new IllegalArgumentException("trustCertChainFile is not a file: " + trustCertChainFile);
}
if (keyCertChainFile != null && !keyCertChainFile.isFile()) {
throw new IllegalArgumentException("keyCertChainFile is not a file: " + keyCertChainFile);
}
if (keyFile != null && !keyFile.isFile()) {
throw new IllegalArgumentException("keyFile is not a file: " + keyFile);
}
if (keyFile == null && keyCertChainFile != null || keyFile != null && keyCertChainFile == null) {
throw new IllegalArgumentException(
"Either both keyCertChainFile and keyFile needs to be null or none of them");
}
synchronized (OpenSslContext.class) {
if (certChainFile != null) {
if (trustCertChainFile != null) {
/* Load the certificate chain. We must skip the first cert when server mode */
if (!SSLContext.setCertificateChainFile(ctx, certChainFile.getPath(), true)) {
if (!SSLContext.setCertificateChainFile(ctx, trustCertChainFile.getPath(), true)) {
long error = SSL.getLastErrorNumber();
if (OpenSsl.isError(error)) {
throw new SSLException(
"failed to set certificate chain: "
+ certChainFile + " (" + SSL.getErrorString(error) + ')');
+ trustCertChainFile + " (" + SSL.getErrorString(error) + ')');
}
}
}
if (keyCertChainFile != null && keyFile != null) {
/* Load the certificate file and private key. */
try {
if (!SSLContext.setCertificate(
ctx, keyCertChainFile.getPath(), keyFile.getPath(), keyPassword, SSL.SSL_AIDX_RSA)) {
long error = SSL.getLastErrorNumber();
if (OpenSsl.isError(error)) {
throw new SSLException("failed to set certificate: " +
keyCertChainFile + " and " + keyFile +
" (" + SSL.getErrorString(error) + ')');
}
}
} catch (SSLException e) {
throw e;
} catch (Exception e) {
throw new SSLException("failed to set certificate: " + keyCertChainFile + " and " + keyFile, e);
}
}
SSLContext.setVerify(ctx, SSL.SSL_VERIFY_NONE, VERIFY_DEPTH);
try {
@ -155,25 +228,24 @@ public final class OpenSslClientContext extends OpenSslContext {
trustManagerFactory = TrustManagerFactory.getInstance(
TrustManagerFactory.getDefaultAlgorithm());
}
initTrustManagerFactory(certChainFile, trustManagerFactory);
initTrustManagerFactory(trustCertChainFile, trustManagerFactory);
final X509TrustManager manager = chooseTrustManager(trustManagerFactory.getTrustManagers());
engineMap = newEngineMap(manager);
// Use this to prevent an error when running on java < 7
if (useExtendedTrustManager(manager)) {
final X509ExtendedTrustManager extendedManager = (X509ExtendedTrustManager) manager;
SSLContext.setCertVerifyCallback(ctx, new AbstractCertificateVerifier() {
@Override
void verify(long ssl, X509Certificate[] peerCerts, String auth) throws Exception {
OpenSslEngine engine = engineMap.remove(ssl);
void verify(OpenSslEngine engine, X509Certificate[] peerCerts, String auth)
throws Exception {
extendedManager.checkServerTrusted(peerCerts, auth, engine);
}
});
} else {
SSLContext.setCertVerifyCallback(ctx, new AbstractCertificateVerifier() {
@Override
void verify(long ssl, X509Certificate[] peerCerts, String auth) throws Exception {
void verify(OpenSslEngine engine, X509Certificate[] peerCerts, String auth)
throws Exception {
manager.checkServerTrusted(peerCerts, auth);
}
});
@ -218,11 +290,6 @@ public final class OpenSslClientContext extends OpenSslContext {
return sessionContext;
}
@Override
OpenSslEngineMap engineMap() {
return engineMap;
}
// No cache is currently supported for client side mode.
private static final class OpenSslClientSessionContext extends OpenSslSessionContext {
private OpenSslClientSessionContext(long context) {

View File

@ -26,6 +26,7 @@ import org.apache.tomcat.jni.SSLContext;
import javax.net.ssl.SSLEngine;
import javax.net.ssl.SSLException;
import javax.net.ssl.SSLHandshakeException;
import javax.net.ssl.TrustManager;
import javax.net.ssl.X509ExtendedTrustManager;
import javax.net.ssl.X509TrustManager;
@ -56,6 +57,8 @@ public abstract class OpenSslContext extends SslContext {
private final List<String> unmodifiableCiphers;
private final long sessionCacheSize;
private final long sessionTimeout;
private final OpenSslEngineMap engineMap = new DefaultOpenSslEngineMap();
private final OpenSslApplicationProtocolNegotiator apn;
/** The OpenSSL SSL_CTX object */
protected final long ctx;
@ -277,14 +280,11 @@ public abstract class OpenSslContext extends SslContext {
*/
@Override
public final SSLEngine newEngine(ByteBufAllocator alloc) {
OpenSslEngineMap engineMap = engineMap();
final OpenSslEngine engine = new OpenSslEngine(ctx, alloc, isClient(), sessionContext(), apn, engineMap);
engineMap.add(engine);
return engine;
}
abstract OpenSslEngineMap engineMap();
/**
* Returns the {@code SSL_CTX} object of this context.
*/
@ -392,31 +392,28 @@ public abstract class OpenSslContext extends SslContext {
}
}
static OpenSslEngineMap newEngineMap(X509TrustManager trustManager) {
if (useExtendedTrustManager(trustManager)) {
return new DefaultOpenSslEngineMap();
}
return OpenSslEngineMap.EMPTY;
}
static boolean useExtendedTrustManager(X509TrustManager trustManager) {
return PlatformDependent.javaVersion() >= 7 && trustManager instanceof X509ExtendedTrustManager;
}
abstract static class AbstractCertificateVerifier implements CertificateVerifier {
abstract class AbstractCertificateVerifier implements CertificateVerifier {
@Override
public final boolean verify(long ssl, byte[][] chain, String auth) {
X509Certificate[] peerCerts = certificates(chain);
final OpenSslEngine engine = engineMap.remove(ssl);
try {
verify(ssl, peerCerts, auth);
verify(engine, peerCerts, auth);
return true;
} catch (Exception e) {
logger.debug("verification of certificate failed", e);
} catch (Throwable cause) {
logger.debug("verification of certificate failed", cause);
SSLHandshakeException e = new SSLHandshakeException("General OpenSslEngine problem");
e.initCause(cause);
engine.handshakeException = e;
}
return false;
}
abstract void verify(long ssl, X509Certificate[] peerCerts, String auth) throws Exception;
abstract void verify(OpenSslEngine engine, X509Certificate[] peerCerts, String auth) throws Exception;
}
private static final class DefaultOpenSslEngineMap implements OpenSslEngineMap {

View File

@ -29,6 +29,7 @@ import org.apache.tomcat.jni.SSL;
import javax.net.ssl.SSLEngine;
import javax.net.ssl.SSLEngineResult;
import javax.net.ssl.SSLException;
import javax.net.ssl.SSLHandshakeException;
import javax.net.ssl.SSLPeerUnverifiedException;
import javax.net.ssl.SSLSession;
import javax.net.ssl.SSLSessionBindingEvent;
@ -160,6 +161,10 @@ public final class OpenSslEngine extends SSLEngine {
private final OpenSslApplicationProtocolNegotiator apn;
private final SSLSession session = new OpenSslSession();
// This is package-private as we set it from OpenSslContext if an exception is thrown during
// the verification step.
SSLHandshakeException handshakeException;
/**
* Creates a new instance
*
@ -489,6 +494,22 @@ public final class OpenSslEngine extends SSLEngine {
return new SSLEngineResult(getEngineStatus(), handshakeStatus0(), bytesConsumed, bytesProduced);
}
private SSLException newSSLException(String msg) {
if (!handshakeFinished) {
return new SSLHandshakeException(msg);
}
return new SSLException(msg);
}
private void checkPendingHandshakeException() throws SSLHandshakeException {
if (handshakeException != null) {
SSLHandshakeException exception = handshakeException;
handshakeException = null;
shutdown();
throw exception;
}
}
public synchronized SSLEngineResult unwrap(
final ByteBuffer[] srcs, int srcsOffset, final int srcsLength,
final ByteBuffer[] dsts, final int dstsOffset, final int dstsLength) throws SSLException {
@ -608,7 +629,9 @@ public final class OpenSslEngine extends SSLEngine {
// There was an internal error -- shutdown
shutdown();
throw new SSLException(err);
throw newSSLException(err);
} else {
checkPendingHandshakeException();
}
}
} else {
@ -954,8 +977,9 @@ public final class OpenSslEngine extends SSLEngine {
// There was an internal error -- shutdown
shutdown();
throw new SSLException(err);
throw newSSLException(err);
}
checkPendingHandshakeException();
} else {
// if SSL_do_handshake returns > 0 it means the handshake was finished. This means we can update
// handshakeFinished directly and so eliminate uncessary calls to SSL.isInInit(...)
@ -1037,6 +1061,7 @@ public final class OpenSslEngine extends SSLEngine {
if (status == FINISHED) {
handshakeFinished();
}
checkPendingHandshakeException();
return status;
}

View File

@ -19,7 +19,10 @@ import io.netty.util.internal.EmptyArrays;
import org.apache.tomcat.jni.SSL;
import org.apache.tomcat.jni.SSLContext;
import javax.net.ssl.KeyManager;
import javax.net.ssl.KeyManagerFactory;
import javax.net.ssl.SSLException;
import javax.net.ssl.TrustManager;
import javax.net.ssl.TrustManagerFactory;
import javax.net.ssl.X509ExtendedTrustManager;
import javax.net.ssl.X509TrustManager;
@ -34,7 +37,6 @@ import static io.netty.util.internal.ObjectUtil.*;
*/
public final class OpenSslServerContext extends OpenSslContext {
private final OpenSslServerSessionContext sessionContext;
private final OpenSslEngineMap engineMap;
/**
* Creates a new instance.
@ -55,8 +57,8 @@ public final class OpenSslServerContext extends OpenSslContext {
* {@code null} if it's not password-protected.
*/
public OpenSslServerContext(File certChainFile, File keyFile, String keyPassword) throws SSLException {
this(certChainFile, keyFile, keyPassword, null, null, IdentityCipherSuiteFilter.INSTANCE,
NONE_PROTOCOL_NEGOTIATOR, 0, 0);
this(certChainFile, keyFile, keyPassword, null, IdentityCipherSuiteFilter.INSTANCE,
ApplicationProtocolConfig.DISABLED, 0, 0);
}
/**
@ -82,8 +84,8 @@ public final class OpenSslServerContext extends OpenSslContext {
File certChainFile, File keyFile, String keyPassword,
Iterable<String> ciphers, ApplicationProtocolConfig apn,
long sessionCacheSize, long sessionTimeout) throws SSLException {
this(certChainFile, keyFile, keyPassword, null, ciphers,
toNegotiator(apn), sessionCacheSize, sessionTimeout);
this(certChainFile, keyFile, keyPassword, ciphers, IdentityCipherSuiteFilter.INSTANCE,
apn, sessionCacheSize, sessionTimeout);
}
/**
@ -100,9 +102,8 @@ public final class OpenSslServerContext extends OpenSslContext {
* {@code 0} to use the default value.
* @param sessionTimeout the timeout for the cached SSL session objects, in seconds.
* {@code 0} to use the default value.
* @deprecated use {@link #OpenSslServerContext(
* File, File, String, TrustManagerFactory, Iterable,
* CipherSuiteFilter, ApplicationProtocolConfig, long, long)}
* @deprecated use {@link #OpenSslServerContext(File, TrustManagerFactory, File, File, String, KeyManagerFactory,
* Iterable, CipherSuiteFilter, ApplicationProtocolConfig, long, long)}
*/
@Deprecated
public OpenSslServerContext(
@ -127,17 +128,16 @@ public final class OpenSslServerContext extends OpenSslContext {
* {@code 0} to use the default value.
* @param sessionTimeout the timeout for the cached SSL session objects, in seconds.
* {@code 0} to use the default value.
* @deprecated use {@link #OpenSslServerContext(
* File, File, String, TrustManagerFactory, Iterable,
* CipherSuiteFilter, OpenSslApplicationProtocolNegotiator, long, long)}
* @deprecated use {@link #OpenSslServerContext(File, TrustManagerFactory, File, File, String, KeyManagerFactory,
* Iterable, CipherSuiteFilter, ApplicationProtocolConfig, long, long)}
*/
@Deprecated
public OpenSslServerContext(
File certChainFile, File keyFile, String keyPassword, TrustManagerFactory trustManagerFactory,
Iterable<String> ciphers, OpenSslApplicationProtocolNegotiator apn,
long sessionCacheSize, long sessionTimeout) throws SSLException {
this(certChainFile, keyFile, keyPassword, trustManagerFactory, ciphers,
IdentityCipherSuiteFilter.INSTANCE, apn, sessionCacheSize, sessionTimeout);
this(null, trustManagerFactory, certChainFile, keyFile, keyPassword, null,
ciphers, null, apn, sessionCacheSize, sessionTimeout);
}
/**
@ -160,8 +160,44 @@ public final class OpenSslServerContext extends OpenSslContext {
File certChainFile, File keyFile, String keyPassword,
Iterable<String> ciphers, CipherSuiteFilter cipherFilter, ApplicationProtocolConfig apn,
long sessionCacheSize, long sessionTimeout) throws SSLException {
this(certChainFile, keyFile, keyPassword, null, ciphers, cipherFilter,
toNegotiator(apn), sessionCacheSize, sessionTimeout);
this(null, null, certChainFile, keyFile, keyPassword, null,
ciphers, cipherFilter, apn, sessionCacheSize, sessionTimeout);
}
/**
* Creates a new instance.
*
* @param trustCertChainFile an X.509 certificate chain file in PEM format.
* This provides the certificate chains used for mutual authentication.
* {@code null} to use the system default
* @param trustManagerFactory the {@link TrustManagerFactory} that provides the {@link TrustManager}s
* that verifies the certificates sent from clients.
* {@code null} to use the default or the results of parsing {@code trustCertChainFile}.
* @param keyCertChainFile an X.509 certificate chain file in PEM format
* @param keyFile a PKCS#8 private key file in PEM format
* @param keyPassword the password of the {@code keyFile}.
* {@code null} if it's not password-protected.
* @param keyManagerFactory the {@link KeyManagerFactory} that provides the {@link KeyManager}s
* that is used to encrypt data being sent to clients.
* {@code null} to use the default or the results of parsing
* {@code keyCertChainFile} and {@code keyFile}.
* @param ciphers the cipher suites to enable, in the order of preference.
* {@code null} to use the default cipher suites.
* @param cipherFilter a filter to apply over the supplied list of ciphers
* Only required if {@code provider} is {@link SslProvider#JDK}
* @param config Provides a means to configure parameters related to application protocol negotiation.
* @param sessionCacheSize the size of the cache used for storing SSL session objects.
* {@code 0} to use the default value.
* @param sessionTimeout the timeout for the cached SSL session objects, in seconds.
* {@code 0} to use the default value.
*/
public OpenSslServerContext(
File trustCertChainFile, TrustManagerFactory trustManagerFactory,
File keyCertChainFile, File keyFile, String keyPassword, KeyManagerFactory keyManagerFactory,
Iterable<String> ciphers, CipherSuiteFilter cipherFilter, ApplicationProtocolConfig config,
long sessionCacheSize, long sessionTimeout) throws SSLException {
this(trustCertChainFile, trustManagerFactory, keyCertChainFile, keyFile, keyPassword, keyManagerFactory,
ciphers, cipherFilter, toNegotiator(config), sessionCacheSize, sessionTimeout);
}
/**
@ -180,12 +216,13 @@ public final class OpenSslServerContext extends OpenSslContext {
* @param sessionTimeout the timeout for the cached SSL session objects, in seconds.
* {@code 0} to use the default value.
*/
public OpenSslServerContext(
File certChainFile, File keyFile, String keyPassword, TrustManagerFactory trustManagerFactory,
Iterable<String> ciphers, CipherSuiteFilter cipherFilter, ApplicationProtocolConfig config,
long sessionCacheSize, long sessionTimeout) throws SSLException {
this(certChainFile, keyFile, keyPassword, trustManagerFactory, ciphers, cipherFilter,
toNegotiator(config), sessionCacheSize, sessionTimeout);
@Deprecated
public OpenSslServerContext(File certChainFile, File keyFile, String keyPassword,
TrustManagerFactory trustManagerFactory, Iterable<String> ciphers,
CipherSuiteFilter cipherFilter, ApplicationProtocolConfig config,
long sessionCacheSize, long sessionTimeout) throws SSLException {
this(null, trustManagerFactory, certChainFile, keyFile, keyPassword, null, ciphers, cipherFilter,
toNegotiator(config), sessionCacheSize, sessionTimeout);
}
/**
@ -203,21 +240,61 @@ public final class OpenSslServerContext extends OpenSslContext {
* {@code 0} to use the default value.
* @param sessionTimeout the timeout for the cached SSL session objects, in seconds.
* {@code 0} to use the default value.
* @deprecated use {@link #OpenSslServerContext(File, TrustManagerFactory, File, File, String, KeyManagerFactory,
* Iterable, CipherSuiteFilter, OpenSslApplicationProtocolNegotiator, long, long)}
*/
@Deprecated
public OpenSslServerContext(
File certChainFile, File keyFile, String keyPassword, TrustManagerFactory trustManagerFactory,
Iterable<String> ciphers, CipherSuiteFilter cipherFilter, OpenSslApplicationProtocolNegotiator apn,
long sessionCacheSize, long sessionTimeout) throws SSLException {
this(null, trustManagerFactory, certChainFile, keyFile, keyPassword, null, ciphers, cipherFilter,
apn, sessionCacheSize, sessionTimeout);
}
/**
* Creates a new instance.
*
*
* @param trustCertChainFile an X.509 certificate chain file in PEM format.
* This provides the certificate chains used for mutual authentication.
* {@code null} to use the system default
* @param trustManagerFactory the {@link TrustManagerFactory} that provides the {@link TrustManager}s
* that verifies the certificates sent from clients.
* {@code null} to use the default or the results of parsing {@code trustCertChainFile}.
* @param keyCertChainFile an X.509 certificate chain file in PEM format
* @param keyFile a PKCS#8 private key file in PEM format
* @param keyPassword the password of the {@code keyFile}.
* {@code null} if it's not password-protected.
* @param keyManagerFactory the {@link KeyManagerFactory} that provides the {@link KeyManager}s
* that is used to encrypt data being sent to clients.
* {@code null} to use the default or the results of parsing
* {@code keyCertChainFile} and {@code keyFile}.
* @param ciphers the cipher suites to enable, in the order of preference.
* {@code null} to use the default cipher suites.
* @param cipherFilter a filter to apply over the supplied list of ciphers
* Only required if {@code provider} is {@link SslProvider#JDK}
* @param apn Application Protocol Negotiator object
* @param sessionCacheSize the size of the cache used for storing SSL session objects.
* {@code 0} to use the default value.
* @param sessionTimeout the timeout for the cached SSL session objects, in seconds.
* {@code 0} to use the default value.
*/
public OpenSslServerContext(
File trustCertChainFile, TrustManagerFactory trustManagerFactory,
File keyCertChainFile, File keyFile, String keyPassword, KeyManagerFactory keyManagerFactory,
Iterable<String> ciphers, CipherSuiteFilter cipherFilter, OpenSslApplicationProtocolNegotiator apn,
long sessionCacheSize, long sessionTimeout) throws SSLException {
super(ciphers, cipherFilter, apn, sessionCacheSize, sessionTimeout, SSL.SSL_MODE_SERVER);
OpenSsl.ensureAvailability();
checkNotNull(certChainFile, "certChainFile");
if (!certChainFile.isFile()) {
throw new IllegalArgumentException("certChainFile is not a file: " + certChainFile);
checkNotNull(keyCertChainFile, "keyCertChainFile");
if (!keyCertChainFile.isFile()) {
throw new IllegalArgumentException("keyCertChainFile is not a file: " + keyCertChainFile);
}
checkNotNull(keyFile, "keyFile");
if (!keyFile.isFile()) {
throw new IllegalArgumentException("keyPath is not a file: " + keyFile);
throw new IllegalArgumentException("keyFile is not a file: " + keyFile);
}
if (keyPassword == null) {
keyPassword = "";
@ -231,59 +308,64 @@ public final class OpenSslServerContext extends OpenSslContext {
SSLContext.setVerify(ctx, SSL.SSL_CVERIFY_NONE, VERIFY_DEPTH);
/* Load the certificate chain. We must skip the first cert when server mode */
if (!SSLContext.setCertificateChainFile(ctx, certChainFile.getPath(), true)) {
if (!SSLContext.setCertificateChainFile(ctx, keyCertChainFile.getPath(), true)) {
long error = SSL.getLastErrorNumber();
if (OpenSsl.isError(error)) {
String err = SSL.getErrorString(error);
throw new SSLException(
"failed to set certificate chain: " + certChainFile + " (" + err + ')');
"failed to set certificate chain: " + keyCertChainFile + " (" + err + ')');
}
}
/* Load the certificate file and private key. */
try {
if (!SSLContext.setCertificate(
ctx, certChainFile.getPath(), keyFile.getPath(), keyPassword, SSL.SSL_AIDX_RSA)) {
ctx, keyCertChainFile.getPath(), keyFile.getPath(), keyPassword, SSL.SSL_AIDX_RSA)) {
long error = SSL.getLastErrorNumber();
if (OpenSsl.isError(error)) {
String err = SSL.getErrorString(error);
throw new SSLException("failed to set certificate: " +
certChainFile + " and " + keyFile + " (" + err + ')');
keyCertChainFile + " and " + keyFile + " (" + err + ')');
}
}
} catch (SSLException e) {
throw e;
} catch (Exception e) {
throw new SSLException("failed to set certificate: " + certChainFile + " and " + keyFile, e);
throw new SSLException("failed to set certificate: " + keyCertChainFile + " and " + keyFile, e);
}
try {
char[] keyPasswordChars = keyPassword == null ? EmptyArrays.EMPTY_CHARS : keyPassword.toCharArray();
KeyStore ks = buildKeyStore(certChainFile, keyFile, keyPasswordChars);
if (trustManagerFactory == null) {
// Mimic the way SSLContext.getInstance(KeyManager[], null, null) works
trustManagerFactory = TrustManagerFactory.getInstance(
TrustManagerFactory.getDefaultAlgorithm());
}
trustManagerFactory.init(ks);
if (trustCertChainFile != null) {
trustManagerFactory = buildTrustManagerFactory(trustCertChainFile, trustManagerFactory);
} else {
char[] keyPasswordChars =
keyPassword == null ? EmptyArrays.EMPTY_CHARS : keyPassword.toCharArray();
KeyStore ks = buildKeyStore(keyCertChainFile, keyFile, keyPasswordChars);
trustManagerFactory.init(ks);
}
final X509TrustManager manager = chooseTrustManager(trustManagerFactory.getTrustManagers());
engineMap = newEngineMap(manager);
// Use this to prevent an error when running on java < 7
if (useExtendedTrustManager(manager)) {
final X509ExtendedTrustManager extendedManager = (X509ExtendedTrustManager) manager;
SSLContext.setCertVerifyCallback(ctx, new AbstractCertificateVerifier() {
@Override
void verify(long ssl, X509Certificate[] peerCerts, String auth) throws Exception {
OpenSslEngine engine = engineMap.remove(ssl);
void verify(OpenSslEngine engine, X509Certificate[] peerCerts, String auth)
throws Exception {
extendedManager.checkClientTrusted(peerCerts, auth, engine);
}
});
} else {
SSLContext.setCertVerifyCallback(ctx, new AbstractCertificateVerifier() {
@Override
void verify(long ssl, X509Certificate[] peerCerts, String auth) throws Exception {
void verify(OpenSslEngine engine, X509Certificate[] peerCerts, String auth)
throws Exception {
manager.checkClientTrusted(peerCerts, auth);
}
});
@ -305,9 +387,4 @@ public final class OpenSslServerContext extends OpenSslContext {
public OpenSslServerSessionContext sessionContext() {
return sessionContext;
}
@Override
OpenSslEngineMap engineMap() {
return engineMap;
}
}

View File

@ -36,6 +36,7 @@ import javax.net.ssl.SSLException;
import javax.net.ssl.SSLSessionContext;
import javax.net.ssl.TrustManager;
import javax.net.ssl.TrustManagerFactory;
import javax.security.auth.x500.X500Principal;
import java.io.File;
import java.io.IOException;
import java.security.InvalidAlgorithmParameterException;
@ -49,6 +50,7 @@ import java.security.PrivateKey;
import java.security.cert.Certificate;
import java.security.cert.CertificateException;
import java.security.cert.CertificateFactory;
import java.security.cert.X509Certificate;
import java.security.spec.InvalidKeySpecException;
import java.security.spec.PKCS8EncodedKeySpec;
import java.util.ArrayList;
@ -296,8 +298,8 @@ public abstract class SslContext {
keyManagerFactory, ciphers, cipherFilter, apn, sessionCacheSize, sessionTimeout);
case OPENSSL:
return new OpenSslServerContext(
keyCertChainFile, keyFile, keyPassword, trustManagerFactory,
ciphers, cipherFilter, apn, sessionCacheSize, sessionTimeout);
trustCertChainFile, trustManagerFactory, keyCertChainFile, keyFile, keyPassword,
keyManagerFactory, ciphers, cipherFilter, apn, sessionCacheSize, sessionTimeout);
default:
throw new Error(provider.toString());
}
@ -555,8 +557,8 @@ public abstract class SslContext {
keyManagerFactory, ciphers, cipherFilter, apn, sessionCacheSize, sessionTimeout);
case OPENSSL:
return new OpenSslClientContext(
trustCertChainFile, trustManagerFactory, ciphers, cipherFilter, apn,
sessionCacheSize, sessionTimeout);
trustCertChainFile, trustManagerFactory, keyCertChainFile, keyFile, keyPassword,
keyManagerFactory, ciphers, cipherFilter, apn, sessionCacheSize, sessionTimeout);
}
// Should never happen!!
throw new Error();
@ -731,4 +733,39 @@ public abstract class SslContext {
ks.setKeyEntry("key", key, keyPasswordChars, certChain.toArray(new Certificate[certChain.size()]));
return ks;
}
/**
* Build a {@link TrustManagerFactory} from a certificate chain file.
* @param certChainFile The certificate file to build from.
* @param trustManagerFactory The existing {@link TrustManagerFactory} that will be used if not {@code null}.
* @return A {@link TrustManagerFactory} which contains the certificates in {@code certChainFile}
*/
protected static TrustManagerFactory buildTrustManagerFactory(File certChainFile,
TrustManagerFactory trustManagerFactory)
throws NoSuchAlgorithmException, CertificateException, KeyStoreException, IOException {
KeyStore ks = KeyStore.getInstance("JKS");
ks.load(null, null);
CertificateFactory cf = CertificateFactory.getInstance("X.509");
ByteBuf[] certs = PemReader.readCertificates(certChainFile);
try {
for (ByteBuf buf: certs) {
X509Certificate cert = (X509Certificate) cf.generateCertificate(new ByteBufInputStream(buf));
X500Principal principal = cert.getSubjectX500Principal();
ks.setCertificateEntry(principal.getName("RFC2253"), cert);
}
} finally {
for (ByteBuf buf: certs) {
buf.release();
}
}
// Set up trust manager factory to use our key store.
if (trustManagerFactory == null) {
trustManagerFactory = TrustManagerFactory.getInstance(TrustManagerFactory.getDefaultAlgorithm());
}
trustManagerFactory.init(ks);
return trustManagerFactory;
}
}

View File

@ -15,23 +15,17 @@
*/
package io.netty.handler.ssl;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertNotNull;
import static org.junit.Assert.assertNull;
import static org.junit.Assert.assertTrue;
import static org.junit.Assume.assumeNoException;
import static org.mockito.Mockito.verify;
import io.netty.bootstrap.Bootstrap;
import io.netty.bootstrap.ServerBootstrap;
import io.netty.buffer.ByteBuf;
import io.netty.buffer.Unpooled;
import io.netty.channel.Channel;
import io.netty.channel.ChannelFuture;
import io.netty.channel.ChannelHandlerAdapter;
import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.ChannelInitializer;
import io.netty.channel.ChannelPipeline;
import io.netty.channel.SimpleChannelInboundHandler;
import io.netty.channel.nio.NioEventLoopGroup;
import io.netty.channel.socket.nio.NioServerSocketChannel;
import io.netty.channel.socket.nio.NioSocketChannel;
@ -40,94 +34,24 @@ import io.netty.handler.ssl.JdkApplicationProtocolNegotiator.ProtocolSelectorFac
import io.netty.handler.ssl.util.InsecureTrustManagerFactory;
import io.netty.handler.ssl.util.SelfSignedCertificate;
import io.netty.util.NetUtil;
import io.netty.util.concurrent.Future;
import java.io.File;
import java.net.InetSocketAddress;
import java.security.cert.CertificateException;
import java.util.List;
import java.util.Set;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.TimeUnit;
import javax.net.ssl.SSLEngine;
import javax.net.ssl.SSLException;
import javax.net.ssl.SSLHandshakeException;
import org.junit.After;
import org.junit.Before;
import org.junit.Test;
import org.mockito.ArgumentCaptor;
import org.mockito.Mock;
import org.mockito.MockitoAnnotations;
public class JdkSslEngineTest {
public class JdkSslEngineTest extends SSLEngineTest {
private static final String PREFERRED_APPLICATION_LEVEL_PROTOCOL = "my-protocol-http2";
private static final String FALLBACK_APPLICATION_LEVEL_PROTOCOL = "my-protocol-http1_1";
private static final String APPLICATION_LEVEL_PROTOCOL_NOT_COMPATIBLE = "my-protocol-FOO";
@Mock
private MessageReciever serverReceiver;
@Mock
private MessageReciever clientReceiver;
private Throwable serverException;
private Throwable clientException;
private SslContext serverSslCtx;
private SslContext clientSslCtx;
private ServerBootstrap sb;
private Bootstrap cb;
private Channel serverChannel;
private Channel serverConnectedChannel;
private Channel clientChannel;
private CountDownLatch serverLatch;
private CountDownLatch clientLatch;
private interface MessageReciever {
void messageReceived(ByteBuf msg);
}
private final class MessageDelegatorChannelHandler extends SimpleChannelInboundHandler<ByteBuf> {
private final MessageReciever receiver;
private final CountDownLatch latch;
public MessageDelegatorChannelHandler(MessageReciever receiver, CountDownLatch latch) {
super(false);
this.receiver = receiver;
this.latch = latch;
}
@Override
protected void messageReceived(ChannelHandlerContext ctx, ByteBuf msg) throws Exception {
receiver.messageReceived(msg);
latch.countDown();
}
}
@Before
public void setup() {
MockitoAnnotations.initMocks(this);
serverLatch = new CountDownLatch(1);
clientLatch = new CountDownLatch(1);
}
@After
public void tearDown() throws InterruptedException {
if (serverChannel != null) {
serverChannel.close().sync();
Future<?> serverGroup = sb.group().shutdownGracefully(0, 0, TimeUnit.MILLISECONDS);
Future<?> serverChildGroup = sb.childGroup().shutdownGracefully(0, 0, TimeUnit.MILLISECONDS);
Future<?> clientGroup = cb.group().shutdownGracefully(0, 0, TimeUnit.MILLISECONDS);
serverGroup.sync();
serverChildGroup.sync();
clientGroup.sync();
}
clientChannel = null;
serverChannel = null;
serverConnectedChannel = null;
serverException = null;
}
@Test
public void testNpn() throws Exception {
try {
@ -344,57 +268,6 @@ public class JdkSslEngineTest {
}
}
@Test
public void testMutualAuthSameCerts() throws Exception {
mySetupMutualAuth(new File(getClass().getResource("test_unencrypted.pem").getFile()),
new File(getClass().getResource("test.crt").getFile()),
null);
runTest(null);
}
@Test
public void testMutualAuthDiffCerts() throws Exception {
File serverKeyFile = new File(getClass().getResource("test_encrypted.pem").getFile());
File serverCrtFile = new File(getClass().getResource("test.crt").getFile());
String serverKeyPassword = "12345";
File clientKeyFile = new File(getClass().getResource("test2_encrypted.pem").getFile());
File clientCrtFile = new File(getClass().getResource("test2.crt").getFile());
String clientKeyPassword = "12345";
mySetupMutualAuth(clientCrtFile, serverKeyFile, serverCrtFile, serverKeyPassword,
serverCrtFile, clientKeyFile, clientCrtFile, clientKeyPassword);
runTest(null);
}
@Test
public void testMutualAuthDiffCertsServerFailure() throws Exception {
File serverKeyFile = new File(getClass().getResource("test_encrypted.pem").getFile());
File serverCrtFile = new File(getClass().getResource("test.crt").getFile());
String serverKeyPassword = "12345";
File clientKeyFile = new File(getClass().getResource("test2_encrypted.pem").getFile());
File clientCrtFile = new File(getClass().getResource("test2.crt").getFile());
String clientKeyPassword = "12345";
// Client trusts server but server only trusts itself
mySetupMutualAuth(serverCrtFile, serverKeyFile, serverCrtFile, serverKeyPassword,
serverCrtFile, clientKeyFile, clientCrtFile, clientKeyPassword);
assertTrue(serverLatch.await(2, TimeUnit.SECONDS));
assertTrue(serverException instanceof SSLHandshakeException);
}
@Test
public void testMutualAuthDiffCertsClientFailure() throws Exception {
File serverKeyFile = new File(getClass().getResource("test_unencrypted.pem").getFile());
File serverCrtFile = new File(getClass().getResource("test.crt").getFile());
String serverKeyPassword = null;
File clientKeyFile = new File(getClass().getResource("test2_unencrypted.pem").getFile());
File clientCrtFile = new File(getClass().getResource("test2.crt").getFile());
String clientKeyPassword = null;
// Server trusts client but client only trusts itself
mySetupMutualAuth(clientCrtFile, serverKeyFile, serverCrtFile, serverKeyPassword,
clientCrtFile, clientKeyFile, clientCrtFile, clientKeyPassword);
assertTrue(clientLatch.await(2, TimeUnit.SECONDS));
assertTrue(clientException instanceof SSLHandshakeException);
}
private void mySetup(JdkApplicationProtocolNegotiator apn) throws InterruptedException, SSLException,
CertificateException {
mySetup(apn, apn);
@ -465,132 +338,12 @@ public class JdkSslEngineTest {
clientChannel = ccf.channel();
}
private void mySetupMutualAuth(File keyFile, File crtFile, String keyPassword)
throws SSLException, CertificateException, InterruptedException {
mySetupMutualAuth(crtFile, keyFile, crtFile, keyPassword, crtFile, keyFile, crtFile, keyPassword);
}
private void mySetupMutualAuth(
File servertTrustCrtFile, File serverKeyFile, File serverCrtFile, String serverKeyPassword,
File clientTrustCrtFile, File clientKeyFile, File clientCrtFile, String clientKeyPassword)
throws InterruptedException, SSLException, CertificateException {
serverSslCtx = new JdkSslServerContext(servertTrustCrtFile, null,
serverCrtFile, serverKeyFile, serverKeyPassword, null,
null, IdentityCipherSuiteFilter.INSTANCE, (ApplicationProtocolConfig) null, 0, 0);
clientSslCtx = new JdkSslClientContext(clientTrustCrtFile, null,
clientCrtFile, clientKeyFile, clientKeyPassword, null,
null, IdentityCipherSuiteFilter.INSTANCE, (ApplicationProtocolConfig) null, 0, 0);
serverConnectedChannel = null;
sb = new ServerBootstrap();
cb = new Bootstrap();
sb.group(new NioEventLoopGroup(), new NioEventLoopGroup());
sb.channel(NioServerSocketChannel.class);
sb.childHandler(new ChannelInitializer<Channel>() {
@Override
protected void initChannel(Channel ch) throws Exception {
ChannelPipeline p = ch.pipeline();
SSLEngine engine = serverSslCtx.newEngine(ch.alloc());
engine.setUseClientMode(false);
engine.setNeedClientAuth(true);
p.addLast(new SslHandler(engine));
p.addLast(new MessageDelegatorChannelHandler(serverReceiver, serverLatch));
p.addLast(new ChannelHandlerAdapter() {
@Override
public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) throws Exception {
if (cause.getCause() instanceof SSLHandshakeException) {
serverException = cause.getCause();
serverLatch.countDown();
} else {
ctx.fireExceptionCaught(cause);
}
}
});
serverConnectedChannel = ch;
}
});
cb.group(new NioEventLoopGroup());
cb.channel(NioSocketChannel.class);
cb.handler(new ChannelInitializer<Channel>() {
@Override
protected void initChannel(Channel ch) throws Exception {
ChannelPipeline p = ch.pipeline();
p.addLast(clientSslCtx.newHandler(ch.alloc()));
p.addLast(new MessageDelegatorChannelHandler(clientReceiver, clientLatch));
p.addLast(new ChannelHandlerAdapter() {
@Override
public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) throws Exception {
if (cause.getCause() instanceof SSLHandshakeException) {
clientException = cause.getCause();
clientLatch.countDown();
} else {
ctx.fireExceptionCaught(cause);
}
}
});
}
});
serverChannel = sb.bind(new InetSocketAddress(0)).sync().channel();
int port = ((InetSocketAddress) serverChannel.localAddress()).getPort();
ChannelFuture ccf = cb.connect(new InetSocketAddress(NetUtil.LOCALHOST, port));
assertTrue(ccf.awaitUninterruptibly().isSuccess());
clientChannel = ccf.channel();
}
private void runTest() throws Exception {
runTest(PREFERRED_APPLICATION_LEVEL_PROTOCOL);
}
private void runTest(String expectedApplicationProtocol) throws Exception {
final ByteBuf clientMessage = Unpooled.copiedBuffer("I am a client".getBytes());
final ByteBuf serverMessage = Unpooled.copiedBuffer("I am a server".getBytes());
try {
writeAndVerifyReceived(clientMessage.retain(), clientChannel, serverLatch, serverReceiver);
writeAndVerifyReceived(serverMessage.retain(), serverConnectedChannel, clientLatch, clientReceiver);
if (expectedApplicationProtocol != null) {
verifyApplicationLevelProtocol(clientChannel, expectedApplicationProtocol);
verifyApplicationLevelProtocol(serverConnectedChannel, expectedApplicationProtocol);
}
} finally {
clientMessage.release();
serverMessage.release();
}
}
private void verifyApplicationLevelProtocol(Channel channel, String expectedApplicationProtocol) {
SslHandler handler = channel.pipeline().get(SslHandler.class);
assertNotNull(handler);
String[] protocol = handler.engine().getSession().getProtocol().split(":");
assertNotNull(protocol);
if (expectedApplicationProtocol != null && !expectedApplicationProtocol.isEmpty()) {
assertTrue("protocol.length must be greater than 1 but is " + protocol.length, protocol.length > 1);
assertEquals(expectedApplicationProtocol, protocol[1]);
} else {
assertEquals(1, protocol.length);
}
}
private static void writeAndVerifyReceived(ByteBuf message, Channel sendChannel, CountDownLatch receiverLatch,
MessageReciever receiver) throws Exception {
List<ByteBuf> dataCapture = null;
try {
sendChannel.writeAndFlush(message);
receiverLatch.await(5, TimeUnit.SECONDS);
message.resetReaderIndex();
ArgumentCaptor<ByteBuf> captor = ArgumentCaptor.forClass(ByteBuf.class);
verify(receiver).messageReceived(captor.capture());
dataCapture = captor.getAllValues();
assertEquals(message, dataCapture.get(0));
} finally {
if (dataCapture != null) {
for (ByteBuf data : dataCapture) {
data.release();
}
}
}
@Override
protected SslProvider sslProvider() {
return SslProvider.JDK;
}
}

View File

@ -0,0 +1,23 @@
/*
* Copyright 2015 The Netty Project
*
* The Netty Project licenses this file to you under the Apache License,
* version 2.0 (the "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at:
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*/
package io.netty.handler.ssl;
public class OpenSslEngineTest extends SSLEngineTest {
@Override
protected SslProvider sslProvider() {
return SslProvider.OPENSSL;
}
}

View File

@ -0,0 +1,296 @@
/*
* Copyright 2015 The Netty Project
*
* The Netty Project licenses this file to you under the Apache License,
* version 2.0 (the "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at:
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*/
package io.netty.handler.ssl;
import io.netty.bootstrap.Bootstrap;
import io.netty.bootstrap.ServerBootstrap;
import io.netty.buffer.ByteBuf;
import io.netty.buffer.Unpooled;
import io.netty.channel.Channel;
import io.netty.channel.ChannelFuture;
import io.netty.channel.ChannelHandlerAdapter;
import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.ChannelInitializer;
import io.netty.channel.ChannelPipeline;
import io.netty.channel.SimpleChannelInboundHandler;
import io.netty.channel.nio.NioEventLoopGroup;
import io.netty.channel.socket.nio.NioServerSocketChannel;
import io.netty.channel.socket.nio.NioSocketChannel;
import io.netty.util.NetUtil;
import io.netty.util.concurrent.Future;
import org.junit.After;
import org.junit.Before;
import org.junit.Test;
import org.mockito.ArgumentCaptor;
import org.mockito.Mock;
import org.mockito.MockitoAnnotations;
import javax.net.ssl.SSLEngine;
import javax.net.ssl.SSLException;
import javax.net.ssl.SSLHandshakeException;
import java.io.File;
import java.net.InetSocketAddress;
import java.util.List;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.TimeUnit;
import static org.junit.Assert.*;
import static org.mockito.Mockito.verify;
public abstract class SSLEngineTest {
@Mock
protected MessageReciever serverReceiver;
@Mock
protected MessageReciever clientReceiver;
protected Throwable serverException;
protected Throwable clientException;
protected SslContext serverSslCtx;
protected SslContext clientSslCtx;
protected ServerBootstrap sb;
protected Bootstrap cb;
protected Channel serverChannel;
protected Channel serverConnectedChannel;
protected Channel clientChannel;
protected CountDownLatch serverLatch;
protected CountDownLatch clientLatch;
interface MessageReciever {
void messageReceived(ByteBuf msg);
}
protected static final class MessageDelegatorChannelHandler extends SimpleChannelInboundHandler<ByteBuf> {
private final MessageReciever receiver;
private final CountDownLatch latch;
public MessageDelegatorChannelHandler(MessageReciever receiver, CountDownLatch latch) {
super(false);
this.receiver = receiver;
this.latch = latch;
}
@Override
protected void messageReceived(ChannelHandlerContext ctx, ByteBuf msg) throws Exception {
receiver.messageReceived(msg);
latch.countDown();
}
}
@Before
public void setup() {
MockitoAnnotations.initMocks(this);
serverLatch = new CountDownLatch(1);
clientLatch = new CountDownLatch(1);
}
@After
public void tearDown() throws InterruptedException {
if (serverChannel != null) {
serverChannel.close().sync();
Future<?> serverGroup = sb.group().shutdownGracefully(0, 0, TimeUnit.MILLISECONDS);
Future<?> serverChildGroup = sb.childGroup().shutdownGracefully(0, 0, TimeUnit.MILLISECONDS);
Future<?> clientGroup = cb.group().shutdownGracefully(0, 0, TimeUnit.MILLISECONDS);
serverGroup.sync();
serverChildGroup.sync();
clientGroup.sync();
}
clientChannel = null;
serverChannel = null;
serverConnectedChannel = null;
serverException = null;
}
@Test
public void testMutualAuthSameCerts() throws Exception {
mySetupMutualAuth(new File(getClass().getResource("test_unencrypted.pem").getFile()),
new File(getClass().getResource("test.crt").getFile()),
null);
runTest(null);
}
@Test
public void testMutualAuthDiffCerts() throws Exception {
File serverKeyFile = new File(getClass().getResource("test_encrypted.pem").getFile());
File serverCrtFile = new File(getClass().getResource("test.crt").getFile());
String serverKeyPassword = "12345";
File clientKeyFile = new File(getClass().getResource("test2_encrypted.pem").getFile());
File clientCrtFile = new File(getClass().getResource("test2.crt").getFile());
String clientKeyPassword = "12345";
mySetupMutualAuth(clientCrtFile, serverKeyFile, serverCrtFile, serverKeyPassword,
serverCrtFile, clientKeyFile, clientCrtFile, clientKeyPassword);
runTest(null);
}
@Test
public void testMutualAuthDiffCertsServerFailure() throws Exception {
File serverKeyFile = new File(getClass().getResource("test_encrypted.pem").getFile());
File serverCrtFile = new File(getClass().getResource("test.crt").getFile());
String serverKeyPassword = "12345";
File clientKeyFile = new File(getClass().getResource("test2_encrypted.pem").getFile());
File clientCrtFile = new File(getClass().getResource("test2.crt").getFile());
String clientKeyPassword = "12345";
// Client trusts server but server only trusts itself
mySetupMutualAuth(serverCrtFile, serverKeyFile, serverCrtFile, serverKeyPassword,
serverCrtFile, clientKeyFile, clientCrtFile, clientKeyPassword);
assertTrue(serverLatch.await(2, TimeUnit.SECONDS));
assertTrue(serverException instanceof SSLHandshakeException);
}
@Test
public void testMutualAuthDiffCertsClientFailure() throws Exception {
File serverKeyFile = new File(getClass().getResource("test_unencrypted.pem").getFile());
File serverCrtFile = new File(getClass().getResource("test.crt").getFile());
String serverKeyPassword = null;
File clientKeyFile = new File(getClass().getResource("test2_unencrypted.pem").getFile());
File clientCrtFile = new File(getClass().getResource("test2.crt").getFile());
String clientKeyPassword = null;
// Server trusts client but client only trusts itself
mySetupMutualAuth(clientCrtFile, serverKeyFile, serverCrtFile, serverKeyPassword,
clientCrtFile, clientKeyFile, clientCrtFile, clientKeyPassword);
assertTrue(clientLatch.await(2, TimeUnit.SECONDS));
assertTrue(clientException instanceof SSLHandshakeException);
}
private void mySetupMutualAuth(File keyFile, File crtFile, String keyPassword)
throws SSLException, InterruptedException {
mySetupMutualAuth(crtFile, keyFile, crtFile, keyPassword, crtFile, keyFile, crtFile, keyPassword);
}
private void mySetupMutualAuth(
File servertTrustCrtFile, File serverKeyFile, File serverCrtFile, String serverKeyPassword,
File clientTrustCrtFile, File clientKeyFile, File clientCrtFile, String clientKeyPassword)
throws InterruptedException, SSLException {
serverSslCtx = SslContext.newServerContext(sslProvider(), servertTrustCrtFile, null,
serverCrtFile, serverKeyFile, serverKeyPassword, null,
null, IdentityCipherSuiteFilter.INSTANCE, null, 0, 0);
clientSslCtx = SslContext.newClientContext(sslProvider(), clientTrustCrtFile, null,
clientCrtFile, clientKeyFile, clientKeyPassword, null,
null, IdentityCipherSuiteFilter.INSTANCE,
null, 0, 0);
serverConnectedChannel = null;
sb = new ServerBootstrap();
cb = new Bootstrap();
sb.group(new NioEventLoopGroup(), new NioEventLoopGroup());
sb.channel(NioServerSocketChannel.class);
sb.childHandler(new ChannelInitializer<Channel>() {
@Override
protected void initChannel(Channel ch) throws Exception {
ChannelPipeline p = ch.pipeline();
SSLEngine engine = serverSslCtx.newEngine(ch.alloc());
engine.setUseClientMode(false);
engine.setNeedClientAuth(true);
p.addLast(new SslHandler(engine));
p.addLast(new MessageDelegatorChannelHandler(serverReceiver, serverLatch));
p.addLast(new ChannelHandlerAdapter() {
@Override
public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) throws Exception {
if (cause.getCause() instanceof SSLHandshakeException) {
serverException = cause.getCause();
serverLatch.countDown();
} else {
ctx.fireExceptionCaught(cause);
}
}
});
serverConnectedChannel = ch;
}
});
cb.group(new NioEventLoopGroup());
cb.channel(NioSocketChannel.class);
cb.handler(new ChannelInitializer<Channel>() {
@Override
protected void initChannel(Channel ch) throws Exception {
ChannelPipeline p = ch.pipeline();
p.addLast(clientSslCtx.newHandler(ch.alloc()));
p.addLast(new MessageDelegatorChannelHandler(clientReceiver, clientLatch));
p.addLast(new ChannelHandlerAdapter() {
@Override
public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) throws Exception {
cause.printStackTrace();
if (cause.getCause() instanceof SSLHandshakeException) {
clientException = cause.getCause();
clientLatch.countDown();
} else {
ctx.fireExceptionCaught(cause);
}
}
});
}
});
serverChannel = sb.bind(new InetSocketAddress(0)).sync().channel();
int port = ((InetSocketAddress) serverChannel.localAddress()).getPort();
ChannelFuture ccf = cb.connect(new InetSocketAddress(NetUtil.LOCALHOST, port));
assertTrue(ccf.awaitUninterruptibly().isSuccess());
clientChannel = ccf.channel();
}
protected void runTest(String expectedApplicationProtocol) throws Exception {
final ByteBuf clientMessage = Unpooled.copiedBuffer("I am a client".getBytes());
final ByteBuf serverMessage = Unpooled.copiedBuffer("I am a server".getBytes());
try {
writeAndVerifyReceived(clientMessage.retain(), clientChannel, serverLatch, serverReceiver);
writeAndVerifyReceived(serverMessage.retain(), serverConnectedChannel, clientLatch, clientReceiver);
if (expectedApplicationProtocol != null) {
verifyApplicationLevelProtocol(clientChannel, expectedApplicationProtocol);
verifyApplicationLevelProtocol(serverConnectedChannel, expectedApplicationProtocol);
}
} finally {
clientMessage.release();
serverMessage.release();
}
}
private static void verifyApplicationLevelProtocol(Channel channel, String expectedApplicationProtocol) {
SslHandler handler = channel.pipeline().get(SslHandler.class);
assertNotNull(handler);
String[] protocol = handler.engine().getSession().getProtocol().split(":");
assertNotNull(protocol);
if (expectedApplicationProtocol != null && !expectedApplicationProtocol.isEmpty()) {
assertTrue("protocol.length must be greater than 1 but is " + protocol.length, protocol.length > 1);
assertEquals(expectedApplicationProtocol, protocol[1]);
} else {
assertEquals(1, protocol.length);
}
}
private static void writeAndVerifyReceived(ByteBuf message, Channel sendChannel, CountDownLatch receiverLatch,
MessageReciever receiver) throws Exception {
List<ByteBuf> dataCapture = null;
try {
sendChannel.writeAndFlush(message);
receiverLatch.await(5, TimeUnit.SECONDS);
message.resetReaderIndex();
ArgumentCaptor<ByteBuf> captor = ArgumentCaptor.forClass(ByteBuf.class);
verify(receiver).messageReceived(captor.capture());
dataCapture = captor.getAllValues();
assertEquals(message, dataCapture.get(0));
} finally {
if (dataCapture != null) {
for (ByteBuf data : dataCapture) {
data.release();
}
}
}
}
protected abstract SslProvider sslProvider();
}