Use persistend AuthKey in ConfigRecoverer

GitOrigin-RevId: 6c4ce6dc49d226de585c7c43d41471271c9fcca1
This commit is contained in:
Arseny Smirnov 2018-05-18 14:38:49 +03:00
parent 0c1d797753
commit 45a20f6929
2 changed files with 50 additions and 28 deletions

View File

@ -177,7 +177,7 @@ ActorOwn<> get_simple_config_google_dns(Promise<SimpleConfig> promise, bool is_t
#endif #endif
} }
ActorOwn<> get_full_config(IPAddress ip_address, Promise<FullConfig> promise) { ActorOwn<> get_full_config(DcId dc_id, IPAddress ip_address, Promise<FullConfig> promise) {
class SessionCallback : public Session::Callback { class SessionCallback : public Session::Callback {
public: public:
SessionCallback(ActorShared<> parent, IPAddress address) SessionCallback(ActorShared<> parent, IPAddress address)
@ -190,7 +190,7 @@ ActorOwn<> get_full_config(IPAddress ip_address, Promise<FullConfig> promise) {
void request_raw_connection(Promise<std::unique_ptr<mtproto::RawConnection>> promise) final { void request_raw_connection(Promise<std::unique_ptr<mtproto::RawConnection>> promise) final {
request_raw_connection_cnt_++; request_raw_connection_cnt_++;
VLOG(config_recoverer) << "Request full config from " << address_ << ", try = " << request_raw_connection_cnt_; VLOG(config_recoverer) << "Request full config from " << address_ << ", try = " << request_raw_connection_cnt_;
if (request_raw_connection_cnt_ <= 1) { if (request_raw_connection_cnt_ <= 2) {
send_closure(G()->connection_creator(), &ConnectionCreator::request_raw_connection_by_ip, address_, send_closure(G()->connection_creator(), &ConnectionCreator::request_raw_connection_by_ip, address_,
std::move(promise)); std::move(promise));
} else { } else {
@ -211,14 +211,22 @@ ActorOwn<> get_full_config(IPAddress ip_address, Promise<FullConfig> promise) {
class SimpleAuthData : public AuthDataShared { class SimpleAuthData : public AuthDataShared {
public: public:
explicit SimpleAuthData(DcId dc_id) : dc_id_(dc_id) {
}
DcId dc_id() const override { DcId dc_id() const override {
return DcId::empty(); return dc_id_;
} }
const std::shared_ptr<PublicRsaKeyShared> &public_rsa_key() override { const std::shared_ptr<PublicRsaKeyShared> &public_rsa_key() override {
return public_rsa_key_; return public_rsa_key_;
} }
mtproto::AuthKey get_auth_key() override { mtproto::AuthKey get_auth_key() override {
return auth_key_; string dc_key = G()->td_db()->get_binlog_pmc()->get(auth_key_key());
mtproto::AuthKey res;
if (!dc_key.empty()) {
unserialize(res, dc_key).ensure();
}
return res;
} }
std::pair<AuthState, bool> get_auth_state() override { std::pair<AuthState, bool> get_auth_state() override {
auto auth_key = get_auth_key(); auto auth_key = get_auth_key();
@ -226,16 +234,16 @@ ActorOwn<> get_full_config(IPAddress ip_address, Promise<FullConfig> promise) {
return std::make_pair(state, auth_key.was_auth_flag()); return std::make_pair(state, auth_key.was_auth_flag());
} }
void set_auth_key(const mtproto::AuthKey &auth_key) override { void set_auth_key(const mtproto::AuthKey &auth_key) override {
auth_key_ = auth_key; G()->td_db()->get_binlog_pmc()->set(auth_key_key(), serialize(auth_key));
//notify();
} }
void update_server_time_difference(double diff) override { void update_server_time_difference(double diff) override {
if (!has_server_time_difference_ || server_time_difference_ < diff) { G()->update_server_time_difference(diff);
has_server_time_difference_ = true;
server_time_difference_ = diff;
}
} }
double get_server_time_difference() override { double get_server_time_difference() override {
return server_time_difference_; return G()->get_server_time_difference();
//return server_time_difference_;
} }
void add_auth_key_listener(unique_ptr<Listener> listener) override { void add_auth_key_listener(unique_ptr<Listener> listener) override {
if (listener->notify()) { if (listener->notify()) {
@ -244,41 +252,52 @@ ActorOwn<> get_full_config(IPAddress ip_address, Promise<FullConfig> promise) {
} }
void set_future_salts(const std::vector<mtproto::ServerSalt> &future_salts) override { void set_future_salts(const std::vector<mtproto::ServerSalt> &future_salts) override {
future_salts_ = future_salts; G()->td_db()->get_binlog_pmc()->set(future_salts_key(), serialize(future_salts));
} }
std::vector<mtproto::ServerSalt> get_future_salts() override { std::vector<mtproto::ServerSalt> get_future_salts() override {
return future_salts_; string future_salts = G()->td_db()->get_binlog_pmc()->get(future_salts_key());
std::vector<mtproto::ServerSalt> res;
if (!future_salts.empty()) {
unserialize(res, future_salts).ensure();
}
return res;
} }
private: private:
DcId dc_id_;
std::shared_ptr<PublicRsaKeyShared> public_rsa_key_ = std::make_shared<PublicRsaKeyShared>(DcId::empty()); std::shared_ptr<PublicRsaKeyShared> public_rsa_key_ = std::make_shared<PublicRsaKeyShared>(DcId::empty());
mtproto::AuthKey auth_key_;
bool has_server_time_difference_ = false; bool has_server_time_difference_ = false;
double server_time_difference_ = 0; double server_time_difference_ = 0;
std::vector<mtproto::ServerSalt> future_salts_;
std::vector<std::unique_ptr<Listener>> auth_key_listeners_; std::vector<std::unique_ptr<Listener>> auth_key_listeners_;
void notify() { void notify() {
auto it = std::remove_if(auth_key_listeners_.begin(), auth_key_listeners_.end(), auto it = std::remove_if(auth_key_listeners_.begin(), auth_key_listeners_.end(),
[&](auto &listener) { return !listener->notify(); }); [&](auto &listener) { return !listener->notify(); });
auth_key_listeners_.erase(it, auth_key_listeners_.end()); auth_key_listeners_.erase(it, auth_key_listeners_.end());
} }
string auth_key_key() {
return PSTRING() << "config_recovery_auth" << dc_id().get_raw_id();
}
string future_salts_key() {
return PSTRING() << "config_recovery_salt" << dc_id().get_raw_id();
}
}; };
class GetConfigActor : public NetQueryCallback { class GetConfigActor : public NetQueryCallback {
public: public:
GetConfigActor(IPAddress ip_address, Promise<FullConfig> promise) GetConfigActor(DcId dc_id, IPAddress ip_address, Promise<FullConfig> promise)
: ip_address_(std::move(ip_address)), promise_(std::move(promise)) { : dc_id_(dc_id), ip_address_(std::move(ip_address)), promise_(std::move(promise)) {
} }
private: private:
void start_up() override { void start_up() override {
auto session_callback = std::make_unique<SessionCallback>(actor_shared(this, 1), std::move(ip_address_)); auto session_callback = std::make_unique<SessionCallback>(actor_shared(this, 1), std::move(ip_address_));
auto auth_data = std::make_shared<SimpleAuthData>(); auto auth_data = std::make_shared<SimpleAuthData>(dc_id_);
session_ = create_actor<Session>("ConfigSession", std::move(session_callback), std::move(auth_data), session_ = create_actor<Session>("ConfigSession", std::move(session_callback), std::move(auth_data),
false /*is_main*/, false /*use_pfs*/, true /*is_cdn*/, mtproto::AuthKey()); false /*is_main*/, true /*use_pfs*/, false /*is_cdn*/, mtproto::AuthKey());
auto query = G()->net_query_creator().create(create_storer(telegram_api::help_getConfig()), DcId::empty(), auto query = G()->net_query_creator().create(create_storer(telegram_api::help_getConfig()), DcId::empty(),
NetQuery::Type::Common, NetQuery::AuthFlag::Off, NetQuery::Type::Common, NetQuery::AuthFlag::Off,
NetQuery::GzipFlag::On, 60 * 60 * 24); NetQuery::GzipFlag::On, 60 * 60 * 24);
@ -293,7 +312,9 @@ ActorOwn<> get_full_config(IPAddress ip_address, Promise<FullConfig> promise) {
} }
void hangup_shared() override { void hangup_shared() override {
if (get_link_token() == 1) { if (get_link_token() == 1) {
promise_.set_error(Status::Error("Failed")); if (promise_) {
promise_.set_error(Status::Error("Failed"));
}
stop(); stop();
} }
} }
@ -302,15 +323,16 @@ ActorOwn<> get_full_config(IPAddress ip_address, Promise<FullConfig> promise) {
} }
void timeout_expired() override { void timeout_expired() override {
promise_.set_error(Status::Error("Timeout expired")); promise_.set_error(Status::Error("Timeout expired"));
stop(); session_.reset();
} }
DcId dc_id_;
IPAddress ip_address_; IPAddress ip_address_;
ActorOwn<Session> session_; ActorOwn<Session> session_;
Promise<FullConfig> promise_; Promise<FullConfig> promise_;
}; };
return ActorOwn<>(create_actor<GetConfigActor>("GetConfigActor", std::move(ip_address), std::move(promise))); return ActorOwn<>(create_actor<GetConfigActor>("GetConfigActor", dc_id, std::move(ip_address), std::move(promise)));
} }
class ConfigRecoverer : public Actor { class ConfigRecoverer : public Actor {
@ -508,11 +530,11 @@ class ConfigRecoverer : public Actor {
if (need_full_config) { if (need_full_config) {
ref_cnt_++; ref_cnt_++;
VLOG(config_recoverer) << "ASK FULL CONFIG"; VLOG(config_recoverer) << "ASK FULL CONFIG";
full_config_query_ = full_config_query_ = get_full_config(
get_full_config(dc_options_.dc_options[dc_options_i_].get_ip_address(), dc_options_.dc_options[dc_options_i_].get_dc_id(), dc_options_.dc_options[dc_options_i_].get_ip_address(),
PromiseCreator::lambda([actor_id = actor_shared(this)](Result<FullConfig> r_full_config) { PromiseCreator::lambda([actor_id = actor_shared(this)](Result<FullConfig> r_full_config) {
send_closure(actor_id, &ConfigRecoverer::on_full_config, std::move(r_full_config), false); send_closure(actor_id, &ConfigRecoverer::on_full_config, std::move(r_full_config), false);
})); }));
dc_options_i_ = (dc_options_i_ + 1) % dc_options_.dc_options.size(); dc_options_i_ = (dc_options_i_ + 1) % dc_options_.dc_options.size();
} }

View File

@ -32,7 +32,7 @@ ActorOwn<> get_simple_config_google_dns(Promise<SimpleConfig> promise, bool is_t
using FullConfig = tl_object_ptr<telegram_api::config>; using FullConfig = tl_object_ptr<telegram_api::config>;
ActorOwn<> get_full_config(IPAddress ip_address, Promise<FullConfig> promise); ActorOwn<> get_full_config(DcId dc_id, IPAddress ip_address, Promise<FullConfig> promise);
class ConfigRecoverer; class ConfigRecoverer;
class ConfigManager : public NetQueryCallback { class ConfigManager : public NetQueryCallback {