Store RsaKey in PublicRsaKeyShared.

This commit is contained in:
levlam 2021-07-05 22:19:59 +03:00
parent 65e96c561c
commit 5176710ada
2 changed files with 14 additions and 19 deletions

View File

@ -106,19 +106,18 @@ PublicRsaKeyShared::PublicRsaKeyShared(DcId dc_id, bool is_test) : dc_id_(dc_id)
void PublicRsaKeyShared::add_rsa(mtproto::RSA rsa) { void PublicRsaKeyShared::add_rsa(mtproto::RSA rsa) {
auto lock = rw_mutex_.lock_write(); auto lock = rw_mutex_.lock_write();
auto fingerprint = rsa.get_fingerprint(); auto fingerprint = rsa.get_fingerprint();
auto *has_rsa = get_rsa_unsafe(fingerprint); if (get_rsa_key_unsafe(fingerprint) != nullptr) {
if (has_rsa) {
return; return;
} }
options_.push_back(RsaOption{fingerprint, std::move(rsa)}); keys_.push_back(RsaKey{std::move(rsa), fingerprint});
} }
Result<mtproto::PublicRsaKeyInterface::RsaKey> PublicRsaKeyShared::get_rsa_key(const vector<int64> &fingerprints) { Result<mtproto::PublicRsaKeyInterface::RsaKey> PublicRsaKeyShared::get_rsa_key(const vector<int64> &fingerprints) {
auto lock = rw_mutex_.lock_read(); auto lock = rw_mutex_.lock_read();
for (auto fingerprint : fingerprints) { for (auto fingerprint : fingerprints) {
auto *rsa = get_rsa_unsafe(fingerprint); auto *rsa_key = get_rsa_key_unsafe(fingerprint);
if (rsa) { if (rsa_key != nullptr) {
return RsaKey{rsa->clone(), fingerprint}; return RsaKey{rsa_key->rsa.clone(), fingerprint};
} }
} }
return Status::Error(PSLICE() << "Unknown fingerprints " << format::as_array(fingerprints)); return Status::Error(PSLICE() << "Unknown fingerprints " << format::as_array(fingerprints));
@ -129,12 +128,12 @@ void PublicRsaKeyShared::drop_keys() {
return; return;
} }
auto lock = rw_mutex_.lock_write(); auto lock = rw_mutex_.lock_write();
options_.clear(); keys_.clear();
} }
bool PublicRsaKeyShared::has_keys() { bool PublicRsaKeyShared::has_keys() {
auto lock = rw_mutex_.lock_read(); auto lock = rw_mutex_.lock_read();
return !options_.empty(); return !keys_.empty();
} }
void PublicRsaKeyShared::add_listener(unique_ptr<Listener> listener) { void PublicRsaKeyShared::add_listener(unique_ptr<Listener> listener) {
@ -144,13 +143,13 @@ void PublicRsaKeyShared::add_listener(unique_ptr<Listener> listener) {
} }
} }
mtproto::RSA *PublicRsaKeyShared::get_rsa_unsafe(int64 fingerprint) { mtproto::PublicRsaKeyInterface::RsaKey *PublicRsaKeyShared::get_rsa_key_unsafe(int64 fingerprint) {
auto it = std::find_if(options_.begin(), options_.end(), auto it = std::find_if(keys_.begin(), keys_.end(),
[&](const auto &value) { return value.fingerprint == fingerprint; }); [fingerprint](const auto &value) { return value.fingerprint == fingerprint; });
if (it == options_.end()) { if (it == keys_.end()) {
return nullptr; return nullptr;
} }
return &it->rsa; return &*it;
} }
void PublicRsaKeyShared::notify() { void PublicRsaKeyShared::notify() {

View File

@ -44,15 +44,11 @@ class PublicRsaKeyShared final : public mtproto::PublicRsaKeyInterface {
private: private:
DcId dc_id_; DcId dc_id_;
struct RsaOption { std::vector<RsaKey> keys_;
int64 fingerprint;
mtproto::RSA rsa;
};
std::vector<RsaOption> options_;
std::vector<unique_ptr<Listener>> listeners_; std::vector<unique_ptr<Listener>> listeners_;
RwMutex rw_mutex_; RwMutex rw_mutex_;
mtproto::RSA *get_rsa_unsafe(int64 fingerprint); RsaKey *get_rsa_key_unsafe(int64 fingerprint);
void notify(); void notify();
}; };