Adding client auth to SslContextBuilder

Motivation:

To simplify the use of client auth, we need to add it to the SslContextBuilder.

Modifications:

Added a ClientAuth enum and plumbed it through the builder, down into the contexts/engines.

Result:

Client auth can be configured when building an SslContext.
This commit is contained in:
nmittler 2015-09-18 09:50:30 -07:00
parent 2a200385a0
commit 94bf412edb
11 changed files with 140 additions and 65 deletions

View File

@ -0,0 +1,38 @@
/*
* Copyright 2015 The Netty Project
*
* The Netty Project licenses this file to you under the Apache License,
* version 2.0 (the "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at:
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*/
package io.netty.handler.ssl;
/**
* Indicates the state of the {@link javax.net.ssl.SSLEngine} with respect to client authentication.
* This configuration item really only applies when building the server-side {@link SslContext}.
*/
public enum ClientAuth {
/**
* Indicates that the {@link javax.net.ssl.SSLEngine} will not request client authentication.
*/
NONE,
/**
* Indicates that the {@link javax.net.ssl.SSLEngine} will request client authentication.
*/
OPTIONAL,
/**
* Indicates that the {@link javax.net.ssl.SSLEngine} will *require* client authentication.
*/
REQUIRE
}

View File

@ -213,7 +213,7 @@ public final class JdkSslClientContext extends JdkSslContext {
File keyCertChainFile, File keyFile, String keyPassword, KeyManagerFactory keyManagerFactory,
Iterable<String> ciphers, CipherSuiteFilter cipherFilter, JdkApplicationProtocolNegotiator apn,
long sessionCacheSize, long sessionTimeout) throws SSLException {
super(ciphers, cipherFilter, apn);
super(ciphers, cipherFilter, apn, ClientAuth.NONE);
try {
ctx = newSSLContext(toX509Certificates(trustCertChainFile), trustManagerFactory,
toX509Certificates(keyCertChainFile), toPrivateKey(keyFile, keyPassword),
@ -230,7 +230,7 @@ public final class JdkSslClientContext extends JdkSslContext {
X509Certificate[] keyCertChain, PrivateKey key, String keyPassword,
KeyManagerFactory keyManagerFactory, Iterable<String> ciphers, CipherSuiteFilter cipherFilter,
ApplicationProtocolConfig apn, long sessionCacheSize, long sessionTimeout) throws SSLException {
super(ciphers, cipherFilter, toNegotiator(apn, false));
super(ciphers, cipherFilter, toNegotiator(apn, false), ClientAuth.NONE);
ctx = newSSLContext(trustCertChain, trustManagerFactory, keyCertChain, key, keyPassword,
keyManagerFactory, sessionCacheSize, sessionTimeout);
}

View File

@ -139,14 +139,12 @@ public abstract class JdkSslContext extends SslContext {
private final String[] cipherSuites;
private final List<String> unmodifiableCipherSuites;
private final JdkApplicationProtocolNegotiator apn;
private final ClientAuth clientAuth;
JdkSslContext(Iterable<String> ciphers, CipherSuiteFilter cipherFilter, ApplicationProtocolConfig config,
boolean isServer) {
this(ciphers, cipherFilter, toNegotiator(config, isServer));
}
JdkSslContext(Iterable<String> ciphers, CipherSuiteFilter cipherFilter, JdkApplicationProtocolNegotiator apn) {
JdkSslContext(Iterable<String> ciphers, CipherSuiteFilter cipherFilter, JdkApplicationProtocolNegotiator apn,
ClientAuth clientAuth) {
this.apn = checkNotNull(apn, "apn");
this.clientAuth = checkNotNull(clientAuth, "clientAuth");
cipherSuites = checkNotNull(cipherFilter, "cipherFilter").filterCipherSuites(
ciphers, DEFAULT_CIPHERS, SUPPORTED_CIPHERS);
unmodifiableCipherSuites = Collections.unmodifiableList(Arrays.asList(cipherSuites));
@ -186,23 +184,28 @@ public abstract class JdkSslContext extends SslContext {
@Override
public final SSLEngine newEngine(ByteBufAllocator alloc) {
SSLEngine engine = context().createSSLEngine();
engine.setEnabledCipherSuites(cipherSuites);
engine.setEnabledProtocols(PROTOCOLS);
engine.setUseClientMode(isClient());
return wrapEngine(engine);
return configureAndWrapEngine(context().createSSLEngine());
}
@Override
public final SSLEngine newEngine(ByteBufAllocator alloc, String peerHost, int peerPort) {
SSLEngine engine = context().createSSLEngine(peerHost, peerPort);
return configureAndWrapEngine(context().createSSLEngine(peerHost, peerPort));
}
private SSLEngine configureAndWrapEngine(SSLEngine engine) {
engine.setEnabledCipherSuites(cipherSuites);
engine.setEnabledProtocols(PROTOCOLS);
engine.setUseClientMode(isClient());
return wrapEngine(engine);
if (isServer()) {
switch (clientAuth) {
case OPTIONAL:
engine.setWantClientAuth(true);
break;
case REQUIRE:
engine.setNeedClientAuth(true);
break;
}
}
private SSLEngine wrapEngine(SSLEngine engine) {
return apn.wrapperFactory().wrapSslEngine(engine, apn, isServer());
}

View File

@ -182,7 +182,7 @@ public final class JdkSslServerContext extends JdkSslContext {
File keyCertChainFile, File keyFile, String keyPassword, KeyManagerFactory keyManagerFactory,
Iterable<String> ciphers, CipherSuiteFilter cipherFilter, JdkApplicationProtocolNegotiator apn,
long sessionCacheSize, long sessionTimeout) throws SSLException {
super(ciphers, cipherFilter, apn);
super(ciphers, cipherFilter, apn, ClientAuth.NONE);
try {
ctx = newSSLContext(toX509Certificates(trustCertChainFile), trustManagerFactory,
toX509Certificates(keyCertChainFile), toPrivateKey(keyFile, keyPassword),
@ -198,8 +198,9 @@ public final class JdkSslServerContext extends JdkSslContext {
JdkSslServerContext(X509Certificate[] trustCertChain, TrustManagerFactory trustManagerFactory,
X509Certificate[] keyCertChain, PrivateKey key, String keyPassword,
KeyManagerFactory keyManagerFactory, Iterable<String> ciphers, CipherSuiteFilter cipherFilter,
ApplicationProtocolConfig apn, long sessionCacheSize, long sessionTimeout) throws SSLException {
super(ciphers, cipherFilter, toNegotiator(apn, true));
ApplicationProtocolConfig apn, long sessionCacheSize, long sessionTimeout,
ClientAuth clientAuth) throws SSLException {
super(ciphers, cipherFilter, toNegotiator(apn, true), clientAuth);
ctx = newSSLContext(trustCertChain, trustManagerFactory, keyCertChain, key,
keyPassword, keyManagerFactory, sessionCacheSize, sessionTimeout);
}

View File

@ -171,7 +171,8 @@ public final class OpenSslClientContext extends OpenSslContext {
CipherSuiteFilter cipherFilter, ApplicationProtocolConfig apn,
long sessionCacheSize, long sessionTimeout)
throws SSLException {
super(ciphers, cipherFilter, apn, sessionCacheSize, sessionTimeout, SSL.SSL_MODE_CLIENT, null);
super(ciphers, cipherFilter, apn, sessionCacheSize, sessionTimeout, SSL.SSL_MODE_CLIENT, null,
ClientAuth.NONE);
boolean success = false;
try {
if (trustCertChainFile != null && !trustCertChainFile.isFile()) {
@ -271,7 +272,8 @@ public final class OpenSslClientContext extends OpenSslContext {
CipherSuiteFilter cipherFilter, ApplicationProtocolConfig apn,
long sessionCacheSize, long sessionTimeout)
throws SSLException {
super(ciphers, cipherFilter, apn, sessionCacheSize, sessionTimeout, SSL.SSL_MODE_CLIENT, keyCertChain);
super(ciphers, cipherFilter, apn, sessionCacheSize, sessionTimeout, SSL.SSL_MODE_CLIENT, keyCertChain,
ClientAuth.NONE);
boolean success = false;
try {
if (key == null && keyCertChain != null || key != null && keyCertChain == null) {

View File

@ -83,6 +83,7 @@ public abstract class OpenSslContext extends SslContext {
private final OpenSslApplicationProtocolNegotiator apn;
private final int mode;
private final Certificate[] keyCertChain;
private final ClientAuth clientAuth;
static final OpenSslApplicationProtocolNegotiator NONE_PROTOCOL_NEGOTIATOR =
new OpenSslApplicationProtocolNegotiator() {
@ -127,20 +128,24 @@ public abstract class OpenSslContext extends SslContext {
}
OpenSslContext(Iterable<String> ciphers, CipherSuiteFilter cipherFilter, ApplicationProtocolConfig apnCfg,
long sessionCacheSize, long sessionTimeout, int mode, Certificate[] keyCertChain)
long sessionCacheSize, long sessionTimeout, int mode, Certificate[] keyCertChain,
ClientAuth clientAuth)
throws SSLException {
this(ciphers, cipherFilter, toNegotiator(apnCfg), sessionCacheSize, sessionTimeout, mode, keyCertChain);
this(ciphers, cipherFilter, toNegotiator(apnCfg), sessionCacheSize, sessionTimeout, mode, keyCertChain,
clientAuth);
}
OpenSslContext(Iterable<String> ciphers, CipherSuiteFilter cipherFilter,
OpenSslApplicationProtocolNegotiator apn, long sessionCacheSize,
long sessionTimeout, int mode, Certificate[] keyCertChain) throws SSLException {
long sessionTimeout, int mode, Certificate[] keyCertChain,
ClientAuth clientAuth) throws SSLException {
OpenSsl.ensureAvailability();
if (mode != SSL.SSL_MODE_SERVER && mode != SSL.SSL_MODE_CLIENT) {
throw new IllegalArgumentException("mode most be either SSL.SSL_MODE_SERVER or SSL.SSL_MODE_CLIENT");
}
this.mode = mode;
this.clientAuth = isServer() ? checkNotNull(clientAuth, "clientAuth") : ClientAuth.NONE;
if (mode == SSL.SSL_MODE_SERVER) {
rejectRemoteInitiatedRenegotiation =
@ -291,7 +296,7 @@ public abstract class OpenSslContext extends SslContext {
@Override
public final SSLEngine newEngine(ByteBufAllocator alloc, String peerHost, int peerPort) {
final OpenSslEngine engine = new OpenSslEngine(ctx, alloc, isClient(), sessionContext(), apn, engineMap,
rejectRemoteInitiatedRenegotiation, peerHost, peerPort, keyCertChain);
rejectRemoteInitiatedRenegotiation, peerHost, peerPort, keyCertChain, clientAuth);
engineMap.add(engine);
return engine;
}

View File

@ -110,12 +110,6 @@ public final class OpenSslEngine extends SSLEngine {
static final int MAX_ENCRYPTION_OVERHEAD_LENGTH = MAX_ENCRYPTED_PACKET_LENGTH - MAX_PLAINTEXT_LENGTH;
enum ClientAuthMode {
NONE,
OPTIONAL,
REQUIRE,
}
private static final AtomicIntegerFieldUpdater<OpenSslEngine> DESTROYED_UPDATER;
private static final String INVALID_CIPHER = "SSL_NULL_WITH_NULL_NULL";
@ -157,7 +151,7 @@ public final class OpenSslEngine extends SSLEngine {
@SuppressWarnings("UnusedDeclaration")
private volatile int destroyed;
private volatile ClientAuthMode clientAuth = ClientAuthMode.NONE;
private volatile ClientAuth clientAuth = ClientAuth.NONE;
private volatile String endPointIdentificationAlgorithm;
// Store as object as AlgorithmConstraints only exists since java 7.
@ -189,30 +183,25 @@ public final class OpenSslEngine extends SSLEngine {
@Deprecated
public OpenSslEngine(long sslCtx, ByteBufAllocator alloc,
@SuppressWarnings("unused") String fallbackApplicationProtocol) {
this(sslCtx, alloc, false, null, OpenSslContext.NONE_PROTOCOL_NEGOTIATOR, OpenSslEngineMap.EMPTY, false);
this(sslCtx, alloc, false, null, OpenSslContext.NONE_PROTOCOL_NEGOTIATOR, OpenSslEngineMap.EMPTY, false,
ClientAuth.NONE);
}
/**
* Creates a new instance
*
* @param sslCtx an OpenSSL {@code SSL_CTX} object
* @param alloc the {@link ByteBufAllocator} that will be used by this engine
* @param clientMode {@code true} if this is used for clients, {@code false} otherwise
* @param sessionContext the {@link OpenSslSessionContext} this {@link SSLEngine} belongs to.
*/
OpenSslEngine(long sslCtx, ByteBufAllocator alloc,
boolean clientMode, OpenSslSessionContext sessionContext,
OpenSslApplicationProtocolNegotiator apn, OpenSslEngineMap engineMap,
boolean rejectRemoteInitiatedRenegation) {
boolean rejectRemoteInitiatedRenegation,
ClientAuth clientAuth) {
this(sslCtx, alloc, clientMode, sessionContext, apn, engineMap, rejectRemoteInitiatedRenegation, null, -1,
null);
null, clientAuth);
}
OpenSslEngine(long sslCtx, ByteBufAllocator alloc,
boolean clientMode, OpenSslSessionContext sessionContext,
OpenSslApplicationProtocolNegotiator apn, OpenSslEngineMap engineMap,
boolean rejectRemoteInitiatedRenegation, String peerHost, int peerPort,
java.security.cert.Certificate[] localCerts) {
java.security.cert.Certificate[] localCerts,
ClientAuth clientAuth) {
super(peerHost, peerPort);
OpenSsl.ensureAvailability();
if (sslCtx == 0) {
@ -221,6 +210,7 @@ public final class OpenSslEngine extends SSLEngine {
this.alloc = checkNotNull(alloc, "alloc");
this.apn = checkNotNull(apn, "apn");
this.clientAuth = clientMode ? ClientAuth.NONE : checkNotNull(clientAuth, "clientAuth");
ssl = SSL.newSSL(sslCtx, !clientMode);
session = new OpenSslSession(ssl, sessionContext);
networkBIO = SSL.makeNetworkBIO(ssl);
@ -1194,25 +1184,25 @@ public final class OpenSslEngine extends SSLEngine {
@Override
public void setNeedClientAuth(boolean b) {
setClientAuth(b ? ClientAuthMode.REQUIRE : ClientAuthMode.NONE);
setClientAuth(b ? ClientAuth.REQUIRE : ClientAuth.NONE);
}
@Override
public boolean getNeedClientAuth() {
return clientAuth == ClientAuthMode.REQUIRE;
return clientAuth == ClientAuth.REQUIRE;
}
@Override
public void setWantClientAuth(boolean b) {
setClientAuth(b ? ClientAuthMode.OPTIONAL : ClientAuthMode.NONE);
setClientAuth(b ? ClientAuth.OPTIONAL : ClientAuth.NONE);
}
@Override
public boolean getWantClientAuth() {
return clientAuth == ClientAuthMode.OPTIONAL;
return clientAuth == ClientAuth.OPTIONAL;
}
private void setClientAuth(ClientAuthMode mode) {
private void setClientAuth(ClientAuth mode) {
if (clientMode) {
return;
}

View File

@ -291,7 +291,8 @@ public final class OpenSslServerContext extends OpenSslContext {
File keyCertChainFile, File keyFile, String keyPassword, KeyManagerFactory keyManagerFactory,
Iterable<String> ciphers, CipherSuiteFilter cipherFilter, OpenSslApplicationProtocolNegotiator apn,
long sessionCacheSize, long sessionTimeout) throws SSLException {
super(ciphers, cipherFilter, apn, sessionCacheSize, sessionTimeout, SSL.SSL_MODE_SERVER, null);
super(ciphers, cipherFilter, apn, sessionCacheSize, sessionTimeout, SSL.SSL_MODE_SERVER, null,
ClientAuth.NONE);
OpenSsl.ensureAvailability();
checkNotNull(keyCertChainFile, "keyCertChainFile");
@ -391,8 +392,9 @@ public final class OpenSslServerContext extends OpenSslContext {
X509Certificate[] trustCertChain, TrustManagerFactory trustManagerFactory,
X509Certificate[] keyCertChain, PrivateKey key, String keyPassword, KeyManagerFactory keyManagerFactory,
Iterable<String> ciphers, CipherSuiteFilter cipherFilter, ApplicationProtocolConfig apn,
long sessionCacheSize, long sessionTimeout) throws SSLException {
super(ciphers, cipherFilter, apn, sessionCacheSize, sessionTimeout, SSL.SSL_MODE_SERVER, keyCertChain);
long sessionCacheSize, long sessionTimeout, ClientAuth clientAuth) throws SSLException {
super(ciphers, cipherFilter, apn, sessionCacheSize, sessionTimeout, SSL.SSL_MODE_SERVER, keyCertChain,
clientAuth);
OpenSsl.ensureAvailability();
checkNotNull(keyCertChain, "keyCertChainFile");

View File

@ -278,7 +278,7 @@ public abstract class SslContext {
toX509Certificates(keyCertChainFile),
toPrivateKey(keyFile, keyPassword),
keyPassword, keyManagerFactory, ciphers, cipherFilter, apn,
sessionCacheSize, sessionTimeout);
sessionCacheSize, sessionTimeout, ClientAuth.NONE);
} catch (Exception e) {
if (e instanceof SSLException) {
throw (SSLException) e;
@ -292,7 +292,8 @@ public abstract class SslContext {
X509Certificate[] trustCertChain, TrustManagerFactory trustManagerFactory,
X509Certificate[] keyCertChain, PrivateKey key, String keyPassword, KeyManagerFactory keyManagerFactory,
Iterable<String> ciphers, CipherSuiteFilter cipherFilter, ApplicationProtocolConfig apn,
long sessionCacheSize, long sessionTimeout) throws SSLException {
long sessionCacheSize, long sessionTimeout,
ClientAuth clientAuth) throws SSLException {
if (provider == null) {
provider = defaultServerProvider();
@ -302,11 +303,13 @@ public abstract class SslContext {
case JDK:
return new JdkSslServerContext(
trustCertChain, trustManagerFactory, keyCertChain, key, keyPassword,
keyManagerFactory, ciphers, cipherFilter, apn, sessionCacheSize, sessionTimeout);
keyManagerFactory, ciphers, cipherFilter, apn, sessionCacheSize, sessionTimeout,
clientAuth);
case OPENSSL:
return new OpenSslServerContext(
trustCertChain, trustManagerFactory, keyCertChain, key, keyPassword,
keyManagerFactory, ciphers, cipherFilter, apn, sessionCacheSize, sessionTimeout);
keyManagerFactory, ciphers, cipherFilter, apn, sessionCacheSize, sessionTimeout,
clientAuth);
default:
throw new Error(provider.toString());
}

View File

@ -110,6 +110,7 @@ public final class SslContextBuilder {
private ApplicationProtocolConfig apn;
private long sessionCacheSize;
private long sessionTimeout;
private ClientAuth clientAuth = ClientAuth.NONE;
private SslContextBuilder(boolean forServer) {
this.forServer = forServer;
@ -298,6 +299,14 @@ public final class SslContextBuilder {
return this;
}
/**
* Sets the client authentication mode.
*/
public SslContextBuilder clientAuth(ClientAuth clientAuth) {
this.clientAuth = checkNotNull(clientAuth, "clientAuth");
return this;
}
/**
* Create new {@code SslContext} instance with configured settings.
*/
@ -305,7 +314,7 @@ public final class SslContextBuilder {
if (forServer) {
return SslContext.newServerContextInternal(provider, trustCertChain,
trustManagerFactory, keyCertChain, key, keyPassword, keyManagerFactory,
ciphers, cipherFilter, apn, sessionCacheSize, sessionTimeout);
ciphers, cipherFilter, apn, sessionCacheSize, sessionTimeout, clientAuth);
} else {
return SslContext.newClientContextInternal(provider, trustCertChain,
trustManagerFactory, keyCertChain, key, keyPassword, keyManagerFactory,

View File

@ -15,6 +15,9 @@
*/
package io.netty.handler.ssl;
import static org.junit.Assert.assertFalse;
import static org.junit.Assert.assertTrue;
import io.netty.buffer.UnpooledByteBufAllocator;
import io.netty.handler.ssl.util.SelfSignedCertificate;
import org.junit.Assume;
@ -70,20 +73,31 @@ public class SslContextBuilderTest {
private static void testClientContextFromFile(SslProvider provider) throws Exception {
SelfSignedCertificate cert = new SelfSignedCertificate();
SslContextBuilder builder = SslContextBuilder.forClient().sslProvider(provider).keyManager(
cert.certificate(), cert.privateKey()).trustManager(cert.certificate());
SslContextBuilder builder = SslContextBuilder.forClient()
.sslProvider(provider)
.keyManager(cert.certificate(),
cert.privateKey())
.trustManager(cert.certificate())
.clientAuth(ClientAuth.OPTIONAL);
SslContext context = builder.build();
SSLEngine engine = context.newEngine(UnpooledByteBufAllocator.DEFAULT);
assertFalse(engine.getWantClientAuth());
assertFalse(engine.getNeedClientAuth());
engine.closeInbound();
engine.closeOutbound();
}
private static void testClientContext(SslProvider provider) throws Exception {
SelfSignedCertificate cert = new SelfSignedCertificate();
SslContextBuilder builder = SslContextBuilder.forClient().sslProvider(provider).keyManager(
cert.key(), cert.cert()).trustManager(cert.cert());
SslContextBuilder builder = SslContextBuilder.forClient()
.sslProvider(provider)
.keyManager(cert.key(), cert.cert())
.trustManager(cert.cert())
.clientAuth(ClientAuth.OPTIONAL);
SslContext context = builder.build();
SSLEngine engine = context.newEngine(UnpooledByteBufAllocator.DEFAULT);
assertFalse(engine.getWantClientAuth());
assertFalse(engine.getNeedClientAuth());
engine.closeInbound();
engine.closeOutbound();
}
@ -91,9 +105,13 @@ public class SslContextBuilderTest {
private static void testServerContextFromFile(SslProvider provider) throws Exception {
SelfSignedCertificate cert = new SelfSignedCertificate();
SslContextBuilder builder = SslContextBuilder.forServer(cert.certificate(), cert.privateKey())
.sslProvider(provider).trustManager(cert.certificate());
.sslProvider(provider)
.trustManager(cert.certificate())
.clientAuth(ClientAuth.OPTIONAL);
SslContext context = builder.build();
SSLEngine engine = context.newEngine(UnpooledByteBufAllocator.DEFAULT);
assertTrue(engine.getWantClientAuth());
assertFalse(engine.getNeedClientAuth());
engine.closeInbound();
engine.closeOutbound();
}
@ -101,9 +119,13 @@ public class SslContextBuilderTest {
private static void testServerContext(SslProvider provider) throws Exception {
SelfSignedCertificate cert = new SelfSignedCertificate();
SslContextBuilder builder = SslContextBuilder.forServer(cert.key(), cert.cert())
.sslProvider(provider).trustManager(cert.cert());
.sslProvider(provider)
.trustManager(cert.cert())
.clientAuth(ClientAuth.REQUIRE);
SslContext context = builder.build();
SSLEngine engine = context.newEngine(UnpooledByteBufAllocator.DEFAULT);
assertFalse(engine.getWantClientAuth());
assertTrue(engine.getNeedClientAuth());
engine.closeInbound();
engine.closeOutbound();
}