Revert "Introduce OpenSslAsyncPrivateKeyMethod which allows to asynchronously sign / decrypt the private key (#11390)"

This reverts commit 7c57c4be17.
This commit is contained in:
Norman Maurer 2021-07-07 08:26:27 +02:00
parent ac91eaaae8
commit 0b2e955aff
8 changed files with 52 additions and 367 deletions

View File

@ -1,20 +0,0 @@
/*
* Copyright 2021 The Netty Project
*
* The Netty Project licenses this file to you under the Apache License,
* version 2.0 (the "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at:
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*/
package io.netty.handler.ssl;
interface AsyncRunnable extends Runnable {
void run(Runnable completionCallback);
}

View File

@ -1,58 +0,0 @@
/*
* Copyright 2021 The Netty Project
*
* The Netty Project licenses this file to you under the Apache License,
* version 2.0 (the "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at:
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*/
package io.netty.handler.ssl;
import io.netty.internal.tcnative.SSLPrivateKeyMethod;
import io.netty.util.concurrent.Future;
import javax.net.ssl.SSLEngine;
public interface OpenSslAsyncPrivateKeyMethod {
int SSL_SIGN_RSA_PKCS1_SHA1 = SSLPrivateKeyMethod.SSL_SIGN_RSA_PKCS1_SHA1;
int SSL_SIGN_RSA_PKCS1_SHA256 = SSLPrivateKeyMethod.SSL_SIGN_RSA_PKCS1_SHA256;
int SSL_SIGN_RSA_PKCS1_SHA384 = SSLPrivateKeyMethod.SSL_SIGN_RSA_PKCS1_SHA384;
int SSL_SIGN_RSA_PKCS1_SHA512 = SSLPrivateKeyMethod.SSL_SIGN_RSA_PKCS1_SHA512;
int SSL_SIGN_ECDSA_SHA1 = SSLPrivateKeyMethod.SSL_SIGN_ECDSA_SHA1;
int SSL_SIGN_ECDSA_SECP256R1_SHA256 = SSLPrivateKeyMethod.SSL_SIGN_ECDSA_SECP256R1_SHA256;
int SSL_SIGN_ECDSA_SECP384R1_SHA384 = SSLPrivateKeyMethod.SSL_SIGN_ECDSA_SECP384R1_SHA384;
int SSL_SIGN_ECDSA_SECP521R1_SHA512 = SSLPrivateKeyMethod.SSL_SIGN_ECDSA_SECP521R1_SHA512;
int SSL_SIGN_RSA_PSS_RSAE_SHA256 = SSLPrivateKeyMethod.SSL_SIGN_RSA_PSS_RSAE_SHA256;
int SSL_SIGN_RSA_PSS_RSAE_SHA384 = SSLPrivateKeyMethod.SSL_SIGN_RSA_PSS_RSAE_SHA384;
int SSL_SIGN_RSA_PSS_RSAE_SHA512 = SSLPrivateKeyMethod.SSL_SIGN_RSA_PSS_RSAE_SHA512;
int SSL_SIGN_ED25519 = SSLPrivateKeyMethod.SSL_SIGN_ED25519;
int SSL_SIGN_RSA_PKCS1_MD5_SHA1 = SSLPrivateKeyMethod.SSL_SIGN_RSA_PKCS1_MD5_SHA1;
/**
* Signs the input with the given key and notifies the returned {@link Future} with the signed bytes.
*
* @param engine the {@link SSLEngine}
* @param signatureAlgorithm the algorithm to use for signing
* @param input the digest itself
* @return the {@link Future} that will be notified with the signed data
* (must not be {@code null}) when the operation completes.
*/
Future<byte[]> sign(SSLEngine engine, int signatureAlgorithm, byte[] input);
/**
* Decrypts the input with the given key and notifies the returned {@link Future} with the decrypted bytes.
*
* @param engine the {@link SSLEngine}
* @param input the input which should be decrypted
* @return the {@link Future} that will be notified with the decrypted data
* (must not be {@code null}) when the operation completes.
*/
Future<byte[]> decrypt(SSLEngine engine, byte[] input);
}

View File

@ -49,13 +49,4 @@ public final class OpenSslContextOption<T> extends SslContextOption<T> {
*/
public static final OpenSslContextOption<OpenSslPrivateKeyMethod> PRIVATE_KEY_METHOD =
new OpenSslContextOption<OpenSslPrivateKeyMethod>("PRIVATE_KEY_METHOD");
/**
* Set the {@link OpenSslAsyncPrivateKeyMethod} to use. This allows to offload private-key operations
* if needed.
*
* This is currently only supported when {@code BoringSSL} is used.
*/
public static final OpenSslContextOption<OpenSslAsyncPrivateKeyMethod> ASYNC_PRIVATE_KEY_METHOD =
new OpenSslContextOption<OpenSslAsyncPrivateKeyMethod>("ASYNC_PRIVATE_KEY_METHOD");
}

View File

@ -18,9 +18,7 @@ package io.netty.handler.ssl;
import io.netty.buffer.ByteBuf;
import io.netty.buffer.ByteBufAllocator;
import io.netty.handler.ssl.util.LazyX509Certificate;
import io.netty.internal.tcnative.AsyncSSLPrivateKeyMethod;
import io.netty.internal.tcnative.CertificateVerifier;
import io.netty.internal.tcnative.ResultCallback;
import io.netty.internal.tcnative.SSL;
import io.netty.internal.tcnative.SSLContext;
import io.netty.internal.tcnative.SSLPrivateKeyMethod;
@ -29,8 +27,6 @@ import io.netty.util.ReferenceCounted;
import io.netty.util.ResourceLeakDetector;
import io.netty.util.ResourceLeakDetectorFactory;
import io.netty.util.ResourceLeakTracker;
import io.netty.util.concurrent.Future;
import io.netty.util.concurrent.FutureListener;
import io.netty.util.internal.PlatformDependent;
import io.netty.util.internal.StringUtil;
import io.netty.util.internal.SuppressJava6Requirement;
@ -68,7 +64,6 @@ import javax.net.ssl.X509TrustManager;
import static io.netty.handler.ssl.OpenSsl.DEFAULT_CIPHERS;
import static io.netty.handler.ssl.OpenSsl.availableJavaCipherSuites;
import static io.netty.handler.ssl.OpenSsl.ensureAvailability;
import static io.netty.util.internal.ObjectUtil.checkNotNull;
import static io.netty.util.internal.ObjectUtil.checkNonEmpty;
import static io.netty.util.internal.ObjectUtil.checkPositiveOrZero;
@ -223,7 +218,6 @@ public abstract class ReferenceCountedOpenSslContext extends SslContext implemen
boolean tlsFalseStart = false;
boolean useTasks = USE_TASKS;
OpenSslPrivateKeyMethod privateKeyMethod = null;
OpenSslAsyncPrivateKeyMethod asyncPrivateKeyMethod = null;
if (ctxOptions != null) {
for (Map.Entry<SslContextOption<?>, Object> ctxOpt : ctxOptions) {
@ -235,20 +229,12 @@ public abstract class ReferenceCountedOpenSslContext extends SslContext implemen
useTasks = (Boolean) ctxOpt.getValue();
} else if (option == OpenSslContextOption.PRIVATE_KEY_METHOD) {
privateKeyMethod = (OpenSslPrivateKeyMethod) ctxOpt.getValue();
} else if (option == OpenSslContextOption.ASYNC_PRIVATE_KEY_METHOD) {
asyncPrivateKeyMethod = (OpenSslAsyncPrivateKeyMethod) ctxOpt.getValue();
} else {
logger.debug("Skipping unsupported " + SslContextOption.class.getSimpleName()
+ ": " + ctxOpt.getKey());
}
}
}
if (privateKeyMethod != null && asyncPrivateKeyMethod != null) {
throw new IllegalArgumentException("You can either only use "
+ OpenSslAsyncPrivateKeyMethod.class.getSimpleName() + " or "
+ OpenSslPrivateKeyMethod.class.getSimpleName());
}
this.tlsFalseStart = tlsFalseStart;
leak = leakDetection ? leakDetector.track(this) : null;
@ -377,9 +363,6 @@ public abstract class ReferenceCountedOpenSslContext extends SslContext implemen
if (privateKeyMethod != null) {
SSLContext.setPrivateKeyMethod(ctx, new PrivateKeyMethod(engineMap, privateKeyMethod));
}
if (asyncPrivateKeyMethod != null) {
SSLContext.setPrivateKeyMethod(ctx, new AsyncPrivateKeyMethod(engineMap, asyncPrivateKeyMethod));
}
success = true;
} finally {
if (!success) {
@ -996,83 +979,12 @@ public abstract class ReferenceCountedOpenSslContext extends SslContext implemen
throw e;
}
}
}
private static final class AsyncPrivateKeyMethod implements AsyncSSLPrivateKeyMethod {
private final OpenSslEngineMap engineMap;
private final OpenSslAsyncPrivateKeyMethod keyMethod;
AsyncPrivateKeyMethod(OpenSslEngineMap engineMap, OpenSslAsyncPrivateKeyMethod keyMethod) {
this.engineMap = engineMap;
this.keyMethod = keyMethod;
}
private ReferenceCountedOpenSslEngine retrieveEngine(long ssl) throws SSLException {
ReferenceCountedOpenSslEngine engine = engineMap.get(ssl);
if (engine == null) {
throw new SSLException("Could not find a " +
StringUtil.simpleClassName(ReferenceCountedOpenSslEngine.class) + " for sslPointer " + ssl);
private static byte[] verifyResult(byte[] result) throws SignatureException {
if (result == null) {
throw new SignatureException();
}
return engine;
return result;
}
@Override
public void sign(long ssl, int signatureAlgorithm, byte[] bytes, ResultCallback<byte[]> resultCallback) {
try {
ReferenceCountedOpenSslEngine engine = retrieveEngine(ssl);
keyMethod.sign(engine, signatureAlgorithm, bytes)
.addListener(new ResultCallbackListener(engine, ssl, resultCallback));
} catch (SSLException e) {
resultCallback.onError(ssl, e);
}
}
@Override
public void decrypt(long ssl, byte[] bytes, ResultCallback<byte[]> resultCallback) {
try {
ReferenceCountedOpenSslEngine engine = retrieveEngine(ssl);
keyMethod.decrypt(engine, bytes)
.addListener(new ResultCallbackListener(engine, ssl, resultCallback));
} catch (SSLException e) {
resultCallback.onError(ssl, e);
}
}
private static final class ResultCallbackListener implements FutureListener<byte[]> {
private final ReferenceCountedOpenSslEngine engine;
private final long ssl;
private final ResultCallback<byte[]> resultCallback;
ResultCallbackListener(ReferenceCountedOpenSslEngine engine, long ssl,
ResultCallback<byte[]> resultCallback) {
this.engine = engine;
this.ssl = ssl;
this.resultCallback = resultCallback;
}
@Override
public void operationComplete(Future<byte[]> future) {
Throwable cause = future.cause();
if (cause == null) {
try {
byte[] result = verifyResult(future.getNow());
resultCallback.onSuccess(ssl, result);
return;
} catch (SignatureException e) {
cause = e;
engine.initHandshakeException(e);
}
}
resultCallback.onError(ssl, cause);
}
}
}
private static byte[] verifyResult(byte[] result) throws SignatureException {
if (result == null) {
throw new SignatureException();
}
return result;
}
}

View File

@ -19,7 +19,6 @@ import io.netty.buffer.ByteBuf;
import io.netty.buffer.ByteBufAllocator;
import io.netty.handler.ssl.util.LazyJavaxX509Certificate;
import io.netty.handler.ssl.util.LazyX509Certificate;
import io.netty.internal.tcnative.AsyncTask;
import io.netty.internal.tcnative.Buffer;
import io.netty.internal.tcnative.SSL;
import io.netty.util.AbstractReferenceCounted;
@ -1437,48 +1436,6 @@ public class ReferenceCountedOpenSslEngine extends SSLEngine implements Referenc
}
}
private class TaskDecorator<R extends Runnable> implements Runnable {
protected final R task;
TaskDecorator(R task) {
this.task = task;
}
@Override
public void run() {
if (isDestroyed()) {
// The engine was destroyed in the meantime, just return.
return;
}
try {
task.run();
} finally {
// The task was run, reset needTask to false so getHandshakeStatus() returns the correct value.
needTask = false;
}
}
}
private final class AsyncTaskDecorator extends TaskDecorator<AsyncTask> implements AsyncRunnable {
AsyncTaskDecorator(AsyncTask task) {
super(task);
}
@Override
public void run(Runnable runnable) {
if (isDestroyed()) {
// The engine was destroyed in the meantime, just return.
runnable.run();
return;
}
try {
task.runAsync(runnable);
} finally {
// The task was run, reset needTask to false so getHandshakeStatus() returns the correct value.
needTask = false;
}
}
}
@Override
public final synchronized Runnable getDelegatedTask() {
if (isDestroyed()) {
@ -1488,10 +1445,21 @@ public class ReferenceCountedOpenSslEngine extends SSLEngine implements Referenc
if (task == null) {
return null;
}
if (task instanceof AsyncTask) {
return new AsyncTaskDecorator((AsyncTask) task);
}
return new TaskDecorator<Runnable>(task);
return new Runnable() {
@Override
public void run() {
if (isDestroyed()) {
// The engine was destroyed in the meantime, just return.
return;
}
try {
task.run();
} finally {
// The task was run, reset needTask to false so getHandshakeStatus() returns the correct value.
needTask = false;
}
}
};
}
@Override

View File

@ -1506,18 +1506,13 @@ public class SslHandler extends ByteToMessageDecoder implements ChannelOutboundH
return executor instanceof EventExecutor && ((EventExecutor) executor).inEventLoop();
}
private static void runAllDelegatedTasks(SSLEngine engine, Runnable completeTask) {
private static void runAllDelegatedTasks(SSLEngine engine) {
for (;;) {
Runnable task = engine.getDelegatedTask();
if (task == null) {
return;
}
if (task instanceof AsyncRunnable) {
((AsyncRunnable) task).run(completeTask);
} else {
task.run();
completeTask.run();
}
task.run();
}
}
@ -1529,8 +1524,14 @@ public class SslHandler extends ByteToMessageDecoder implements ChannelOutboundH
* more tasks to process.
*/
private boolean runDelegatedTasks(boolean inUnwrap) {
executeDelegatedTasks(inUnwrap);
return false;
if (delegatedTaskExecutor == ImmediateExecutor.INSTANCE || inEventLoop(delegatedTaskExecutor)) {
// We should run the task directly in the EventExecutor thread and not offload at all.
runAllDelegatedTasks(engine);
return true;
} else {
executeDelegatedTasks(inUnwrap);
return false;
}
}
private void executeDelegatedTasks(boolean inUnwrap) {
@ -1692,20 +1693,18 @@ public class SslHandler extends ByteToMessageDecoder implements ChannelOutboundH
@Override
public void run() {
try {
runAllDelegatedTasks(engine, new Runnable() {
runAllDelegatedTasks(engine);
// All tasks were processed.
assert engine.getHandshakeStatus() != HandshakeStatus.NEED_TASK;
// Jump back on the EventExecutor.
ctx.executor().execute(new Runnable() {
@Override
public void run() {
// All tasks were processed.
// Jump back on the EventExecutor.
ctx.executor().execute(new Runnable() {
@Override
public void run() {
resumeOnEventExecutor();
}
});
resumeOnEventExecutor();
}
});
} catch (final Throwable cause) {
handleException(cause);
}

View File

@ -34,11 +34,9 @@ import io.netty.channel.local.LocalServerChannel;
import io.netty.handler.ssl.util.InsecureTrustManagerFactory;
import io.netty.handler.ssl.util.SelfSignedCertificate;
import io.netty.util.ReferenceCountUtil;
import io.netty.util.concurrent.Future;
import io.netty.util.concurrent.ImmediateEventExecutor;
import io.netty.util.concurrent.Promise;
import io.netty.util.internal.ThreadLocalRandom;
import org.hamcrest.Matchers;
import org.junit.Assume;
import org.junit.jupiter.api.AfterAll;
import org.junit.jupiter.api.BeforeAll;
import org.junit.jupiter.params.ParameterizedTest;
@ -68,9 +66,9 @@ import java.util.concurrent.atomic.AtomicBoolean;
import static io.netty.handler.ssl.OpenSslTestUtils.checkShouldUseKeyManagerFactory;
import static org.hamcrest.MatcherAssert.assertThat;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertNotEquals;
import static org.junit.jupiter.api.Assertions.assertNotNull;
import static org.junit.jupiter.api.Assertions.assertTrue;
import static org.junit.jupiter.api.Assumptions.assumeTrue;
public class OpenSslPrivateKeyMethodTest {
private static final String RFC_CIPHER_NAME = "TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256";
@ -80,13 +78,8 @@ public class OpenSslPrivateKeyMethodTest {
static Collection<Object[]> parameters() {
List<Object[]> dst = new ArrayList<Object[]>();
for (int a = 0; a < 2; a++) {
for (int b = 0; b < 2; b++) {
for (int c = 0; c < 2; c++) {
dst.add(new Object[] { a == 0, b == 0, c == 0 });
}
}
}
dst.add(new Object[] { true });
dst.add(new Object[] { false });
return dst;
}
@ -94,7 +87,7 @@ public class OpenSslPrivateKeyMethodTest {
public static void init() throws Exception {
checkShouldUseKeyManagerFactory();
assumeTrue(OpenSsl.isBoringSSL());
Assume.assumeTrue(OpenSsl.isBoringSSL());
// Check if the cipher is supported at all which may not be the case for various JDK versions and OpenSSL API
// implementations.
assumeCipherAvailable(SslProvider.OPENSSL);
@ -132,7 +125,7 @@ public class OpenSslPrivateKeyMethodTest {
} else {
cipherSupported = OpenSsl.isCipherSuiteAvailable(RFC_CIPHER_NAME);
}
assumeTrue(cipherSupported, "Unsupported cipher: " + RFC_CIPHER_NAME);
Assume.assumeTrue("Unsupported cipher: " + RFC_CIPHER_NAME, cipherSupported);
}
private static SslHandler newSslHandler(SslContext sslCtx, ByteBufAllocator allocator, Executor executor) {
@ -167,35 +160,23 @@ public class OpenSslPrivateKeyMethodTest {
.build();
}
private SslContext buildServerContext(OpenSslAsyncPrivateKeyMethod method) throws Exception {
List<String> ciphers = Collections.singletonList(RFC_CIPHER_NAME);
final KeyManagerFactory kmf = OpenSslX509KeyManagerFactory.newKeyless(CERT.cert());
return SslContextBuilder.forServer(kmf)
.sslProvider(SslProvider.OPENSSL)
.ciphers(ciphers)
// As this is not a TLSv1.3 cipher we should ensure we talk something else.
.protocols(SslUtils.PROTOCOL_TLS_V1_2)
.option(OpenSslContextOption.ASYNC_PRIVATE_KEY_METHOD, method)
.build();
}
private Executor delegateExecutor(boolean delegate) {
private static Executor delegateExecutor(boolean delegate) {
return delegate ? EXECUTOR : null;
}
private static void assertThread(boolean delegate) {
if (delegate && OpenSslContext.USE_TASKS) {
assertEquals(DelegateThread.class, Thread.currentThread().getClass());
} else {
assertNotEquals(DelegateThread.class, Thread.currentThread().getClass());
}
}
@ParameterizedTest(name = "{index}: delegate = {0}, async = {1}, newThread={2}")
@ParameterizedTest(name = "{index}: delegate = {0}")
@MethodSource("parameters")
public void testPrivateKeyMethod(final boolean delegate, boolean async, boolean newThread) throws Exception {
public void testPrivateKeyMethod(final boolean delegate) throws Exception {
final AtomicBoolean signCalled = new AtomicBoolean();
OpenSslPrivateKeyMethod keyMethod = new OpenSslPrivateKeyMethod() {
final SslContext sslServerContext = buildServerContext(new OpenSslPrivateKeyMethod() {
@Override
public byte[] sign(SSLEngine engine, int signatureAlgorithm, byte[] input) throws Exception {
signCalled.set(true);
@ -225,10 +206,7 @@ public class OpenSslPrivateKeyMethodTest {
public byte[] decrypt(SSLEngine engine, byte[] input) {
throw new UnsupportedOperationException();
}
};
final SslContext sslServerContext = async ? buildServerContext(
new OpenSslPrivateKeyMethodAdapter(keyMethod, newThread)) : buildServerContext(keyMethod);
});
final SslContext sslClientContext = buildClientContext();
try {
@ -413,70 +391,4 @@ public class OpenSslPrivateKeyMethodTest {
super(target);
}
}
private static final class OpenSslPrivateKeyMethodAdapter implements OpenSslAsyncPrivateKeyMethod {
private final OpenSslPrivateKeyMethod keyMethod;
private final boolean newThread;
OpenSslPrivateKeyMethodAdapter(OpenSslPrivateKeyMethod keyMethod, boolean newThread) {
this.keyMethod = keyMethod;
this.newThread = newThread;
}
@Override
public Future<byte[]> sign(final SSLEngine engine, final int signatureAlgorithm, final byte[] input) {
final Promise<byte[]> promise = ImmediateEventExecutor.INSTANCE.newPromise();
try {
if (newThread) {
// Let's run these in an extra thread to ensure that this would also work if the promise is
// notified later.
new DelegateThread(new Runnable() {
@Override
public void run() {
try {
// Let's sleep for some time to ensure we would notify in an async fashion
Thread.sleep(ThreadLocalRandom.current().nextLong(100, 500));
promise.setSuccess(keyMethod.sign(engine, signatureAlgorithm, input));
} catch (Throwable cause) {
promise.setFailure(cause);
}
}
}).start();
} else {
promise.setSuccess(keyMethod.sign(engine, signatureAlgorithm, input));
}
} catch (Throwable cause) {
promise.setFailure(cause);
}
return promise;
}
@Override
public Future<byte[]> decrypt(final SSLEngine engine, final byte[] input) {
final Promise<byte[]> promise = ImmediateEventExecutor.INSTANCE.newPromise();
try {
if (newThread) {
// Let's run these in an extra thread to ensure that this would also work if the promise is
// notified later.
new DelegateThread(new Runnable() {
@Override
public void run() {
try {
// Let's sleep for some time to ensure we would notify in an async fashion
Thread.sleep(ThreadLocalRandom.current().nextLong(100, 500));
promise.setSuccess(keyMethod.decrypt(engine, input));
} catch (Throwable cause) {
promise.setFailure(cause);
}
}
}).start();
} else {
promise.setSuccess(keyMethod.decrypt(engine, input));
}
} catch (Throwable cause) {
promise.setFailure(cause);
}
return promise;
}
}
}

View File

@ -55,10 +55,6 @@ import javax.net.ssl.ManagerFactoryParameters;
import javax.net.ssl.SSLException;
import javax.net.ssl.TrustManager;
import javax.net.ssl.X509TrustManager;
import java.io.ByteArrayOutputStream;
import java.io.IOException;
import java.io.PrintStream;
import java.io.UnsupportedEncodingException;
import java.net.InetSocketAddress;
import java.net.SocketAddress;
import java.security.KeyStore;
@ -683,23 +679,8 @@ public class ParameterizedSslHandlerTest {
}
private void appendError(Throwable cause) {
readQueue.append("failed to write '").append(toWrite).append("': ");
ByteArrayOutputStream out = new ByteArrayOutputStream();
try {
cause.printStackTrace(new PrintStream(out));
readQueue.append(out.toString(CharsetUtil.US_ASCII.name()));
} catch (UnsupportedEncodingException ignore) {
// Let's just fallback to using toString().
readQueue.append(cause);
} finally {
doneLatch.countDown();
try {
out.close();
} catch (IOException ignore) {
// ignore
}
}
readQueue.append("failed to write '").append(toWrite).append("': ").append(cause);
doneLatch.countDown();
}
}
}