ReferenceCountedOpenSslEngine SSLSession.getLocalCertificates() / getLocalPrincipial() did not work when KeyManagerFactory was used. (#8560)

Motivation:

The SSLSession.getLocalCertificates() / getLocalPrincipial() methods did not correctly return the local configured certificate / principal if a KeyManagerFactory was used when configure the SslContext.

Modifications:

- Correctly update the local certificates / principial when the key material is selected.
- Add test case that verifies the SSLSession after the handshake to ensure we correctly return all values.

Result:

SSLSession returns correct values also when KeyManagerFactory is used with the OpenSSL provider.
This commit is contained in:
Norman Maurer 2018-11-16 07:38:32 +01:00 committed by GitHub
parent 20d4fda55e
commit 8d4d76d216
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 240 additions and 10 deletions

View File

@ -22,20 +22,29 @@ import io.netty.util.ResourceLeakDetector;
import io.netty.util.ResourceLeakDetectorFactory; import io.netty.util.ResourceLeakDetectorFactory;
import io.netty.util.ResourceLeakTracker; import io.netty.util.ResourceLeakTracker;
import java.security.cert.X509Certificate;
final class DefaultOpenSslKeyMaterial extends AbstractReferenceCounted implements OpenSslKeyMaterial { final class DefaultOpenSslKeyMaterial extends AbstractReferenceCounted implements OpenSslKeyMaterial {
private static final ResourceLeakDetector<DefaultOpenSslKeyMaterial> leakDetector = private static final ResourceLeakDetector<DefaultOpenSslKeyMaterial> leakDetector =
ResourceLeakDetectorFactory.instance().newResourceLeakDetector(DefaultOpenSslKeyMaterial.class); ResourceLeakDetectorFactory.instance().newResourceLeakDetector(DefaultOpenSslKeyMaterial.class);
private final ResourceLeakTracker<DefaultOpenSslKeyMaterial> leak; private final ResourceLeakTracker<DefaultOpenSslKeyMaterial> leak;
private final X509Certificate[] x509CertificateChain;
private long chain; private long chain;
private long privateKey; private long privateKey;
DefaultOpenSslKeyMaterial(long chain, long privateKey) { DefaultOpenSslKeyMaterial(long chain, long privateKey, X509Certificate[] x509CertificateChain) {
this.chain = chain; this.chain = chain;
this.privateKey = privateKey; this.privateKey = privateKey;
this.x509CertificateChain = x509CertificateChain;
leak = leakDetector.track(this); leak = leakDetector.track(this);
} }
@Override
public X509Certificate[] certificateChain() {
return x509CertificateChain.clone();
}
@Override @Override
public long certificateChainAddress() { public long certificateChainAddress() {
if (refCnt() <= 0) { if (refCnt() <= 0) {

View File

@ -17,11 +17,18 @@ package io.netty.handler.ssl;
import io.netty.util.ReferenceCounted; import io.netty.util.ReferenceCounted;
import java.security.cert.X509Certificate;
/** /**
* Holds references to the native key-material that is used by OpenSSL. * Holds references to the native key-material that is used by OpenSSL.
*/ */
interface OpenSslKeyMaterial extends ReferenceCounted { interface OpenSslKeyMaterial extends ReferenceCounted {
/**
* Returns the configured {@link X509Certificate}s.
*/
X509Certificate[] certificateChain();
/** /**
* Returns the pointer to the {@code STACK_OF(X509)} which holds the certificate chain. * Returns the pointer to the {@code STACK_OF(X509)} which holds the certificate chain.
*/ */

View File

@ -96,8 +96,7 @@ final class OpenSslKeyMaterialManager {
try { try {
keyMaterial = provider.chooseKeyMaterial(engine.alloc, alias); keyMaterial = provider.chooseKeyMaterial(engine.alloc, alias);
if (keyMaterial != null) { if (keyMaterial != null) {
SSL.setKeyMaterial(engine.sslPointer(), engine.setKeyMaterial(keyMaterial);
keyMaterial.certificateChainAddress(), keyMaterial.privateKeyAddress());
} }
} catch (SSLException e) { } catch (SSLException e) {
throw e; throw e;

View File

@ -66,11 +66,11 @@ class OpenSslKeyMaterialProvider {
OpenSslKeyMaterial keyMaterial; OpenSslKeyMaterial keyMaterial;
if (key instanceof OpenSslPrivateKey) { if (key instanceof OpenSslPrivateKey) {
keyMaterial = ((OpenSslPrivateKey) key).toKeyMaterial(chain); keyMaterial = ((OpenSslPrivateKey) key).toKeyMaterial(chain, certificates);
} else { } else {
pkeyBio = toBIO(allocator, key); pkeyBio = toBIO(allocator, key);
pkey = key == null ? 0 : SSL.parsePrivateKey(pkeyBio, password); pkey = key == null ? 0 : SSL.parsePrivateKey(pkeyBio, password);
keyMaterial = new DefaultOpenSslKeyMaterial(chain, pkey); keyMaterial = new DefaultOpenSslKeyMaterial(chain, pkey, certificates);
} }
// See the chain and pkey to 0 so we will not release it as the ownership was // See the chain and pkey to 0 so we will not release it as the ownership was

View File

@ -18,9 +18,11 @@ package io.netty.handler.ssl;
import io.netty.internal.tcnative.SSL; import io.netty.internal.tcnative.SSL;
import io.netty.util.AbstractReferenceCounted; import io.netty.util.AbstractReferenceCounted;
import io.netty.util.IllegalReferenceCountException; import io.netty.util.IllegalReferenceCountException;
import io.netty.util.internal.EmptyArrays;
import javax.security.auth.Destroyable; import javax.security.auth.Destroyable;
import java.security.PrivateKey; import java.security.PrivateKey;
import java.security.cert.X509Certificate;
final class OpenSslPrivateKey extends AbstractReferenceCounted implements PrivateKey { final class OpenSslPrivateKey extends AbstractReferenceCounted implements PrivateKey {
@ -110,16 +112,24 @@ final class OpenSslPrivateKey extends AbstractReferenceCounted implements Privat
/** /**
* Convert to a {@link OpenSslKeyMaterial}. Reference count of both is shared. * Convert to a {@link OpenSslKeyMaterial}. Reference count of both is shared.
*/ */
OpenSslKeyMaterial toKeyMaterial(long certificateChain) { OpenSslKeyMaterial toKeyMaterial(long certificateChain, X509Certificate[] chain) {
return new OpenSslPrivateKeyMaterial(certificateChain); return new OpenSslPrivateKeyMaterial(certificateChain, chain);
} }
private final class OpenSslPrivateKeyMaterial implements OpenSslKeyMaterial { private final class OpenSslPrivateKeyMaterial implements OpenSslKeyMaterial {
private long certificateChain; private long certificateChain;
private final X509Certificate[] x509CertificateChain;
OpenSslPrivateKeyMaterial(long certificateChain) { OpenSslPrivateKeyMaterial(long certificateChain, X509Certificate[] x509CertificateChain) {
this.certificateChain = certificateChain; this.certificateChain = certificateChain;
this.x509CertificateChain = x509CertificateChain == null ?
EmptyArrays.EMPTY_X509_CERTIFICATES : x509CertificateChain;
}
@Override
public X509Certificate[] certificateChain() {
return x509CertificateChain.clone();
} }
@Override @Override

View File

@ -193,6 +193,7 @@ public class ReferenceCountedOpenSslEngine extends SSLEngine implements Referenc
}; };
private volatile ClientAuth clientAuth = ClientAuth.NONE; private volatile ClientAuth clientAuth = ClientAuth.NONE;
private volatile Certificate[] localCertificateChain;
// Updated once a new handshake is started and so the SSLSession reused. // Updated once a new handshake is started and so the SSLSession reused.
private volatile long lastAccessed = -1; private volatile long lastAccessed = -1;
@ -216,7 +217,6 @@ public class ReferenceCountedOpenSslEngine extends SSLEngine implements Referenc
private final OpenSslEngineMap engineMap; private final OpenSslEngineMap engineMap;
private final OpenSslApplicationProtocolNegotiator apn; private final OpenSslApplicationProtocolNegotiator apn;
private final OpenSslSession session; private final OpenSslSession session;
private final Certificate[] localCerts;
private final ByteBuffer[] singleSrcBuffer = new ByteBuffer[1]; private final ByteBuffer[] singleSrcBuffer = new ByteBuffer[1];
private final ByteBuffer[] singleDstBuffer = new ByteBuffer[1]; private final ByteBuffer[] singleDstBuffer = new ByteBuffer[1];
private final boolean enableOcsp; private final boolean enableOcsp;
@ -323,8 +323,11 @@ public class ReferenceCountedOpenSslEngine extends SSLEngine implements Referenc
session = new DefaultOpenSslSession(context.sessionContext()); session = new DefaultOpenSslSession(context.sessionContext());
} }
engineMap = context.engineMap; engineMap = context.engineMap;
localCerts = context.keyCertChain;
enableOcsp = context.enableOcsp; enableOcsp = context.enableOcsp;
// context.keyCertChain will only be non-null if we do not use the KeyManagerFactory. In this case
// localCertificateChain will be set in setKeyMaterial(...).
localCertificateChain = context.keyCertChain;
this.jdkCompatibilityMode = jdkCompatibilityMode; this.jdkCompatibilityMode = jdkCompatibilityMode;
Lock readerLock = context.ctxLock.readLock(); Lock readerLock = context.ctxLock.readLock();
readerLock.lock(); readerLock.lock();
@ -379,6 +382,11 @@ public class ReferenceCountedOpenSslEngine extends SSLEngine implements Referenc
leak = leakDetection ? leakDetector.track(this) : null; leak = leakDetection ? leakDetector.track(this) : null;
} }
final void setKeyMaterial(OpenSslKeyMaterial keyMaterial) throws Exception {
SSL.setKeyMaterial(ssl, keyMaterial.certificateChainAddress(), keyMaterial.privateKeyAddress());
localCertificateChain = keyMaterial.certificateChain();
}
/** /**
* Sets the OCSP response. * Sets the OCSP response.
*/ */
@ -1930,6 +1938,8 @@ public class ReferenceCountedOpenSslEngine extends SSLEngine implements Referenc
// thread. // thread.
private X509Certificate[] x509PeerCerts; private X509Certificate[] x509PeerCerts;
private Certificate[] peerCerts; private Certificate[] peerCerts;
private Certificate[] localCerts;
private String protocol; private String protocol;
private String cipher; private String cipher;
private byte[] id; private byte[] id;
@ -2070,6 +2080,7 @@ public class ReferenceCountedOpenSslEngine extends SSLEngine implements Referenc
id = SSL.getSessionId(ssl); id = SSL.getSessionId(ssl);
cipher = toJavaCipherSuite(SSL.getCipherForSSL(ssl)); cipher = toJavaCipherSuite(SSL.getCipherForSSL(ssl));
protocol = SSL.getVersion(ssl); protocol = SSL.getVersion(ssl);
localCerts = localCertificateChain;
initPeerCerts(); initPeerCerts();
selectApplicationProtocol(); selectApplicationProtocol();

View File

@ -95,6 +95,20 @@ public class OpenSslEngineTest extends SSLEngineTest {
assertEquals("SSL error stack not correctly consumed", 0, SSL.getLastErrorNumber()); assertEquals("SSL error stack not correctly consumed", 0, SSL.getLastErrorNumber());
} }
@Override
@Test
public void testSessionAfterHandshakeKeyManagerFactory() throws Exception {
checkShouldUseKeyManagerFactory();
super.testSessionAfterHandshakeKeyManagerFactory();
}
@Override
@Test
public void testSessionAfterHandshakeKeyManagerFactoryMutualAuth() throws Exception {
checkShouldUseKeyManagerFactory();
super.testSessionAfterHandshakeKeyManagerFactoryMutualAuth();
}
@Override @Override
@Test @Test
public void testMutualAuthInvalidIntermediateCASucceedWithOptionalClientAuth() throws Exception { public void testMutualAuthInvalidIntermediateCASucceedWithOptionalClientAuth() throws Exception {

View File

@ -60,6 +60,7 @@ import java.net.InetSocketAddress;
import java.nio.ByteBuffer; import java.nio.ByteBuffer;
import java.nio.channels.ClosedChannelException; import java.nio.channels.ClosedChannelException;
import java.security.KeyStore; import java.security.KeyStore;
import java.security.Principal;
import java.security.Provider; import java.security.Provider;
import java.security.cert.Certificate; import java.security.cert.Certificate;
import java.security.cert.CertificateException; import java.security.cert.CertificateException;
@ -82,6 +83,7 @@ import javax.net.ssl.SSLEngineResult.Status;
import javax.net.ssl.SSLException; import javax.net.ssl.SSLException;
import javax.net.ssl.SSLHandshakeException; import javax.net.ssl.SSLHandshakeException;
import javax.net.ssl.SSLParameters; import javax.net.ssl.SSLParameters;
import javax.net.ssl.SSLPeerUnverifiedException;
import javax.net.ssl.SSLSession; import javax.net.ssl.SSLSession;
import javax.net.ssl.SSLSocketFactory; import javax.net.ssl.SSLSocketFactory;
import javax.net.ssl.TrustManager; import javax.net.ssl.TrustManager;
@ -97,6 +99,7 @@ import static org.junit.Assert.assertFalse;
import static org.junit.Assert.assertNotEquals; import static org.junit.Assert.assertNotEquals;
import static org.junit.Assert.assertNotNull; import static org.junit.Assert.assertNotNull;
import static org.junit.Assert.assertNull; import static org.junit.Assert.assertNull;
import static org.junit.Assert.assertSame;
import static org.junit.Assert.assertTrue; import static org.junit.Assert.assertTrue;
import static org.junit.Assert.fail; import static org.junit.Assert.fail;
import static org.mockito.Mockito.verify; import static org.mockito.Mockito.verify;
@ -2726,6 +2729,183 @@ public abstract class SSLEngineTest {
} }
} }
@Test
public void testSessionAfterHandshake() throws Exception {
testSessionAfterHandshake0(false, false);
}
@Test
public void testSessionAfterHandshakeMutualAuth() throws Exception {
testSessionAfterHandshake0(false, true);
}
@Test
public void testSessionAfterHandshakeKeyManagerFactory() throws Exception {
testSessionAfterHandshake0(true, false);
}
@Test
public void testSessionAfterHandshakeKeyManagerFactoryMutualAuth() throws Exception {
testSessionAfterHandshake0(true, true);
}
private void testSessionAfterHandshake0(boolean useKeyManagerFactory, boolean mutualAuth) throws Exception {
SelfSignedCertificate ssc = new SelfSignedCertificate();
KeyManagerFactory kmf = useKeyManagerFactory ?
SslContext.buildKeyManagerFactory(
new java.security.cert.X509Certificate[] { ssc.cert()}, ssc.key(), null, null) : null;
SslContextBuilder clientContextBuilder = SslContextBuilder.forClient();
if (mutualAuth) {
if (kmf != null) {
clientContextBuilder.keyManager(kmf);
} else {
clientContextBuilder.keyManager(ssc.key(), ssc.cert());
}
}
clientSslCtx = clientContextBuilder
.trustManager(InsecureTrustManagerFactory.INSTANCE)
.sslProvider(sslClientProvider())
.sslContextProvider(clientSslContextProvider())
.protocols(protocols())
.ciphers(ciphers())
.build();
SslContextBuilder serverContextBuilder = kmf != null ?
SslContextBuilder.forServer(kmf) :
SslContextBuilder.forServer(ssc.certificate(), ssc.privateKey());
if (mutualAuth) {
serverContextBuilder.clientAuth(ClientAuth.REQUIRE);
}
serverSslCtx = serverContextBuilder.trustManager(InsecureTrustManagerFactory.INSTANCE)
.sslProvider(sslServerProvider())
.sslContextProvider(serverSslContextProvider())
.protocols(protocols())
.ciphers(ciphers())
.build();
SSLEngine clientEngine = null;
SSLEngine serverEngine = null;
try {
clientEngine = wrapEngine(clientSslCtx.newEngine(UnpooledByteBufAllocator.DEFAULT));
serverEngine = wrapEngine(serverSslCtx.newEngine(UnpooledByteBufAllocator.DEFAULT));
handshake(clientEngine, serverEngine);
SSLSession clientSession = clientEngine.getSession();
SSLSession serverSession = serverEngine.getSession();
assertNull(clientSession.getPeerHost());
assertNull(serverSession.getPeerHost());
assertEquals(-1, clientSession.getPeerPort());
assertEquals(-1, serverSession.getPeerPort());
assertTrue(clientSession.getCreationTime() > 0);
assertTrue(serverSession.getCreationTime() > 0);
assertTrue(clientSession.getLastAccessedTime() > 0);
assertTrue(serverSession.getLastAccessedTime() > 0);
assertEquals(protocolCipherCombo.protocol, clientSession.getProtocol());
assertEquals(protocolCipherCombo.protocol, serverSession.getProtocol());
assertEquals(protocolCipherCombo.cipher, clientSession.getCipherSuite());
assertEquals(protocolCipherCombo.cipher, serverSession.getCipherSuite());
assertNotNull(clientSession.getId());
assertNotNull(serverSession.getId());
assertTrue(clientSession.getApplicationBufferSize() > 0);
assertTrue(serverSession.getApplicationBufferSize() > 0);
assertTrue(clientSession.getPacketBufferSize() > 0);
assertTrue(serverSession.getPacketBufferSize() > 0);
assertNotNull(clientSession.getSessionContext());
assertNotNull(serverSession.getSessionContext());
Object value = new Object();
assertEquals(0, clientSession.getValueNames().length);
clientSession.putValue("test", value);
assertEquals("test", clientSession.getValueNames()[0]);
assertSame(value, clientSession.getValue("test"));
clientSession.removeValue("test");
assertEquals(0, clientSession.getValueNames().length);
assertEquals(0, serverSession.getValueNames().length);
serverSession.putValue("test", value);
assertEquals("test", serverSession.getValueNames()[0]);
assertSame(value, serverSession.getValue("test"));
serverSession.removeValue("test");
assertEquals(0, serverSession.getValueNames().length);
Certificate[] serverLocalCertificates = serverSession.getLocalCertificates();
assertEquals(1, serverLocalCertificates.length);
assertArrayEquals(ssc.cert().getEncoded(), serverLocalCertificates[0].getEncoded());
Principal serverLocalPrincipal = serverSession.getLocalPrincipal();
assertNotNull(serverLocalPrincipal);
if (mutualAuth) {
Certificate[] clientLocalCertificates = clientSession.getLocalCertificates();
assertEquals(1, clientLocalCertificates.length);
Certificate[] serverPeerCertificates = serverSession.getPeerCertificates();
assertEquals(1, serverPeerCertificates.length);
assertArrayEquals(clientLocalCertificates[0].getEncoded(), serverPeerCertificates[0].getEncoded());
X509Certificate[] serverPeerX509Certificates = serverSession.getPeerCertificateChain();
assertEquals(1, serverPeerX509Certificates.length);
assertArrayEquals(clientLocalCertificates[0].getEncoded(), serverPeerX509Certificates[0].getEncoded());
Principal clientLocalPrincipial = clientSession.getLocalPrincipal();
assertNotNull(clientLocalPrincipial);
Principal serverPeerPrincipal = serverSession.getPeerPrincipal();
assertEquals(clientLocalPrincipial, serverPeerPrincipal);
} else {
assertNull(clientSession.getLocalCertificates());
assertNull(clientSession.getLocalPrincipal());
try {
serverSession.getPeerCertificates();
fail();
} catch (SSLPeerUnverifiedException expected) {
// As we did not use mutual auth this is expected
}
try {
serverSession.getPeerCertificateChain();
fail();
} catch (SSLPeerUnverifiedException expected) {
// As we did not use mutual auth this is expected
}
try {
serverSession.getPeerPrincipal();
fail();
} catch (SSLPeerUnverifiedException expected) {
// As we did not use mutual auth this is expected
}
}
Certificate[] clientPeerCertificates = clientSession.getPeerCertificates();
assertEquals(1, clientPeerCertificates.length);
assertArrayEquals(serverLocalCertificates[0].getEncoded(), clientPeerCertificates[0].getEncoded());
X509Certificate[] clientPeerX509Certificates = clientSession.getPeerCertificateChain();
assertEquals(1, clientPeerX509Certificates.length);
assertArrayEquals(serverLocalCertificates[0].getEncoded(), clientPeerX509Certificates[0].getEncoded());
Principal clientPeerPrincipal = clientSession.getPeerPrincipal();
assertEquals(serverLocalPrincipal, clientPeerPrincipal);
} finally {
cleanupClientSslEngine(clientEngine);
cleanupServerSslEngine(serverEngine);
ssc.delete();
}
}
protected SSLEngine wrapEngine(SSLEngine engine) { protected SSLEngine wrapEngine(SSLEngine engine) {
return engine; return engine;
} }