We need to ensure we always drain the error stack when a callback throws (#10920)

Motivation:

We need to ensure we always drain the error stack when a callback throws as otherwise we may pick up the error on a different SSL instance which uses the same thread.

Modifications:

- Correctly drain the error stack if native method throws
- Add a unit test which failed before the change

Result:

Always drain the error stack
This commit is contained in:
Norman Maurer 2021-01-11 20:56:52 +01:00 committed by GitHub
parent 3f04385919
commit cbebad8417
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 91 additions and 7 deletions

View File

@ -562,7 +562,7 @@ public class ReferenceCountedOpenSslEngine extends SSLEngine implements Referenc
/**
* Write encrypted data to the OpenSSL network BIO.
*/
private ByteBuf writeEncryptedData(final ByteBuffer src, int len) {
private ByteBuf writeEncryptedData(final ByteBuffer src, int len) throws SSLException {
final int pos = src.position();
if (src.isDirect()) {
SSL.bioSetByteBuffer(networkBIO, bufferAddress(src) + pos, len, false);
@ -589,7 +589,7 @@ public class ReferenceCountedOpenSslEngine extends SSLEngine implements Referenc
/**
* Read plaintext data from the OpenSSL internal BIO
*/
private int readPlaintextData(final ByteBuffer dst) {
private int readPlaintextData(final ByteBuffer dst) throws SSLException {
final int sslRead;
final int pos = dst.position();
if (dst.isDirect()) {
@ -1037,6 +1037,16 @@ public class ReferenceCountedOpenSslEngine extends SSLEngine implements Referenc
return exception;
}
private SSLEngineResult handleUnwrapException(int bytesConsumed, int bytesProduced, SSLException e)
throws SSLException {
int lastError = SSL.getLastErrorNumber();
if (lastError != 0) {
return sslReadErrorResult(SSL.SSL_ERROR_SSL, lastError, bytesConsumed,
bytesProduced);
}
throw e;
}
public final SSLEngineResult unwrap(
final ByteBuffer[] srcs, int srcsOffset, final int srcsLength,
final ByteBuffer[] dsts, int dstsOffset, final int dstsLength) throws SSLException {
@ -1187,7 +1197,12 @@ public class ReferenceCountedOpenSslEngine extends SSLEngine implements Referenc
// Write more encrypted data into the BIO. Ensure we only read one packet at a time as
// stated in the SSLEngine javadocs.
pendingEncryptedBytes = min(packetLength, remaining);
bioWriteCopyBuf = writeEncryptedData(src, pendingEncryptedBytes);
try {
bioWriteCopyBuf = writeEncryptedData(src, pendingEncryptedBytes);
} catch (SSLException e) {
// Ensure we correctly handle the error stack.
return handleUnwrapException(bytesConsumed, bytesProduced, e);
}
}
try {
for (;;) {
@ -1200,7 +1215,13 @@ public class ReferenceCountedOpenSslEngine extends SSLEngine implements Referenc
continue;
}
int bytesRead = readPlaintextData(dst);
int bytesRead;
try {
bytesRead = readPlaintextData(dst);
} catch (SSLException e) {
// Ensure we correctly handle the error stack.
return handleUnwrapException(bytesConsumed, bytesProduced, e);
}
// We are directly using the ByteBuffer memory for the write, and so we only know what has
// been consumed after we let SSL decrypt the data. At this point we should update the
// number of bytes consumed, update the ByteBuffer position, and release temp ByteBuf.
@ -1220,7 +1241,7 @@ public class ReferenceCountedOpenSslEngine extends SSLEngine implements Referenc
return sslPending > 0 ?
newResult(BUFFER_OVERFLOW, status, bytesConsumed, bytesProduced) :
newResultMayFinishHandshake(isInboundDone() ? CLOSED : OK, status,
bytesConsumed, bytesProduced);
bytesConsumed, bytesProduced);
}
} else if (packetLength == 0 || jdkCompatibilityMode) {
// We either consumed all data or we are in jdkCompatibilityMode and have consumed
@ -1247,7 +1268,7 @@ public class ReferenceCountedOpenSslEngine extends SSLEngine implements Referenc
NEED_TASK, bytesConsumed, bytesProduced);
} else {
return sslReadErrorResult(sslError, SSL.getLastErrorNumber(), bytesConsumed,
bytesProduced);
bytesProduced);
}
}
}

View File

@ -39,11 +39,16 @@ import javax.net.ssl.SSLEngineResult;
import javax.net.ssl.SSLEngineResult.HandshakeStatus;
import javax.net.ssl.SSLException;
import javax.net.ssl.SSLParameters;
import javax.net.ssl.X509ExtendedKeyManager;
import java.net.Socket;
import java.nio.ByteBuffer;
import java.security.AlgorithmConstraints;
import java.security.AlgorithmParameters;
import java.security.CryptoPrimitive;
import java.security.Key;
import java.security.Principal;
import java.security.PrivateKey;
import java.security.cert.X509Certificate;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
@ -1312,6 +1317,63 @@ public class OpenSslEngineTest extends SSLEngineTest {
}
}
@Test(expected = SSLException.class)
public void testNoKeyFound() throws Exception {
clientSslCtx = wrapContext(SslContextBuilder
.forClient()
.trustManager(InsecureTrustManagerFactory.INSTANCE)
.sslProvider(sslClientProvider())
.protocols(protocols())
.ciphers(ciphers())
.build());
SSLEngine client = wrapEngine(clientSslCtx.newEngine(UnpooledByteBufAllocator.DEFAULT));
serverSslCtx = wrapContext(SslContextBuilder
.forServer(new X509ExtendedKeyManager() {
@Override
public String[] getClientAliases(String keyType, Principal[] issuers) {
return new String[0];
}
@Override
public String chooseClientAlias(String[] keyType, Principal[] issuers, Socket socket) {
return null;
}
@Override
public String[] getServerAliases(String keyType, Principal[] issuers) {
return new String[0];
}
@Override
public String chooseServerAlias(String keyType, Principal[] issuers, Socket socket) {
return null;
}
@Override
public X509Certificate[] getCertificateChain(String alias) {
return new X509Certificate[0];
}
@Override
public PrivateKey getPrivateKey(String alias) {
return null;
}
})
.sslProvider(sslServerProvider())
.protocols(protocols())
.ciphers(ciphers())
.build());
SSLEngine server = wrapEngine(serverSslCtx.newEngine(UnpooledByteBufAllocator.DEFAULT));
try {
handshake(client, server);
} finally {
cleanupClientSslEngine(client);
cleanupServerSslEngine(server);
}
}
@Override
@Test
public void testSessionLocalWhenNonMutualWithKeyManager() throws Exception {

View File

@ -435,6 +435,7 @@ final class OpenSslErrorStackAssertSSLEngine extends JdkSslEngine implements Ref
}
private static void assertErrorStackEmpty() {
Assert.assertEquals("SSL error stack non-empty", 0, SSL.getLastErrorNumber());
long error = SSL.getLastErrorNumber();
Assert.assertEquals("SSL error stack non-empty: " + SSL.getErrorString(error), 0, error);
}
}