diff --git a/tdnet/td/net/SslStream.cpp b/tdnet/td/net/SslStream.cpp index feb8eed26..32eaec151 100644 --- a/tdnet/td/net/SslStream.cpp +++ b/tdnet/td/net/SslStream.cpp @@ -151,23 +151,165 @@ void do_ssl_shutdown(SSL *ssl_handle) { clear_openssl_errors("After SSL_shutdown"); } +struct SslCtxDeleter { + void operator()(SSL_CTX *ssl_ctx) { + if (!ssl_ctx) { + return; + } + SSL_CTX_free(ssl_ctx); + } +}; + +using SslCtx = std::unique_ptr; + +struct SslHandleDeleter { + void operator()(SSL *ssl_handle) { + if (!ssl_handle) { + return; + } + do_ssl_shutdown(ssl_handle); + SSL_free(ssl_handle); + } +}; + +using SslHandle = std::unique_ptr; + +static constexpr int VERIFY_DEPTH = 10; + +td::Result do_create_ssl_ctx(CSlice cert_file, SslStream::VerifyPeer verify_peer) { + using VerifyPeer = SslStream::VerifyPeer; + + auto ssl_method = +#if OPENSSL_VERSION_NUMBER >= 0x10100000L + TLS_client_method(); +#else + SSLv23_client_method(); +#endif + if (ssl_method == nullptr) { + return create_openssl_error(-6, "Failed to create an SSL client method"); + } + auto ssl_ctx_ptr = SslCtx(SSL_CTX_new(ssl_method)); + if (!ssl_ctx_ptr) { + return create_openssl_error(-7, "Failed to create an SSL context"); + } + auto ssl_ctx = ssl_ctx_ptr.get(); + long options = 0; +#ifdef SSL_OP_NO_SSLv2 + options |= SSL_OP_NO_SSLv2; +#endif +#ifdef SSL_OP_NO_SSLv3 + options |= SSL_OP_NO_SSLv3; +#endif + SSL_CTX_set_options(ssl_ctx, options); +#if OPENSSL_VERSION_NUMBER >= 0x10100000L + SSL_CTX_set_min_proto_version(ssl_ctx, TLS1_VERSION); +#endif + SSL_CTX_set_mode(ssl_ctx, SSL_MODE_ACCEPT_MOVING_WRITE_BUFFER | SSL_MODE_ENABLE_PARTIAL_WRITE); + + if (cert_file.empty()) { +#if TD_PORT_WINDOWS + // TODO thread-local SSL_CTX cache + LOG(DEBUG) << "Begin to load system store"; + auto flags = CERT_STORE_OPEN_EXISTING_FLAG | CERT_STORE_READONLY_FLAG | CERT_SYSTEM_STORE_CURRENT_USER; + HCERTSTORE system_store = + CertOpenStore(CERT_STORE_PROV_SYSTEM_W, X509_ASN_ENCODING | PKCS_7_ASN_ENCODING, HCRYPTPROV_LEGACY(), flags, + static_cast(to_wstring("ROOT").ok().c_str())); + + if (system_store) { + X509_STORE *store = X509_STORE_new(); + + for (PCCERT_CONTEXT cert_context = CertEnumCertificatesInStore(system_store, nullptr); cert_context != nullptr; + cert_context = CertEnumCertificatesInStore(system_store, cert_context)) { + const unsigned char *in = cert_context->pbCertEncoded; + X509 *x509 = d2i_X509(nullptr, &in, static_cast(cert_context->cbCertEncoded)); + if (x509 != nullptr) { + if (X509_STORE_add_cert(store, x509) != 1) { + auto error_code = ERR_peek_error(); + auto error = create_openssl_error(-20, "Failed to add certificate"); + if (ERR_GET_REASON(error_code) != X509_R_CERT_ALREADY_IN_HASH_TABLE) { + LOG(ERROR) << error; + } else { + LOG(INFO) << error; + } + } + + X509_free(x509); + } else { + LOG(ERROR) << create_openssl_error(-21, "Failed to load X509 certificate"); + } + } + + CertCloseStore(system_store, 0); + + SSL_CTX_set_cert_store(ssl_ctx, store); + LOG(DEBUG) << "End to load system store"; + } else { + LOG(ERROR) << create_openssl_error(-22, "Failed to open system certificate store"); + } +#else + if (SSL_CTX_set_default_verify_paths(ssl_ctx) == 0) { + auto error = create_openssl_error(-8, "Failed to load default verify paths"); + if (verify_peer == VerifyPeer::On) { + return error; + } else { + LOG(ERROR) << error; + } + } +#endif + } else { + if (SSL_CTX_load_verify_locations(ssl_ctx, cert_file.c_str(), nullptr) == 0) { + return create_openssl_error(-8, "Failed to set custom certificate file"); + } + } + + if (verify_peer == VerifyPeer::On) { + SSL_CTX_set_verify(ssl_ctx, SSL_VERIFY_PEER, verify_callback); + + if (VERIFY_DEPTH != -1) { + SSL_CTX_set_verify_depth(ssl_ctx, VERIFY_DEPTH); + } + } else { + SSL_CTX_set_verify(ssl_ctx, SSL_VERIFY_NONE, nullptr); + } + + // TODO(now): cipher list + string cipher_list; + if (SSL_CTX_set_cipher_list(ssl_ctx, cipher_list.empty() ? "DEFAULT" : cipher_list.c_str()) == 0) { + return create_openssl_error(-9, PSLICE() << "Failed to set cipher list \"" << cipher_list << '"'); + } + + return ssl_ctx_ptr; +} + +td::Result clone(const SslCtx &ctx_ptr) { + auto ctx = ctx_ptr.get(); + if (!SSL_CTX_up_ref(ctx)) { + return create_openssl_error(-23, "Failed to increase reference counter in ssl context"); + } + return SslCtx(ctx); +} + +td::Result get_default_ssl_ctx() { + static auto ctx = do_create_ssl_ctx("", SslStream::VerifyPeer::On); + if (ctx.is_error()) { + return ctx.error().clone(); + } + + return clone(ctx.ok()); +} + +td::Result create_ssl_ctx(CSlice cert_file, SslStream::VerifyPeer verify_peer) { + if (cert_file.empty() && verify_peer == SslStream::VerifyPeer::On) { + return get_default_ssl_ctx(); + } + return do_create_ssl_ctx(cert_file, verify_peer); +} + } // namespace class SslStreamImpl { public: using VerifyPeer = SslStream::VerifyPeer; - ~SslStreamImpl() { - if (!ssl_handle_) { - CHECK(!ssl_ctx_ && !bio_); - return; - } - CHECK(ssl_handle_ && ssl_ctx_ && bio_); - do_ssl_shutdown(ssl_handle_); - SSL_free(ssl_handle_); - ssl_handle_ = nullptr; - SSL_CTX_free(ssl_ctx_); - ssl_ctx_ = nullptr; - } Status init(CSlice host, CSlice cert_file, VerifyPeer verify_peer) { static bool init_openssl = [] { #if OPENSSL_VERSION_NUMBER >= 0x10100000L @@ -182,121 +324,17 @@ class SslStreamImpl { clear_openssl_errors("Before SslFd::init"); - auto ssl_method = -#if OPENSSL_VERSION_NUMBER >= 0x10100000L - TLS_client_method(); -#else - SSLv23_client_method(); -#endif - if (ssl_method == nullptr) { - return create_openssl_error(-6, "Failed to create an SSL client method"); - } + TRY_RESULT(ssl_ctx, create_ssl_ctx(cert_file, verify_peer)); - auto ssl_ctx = SSL_CTX_new(ssl_method); - if (ssl_ctx == nullptr) { - return create_openssl_error(-7, "Failed to create an SSL context"); - } - auto ssl_ctx_guard = ScopeExit() + [&] { - SSL_CTX_free(ssl_ctx); - }; - long options = 0; -#ifdef SSL_OP_NO_SSLv2 - options |= SSL_OP_NO_SSLv2; -#endif -#ifdef SSL_OP_NO_SSLv3 - options |= SSL_OP_NO_SSLv3; -#endif - SSL_CTX_set_options(ssl_ctx, options); -#if OPENSSL_VERSION_NUMBER >= 0x10100000L - SSL_CTX_set_min_proto_version(ssl_ctx, TLS1_VERSION); -#endif - SSL_CTX_set_mode(ssl_ctx, SSL_MODE_ACCEPT_MOVING_WRITE_BUFFER | SSL_MODE_ENABLE_PARTIAL_WRITE); - - if (cert_file.empty()) { -#if TD_PORT_WINDOWS - // TODO thread-local SSL_CTX cache - LOG(DEBUG) << "Begin to load system store"; - auto flags = CERT_STORE_OPEN_EXISTING_FLAG | CERT_STORE_READONLY_FLAG | CERT_SYSTEM_STORE_CURRENT_USER; - HCERTSTORE system_store = - CertOpenStore(CERT_STORE_PROV_SYSTEM_W, X509_ASN_ENCODING | PKCS_7_ASN_ENCODING, HCRYPTPROV_LEGACY(), flags, - static_cast(to_wstring("ROOT").ok().c_str())); - - if (system_store) { - X509_STORE *store = X509_STORE_new(); - - for (PCCERT_CONTEXT cert_context = CertEnumCertificatesInStore(system_store, nullptr); cert_context != nullptr; - cert_context = CertEnumCertificatesInStore(system_store, cert_context)) { - const unsigned char *in = cert_context->pbCertEncoded; - X509 *x509 = d2i_X509(nullptr, &in, static_cast(cert_context->cbCertEncoded)); - if (x509 != nullptr) { - if (X509_STORE_add_cert(store, x509) != 1) { - auto error_code = ERR_peek_error(); - auto error = create_openssl_error(-20, "Failed to add certificate"); - if (ERR_GET_REASON(error_code) != X509_R_CERT_ALREADY_IN_HASH_TABLE) { - LOG(ERROR) << error; - } else { - LOG(INFO) << error; - } - } - - X509_free(x509); - } else { - LOG(ERROR) << create_openssl_error(-21, "Failed to load X509 certificate"); - } - } - - CertCloseStore(system_store, 0); - - SSL_CTX_set_cert_store(ssl_ctx, store); - LOG(DEBUG) << "End to load system store"; - } else { - LOG(ERROR) << create_openssl_error(-22, "Failed to open system certificate store"); - } -#else - if (SSL_CTX_set_default_verify_paths(ssl_ctx) == 0) { - auto error = create_openssl_error(-8, "Failed to load default verify paths"); - if (verify_peer == VerifyPeer::On) { - return error; - } else { - LOG(ERROR) << error; - } - } -#endif - } else { - if (SSL_CTX_load_verify_locations(ssl_ctx, cert_file.c_str(), nullptr) == 0) { - return create_openssl_error(-8, "Failed to set custom certificate file"); - } - } - - if (verify_peer == VerifyPeer::On) { - SSL_CTX_set_verify(ssl_ctx, SSL_VERIFY_PEER, verify_callback); - - if (VERIFY_DEPTH != -1) { - SSL_CTX_set_verify_depth(ssl_ctx, VERIFY_DEPTH); - } - } else { - SSL_CTX_set_verify(ssl_ctx, SSL_VERIFY_NONE, nullptr); - } - - // TODO(now): cipher list - string cipher_list; - if (SSL_CTX_set_cipher_list(ssl_ctx, cipher_list.empty() ? "DEFAULT" : cipher_list.c_str()) == 0) { - return create_openssl_error(-9, PSLICE() << "Failed to set cipher list \"" << cipher_list << '"'); - } - - auto ssl_handle = SSL_new(ssl_ctx); - if (ssl_handle == nullptr) { + auto ssl_handle = SslHandle(SSL_new(ssl_ctx.get())); + if (!ssl_handle) { return create_openssl_error(-13, "Failed to create an SSL handle"); } - auto ssl_handle_guard = ScopeExit() + [&] { - do_ssl_shutdown(ssl_handle); - SSL_free(ssl_handle); - }; auto r_ip_address = IPAddress::get_ip_address(host); #if OPENSSL_VERSION_NUMBER >= 0x10002000L - X509_VERIFY_PARAM *param = SSL_get0_param(ssl_handle); + X509_VERIFY_PARAM *param = SSL_get0_param(ssl_handle.get()); X509_VERIFY_PARAM_set_hostflags(param, 0); if (r_ip_address.is_ok()) { LOG(DEBUG) << "Set verification IP address to " << r_ip_address.ok().get_ip_str(); @@ -311,23 +349,18 @@ class SslStreamImpl { auto *bio = BIO_new(BIO_s_sslstream()); BIO_set_data(bio, static_cast(this)); - SSL_set_bio(ssl_handle, bio, bio); + SSL_set_bio(ssl_handle.get(), bio, bio); #if OPENSSL_VERSION_NUMBER >= 0x0090806fL && !defined(OPENSSL_NO_TLSEXT) if (r_ip_address.is_error()) { // IP address must not be send as SNI LOG(DEBUG) << "Set SNI host name to " << host; auto host_str = host.str(); - SSL_set_tlsext_host_name(ssl_handle, MutableCSlice(host_str).begin()); + SSL_set_tlsext_host_name(ssl_handle.get(), MutableCSlice(host_str).begin()); } #endif - SSL_set_connect_state(ssl_handle); + SSL_set_connect_state(ssl_handle.get()); - ssl_ctx_guard.dismiss(); - ssl_handle_guard.dismiss(); - - ssl_handle_ = ssl_handle; - ssl_ctx_ = ssl_ctx; - bio_ = bio; + ssl_handle_ = std::move(ssl_handle); return Status::OK(); } @@ -346,18 +379,14 @@ class SslStreamImpl { } private: - static constexpr int VERIFY_DEPTH = 10; - - SSL *ssl_handle_ = nullptr; - SSL_CTX *ssl_ctx_ = nullptr; - BIO *bio_ = nullptr; + SslHandle ssl_handle_; friend class SslReadByteFlow; friend class SslWriteByteFlow; Result write(Slice slice) { clear_openssl_errors("Before SslFd::write"); - auto size = SSL_write(ssl_handle_, slice.data(), static_cast(slice.size())); + auto size = SSL_write(ssl_handle_.get(), slice.data(), static_cast(slice.size())); if (size <= 0) { return process_ssl_error(size); } @@ -366,7 +395,7 @@ class SslStreamImpl { Result read(MutableSlice slice) { clear_openssl_errors("Before SslFd::read"); - auto size = SSL_read(ssl_handle_, slice.data(), static_cast(slice.size())); + auto size = SSL_read(ssl_handle_.get(), slice.data(), static_cast(slice.size())); if (size <= 0) { return process_ssl_error(size); } @@ -444,7 +473,7 @@ class SslStreamImpl { Result process_ssl_error(int ret) { auto os_error = OS_ERROR("SSL_ERROR_SYSCALL"); - int error = SSL_get_error(ssl_handle_, ret); + int error = SSL_get_error(ssl_handle_.get(), ret); switch (error) { case SSL_ERROR_NONE: LOG(ERROR) << "SSL_get_error returned no error";