Add struct RsaKey.

This commit is contained in:
levlam 2021-07-05 22:09:45 +03:00
parent c60693cc7e
commit 65e96c561c
4 changed files with 20 additions and 18 deletions

View File

@ -79,13 +79,13 @@ Status AuthKeyHandshake::on_res_pq(Slice message, Callback *connection, PublicRs
server_nonce_ = res_pq->server_nonce_; server_nonce_ = res_pq->server_nonce_;
auto r_rsa = public_rsa_key->get_rsa(res_pq->server_public_key_fingerprints_); auto r_rsa_key = public_rsa_key->get_rsa_key(res_pq->server_public_key_fingerprints_);
if (r_rsa.is_error()) { if (r_rsa_key.is_error()) {
public_rsa_key->drop_keys(); public_rsa_key->drop_keys();
return r_rsa.move_as_error(); return r_rsa_key.move_as_error();
} }
int64 rsa_fingerprint = r_rsa.ok().second; int64 rsa_fingerprint = r_rsa_key.ok().fingerprint;
RSA rsa = std::move(r_rsa.ok_ref().first); RSA rsa = std::move(r_rsa_key.ok_ref().rsa);
string p, q; string p, q;
if (pq_factorize(res_pq->pq_, &p, &q) == -1) { if (pq_factorize(res_pq->pq_, &p, &q) == -1) {

View File

@ -11,8 +11,6 @@
#include "td/utils/Slice.h" #include "td/utils/Slice.h"
#include "td/utils/Status.h" #include "td/utils/Status.h"
#include <utility>
namespace td { namespace td {
namespace mtproto { namespace mtproto {
@ -36,7 +34,13 @@ class RSA {
class PublicRsaKeyInterface { class PublicRsaKeyInterface {
public: public:
virtual ~PublicRsaKeyInterface() = default; virtual ~PublicRsaKeyInterface() = default;
virtual Result<std::pair<RSA, int64>> get_rsa(const vector<int64> &fingerprints) = 0;
struct RsaKey {
RSA rsa;
int64 fingerprint;
};
virtual Result<RsaKey> get_rsa_key(const vector<int64> &fingerprints) = 0;
virtual void drop_keys() = 0; virtual void drop_keys() = 0;
}; };

View File

@ -26,7 +26,7 @@ PublicRsaKeyShared::PublicRsaKeyShared(DcId dc_id, bool is_test) : dc_id_(dc_id)
LOG_CHECK(r_rsa.is_ok()) << r_rsa.error() << " " << pem; LOG_CHECK(r_rsa.is_ok()) << r_rsa.error() << " " << pem;
if (r_rsa.is_ok()) { if (r_rsa.is_ok()) {
this->add_rsa(r_rsa.move_as_ok()); add_rsa(r_rsa.move_as_ok());
} }
}; };
@ -106,19 +106,19 @@ PublicRsaKeyShared::PublicRsaKeyShared(DcId dc_id, bool is_test) : dc_id_(dc_id)
void PublicRsaKeyShared::add_rsa(mtproto::RSA rsa) { void PublicRsaKeyShared::add_rsa(mtproto::RSA rsa) {
auto lock = rw_mutex_.lock_write(); auto lock = rw_mutex_.lock_write();
auto fingerprint = rsa.get_fingerprint(); auto fingerprint = rsa.get_fingerprint();
auto *has_rsa = get_rsa_locked(fingerprint); auto *has_rsa = get_rsa_unsafe(fingerprint);
if (has_rsa) { if (has_rsa) {
return; return;
} }
options_.push_back(RsaOption{fingerprint, std::move(rsa)}); options_.push_back(RsaOption{fingerprint, std::move(rsa)});
} }
Result<std::pair<mtproto::RSA, int64>> PublicRsaKeyShared::get_rsa(const vector<int64> &fingerprints) { Result<mtproto::PublicRsaKeyInterface::RsaKey> PublicRsaKeyShared::get_rsa_key(const vector<int64> &fingerprints) {
auto lock = rw_mutex_.lock_read(); auto lock = rw_mutex_.lock_read();
for (auto fingerprint : fingerprints) { for (auto fingerprint : fingerprints) {
auto *rsa = get_rsa_locked(fingerprint); auto *rsa = get_rsa_unsafe(fingerprint);
if (rsa) { if (rsa) {
return std::make_pair(rsa->clone(), fingerprint); return RsaKey{rsa->clone(), fingerprint};
} }
} }
return Status::Error(PSLICE() << "Unknown fingerprints " << format::as_array(fingerprints)); return Status::Error(PSLICE() << "Unknown fingerprints " << format::as_array(fingerprints));
@ -144,7 +144,7 @@ void PublicRsaKeyShared::add_listener(unique_ptr<Listener> listener) {
} }
} }
mtproto::RSA *PublicRsaKeyShared::get_rsa_locked(int64 fingerprint) { mtproto::RSA *PublicRsaKeyShared::get_rsa_unsafe(int64 fingerprint) {
auto it = std::find_if(options_.begin(), options_.end(), auto it = std::find_if(options_.begin(), options_.end(),
[&](const auto &value) { return value.fingerprint == fingerprint; }); [&](const auto &value) { return value.fingerprint == fingerprint; });
if (it == options_.end()) { if (it == options_.end()) {

View File

@ -14,8 +14,6 @@
#include "td/utils/port/RwMutex.h" #include "td/utils/port/RwMutex.h"
#include "td/utils/Status.h" #include "td/utils/Status.h"
#include <utility>
namespace td { namespace td {
class PublicRsaKeyShared final : public mtproto::PublicRsaKeyInterface { class PublicRsaKeyShared final : public mtproto::PublicRsaKeyInterface {
@ -34,7 +32,7 @@ class PublicRsaKeyShared final : public mtproto::PublicRsaKeyInterface {
}; };
void add_rsa(mtproto::RSA rsa); void add_rsa(mtproto::RSA rsa);
Result<std::pair<mtproto::RSA, int64>> get_rsa(const vector<int64> &fingerprints) final; Result<RsaKey> get_rsa_key(const vector<int64> &fingerprints) final;
void drop_keys() final; void drop_keys() final;
bool has_keys(); bool has_keys();
@ -54,7 +52,7 @@ class PublicRsaKeyShared final : public mtproto::PublicRsaKeyInterface {
std::vector<unique_ptr<Listener>> listeners_; std::vector<unique_ptr<Listener>> listeners_;
RwMutex rw_mutex_; RwMutex rw_mutex_;
mtproto::RSA *get_rsa_locked(int64 fingerprint); mtproto::RSA *get_rsa_unsafe(int64 fingerprint);
void notify(); void notify();
}; };