SslContext to support TLS/SSL protocols

Motivation:
SslContext and SslContextBuilder do not support a way to specify the desired TLS protocols. This currently requires that the user extracts the SSLEngine once a context is built and manually call SSLEngine#setEnabledProtocols(String[]). Something this critical should be supported at the SslContext level.

Modifications:
- SslContextBuilder should accept a list of protocols to configure for each SslEngine

Result:
SslContext consistently sets the supported TLS/SSL protocols.
This commit is contained in:
Scott Mitchell 2017-03-07 19:39:31 -08:00
parent e18c85b768
commit a2304287a1
13 changed files with 123 additions and 62 deletions

View File

@ -248,16 +248,17 @@ public final class JdkSslClientContext extends JdkSslContext {
trustCertCollectionFile), trustManagerFactory, trustCertCollectionFile), trustManagerFactory,
toX509CertificatesInternal(keyCertChainFile), toPrivateKeyInternal(keyFile, keyPassword), toX509CertificatesInternal(keyCertChainFile), toPrivateKeyInternal(keyFile, keyPassword),
keyPassword, keyManagerFactory, sessionCacheSize, sessionTimeout), true, keyPassword, keyManagerFactory, sessionCacheSize, sessionTimeout), true,
ciphers, cipherFilter, apn, ClientAuth.NONE, false); ciphers, cipherFilter, apn, ClientAuth.NONE, null, false);
} }
JdkSslClientContext(X509Certificate[] trustCertCollection, TrustManagerFactory trustManagerFactory, JdkSslClientContext(X509Certificate[] trustCertCollection, TrustManagerFactory trustManagerFactory,
X509Certificate[] keyCertChain, PrivateKey key, String keyPassword, X509Certificate[] keyCertChain, PrivateKey key, String keyPassword,
KeyManagerFactory keyManagerFactory, Iterable<String> ciphers, CipherSuiteFilter cipherFilter, KeyManagerFactory keyManagerFactory, Iterable<String> ciphers, CipherSuiteFilter cipherFilter,
ApplicationProtocolConfig apn, long sessionCacheSize, long sessionTimeout) throws SSLException { ApplicationProtocolConfig apn, String[] protocols, long sessionCacheSize, long sessionTimeout)
throws SSLException {
super(newSSLContext(trustCertCollection, trustManagerFactory, keyCertChain, key, keyPassword, super(newSSLContext(trustCertCollection, trustManagerFactory, keyCertChain, key, keyPassword,
keyManagerFactory, sessionCacheSize, sessionTimeout), true, keyManagerFactory, sessionCacheSize, sessionTimeout), true,
ciphers, cipherFilter, toNegotiator(apn, false), ClientAuth.NONE, false); ciphers, cipherFilter, toNegotiator(apn, false), ClientAuth.NONE, protocols, false);
} }
private static SSLContext newSSLContext(X509Certificate[] trustCertCollection, private static SSLContext newSSLContext(X509Certificate[] trustCertCollection,

View File

@ -52,7 +52,7 @@ public class JdkSslContext extends SslContext {
private static final InternalLogger logger = InternalLoggerFactory.getInstance(JdkSslContext.class); private static final InternalLogger logger = InternalLoggerFactory.getInstance(JdkSslContext.class);
static final String PROTOCOL = "TLS"; static final String PROTOCOL = "TLS";
static final String[] PROTOCOLS; static final String[] DEFAULT_PROTOCOLS;
static final List<String> DEFAULT_CIPHERS; static final List<String> DEFAULT_CIPHERS;
static final Set<String> SUPPORTED_CIPHERS; static final Set<String> SUPPORTED_CIPHERS;
@ -80,9 +80,9 @@ public class JdkSslContext extends SslContext {
"TLSv1.2", "TLSv1.1", "TLSv1"); "TLSv1.2", "TLSv1.1", "TLSv1");
if (!protocols.isEmpty()) { if (!protocols.isEmpty()) {
PROTOCOLS = protocols.toArray(new String[protocols.size()]); DEFAULT_PROTOCOLS = protocols.toArray(new String[protocols.size()]);
} else { } else {
PROTOCOLS = engine.getEnabledProtocols(); DEFAULT_PROTOCOLS = engine.getEnabledProtocols();
} }
// Choose the sensible default list of cipher suites. // Choose the sensible default list of cipher suites.
@ -120,7 +120,7 @@ public class JdkSslContext extends SslContext {
DEFAULT_CIPHERS = Collections.unmodifiableList(ciphers); DEFAULT_CIPHERS = Collections.unmodifiableList(ciphers);
if (logger.isDebugEnabled()) { if (logger.isDebugEnabled()) {
logger.debug("Default protocols (JDK): {} ", Arrays.asList(PROTOCOLS)); logger.debug("Default protocols (JDK): {} ", Arrays.asList(DEFAULT_PROTOCOLS));
logger.debug("Default cipher suites (JDK): {}", DEFAULT_CIPHERS); logger.debug("Default cipher suites (JDK): {}", DEFAULT_CIPHERS);
} }
} }
@ -133,6 +133,7 @@ public class JdkSslContext extends SslContext {
} }
} }
private final String[] protocols;
private final String[] cipherSuites; private final String[] cipherSuites;
private final List<String> unmodifiableCipherSuites; private final List<String> unmodifiableCipherSuites;
private final JdkApplicationProtocolNegotiator apn; private final JdkApplicationProtocolNegotiator apn;
@ -150,7 +151,7 @@ public class JdkSslContext extends SslContext {
public JdkSslContext(SSLContext sslContext, boolean isClient, public JdkSslContext(SSLContext sslContext, boolean isClient,
ClientAuth clientAuth) { ClientAuth clientAuth) {
this(sslContext, isClient, null, IdentityCipherSuiteFilter.INSTANCE, this(sslContext, isClient, null, IdentityCipherSuiteFilter.INSTANCE,
JdkDefaultApplicationProtocolNegotiator.INSTANCE, clientAuth, false); JdkDefaultApplicationProtocolNegotiator.INSTANCE, clientAuth, null, false);
} }
/** /**
@ -166,16 +167,17 @@ public class JdkSslContext extends SslContext {
public JdkSslContext(SSLContext sslContext, boolean isClient, Iterable<String> ciphers, public JdkSslContext(SSLContext sslContext, boolean isClient, Iterable<String> ciphers,
CipherSuiteFilter cipherFilter, ApplicationProtocolConfig apn, CipherSuiteFilter cipherFilter, ApplicationProtocolConfig apn,
ClientAuth clientAuth) { ClientAuth clientAuth) {
this(sslContext, isClient, ciphers, cipherFilter, toNegotiator(apn, !isClient), clientAuth, false); this(sslContext, isClient, ciphers, cipherFilter, toNegotiator(apn, !isClient), clientAuth, null, false);
} }
JdkSslContext(SSLContext sslContext, boolean isClient, Iterable<String> ciphers, CipherSuiteFilter cipherFilter, JdkSslContext(SSLContext sslContext, boolean isClient, Iterable<String> ciphers, CipherSuiteFilter cipherFilter,
JdkApplicationProtocolNegotiator apn, ClientAuth clientAuth, boolean startTls) { JdkApplicationProtocolNegotiator apn, ClientAuth clientAuth, String[] protocols, boolean startTls) {
super(startTls); super(startTls);
this.apn = checkNotNull(apn, "apn"); this.apn = checkNotNull(apn, "apn");
this.clientAuth = checkNotNull(clientAuth, "clientAuth"); this.clientAuth = checkNotNull(clientAuth, "clientAuth");
cipherSuites = checkNotNull(cipherFilter, "cipherFilter").filterCipherSuites( cipherSuites = checkNotNull(cipherFilter, "cipherFilter").filterCipherSuites(
ciphers, DEFAULT_CIPHERS, SUPPORTED_CIPHERS); ciphers, DEFAULT_CIPHERS, SUPPORTED_CIPHERS);
this.protocols = protocols == null ? DEFAULT_PROTOCOLS : protocols;
unmodifiableCipherSuites = Collections.unmodifiableList(Arrays.asList(cipherSuites)); unmodifiableCipherSuites = Collections.unmodifiableList(Arrays.asList(cipherSuites));
this.sslContext = checkNotNull(sslContext, "sslContext"); this.sslContext = checkNotNull(sslContext, "sslContext");
this.isClient = isClient; this.isClient = isClient;
@ -232,7 +234,7 @@ public class JdkSslContext extends SslContext {
private SSLEngine configureAndWrapEngine(SSLEngine engine) { private SSLEngine configureAndWrapEngine(SSLEngine engine) {
engine.setEnabledCipherSuites(cipherSuites); engine.setEnabledCipherSuites(cipherSuites);
engine.setEnabledProtocols(PROTOCOLS); engine.setEnabledProtocols(protocols);
engine.setUseClientMode(isClient()); engine.setUseClientMode(isClient());
if (isServer()) { if (isServer()) {
switch (clientAuth) { switch (clientAuth) {

View File

@ -215,17 +215,17 @@ public final class JdkSslServerContext extends JdkSslContext {
super(newSSLContext(toX509CertificatesInternal(trustCertCollectionFile), trustManagerFactory, super(newSSLContext(toX509CertificatesInternal(trustCertCollectionFile), trustManagerFactory,
toX509CertificatesInternal(keyCertChainFile), toPrivateKeyInternal(keyFile, keyPassword), toX509CertificatesInternal(keyCertChainFile), toPrivateKeyInternal(keyFile, keyPassword),
keyPassword, keyManagerFactory, sessionCacheSize, sessionTimeout), false, keyPassword, keyManagerFactory, sessionCacheSize, sessionTimeout), false,
ciphers, cipherFilter, apn, ClientAuth.NONE, false); ciphers, cipherFilter, apn, ClientAuth.NONE, null, false);
} }
JdkSslServerContext(X509Certificate[] trustCertCollection, TrustManagerFactory trustManagerFactory, JdkSslServerContext(X509Certificate[] trustCertCollection, TrustManagerFactory trustManagerFactory,
X509Certificate[] keyCertChain, PrivateKey key, String keyPassword, X509Certificate[] keyCertChain, PrivateKey key, String keyPassword,
KeyManagerFactory keyManagerFactory, Iterable<String> ciphers, CipherSuiteFilter cipherFilter, KeyManagerFactory keyManagerFactory, Iterable<String> ciphers, CipherSuiteFilter cipherFilter,
ApplicationProtocolConfig apn, long sessionCacheSize, long sessionTimeout, ApplicationProtocolConfig apn, long sessionCacheSize, long sessionTimeout,
ClientAuth clientAuth, boolean startTls) throws SSLException { ClientAuth clientAuth, String[] protocols, boolean startTls) throws SSLException {
super(newSSLContext(trustCertCollection, trustManagerFactory, keyCertChain, key, super(newSSLContext(trustCertCollection, trustManagerFactory, keyCertChain, key,
keyPassword, keyManagerFactory, sessionCacheSize, sessionTimeout), false, keyPassword, keyManagerFactory, sessionCacheSize, sessionTimeout), false,
ciphers, cipherFilter, toNegotiator(apn, true), clientAuth, startTls); ciphers, cipherFilter, toNegotiator(apn, true), clientAuth, protocols, startTls);
} }
private static SSLContext newSSLContext(X509Certificate[] trustCertCollection, private static SSLContext newSSLContext(X509Certificate[] trustCertCollection,

View File

@ -175,17 +175,17 @@ public final class OpenSslClientContext extends OpenSslContext {
throws SSLException { throws SSLException {
this(toX509CertificatesInternal(trustCertCollectionFile), trustManagerFactory, this(toX509CertificatesInternal(trustCertCollectionFile), trustManagerFactory,
toX509CertificatesInternal(keyCertChainFile), toPrivateKeyInternal(keyFile, keyPassword), toX509CertificatesInternal(keyCertChainFile), toPrivateKeyInternal(keyFile, keyPassword),
keyPassword, keyManagerFactory, ciphers, cipherFilter, apn, sessionCacheSize, sessionTimeout); keyPassword, keyManagerFactory, ciphers, cipherFilter, apn, null, sessionCacheSize, sessionTimeout);
} }
OpenSslClientContext(X509Certificate[] trustCertCollection, TrustManagerFactory trustManagerFactory, OpenSslClientContext(X509Certificate[] trustCertCollection, TrustManagerFactory trustManagerFactory,
X509Certificate[] keyCertChain, PrivateKey key, String keyPassword, X509Certificate[] keyCertChain, PrivateKey key, String keyPassword,
KeyManagerFactory keyManagerFactory, Iterable<String> ciphers, KeyManagerFactory keyManagerFactory, Iterable<String> ciphers,
CipherSuiteFilter cipherFilter, ApplicationProtocolConfig apn, CipherSuiteFilter cipherFilter, ApplicationProtocolConfig apn, String[] protocols,
long sessionCacheSize, long sessionTimeout) long sessionCacheSize, long sessionTimeout)
throws SSLException { throws SSLException {
super(ciphers, cipherFilter, apn, sessionCacheSize, sessionTimeout, SSL.SSL_MODE_CLIENT, keyCertChain, super(ciphers, cipherFilter, apn, sessionCacheSize, sessionTimeout, SSL.SSL_MODE_CLIENT, keyCertChain,
ClientAuth.NONE, false); ClientAuth.NONE, protocols, false);
boolean success = false; boolean success = false;
try { try {
sessionContext = newSessionContext(this, ctx, engineMap, trustCertCollection, trustManagerFactory, sessionContext = newSessionContext(this, ctx, engineMap, trustCertCollection, trustManagerFactory,

View File

@ -29,18 +29,18 @@ import javax.net.ssl.SSLException;
public abstract class OpenSslContext extends ReferenceCountedOpenSslContext { public abstract class OpenSslContext extends ReferenceCountedOpenSslContext {
OpenSslContext(Iterable<String> ciphers, CipherSuiteFilter cipherFilter, ApplicationProtocolConfig apnCfg, OpenSslContext(Iterable<String> ciphers, CipherSuiteFilter cipherFilter, ApplicationProtocolConfig apnCfg,
long sessionCacheSize, long sessionTimeout, int mode, Certificate[] keyCertChain, long sessionCacheSize, long sessionTimeout, int mode, Certificate[] keyCertChain,
ClientAuth clientAuth, boolean startTls) ClientAuth clientAuth, String[] protocols, boolean startTls)
throws SSLException { throws SSLException {
super(ciphers, cipherFilter, apnCfg, sessionCacheSize, sessionTimeout, mode, keyCertChain, super(ciphers, cipherFilter, apnCfg, sessionCacheSize, sessionTimeout, mode, keyCertChain,
clientAuth, startTls, false); clientAuth, protocols, startTls, false);
} }
OpenSslContext(Iterable<String> ciphers, CipherSuiteFilter cipherFilter, OpenSslContext(Iterable<String> ciphers, CipherSuiteFilter cipherFilter,
OpenSslApplicationProtocolNegotiator apn, long sessionCacheSize, OpenSslApplicationProtocolNegotiator apn, long sessionCacheSize,
long sessionTimeout, int mode, Certificate[] keyCertChain, long sessionTimeout, int mode, Certificate[] keyCertChain,
ClientAuth clientAuth, boolean startTls) throws SSLException { ClientAuth clientAuth, String[] protocols, boolean startTls) throws SSLException {
super(ciphers, cipherFilter, apn, sessionCacheSize, sessionTimeout, mode, keyCertChain, clientAuth, startTls, super(ciphers, cipherFilter, apn, sessionCacheSize, sessionTimeout, mode, keyCertChain, clientAuth, protocols,
false); startTls, false);
} }
@Override @Override

View File

@ -323,16 +323,17 @@ public final class OpenSslServerContext extends OpenSslContext {
this(toX509CertificatesInternal(trustCertCollectionFile), trustManagerFactory, this(toX509CertificatesInternal(trustCertCollectionFile), trustManagerFactory,
toX509CertificatesInternal(keyCertChainFile), toPrivateKeyInternal(keyFile, keyPassword), toX509CertificatesInternal(keyCertChainFile), toPrivateKeyInternal(keyFile, keyPassword),
keyPassword, keyManagerFactory, ciphers, cipherFilter, keyPassword, keyManagerFactory, ciphers, cipherFilter,
apn, sessionCacheSize, sessionTimeout, ClientAuth.NONE, false); apn, sessionCacheSize, sessionTimeout, ClientAuth.NONE, null, false);
} }
OpenSslServerContext( OpenSslServerContext(
X509Certificate[] trustCertCollection, TrustManagerFactory trustManagerFactory, X509Certificate[] trustCertCollection, TrustManagerFactory trustManagerFactory,
X509Certificate[] keyCertChain, PrivateKey key, String keyPassword, KeyManagerFactory keyManagerFactory, X509Certificate[] keyCertChain, PrivateKey key, String keyPassword, KeyManagerFactory keyManagerFactory,
Iterable<String> ciphers, CipherSuiteFilter cipherFilter, ApplicationProtocolConfig apn, Iterable<String> ciphers, CipherSuiteFilter cipherFilter, ApplicationProtocolConfig apn,
long sessionCacheSize, long sessionTimeout, ClientAuth clientAuth, boolean startTls) throws SSLException { long sessionCacheSize, long sessionTimeout, ClientAuth clientAuth, String[] protocols, boolean startTls)
throws SSLException {
this(trustCertCollection, trustManagerFactory, keyCertChain, key, keyPassword, keyManagerFactory, ciphers, this(trustCertCollection, trustManagerFactory, keyCertChain, key, keyPassword, keyManagerFactory, ciphers,
cipherFilter, toNegotiator(apn), sessionCacheSize, sessionTimeout, clientAuth, startTls); cipherFilter, toNegotiator(apn), sessionCacheSize, sessionTimeout, clientAuth, protocols, startTls);
} }
@SuppressWarnings("deprecation") @SuppressWarnings("deprecation")
@ -340,9 +341,10 @@ public final class OpenSslServerContext extends OpenSslContext {
X509Certificate[] trustCertCollection, TrustManagerFactory trustManagerFactory, X509Certificate[] trustCertCollection, TrustManagerFactory trustManagerFactory,
X509Certificate[] keyCertChain, PrivateKey key, String keyPassword, KeyManagerFactory keyManagerFactory, X509Certificate[] keyCertChain, PrivateKey key, String keyPassword, KeyManagerFactory keyManagerFactory,
Iterable<String> ciphers, CipherSuiteFilter cipherFilter, OpenSslApplicationProtocolNegotiator apn, Iterable<String> ciphers, CipherSuiteFilter cipherFilter, OpenSslApplicationProtocolNegotiator apn,
long sessionCacheSize, long sessionTimeout, ClientAuth clientAuth, boolean startTls) throws SSLException { long sessionCacheSize, long sessionTimeout, ClientAuth clientAuth, String[] protocols, boolean startTls)
throws SSLException {
super(ciphers, cipherFilter, apn, sessionCacheSize, sessionTimeout, SSL.SSL_MODE_SERVER, keyCertChain, super(ciphers, cipherFilter, apn, sessionCacheSize, sessionTimeout, SSL.SSL_MODE_SERVER, keyCertChain,
clientAuth, startTls); clientAuth, protocols, startTls);
// Create a new SSL_CTX and configure it. // Create a new SSL_CTX and configure it.
boolean success = false; boolean success = false;
try { try {

View File

@ -54,10 +54,10 @@ public final class ReferenceCountedOpenSslClientContext extends ReferenceCounted
X509Certificate[] keyCertChain, PrivateKey key, String keyPassword, X509Certificate[] keyCertChain, PrivateKey key, String keyPassword,
KeyManagerFactory keyManagerFactory, Iterable<String> ciphers, KeyManagerFactory keyManagerFactory, Iterable<String> ciphers,
CipherSuiteFilter cipherFilter, ApplicationProtocolConfig apn, CipherSuiteFilter cipherFilter, ApplicationProtocolConfig apn,
long sessionCacheSize, long sessionTimeout) String[] protocols, long sessionCacheSize, long sessionTimeout)
throws SSLException { throws SSLException {
super(ciphers, cipherFilter, apn, sessionCacheSize, sessionTimeout, SSL.SSL_MODE_CLIENT, keyCertChain, super(ciphers, cipherFilter, apn, sessionCacheSize, sessionTimeout, SSL.SSL_MODE_CLIENT, keyCertChain,
ClientAuth.NONE, false, true); ClientAuth.NONE, protocols, false, true);
boolean success = false; boolean success = false;
try { try {
sessionContext = newSessionContext(this, ctx, engineMap, trustCertCollection, trustManagerFactory, sessionContext = newSessionContext(this, ctx, engineMap, trustCertCollection, trustManagerFactory,

View File

@ -140,6 +140,7 @@ public abstract class ReferenceCountedOpenSslContext extends SslContext implemen
final Certificate[] keyCertChain; final Certificate[] keyCertChain;
final ClientAuth clientAuth; final ClientAuth clientAuth;
final String[] protocols;
final OpenSslEngineMap engineMap = new DefaultOpenSslEngineMap(); final OpenSslEngineMap engineMap = new DefaultOpenSslEngineMap();
private volatile boolean rejectRemoteInitiatedRenegotiation; private volatile boolean rejectRemoteInitiatedRenegotiation;
private volatile int bioNonApplicationBufferSize = DEFAULT_BIO_NON_APPLICATION_BUFFER_SIZE; private volatile int bioNonApplicationBufferSize = DEFAULT_BIO_NON_APPLICATION_BUFFER_SIZE;
@ -211,16 +212,17 @@ public abstract class ReferenceCountedOpenSslContext extends SslContext implemen
ReferenceCountedOpenSslContext(Iterable<String> ciphers, CipherSuiteFilter cipherFilter, ReferenceCountedOpenSslContext(Iterable<String> ciphers, CipherSuiteFilter cipherFilter,
ApplicationProtocolConfig apnCfg, long sessionCacheSize, long sessionTimeout, ApplicationProtocolConfig apnCfg, long sessionCacheSize, long sessionTimeout,
int mode, Certificate[] keyCertChain, ClientAuth clientAuth, boolean startTls, int mode, Certificate[] keyCertChain, ClientAuth clientAuth, String[] protocols,
boolean leakDetection) throws SSLException { boolean startTls, boolean leakDetection) throws SSLException {
this(ciphers, cipherFilter, toNegotiator(apnCfg), sessionCacheSize, sessionTimeout, mode, keyCertChain, this(ciphers, cipherFilter, toNegotiator(apnCfg), sessionCacheSize, sessionTimeout, mode, keyCertChain,
clientAuth, startTls, leakDetection); clientAuth, protocols, startTls, leakDetection);
} }
ReferenceCountedOpenSslContext(Iterable<String> ciphers, CipherSuiteFilter cipherFilter, ReferenceCountedOpenSslContext(Iterable<String> ciphers, CipherSuiteFilter cipherFilter,
OpenSslApplicationProtocolNegotiator apn, long sessionCacheSize, OpenSslApplicationProtocolNegotiator apn, long sessionCacheSize,
long sessionTimeout, int mode, Certificate[] keyCertChain, long sessionTimeout, int mode, Certificate[] keyCertChain,
ClientAuth clientAuth, boolean startTls, boolean leakDetection) throws SSLException { ClientAuth clientAuth, String[] protocols, boolean startTls, boolean leakDetection)
throws SSLException {
super(startTls); super(startTls);
OpenSsl.ensureAvailability(); OpenSsl.ensureAvailability();
@ -231,6 +233,7 @@ public abstract class ReferenceCountedOpenSslContext extends SslContext implemen
leak = leakDetection ? leakDetector.track(this) : null; leak = leakDetection ? leakDetector.track(this) : null;
this.mode = mode; this.mode = mode;
this.clientAuth = isServer() ? checkNotNull(clientAuth, "clientAuth") : ClientAuth.NONE; this.clientAuth = isServer() ? checkNotNull(clientAuth, "clientAuth") : ClientAuth.NONE;
this.protocols = protocols;
if (mode == SSL.SSL_MODE_SERVER) { if (mode == SSL.SSL_MODE_SERVER) {
rejectRemoteInitiatedRenegotiation = rejectRemoteInitiatedRenegotiation =
@ -305,19 +308,19 @@ public abstract class ReferenceCountedOpenSslContext extends SslContext implemen
List<String> nextProtoList = apn.protocols(); List<String> nextProtoList = apn.protocols();
/* Set next protocols for next protocol negotiation extension, if specified */ /* Set next protocols for next protocol negotiation extension, if specified */
if (!nextProtoList.isEmpty()) { if (!nextProtoList.isEmpty()) {
String[] protocols = nextProtoList.toArray(new String[nextProtoList.size()]); String[] appProtocols = nextProtoList.toArray(new String[nextProtoList.size()]);
int selectorBehavior = opensslSelectorFailureBehavior(apn.selectorFailureBehavior()); int selectorBehavior = opensslSelectorFailureBehavior(apn.selectorFailureBehavior());
switch (apn.protocol()) { switch (apn.protocol()) {
case NPN: case NPN:
SSLContext.setNpnProtos(ctx, protocols, selectorBehavior); SSLContext.setNpnProtos(ctx, appProtocols, selectorBehavior);
break; break;
case ALPN: case ALPN:
SSLContext.setAlpnProtos(ctx, protocols, selectorBehavior); SSLContext.setAlpnProtos(ctx, appProtocols, selectorBehavior);
break; break;
case NPN_AND_ALPN: case NPN_AND_ALPN:
SSLContext.setNpnProtos(ctx, protocols, selectorBehavior); SSLContext.setNpnProtos(ctx, appProtocols, selectorBehavior);
SSLContext.setAlpnProtos(ctx, protocols, selectorBehavior); SSLContext.setAlpnProtos(ctx, appProtocols, selectorBehavior);
break; break;
default: default:
throw new Error(); throw new Error();

View File

@ -237,6 +237,10 @@ public class ReferenceCountedOpenSslEngine extends SSLEngine implements Referenc
// needed JNI methods. // needed JNI methods.
setClientAuth(clientMode ? ClientAuth.NONE : context.clientAuth); setClientAuth(clientMode ? ClientAuth.NONE : context.clientAuth);
if (context.protocols != null) {
setEnabledProtocols(context.protocols);
}
// Use SNI if peerHost was specified // Use SNI if peerHost was specified
// See https://github.com/netty/netty/issues/4746 // See https://github.com/netty/netty/issues/4746
if (clientMode && peerHost != null) { if (clientMode && peerHost != null) {

View File

@ -49,18 +49,20 @@ public final class ReferenceCountedOpenSslServerContext extends ReferenceCounted
X509Certificate[] trustCertCollection, TrustManagerFactory trustManagerFactory, X509Certificate[] trustCertCollection, TrustManagerFactory trustManagerFactory,
X509Certificate[] keyCertChain, PrivateKey key, String keyPassword, KeyManagerFactory keyManagerFactory, X509Certificate[] keyCertChain, PrivateKey key, String keyPassword, KeyManagerFactory keyManagerFactory,
Iterable<String> ciphers, CipherSuiteFilter cipherFilter, ApplicationProtocolConfig apn, Iterable<String> ciphers, CipherSuiteFilter cipherFilter, ApplicationProtocolConfig apn,
long sessionCacheSize, long sessionTimeout, ClientAuth clientAuth, boolean startTls) throws SSLException { long sessionCacheSize, long sessionTimeout, ClientAuth clientAuth, String[] protocols, boolean startTls)
throws SSLException {
this(trustCertCollection, trustManagerFactory, keyCertChain, key, keyPassword, keyManagerFactory, ciphers, this(trustCertCollection, trustManagerFactory, keyCertChain, key, keyPassword, keyManagerFactory, ciphers,
cipherFilter, toNegotiator(apn), sessionCacheSize, sessionTimeout, clientAuth, startTls); cipherFilter, toNegotiator(apn), sessionCacheSize, sessionTimeout, clientAuth, protocols, startTls);
} }
private ReferenceCountedOpenSslServerContext( private ReferenceCountedOpenSslServerContext(
X509Certificate[] trustCertCollection, TrustManagerFactory trustManagerFactory, X509Certificate[] trustCertCollection, TrustManagerFactory trustManagerFactory,
X509Certificate[] keyCertChain, PrivateKey key, String keyPassword, KeyManagerFactory keyManagerFactory, X509Certificate[] keyCertChain, PrivateKey key, String keyPassword, KeyManagerFactory keyManagerFactory,
Iterable<String> ciphers, CipherSuiteFilter cipherFilter, OpenSslApplicationProtocolNegotiator apn, Iterable<String> ciphers, CipherSuiteFilter cipherFilter, OpenSslApplicationProtocolNegotiator apn,
long sessionCacheSize, long sessionTimeout, ClientAuth clientAuth, boolean startTls) throws SSLException { long sessionCacheSize, long sessionTimeout, ClientAuth clientAuth, String[] protocols, boolean startTls)
throws SSLException {
super(ciphers, cipherFilter, apn, sessionCacheSize, sessionTimeout, SSL.SSL_MODE_SERVER, keyCertChain, super(ciphers, cipherFilter, apn, sessionCacheSize, sessionTimeout, SSL.SSL_MODE_SERVER, keyCertChain,
clientAuth, startTls, true); clientAuth, protocols, startTls, true);
// Create a new SSL_CTX and configure it. // Create a new SSL_CTX and configure it.
boolean success = false; boolean success = false;
try { try {

View File

@ -384,7 +384,7 @@ public abstract class SslContext {
toX509Certificates(keyCertChainFile), toX509Certificates(keyCertChainFile),
toPrivateKey(keyFile, keyPassword), toPrivateKey(keyFile, keyPassword),
keyPassword, keyManagerFactory, ciphers, cipherFilter, apn, keyPassword, keyManagerFactory, ciphers, cipherFilter, apn,
sessionCacheSize, sessionTimeout, ClientAuth.NONE, false); sessionCacheSize, sessionTimeout, ClientAuth.NONE, null, false);
} catch (Exception e) { } catch (Exception e) {
if (e instanceof SSLException) { if (e instanceof SSLException) {
throw (SSLException) e; throw (SSLException) e;
@ -398,7 +398,8 @@ public abstract class SslContext {
X509Certificate[] trustCertCollection, TrustManagerFactory trustManagerFactory, X509Certificate[] trustCertCollection, TrustManagerFactory trustManagerFactory,
X509Certificate[] keyCertChain, PrivateKey key, String keyPassword, KeyManagerFactory keyManagerFactory, X509Certificate[] keyCertChain, PrivateKey key, String keyPassword, KeyManagerFactory keyManagerFactory,
Iterable<String> ciphers, CipherSuiteFilter cipherFilter, ApplicationProtocolConfig apn, Iterable<String> ciphers, CipherSuiteFilter cipherFilter, ApplicationProtocolConfig apn,
long sessionCacheSize, long sessionTimeout, ClientAuth clientAuth, boolean startTls) throws SSLException { long sessionCacheSize, long sessionTimeout, ClientAuth clientAuth, String[] protocols, boolean startTls)
throws SSLException {
if (provider == null) { if (provider == null) {
provider = defaultServerProvider(); provider = defaultServerProvider();
@ -409,17 +410,17 @@ public abstract class SslContext {
return new JdkSslServerContext( return new JdkSslServerContext(
trustCertCollection, trustManagerFactory, keyCertChain, key, keyPassword, trustCertCollection, trustManagerFactory, keyCertChain, key, keyPassword,
keyManagerFactory, ciphers, cipherFilter, apn, sessionCacheSize, sessionTimeout, keyManagerFactory, ciphers, cipherFilter, apn, sessionCacheSize, sessionTimeout,
clientAuth, startTls); clientAuth, protocols, startTls);
case OPENSSL: case OPENSSL:
return new OpenSslServerContext( return new OpenSslServerContext(
trustCertCollection, trustManagerFactory, keyCertChain, key, keyPassword, trustCertCollection, trustManagerFactory, keyCertChain, key, keyPassword,
keyManagerFactory, ciphers, cipherFilter, apn, sessionCacheSize, sessionTimeout, keyManagerFactory, ciphers, cipherFilter, apn, sessionCacheSize, sessionTimeout,
clientAuth, startTls); clientAuth, protocols, startTls);
case OPENSSL_REFCNT: case OPENSSL_REFCNT:
return new ReferenceCountedOpenSslServerContext( return new ReferenceCountedOpenSslServerContext(
trustCertCollection, trustManagerFactory, keyCertChain, key, keyPassword, trustCertCollection, trustManagerFactory, keyCertChain, key, keyPassword,
keyManagerFactory, ciphers, cipherFilter, apn, sessionCacheSize, sessionTimeout, keyManagerFactory, ciphers, cipherFilter, apn, sessionCacheSize, sessionTimeout,
clientAuth, startTls); clientAuth, protocols, startTls);
default: default:
throw new Error(provider.toString()); throw new Error(provider.toString());
} }
@ -727,8 +728,7 @@ public abstract class SslContext {
return newClientContextInternal(provider, toX509Certificates(trustCertCollectionFile), trustManagerFactory, return newClientContextInternal(provider, toX509Certificates(trustCertCollectionFile), trustManagerFactory,
toX509Certificates(keyCertChainFile), toPrivateKey(keyFile, keyPassword), toX509Certificates(keyCertChainFile), toPrivateKey(keyFile, keyPassword),
keyPassword, keyManagerFactory, ciphers, cipherFilter, keyPassword, keyManagerFactory, ciphers, cipherFilter,
apn, apn, null, sessionCacheSize, sessionTimeout);
sessionCacheSize, sessionTimeout);
} catch (Exception e) { } catch (Exception e) {
if (e instanceof SSLException) { if (e instanceof SSLException) {
throw (SSLException) e; throw (SSLException) e;
@ -741,7 +741,7 @@ public abstract class SslContext {
SslProvider provider, SslProvider provider,
X509Certificate[] trustCert, TrustManagerFactory trustManagerFactory, X509Certificate[] trustCert, TrustManagerFactory trustManagerFactory,
X509Certificate[] keyCertChain, PrivateKey key, String keyPassword, KeyManagerFactory keyManagerFactory, X509Certificate[] keyCertChain, PrivateKey key, String keyPassword, KeyManagerFactory keyManagerFactory,
Iterable<String> ciphers, CipherSuiteFilter cipherFilter, ApplicationProtocolConfig apn, Iterable<String> ciphers, CipherSuiteFilter cipherFilter, ApplicationProtocolConfig apn, String[] protocols,
long sessionCacheSize, long sessionTimeout) throws SSLException { long sessionCacheSize, long sessionTimeout) throws SSLException {
if (provider == null) { if (provider == null) {
provider = defaultClientProvider(); provider = defaultClientProvider();
@ -750,15 +750,15 @@ public abstract class SslContext {
case JDK: case JDK:
return new JdkSslClientContext( return new JdkSslClientContext(
trustCert, trustManagerFactory, keyCertChain, key, keyPassword, trustCert, trustManagerFactory, keyCertChain, key, keyPassword,
keyManagerFactory, ciphers, cipherFilter, apn, sessionCacheSize, sessionTimeout); keyManagerFactory, ciphers, cipherFilter, apn, protocols, sessionCacheSize, sessionTimeout);
case OPENSSL: case OPENSSL:
return new OpenSslClientContext( return new OpenSslClientContext(
trustCert, trustManagerFactory, keyCertChain, key, keyPassword, trustCert, trustManagerFactory, keyCertChain, key, keyPassword,
keyManagerFactory, ciphers, cipherFilter, apn, sessionCacheSize, sessionTimeout); keyManagerFactory, ciphers, cipherFilter, apn, protocols, sessionCacheSize, sessionTimeout);
case OPENSSL_REFCNT: case OPENSSL_REFCNT:
return new ReferenceCountedOpenSslClientContext( return new ReferenceCountedOpenSslClientContext(
trustCert, trustManagerFactory, keyCertChain, key, keyPassword, trustCert, trustManagerFactory, keyCertChain, key, keyPassword,
keyManagerFactory, ciphers, cipherFilter, apn, sessionCacheSize, sessionTimeout); keyManagerFactory, ciphers, cipherFilter, apn, protocols, sessionCacheSize, sessionTimeout);
default: default:
throw new Error(provider.toString()); throw new Error(provider.toString());
} }

View File

@ -16,15 +16,16 @@
package io.netty.handler.ssl; package io.netty.handler.ssl;
import static io.netty.util.internal.ObjectUtil.checkNotNull;
import javax.net.ssl.KeyManagerFactory;
import javax.net.ssl.SSLException;
import javax.net.ssl.TrustManagerFactory;
import java.io.File; import java.io.File;
import java.io.InputStream; import java.io.InputStream;
import java.security.PrivateKey; import java.security.PrivateKey;
import java.security.cert.X509Certificate; import java.security.cert.X509Certificate;
import javax.net.ssl.KeyManagerFactory;
import javax.net.ssl.SSLEngine;
import javax.net.ssl.SSLException;
import javax.net.ssl.TrustManagerFactory;
import static io.netty.util.internal.ObjectUtil.checkNotNull;
/** /**
* Builder for configuring a new SslContext for creation. * Builder for configuring a new SslContext for creation.
@ -137,6 +138,7 @@ public final class SslContextBuilder {
private long sessionCacheSize; private long sessionCacheSize;
private long sessionTimeout; private long sessionTimeout;
private ClientAuth clientAuth = ClientAuth.NONE; private ClientAuth clientAuth = ClientAuth.NONE;
private String[] protocols;
private boolean startTls; private boolean startTls;
private SslContextBuilder(boolean forServer) { private SslContextBuilder(boolean forServer) {
@ -384,6 +386,16 @@ public final class SslContextBuilder {
return this; return this;
} }
/**
* The TLS protocol versions to enable.
* @param protocols The protocols to enable, or {@code null} to enable the default protocols.
* @see SSLEngine#setEnabledCipherSuites(String[])
*/
public SslContextBuilder protocols(String... protocols) {
this.protocols = protocols == null ? null : protocols.clone();
return this;
}
/** /**
* {@code true} if the first write request shouldn't be encrypted. * {@code true} if the first write request shouldn't be encrypted.
*/ */
@ -401,11 +413,11 @@ public final class SslContextBuilder {
if (forServer) { if (forServer) {
return SslContext.newServerContextInternal(provider, trustCertCollection, return SslContext.newServerContextInternal(provider, trustCertCollection,
trustManagerFactory, keyCertChain, key, keyPassword, keyManagerFactory, trustManagerFactory, keyCertChain, key, keyPassword, keyManagerFactory,
ciphers, cipherFilter, apn, sessionCacheSize, sessionTimeout, clientAuth, startTls); ciphers, cipherFilter, apn, sessionCacheSize, sessionTimeout, clientAuth, protocols, startTls);
} else { } else {
return SslContext.newClientContextInternal(provider, trustCertCollection, return SslContext.newClientContextInternal(provider, trustCertCollection,
trustManagerFactory, keyCertChain, key, keyPassword, keyManagerFactory, trustManagerFactory, keyCertChain, key, keyPassword, keyManagerFactory,
ciphers, cipherFilter, apn, sessionCacheSize, sessionTimeout); ciphers, cipherFilter, apn, protocols, sessionCacheSize, sessionTimeout);
} }
} }
} }

View File

@ -63,13 +63,11 @@ import java.security.cert.Certificate;
import java.security.cert.CertificateException; import java.security.cert.CertificateException;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.Collection; import java.util.Collection;
import java.util.Collections;
import java.util.List; import java.util.List;
import java.util.concurrent.CountDownLatch; import java.util.concurrent.CountDownLatch;
import java.util.concurrent.ExecutionException; import java.util.concurrent.ExecutionException;
import java.util.concurrent.TimeUnit; import java.util.concurrent.TimeUnit;
import javax.net.ssl.KeyManagerFactory; import javax.net.ssl.KeyManagerFactory;
import javax.net.ssl.SNIHostName;
import javax.net.ssl.SSLEngine; import javax.net.ssl.SSLEngine;
import javax.net.ssl.SSLEngineResult; import javax.net.ssl.SSLEngineResult;
import javax.net.ssl.SSLException; import javax.net.ssl.SSLException;
@ -1633,6 +1631,43 @@ public abstract class SSLEngineTest {
} }
} }
@Test
public void testProtocolMatch() throws Exception {
testProtocol(new String[] {"TLSv1.2"}, new String[] {"TLSv1", "TLSv1.1", "TLSv1.2"});
}
@Test(expected = SSLHandshakeException.class)
public void testProtocolNoMatch() throws Exception {
testProtocol(new String[] {"TLSv1.2"}, new String[] {"TLSv1", "TLSv1.1"});
}
private void testProtocol(String[] clientProtocols, String[] serverProtocols) throws Exception {
SelfSignedCertificate cert = new SelfSignedCertificate();
clientSslCtx = SslContextBuilder
.forClient()
.trustManager(cert.cert())
.sslProvider(sslClientProvider())
.protocols(clientProtocols)
.build();
SSLEngine client = clientSslCtx.newEngine(UnpooledByteBufAllocator.DEFAULT);
serverSslCtx = SslContextBuilder
.forServer(cert.certificate(), cert.privateKey())
.sslProvider(sslServerProvider())
.protocols(serverProtocols)
.build();
SSLEngine server = serverSslCtx.newEngine(UnpooledByteBufAllocator.DEFAULT);
try {
handshake(client, server);
} finally {
cleanupClientSslEngine(client);
cleanupServerSslEngine(server);
cert.delete();
}
}
@Test @Test
public void testPacketBufferSizeLimit() throws Exception { public void testPacketBufferSizeLimit() throws Exception {
SelfSignedCertificate cert = new SelfSignedCertificate(); SelfSignedCertificate cert = new SelfSignedCertificate();