diff --git a/td/mtproto/Handshake.cpp b/td/mtproto/Handshake.cpp index 9b66daa59..5ec9eb652 100644 --- a/td/mtproto/Handshake.cpp +++ b/td/mtproto/Handshake.cpp @@ -79,13 +79,13 @@ Status AuthKeyHandshake::on_res_pq(Slice message, Callback *connection, PublicRs server_nonce_ = res_pq->server_nonce_; - auto r_rsa = public_rsa_key->get_rsa(res_pq->server_public_key_fingerprints_); - if (r_rsa.is_error()) { + auto r_rsa_key = public_rsa_key->get_rsa_key(res_pq->server_public_key_fingerprints_); + if (r_rsa_key.is_error()) { public_rsa_key->drop_keys(); - return r_rsa.move_as_error(); + return r_rsa_key.move_as_error(); } - int64 rsa_fingerprint = r_rsa.ok().second; - RSA rsa = std::move(r_rsa.ok_ref().first); + int64 rsa_fingerprint = r_rsa_key.ok().fingerprint; + RSA rsa = std::move(r_rsa_key.ok_ref().rsa); string p, q; if (pq_factorize(res_pq->pq_, &p, &q) == -1) { diff --git a/td/mtproto/RSA.h b/td/mtproto/RSA.h index 99105cbcb..33c905226 100644 --- a/td/mtproto/RSA.h +++ b/td/mtproto/RSA.h @@ -11,8 +11,6 @@ #include "td/utils/Slice.h" #include "td/utils/Status.h" -#include - namespace td { namespace mtproto { @@ -36,7 +34,13 @@ class RSA { class PublicRsaKeyInterface { public: virtual ~PublicRsaKeyInterface() = default; - virtual Result> get_rsa(const vector &fingerprints) = 0; + + struct RsaKey { + RSA rsa; + int64 fingerprint; + }; + virtual Result get_rsa_key(const vector &fingerprints) = 0; + virtual void drop_keys() = 0; }; diff --git a/td/telegram/net/PublicRsaKeyShared.cpp b/td/telegram/net/PublicRsaKeyShared.cpp index 821f53443..e81c2a615 100644 --- a/td/telegram/net/PublicRsaKeyShared.cpp +++ b/td/telegram/net/PublicRsaKeyShared.cpp @@ -26,7 +26,7 @@ PublicRsaKeyShared::PublicRsaKeyShared(DcId dc_id, bool is_test) : dc_id_(dc_id) LOG_CHECK(r_rsa.is_ok()) << r_rsa.error() << " " << pem; if (r_rsa.is_ok()) { - this->add_rsa(r_rsa.move_as_ok()); + add_rsa(r_rsa.move_as_ok()); } }; @@ -106,19 +106,19 @@ PublicRsaKeyShared::PublicRsaKeyShared(DcId dc_id, bool is_test) : dc_id_(dc_id) void PublicRsaKeyShared::add_rsa(mtproto::RSA rsa) { auto lock = rw_mutex_.lock_write(); auto fingerprint = rsa.get_fingerprint(); - auto *has_rsa = get_rsa_locked(fingerprint); + auto *has_rsa = get_rsa_unsafe(fingerprint); if (has_rsa) { return; } options_.push_back(RsaOption{fingerprint, std::move(rsa)}); } -Result> PublicRsaKeyShared::get_rsa(const vector &fingerprints) { +Result PublicRsaKeyShared::get_rsa_key(const vector &fingerprints) { auto lock = rw_mutex_.lock_read(); for (auto fingerprint : fingerprints) { - auto *rsa = get_rsa_locked(fingerprint); + auto *rsa = get_rsa_unsafe(fingerprint); if (rsa) { - return std::make_pair(rsa->clone(), fingerprint); + return RsaKey{rsa->clone(), fingerprint}; } } return Status::Error(PSLICE() << "Unknown fingerprints " << format::as_array(fingerprints)); @@ -144,7 +144,7 @@ void PublicRsaKeyShared::add_listener(unique_ptr listener) { } } -mtproto::RSA *PublicRsaKeyShared::get_rsa_locked(int64 fingerprint) { +mtproto::RSA *PublicRsaKeyShared::get_rsa_unsafe(int64 fingerprint) { auto it = std::find_if(options_.begin(), options_.end(), [&](const auto &value) { return value.fingerprint == fingerprint; }); if (it == options_.end()) { diff --git a/td/telegram/net/PublicRsaKeyShared.h b/td/telegram/net/PublicRsaKeyShared.h index 6de9b14a3..7208874d0 100644 --- a/td/telegram/net/PublicRsaKeyShared.h +++ b/td/telegram/net/PublicRsaKeyShared.h @@ -14,8 +14,6 @@ #include "td/utils/port/RwMutex.h" #include "td/utils/Status.h" -#include - namespace td { class PublicRsaKeyShared final : public mtproto::PublicRsaKeyInterface { @@ -34,7 +32,7 @@ class PublicRsaKeyShared final : public mtproto::PublicRsaKeyInterface { }; void add_rsa(mtproto::RSA rsa); - Result> get_rsa(const vector &fingerprints) final; + Result get_rsa_key(const vector &fingerprints) final; void drop_keys() final; bool has_keys(); @@ -54,7 +52,7 @@ class PublicRsaKeyShared final : public mtproto::PublicRsaKeyInterface { std::vector> listeners_; RwMutex rw_mutex_; - mtproto::RSA *get_rsa_locked(int64 fingerprint); + mtproto::RSA *get_rsa_unsafe(int64 fingerprint); void notify(); };