From f8796c7eafcc7e55b4b3939ab608e3c7b6fb01a0 Mon Sep 17 00:00:00 2001 From: Norman Maurer Date: Thu, 8 Jul 2021 16:19:22 +0200 Subject: [PATCH] Introduce OpenSslAsyncPrivateKeyMethod which allows to asynchronously sign / decrypt the private key (#11390) (#11460) Motivation: At the moment we only support signing / decrypting the private key in a synchronous fashion. This is quite limited as we may want to do a network call to do so on a remote system for example. Modifications: - Update to latest netty-tcnative which supports running tasks in an asynchronous fashion. - Add OpenSslAsyncPrivateKeyMethod interface - Adjust SslHandler to be able to handle asynchronous task execution - Adjust unit tests to test that asynchronous task execution works in all cases Result: Be able to asynchronous do key signing operations --- .../io/netty/handler/ssl/AsyncRunnable.java | 20 +++ .../ssl/OpenSslAsyncPrivateKeyMethod.java | 58 +++++++++ .../handler/ssl/OpenSslContextOption.java | 9 ++ .../ssl/ReferenceCountedOpenSslContext.java | 96 ++++++++++++++- .../ssl/ReferenceCountedOpenSslEngine.java | 55 +++++++-- .../java/io/netty/handler/ssl/SslHandler.java | 116 ++++++++++++++---- .../ssl/OpenSslPrivateKeyMethodTest.java | 111 +++++++++++++++-- .../ssl/ParameterizedSslHandlerTest.java | 23 +++- 8 files changed, 438 insertions(+), 50 deletions(-) create mode 100644 handler/src/main/java/io/netty/handler/ssl/AsyncRunnable.java create mode 100644 handler/src/main/java/io/netty/handler/ssl/OpenSslAsyncPrivateKeyMethod.java diff --git a/handler/src/main/java/io/netty/handler/ssl/AsyncRunnable.java b/handler/src/main/java/io/netty/handler/ssl/AsyncRunnable.java new file mode 100644 index 0000000000..6fb620e227 --- /dev/null +++ b/handler/src/main/java/io/netty/handler/ssl/AsyncRunnable.java @@ -0,0 +1,20 @@ +/* + * Copyright 2021 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.ssl; + +interface AsyncRunnable extends Runnable { + void run(Runnable completionCallback); +} diff --git a/handler/src/main/java/io/netty/handler/ssl/OpenSslAsyncPrivateKeyMethod.java b/handler/src/main/java/io/netty/handler/ssl/OpenSslAsyncPrivateKeyMethod.java new file mode 100644 index 0000000000..27edaa629f --- /dev/null +++ b/handler/src/main/java/io/netty/handler/ssl/OpenSslAsyncPrivateKeyMethod.java @@ -0,0 +1,58 @@ +/* + * Copyright 2021 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.ssl; + +import io.netty.internal.tcnative.SSLPrivateKeyMethod; +import io.netty.util.concurrent.Future; + +import javax.net.ssl.SSLEngine; + +public interface OpenSslAsyncPrivateKeyMethod { + int SSL_SIGN_RSA_PKCS1_SHA1 = SSLPrivateKeyMethod.SSL_SIGN_RSA_PKCS1_SHA1; + int SSL_SIGN_RSA_PKCS1_SHA256 = SSLPrivateKeyMethod.SSL_SIGN_RSA_PKCS1_SHA256; + int SSL_SIGN_RSA_PKCS1_SHA384 = SSLPrivateKeyMethod.SSL_SIGN_RSA_PKCS1_SHA384; + int SSL_SIGN_RSA_PKCS1_SHA512 = SSLPrivateKeyMethod.SSL_SIGN_RSA_PKCS1_SHA512; + int SSL_SIGN_ECDSA_SHA1 = SSLPrivateKeyMethod.SSL_SIGN_ECDSA_SHA1; + int SSL_SIGN_ECDSA_SECP256R1_SHA256 = SSLPrivateKeyMethod.SSL_SIGN_ECDSA_SECP256R1_SHA256; + int SSL_SIGN_ECDSA_SECP384R1_SHA384 = SSLPrivateKeyMethod.SSL_SIGN_ECDSA_SECP384R1_SHA384; + int SSL_SIGN_ECDSA_SECP521R1_SHA512 = SSLPrivateKeyMethod.SSL_SIGN_ECDSA_SECP521R1_SHA512; + int SSL_SIGN_RSA_PSS_RSAE_SHA256 = SSLPrivateKeyMethod.SSL_SIGN_RSA_PSS_RSAE_SHA256; + int SSL_SIGN_RSA_PSS_RSAE_SHA384 = SSLPrivateKeyMethod.SSL_SIGN_RSA_PSS_RSAE_SHA384; + int SSL_SIGN_RSA_PSS_RSAE_SHA512 = SSLPrivateKeyMethod.SSL_SIGN_RSA_PSS_RSAE_SHA512; + int SSL_SIGN_ED25519 = SSLPrivateKeyMethod.SSL_SIGN_ED25519; + int SSL_SIGN_RSA_PKCS1_MD5_SHA1 = SSLPrivateKeyMethod.SSL_SIGN_RSA_PKCS1_MD5_SHA1; + + /** + * Signs the input with the given key and notifies the returned {@link Future} with the signed bytes. + * + * @param engine the {@link SSLEngine} + * @param signatureAlgorithm the algorithm to use for signing + * @param input the digest itself + * @return the {@link Future} that will be notified with the signed data + * (must not be {@code null}) when the operation completes. + */ + Future sign(SSLEngine engine, int signatureAlgorithm, byte[] input); + + /** + * Decrypts the input with the given key and notifies the returned {@link Future} with the decrypted bytes. + * + * @param engine the {@link SSLEngine} + * @param input the input which should be decrypted + * @return the {@link Future} that will be notified with the decrypted data + * (must not be {@code null}) when the operation completes. + */ + Future decrypt(SSLEngine engine, byte[] input); +} diff --git a/handler/src/main/java/io/netty/handler/ssl/OpenSslContextOption.java b/handler/src/main/java/io/netty/handler/ssl/OpenSslContextOption.java index 5cfd7da3d0..cf7d11c09c 100644 --- a/handler/src/main/java/io/netty/handler/ssl/OpenSslContextOption.java +++ b/handler/src/main/java/io/netty/handler/ssl/OpenSslContextOption.java @@ -49,4 +49,13 @@ public final class OpenSslContextOption extends SslContextOption { */ public static final OpenSslContextOption PRIVATE_KEY_METHOD = new OpenSslContextOption("PRIVATE_KEY_METHOD"); + + /** + * Set the {@link OpenSslAsyncPrivateKeyMethod} to use. This allows to offload private-key operations + * if needed. + * + * This is currently only supported when {@code BoringSSL} is used. + */ + public static final OpenSslContextOption ASYNC_PRIVATE_KEY_METHOD = + new OpenSslContextOption("ASYNC_PRIVATE_KEY_METHOD"); } diff --git a/handler/src/main/java/io/netty/handler/ssl/ReferenceCountedOpenSslContext.java b/handler/src/main/java/io/netty/handler/ssl/ReferenceCountedOpenSslContext.java index bb18f663c9..e9983b8003 100644 --- a/handler/src/main/java/io/netty/handler/ssl/ReferenceCountedOpenSslContext.java +++ b/handler/src/main/java/io/netty/handler/ssl/ReferenceCountedOpenSslContext.java @@ -18,7 +18,9 @@ package io.netty.handler.ssl; import io.netty.buffer.ByteBuf; import io.netty.buffer.ByteBufAllocator; import io.netty.handler.ssl.util.LazyX509Certificate; +import io.netty.internal.tcnative.AsyncSSLPrivateKeyMethod; import io.netty.internal.tcnative.CertificateVerifier; +import io.netty.internal.tcnative.ResultCallback; import io.netty.internal.tcnative.SSL; import io.netty.internal.tcnative.SSLContext; import io.netty.internal.tcnative.SSLPrivateKeyMethod; @@ -27,6 +29,8 @@ import io.netty.util.ReferenceCounted; import io.netty.util.ResourceLeakDetector; import io.netty.util.ResourceLeakDetectorFactory; import io.netty.util.ResourceLeakTracker; +import io.netty.util.concurrent.Future; +import io.netty.util.concurrent.FutureListener; import io.netty.util.internal.StringUtil; import io.netty.util.internal.SystemPropertyUtil; import io.netty.util.internal.UnstableApi; @@ -63,6 +67,7 @@ import javax.net.ssl.X509TrustManager; import static io.netty.handler.ssl.OpenSsl.DEFAULT_CIPHERS; import static io.netty.handler.ssl.OpenSsl.availableJavaCipherSuites; +import static io.netty.handler.ssl.OpenSsl.ensureAvailability; import static io.netty.util.internal.ObjectUtil.checkNonEmpty; import static io.netty.util.internal.ObjectUtil.checkPositiveOrZero; import static java.util.Objects.requireNonNull; @@ -217,6 +222,7 @@ public abstract class ReferenceCountedOpenSslContext extends SslContext implemen boolean tlsFalseStart = false; boolean useTasks = USE_TASKS; OpenSslPrivateKeyMethod privateKeyMethod = null; + OpenSslAsyncPrivateKeyMethod asyncPrivateKeyMethod = null; if (ctxOptions != null) { for (Map.Entry, Object> ctxOpt : ctxOptions) { @@ -228,12 +234,20 @@ public abstract class ReferenceCountedOpenSslContext extends SslContext implemen useTasks = (Boolean) ctxOpt.getValue(); } else if (option == OpenSslContextOption.PRIVATE_KEY_METHOD) { privateKeyMethod = (OpenSslPrivateKeyMethod) ctxOpt.getValue(); + } else if (option == OpenSslContextOption.ASYNC_PRIVATE_KEY_METHOD) { + asyncPrivateKeyMethod = (OpenSslAsyncPrivateKeyMethod) ctxOpt.getValue(); } else { logger.debug("Skipping unsupported " + SslContextOption.class.getSimpleName() + ": " + ctxOpt.getKey()); } } } + if (privateKeyMethod != null && asyncPrivateKeyMethod != null) { + throw new IllegalArgumentException("You can either only use " + + OpenSslAsyncPrivateKeyMethod.class.getSimpleName() + " or " + + OpenSslPrivateKeyMethod.class.getSimpleName()); + } + this.tlsFalseStart = tlsFalseStart; leak = leakDetection ? leakDetector.track(this) : null; @@ -362,6 +376,9 @@ public abstract class ReferenceCountedOpenSslContext extends SslContext implemen if (privateKeyMethod != null) { SSLContext.setPrivateKeyMethod(ctx, new PrivateKeyMethod(engineMap, privateKeyMethod)); } + if (asyncPrivateKeyMethod != null) { + SSLContext.setPrivateKeyMethod(ctx, new AsyncPrivateKeyMethod(engineMap, asyncPrivateKeyMethod)); + } success = true; } finally { if (!success) { @@ -966,12 +983,83 @@ public abstract class ReferenceCountedOpenSslContext extends SslContext implemen throw e; } } + } - private static byte[] verifyResult(byte[] result) throws SignatureException { - if (result == null) { - throw new SignatureException(); + private static final class AsyncPrivateKeyMethod implements AsyncSSLPrivateKeyMethod { + + private final OpenSslEngineMap engineMap; + private final OpenSslAsyncPrivateKeyMethod keyMethod; + + AsyncPrivateKeyMethod(OpenSslEngineMap engineMap, OpenSslAsyncPrivateKeyMethod keyMethod) { + this.engineMap = engineMap; + this.keyMethod = keyMethod; + } + + private ReferenceCountedOpenSslEngine retrieveEngine(long ssl) throws SSLException { + ReferenceCountedOpenSslEngine engine = engineMap.get(ssl); + if (engine == null) { + throw new SSLException("Could not find a " + + StringUtil.simpleClassName(ReferenceCountedOpenSslEngine.class) + " for sslPointer " + ssl); + } + return engine; + } + + @Override + public void sign(long ssl, int signatureAlgorithm, byte[] bytes, ResultCallback resultCallback) { + try { + ReferenceCountedOpenSslEngine engine = retrieveEngine(ssl); + keyMethod.sign(engine, signatureAlgorithm, bytes) + .addListener(new ResultCallbackListener(engine, ssl, resultCallback)); + } catch (SSLException e) { + resultCallback.onError(ssl, e); + } + } + + @Override + public void decrypt(long ssl, byte[] bytes, ResultCallback resultCallback) { + try { + ReferenceCountedOpenSslEngine engine = retrieveEngine(ssl); + keyMethod.decrypt(engine, bytes) + .addListener(new ResultCallbackListener(engine, ssl, resultCallback)); + } catch (SSLException e) { + resultCallback.onError(ssl, e); + } + } + + private static final class ResultCallbackListener implements FutureListener { + private final ReferenceCountedOpenSslEngine engine; + private final long ssl; + private final ResultCallback resultCallback; + + ResultCallbackListener(ReferenceCountedOpenSslEngine engine, long ssl, + ResultCallback resultCallback) { + this.engine = engine; + this.ssl = ssl; + this.resultCallback = resultCallback; + } + + @Override + public void operationComplete(Future future) { + Throwable cause = future.cause(); + if (cause == null) { + try { + byte[] result = verifyResult(future.getNow()); + resultCallback.onSuccess(ssl, result); + return; + } catch (SignatureException e) { + cause = e; + engine.initHandshakeException(e); + } + } + resultCallback.onError(ssl, cause); } - return result; } } + + private static byte[] verifyResult(byte[] result) throws SignatureException { + if (result == null) { + throw new SignatureException(); + } + return result; + } } diff --git a/handler/src/main/java/io/netty/handler/ssl/ReferenceCountedOpenSslEngine.java b/handler/src/main/java/io/netty/handler/ssl/ReferenceCountedOpenSslEngine.java index 8d2e511238..48eff8542f 100644 --- a/handler/src/main/java/io/netty/handler/ssl/ReferenceCountedOpenSslEngine.java +++ b/handler/src/main/java/io/netty/handler/ssl/ReferenceCountedOpenSslEngine.java @@ -19,6 +19,7 @@ import io.netty.buffer.ByteBuf; import io.netty.buffer.ByteBufAllocator; import io.netty.handler.ssl.util.LazyJavaxX509Certificate; import io.netty.handler.ssl.util.LazyX509Certificate; +import io.netty.internal.tcnative.AsyncTask; import io.netty.internal.tcnative.Buffer; import io.netty.internal.tcnative.SSL; import io.netty.util.AbstractReferenceCounted; @@ -1420,16 +1421,14 @@ public class ReferenceCountedOpenSslEngine extends SSLEngine implements Referenc } } - @Override - public final synchronized Runnable getDelegatedTask() { - if (isDestroyed()) { - return null; + private class TaskDecorator implements Runnable { + protected final R task; + TaskDecorator(R task) { + this.task = task; } - final Runnable task = SSL.getTask(ssl); - if (task == null) { - return null; - } - return () -> { + + @Override + public void run() { if (isDestroyed()) { // The engine was destroyed in the meantime, just return. return; @@ -1440,7 +1439,43 @@ public class ReferenceCountedOpenSslEngine extends SSLEngine implements Referenc // The task was run, reset needTask to false so getHandshakeStatus() returns the correct value. needTask = false; } - }; + } + } + + private final class AsyncTaskDecorator extends TaskDecorator implements AsyncRunnable { + AsyncTaskDecorator(AsyncTask task) { + super(task); + } + + @Override + public void run(Runnable runnable) { + if (isDestroyed()) { + // The engine was destroyed in the meantime, just return. + runnable.run(); + return; + } + try { + task.runAsync(runnable); + } finally { + // The task was run, reset needTask to false so getHandshakeStatus() returns the correct value. + needTask = false; + } + } + } + + @Override + public final synchronized Runnable getDelegatedTask() { + if (isDestroyed()) { + return null; + } + final Runnable task = SSL.getTask(ssl); + if (task == null) { + return null; + } + if (task instanceof AsyncTask) { + return new AsyncTaskDecorator((AsyncTask) task); + } + return new TaskDecorator<>(task); } @Override diff --git a/handler/src/main/java/io/netty/handler/ssl/SslHandler.java b/handler/src/main/java/io/netty/handler/ssl/SslHandler.java index 0ac50dd3f1..cf9824b28f 100644 --- a/handler/src/main/java/io/netty/handler/ssl/SslHandler.java +++ b/handler/src/main/java/io/netty/handler/ssl/SslHandler.java @@ -1388,7 +1388,8 @@ public class SslHandler extends ByteToMessageDecoder { } if (handshakeStatus == HandshakeStatus.NEED_TASK) { - if (!runDelegatedTasks(true)) { + boolean pending = runDelegatedTasks(true); + if (!pending) { // We scheduled a task on the delegatingTaskExecutor, so stop processing as we will // resume once the task completes. // @@ -1494,16 +1495,6 @@ public class SslHandler extends ByteToMessageDecoder { return executor instanceof EventExecutor && ((EventExecutor) executor).inEventLoop(); } - private static void runAllDelegatedTasks(SSLEngine engine) { - for (;;) { - Runnable task = engine.getDelegatedTask(); - if (task == null) { - return; - } - task.run(); - } - } - /** * Will either run the delegated task directly calling {@link Runnable#run()} and return {@code true} or will * offload the delegated task using {@link Executor#execute(Runnable)} and return {@code false}. @@ -1515,16 +1506,50 @@ public class SslHandler extends ByteToMessageDecoder { if (delegatedTaskExecutor == ImmediateExecutor.INSTANCE || inEventLoop(delegatedTaskExecutor)) { // We should run the task directly in the EventExecutor thread and not offload at all. As we are on the // EventLoop we can just run all tasks at once. - runAllDelegatedTasks(engine); - return true; + for (;;) { + Runnable task = engine.getDelegatedTask(); + if (task == null) { + return true; + } + setState(STATE_PROCESS_TASK); + if (task instanceof AsyncRunnable) { + // Let's set the task to processing task before we try to execute it. + boolean pending = false; + try { + AsyncRunnable asyncTask = (AsyncRunnable) task; + AsyncTaskCompletionHandler completionHandler = new AsyncTaskCompletionHandler(inUnwrap); + asyncTask.run(completionHandler); + pending = completionHandler.resumeLater(); + if (pending) { + return false; + } + } finally { + if (!pending) { + // The task has completed, lets clear the state. If it is not completed we will clear the + // state once it is. + clearState(STATE_PROCESS_TASK); + } + } + } else { + try { + task.run(); + } finally { + clearState(STATE_PROCESS_TASK); + } + } + } } else { executeDelegatedTask(inUnwrap); return false; } } + private SslTasksRunner getTaskRunner(boolean inUnwrap) { + return inUnwrap ? sslTaskRunnerForUnwrap : sslTaskRunner; + } + private void executeDelegatedTask(boolean inUnwrap) { - executeDelegatedTask(inUnwrap ? sslTaskRunnerForUnwrap : sslTaskRunner); + executeDelegatedTask(getTaskRunner(inUnwrap)); } private void executeDelegatedTask(SslTasksRunner task) { @@ -1537,12 +1562,44 @@ public class SslHandler extends ByteToMessageDecoder { } } + private final class AsyncTaskCompletionHandler implements Runnable { + private final boolean inUnwrap; + boolean didRun; + boolean resumeLater; + + AsyncTaskCompletionHandler(boolean inUnwrap) { + this.inUnwrap = inUnwrap; + } + + @Override + public void run() { + didRun = true; + if (resumeLater) { + getTaskRunner(inUnwrap).runComplete(); + } + } + + boolean resumeLater() { + if (!didRun) { + resumeLater = true; + return true; + } + return false; + } + } + /** * {@link Runnable} that will be scheduled on the {@code delegatedTaskExecutor} and will take care * of resume work on the {@link EventExecutor} once the task was executed. */ private final class SslTasksRunner implements Runnable { private final boolean inUnwrap; + private final Runnable runCompleteTask = new Runnable() { + @Override + public void run() { + runComplete(); + } + }; SslTasksRunner(boolean inUnwrap) { this.inUnwrap = inUnwrap; @@ -1684,6 +1741,21 @@ public class SslHandler extends ByteToMessageDecoder { } } + void runComplete() { + EventExecutor executor = ctx.executor(); + if (executor.inEventLoop()) { + resumeOnEventExecutor(); + } else { + // Jump back on the EventExecutor. + executor.execute(new Runnable() { + @Override + public void run() { + resumeOnEventExecutor(); + } + }); + } + } + @Override public void run() { try { @@ -1692,13 +1764,12 @@ public class SslHandler extends ByteToMessageDecoder { // The task was processed in the meantime. Let's just return. return; } - task.run(); - - EventExecutor executor = ctx.executor(); - if (executor.inEventLoop()) { - resumeOnEventExecutor(); + if (task instanceof AsyncRunnable) { + AsyncRunnable asyncTask = (AsyncRunnable) task; + asyncTask.run(runCompleteTask); } else { - executor.execute(this::resumeOnEventExecutor); + task.run(); + runComplete(); } } catch (final Throwable cause) { handleException(cause); @@ -1706,12 +1777,13 @@ public class SslHandler extends ByteToMessageDecoder { } private void handleException(final Throwable cause) { - if (ctx.executor().inEventLoop()) { + EventExecutor executor = ctx.executor(); + if (executor.inEventLoop()) { clearState(STATE_PROCESS_TASK); safeExceptionCaught(cause); } else { try { - ctx.executor().execute(() -> { + executor.execute(() -> { clearState(STATE_PROCESS_TASK); safeExceptionCaught(cause); }); diff --git a/handler/src/test/java/io/netty/handler/ssl/OpenSslPrivateKeyMethodTest.java b/handler/src/test/java/io/netty/handler/ssl/OpenSslPrivateKeyMethodTest.java index 2ca3372b7e..6e6f6d4681 100644 --- a/handler/src/test/java/io/netty/handler/ssl/OpenSslPrivateKeyMethodTest.java +++ b/handler/src/test/java/io/netty/handler/ssl/OpenSslPrivateKeyMethodTest.java @@ -35,9 +35,10 @@ import io.netty.channel.local.LocalServerChannel; import io.netty.handler.ssl.util.InsecureTrustManagerFactory; import io.netty.handler.ssl.util.SelfSignedCertificate; import io.netty.util.ReferenceCountUtil; +import io.netty.util.concurrent.Future; +import io.netty.util.concurrent.ImmediateEventExecutor; import io.netty.util.concurrent.Promise; import org.hamcrest.Matchers; -import org.junit.Assume; import org.junit.jupiter.api.AfterAll; import org.junit.jupiter.api.BeforeAll; import org.junit.jupiter.params.ParameterizedTest; @@ -60,15 +61,16 @@ import java.util.List; import java.util.concurrent.Executor; import java.util.concurrent.ExecutorService; import java.util.concurrent.Executors; +import java.util.concurrent.ThreadLocalRandom; import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicBoolean; import static io.netty.handler.ssl.OpenSslTestUtils.checkShouldUseKeyManagerFactory; import static org.hamcrest.MatcherAssert.assertThat; import static org.junit.jupiter.api.Assertions.assertEquals; -import static org.junit.jupiter.api.Assertions.assertNotEquals; import static org.junit.jupiter.api.Assertions.assertNotNull; import static org.junit.jupiter.api.Assertions.assertTrue; +import static org.junit.jupiter.api.Assumptions.assumeTrue; public class OpenSslPrivateKeyMethodTest { private static final String RFC_CIPHER_NAME = "TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256"; @@ -78,8 +80,13 @@ public class OpenSslPrivateKeyMethodTest { static Collection parameters() { List dst = new ArrayList(); - dst.add(new Object[] { true }); - dst.add(new Object[] { false }); + for (int a = 0; a < 2; a++) { + for (int b = 0; b < 2; b++) { + for (int c = 0; c < 2; c++) { + dst.add(new Object[] { a == 0, b == 0, c == 0 }); + } + } + } return dst; } @@ -87,7 +94,7 @@ public class OpenSslPrivateKeyMethodTest { public static void init() throws Exception { checkShouldUseKeyManagerFactory(); - Assume.assumeTrue(OpenSsl.isBoringSSL()); + assumeTrue(OpenSsl.isBoringSSL()); // Check if the cipher is supported at all which may not be the case for various JDK versions and OpenSSL API // implementations. assumeCipherAvailable(SslProvider.OPENSSL); @@ -120,7 +127,7 @@ public class OpenSslPrivateKeyMethodTest { } else { cipherSupported = OpenSsl.isCipherSuiteAvailable(RFC_CIPHER_NAME); } - Assume.assumeTrue("Unsupported cipher: " + RFC_CIPHER_NAME, cipherSupported); + assumeTrue(cipherSupported, "Unsupported cipher: " + RFC_CIPHER_NAME); } private static SslHandler newSslHandler(SslContext sslCtx, ByteBufAllocator allocator, Executor executor) { @@ -158,20 +165,31 @@ public class OpenSslPrivateKeyMethodTest { private static Executor delegateExecutor(boolean delegate) { return delegate ? EXECUTOR : null; } + private SslContext buildServerContext(OpenSslAsyncPrivateKeyMethod method) throws Exception { + List ciphers = Collections.singletonList(RFC_CIPHER_NAME); + + final KeyManagerFactory kmf = OpenSslX509KeyManagerFactory.newKeyless(CERT.cert()); + + return SslContextBuilder.forServer(kmf) + .sslProvider(SslProvider.OPENSSL) + .ciphers(ciphers) + // As this is not a TLSv1.3 cipher we should ensure we talk something else. + .protocols(SslProtocols.TLS_v1_2) + .option(OpenSslContextOption.ASYNC_PRIVATE_KEY_METHOD, method) + .build(); + } private static void assertThread(boolean delegate) { if (delegate && OpenSslContext.USE_TASKS) { assertEquals(DelegateThread.class, Thread.currentThread().getClass()); - } else { - assertNotEquals(DelegateThread.class, Thread.currentThread().getClass()); } } - @ParameterizedTest(name = "{index}: delegate = {0}") + @ParameterizedTest(name = "{index}: delegate = {0}, async = {1}, newThread={2}") @MethodSource("parameters") - public void testPrivateKeyMethod(final boolean delegate) throws Exception { + public void testPrivateKeyMethod(final boolean delegate, boolean async, boolean newThread) throws Exception { final AtomicBoolean signCalled = new AtomicBoolean(); - final SslContext sslServerContext = buildServerContext(new OpenSslPrivateKeyMethod() { + OpenSslPrivateKeyMethod keyMethod = new OpenSslPrivateKeyMethod() { @Override public byte[] sign(SSLEngine engine, int signatureAlgorithm, byte[] input) throws Exception { signCalled.set(true); @@ -201,7 +219,10 @@ public class OpenSslPrivateKeyMethodTest { public byte[] decrypt(SSLEngine engine, byte[] input) { throw new UnsupportedOperationException(); } - }); + }; + + final SslContext sslServerContext = async ? buildServerContext( + new OpenSslPrivateKeyMethodAdapter(keyMethod, newThread)) : buildServerContext(keyMethod); final SslContext sslClientContext = buildClientContext(); try { @@ -386,4 +407,70 @@ public class OpenSslPrivateKeyMethodTest { super(target); } } + + private static final class OpenSslPrivateKeyMethodAdapter implements OpenSslAsyncPrivateKeyMethod { + private final OpenSslPrivateKeyMethod keyMethod; + private final boolean newThread; + + OpenSslPrivateKeyMethodAdapter(OpenSslPrivateKeyMethod keyMethod, boolean newThread) { + this.keyMethod = keyMethod; + this.newThread = newThread; + } + + @Override + public Future sign(final SSLEngine engine, final int signatureAlgorithm, final byte[] input) { + final Promise promise = ImmediateEventExecutor.INSTANCE.newPromise(); + try { + if (newThread) { + // Let's run these in an extra thread to ensure that this would also work if the promise is + // notified later. + new DelegateThread(new Runnable() { + @Override + public void run() { + try { + // Let's sleep for some time to ensure we would notify in an async fashion + Thread.sleep(ThreadLocalRandom.current().nextLong(100, 500)); + promise.setSuccess(keyMethod.sign(engine, signatureAlgorithm, input)); + } catch (Throwable cause) { + promise.setFailure(cause); + } + } + }).start(); + } else { + promise.setSuccess(keyMethod.sign(engine, signatureAlgorithm, input)); + } + } catch (Throwable cause) { + promise.setFailure(cause); + } + return promise; + } + + @Override + public Future decrypt(final SSLEngine engine, final byte[] input) { + final Promise promise = ImmediateEventExecutor.INSTANCE.newPromise(); + try { + if (newThread) { + // Let's run these in an extra thread to ensure that this would also work if the promise is + // notified later. + new DelegateThread(new Runnable() { + @Override + public void run() { + try { + // Let's sleep for some time to ensure we would notify in an async fashion + Thread.sleep(ThreadLocalRandom.current().nextLong(100, 500)); + promise.setSuccess(keyMethod.decrypt(engine, input)); + } catch (Throwable cause) { + promise.setFailure(cause); + } + } + }).start(); + } else { + promise.setSuccess(keyMethod.decrypt(engine, input)); + } + } catch (Throwable cause) { + promise.setFailure(cause); + } + return promise; + } + } } diff --git a/handler/src/test/java/io/netty/handler/ssl/ParameterizedSslHandlerTest.java b/handler/src/test/java/io/netty/handler/ssl/ParameterizedSslHandlerTest.java index 3e89debd5c..883960ef2c 100644 --- a/handler/src/test/java/io/netty/handler/ssl/ParameterizedSslHandlerTest.java +++ b/handler/src/test/java/io/netty/handler/ssl/ParameterizedSslHandlerTest.java @@ -54,6 +54,10 @@ import javax.net.ssl.ManagerFactoryParameters; import javax.net.ssl.SSLException; import javax.net.ssl.TrustManager; import javax.net.ssl.X509TrustManager; +import java.io.ByteArrayOutputStream; +import java.io.IOException; +import java.io.PrintStream; +import java.io.UnsupportedEncodingException; import java.net.InetSocketAddress; import java.net.SocketAddress; import java.security.KeyStore; @@ -669,8 +673,23 @@ public class ParameterizedSslHandlerTest { } private void appendError(Throwable cause) { - readQueue.append("failed to write '").append(toWrite).append("': ").append(cause); - doneLatch.countDown(); + readQueue.append("failed to write '").append(toWrite).append("': "); + + ByteArrayOutputStream out = new ByteArrayOutputStream(); + try { + cause.printStackTrace(new PrintStream(out)); + readQueue.append(out.toString(CharsetUtil.US_ASCII.name())); + } catch (UnsupportedEncodingException ignore) { + // Let's just fallback to using toString(). + readQueue.append(cause); + } finally { + doneLatch.countDown(); + try { + out.close(); + } catch (IOException ignore) { + // ignore + } + } } } }