diff --git a/handler/src/main/java/io/netty/handler/ssl/OpenSslKeyMaterialProvider.java b/handler/src/main/java/io/netty/handler/ssl/OpenSslKeyMaterialProvider.java index b22ff1d091..f931fcfdbb 100644 --- a/handler/src/main/java/io/netty/handler/ssl/OpenSslKeyMaterialProvider.java +++ b/handler/src/main/java/io/netty/handler/ssl/OpenSslKeyMaterialProvider.java @@ -120,7 +120,7 @@ class OpenSslKeyMaterialProvider { OpenSslKeyMaterial keyMaterial; if (key instanceof OpenSslPrivateKey) { - keyMaterial = ((OpenSslPrivateKey) key).toKeyMaterial(chain, certificates); + keyMaterial = ((OpenSslPrivateKey) key).newKeyMaterial(chain, certificates); } else { pkeyBio = toBIO(allocator, key); pkey = key == null ? 0 : SSL.parsePrivateKey(pkeyBio, password); diff --git a/handler/src/main/java/io/netty/handler/ssl/OpenSslPrivateKey.java b/handler/src/main/java/io/netty/handler/ssl/OpenSslPrivateKey.java index 67639aae3c..4b3bbeb5d6 100644 --- a/handler/src/main/java/io/netty/handler/ssl/OpenSslPrivateKey.java +++ b/handler/src/main/java/io/netty/handler/ssl/OpenSslPrivateKey.java @@ -34,7 +34,7 @@ final class OpenSslPrivateKey extends AbstractReferenceCounted implements Privat @Override public String getAlgorithm() { - return "unkown"; + return "unknown"; } @Override @@ -48,10 +48,7 @@ final class OpenSslPrivateKey extends AbstractReferenceCounted implements Privat return null; } - /** - * Returns the pointer to the {@code EVP_PKEY}. - */ - long privateKeyAddress() { + private long privateKeyAddress() { if (refCnt() <= 0) { throw new IllegalReferenceCountException(); } @@ -110,21 +107,27 @@ final class OpenSslPrivateKey extends AbstractReferenceCounted implements Privat } /** - * Convert to a {@link OpenSslKeyMaterial}. Reference count of both is shared. + * Create a new {@link OpenSslKeyMaterial} which uses the private key that is held by {@link OpenSslPrivateKey}. + * + * When the material is created we increment the reference count of the enclosing {@link OpenSslPrivateKey} and + * decrement it again when the reference count of the {@link OpenSslKeyMaterial} reaches {@code 0}. */ - OpenSslKeyMaterial toKeyMaterial(long certificateChain, X509Certificate[] chain) { + OpenSslKeyMaterial newKeyMaterial(long certificateChain, X509Certificate[] chain) { return new OpenSslPrivateKeyMaterial(certificateChain, chain); } - private final class OpenSslPrivateKeyMaterial implements OpenSslKeyMaterial { + // Package-private for unit-test only + final class OpenSslPrivateKeyMaterial extends AbstractReferenceCounted implements OpenSslKeyMaterial { - private long certificateChain; + // Package-private for unit-test only + long certificateChain; private final X509Certificate[] x509CertificateChain; OpenSslPrivateKeyMaterial(long certificateChain, X509Certificate[] x509CertificateChain) { this.certificateChain = certificateChain; this.x509CertificateChain = x509CertificateChain == null ? EmptyArrays.EMPTY_X509_CERTIFICATES : x509CertificateChain; + OpenSslPrivateKey.this.retain(); } @Override @@ -142,18 +145,27 @@ final class OpenSslPrivateKey extends AbstractReferenceCounted implements Privat @Override public long privateKeyAddress() { + if (refCnt() <= 0) { + throw new IllegalReferenceCountException(); + } return OpenSslPrivateKey.this.privateKeyAddress(); } + @Override + public OpenSslKeyMaterial touch(Object hint) { + OpenSslPrivateKey.this.touch(hint); + return this; + } + @Override public OpenSslKeyMaterial retain() { - OpenSslPrivateKey.this.retain(); + super.retain(); return this; } @Override public OpenSslKeyMaterial retain(int increment) { - OpenSslPrivateKey.this.retain(increment); + super.retain(increment); return this; } @@ -164,37 +176,14 @@ final class OpenSslPrivateKey extends AbstractReferenceCounted implements Privat } @Override - public OpenSslKeyMaterial touch(Object hint) { - OpenSslPrivateKey.this.touch(hint); - return this; - } - - @Override - public boolean release() { - if (OpenSslPrivateKey.this.release()) { - releaseChain(); - return true; - } - return false; - } - - @Override - public boolean release(int decrement) { - if (OpenSslPrivateKey.this.release(decrement)) { - releaseChain(); - return true; - } - return false; + protected void deallocate() { + releaseChain(); + OpenSslPrivateKey.this.release(); } private void releaseChain() { SSL.freeX509Chain(certificateChain); certificateChain = 0; } - - @Override - public int refCnt() { - return OpenSslPrivateKey.this.refCnt(); - } } } diff --git a/handler/src/test/java/io/netty/handler/ssl/OpenSslKeyMaterialProviderTest.java b/handler/src/test/java/io/netty/handler/ssl/OpenSslKeyMaterialProviderTest.java index 5b793fe87b..800953a8ae 100644 --- a/handler/src/test/java/io/netty/handler/ssl/OpenSslKeyMaterialProviderTest.java +++ b/handler/src/test/java/io/netty/handler/ssl/OpenSslKeyMaterialProviderTest.java @@ -15,13 +15,21 @@ */ package io.netty.handler.ssl; +import io.netty.buffer.ByteBufAllocator; import io.netty.buffer.UnpooledByteBufAllocator; +import io.netty.internal.tcnative.SSL; +import io.netty.util.ReferenceCountUtil; import org.junit.BeforeClass; import org.junit.Test; import javax.net.ssl.KeyManagerFactory; +import javax.net.ssl.X509KeyManager; +import java.net.Socket; import java.security.KeyStore; +import java.security.Principal; +import java.security.PrivateKey; +import java.security.cert.X509Certificate; import static org.junit.Assert.*; import static org.junit.Assume.assumeTrue; @@ -72,4 +80,94 @@ public class OpenSslKeyMaterialProviderTest { provider.destroy(); } + + /** + * Test class used by testChooseOpenSslPrivateKeyMaterial(). + */ + private static final class SingleKeyManager implements X509KeyManager { + private final String keyAlias; + private final PrivateKey pk; + private final X509Certificate[] certChain; + + SingleKeyManager(String keyAlias, PrivateKey pk, X509Certificate[] certChain) { + this.keyAlias = keyAlias; + this.pk = pk; + this.certChain = certChain; + } + + @Override + public String[] getClientAliases(String keyType, Principal[] issuers) { + return new String[]{keyAlias}; + } + + @Override + public String chooseClientAlias(String[] keyType, Principal[] issuers, Socket socket) { + return keyAlias; + } + + @Override + public String[] getServerAliases(String keyType, Principal[] issuers) { + return new String[]{keyAlias}; + } + + @Override + public String chooseServerAlias(String keyType, Principal[] issuers, Socket socket) { + return keyAlias; + } + + @Override + public X509Certificate[] getCertificateChain(String alias) { + return certChain; + } + + @Override + public PrivateKey getPrivateKey(String alias) { + return pk; + } + } + + @Test + public void testChooseOpenSslPrivateKeyMaterial() throws Exception { + PrivateKey privateKey = SslContext.toPrivateKey( + getClass().getResourceAsStream("localhost_server.key"), + null); + assertNotNull(privateKey); + assertEquals("PKCS#8", privateKey.getFormat()); + final X509Certificate[] certChain = SslContext.toX509Certificates( + getClass().getResourceAsStream("localhost_server.pem")); + assertNotNull(certChain); + PemEncoded pemKey = null; + long pkeyBio = 0L; + OpenSslPrivateKey sslPrivateKey; + try { + pemKey = PemPrivateKey.toPEM(ByteBufAllocator.DEFAULT, true, privateKey); + pkeyBio = ReferenceCountedOpenSslContext.toBIO(ByteBufAllocator.DEFAULT, pemKey.retain()); + sslPrivateKey = new OpenSslPrivateKey(SSL.parsePrivateKey(pkeyBio, null)); + } finally { + ReferenceCountUtil.safeRelease(pemKey); + if (pkeyBio != 0L) { + SSL.freeBIO(pkeyBio); + } + } + final String keyAlias = "key"; + + OpenSslKeyMaterialProvider provider = new OpenSslKeyMaterialProvider( + new SingleKeyManager(keyAlias, sslPrivateKey, certChain), + null); + OpenSslKeyMaterial material = provider.chooseKeyMaterial(ByteBufAllocator.DEFAULT, keyAlias); + assertNotNull(material); + assertEquals(2, sslPrivateKey.refCnt()); + assertEquals(1, material.refCnt()); + assertTrue(material.release()); + assertEquals(1, sslPrivateKey.refCnt()); + // Can get material multiple times from the same key + material = provider.chooseKeyMaterial(ByteBufAllocator.DEFAULT, keyAlias); + assertNotNull(material); + assertEquals(2, sslPrivateKey.refCnt()); + assertTrue(material.release()); + assertTrue(sslPrivateKey.release()); + assertEquals(0, sslPrivateKey.refCnt()); + assertEquals(0, material.refCnt()); + assertEquals(0, ((OpenSslPrivateKey.OpenSslPrivateKeyMaterial) material).certificateChain); + } }