From b07fc66b6960d13969981a6035fec0ca2ccb0cd2 Mon Sep 17 00:00:00 2001 From: levlam Date: Sun, 12 Aug 2018 15:44:24 +0300 Subject: [PATCH] Add cache of secure values. GitOrigin-RevId: 54fcf89a89f28086705e12869e9dc777c2a86233 --- td/telegram/DialogDb.cpp | 2 +- td/telegram/SecureManager.cpp | 113 ++++++++++++++++++++-------------- td/telegram/SecureManager.h | 14 +++-- 3 files changed, 78 insertions(+), 51 deletions(-) diff --git a/td/telegram/DialogDb.cpp b/td/telegram/DialogDb.cpp index 3071ebd0c..5ca7f6805 100644 --- a/td/telegram/DialogDb.cpp +++ b/td/telegram/DialogDb.cpp @@ -212,7 +212,7 @@ class DialogDbAsync : public DialogDbAsyncInterface { do_flush(); sync_db_safe_.reset(); sync_db_ = nullptr; - promise.set_result(Unit()); + promise.set_value(Unit()); stop(); } diff --git a/td/telegram/SecureManager.cpp b/td/telegram/SecureManager.cpp index b379803fd..b2bba5dfd 100644 --- a/td/telegram/SecureManager.cpp +++ b/td/telegram/SecureManager.cpp @@ -22,17 +22,15 @@ #include "td/utils/optional.h" #include "td/utils/Slice.h" -#include - namespace td { class GetSecureValue : public NetQueryCallback { public: - GetSecureValue(ActorShared<> parent, std::string password, SecureValueType type, + GetSecureValue(ActorShared parent, std::string password, SecureValueType type, Promise promise); private: - ActorShared<> parent_; + ActorShared parent_; string password_; SecureValueType type_; Promise promise_; @@ -49,10 +47,10 @@ class GetSecureValue : public NetQueryCallback { class GetAllSecureValues : public NetQueryCallback { public: - GetAllSecureValues(ActorShared<> parent, std::string password, Promise promise); + GetAllSecureValues(ActorShared parent, std::string password, Promise promise); private: - ActorShared<> parent_; + ActorShared parent_; string password_; Promise promise_; optional> encrypted_secure_values_; @@ -68,11 +66,11 @@ class GetAllSecureValues : public NetQueryCallback { class SetSecureValue : public NetQueryCallback { public: - SetSecureValue(ActorShared<> parent, string password, SecureValue secure_value, + SetSecureValue(ActorShared parent, string password, SecureValue secure_value, Promise promise); private: - ActorShared<> parent_; + ActorShared parent_; string password_; SecureValue secure_value_; Promise promise_; @@ -153,7 +151,7 @@ class SetSecureValueErrorsQuery : public Td::ResultHandler { } }; -GetSecureValue::GetSecureValue(ActorShared<> parent, std::string password, SecureValueType type, +GetSecureValue::GetSecureValue(ActorShared parent, std::string password, SecureValueType type, Promise promise) : parent_(std::move(parent)), password_(std::move(password)), type_(type), promise_(std::move(promise)) { } @@ -188,7 +186,10 @@ void GetSecureValue::loop() { if (r_secure_value.is_error()) { return on_error(r_secure_value.move_as_error()); } - promise_.set_result(r_secure_value.move_as_ok()); + + send_closure(parent_, &SecureManager::on_get_secure_value, r_secure_value.ok()); + + promise_.set_value(r_secure_value.move_as_ok()); stop(); } @@ -226,7 +227,8 @@ void GetSecureValue::on_result(NetQueryPtr query) { loop(); } -GetAllSecureValues::GetAllSecureValues(ActorShared<> parent, std::string password, Promise promise) +GetAllSecureValues::GetAllSecureValues(ActorShared parent, std::string password, + Promise promise) : parent_(std::move(parent)), password_(std::move(password)), promise_(std::move(promise)) { } @@ -260,9 +262,14 @@ void GetAllSecureValues::loop() { if (r_secure_values.is_error()) { return on_error(r_secure_values.move_as_error()); } + + for (auto &secure_value : r_secure_values.ok()) { + send_closure(parent_, &SecureManager::on_get_secure_value, secure_value); + } + auto secure_values = transform(r_secure_values.move_as_ok(), [](SecureValueWithCredentials &&value) { return std::move(value.value); }); - promise_.set_result(get_passport_elements_object(file_manager, std::move(secure_values))); + promise_.set_value(get_passport_elements_object(file_manager, std::move(secure_values))); stop(); } @@ -287,7 +294,7 @@ void GetAllSecureValues::on_result(NetQueryPtr query) { loop(); } -SetSecureValue::SetSecureValue(ActorShared<> parent, string password, SecureValue secure_value, +SetSecureValue::SetSecureValue(ActorShared parent, string password, SecureValue secure_value, Promise promise) : parent_(std::move(parent)) , password_(std::move(password)) @@ -561,7 +568,10 @@ void SetSecureValue::on_result(NetQueryPtr query) { if (r_secure_value.is_error()) { return on_error(r_secure_value.move_as_error()); } - promise_.set_result(r_secure_value.move_as_ok()); + + send_closure(parent_, &SecureManager::on_get_secure_value, r_secure_value.ok()); + + promise_.set_value(r_secure_value.move_as_ok()); stop(); } @@ -579,12 +589,12 @@ void SetSecureValue::merge(FileManager *file_manager, FileId file_id, EncryptedS class DeleteSecureValue : public NetQueryCallback { public: - DeleteSecureValue(ActorShared<> parent, SecureValueType type, Promise promise) + DeleteSecureValue(ActorShared parent, SecureValueType type, Promise promise) : parent_(std::move(parent)), type_(std::move(type)), promise_(std::move(promise)) { } private: - ActorShared<> parent_; + ActorShared parent_; SecureValueType type_; Promise promise_; @@ -609,8 +619,9 @@ class DeleteSecureValue : public NetQueryCallback { class GetPassportAuthorizationForm : public NetQueryCallback { public: - GetPassportAuthorizationForm(ActorShared<> parent, string password, int32 authorization_form_id, UserId bot_user_id, - string scope, string public_key, Promise promise) + GetPassportAuthorizationForm(ActorShared parent, string password, int32 authorization_form_id, + UserId bot_user_id, string scope, string public_key, + Promise promise) : parent_(std::move(parent)) , password_(std::move(password)) , authorization_form_id_(authorization_form_id) @@ -621,7 +632,7 @@ class GetPassportAuthorizationForm : public NetQueryCallback { } private: - ActorShared<> parent_; + ActorShared parent_; string password_; int32 authorization_form_id_; UserId bot_user_id_; @@ -844,7 +855,11 @@ void SecureManager::get_secure_value(std::string password, SecureValueType type, if (r_secure_value.is_error()) { return promise.set_error(r_secure_value.move_as_error()); } + auto *file_manager = G()->td().get_actor_unsafe()->file_manager_.get(); + if (file_manager == nullptr) { + return promise.set_value(nullptr); + } auto r_passport_element = get_passport_element_object(file_manager, r_secure_value.move_as_ok().value); if (r_passport_element.is_error()) { LOG(ERROR) << "Failed to get passport element object: " << r_passport_element.error(); @@ -852,19 +867,29 @@ void SecureManager::get_secure_value(std::string password, SecureValueType type, } promise.set_value(r_passport_element.move_as_ok()); }); - do_get_secure_value(std::move(password), type, std::move(new_promise)); + do_get_secure_value(std::move(password), type, false, std::move(new_promise)); } -void SecureManager::do_get_secure_value(std::string password, SecureValueType type, +void SecureManager::do_get_secure_value(std::string password, SecureValueType type, bool allow_from_cache, Promise promise) { + if (allow_from_cache && secure_value_cache_.count(type)) { + // TODO check password? + return promise.set_value(SecureValueWithCredentials(secure_value_cache_[type])); + } + refcnt_++; - create_actor("GetSecureValue", actor_shared(), std::move(password), type, std::move(promise)) + create_actor("GetSecureValue", actor_shared(this), std::move(password), type, std::move(promise)) .release(); } +void SecureManager::on_get_secure_value(SecureValueWithCredentials value) { + auto type = value.value.type; + secure_value_cache_[type] = std::move(value); +} + void SecureManager::get_all_secure_values(std::string password, Promise promise) { refcnt_++; - create_actor("GetAllSecureValues", actor_shared(), std::move(password), std::move(promise)) + create_actor("GetAllSecureValues", actor_shared(this), std::move(password), std::move(promise)) .release(); } @@ -884,8 +909,8 @@ void SecureManager::set_secure_value(string password, SecureValue secure_value, } promise.set_value(r_passport_element.move_as_ok()); }); - set_secure_value_queries_[type] = create_actor("SetSecureValue", actor_shared(), std::move(password), - std::move(secure_value), std::move(new_promise)); + set_secure_value_queries_[type] = create_actor( + "SetSecureValue", actor_shared(this), std::move(password), std::move(secure_value), std::move(new_promise)); } void SecureManager::delete_secure_value(SecureValueType type, Promise promise) { @@ -894,7 +919,7 @@ void SecureManager::delete_secure_value(SecureValueType type, Promise prom [actor_id = actor_id(this), type, promise = std::move(promise)](Result result) mutable { send_closure(actor_id, &SecureManager::on_delete_secure_value, type, std::move(promise), std::move(result)); }); - create_actor("DeleteSecureValue", actor_shared(), type, std::move(new_promise)).release(); + create_actor("DeleteSecureValue", actor_shared(this), type, std::move(new_promise)).release(); } void SecureManager::on_delete_secure_value(SecureValueType type, Promise promise, Result result) { @@ -902,6 +927,7 @@ void SecureManager::on_delete_secure_value(SecureValueType type, Promise p return promise.set_error(result.move_as_error()); } + secure_value_cache_.erase(type); promise.set_value(Unit()); } @@ -1005,7 +1031,7 @@ void SecureManager::get_passport_authorization_form(string password, UserId bot_ string public_key, string payload, Promise promise) { refcnt_++; - auto authorization_form_id = ++authorization_form_id_; + auto authorization_form_id = ++max_authorization_form_id_; authorization_forms_[authorization_form_id] = AuthorizationForm{bot_user_id, scope, public_key, payload, false}; auto new_promise = PromiseCreator::lambda([actor_id = actor_id(this), authorization_form_id, promise = std::move(promise)]( @@ -1013,7 +1039,7 @@ void SecureManager::get_passport_authorization_form(string password, UserId bot_ send_closure(actor_id, &SecureManager::on_get_passport_authorization_form, authorization_form_id, std::move(promise), std::move(r_authorization_form)); }); - create_actor("GetPassportAuthorizationForm", actor_shared(), std::move(password), + create_actor("GetPassportAuthorizationForm", actor_shared(this), std::move(password), authorization_form_id, bot_user_id, std::move(scope), std::move(public_key), std::move(new_promise)) .release(); @@ -1050,31 +1076,28 @@ void SecureManager::send_passport_authorization_form(string password, int32 auth } struct JoinPromise { - std::mutex mutex_; Promise> promise_; std::vector credentials_; int wait_cnt_{0}; }; auto join = std::make_shared(); - std::lock_guard guard(join->mutex_); for (auto type : types) { join->wait_cnt_++; - do_get_secure_value(password, type, - PromiseCreator::lambda([join](Result r_secure_value) { - std::lock_guard guard(join->mutex_); - if (!join->promise_) { - return; - } - if (r_secure_value.is_error()) { - return join->promise_.set_error(r_secure_value.move_as_error()); - } - join->credentials_.push_back(r_secure_value.move_as_ok().credentials); - join->wait_cnt_--; - if (join->wait_cnt_ == 0) { - join->promise_.set_value(std::move(join->credentials_)); - } - })); + send_closure_later(actor_id(this), &SecureManager::do_get_secure_value, password, type, true, + PromiseCreator::lambda([join](Result r_secure_value) { + if (!join->promise_) { + return; + } + if (r_secure_value.is_error()) { + return join->promise_.set_error(r_secure_value.move_as_error()); + } + join->credentials_.push_back(r_secure_value.move_as_ok().credentials); + join->wait_cnt_--; + if (join->wait_cnt_ == 0) { + join->promise_.set_value(std::move(join->credentials_)); + } + })); } join->promise_ = PromiseCreator::lambda([promise = std::move(promise), actor_id = actor_id(this), diff --git a/td/telegram/SecureManager.h b/td/telegram/SecureManager.h index 7502e4629..018761f4f 100644 --- a/td/telegram/SecureManager.h +++ b/td/telegram/SecureManager.h @@ -19,8 +19,8 @@ #include "td/utils/Container.h" #include "td/utils/Status.h" -#include #include +#include namespace td { @@ -41,6 +41,8 @@ class SecureManager : public NetQueryCallback { void set_secure_value_errors(Td *td, tl_object_ptr input_user, vector> errors, Promise promise); + void on_get_secure_value(SecureValueWithCredentials value); + void get_passport_authorization_form(string password, UserId bot_user_id, string scope, string public_key, string payload, Promise promise); void send_passport_authorization_form(string password, int32 authorization_form_id, @@ -49,7 +51,8 @@ class SecureManager : public NetQueryCallback { private: ActorShared<> parent_; int32 refcnt_{1}; - std::map> set_secure_value_queries_; + std::unordered_map> set_secure_value_queries_; + std::unordered_map secure_value_cache_; struct AuthorizationForm { UserId bot_user_id; @@ -59,13 +62,14 @@ class SecureManager : public NetQueryCallback { bool is_received; }; - std::map authorization_forms_; - int32 authorization_form_id_{0}; + std::unordered_map authorization_forms_; + int32 max_authorization_form_id_{0}; void hangup() override; void hangup_shared() override; void dec_refcnt(); - void do_get_secure_value(std::string password, SecureValueType type, Promise promise); + void do_get_secure_value(std::string password, SecureValueType type, bool allow_from_cache, + Promise promise); void on_delete_secure_value(SecureValueType type, Promise promise, Result result); void on_get_passport_authorization_form(int32 authorization_form_id, Promise promise, Result r_authorization_form);