Protect G()->td_db() usage in AuthDataShared with guard

GitOrigin-RevId: d9e8a79c9db9b061de81f98fcd66f085d4efcaaa
This commit is contained in:
Arseny Smirnov 2018-03-13 16:40:02 +03:00
parent c29f5e9432
commit 600bbcf3ca
4 changed files with 14 additions and 6 deletions

View File

@ -19,8 +19,8 @@ namespace td {
class AuthDataSharedImpl : public AuthDataShared { class AuthDataSharedImpl : public AuthDataShared {
public: public:
AuthDataSharedImpl(DcId dc_id, std::shared_ptr<PublicRsaKeyShared> public_rsa_key) AuthDataSharedImpl(DcId dc_id, std::shared_ptr<PublicRsaKeyShared> public_rsa_key, std::shared_ptr<Guard> guard)
: dc_id_(dc_id), public_rsa_key_(std::move(public_rsa_key)) { : dc_id_(dc_id), public_rsa_key_(std::move(public_rsa_key)), guard_(std::move(guard)) {
log_auth_key(get_auth_key()); log_auth_key(get_auth_key());
} }
@ -89,6 +89,7 @@ class AuthDataSharedImpl : public AuthDataShared {
DcId dc_id_; DcId dc_id_;
std::vector<unique_ptr<Listener>> auth_key_listeners_; std::vector<unique_ptr<Listener>> auth_key_listeners_;
std::shared_ptr<PublicRsaKeyShared> public_rsa_key_; std::shared_ptr<PublicRsaKeyShared> public_rsa_key_;
std::shared_ptr<Guard> guard_;
RwMutex rw_mutex_; RwMutex rw_mutex_;
string auth_key_key() { string auth_key_key() {
@ -111,7 +112,8 @@ class AuthDataSharedImpl : public AuthDataShared {
} }
}; };
std::shared_ptr<AuthDataShared> AuthDataShared::create(DcId dc_id, std::shared_ptr<PublicRsaKeyShared> public_rsa_key) { std::shared_ptr<AuthDataShared> AuthDataShared::create(DcId dc_id, std::shared_ptr<PublicRsaKeyShared> public_rsa_key,
return std::make_shared<AuthDataSharedImpl>(dc_id, std::move(public_rsa_key)); std::shared_ptr<Guard> guard) {
return std::make_shared<AuthDataSharedImpl>(dc_id, std::move(public_rsa_key), std::move(guard));
} }
} // namespace td } // namespace td

View File

@ -13,6 +13,7 @@
#include "td/telegram/net/PublicRsaKeyShared.h" #include "td/telegram/net/PublicRsaKeyShared.h"
#include "td/utils/common.h" #include "td/utils/common.h"
#include "td/utils/ScopeGuard.h"
#include "td/utils/StringBuilder.h" #include "td/utils/StringBuilder.h"
#include <memory> #include <memory>
@ -70,7 +71,8 @@ class AuthDataShared {
return state; return state;
} }
static std::shared_ptr<AuthDataShared> create(DcId dc_id, std::shared_ptr<PublicRsaKeyShared> public_rsa_key); static std::shared_ptr<AuthDataShared> create(DcId dc_id, std::shared_ptr<PublicRsaKeyShared> public_rsa_key,
std::shared_ptr<Guard> guard);
}; };
}; // namespace td }; // namespace td

View File

@ -140,7 +140,7 @@ Status NetQueryDispatcher::wait_dc_init(DcId dc_id, bool force) {
send_closure_later(public_rsa_key_watchdog_, &PublicRsaKeyWatchdog::add_public_rsa_key, public_rsa_key); send_closure_later(public_rsa_key_watchdog_, &PublicRsaKeyWatchdog::add_public_rsa_key, public_rsa_key);
is_cdn = true; is_cdn = true;
} }
auto auth_data = AuthDataShared::create(dc_id, std::move(public_rsa_key)); auto auth_data = AuthDataShared::create(dc_id, std::move(public_rsa_key), td_guard_);
int32 session_count = get_session_count(); int32 session_count = get_session_count();
bool use_pfs = get_use_pfs(); bool use_pfs = get_use_pfs();
@ -184,6 +184,7 @@ void NetQueryDispatcher::dispatch_with_callback(NetQueryPtr net_query, ActorShar
void NetQueryDispatcher::stop() { void NetQueryDispatcher::stop() {
std::lock_guard<std::mutex> guard(main_dc_id_mutex_); std::lock_guard<std::mutex> guard(main_dc_id_mutex_);
td_guard_.reset();
stop_flag_ = true; stop_flag_ = true;
delayer_.hangup(); delayer_.hangup();
for (const auto &dc : dcs_) { for (const auto &dc : dcs_) {
@ -246,6 +247,8 @@ NetQueryDispatcher::NetQueryDispatcher(std::function<ActorShared<>()> create_ref
dc_auth_manager_ = create_actor<DcAuthManager>("DcAuthManager", create_reference()); dc_auth_manager_ = create_actor<DcAuthManager>("DcAuthManager", create_reference());
common_public_rsa_key_ = std::make_shared<PublicRsaKeyShared>(DcId::empty()); common_public_rsa_key_ = std::make_shared<PublicRsaKeyShared>(DcId::empty());
public_rsa_key_watchdog_ = create_actor<PublicRsaKeyWatchdog>("PublicRsaKeyWatchdog", create_reference()); public_rsa_key_watchdog_ = create_actor<PublicRsaKeyWatchdog>("PublicRsaKeyWatchdog", create_reference());
td_guard_ = create_shared_lamda_guard([actor = create_reference] {});
} }
NetQueryDispatcher::NetQueryDispatcher() = default; NetQueryDispatcher::NetQueryDispatcher() = default;

View File

@ -74,6 +74,7 @@ class NetQueryDispatcher {
std::shared_ptr<PublicRsaKeyShared> common_public_rsa_key_; std::shared_ptr<PublicRsaKeyShared> common_public_rsa_key_;
ActorOwn<PublicRsaKeyWatchdog> public_rsa_key_watchdog_; ActorOwn<PublicRsaKeyWatchdog> public_rsa_key_watchdog_;
std::mutex main_dc_id_mutex_; std::mutex main_dc_id_mutex_;
std::shared_ptr<Guard> td_guard_;
Status wait_dc_init(DcId dc_id, bool force); Status wait_dc_init(DcId dc_id, bool force);
bool is_dc_inited(int32 raw_dc_id); bool is_dc_inited(int32 raw_dc_id);