SslStream: reuse SslCtx in a simple case

GitOrigin-RevId: 4ac372b23d57a305d69f2d7ec3032b239b43ca7c
This commit is contained in:
Arseny Smirnov 2020-07-02 18:01:23 +03:00
parent 06bd4fa734
commit 4ed1713553

View File

@ -151,36 +151,33 @@ void do_ssl_shutdown(SSL *ssl_handle) {
clear_openssl_errors("After SSL_shutdown"); clear_openssl_errors("After SSL_shutdown");
} }
} // namespace struct SslCtxDeleter {
void operator()(SSL_CTX *ssl_ctx) {
class SslStreamImpl { if (!ssl_ctx) {
public:
using VerifyPeer = SslStream::VerifyPeer;
~SslStreamImpl() {
if (!ssl_handle_) {
CHECK(!ssl_ctx_ && !bio_);
return; return;
} }
CHECK(ssl_handle_ && ssl_ctx_ && bio_); SSL_CTX_free(ssl_ctx);
do_ssl_shutdown(ssl_handle_);
SSL_free(ssl_handle_);
ssl_handle_ = nullptr;
SSL_CTX_free(ssl_ctx_);
ssl_ctx_ = nullptr;
} }
Status init(CSlice host, CSlice cert_file, VerifyPeer verify_peer) { };
static bool init_openssl = [] {
#if OPENSSL_VERSION_NUMBER >= 0x10100000L
return OPENSSL_init_ssl(0, nullptr) != 0;
#else
OpenSSL_add_all_algorithms();
SSL_load_error_strings();
return OpenSSL_add_ssl_algorithms() != 0;
#endif
}();
CHECK(init_openssl);
clear_openssl_errors("Before SslFd::init"); using SslCtx = std::unique_ptr<SSL_CTX, SslCtxDeleter>;
struct SslHandleDeleter {
void operator()(SSL *ssl_handle) {
if (!ssl_handle) {
return;
}
do_ssl_shutdown(ssl_handle);
SSL_free(ssl_handle);
}
};
using SslHandle = std::unique_ptr<SSL, SslHandleDeleter>;
static constexpr int VERIFY_DEPTH = 10;
td::Result<SslCtx> do_create_ssl_ctx(CSlice cert_file, SslStream::VerifyPeer verify_peer) {
using VerifyPeer = SslStream::VerifyPeer;
auto ssl_method = auto ssl_method =
#if OPENSSL_VERSION_NUMBER >= 0x10100000L #if OPENSSL_VERSION_NUMBER >= 0x10100000L
@ -191,14 +188,11 @@ class SslStreamImpl {
if (ssl_method == nullptr) { if (ssl_method == nullptr) {
return create_openssl_error(-6, "Failed to create an SSL client method"); return create_openssl_error(-6, "Failed to create an SSL client method");
} }
auto ssl_ctx_ptr = SslCtx(SSL_CTX_new(ssl_method));
auto ssl_ctx = SSL_CTX_new(ssl_method); if (!ssl_ctx_ptr) {
if (ssl_ctx == nullptr) {
return create_openssl_error(-7, "Failed to create an SSL context"); return create_openssl_error(-7, "Failed to create an SSL context");
} }
auto ssl_ctx_guard = ScopeExit() + [&] { auto ssl_ctx = ssl_ctx_ptr.get();
SSL_CTX_free(ssl_ctx);
};
long options = 0; long options = 0;
#ifdef SSL_OP_NO_SSLv2 #ifdef SSL_OP_NO_SSLv2
options |= SSL_OP_NO_SSLv2; options |= SSL_OP_NO_SSLv2;
@ -284,19 +278,63 @@ class SslStreamImpl {
return create_openssl_error(-9, PSLICE() << "Failed to set cipher list \"" << cipher_list << '"'); return create_openssl_error(-9, PSLICE() << "Failed to set cipher list \"" << cipher_list << '"');
} }
auto ssl_handle = SSL_new(ssl_ctx); return ssl_ctx_ptr;
if (ssl_handle == nullptr) { }
td::Result<SslCtx> clone(const SslCtx &ctx_ptr) {
auto ctx = ctx_ptr.get();
if (!SSL_CTX_up_ref(ctx)) {
return create_openssl_error(-23, "Failed to increase reference counter in ssl context");
}
return SslCtx(ctx);
}
td::Result<SslCtx> get_default_ssl_ctx() {
static auto ctx = do_create_ssl_ctx("", SslStream::VerifyPeer::On);
if (ctx.is_error()) {
return ctx.error().clone();
}
return clone(ctx.ok());
}
td::Result<SslCtx> create_ssl_ctx(CSlice cert_file, SslStream::VerifyPeer verify_peer) {
if (cert_file.empty() && verify_peer == SslStream::VerifyPeer::On) {
return get_default_ssl_ctx();
}
return do_create_ssl_ctx(cert_file, verify_peer);
}
} // namespace
class SslStreamImpl {
public:
using VerifyPeer = SslStream::VerifyPeer;
Status init(CSlice host, CSlice cert_file, VerifyPeer verify_peer) {
static bool init_openssl = [] {
#if OPENSSL_VERSION_NUMBER >= 0x10100000L
return OPENSSL_init_ssl(0, nullptr) != 0;
#else
OpenSSL_add_all_algorithms();
SSL_load_error_strings();
return OpenSSL_add_ssl_algorithms() != 0;
#endif
}();
CHECK(init_openssl);
clear_openssl_errors("Before SslFd::init");
TRY_RESULT(ssl_ctx, create_ssl_ctx(cert_file, verify_peer));
auto ssl_handle = SslHandle(SSL_new(ssl_ctx.get()));
if (!ssl_handle) {
return create_openssl_error(-13, "Failed to create an SSL handle"); return create_openssl_error(-13, "Failed to create an SSL handle");
} }
auto ssl_handle_guard = ScopeExit() + [&] {
do_ssl_shutdown(ssl_handle);
SSL_free(ssl_handle);
};
auto r_ip_address = IPAddress::get_ip_address(host); auto r_ip_address = IPAddress::get_ip_address(host);
#if OPENSSL_VERSION_NUMBER >= 0x10002000L #if OPENSSL_VERSION_NUMBER >= 0x10002000L
X509_VERIFY_PARAM *param = SSL_get0_param(ssl_handle); X509_VERIFY_PARAM *param = SSL_get0_param(ssl_handle.get());
X509_VERIFY_PARAM_set_hostflags(param, 0); X509_VERIFY_PARAM_set_hostflags(param, 0);
if (r_ip_address.is_ok()) { if (r_ip_address.is_ok()) {
LOG(DEBUG) << "Set verification IP address to " << r_ip_address.ok().get_ip_str(); LOG(DEBUG) << "Set verification IP address to " << r_ip_address.ok().get_ip_str();
@ -311,23 +349,18 @@ class SslStreamImpl {
auto *bio = BIO_new(BIO_s_sslstream()); auto *bio = BIO_new(BIO_s_sslstream());
BIO_set_data(bio, static_cast<void *>(this)); BIO_set_data(bio, static_cast<void *>(this));
SSL_set_bio(ssl_handle, bio, bio); SSL_set_bio(ssl_handle.get(), bio, bio);
#if OPENSSL_VERSION_NUMBER >= 0x0090806fL && !defined(OPENSSL_NO_TLSEXT) #if OPENSSL_VERSION_NUMBER >= 0x0090806fL && !defined(OPENSSL_NO_TLSEXT)
if (r_ip_address.is_error()) { // IP address must not be send as SNI if (r_ip_address.is_error()) { // IP address must not be send as SNI
LOG(DEBUG) << "Set SNI host name to " << host; LOG(DEBUG) << "Set SNI host name to " << host;
auto host_str = host.str(); auto host_str = host.str();
SSL_set_tlsext_host_name(ssl_handle, MutableCSlice(host_str).begin()); SSL_set_tlsext_host_name(ssl_handle.get(), MutableCSlice(host_str).begin());
} }
#endif #endif
SSL_set_connect_state(ssl_handle); SSL_set_connect_state(ssl_handle.get());
ssl_ctx_guard.dismiss(); ssl_handle_ = std::move(ssl_handle);
ssl_handle_guard.dismiss();
ssl_handle_ = ssl_handle;
ssl_ctx_ = ssl_ctx;
bio_ = bio;
return Status::OK(); return Status::OK();
} }
@ -346,18 +379,14 @@ class SslStreamImpl {
} }
private: private:
static constexpr int VERIFY_DEPTH = 10; SslHandle ssl_handle_;
SSL *ssl_handle_ = nullptr;
SSL_CTX *ssl_ctx_ = nullptr;
BIO *bio_ = nullptr;
friend class SslReadByteFlow; friend class SslReadByteFlow;
friend class SslWriteByteFlow; friend class SslWriteByteFlow;
Result<size_t> write(Slice slice) { Result<size_t> write(Slice slice) {
clear_openssl_errors("Before SslFd::write"); clear_openssl_errors("Before SslFd::write");
auto size = SSL_write(ssl_handle_, slice.data(), static_cast<int>(slice.size())); auto size = SSL_write(ssl_handle_.get(), slice.data(), static_cast<int>(slice.size()));
if (size <= 0) { if (size <= 0) {
return process_ssl_error(size); return process_ssl_error(size);
} }
@ -366,7 +395,7 @@ class SslStreamImpl {
Result<size_t> read(MutableSlice slice) { Result<size_t> read(MutableSlice slice) {
clear_openssl_errors("Before SslFd::read"); clear_openssl_errors("Before SslFd::read");
auto size = SSL_read(ssl_handle_, slice.data(), static_cast<int>(slice.size())); auto size = SSL_read(ssl_handle_.get(), slice.data(), static_cast<int>(slice.size()));
if (size <= 0) { if (size <= 0) {
return process_ssl_error(size); return process_ssl_error(size);
} }
@ -444,7 +473,7 @@ class SslStreamImpl {
Result<size_t> process_ssl_error(int ret) { Result<size_t> process_ssl_error(int ret) {
auto os_error = OS_ERROR("SSL_ERROR_SYSCALL"); auto os_error = OS_ERROR("SSL_ERROR_SYSCALL");
int error = SSL_get_error(ssl_handle_, ret); int error = SSL_get_error(ssl_handle_.get(), ret);
switch (error) { switch (error) {
case SSL_ERROR_NONE: case SSL_ERROR_NONE:
LOG(ERROR) << "SSL_get_error returned no error"; LOG(ERROR) << "SSL_get_error returned no error";