Add cache of secure values.

GitOrigin-RevId: 54fcf89a89f28086705e12869e9dc777c2a86233
This commit is contained in:
levlam 2018-08-12 15:44:24 +03:00
parent 3728c89f53
commit b07fc66b69
3 changed files with 78 additions and 51 deletions

View File

@ -212,7 +212,7 @@ class DialogDbAsync : public DialogDbAsyncInterface {
do_flush(); do_flush();
sync_db_safe_.reset(); sync_db_safe_.reset();
sync_db_ = nullptr; sync_db_ = nullptr;
promise.set_result(Unit()); promise.set_value(Unit());
stop(); stop();
} }

View File

@ -22,17 +22,15 @@
#include "td/utils/optional.h" #include "td/utils/optional.h"
#include "td/utils/Slice.h" #include "td/utils/Slice.h"
#include <mutex>
namespace td { namespace td {
class GetSecureValue : public NetQueryCallback { class GetSecureValue : public NetQueryCallback {
public: public:
GetSecureValue(ActorShared<> parent, std::string password, SecureValueType type, GetSecureValue(ActorShared<SecureManager> parent, std::string password, SecureValueType type,
Promise<SecureValueWithCredentials> promise); Promise<SecureValueWithCredentials> promise);
private: private:
ActorShared<> parent_; ActorShared<SecureManager> parent_;
string password_; string password_;
SecureValueType type_; SecureValueType type_;
Promise<SecureValueWithCredentials> promise_; Promise<SecureValueWithCredentials> promise_;
@ -49,10 +47,10 @@ class GetSecureValue : public NetQueryCallback {
class GetAllSecureValues : public NetQueryCallback { class GetAllSecureValues : public NetQueryCallback {
public: public:
GetAllSecureValues(ActorShared<> parent, std::string password, Promise<TdApiSecureValues> promise); GetAllSecureValues(ActorShared<SecureManager> parent, std::string password, Promise<TdApiSecureValues> promise);
private: private:
ActorShared<> parent_; ActorShared<SecureManager> parent_;
string password_; string password_;
Promise<TdApiSecureValues> promise_; Promise<TdApiSecureValues> promise_;
optional<vector<EncryptedSecureValue>> encrypted_secure_values_; optional<vector<EncryptedSecureValue>> encrypted_secure_values_;
@ -68,11 +66,11 @@ class GetAllSecureValues : public NetQueryCallback {
class SetSecureValue : public NetQueryCallback { class SetSecureValue : public NetQueryCallback {
public: public:
SetSecureValue(ActorShared<> parent, string password, SecureValue secure_value, SetSecureValue(ActorShared<SecureManager> parent, string password, SecureValue secure_value,
Promise<SecureValueWithCredentials> promise); Promise<SecureValueWithCredentials> promise);
private: private:
ActorShared<> parent_; ActorShared<SecureManager> parent_;
string password_; string password_;
SecureValue secure_value_; SecureValue secure_value_;
Promise<SecureValueWithCredentials> promise_; Promise<SecureValueWithCredentials> promise_;
@ -153,7 +151,7 @@ class SetSecureValueErrorsQuery : public Td::ResultHandler {
} }
}; };
GetSecureValue::GetSecureValue(ActorShared<> parent, std::string password, SecureValueType type, GetSecureValue::GetSecureValue(ActorShared<SecureManager> parent, std::string password, SecureValueType type,
Promise<SecureValueWithCredentials> promise) Promise<SecureValueWithCredentials> promise)
: parent_(std::move(parent)), password_(std::move(password)), type_(type), promise_(std::move(promise)) { : parent_(std::move(parent)), password_(std::move(password)), type_(type), promise_(std::move(promise)) {
} }
@ -188,7 +186,10 @@ void GetSecureValue::loop() {
if (r_secure_value.is_error()) { if (r_secure_value.is_error()) {
return on_error(r_secure_value.move_as_error()); return on_error(r_secure_value.move_as_error());
} }
promise_.set_result(r_secure_value.move_as_ok());
send_closure(parent_, &SecureManager::on_get_secure_value, r_secure_value.ok());
promise_.set_value(r_secure_value.move_as_ok());
stop(); stop();
} }
@ -226,7 +227,8 @@ void GetSecureValue::on_result(NetQueryPtr query) {
loop(); loop();
} }
GetAllSecureValues::GetAllSecureValues(ActorShared<> parent, std::string password, Promise<TdApiSecureValues> promise) GetAllSecureValues::GetAllSecureValues(ActorShared<SecureManager> parent, std::string password,
Promise<TdApiSecureValues> promise)
: parent_(std::move(parent)), password_(std::move(password)), promise_(std::move(promise)) { : parent_(std::move(parent)), password_(std::move(password)), promise_(std::move(promise)) {
} }
@ -260,9 +262,14 @@ void GetAllSecureValues::loop() {
if (r_secure_values.is_error()) { if (r_secure_values.is_error()) {
return on_error(r_secure_values.move_as_error()); return on_error(r_secure_values.move_as_error());
} }
for (auto &secure_value : r_secure_values.ok()) {
send_closure(parent_, &SecureManager::on_get_secure_value, secure_value);
}
auto secure_values = transform(r_secure_values.move_as_ok(), auto secure_values = transform(r_secure_values.move_as_ok(),
[](SecureValueWithCredentials &&value) { return std::move(value.value); }); [](SecureValueWithCredentials &&value) { return std::move(value.value); });
promise_.set_result(get_passport_elements_object(file_manager, std::move(secure_values))); promise_.set_value(get_passport_elements_object(file_manager, std::move(secure_values)));
stop(); stop();
} }
@ -287,7 +294,7 @@ void GetAllSecureValues::on_result(NetQueryPtr query) {
loop(); loop();
} }
SetSecureValue::SetSecureValue(ActorShared<> parent, string password, SecureValue secure_value, SetSecureValue::SetSecureValue(ActorShared<SecureManager> parent, string password, SecureValue secure_value,
Promise<SecureValueWithCredentials> promise) Promise<SecureValueWithCredentials> promise)
: parent_(std::move(parent)) : parent_(std::move(parent))
, password_(std::move(password)) , password_(std::move(password))
@ -561,7 +568,10 @@ void SetSecureValue::on_result(NetQueryPtr query) {
if (r_secure_value.is_error()) { if (r_secure_value.is_error()) {
return on_error(r_secure_value.move_as_error()); return on_error(r_secure_value.move_as_error());
} }
promise_.set_result(r_secure_value.move_as_ok());
send_closure(parent_, &SecureManager::on_get_secure_value, r_secure_value.ok());
promise_.set_value(r_secure_value.move_as_ok());
stop(); stop();
} }
@ -579,12 +589,12 @@ void SetSecureValue::merge(FileManager *file_manager, FileId file_id, EncryptedS
class DeleteSecureValue : public NetQueryCallback { class DeleteSecureValue : public NetQueryCallback {
public: public:
DeleteSecureValue(ActorShared<> parent, SecureValueType type, Promise<Unit> promise) DeleteSecureValue(ActorShared<SecureManager> parent, SecureValueType type, Promise<Unit> promise)
: parent_(std::move(parent)), type_(std::move(type)), promise_(std::move(promise)) { : parent_(std::move(parent)), type_(std::move(type)), promise_(std::move(promise)) {
} }
private: private:
ActorShared<> parent_; ActorShared<SecureManager> parent_;
SecureValueType type_; SecureValueType type_;
Promise<Unit> promise_; Promise<Unit> promise_;
@ -609,8 +619,9 @@ class DeleteSecureValue : public NetQueryCallback {
class GetPassportAuthorizationForm : public NetQueryCallback { class GetPassportAuthorizationForm : public NetQueryCallback {
public: public:
GetPassportAuthorizationForm(ActorShared<> parent, string password, int32 authorization_form_id, UserId bot_user_id, GetPassportAuthorizationForm(ActorShared<SecureManager> parent, string password, int32 authorization_form_id,
string scope, string public_key, Promise<TdApiAuthorizationForm> promise) UserId bot_user_id, string scope, string public_key,
Promise<TdApiAuthorizationForm> promise)
: parent_(std::move(parent)) : parent_(std::move(parent))
, password_(std::move(password)) , password_(std::move(password))
, authorization_form_id_(authorization_form_id) , authorization_form_id_(authorization_form_id)
@ -621,7 +632,7 @@ class GetPassportAuthorizationForm : public NetQueryCallback {
} }
private: private:
ActorShared<> parent_; ActorShared<SecureManager> parent_;
string password_; string password_;
int32 authorization_form_id_; int32 authorization_form_id_;
UserId bot_user_id_; UserId bot_user_id_;
@ -844,7 +855,11 @@ void SecureManager::get_secure_value(std::string password, SecureValueType type,
if (r_secure_value.is_error()) { if (r_secure_value.is_error()) {
return promise.set_error(r_secure_value.move_as_error()); return promise.set_error(r_secure_value.move_as_error());
} }
auto *file_manager = G()->td().get_actor_unsafe()->file_manager_.get(); auto *file_manager = G()->td().get_actor_unsafe()->file_manager_.get();
if (file_manager == nullptr) {
return promise.set_value(nullptr);
}
auto r_passport_element = get_passport_element_object(file_manager, r_secure_value.move_as_ok().value); auto r_passport_element = get_passport_element_object(file_manager, r_secure_value.move_as_ok().value);
if (r_passport_element.is_error()) { if (r_passport_element.is_error()) {
LOG(ERROR) << "Failed to get passport element object: " << r_passport_element.error(); LOG(ERROR) << "Failed to get passport element object: " << r_passport_element.error();
@ -852,19 +867,29 @@ void SecureManager::get_secure_value(std::string password, SecureValueType type,
} }
promise.set_value(r_passport_element.move_as_ok()); promise.set_value(r_passport_element.move_as_ok());
}); });
do_get_secure_value(std::move(password), type, std::move(new_promise)); do_get_secure_value(std::move(password), type, false, std::move(new_promise));
} }
void SecureManager::do_get_secure_value(std::string password, SecureValueType type, void SecureManager::do_get_secure_value(std::string password, SecureValueType type, bool allow_from_cache,
Promise<SecureValueWithCredentials> promise) { Promise<SecureValueWithCredentials> promise) {
if (allow_from_cache && secure_value_cache_.count(type)) {
// TODO check password?
return promise.set_value(SecureValueWithCredentials(secure_value_cache_[type]));
}
refcnt_++; refcnt_++;
create_actor<GetSecureValue>("GetSecureValue", actor_shared(), std::move(password), type, std::move(promise)) create_actor<GetSecureValue>("GetSecureValue", actor_shared(this), std::move(password), type, std::move(promise))
.release(); .release();
} }
void SecureManager::on_get_secure_value(SecureValueWithCredentials value) {
auto type = value.value.type;
secure_value_cache_[type] = std::move(value);
}
void SecureManager::get_all_secure_values(std::string password, Promise<TdApiSecureValues> promise) { void SecureManager::get_all_secure_values(std::string password, Promise<TdApiSecureValues> promise) {
refcnt_++; refcnt_++;
create_actor<GetAllSecureValues>("GetAllSecureValues", actor_shared(), std::move(password), std::move(promise)) create_actor<GetAllSecureValues>("GetAllSecureValues", actor_shared(this), std::move(password), std::move(promise))
.release(); .release();
} }
@ -884,8 +909,8 @@ void SecureManager::set_secure_value(string password, SecureValue secure_value,
} }
promise.set_value(r_passport_element.move_as_ok()); promise.set_value(r_passport_element.move_as_ok());
}); });
set_secure_value_queries_[type] = create_actor<SetSecureValue>("SetSecureValue", actor_shared(), std::move(password), set_secure_value_queries_[type] = create_actor<SetSecureValue>(
std::move(secure_value), std::move(new_promise)); "SetSecureValue", actor_shared(this), std::move(password), std::move(secure_value), std::move(new_promise));
} }
void SecureManager::delete_secure_value(SecureValueType type, Promise<Unit> promise) { void SecureManager::delete_secure_value(SecureValueType type, Promise<Unit> promise) {
@ -894,7 +919,7 @@ void SecureManager::delete_secure_value(SecureValueType type, Promise<Unit> prom
[actor_id = actor_id(this), type, promise = std::move(promise)](Result<Unit> result) mutable { [actor_id = actor_id(this), type, promise = std::move(promise)](Result<Unit> result) mutable {
send_closure(actor_id, &SecureManager::on_delete_secure_value, type, std::move(promise), std::move(result)); send_closure(actor_id, &SecureManager::on_delete_secure_value, type, std::move(promise), std::move(result));
}); });
create_actor<DeleteSecureValue>("DeleteSecureValue", actor_shared(), type, std::move(new_promise)).release(); create_actor<DeleteSecureValue>("DeleteSecureValue", actor_shared(this), type, std::move(new_promise)).release();
} }
void SecureManager::on_delete_secure_value(SecureValueType type, Promise<Unit> promise, Result<Unit> result) { void SecureManager::on_delete_secure_value(SecureValueType type, Promise<Unit> promise, Result<Unit> result) {
@ -902,6 +927,7 @@ void SecureManager::on_delete_secure_value(SecureValueType type, Promise<Unit> p
return promise.set_error(result.move_as_error()); return promise.set_error(result.move_as_error());
} }
secure_value_cache_.erase(type);
promise.set_value(Unit()); promise.set_value(Unit());
} }
@ -1005,7 +1031,7 @@ void SecureManager::get_passport_authorization_form(string password, UserId bot_
string public_key, string payload, string public_key, string payload,
Promise<TdApiAuthorizationForm> promise) { Promise<TdApiAuthorizationForm> promise) {
refcnt_++; refcnt_++;
auto authorization_form_id = ++authorization_form_id_; auto authorization_form_id = ++max_authorization_form_id_;
authorization_forms_[authorization_form_id] = AuthorizationForm{bot_user_id, scope, public_key, payload, false}; authorization_forms_[authorization_form_id] = AuthorizationForm{bot_user_id, scope, public_key, payload, false};
auto new_promise = auto new_promise =
PromiseCreator::lambda([actor_id = actor_id(this), authorization_form_id, promise = std::move(promise)]( PromiseCreator::lambda([actor_id = actor_id(this), authorization_form_id, promise = std::move(promise)](
@ -1013,7 +1039,7 @@ void SecureManager::get_passport_authorization_form(string password, UserId bot_
send_closure(actor_id, &SecureManager::on_get_passport_authorization_form, authorization_form_id, send_closure(actor_id, &SecureManager::on_get_passport_authorization_form, authorization_form_id,
std::move(promise), std::move(r_authorization_form)); std::move(promise), std::move(r_authorization_form));
}); });
create_actor<GetPassportAuthorizationForm>("GetPassportAuthorizationForm", actor_shared(), std::move(password), create_actor<GetPassportAuthorizationForm>("GetPassportAuthorizationForm", actor_shared(this), std::move(password),
authorization_form_id, bot_user_id, std::move(scope), authorization_form_id, bot_user_id, std::move(scope),
std::move(public_key), std::move(new_promise)) std::move(public_key), std::move(new_promise))
.release(); .release();
@ -1050,31 +1076,28 @@ void SecureManager::send_passport_authorization_form(string password, int32 auth
} }
struct JoinPromise { struct JoinPromise {
std::mutex mutex_;
Promise<std::vector<SecureValueCredentials>> promise_; Promise<std::vector<SecureValueCredentials>> promise_;
std::vector<SecureValueCredentials> credentials_; std::vector<SecureValueCredentials> credentials_;
int wait_cnt_{0}; int wait_cnt_{0};
}; };
auto join = std::make_shared<JoinPromise>(); auto join = std::make_shared<JoinPromise>();
std::lock_guard<std::mutex> guard(join->mutex_);
for (auto type : types) { for (auto type : types) {
join->wait_cnt_++; join->wait_cnt_++;
do_get_secure_value(password, type, send_closure_later(actor_id(this), &SecureManager::do_get_secure_value, password, type, true,
PromiseCreator::lambda([join](Result<SecureValueWithCredentials> r_secure_value) { PromiseCreator::lambda([join](Result<SecureValueWithCredentials> r_secure_value) {
std::lock_guard<std::mutex> guard(join->mutex_); if (!join->promise_) {
if (!join->promise_) { return;
return; }
} if (r_secure_value.is_error()) {
if (r_secure_value.is_error()) { return join->promise_.set_error(r_secure_value.move_as_error());
return join->promise_.set_error(r_secure_value.move_as_error()); }
} join->credentials_.push_back(r_secure_value.move_as_ok().credentials);
join->credentials_.push_back(r_secure_value.move_as_ok().credentials); join->wait_cnt_--;
join->wait_cnt_--; if (join->wait_cnt_ == 0) {
if (join->wait_cnt_ == 0) { join->promise_.set_value(std::move(join->credentials_));
join->promise_.set_value(std::move(join->credentials_)); }
} }));
}));
} }
join->promise_ = join->promise_ =
PromiseCreator::lambda([promise = std::move(promise), actor_id = actor_id(this), PromiseCreator::lambda([promise = std::move(promise), actor_id = actor_id(this),

View File

@ -19,8 +19,8 @@
#include "td/utils/Container.h" #include "td/utils/Container.h"
#include "td/utils/Status.h" #include "td/utils/Status.h"
#include <map>
#include <memory> #include <memory>
#include <unordered_map>
namespace td { namespace td {
@ -41,6 +41,8 @@ class SecureManager : public NetQueryCallback {
void set_secure_value_errors(Td *td, tl_object_ptr<telegram_api::InputUser> input_user, void set_secure_value_errors(Td *td, tl_object_ptr<telegram_api::InputUser> input_user,
vector<tl_object_ptr<td_api::inputPassportElementError>> errors, Promise<Unit> promise); vector<tl_object_ptr<td_api::inputPassportElementError>> errors, Promise<Unit> promise);
void on_get_secure_value(SecureValueWithCredentials value);
void get_passport_authorization_form(string password, UserId bot_user_id, string scope, string public_key, void get_passport_authorization_form(string password, UserId bot_user_id, string scope, string public_key,
string payload, Promise<TdApiAuthorizationForm> promise); string payload, Promise<TdApiAuthorizationForm> promise);
void send_passport_authorization_form(string password, int32 authorization_form_id, void send_passport_authorization_form(string password, int32 authorization_form_id,
@ -49,7 +51,8 @@ class SecureManager : public NetQueryCallback {
private: private:
ActorShared<> parent_; ActorShared<> parent_;
int32 refcnt_{1}; int32 refcnt_{1};
std::map<SecureValueType, ActorOwn<>> set_secure_value_queries_; std::unordered_map<SecureValueType, ActorOwn<>> set_secure_value_queries_;
std::unordered_map<SecureValueType, SecureValueWithCredentials> secure_value_cache_;
struct AuthorizationForm { struct AuthorizationForm {
UserId bot_user_id; UserId bot_user_id;
@ -59,13 +62,14 @@ class SecureManager : public NetQueryCallback {
bool is_received; bool is_received;
}; };
std::map<int32, AuthorizationForm> authorization_forms_; std::unordered_map<int32, AuthorizationForm> authorization_forms_;
int32 authorization_form_id_{0}; int32 max_authorization_form_id_{0};
void hangup() override; void hangup() override;
void hangup_shared() override; void hangup_shared() override;
void dec_refcnt(); void dec_refcnt();
void do_get_secure_value(std::string password, SecureValueType type, Promise<SecureValueWithCredentials> promise); void do_get_secure_value(std::string password, SecureValueType type, bool allow_from_cache,
Promise<SecureValueWithCredentials> promise);
void on_delete_secure_value(SecureValueType type, Promise<Unit> promise, Result<Unit> result); void on_delete_secure_value(SecureValueType type, Promise<Unit> promise, Result<Unit> result);
void on_get_passport_authorization_form(int32 authorization_form_id, Promise<TdApiAuthorizationForm> promise, void on_get_passport_authorization_form(int32 authorization_form_id, Promise<TdApiAuthorizationForm> promise,
Result<TdApiAuthorizationForm> r_authorization_form); Result<TdApiAuthorizationForm> r_authorization_form);