From 8d4d76d2163fe65ab35a5c206ed4c1348e86d671 Mon Sep 17 00:00:00 2001 From: Norman Maurer Date: Fri, 16 Nov 2018 07:38:32 +0100 Subject: [PATCH] ReferenceCountedOpenSslEngine SSLSession.getLocalCertificates() / getLocalPrincipial() did not work when KeyManagerFactory was used. (#8560) Motivation: The SSLSession.getLocalCertificates() / getLocalPrincipial() methods did not correctly return the local configured certificate / principal if a KeyManagerFactory was used when configure the SslContext. Modifications: - Correctly update the local certificates / principial when the key material is selected. - Add test case that verifies the SSLSession after the handshake to ensure we correctly return all values. Result: SSLSession returns correct values also when KeyManagerFactory is used with the OpenSSL provider. --- .../ssl/DefaultOpenSslKeyMaterial.java | 11 +- .../netty/handler/ssl/OpenSslKeyMaterial.java | 7 + .../ssl/OpenSslKeyMaterialManager.java | 3 +- .../ssl/OpenSslKeyMaterialProvider.java | 4 +- .../netty/handler/ssl/OpenSslPrivateKey.java | 16 +- .../ssl/ReferenceCountedOpenSslEngine.java | 15 +- .../netty/handler/ssl/OpenSslEngineTest.java | 14 ++ .../io/netty/handler/ssl/SSLEngineTest.java | 180 ++++++++++++++++++ 8 files changed, 240 insertions(+), 10 deletions(-) diff --git a/handler/src/main/java/io/netty/handler/ssl/DefaultOpenSslKeyMaterial.java b/handler/src/main/java/io/netty/handler/ssl/DefaultOpenSslKeyMaterial.java index 4bba0e741f..fcea5266f2 100644 --- a/handler/src/main/java/io/netty/handler/ssl/DefaultOpenSslKeyMaterial.java +++ b/handler/src/main/java/io/netty/handler/ssl/DefaultOpenSslKeyMaterial.java @@ -22,20 +22,29 @@ import io.netty.util.ResourceLeakDetector; import io.netty.util.ResourceLeakDetectorFactory; import io.netty.util.ResourceLeakTracker; +import java.security.cert.X509Certificate; + final class DefaultOpenSslKeyMaterial extends AbstractReferenceCounted implements OpenSslKeyMaterial { private static final ResourceLeakDetector leakDetector = ResourceLeakDetectorFactory.instance().newResourceLeakDetector(DefaultOpenSslKeyMaterial.class); private final ResourceLeakTracker leak; + private final X509Certificate[] x509CertificateChain; private long chain; private long privateKey; - DefaultOpenSslKeyMaterial(long chain, long privateKey) { + DefaultOpenSslKeyMaterial(long chain, long privateKey, X509Certificate[] x509CertificateChain) { this.chain = chain; this.privateKey = privateKey; + this.x509CertificateChain = x509CertificateChain; leak = leakDetector.track(this); } + @Override + public X509Certificate[] certificateChain() { + return x509CertificateChain.clone(); + } + @Override public long certificateChainAddress() { if (refCnt() <= 0) { diff --git a/handler/src/main/java/io/netty/handler/ssl/OpenSslKeyMaterial.java b/handler/src/main/java/io/netty/handler/ssl/OpenSslKeyMaterial.java index 29099e5fa1..68fc85a3ed 100644 --- a/handler/src/main/java/io/netty/handler/ssl/OpenSslKeyMaterial.java +++ b/handler/src/main/java/io/netty/handler/ssl/OpenSslKeyMaterial.java @@ -17,11 +17,18 @@ package io.netty.handler.ssl; import io.netty.util.ReferenceCounted; +import java.security.cert.X509Certificate; + /** * Holds references to the native key-material that is used by OpenSSL. */ interface OpenSslKeyMaterial extends ReferenceCounted { + /** + * Returns the configured {@link X509Certificate}s. + */ + X509Certificate[] certificateChain(); + /** * Returns the pointer to the {@code STACK_OF(X509)} which holds the certificate chain. */ diff --git a/handler/src/main/java/io/netty/handler/ssl/OpenSslKeyMaterialManager.java b/handler/src/main/java/io/netty/handler/ssl/OpenSslKeyMaterialManager.java index 94398d5cf2..9f0e0199ef 100644 --- a/handler/src/main/java/io/netty/handler/ssl/OpenSslKeyMaterialManager.java +++ b/handler/src/main/java/io/netty/handler/ssl/OpenSslKeyMaterialManager.java @@ -96,8 +96,7 @@ final class OpenSslKeyMaterialManager { try { keyMaterial = provider.chooseKeyMaterial(engine.alloc, alias); if (keyMaterial != null) { - SSL.setKeyMaterial(engine.sslPointer(), - keyMaterial.certificateChainAddress(), keyMaterial.privateKeyAddress()); + engine.setKeyMaterial(keyMaterial); } } catch (SSLException e) { throw e; diff --git a/handler/src/main/java/io/netty/handler/ssl/OpenSslKeyMaterialProvider.java b/handler/src/main/java/io/netty/handler/ssl/OpenSslKeyMaterialProvider.java index 7430b77f7d..72cd2e0c8b 100644 --- a/handler/src/main/java/io/netty/handler/ssl/OpenSslKeyMaterialProvider.java +++ b/handler/src/main/java/io/netty/handler/ssl/OpenSslKeyMaterialProvider.java @@ -66,11 +66,11 @@ class OpenSslKeyMaterialProvider { OpenSslKeyMaterial keyMaterial; if (key instanceof OpenSslPrivateKey) { - keyMaterial = ((OpenSslPrivateKey) key).toKeyMaterial(chain); + keyMaterial = ((OpenSslPrivateKey) key).toKeyMaterial(chain, certificates); } else { pkeyBio = toBIO(allocator, key); pkey = key == null ? 0 : SSL.parsePrivateKey(pkeyBio, password); - keyMaterial = new DefaultOpenSslKeyMaterial(chain, pkey); + keyMaterial = new DefaultOpenSslKeyMaterial(chain, pkey, certificates); } // See the chain and pkey to 0 so we will not release it as the ownership was diff --git a/handler/src/main/java/io/netty/handler/ssl/OpenSslPrivateKey.java b/handler/src/main/java/io/netty/handler/ssl/OpenSslPrivateKey.java index de1ff04daf..67639aae3c 100644 --- a/handler/src/main/java/io/netty/handler/ssl/OpenSslPrivateKey.java +++ b/handler/src/main/java/io/netty/handler/ssl/OpenSslPrivateKey.java @@ -18,9 +18,11 @@ package io.netty.handler.ssl; import io.netty.internal.tcnative.SSL; import io.netty.util.AbstractReferenceCounted; import io.netty.util.IllegalReferenceCountException; +import io.netty.util.internal.EmptyArrays; import javax.security.auth.Destroyable; import java.security.PrivateKey; +import java.security.cert.X509Certificate; final class OpenSslPrivateKey extends AbstractReferenceCounted implements PrivateKey { @@ -110,16 +112,24 @@ final class OpenSslPrivateKey extends AbstractReferenceCounted implements Privat /** * Convert to a {@link OpenSslKeyMaterial}. Reference count of both is shared. */ - OpenSslKeyMaterial toKeyMaterial(long certificateChain) { - return new OpenSslPrivateKeyMaterial(certificateChain); + OpenSslKeyMaterial toKeyMaterial(long certificateChain, X509Certificate[] chain) { + return new OpenSslPrivateKeyMaterial(certificateChain, chain); } private final class OpenSslPrivateKeyMaterial implements OpenSslKeyMaterial { private long certificateChain; + private final X509Certificate[] x509CertificateChain; - OpenSslPrivateKeyMaterial(long certificateChain) { + OpenSslPrivateKeyMaterial(long certificateChain, X509Certificate[] x509CertificateChain) { this.certificateChain = certificateChain; + this.x509CertificateChain = x509CertificateChain == null ? + EmptyArrays.EMPTY_X509_CERTIFICATES : x509CertificateChain; + } + + @Override + public X509Certificate[] certificateChain() { + return x509CertificateChain.clone(); } @Override diff --git a/handler/src/main/java/io/netty/handler/ssl/ReferenceCountedOpenSslEngine.java b/handler/src/main/java/io/netty/handler/ssl/ReferenceCountedOpenSslEngine.java index cf98743379..4537ee75f8 100644 --- a/handler/src/main/java/io/netty/handler/ssl/ReferenceCountedOpenSslEngine.java +++ b/handler/src/main/java/io/netty/handler/ssl/ReferenceCountedOpenSslEngine.java @@ -193,6 +193,7 @@ public class ReferenceCountedOpenSslEngine extends SSLEngine implements Referenc }; private volatile ClientAuth clientAuth = ClientAuth.NONE; + private volatile Certificate[] localCertificateChain; // Updated once a new handshake is started and so the SSLSession reused. private volatile long lastAccessed = -1; @@ -216,7 +217,6 @@ public class ReferenceCountedOpenSslEngine extends SSLEngine implements Referenc private final OpenSslEngineMap engineMap; private final OpenSslApplicationProtocolNegotiator apn; private final OpenSslSession session; - private final Certificate[] localCerts; private final ByteBuffer[] singleSrcBuffer = new ByteBuffer[1]; private final ByteBuffer[] singleDstBuffer = new ByteBuffer[1]; private final boolean enableOcsp; @@ -323,8 +323,11 @@ public class ReferenceCountedOpenSslEngine extends SSLEngine implements Referenc session = new DefaultOpenSslSession(context.sessionContext()); } engineMap = context.engineMap; - localCerts = context.keyCertChain; enableOcsp = context.enableOcsp; + // context.keyCertChain will only be non-null if we do not use the KeyManagerFactory. In this case + // localCertificateChain will be set in setKeyMaterial(...). + localCertificateChain = context.keyCertChain; + this.jdkCompatibilityMode = jdkCompatibilityMode; Lock readerLock = context.ctxLock.readLock(); readerLock.lock(); @@ -379,6 +382,11 @@ public class ReferenceCountedOpenSslEngine extends SSLEngine implements Referenc leak = leakDetection ? leakDetector.track(this) : null; } + final void setKeyMaterial(OpenSslKeyMaterial keyMaterial) throws Exception { + SSL.setKeyMaterial(ssl, keyMaterial.certificateChainAddress(), keyMaterial.privateKeyAddress()); + localCertificateChain = keyMaterial.certificateChain(); + } + /** * Sets the OCSP response. */ @@ -1930,6 +1938,8 @@ public class ReferenceCountedOpenSslEngine extends SSLEngine implements Referenc // thread. private X509Certificate[] x509PeerCerts; private Certificate[] peerCerts; + private Certificate[] localCerts; + private String protocol; private String cipher; private byte[] id; @@ -2070,6 +2080,7 @@ public class ReferenceCountedOpenSslEngine extends SSLEngine implements Referenc id = SSL.getSessionId(ssl); cipher = toJavaCipherSuite(SSL.getCipherForSSL(ssl)); protocol = SSL.getVersion(ssl); + localCerts = localCertificateChain; initPeerCerts(); selectApplicationProtocol(); diff --git a/handler/src/test/java/io/netty/handler/ssl/OpenSslEngineTest.java b/handler/src/test/java/io/netty/handler/ssl/OpenSslEngineTest.java index 1ae88ebdfb..32be76781a 100644 --- a/handler/src/test/java/io/netty/handler/ssl/OpenSslEngineTest.java +++ b/handler/src/test/java/io/netty/handler/ssl/OpenSslEngineTest.java @@ -95,6 +95,20 @@ public class OpenSslEngineTest extends SSLEngineTest { assertEquals("SSL error stack not correctly consumed", 0, SSL.getLastErrorNumber()); } + @Override + @Test + public void testSessionAfterHandshakeKeyManagerFactory() throws Exception { + checkShouldUseKeyManagerFactory(); + super.testSessionAfterHandshakeKeyManagerFactory(); + } + + @Override + @Test + public void testSessionAfterHandshakeKeyManagerFactoryMutualAuth() throws Exception { + checkShouldUseKeyManagerFactory(); + super.testSessionAfterHandshakeKeyManagerFactoryMutualAuth(); + } + @Override @Test public void testMutualAuthInvalidIntermediateCASucceedWithOptionalClientAuth() throws Exception { diff --git a/handler/src/test/java/io/netty/handler/ssl/SSLEngineTest.java b/handler/src/test/java/io/netty/handler/ssl/SSLEngineTest.java index 6e8a231c82..c23800162a 100644 --- a/handler/src/test/java/io/netty/handler/ssl/SSLEngineTest.java +++ b/handler/src/test/java/io/netty/handler/ssl/SSLEngineTest.java @@ -60,6 +60,7 @@ import java.net.InetSocketAddress; import java.nio.ByteBuffer; import java.nio.channels.ClosedChannelException; import java.security.KeyStore; +import java.security.Principal; import java.security.Provider; import java.security.cert.Certificate; import java.security.cert.CertificateException; @@ -82,6 +83,7 @@ import javax.net.ssl.SSLEngineResult.Status; import javax.net.ssl.SSLException; import javax.net.ssl.SSLHandshakeException; import javax.net.ssl.SSLParameters; +import javax.net.ssl.SSLPeerUnverifiedException; import javax.net.ssl.SSLSession; import javax.net.ssl.SSLSocketFactory; import javax.net.ssl.TrustManager; @@ -97,6 +99,7 @@ import static org.junit.Assert.assertFalse; import static org.junit.Assert.assertNotEquals; import static org.junit.Assert.assertNotNull; import static org.junit.Assert.assertNull; +import static org.junit.Assert.assertSame; import static org.junit.Assert.assertTrue; import static org.junit.Assert.fail; import static org.mockito.Mockito.verify; @@ -2726,6 +2729,183 @@ public abstract class SSLEngineTest { } } + @Test + public void testSessionAfterHandshake() throws Exception { + testSessionAfterHandshake0(false, false); + } + + @Test + public void testSessionAfterHandshakeMutualAuth() throws Exception { + testSessionAfterHandshake0(false, true); + } + + @Test + public void testSessionAfterHandshakeKeyManagerFactory() throws Exception { + testSessionAfterHandshake0(true, false); + } + + @Test + public void testSessionAfterHandshakeKeyManagerFactoryMutualAuth() throws Exception { + testSessionAfterHandshake0(true, true); + } + + private void testSessionAfterHandshake0(boolean useKeyManagerFactory, boolean mutualAuth) throws Exception { + SelfSignedCertificate ssc = new SelfSignedCertificate(); + KeyManagerFactory kmf = useKeyManagerFactory ? + SslContext.buildKeyManagerFactory( + new java.security.cert.X509Certificate[] { ssc.cert()}, ssc.key(), null, null) : null; + + SslContextBuilder clientContextBuilder = SslContextBuilder.forClient(); + if (mutualAuth) { + if (kmf != null) { + clientContextBuilder.keyManager(kmf); + } else { + clientContextBuilder.keyManager(ssc.key(), ssc.cert()); + } + } + clientSslCtx = clientContextBuilder + .trustManager(InsecureTrustManagerFactory.INSTANCE) + .sslProvider(sslClientProvider()) + .sslContextProvider(clientSslContextProvider()) + .protocols(protocols()) + .ciphers(ciphers()) + .build(); + + SslContextBuilder serverContextBuilder = kmf != null ? + SslContextBuilder.forServer(kmf) : + SslContextBuilder.forServer(ssc.certificate(), ssc.privateKey()); + if (mutualAuth) { + serverContextBuilder.clientAuth(ClientAuth.REQUIRE); + } + serverSslCtx = serverContextBuilder.trustManager(InsecureTrustManagerFactory.INSTANCE) + .sslProvider(sslServerProvider()) + .sslContextProvider(serverSslContextProvider()) + .protocols(protocols()) + .ciphers(ciphers()) + .build(); + SSLEngine clientEngine = null; + SSLEngine serverEngine = null; + try { + clientEngine = wrapEngine(clientSslCtx.newEngine(UnpooledByteBufAllocator.DEFAULT)); + serverEngine = wrapEngine(serverSslCtx.newEngine(UnpooledByteBufAllocator.DEFAULT)); + + handshake(clientEngine, serverEngine); + + SSLSession clientSession = clientEngine.getSession(); + SSLSession serverSession = serverEngine.getSession(); + + assertNull(clientSession.getPeerHost()); + assertNull(serverSession.getPeerHost()); + assertEquals(-1, clientSession.getPeerPort()); + assertEquals(-1, serverSession.getPeerPort()); + + assertTrue(clientSession.getCreationTime() > 0); + assertTrue(serverSession.getCreationTime() > 0); + + assertTrue(clientSession.getLastAccessedTime() > 0); + assertTrue(serverSession.getLastAccessedTime() > 0); + + assertEquals(protocolCipherCombo.protocol, clientSession.getProtocol()); + assertEquals(protocolCipherCombo.protocol, serverSession.getProtocol()); + + assertEquals(protocolCipherCombo.cipher, clientSession.getCipherSuite()); + assertEquals(protocolCipherCombo.cipher, serverSession.getCipherSuite()); + + assertNotNull(clientSession.getId()); + assertNotNull(serverSession.getId()); + + assertTrue(clientSession.getApplicationBufferSize() > 0); + assertTrue(serverSession.getApplicationBufferSize() > 0); + + assertTrue(clientSession.getPacketBufferSize() > 0); + assertTrue(serverSession.getPacketBufferSize() > 0); + + assertNotNull(clientSession.getSessionContext()); + assertNotNull(serverSession.getSessionContext()); + + Object value = new Object(); + + assertEquals(0, clientSession.getValueNames().length); + clientSession.putValue("test", value); + assertEquals("test", clientSession.getValueNames()[0]); + assertSame(value, clientSession.getValue("test")); + clientSession.removeValue("test"); + assertEquals(0, clientSession.getValueNames().length); + + assertEquals(0, serverSession.getValueNames().length); + serverSession.putValue("test", value); + assertEquals("test", serverSession.getValueNames()[0]); + assertSame(value, serverSession.getValue("test")); + serverSession.removeValue("test"); + assertEquals(0, serverSession.getValueNames().length); + + Certificate[] serverLocalCertificates = serverSession.getLocalCertificates(); + assertEquals(1, serverLocalCertificates.length); + assertArrayEquals(ssc.cert().getEncoded(), serverLocalCertificates[0].getEncoded()); + + Principal serverLocalPrincipal = serverSession.getLocalPrincipal(); + assertNotNull(serverLocalPrincipal); + + if (mutualAuth) { + Certificate[] clientLocalCertificates = clientSession.getLocalCertificates(); + assertEquals(1, clientLocalCertificates.length); + + Certificate[] serverPeerCertificates = serverSession.getPeerCertificates(); + assertEquals(1, serverPeerCertificates.length); + assertArrayEquals(clientLocalCertificates[0].getEncoded(), serverPeerCertificates[0].getEncoded()); + + X509Certificate[] serverPeerX509Certificates = serverSession.getPeerCertificateChain(); + assertEquals(1, serverPeerX509Certificates.length); + assertArrayEquals(clientLocalCertificates[0].getEncoded(), serverPeerX509Certificates[0].getEncoded()); + + Principal clientLocalPrincipial = clientSession.getLocalPrincipal(); + assertNotNull(clientLocalPrincipial); + + Principal serverPeerPrincipal = serverSession.getPeerPrincipal(); + assertEquals(clientLocalPrincipial, serverPeerPrincipal); + } else { + assertNull(clientSession.getLocalCertificates()); + assertNull(clientSession.getLocalPrincipal()); + + try { + serverSession.getPeerCertificates(); + fail(); + } catch (SSLPeerUnverifiedException expected) { + // As we did not use mutual auth this is expected + } + + try { + serverSession.getPeerCertificateChain(); + fail(); + } catch (SSLPeerUnverifiedException expected) { + // As we did not use mutual auth this is expected + } + + try { + serverSession.getPeerPrincipal(); + fail(); + } catch (SSLPeerUnverifiedException expected) { + // As we did not use mutual auth this is expected + } + } + + Certificate[] clientPeerCertificates = clientSession.getPeerCertificates(); + assertEquals(1, clientPeerCertificates.length); + assertArrayEquals(serverLocalCertificates[0].getEncoded(), clientPeerCertificates[0].getEncoded()); + + X509Certificate[] clientPeerX509Certificates = clientSession.getPeerCertificateChain(); + assertEquals(1, clientPeerX509Certificates.length); + assertArrayEquals(serverLocalCertificates[0].getEncoded(), clientPeerX509Certificates[0].getEncoded()); + + Principal clientPeerPrincipal = clientSession.getPeerPrincipal(); + assertEquals(serverLocalPrincipal, clientPeerPrincipal); + } finally { + cleanupClientSslEngine(clientEngine); + cleanupServerSslEngine(serverEngine); + ssc.delete(); + } + } + protected SSLEngine wrapEngine(SSLEngine engine) { return engine; }