Add struct RsaKey.

This commit is contained in:
levlam 2021-07-05 22:09:45 +03:00
parent c60693cc7e
commit 65e96c561c
4 changed files with 20 additions and 18 deletions

View File

@ -79,13 +79,13 @@ Status AuthKeyHandshake::on_res_pq(Slice message, Callback *connection, PublicRs
server_nonce_ = res_pq->server_nonce_;
auto r_rsa = public_rsa_key->get_rsa(res_pq->server_public_key_fingerprints_);
if (r_rsa.is_error()) {
auto r_rsa_key = public_rsa_key->get_rsa_key(res_pq->server_public_key_fingerprints_);
if (r_rsa_key.is_error()) {
public_rsa_key->drop_keys();
return r_rsa.move_as_error();
return r_rsa_key.move_as_error();
}
int64 rsa_fingerprint = r_rsa.ok().second;
RSA rsa = std::move(r_rsa.ok_ref().first);
int64 rsa_fingerprint = r_rsa_key.ok().fingerprint;
RSA rsa = std::move(r_rsa_key.ok_ref().rsa);
string p, q;
if (pq_factorize(res_pq->pq_, &p, &q) == -1) {

View File

@ -11,8 +11,6 @@
#include "td/utils/Slice.h"
#include "td/utils/Status.h"
#include <utility>
namespace td {
namespace mtproto {
@ -36,7 +34,13 @@ class RSA {
class PublicRsaKeyInterface {
public:
virtual ~PublicRsaKeyInterface() = default;
virtual Result<std::pair<RSA, int64>> get_rsa(const vector<int64> &fingerprints) = 0;
struct RsaKey {
RSA rsa;
int64 fingerprint;
};
virtual Result<RsaKey> get_rsa_key(const vector<int64> &fingerprints) = 0;
virtual void drop_keys() = 0;
};

View File

@ -26,7 +26,7 @@ PublicRsaKeyShared::PublicRsaKeyShared(DcId dc_id, bool is_test) : dc_id_(dc_id)
LOG_CHECK(r_rsa.is_ok()) << r_rsa.error() << " " << pem;
if (r_rsa.is_ok()) {
this->add_rsa(r_rsa.move_as_ok());
add_rsa(r_rsa.move_as_ok());
}
};
@ -106,19 +106,19 @@ PublicRsaKeyShared::PublicRsaKeyShared(DcId dc_id, bool is_test) : dc_id_(dc_id)
void PublicRsaKeyShared::add_rsa(mtproto::RSA rsa) {
auto lock = rw_mutex_.lock_write();
auto fingerprint = rsa.get_fingerprint();
auto *has_rsa = get_rsa_locked(fingerprint);
auto *has_rsa = get_rsa_unsafe(fingerprint);
if (has_rsa) {
return;
}
options_.push_back(RsaOption{fingerprint, std::move(rsa)});
}
Result<std::pair<mtproto::RSA, int64>> PublicRsaKeyShared::get_rsa(const vector<int64> &fingerprints) {
Result<mtproto::PublicRsaKeyInterface::RsaKey> PublicRsaKeyShared::get_rsa_key(const vector<int64> &fingerprints) {
auto lock = rw_mutex_.lock_read();
for (auto fingerprint : fingerprints) {
auto *rsa = get_rsa_locked(fingerprint);
auto *rsa = get_rsa_unsafe(fingerprint);
if (rsa) {
return std::make_pair(rsa->clone(), fingerprint);
return RsaKey{rsa->clone(), fingerprint};
}
}
return Status::Error(PSLICE() << "Unknown fingerprints " << format::as_array(fingerprints));
@ -144,7 +144,7 @@ void PublicRsaKeyShared::add_listener(unique_ptr<Listener> listener) {
}
}
mtproto::RSA *PublicRsaKeyShared::get_rsa_locked(int64 fingerprint) {
mtproto::RSA *PublicRsaKeyShared::get_rsa_unsafe(int64 fingerprint) {
auto it = std::find_if(options_.begin(), options_.end(),
[&](const auto &value) { return value.fingerprint == fingerprint; });
if (it == options_.end()) {

View File

@ -14,8 +14,6 @@
#include "td/utils/port/RwMutex.h"
#include "td/utils/Status.h"
#include <utility>
namespace td {
class PublicRsaKeyShared final : public mtproto::PublicRsaKeyInterface {
@ -34,7 +32,7 @@ class PublicRsaKeyShared final : public mtproto::PublicRsaKeyInterface {
};
void add_rsa(mtproto::RSA rsa);
Result<std::pair<mtproto::RSA, int64>> get_rsa(const vector<int64> &fingerprints) final;
Result<RsaKey> get_rsa_key(const vector<int64> &fingerprints) final;
void drop_keys() final;
bool has_keys();
@ -54,7 +52,7 @@ class PublicRsaKeyShared final : public mtproto::PublicRsaKeyInterface {
std::vector<unique_ptr<Listener>> listeners_;
RwMutex rw_mutex_;
mtproto::RSA *get_rsa_locked(int64 fingerprint);
mtproto::RSA *get_rsa_unsafe(int64 fingerprint);
void notify();
};