Introduce OpenSslAsyncPrivateKeyMethod which allows to asynchronously sign / decrypt the private key (#11390) (#11460)
Motivation: At the moment we only support signing / decrypting the private key in a synchronous fashion. This is quite limited as we may want to do a network call to do so on a remote system for example. Modifications: - Update to latest netty-tcnative which supports running tasks in an asynchronous fashion. - Add OpenSslAsyncPrivateKeyMethod interface - Adjust SslHandler to be able to handle asynchronous task execution - Adjust unit tests to test that asynchronous task execution works in all cases Result: Be able to asynchronous do key signing operations
This commit is contained in:
parent
fef761d03e
commit
40fb6026ef
@ -0,0 +1,20 @@
|
|||||||
|
/*
|
||||||
|
* Copyright 2021 The Netty Project
|
||||||
|
*
|
||||||
|
* The Netty Project licenses this file to you under the Apache License,
|
||||||
|
* version 2.0 (the "License"); you may not use this file except in compliance
|
||||||
|
* with the License. You may obtain a copy of the License at:
|
||||||
|
*
|
||||||
|
* https://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
|
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
|
* License for the specific language governing permissions and limitations
|
||||||
|
* under the License.
|
||||||
|
*/
|
||||||
|
package io.netty.handler.ssl;
|
||||||
|
|
||||||
|
interface AsyncRunnable extends Runnable {
|
||||||
|
void run(Runnable completionCallback);
|
||||||
|
}
|
@ -0,0 +1,58 @@
|
|||||||
|
/*
|
||||||
|
* Copyright 2021 The Netty Project
|
||||||
|
*
|
||||||
|
* The Netty Project licenses this file to you under the Apache License,
|
||||||
|
* version 2.0 (the "License"); you may not use this file except in compliance
|
||||||
|
* with the License. You may obtain a copy of the License at:
|
||||||
|
*
|
||||||
|
* https://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
|
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
|
* License for the specific language governing permissions and limitations
|
||||||
|
* under the License.
|
||||||
|
*/
|
||||||
|
package io.netty.handler.ssl;
|
||||||
|
|
||||||
|
import io.netty.internal.tcnative.SSLPrivateKeyMethod;
|
||||||
|
import io.netty.util.concurrent.Future;
|
||||||
|
|
||||||
|
import javax.net.ssl.SSLEngine;
|
||||||
|
|
||||||
|
public interface OpenSslAsyncPrivateKeyMethod {
|
||||||
|
int SSL_SIGN_RSA_PKCS1_SHA1 = SSLPrivateKeyMethod.SSL_SIGN_RSA_PKCS1_SHA1;
|
||||||
|
int SSL_SIGN_RSA_PKCS1_SHA256 = SSLPrivateKeyMethod.SSL_SIGN_RSA_PKCS1_SHA256;
|
||||||
|
int SSL_SIGN_RSA_PKCS1_SHA384 = SSLPrivateKeyMethod.SSL_SIGN_RSA_PKCS1_SHA384;
|
||||||
|
int SSL_SIGN_RSA_PKCS1_SHA512 = SSLPrivateKeyMethod.SSL_SIGN_RSA_PKCS1_SHA512;
|
||||||
|
int SSL_SIGN_ECDSA_SHA1 = SSLPrivateKeyMethod.SSL_SIGN_ECDSA_SHA1;
|
||||||
|
int SSL_SIGN_ECDSA_SECP256R1_SHA256 = SSLPrivateKeyMethod.SSL_SIGN_ECDSA_SECP256R1_SHA256;
|
||||||
|
int SSL_SIGN_ECDSA_SECP384R1_SHA384 = SSLPrivateKeyMethod.SSL_SIGN_ECDSA_SECP384R1_SHA384;
|
||||||
|
int SSL_SIGN_ECDSA_SECP521R1_SHA512 = SSLPrivateKeyMethod.SSL_SIGN_ECDSA_SECP521R1_SHA512;
|
||||||
|
int SSL_SIGN_RSA_PSS_RSAE_SHA256 = SSLPrivateKeyMethod.SSL_SIGN_RSA_PSS_RSAE_SHA256;
|
||||||
|
int SSL_SIGN_RSA_PSS_RSAE_SHA384 = SSLPrivateKeyMethod.SSL_SIGN_RSA_PSS_RSAE_SHA384;
|
||||||
|
int SSL_SIGN_RSA_PSS_RSAE_SHA512 = SSLPrivateKeyMethod.SSL_SIGN_RSA_PSS_RSAE_SHA512;
|
||||||
|
int SSL_SIGN_ED25519 = SSLPrivateKeyMethod.SSL_SIGN_ED25519;
|
||||||
|
int SSL_SIGN_RSA_PKCS1_MD5_SHA1 = SSLPrivateKeyMethod.SSL_SIGN_RSA_PKCS1_MD5_SHA1;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Signs the input with the given key and notifies the returned {@link Future} with the signed bytes.
|
||||||
|
*
|
||||||
|
* @param engine the {@link SSLEngine}
|
||||||
|
* @param signatureAlgorithm the algorithm to use for signing
|
||||||
|
* @param input the digest itself
|
||||||
|
* @return the {@link Future} that will be notified with the signed data
|
||||||
|
* (must not be {@code null}) when the operation completes.
|
||||||
|
*/
|
||||||
|
Future<byte[]> sign(SSLEngine engine, int signatureAlgorithm, byte[] input);
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Decrypts the input with the given key and notifies the returned {@link Future} with the decrypted bytes.
|
||||||
|
*
|
||||||
|
* @param engine the {@link SSLEngine}
|
||||||
|
* @param input the input which should be decrypted
|
||||||
|
* @return the {@link Future} that will be notified with the decrypted data
|
||||||
|
* (must not be {@code null}) when the operation completes.
|
||||||
|
*/
|
||||||
|
Future<byte[]> decrypt(SSLEngine engine, byte[] input);
|
||||||
|
}
|
@ -49,4 +49,13 @@ public final class OpenSslContextOption<T> extends SslContextOption<T> {
|
|||||||
*/
|
*/
|
||||||
public static final OpenSslContextOption<OpenSslPrivateKeyMethod> PRIVATE_KEY_METHOD =
|
public static final OpenSslContextOption<OpenSslPrivateKeyMethod> PRIVATE_KEY_METHOD =
|
||||||
new OpenSslContextOption<OpenSslPrivateKeyMethod>("PRIVATE_KEY_METHOD");
|
new OpenSslContextOption<OpenSslPrivateKeyMethod>("PRIVATE_KEY_METHOD");
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Set the {@link OpenSslAsyncPrivateKeyMethod} to use. This allows to offload private-key operations
|
||||||
|
* if needed.
|
||||||
|
*
|
||||||
|
* This is currently only supported when {@code BoringSSL} is used.
|
||||||
|
*/
|
||||||
|
public static final OpenSslContextOption<OpenSslAsyncPrivateKeyMethod> ASYNC_PRIVATE_KEY_METHOD =
|
||||||
|
new OpenSslContextOption<OpenSslAsyncPrivateKeyMethod>("ASYNC_PRIVATE_KEY_METHOD");
|
||||||
}
|
}
|
||||||
|
@ -18,7 +18,9 @@ package io.netty.handler.ssl;
|
|||||||
import io.netty.buffer.ByteBuf;
|
import io.netty.buffer.ByteBuf;
|
||||||
import io.netty.buffer.ByteBufAllocator;
|
import io.netty.buffer.ByteBufAllocator;
|
||||||
import io.netty.handler.ssl.util.LazyX509Certificate;
|
import io.netty.handler.ssl.util.LazyX509Certificate;
|
||||||
|
import io.netty.internal.tcnative.AsyncSSLPrivateKeyMethod;
|
||||||
import io.netty.internal.tcnative.CertificateVerifier;
|
import io.netty.internal.tcnative.CertificateVerifier;
|
||||||
|
import io.netty.internal.tcnative.ResultCallback;
|
||||||
import io.netty.internal.tcnative.SSL;
|
import io.netty.internal.tcnative.SSL;
|
||||||
import io.netty.internal.tcnative.SSLContext;
|
import io.netty.internal.tcnative.SSLContext;
|
||||||
import io.netty.internal.tcnative.SSLPrivateKeyMethod;
|
import io.netty.internal.tcnative.SSLPrivateKeyMethod;
|
||||||
@ -27,6 +29,8 @@ import io.netty.util.ReferenceCounted;
|
|||||||
import io.netty.util.ResourceLeakDetector;
|
import io.netty.util.ResourceLeakDetector;
|
||||||
import io.netty.util.ResourceLeakDetectorFactory;
|
import io.netty.util.ResourceLeakDetectorFactory;
|
||||||
import io.netty.util.ResourceLeakTracker;
|
import io.netty.util.ResourceLeakTracker;
|
||||||
|
import io.netty.util.concurrent.Future;
|
||||||
|
import io.netty.util.concurrent.FutureListener;
|
||||||
import io.netty.util.internal.PlatformDependent;
|
import io.netty.util.internal.PlatformDependent;
|
||||||
import io.netty.util.internal.StringUtil;
|
import io.netty.util.internal.StringUtil;
|
||||||
import io.netty.util.internal.SuppressJava6Requirement;
|
import io.netty.util.internal.SuppressJava6Requirement;
|
||||||
@ -64,6 +68,7 @@ import javax.net.ssl.X509TrustManager;
|
|||||||
|
|
||||||
import static io.netty.handler.ssl.OpenSsl.DEFAULT_CIPHERS;
|
import static io.netty.handler.ssl.OpenSsl.DEFAULT_CIPHERS;
|
||||||
import static io.netty.handler.ssl.OpenSsl.availableJavaCipherSuites;
|
import static io.netty.handler.ssl.OpenSsl.availableJavaCipherSuites;
|
||||||
|
import static io.netty.handler.ssl.OpenSsl.ensureAvailability;
|
||||||
import static io.netty.util.internal.ObjectUtil.checkNotNull;
|
import static io.netty.util.internal.ObjectUtil.checkNotNull;
|
||||||
import static io.netty.util.internal.ObjectUtil.checkNonEmpty;
|
import static io.netty.util.internal.ObjectUtil.checkNonEmpty;
|
||||||
import static io.netty.util.internal.ObjectUtil.checkPositiveOrZero;
|
import static io.netty.util.internal.ObjectUtil.checkPositiveOrZero;
|
||||||
@ -218,6 +223,7 @@ public abstract class ReferenceCountedOpenSslContext extends SslContext implemen
|
|||||||
boolean tlsFalseStart = false;
|
boolean tlsFalseStart = false;
|
||||||
boolean useTasks = USE_TASKS;
|
boolean useTasks = USE_TASKS;
|
||||||
OpenSslPrivateKeyMethod privateKeyMethod = null;
|
OpenSslPrivateKeyMethod privateKeyMethod = null;
|
||||||
|
OpenSslAsyncPrivateKeyMethod asyncPrivateKeyMethod = null;
|
||||||
|
|
||||||
if (ctxOptions != null) {
|
if (ctxOptions != null) {
|
||||||
for (Map.Entry<SslContextOption<?>, Object> ctxOpt : ctxOptions) {
|
for (Map.Entry<SslContextOption<?>, Object> ctxOpt : ctxOptions) {
|
||||||
@ -229,12 +235,20 @@ public abstract class ReferenceCountedOpenSslContext extends SslContext implemen
|
|||||||
useTasks = (Boolean) ctxOpt.getValue();
|
useTasks = (Boolean) ctxOpt.getValue();
|
||||||
} else if (option == OpenSslContextOption.PRIVATE_KEY_METHOD) {
|
} else if (option == OpenSslContextOption.PRIVATE_KEY_METHOD) {
|
||||||
privateKeyMethod = (OpenSslPrivateKeyMethod) ctxOpt.getValue();
|
privateKeyMethod = (OpenSslPrivateKeyMethod) ctxOpt.getValue();
|
||||||
|
} else if (option == OpenSslContextOption.ASYNC_PRIVATE_KEY_METHOD) {
|
||||||
|
asyncPrivateKeyMethod = (OpenSslAsyncPrivateKeyMethod) ctxOpt.getValue();
|
||||||
} else {
|
} else {
|
||||||
logger.debug("Skipping unsupported " + SslContextOption.class.getSimpleName()
|
logger.debug("Skipping unsupported " + SslContextOption.class.getSimpleName()
|
||||||
+ ": " + ctxOpt.getKey());
|
+ ": " + ctxOpt.getKey());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
if (privateKeyMethod != null && asyncPrivateKeyMethod != null) {
|
||||||
|
throw new IllegalArgumentException("You can either only use "
|
||||||
|
+ OpenSslAsyncPrivateKeyMethod.class.getSimpleName() + " or "
|
||||||
|
+ OpenSslPrivateKeyMethod.class.getSimpleName());
|
||||||
|
}
|
||||||
|
|
||||||
this.tlsFalseStart = tlsFalseStart;
|
this.tlsFalseStart = tlsFalseStart;
|
||||||
|
|
||||||
leak = leakDetection ? leakDetector.track(this) : null;
|
leak = leakDetection ? leakDetector.track(this) : null;
|
||||||
@ -363,6 +377,9 @@ public abstract class ReferenceCountedOpenSslContext extends SslContext implemen
|
|||||||
if (privateKeyMethod != null) {
|
if (privateKeyMethod != null) {
|
||||||
SSLContext.setPrivateKeyMethod(ctx, new PrivateKeyMethod(engineMap, privateKeyMethod));
|
SSLContext.setPrivateKeyMethod(ctx, new PrivateKeyMethod(engineMap, privateKeyMethod));
|
||||||
}
|
}
|
||||||
|
if (asyncPrivateKeyMethod != null) {
|
||||||
|
SSLContext.setPrivateKeyMethod(ctx, new AsyncPrivateKeyMethod(engineMap, asyncPrivateKeyMethod));
|
||||||
|
}
|
||||||
success = true;
|
success = true;
|
||||||
} finally {
|
} finally {
|
||||||
if (!success) {
|
if (!success) {
|
||||||
@ -979,6 +996,78 @@ public abstract class ReferenceCountedOpenSslContext extends SslContext implemen
|
|||||||
throw e;
|
throw e;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private static final class AsyncPrivateKeyMethod implements AsyncSSLPrivateKeyMethod {
|
||||||
|
|
||||||
|
private final OpenSslEngineMap engineMap;
|
||||||
|
private final OpenSslAsyncPrivateKeyMethod keyMethod;
|
||||||
|
|
||||||
|
AsyncPrivateKeyMethod(OpenSslEngineMap engineMap, OpenSslAsyncPrivateKeyMethod keyMethod) {
|
||||||
|
this.engineMap = engineMap;
|
||||||
|
this.keyMethod = keyMethod;
|
||||||
|
}
|
||||||
|
|
||||||
|
private ReferenceCountedOpenSslEngine retrieveEngine(long ssl) throws SSLException {
|
||||||
|
ReferenceCountedOpenSslEngine engine = engineMap.get(ssl);
|
||||||
|
if (engine == null) {
|
||||||
|
throw new SSLException("Could not find a " +
|
||||||
|
StringUtil.simpleClassName(ReferenceCountedOpenSslEngine.class) + " for sslPointer " + ssl);
|
||||||
|
}
|
||||||
|
return engine;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void sign(long ssl, int signatureAlgorithm, byte[] bytes, ResultCallback<byte[]> resultCallback) {
|
||||||
|
try {
|
||||||
|
ReferenceCountedOpenSslEngine engine = retrieveEngine(ssl);
|
||||||
|
keyMethod.sign(engine, signatureAlgorithm, bytes)
|
||||||
|
.addListener(new ResultCallbackListener(engine, ssl, resultCallback));
|
||||||
|
} catch (SSLException e) {
|
||||||
|
resultCallback.onError(ssl, e);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void decrypt(long ssl, byte[] bytes, ResultCallback<byte[]> resultCallback) {
|
||||||
|
try {
|
||||||
|
ReferenceCountedOpenSslEngine engine = retrieveEngine(ssl);
|
||||||
|
keyMethod.decrypt(engine, bytes)
|
||||||
|
.addListener(new ResultCallbackListener(engine, ssl, resultCallback));
|
||||||
|
} catch (SSLException e) {
|
||||||
|
resultCallback.onError(ssl, e);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private static final class ResultCallbackListener implements FutureListener<byte[]> {
|
||||||
|
private final ReferenceCountedOpenSslEngine engine;
|
||||||
|
private final long ssl;
|
||||||
|
private final ResultCallback<byte[]> resultCallback;
|
||||||
|
|
||||||
|
ResultCallbackListener(ReferenceCountedOpenSslEngine engine, long ssl,
|
||||||
|
ResultCallback<byte[]> resultCallback) {
|
||||||
|
this.engine = engine;
|
||||||
|
this.ssl = ssl;
|
||||||
|
this.resultCallback = resultCallback;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void operationComplete(Future<byte[]> future) {
|
||||||
|
Throwable cause = future.cause();
|
||||||
|
if (cause == null) {
|
||||||
|
try {
|
||||||
|
byte[] result = verifyResult(future.getNow());
|
||||||
|
resultCallback.onSuccess(ssl, result);
|
||||||
|
return;
|
||||||
|
} catch (SignatureException e) {
|
||||||
|
cause = e;
|
||||||
|
engine.initHandshakeException(e);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
resultCallback.onError(ssl, cause);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
private static byte[] verifyResult(byte[] result) throws SignatureException {
|
private static byte[] verifyResult(byte[] result) throws SignatureException {
|
||||||
if (result == null) {
|
if (result == null) {
|
||||||
@ -986,5 +1075,4 @@ public abstract class ReferenceCountedOpenSslContext extends SslContext implemen
|
|||||||
}
|
}
|
||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
@ -19,6 +19,7 @@ import io.netty.buffer.ByteBuf;
|
|||||||
import io.netty.buffer.ByteBufAllocator;
|
import io.netty.buffer.ByteBufAllocator;
|
||||||
import io.netty.handler.ssl.util.LazyJavaxX509Certificate;
|
import io.netty.handler.ssl.util.LazyJavaxX509Certificate;
|
||||||
import io.netty.handler.ssl.util.LazyX509Certificate;
|
import io.netty.handler.ssl.util.LazyX509Certificate;
|
||||||
|
import io.netty.internal.tcnative.AsyncTask;
|
||||||
import io.netty.internal.tcnative.Buffer;
|
import io.netty.internal.tcnative.Buffer;
|
||||||
import io.netty.internal.tcnative.SSL;
|
import io.netty.internal.tcnative.SSL;
|
||||||
import io.netty.util.AbstractReferenceCounted;
|
import io.netty.util.AbstractReferenceCounted;
|
||||||
@ -1429,16 +1430,12 @@ public class ReferenceCountedOpenSslEngine extends SSLEngine implements Referenc
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
private class TaskDecorator<R extends Runnable> implements Runnable {
|
||||||
public final synchronized Runnable getDelegatedTask() {
|
protected final R task;
|
||||||
if (isDestroyed()) {
|
TaskDecorator(R task) {
|
||||||
return null;
|
this.task = task;
|
||||||
}
|
}
|
||||||
final Runnable task = SSL.getTask(ssl);
|
|
||||||
if (task == null) {
|
|
||||||
return null;
|
|
||||||
}
|
|
||||||
return new Runnable() {
|
|
||||||
@Override
|
@Override
|
||||||
public void run() {
|
public void run() {
|
||||||
if (isDestroyed()) {
|
if (isDestroyed()) {
|
||||||
@ -1452,7 +1449,42 @@ public class ReferenceCountedOpenSslEngine extends SSLEngine implements Referenc
|
|||||||
needTask = false;
|
needTask = false;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
};
|
}
|
||||||
|
|
||||||
|
private final class AsyncTaskDecorator extends TaskDecorator<AsyncTask> implements AsyncRunnable {
|
||||||
|
AsyncTaskDecorator(AsyncTask task) {
|
||||||
|
super(task);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void run(Runnable runnable) {
|
||||||
|
if (isDestroyed()) {
|
||||||
|
// The engine was destroyed in the meantime, just return.
|
||||||
|
runnable.run();
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
try {
|
||||||
|
task.runAsync(runnable);
|
||||||
|
} finally {
|
||||||
|
// The task was run, reset needTask to false so getHandshakeStatus() returns the correct value.
|
||||||
|
needTask = false;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public final synchronized Runnable getDelegatedTask() {
|
||||||
|
if (isDestroyed()) {
|
||||||
|
return null;
|
||||||
|
}
|
||||||
|
final Runnable task = SSL.getTask(ssl);
|
||||||
|
if (task == null) {
|
||||||
|
return null;
|
||||||
|
}
|
||||||
|
if (task instanceof AsyncTask) {
|
||||||
|
return new AsyncTaskDecorator((AsyncTask) task);
|
||||||
|
}
|
||||||
|
return new TaskDecorator<Runnable>(task);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
|
@ -1393,7 +1393,8 @@ public class SslHandler extends ByteToMessageDecoder implements ChannelOutboundH
|
|||||||
}
|
}
|
||||||
|
|
||||||
if (handshakeStatus == HandshakeStatus.NEED_TASK) {
|
if (handshakeStatus == HandshakeStatus.NEED_TASK) {
|
||||||
if (!runDelegatedTasks(true)) {
|
boolean pending = runDelegatedTasks(true);
|
||||||
|
if (!pending) {
|
||||||
// We scheduled a task on the delegatingTaskExecutor, so stop processing as we will
|
// We scheduled a task on the delegatingTaskExecutor, so stop processing as we will
|
||||||
// resume once the task completes.
|
// resume once the task completes.
|
||||||
//
|
//
|
||||||
@ -1509,16 +1510,6 @@ public class SslHandler extends ByteToMessageDecoder implements ChannelOutboundH
|
|||||||
return executor instanceof EventExecutor && ((EventExecutor) executor).inEventLoop();
|
return executor instanceof EventExecutor && ((EventExecutor) executor).inEventLoop();
|
||||||
}
|
}
|
||||||
|
|
||||||
private static void runAllDelegatedTasks(SSLEngine engine) {
|
|
||||||
for (;;) {
|
|
||||||
Runnable task = engine.getDelegatedTask();
|
|
||||||
if (task == null) {
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
task.run();
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Will either run the delegated task directly calling {@link Runnable#run()} and return {@code true} or will
|
* Will either run the delegated task directly calling {@link Runnable#run()} and return {@code true} or will
|
||||||
* offload the delegated task using {@link Executor#execute(Runnable)} and return {@code false}.
|
* offload the delegated task using {@link Executor#execute(Runnable)} and return {@code false}.
|
||||||
@ -1530,16 +1521,50 @@ public class SslHandler extends ByteToMessageDecoder implements ChannelOutboundH
|
|||||||
if (delegatedTaskExecutor == ImmediateExecutor.INSTANCE || inEventLoop(delegatedTaskExecutor)) {
|
if (delegatedTaskExecutor == ImmediateExecutor.INSTANCE || inEventLoop(delegatedTaskExecutor)) {
|
||||||
// We should run the task directly in the EventExecutor thread and not offload at all. As we are on the
|
// We should run the task directly in the EventExecutor thread and not offload at all. As we are on the
|
||||||
// EventLoop we can just run all tasks at once.
|
// EventLoop we can just run all tasks at once.
|
||||||
runAllDelegatedTasks(engine);
|
for (;;) {
|
||||||
|
Runnable task = engine.getDelegatedTask();
|
||||||
|
if (task == null) {
|
||||||
return true;
|
return true;
|
||||||
|
}
|
||||||
|
setState(STATE_PROCESS_TASK);
|
||||||
|
if (task instanceof AsyncRunnable) {
|
||||||
|
// Let's set the task to processing task before we try to execute it.
|
||||||
|
boolean pending = false;
|
||||||
|
try {
|
||||||
|
AsyncRunnable asyncTask = (AsyncRunnable) task;
|
||||||
|
AsyncTaskCompletionHandler completionHandler = new AsyncTaskCompletionHandler(inUnwrap);
|
||||||
|
asyncTask.run(completionHandler);
|
||||||
|
pending = completionHandler.resumeLater();
|
||||||
|
if (pending) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
} finally {
|
||||||
|
if (!pending) {
|
||||||
|
// The task has completed, lets clear the state. If it is not completed we will clear the
|
||||||
|
// state once it is.
|
||||||
|
clearState(STATE_PROCESS_TASK);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
try {
|
||||||
|
task.run();
|
||||||
|
} finally {
|
||||||
|
clearState(STATE_PROCESS_TASK);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
} else {
|
} else {
|
||||||
executeDelegatedTask(inUnwrap);
|
executeDelegatedTask(inUnwrap);
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
private SslTasksRunner getTaskRunner(boolean inUnwrap) {
|
||||||
|
return inUnwrap ? sslTaskRunnerForUnwrap : sslTaskRunner;
|
||||||
|
}
|
||||||
|
|
||||||
private void executeDelegatedTask(boolean inUnwrap) {
|
private void executeDelegatedTask(boolean inUnwrap) {
|
||||||
executeDelegatedTask(inUnwrap ? sslTaskRunnerForUnwrap : sslTaskRunner);
|
executeDelegatedTask(getTaskRunner(inUnwrap));
|
||||||
}
|
}
|
||||||
|
|
||||||
private void executeDelegatedTask(SslTasksRunner task) {
|
private void executeDelegatedTask(SslTasksRunner task) {
|
||||||
@ -1552,12 +1577,44 @@ public class SslHandler extends ByteToMessageDecoder implements ChannelOutboundH
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
private final class AsyncTaskCompletionHandler implements Runnable {
|
||||||
|
private final boolean inUnwrap;
|
||||||
|
boolean didRun;
|
||||||
|
boolean resumeLater;
|
||||||
|
|
||||||
|
AsyncTaskCompletionHandler(boolean inUnwrap) {
|
||||||
|
this.inUnwrap = inUnwrap;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void run() {
|
||||||
|
didRun = true;
|
||||||
|
if (resumeLater) {
|
||||||
|
getTaskRunner(inUnwrap).runComplete();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
boolean resumeLater() {
|
||||||
|
if (!didRun) {
|
||||||
|
resumeLater = true;
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* {@link Runnable} that will be scheduled on the {@code delegatedTaskExecutor} and will take care
|
* {@link Runnable} that will be scheduled on the {@code delegatedTaskExecutor} and will take care
|
||||||
* of resume work on the {@link EventExecutor} once the task was executed.
|
* of resume work on the {@link EventExecutor} once the task was executed.
|
||||||
*/
|
*/
|
||||||
private final class SslTasksRunner implements Runnable {
|
private final class SslTasksRunner implements Runnable {
|
||||||
private final boolean inUnwrap;
|
private final boolean inUnwrap;
|
||||||
|
private final Runnable runCompleteTask = new Runnable() {
|
||||||
|
@Override
|
||||||
|
public void run() {
|
||||||
|
runComplete();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
SslTasksRunner(boolean inUnwrap) {
|
SslTasksRunner(boolean inUnwrap) {
|
||||||
this.inUnwrap = inUnwrap;
|
this.inUnwrap = inUnwrap;
|
||||||
@ -1699,16 +1756,7 @@ public class SslHandler extends ByteToMessageDecoder implements ChannelOutboundH
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
void runComplete() {
|
||||||
public void run() {
|
|
||||||
try {
|
|
||||||
Runnable task = engine.getDelegatedTask();
|
|
||||||
if (task == null) {
|
|
||||||
// The task was processed in the meantime. Let's just return.
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
task.run();
|
|
||||||
|
|
||||||
EventExecutor executor = ctx.executor();
|
EventExecutor executor = ctx.executor();
|
||||||
if (executor.inEventLoop()) {
|
if (executor.inEventLoop()) {
|
||||||
resumeOnEventExecutor();
|
resumeOnEventExecutor();
|
||||||
@ -1721,18 +1769,36 @@ public class SslHandler extends ByteToMessageDecoder implements ChannelOutboundH
|
|||||||
}
|
}
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void run() {
|
||||||
|
try {
|
||||||
|
Runnable task = engine.getDelegatedTask();
|
||||||
|
if (task == null) {
|
||||||
|
// The task was processed in the meantime. Let's just return.
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
if (task instanceof AsyncRunnable) {
|
||||||
|
AsyncRunnable asyncTask = (AsyncRunnable) task;
|
||||||
|
asyncTask.run(runCompleteTask);
|
||||||
|
} else {
|
||||||
|
task.run();
|
||||||
|
runComplete();
|
||||||
|
}
|
||||||
} catch (final Throwable cause) {
|
} catch (final Throwable cause) {
|
||||||
handleException(cause);
|
handleException(cause);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
private void handleException(final Throwable cause) {
|
private void handleException(final Throwable cause) {
|
||||||
if (ctx.executor().inEventLoop()) {
|
EventExecutor executor = ctx.executor();
|
||||||
|
if (executor.inEventLoop()) {
|
||||||
clearState(STATE_PROCESS_TASK);
|
clearState(STATE_PROCESS_TASK);
|
||||||
safeExceptionCaught(cause);
|
safeExceptionCaught(cause);
|
||||||
} else {
|
} else {
|
||||||
try {
|
try {
|
||||||
ctx.executor().execute(new Runnable() {
|
executor.execute(new Runnable() {
|
||||||
@Override
|
@Override
|
||||||
public void run() {
|
public void run() {
|
||||||
clearState(STATE_PROCESS_TASK);
|
clearState(STATE_PROCESS_TASK);
|
||||||
|
@ -34,9 +34,11 @@ import io.netty.channel.local.LocalServerChannel;
|
|||||||
import io.netty.handler.ssl.util.InsecureTrustManagerFactory;
|
import io.netty.handler.ssl.util.InsecureTrustManagerFactory;
|
||||||
import io.netty.handler.ssl.util.SelfSignedCertificate;
|
import io.netty.handler.ssl.util.SelfSignedCertificate;
|
||||||
import io.netty.util.ReferenceCountUtil;
|
import io.netty.util.ReferenceCountUtil;
|
||||||
|
import io.netty.util.concurrent.Future;
|
||||||
|
import io.netty.util.concurrent.ImmediateEventExecutor;
|
||||||
import io.netty.util.concurrent.Promise;
|
import io.netty.util.concurrent.Promise;
|
||||||
|
import io.netty.util.internal.ThreadLocalRandom;
|
||||||
import org.hamcrest.Matchers;
|
import org.hamcrest.Matchers;
|
||||||
import org.junit.Assume;
|
|
||||||
import org.junit.jupiter.api.AfterAll;
|
import org.junit.jupiter.api.AfterAll;
|
||||||
import org.junit.jupiter.api.BeforeAll;
|
import org.junit.jupiter.api.BeforeAll;
|
||||||
import org.junit.jupiter.params.ParameterizedTest;
|
import org.junit.jupiter.params.ParameterizedTest;
|
||||||
@ -66,9 +68,9 @@ import java.util.concurrent.atomic.AtomicBoolean;
|
|||||||
import static io.netty.handler.ssl.OpenSslTestUtils.checkShouldUseKeyManagerFactory;
|
import static io.netty.handler.ssl.OpenSslTestUtils.checkShouldUseKeyManagerFactory;
|
||||||
import static org.hamcrest.MatcherAssert.assertThat;
|
import static org.hamcrest.MatcherAssert.assertThat;
|
||||||
import static org.junit.jupiter.api.Assertions.assertEquals;
|
import static org.junit.jupiter.api.Assertions.assertEquals;
|
||||||
import static org.junit.jupiter.api.Assertions.assertNotEquals;
|
|
||||||
import static org.junit.jupiter.api.Assertions.assertNotNull;
|
import static org.junit.jupiter.api.Assertions.assertNotNull;
|
||||||
import static org.junit.jupiter.api.Assertions.assertTrue;
|
import static org.junit.jupiter.api.Assertions.assertTrue;
|
||||||
|
import static org.junit.jupiter.api.Assumptions.assumeTrue;
|
||||||
|
|
||||||
public class OpenSslPrivateKeyMethodTest {
|
public class OpenSslPrivateKeyMethodTest {
|
||||||
private static final String RFC_CIPHER_NAME = "TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256";
|
private static final String RFC_CIPHER_NAME = "TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256";
|
||||||
@ -78,8 +80,13 @@ public class OpenSslPrivateKeyMethodTest {
|
|||||||
|
|
||||||
static Collection<Object[]> parameters() {
|
static Collection<Object[]> parameters() {
|
||||||
List<Object[]> dst = new ArrayList<Object[]>();
|
List<Object[]> dst = new ArrayList<Object[]>();
|
||||||
dst.add(new Object[] { true });
|
for (int a = 0; a < 2; a++) {
|
||||||
dst.add(new Object[] { false });
|
for (int b = 0; b < 2; b++) {
|
||||||
|
for (int c = 0; c < 2; c++) {
|
||||||
|
dst.add(new Object[] { a == 0, b == 0, c == 0 });
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
return dst;
|
return dst;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -87,7 +94,7 @@ public class OpenSslPrivateKeyMethodTest {
|
|||||||
public static void init() throws Exception {
|
public static void init() throws Exception {
|
||||||
checkShouldUseKeyManagerFactory();
|
checkShouldUseKeyManagerFactory();
|
||||||
|
|
||||||
Assume.assumeTrue(OpenSsl.isBoringSSL());
|
assumeTrue(OpenSsl.isBoringSSL());
|
||||||
// Check if the cipher is supported at all which may not be the case for various JDK versions and OpenSSL API
|
// Check if the cipher is supported at all which may not be the case for various JDK versions and OpenSSL API
|
||||||
// implementations.
|
// implementations.
|
||||||
assumeCipherAvailable(SslProvider.OPENSSL);
|
assumeCipherAvailable(SslProvider.OPENSSL);
|
||||||
@ -125,7 +132,7 @@ public class OpenSslPrivateKeyMethodTest {
|
|||||||
} else {
|
} else {
|
||||||
cipherSupported = OpenSsl.isCipherSuiteAvailable(RFC_CIPHER_NAME);
|
cipherSupported = OpenSsl.isCipherSuiteAvailable(RFC_CIPHER_NAME);
|
||||||
}
|
}
|
||||||
Assume.assumeTrue("Unsupported cipher: " + RFC_CIPHER_NAME, cipherSupported);
|
assumeTrue(cipherSupported, "Unsupported cipher: " + RFC_CIPHER_NAME);
|
||||||
}
|
}
|
||||||
|
|
||||||
private static SslHandler newSslHandler(SslContext sslCtx, ByteBufAllocator allocator, Executor executor) {
|
private static SslHandler newSslHandler(SslContext sslCtx, ByteBufAllocator allocator, Executor executor) {
|
||||||
@ -163,20 +170,31 @@ public class OpenSslPrivateKeyMethodTest {
|
|||||||
private static Executor delegateExecutor(boolean delegate) {
|
private static Executor delegateExecutor(boolean delegate) {
|
||||||
return delegate ? EXECUTOR : null;
|
return delegate ? EXECUTOR : null;
|
||||||
}
|
}
|
||||||
|
private SslContext buildServerContext(OpenSslAsyncPrivateKeyMethod method) throws Exception {
|
||||||
|
List<String> ciphers = Collections.singletonList(RFC_CIPHER_NAME);
|
||||||
|
|
||||||
|
final KeyManagerFactory kmf = OpenSslX509KeyManagerFactory.newKeyless(CERT.cert());
|
||||||
|
|
||||||
|
return SslContextBuilder.forServer(kmf)
|
||||||
|
.sslProvider(SslProvider.OPENSSL)
|
||||||
|
.ciphers(ciphers)
|
||||||
|
// As this is not a TLSv1.3 cipher we should ensure we talk something else.
|
||||||
|
.protocols(SslProtocols.TLS_v1_2)
|
||||||
|
.option(OpenSslContextOption.ASYNC_PRIVATE_KEY_METHOD, method)
|
||||||
|
.build();
|
||||||
|
}
|
||||||
|
|
||||||
private static void assertThread(boolean delegate) {
|
private static void assertThread(boolean delegate) {
|
||||||
if (delegate && OpenSslContext.USE_TASKS) {
|
if (delegate && OpenSslContext.USE_TASKS) {
|
||||||
assertEquals(DelegateThread.class, Thread.currentThread().getClass());
|
assertEquals(DelegateThread.class, Thread.currentThread().getClass());
|
||||||
} else {
|
|
||||||
assertNotEquals(DelegateThread.class, Thread.currentThread().getClass());
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ParameterizedTest(name = "{index}: delegate = {0}")
|
@ParameterizedTest(name = "{index}: delegate = {0}, async = {1}, newThread={2}")
|
||||||
@MethodSource("parameters")
|
@MethodSource("parameters")
|
||||||
public void testPrivateKeyMethod(final boolean delegate) throws Exception {
|
public void testPrivateKeyMethod(final boolean delegate, boolean async, boolean newThread) throws Exception {
|
||||||
final AtomicBoolean signCalled = new AtomicBoolean();
|
final AtomicBoolean signCalled = new AtomicBoolean();
|
||||||
final SslContext sslServerContext = buildServerContext(new OpenSslPrivateKeyMethod() {
|
OpenSslPrivateKeyMethod keyMethod = new OpenSslPrivateKeyMethod() {
|
||||||
@Override
|
@Override
|
||||||
public byte[] sign(SSLEngine engine, int signatureAlgorithm, byte[] input) throws Exception {
|
public byte[] sign(SSLEngine engine, int signatureAlgorithm, byte[] input) throws Exception {
|
||||||
signCalled.set(true);
|
signCalled.set(true);
|
||||||
@ -206,7 +224,10 @@ public class OpenSslPrivateKeyMethodTest {
|
|||||||
public byte[] decrypt(SSLEngine engine, byte[] input) {
|
public byte[] decrypt(SSLEngine engine, byte[] input) {
|
||||||
throw new UnsupportedOperationException();
|
throw new UnsupportedOperationException();
|
||||||
}
|
}
|
||||||
});
|
};
|
||||||
|
|
||||||
|
final SslContext sslServerContext = async ? buildServerContext(
|
||||||
|
new OpenSslPrivateKeyMethodAdapter(keyMethod, newThread)) : buildServerContext(keyMethod);
|
||||||
|
|
||||||
final SslContext sslClientContext = buildClientContext();
|
final SslContext sslClientContext = buildClientContext();
|
||||||
try {
|
try {
|
||||||
@ -391,4 +412,70 @@ public class OpenSslPrivateKeyMethodTest {
|
|||||||
super(target);
|
super(target);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
private static final class OpenSslPrivateKeyMethodAdapter implements OpenSslAsyncPrivateKeyMethod {
|
||||||
|
private final OpenSslPrivateKeyMethod keyMethod;
|
||||||
|
private final boolean newThread;
|
||||||
|
|
||||||
|
OpenSslPrivateKeyMethodAdapter(OpenSslPrivateKeyMethod keyMethod, boolean newThread) {
|
||||||
|
this.keyMethod = keyMethod;
|
||||||
|
this.newThread = newThread;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public Future<byte[]> sign(final SSLEngine engine, final int signatureAlgorithm, final byte[] input) {
|
||||||
|
final Promise<byte[]> promise = ImmediateEventExecutor.INSTANCE.newPromise();
|
||||||
|
try {
|
||||||
|
if (newThread) {
|
||||||
|
// Let's run these in an extra thread to ensure that this would also work if the promise is
|
||||||
|
// notified later.
|
||||||
|
new DelegateThread(new Runnable() {
|
||||||
|
@Override
|
||||||
|
public void run() {
|
||||||
|
try {
|
||||||
|
// Let's sleep for some time to ensure we would notify in an async fashion
|
||||||
|
Thread.sleep(ThreadLocalRandom.current().nextLong(100, 500));
|
||||||
|
promise.setSuccess(keyMethod.sign(engine, signatureAlgorithm, input));
|
||||||
|
} catch (Throwable cause) {
|
||||||
|
promise.setFailure(cause);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}).start();
|
||||||
|
} else {
|
||||||
|
promise.setSuccess(keyMethod.sign(engine, signatureAlgorithm, input));
|
||||||
|
}
|
||||||
|
} catch (Throwable cause) {
|
||||||
|
promise.setFailure(cause);
|
||||||
|
}
|
||||||
|
return promise;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public Future<byte[]> decrypt(final SSLEngine engine, final byte[] input) {
|
||||||
|
final Promise<byte[]> promise = ImmediateEventExecutor.INSTANCE.newPromise();
|
||||||
|
try {
|
||||||
|
if (newThread) {
|
||||||
|
// Let's run these in an extra thread to ensure that this would also work if the promise is
|
||||||
|
// notified later.
|
||||||
|
new DelegateThread(new Runnable() {
|
||||||
|
@Override
|
||||||
|
public void run() {
|
||||||
|
try {
|
||||||
|
// Let's sleep for some time to ensure we would notify in an async fashion
|
||||||
|
Thread.sleep(ThreadLocalRandom.current().nextLong(100, 500));
|
||||||
|
promise.setSuccess(keyMethod.decrypt(engine, input));
|
||||||
|
} catch (Throwable cause) {
|
||||||
|
promise.setFailure(cause);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}).start();
|
||||||
|
} else {
|
||||||
|
promise.setSuccess(keyMethod.decrypt(engine, input));
|
||||||
|
}
|
||||||
|
} catch (Throwable cause) {
|
||||||
|
promise.setFailure(cause);
|
||||||
|
}
|
||||||
|
return promise;
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
@ -55,6 +55,10 @@ import javax.net.ssl.ManagerFactoryParameters;
|
|||||||
import javax.net.ssl.SSLException;
|
import javax.net.ssl.SSLException;
|
||||||
import javax.net.ssl.TrustManager;
|
import javax.net.ssl.TrustManager;
|
||||||
import javax.net.ssl.X509TrustManager;
|
import javax.net.ssl.X509TrustManager;
|
||||||
|
import java.io.ByteArrayOutputStream;
|
||||||
|
import java.io.IOException;
|
||||||
|
import java.io.PrintStream;
|
||||||
|
import java.io.UnsupportedEncodingException;
|
||||||
import java.net.InetSocketAddress;
|
import java.net.InetSocketAddress;
|
||||||
import java.net.SocketAddress;
|
import java.net.SocketAddress;
|
||||||
import java.security.KeyStore;
|
import java.security.KeyStore;
|
||||||
@ -679,8 +683,23 @@ public class ParameterizedSslHandlerTest {
|
|||||||
}
|
}
|
||||||
|
|
||||||
private void appendError(Throwable cause) {
|
private void appendError(Throwable cause) {
|
||||||
readQueue.append("failed to write '").append(toWrite).append("': ").append(cause);
|
readQueue.append("failed to write '").append(toWrite).append("': ");
|
||||||
|
|
||||||
|
ByteArrayOutputStream out = new ByteArrayOutputStream();
|
||||||
|
try {
|
||||||
|
cause.printStackTrace(new PrintStream(out));
|
||||||
|
readQueue.append(out.toString(CharsetUtil.US_ASCII.name()));
|
||||||
|
} catch (UnsupportedEncodingException ignore) {
|
||||||
|
// Let's just fallback to using toString().
|
||||||
|
readQueue.append(cause);
|
||||||
|
} finally {
|
||||||
doneLatch.countDown();
|
doneLatch.countDown();
|
||||||
|
try {
|
||||||
|
out.close();
|
||||||
|
} catch (IOException ignore) {
|
||||||
|
// ignore
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
Loading…
x
Reference in New Issue
Block a user