We need to ensure we always drain the error stack when a callback throws (#10920)

Motivation:

We need to ensure we always drain the error stack when a callback throws as otherwise we may pick up the error on a different SSL instance which uses the same thread.

Modifications:

- Correctly drain the error stack if native method throws
- Add a unit test which failed before the change

Result:

Always drain the error stack
This commit is contained in:
Norman Maurer 2021-01-11 20:56:52 +01:00
parent e8f3e88526
commit dc632e378f
3 changed files with 91 additions and 7 deletions

View File

@ -553,7 +553,7 @@ public class ReferenceCountedOpenSslEngine extends SSLEngine implements Referenc
/** /**
* Write encrypted data to the OpenSSL network BIO. * Write encrypted data to the OpenSSL network BIO.
*/ */
private ByteBuf writeEncryptedData(final ByteBuffer src, int len) { private ByteBuf writeEncryptedData(final ByteBuffer src, int len) throws SSLException {
final int pos = src.position(); final int pos = src.position();
if (src.isDirect()) { if (src.isDirect()) {
SSL.bioSetByteBuffer(networkBIO, bufferAddress(src) + pos, len, false); SSL.bioSetByteBuffer(networkBIO, bufferAddress(src) + pos, len, false);
@ -580,7 +580,7 @@ public class ReferenceCountedOpenSslEngine extends SSLEngine implements Referenc
/** /**
* Read plaintext data from the OpenSSL internal BIO * Read plaintext data from the OpenSSL internal BIO
*/ */
private int readPlaintextData(final ByteBuffer dst) { private int readPlaintextData(final ByteBuffer dst) throws SSLException {
final int sslRead; final int sslRead;
final int pos = dst.position(); final int pos = dst.position();
if (dst.isDirect()) { if (dst.isDirect()) {
@ -1028,6 +1028,16 @@ public class ReferenceCountedOpenSslEngine extends SSLEngine implements Referenc
return exception; return exception;
} }
private SSLEngineResult handleUnwrapException(int bytesConsumed, int bytesProduced, SSLException e)
throws SSLException {
int lastError = SSL.getLastErrorNumber();
if (lastError != 0) {
return sslReadErrorResult(SSL.SSL_ERROR_SSL, lastError, bytesConsumed,
bytesProduced);
}
throw e;
}
public final SSLEngineResult unwrap( public final SSLEngineResult unwrap(
final ByteBuffer[] srcs, int srcsOffset, final int srcsLength, final ByteBuffer[] srcs, int srcsOffset, final int srcsLength,
final ByteBuffer[] dsts, int dstsOffset, final int dstsLength) throws SSLException { final ByteBuffer[] dsts, int dstsOffset, final int dstsLength) throws SSLException {
@ -1178,7 +1188,12 @@ public class ReferenceCountedOpenSslEngine extends SSLEngine implements Referenc
// Write more encrypted data into the BIO. Ensure we only read one packet at a time as // Write more encrypted data into the BIO. Ensure we only read one packet at a time as
// stated in the SSLEngine javadocs. // stated in the SSLEngine javadocs.
pendingEncryptedBytes = min(packetLength, remaining); pendingEncryptedBytes = min(packetLength, remaining);
bioWriteCopyBuf = writeEncryptedData(src, pendingEncryptedBytes); try {
bioWriteCopyBuf = writeEncryptedData(src, pendingEncryptedBytes);
} catch (SSLException e) {
// Ensure we correctly handle the error stack.
return handleUnwrapException(bytesConsumed, bytesProduced, e);
}
} }
try { try {
for (;;) { for (;;) {
@ -1191,7 +1206,13 @@ public class ReferenceCountedOpenSslEngine extends SSLEngine implements Referenc
continue; continue;
} }
int bytesRead = readPlaintextData(dst); int bytesRead;
try {
bytesRead = readPlaintextData(dst);
} catch (SSLException e) {
// Ensure we correctly handle the error stack.
return handleUnwrapException(bytesConsumed, bytesProduced, e);
}
// We are directly using the ByteBuffer memory for the write, and so we only know what has // We are directly using the ByteBuffer memory for the write, and so we only know what has
// been consumed after we let SSL decrypt the data. At this point we should update the // been consumed after we let SSL decrypt the data. At this point we should update the
// number of bytes consumed, update the ByteBuffer position, and release temp ByteBuf. // number of bytes consumed, update the ByteBuffer position, and release temp ByteBuf.
@ -1211,7 +1232,7 @@ public class ReferenceCountedOpenSslEngine extends SSLEngine implements Referenc
return sslPending > 0 ? return sslPending > 0 ?
newResult(BUFFER_OVERFLOW, status, bytesConsumed, bytesProduced) : newResult(BUFFER_OVERFLOW, status, bytesConsumed, bytesProduced) :
newResultMayFinishHandshake(isInboundDone() ? CLOSED : OK, status, newResultMayFinishHandshake(isInboundDone() ? CLOSED : OK, status,
bytesConsumed, bytesProduced); bytesConsumed, bytesProduced);
} }
} else if (packetLength == 0 || jdkCompatibilityMode) { } else if (packetLength == 0 || jdkCompatibilityMode) {
// We either consumed all data or we are in jdkCompatibilityMode and have consumed // We either consumed all data or we are in jdkCompatibilityMode and have consumed
@ -1238,7 +1259,7 @@ public class ReferenceCountedOpenSslEngine extends SSLEngine implements Referenc
NEED_TASK, bytesConsumed, bytesProduced); NEED_TASK, bytesConsumed, bytesProduced);
} else { } else {
return sslReadErrorResult(sslError, SSL.getLastErrorNumber(), bytesConsumed, return sslReadErrorResult(sslError, SSL.getLastErrorNumber(), bytesConsumed,
bytesProduced); bytesProduced);
} }
} }
} }

View File

@ -39,11 +39,16 @@ import javax.net.ssl.SSLEngineResult;
import javax.net.ssl.SSLEngineResult.HandshakeStatus; import javax.net.ssl.SSLEngineResult.HandshakeStatus;
import javax.net.ssl.SSLException; import javax.net.ssl.SSLException;
import javax.net.ssl.SSLParameters; import javax.net.ssl.SSLParameters;
import javax.net.ssl.X509ExtendedKeyManager;
import java.net.Socket;
import java.nio.ByteBuffer; import java.nio.ByteBuffer;
import java.security.AlgorithmConstraints; import java.security.AlgorithmConstraints;
import java.security.AlgorithmParameters; import java.security.AlgorithmParameters;
import java.security.CryptoPrimitive; import java.security.CryptoPrimitive;
import java.security.Key; import java.security.Key;
import java.security.Principal;
import java.security.PrivateKey;
import java.security.cert.X509Certificate;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.Arrays; import java.util.Arrays;
import java.util.Collection; import java.util.Collection;
@ -1311,6 +1316,63 @@ public class OpenSslEngineTest extends SSLEngineTest {
} }
} }
@Test(expected = SSLException.class)
public void testNoKeyFound() throws Exception {
clientSslCtx = wrapContext(SslContextBuilder
.forClient()
.trustManager(InsecureTrustManagerFactory.INSTANCE)
.sslProvider(sslClientProvider())
.protocols(protocols())
.ciphers(ciphers())
.build());
SSLEngine client = wrapEngine(clientSslCtx.newEngine(UnpooledByteBufAllocator.DEFAULT));
serverSslCtx = wrapContext(SslContextBuilder
.forServer(new X509ExtendedKeyManager() {
@Override
public String[] getClientAliases(String keyType, Principal[] issuers) {
return new String[0];
}
@Override
public String chooseClientAlias(String[] keyType, Principal[] issuers, Socket socket) {
return null;
}
@Override
public String[] getServerAliases(String keyType, Principal[] issuers) {
return new String[0];
}
@Override
public String chooseServerAlias(String keyType, Principal[] issuers, Socket socket) {
return null;
}
@Override
public X509Certificate[] getCertificateChain(String alias) {
return new X509Certificate[0];
}
@Override
public PrivateKey getPrivateKey(String alias) {
return null;
}
})
.sslProvider(sslServerProvider())
.protocols(protocols())
.ciphers(ciphers())
.build());
SSLEngine server = wrapEngine(serverSslCtx.newEngine(UnpooledByteBufAllocator.DEFAULT));
try {
handshake(client, server);
} finally {
cleanupClientSslEngine(client);
cleanupServerSslEngine(server);
}
}
@Override @Override
@Test @Test
public void testSessionLocalWhenNonMutualWithKeyManager() throws Exception { public void testSessionLocalWhenNonMutualWithKeyManager() throws Exception {

View File

@ -435,6 +435,7 @@ final class OpenSslErrorStackAssertSSLEngine extends JdkSslEngine implements Ref
} }
private static void assertErrorStackEmpty() { private static void assertErrorStackEmpty() {
Assert.assertEquals("SSL error stack non-empty", 0, SSL.getLastErrorNumber()); long error = SSL.getLastErrorNumber();
Assert.assertEquals("SSL error stack non-empty: " + SSL.getErrorString(error), 0, error);
} }
} }