diff --git a/td/generate/scheme/mtproto_api.tl b/td/generate/scheme/mtproto_api.tl index cee9bfc97..fdfe78d8f 100644 --- a/td/generate/scheme/mtproto_api.tl +++ b/td/generate/scheme/mtproto_api.tl @@ -66,6 +66,10 @@ msg_new_detailed_info#809db6df answer_msg_id:long bytes:int status:int = MsgDeta rsa_public_key n:string e:string = RSAPublicKey; +destroy_auth_key_ok#f660e1d4 = DestroyAuthKeyRes; +destroy_auth_key_none#0a9f2259 = DestroyAuthKeyRes; +destroy_auth_key_fail#ea109b13 = DestroyAuthKeyRes; + ---functions--- req_pq_multi#be7e8ef1 nonce:int128 = ResPQ; @@ -82,6 +86,8 @@ destroy_session#e7512126 session_id:long = DestroySessionRes; http_wait#9299359f max_delay:int wait_after:int max_wait:int = HttpWait; +destroy_auth_key#d1435160 = DestroyAuthKeyRes; + //test.useGzipPacked = GzipPacked; //test.useServerDhInnerData = Server_DH_inner_data; //test.useNewSessionCreated = NewSession; diff --git a/td/generate/scheme/mtproto_api.tlo b/td/generate/scheme/mtproto_api.tlo index fff565817..2e00490e5 100644 Binary files a/td/generate/scheme/mtproto_api.tlo and b/td/generate/scheme/mtproto_api.tlo differ diff --git a/td/mtproto/AuthData.h b/td/mtproto/AuthData.h index 6e6af1fb6..53370d965 100644 --- a/td/mtproto/AuthData.h +++ b/td/mtproto/AuthData.h @@ -136,7 +136,7 @@ class AuthData { void set_auth_flag(bool auth_flag) { main_auth_key_.set_auth_flag(auth_flag); if (!auth_flag) { - tmp_auth_key_.set_auth_flag(auth_flag); + drop_tmp_auth_key(); } } diff --git a/td/mtproto/AuthKey.h b/td/mtproto/AuthKey.h index d71afc88e..e3809eb61 100644 --- a/td/mtproto/AuthKey.h +++ b/td/mtproto/AuthKey.h @@ -33,11 +33,7 @@ class AuthKey { return was_auth_flag_; } void set_auth_flag(bool new_auth_flag) { - if (new_auth_flag == false) { - clear(); - } else { - was_auth_flag_ = true; - } + was_auth_flag_ |= new_auth_flag; auth_flag_ = new_auth_flag; } diff --git a/td/mtproto/CryptoStorer.h b/td/mtproto/CryptoStorer.h index 580103631..e06e3ab58 100644 --- a/td/mtproto/CryptoStorer.h +++ b/td/mtproto/CryptoStorer.h @@ -17,6 +17,12 @@ #include "td/utils/Time.h" namespace td { +namespace mtproto_api { +class msg_container { + public: + static const int32 ID = 0x73f1f8dc; +}; +} // namespace mtproto_api namespace mtproto { template @@ -65,6 +71,7 @@ using GetFutureSaltsImpl = ObjectImpl>; using CancelImpl = ObjectImpl>; using GetInfoImpl = ObjectImpl>; +using DestroyAuthKeyImpl = ObjectImpl>; class CancelVectorImpl { public: @@ -182,8 +189,8 @@ class CryptoImpl { public: CryptoImpl(const vector &to_send, Slice header, vector &&to_ack, int64 ping_id, int ping_timeout, int max_delay, int max_after, int max_wait, int future_salt_n, vector get_info, - vector resend, vector cancel, AuthData *auth_data, uint64 *container_id, uint64 *get_info_id, - uint64 *resend_id, uint64 *ping_message_id, uint64 *parent_message_id) + vector resend, vector cancel, bool destroy_key, AuthData *auth_data, uint64 *container_id, + uint64 *get_info_id, uint64 *resend_id, uint64 *ping_message_id, uint64 *parent_message_id) : query_storer_(to_send, header) , ack_empty_(to_ack.empty()) , ack_storer_(!ack_empty_, mtproto_api::msgs_ack(std::move(to_ack)), auth_data) @@ -197,16 +204,18 @@ class CryptoImpl { , cancel_not_empty_(!cancel.empty()) , cancel_cnt_(static_cast(cancel.size())) , cancel_storer_(cancel_not_empty_, std::move(cancel), auth_data, true) + , destroy_key_storer_(destroy_key, mtproto_api::destroy_auth_key(), auth_data, true) , tmp_storer_(query_storer_, ack_storer_) , tmp2_storer_(tmp_storer_, http_wait_storer_) , tmp3_storer_(tmp2_storer_, get_future_salts_storer_) , tmp4_storer_(tmp3_storer_, get_info_storer_) , tmp5_storer_(tmp4_storer_, resend_storer_) , tmp6_storer_(tmp5_storer_, cancel_storer_) - , concat_storer_(tmp6_storer_, ping_storer_) + , tmp7_storer_(tmp6_storer_, destroy_key_storer_) + , concat_storer_(tmp7_storer_, ping_storer_) , cnt_(static_cast(to_send.size()) + ack_storer_.not_empty() + ping_storer_.not_empty() + http_wait_storer_.not_empty() + get_future_salts_storer_.not_empty() + get_info_storer_.not_empty() + - resend_storer_.not_empty() + cancel_cnt_) + resend_storer_.not_empty() + cancel_cnt_ + destroy_key_storer_.not_empty()) , container_storer_(cnt_, concat_storer_) { CHECK(cnt_ != 0); if (get_info_storer_.not_empty() && get_info_id) { @@ -252,6 +261,9 @@ class CryptoImpl { } else if (cancel_storer_.not_empty()) { type_ = OnlyCancel; *parent_message_id = cancel_storer_.get_message_id(); + } else if (destroy_key_storer_.not_empty()) { + type_ = OnlyDestroyKey; + *parent_message_id = destroy_key_storer_.get_message_id(); } else { UNREACHABLE(); } @@ -284,6 +296,9 @@ class CryptoImpl { case OnlyGetInfo: return storer.store_storer(get_info_storer_); + case OnlyDestroyKey: + return storer.store_storer(destroy_key_storer_); + default: storer.store_binary(message_id_); storer.store_binary(seq_no_); @@ -306,12 +321,14 @@ class CryptoImpl { bool cancel_not_empty_; int32 cancel_cnt_; PacketStorer cancel_storer_; + PacketStorer destroy_key_storer_; ConcatStorer tmp_storer_; ConcatStorer tmp2_storer_; ConcatStorer tmp3_storer_; ConcatStorer tmp4_storer_; ConcatStorer tmp5_storer_; ConcatStorer tmp6_storer_; + ConcatStorer tmp7_storer_; ConcatStorer concat_storer_; int32 cnt_; PacketStorer container_storer_; @@ -324,6 +341,7 @@ class CryptoImpl { OnlyResend, OnlyCancel, OnlyGetInfo, + OnlyDestroyKey, Mixed }; Type type_; diff --git a/td/mtproto/SessionConnection.cpp b/td/mtproto/SessionConnection.cpp index 89acc64d2..8aebb1045 100644 --- a/td/mtproto/SessionConnection.cpp +++ b/td/mtproto/SessionConnection.cpp @@ -270,14 +270,29 @@ Status SessionConnection::on_packet(const MsgInfo &info, const T &packet) { LOG(ERROR) << "Unsupported: " << to_string(packet); return Status::OK(); } +Status SessionConnection::on_packet(const MsgInfo &info, const mtproto_api::destroy_auth_key_ok &destroy_auth_key) { + return on_destroy_auth_key(destroy_auth_key); +} +Status SessionConnection::on_packet(const MsgInfo &info, const mtproto_api::destroy_auth_key_none &destroy_auth_key) { + return on_destroy_auth_key(destroy_auth_key); +} +Status SessionConnection::on_packet(const MsgInfo &info, const mtproto_api::destroy_auth_key_fail &destroy_auth_key) { + return on_destroy_auth_key(destroy_auth_key); +} + +Status SessionConnection::on_destroy_auth_key(const mtproto_api::DestroyAuthKeyRes &destroy_auth_key) { + CHECK(need_destroy_auth_key_); + LOG(INFO) << to_string(destroy_auth_key); + return callback_->on_destroy_auth_key(); +} Status SessionConnection::on_packet(const MsgInfo &info, const mtproto_api::rpc_error &rpc_error) { return on_packet(info, 0, rpc_error); } Status SessionConnection::on_packet(const MsgInfo &info, uint64 req_msg_id, const mtproto_api::rpc_error &rpc_error) { - VLOG(mtproto) << "ERROR [code:" << rpc_error.error_code_ << "] [msg:" << rpc_error.error_message_.str().c_str() - << "]"; + VLOG(mtproto) << "ERROR [code:" << rpc_error.error_code_ << "] [msg:" << rpc_error.error_message_.str().c_str() << "]" + << " " << tag("req_msg_id", req_msg_id); if (req_msg_id != 0) { callback_->on_message_result_error(req_msg_id, rpc_error.error_code_, as_buffer_slice(rpc_error.error_message_)); } else { @@ -524,6 +539,8 @@ Status SessionConnection::on_main_packet(const PacketInfo &info, Slice packet) { void SessionConnection::on_message_failed(uint64 id, Status status) { callback_->on_message_failed(id, std::move(status)); + sent_destroy_auth_key_ = false; + if (id == last_ping_message_id_ || id == last_ping_container_id_) { // restart ping immediately last_ping_at_ = 0; @@ -613,6 +630,10 @@ bool SessionConnection::must_flush_packet() { relax_timeout_at(&flush_packet_at_, get_future_salts_at); } + if (has_salt && need_destroy_auth_key_ && !sent_destroy_auth_key_) { + return true; + } + return false; } @@ -741,6 +762,11 @@ void SessionConnection::cancel_answer(int64 message_id) { to_cancel_answer_.push_back(message_id); } +void SessionConnection::destroy_key() { + LOG(INFO) << "need_destroy_key = true"; + need_destroy_auth_key_ = true; +} + std::pair SessionConnection::encrypted_bind(int64 perm_key, int64 nonce, int32 expire_at) { int64 temp_key = auth_data_->get_tmp_auth_key().id(); @@ -839,17 +865,21 @@ void SessionConnection::flush_packet() { to_send_.erase(to_send_.begin(), to_send_.begin() + send_till); } + bool destroy_auth_key = need_destroy_auth_key_ && !sent_destroy_auth_key_; + if (queries.empty() && to_ack_.empty() && ping_id == 0 && max_delay < 0 && future_salt_n == 0 && - to_resend_answer_.empty() && to_cancel_answer_.empty() && to_get_state_info_.empty()) { + to_resend_answer_.empty() && to_cancel_answer_.empty() && to_get_state_info_.empty() && !destroy_auth_key) { force_send_at_ = 0; return; } + sent_destroy_auth_key_ |= destroy_auth_key; + VLOG(mtproto) << "Sent packet: " << tag("query_count", queries.size()) << tag("ack_cnt", to_ack_.size()) << tag("ping", ping_id != 0) << tag("http_wait", max_delay >= 0) << tag("future_salt", future_salt_n > 0) << tag("get_info", to_get_state_info_.size()) << tag("resend", to_resend_answer_.size()) << tag("cancel", to_cancel_answer_.size()) - << tag("auth_id", auth_data_->get_auth_key().id()); + << tag("destroy_key", destroy_auth_key) << tag("auth_id", auth_data_->get_auth_key().id()); auto cut_tail = [](auto &v, size_t size, Slice name) { if (size >= v.size()) { @@ -878,8 +908,8 @@ void SessionConnection::flush_packet() { uint64 parent_message_id = 0; auto storer = PacketStorer( queries, auth_data_->get_header(), std::move(to_ack), ping_id, ping_disconnect_delay() + 2, max_delay, - max_after, max_wait, future_salt_n, to_get_state_info, to_resend_answer, to_cancel_answer, auth_data_, - &container_id, &get_state_info_id, &resend_answer_id, &ping_message_id, &parent_message_id); + max_after, max_wait, future_salt_n, to_get_state_info, to_resend_answer, to_cancel_answer, destroy_auth_key, + auth_data_, &container_id, &get_state_info_id, &resend_answer_id, &ping_message_id, &parent_message_id); auto quick_ack_token = use_quick_ack ? parent_message_id : 0; send_crypto(storer, quick_ack_token); diff --git a/td/mtproto/SessionConnection.h b/td/mtproto/SessionConnection.h index b5c965b00..4f5e196a8 100644 --- a/td/mtproto/SessionConnection.h +++ b/td/mtproto/SessionConnection.h @@ -25,11 +25,6 @@ namespace td { namespace mtproto_api { -class msg_container { - public: - static const int32 ID = 0x73f1f8dc; -}; - class rpc_error; class new_session_created; class bad_msg_notification; @@ -42,6 +37,10 @@ class msgs_state_info; class msgs_all_info; class msg_detailed_info; class msg_new_detailed_info; +class DestroyAuthKeyRes; +class destroy_auth_key_ok; +class destroy_auth_key_fail; +class destroy_auth_key_none; } // namespace mtproto_api namespace mtproto { @@ -78,6 +77,7 @@ class SessionConnection void get_state_info(int64 message_id); void resend_answer(int64 message_id); void cancel_answer(int64 message_id); + void destroy_key(); void set_online(bool online_flag); @@ -109,6 +109,8 @@ class SessionConnection virtual void on_message_result_error(uint64 id, int code, BufferSlice descr) = 0; virtual void on_message_failed(uint64 id, Status status) = 0; virtual void on_message_info(uint64 id, int32 state, uint64 answer_id, int32 answer_size) = 0; + + virtual Status on_destroy_auth_key() = 0; }; double flush(SessionConnection::Callback *callback); @@ -168,6 +170,9 @@ class SessionConnection uint64 last_ping_message_id_ = 0; uint64 last_ping_container_id_ = 0; + bool need_destroy_auth_key_{false}; + bool sent_destroy_auth_key_{false}; + double wakeup_at_ = 0; double flush_packet_at_ = 0; @@ -222,6 +227,12 @@ class SessionConnection Status on_packet(const MsgInfo &info, const mtproto_api::msg_detailed_info &msg_detailed_info) TD_WARN_UNUSED_RESULT; Status on_packet(const MsgInfo &info, const mtproto_api::msg_new_detailed_info &msg_new_detailed_info) TD_WARN_UNUSED_RESULT; + Status on_packet(const MsgInfo &info, const mtproto_api::destroy_auth_key_ok &destroy_auth_key) TD_WARN_UNUSED_RESULT; + Status on_packet(const MsgInfo &info, + const mtproto_api::destroy_auth_key_none &destroy_auth_key) TD_WARN_UNUSED_RESULT; + Status on_packet(const MsgInfo &info, + const mtproto_api::destroy_auth_key_fail &destroy_auth_key) TD_WARN_UNUSED_RESULT; + Status on_destroy_auth_key(const mtproto_api::DestroyAuthKeyRes &destroy_auth_key); Status on_slice_packet(const MsgInfo &info, Slice packet) TD_WARN_UNUSED_RESULT; Status on_main_packet(const PacketInfo &info, Slice packet) TD_WARN_UNUSED_RESULT; diff --git a/td/telegram/AuthManager.cpp b/td/telegram/AuthManager.cpp index 5dc833172..7d1024e1c 100644 --- a/td/telegram/AuthManager.cpp +++ b/td/telegram/AuthManager.cpp @@ -431,6 +431,8 @@ AuthManager::AuthManager(int32 api_id, const string &api_hash, ActorShared<> par } } else if (auth_str == "logout") { update_state(State::LoggingOut); + } else if (auth_str == "destroy") { + update_state(State::DestroyingKeys); } else { if (!load_state()) { update_state(State::WaitPhoneNumber); @@ -441,6 +443,8 @@ AuthManager::AuthManager(int32 api_id, const string &api_hash, ActorShared<> par void AuthManager::start_up() { if (state_ == State::LoggingOut) { start_net_query(NetQueryType::LogOut, G()->net_query_creator().create(create_storer(telegram_api::auth_logOut()))); + } else if (state_ == State::DestroyingKeys) { + destroy_auth_keys(); } } void AuthManager::tear_down() { @@ -475,6 +479,7 @@ tl_object_ptr AuthManager::get_authorization_state_o return make_tl_object( wait_password_state_.hint_, wait_password_state_.has_recovery_, wait_password_state_.email_address_pattern_); case State::LoggingOut: + case State::DestroyingKeys: return make_tl_object(); case State::Closing: return make_tl_object(); @@ -655,7 +660,7 @@ void AuthManager::logout(uint64 query_id) { if (state_ == State::Closing) { return on_query_error(query_id, Status::Error(8, "Already logged out")); } - if (state_ == State::LoggingOut) { + if (state_ == State::LoggingOut || state_ == State::DestroyingKeys) { return on_query_error(query_id, Status::Error(8, "Already logging out")); } on_new_query(query_id); @@ -663,7 +668,6 @@ void AuthManager::logout(uint64 query_id) { update_state(State::LoggingOut); // TODO: could skip full logout if still no authorization // TODO: send auth.cancelCode if state_ == State::WaitCode - send_closure_later(G()->td(), &Td::destroy); on_query_ok(); } else { LOG(INFO) << "Logging out"; @@ -844,11 +848,29 @@ void AuthManager::on_log_out_result(NetQueryPtr &result) { } LOG_IF(ERROR, status.is_error()) << "auth.logOut failed: " << status; // state_ will stay logout, so no queries will work. - send_closure_later(G()->td(), &Td::destroy); + destroy_auth_keys(); if (query_id_ != 0) { on_query_ok(); } } +void AuthManager::on_authorization_lost() { + destroy_auth_keys(); +} + +void AuthManager::destroy_auth_keys() { + if (state_ == State::Closing) { + return; + } + update_state(State::DestroyingKeys); + auto promise = PromiseCreator::lambda( + [](Unit) { + G()->net_query_dispatcher().destroy_auth_keys(PromiseCreator::lambda( + [](Unit) { send_closure_later(G()->td(), &Td::destroy); }, PromiseCreator::Ignore())); + }, + PromiseCreator::Ignore()); + G()->td_db()->get_binlog_pmc()->set("auth", "destroy"); + G()->td_db()->get_binlog_pmc()->force_sync(std::move(promise)); +} void AuthManager::on_delete_account_result(NetQueryPtr &result) { Status status; @@ -871,8 +893,7 @@ void AuthManager::on_delete_account_result(NetQueryPtr &result) { on_query_error(std::move(status)); } } else { - update_state(State::LoggingOut); - send_closure_later(G()->td(), &Td::destroy); + destroy_auth_keys(); if (query_id_ != 0) { on_query_ok(); } diff --git a/td/telegram/AuthManager.h b/td/telegram/AuthManager.h index 6d6181858..1e8cccf85 100644 --- a/td/telegram/AuthManager.h +++ b/td/telegram/AuthManager.h @@ -166,6 +166,7 @@ class AuthManager : public NetActor { void logout(uint64 query_id); void delete_account(uint64 query_id, const string &reason); + void on_authorization_lost(); void on_closing(); // can return nullptr if state isn't initialized yet @@ -181,6 +182,7 @@ class AuthManager : public NetActor { WaitPassword, Ok, LoggingOut, + DestroyingKeys, Closing } state_ = State::None; enum class NetQueryType : int32 { @@ -291,6 +293,8 @@ class AuthManager : public NetActor { void on_query_ok(); void start_net_query(NetQueryType net_query_type, NetQueryPtr net_query); + void destroy_auth_keys(); + void on_send_code_result(NetQueryPtr &result); void on_get_password_result(NetQueryPtr &result); void on_request_password_recovery_result(NetQueryPtr &result); diff --git a/td/telegram/ConfigManager.cpp b/td/telegram/ConfigManager.cpp index 83b87f016..2cd6751ed 100644 --- a/td/telegram/ConfigManager.cpp +++ b/td/telegram/ConfigManager.cpp @@ -311,9 +311,10 @@ ActorOwn<> get_full_config(DcId dc_id, IPAddress ip_address, Promise if (G()->is_test_dc()) { int_dc_id += 10000; } - session_ = create_actor("ConfigSession", std::move(session_callback), std::move(auth_data), int_dc_id, - false /*is_main*/, true /*use_pfs*/, false /*is_cdn*/, mtproto::AuthKey(), - std::vector()); + session_ = + create_actor("ConfigSession", std::move(session_callback), std::move(auth_data), int_dc_id, + false /*is_main*/, true /*use_pfs*/, false /*is_cdn*/, false /*need_destroy_auth_key*/, + mtproto::AuthKey(), std::vector()); auto query = G()->net_query_creator().create(create_storer(telegram_api::help_getConfig()), DcId::empty(), NetQuery::Type::Common, NetQuery::AuthFlag::Off, NetQuery::GzipFlag::On, 60 * 60 * 24); diff --git a/td/telegram/net/DcAuthManager.cpp b/td/telegram/net/DcAuthManager.cpp index ed2f8db20..2f4827bcb 100644 --- a/td/telegram/net/DcAuthManager.cpp +++ b/td/telegram/net/DcAuthManager.cpp @@ -192,11 +192,34 @@ void DcAuthManager::dc_loop(DcInfo &dc) { } } +void DcAuthManager::destroy(Promise<> promise) { + destroy_promise_ = std::move(promise); + loop(); +} + +void DcAuthManager::destroy_loop() { + if (!destroy_promise_) { + return; + } + bool is_ready{true}; + for (auto &dc : dcs_) { + is_ready &= dc.auth_state == AuthState::Empty; + } + + if (is_ready) { + LOG(INFO) << "Destroy auth keys loop is ready, all keys are destroyed"; + destroy_promise_.set_value(Unit()); + } else { + LOG(ERROR) << "NOT READY"; + } +} + void DcAuthManager::loop() { if (close_flag_) { VLOG(dc) << "Skip loop because close_flag"; return; } + destroy_loop(); if (!main_dc_id_.is_exact()) { VLOG(dc) << "Skip loop because main_dc_id is unknown"; return; @@ -205,6 +228,7 @@ void DcAuthManager::loop() { if (!main_dc || main_dc->auth_state != AuthState::OK) { if (was_auth_) { G()->shared_config().set_option_boolean("auth", false); + destroy_loop(); } VLOG(dc) << "Skip loop because auth state of main dc " << main_dc_id_.get_raw_id() << " is " << (main_dc != nullptr ? (PSTRING() << main_dc->auth_state) : "unknown"); diff --git a/td/telegram/net/DcAuthManager.h b/td/telegram/net/DcAuthManager.h index 58ff39276..52df29a7e 100644 --- a/td/telegram/net/DcAuthManager.h +++ b/td/telegram/net/DcAuthManager.h @@ -10,7 +10,6 @@ #include "td/telegram/net/AuthDataShared.h" #include "td/telegram/net/DcId.h" #include "td/telegram/net/NetQuery.h" - #include "td/actor/actor.h" #include "td/utils/buffer.h" @@ -26,6 +25,7 @@ class DcAuthManager : public NetQueryCallback { void add_dc(std::shared_ptr auth_data); void update_main_dc(DcId new_main_dc_id); + void destroy(Promise<> promise); private: struct DcInfo { @@ -43,9 +43,10 @@ class DcAuthManager : public NetQueryCallback { ActorShared<> parent_; std::vector dcs_; - bool was_auth_ = false; + bool was_auth_{false}; DcId main_dc_id_; - bool close_flag_ = false; + bool close_flag_{false}; + Promise<> destroy_promise_; DcInfo &get_dc(int32 dc_id); DcInfo *find_dc(int32 dc_id); @@ -55,6 +56,7 @@ class DcAuthManager : public NetQueryCallback { void on_result(NetQueryPtr result) override; void dc_loop(DcInfo &dc); + void destroy_loop(); void loop() override; }; diff --git a/td/telegram/net/NetQueryDispatcher.cpp b/td/telegram/net/NetQueryDispatcher.cpp index a71ea7621..8ce6e50d7 100644 --- a/td/telegram/net/NetQueryDispatcher.cpp +++ b/td/telegram/net/NetQueryDispatcher.cpp @@ -127,12 +127,14 @@ Status NetQueryDispatcher::wait_dc_init(DcId dc_id, bool force) { if (should_init) { std::lock_guard guard(main_dc_id_mutex_); - if (stop_flag_.load(std::memory_order_relaxed)) { + if (stop_flag_.load(std::memory_order_relaxed) || need_destroy_auth_key_) { return Status::Error("Closing"); } // init dc + dc.id_ = dc_id; decltype(common_public_rsa_key_) public_rsa_key; bool is_cdn = false; + bool need_destroy_key = false; if (dc_id.is_internal()) { public_rsa_key = common_public_rsa_key_; } else { @@ -150,18 +152,18 @@ Status NetQueryDispatcher::wait_dc_init(DcId dc_id, bool force) { int32 upload_session_count = raw_dc_id != 2 && raw_dc_id != 4 ? 8 : 4; int32 download_session_count = 2; int32 download_small_session_count = 2; - dc.main_session_ = - create_actor(PSLICE() << "SessionMultiProxy:" << raw_dc_id << ":main", session_count, - auth_data, raw_dc_id == main_dc_id_, use_pfs, false, false, is_cdn); + dc.main_session_ = create_actor(PSLICE() << "SessionMultiProxy:" << raw_dc_id << ":main", + session_count, auth_data, raw_dc_id == main_dc_id_, use_pfs, + false, false, is_cdn, need_destroy_key); dc.upload_session_ = create_actor_on_scheduler( PSLICE() << "SessionMultiProxy:" << raw_dc_id << ":upload", slow_net_scheduler_id, upload_session_count, - auth_data, false, use_pfs, false, true, is_cdn); + auth_data, false, use_pfs, false, true, is_cdn, need_destroy_key); dc.download_session_ = create_actor_on_scheduler( PSLICE() << "SessionMultiProxy:" << raw_dc_id << ":download", slow_net_scheduler_id, download_session_count, - auth_data, false, use_pfs, true, true, is_cdn); + auth_data, false, use_pfs, true, true, is_cdn, need_destroy_key); dc.download_small_session_ = create_actor_on_scheduler( PSLICE() << "SessionMultiProxy:" << raw_dc_id << ":download_small", slow_net_scheduler_id, - download_small_session_count, auth_data, false, use_pfs, true, true, is_cdn); + download_small_session_count, auth_data, false, use_pfs, true, true, is_cdn, need_destroy_key); dc.is_inited_ = true; if (dc_id.is_internal()) { send_closure_later(dc_auth_manager_, &DcAuthManager::add_dc, std::move(auth_data)); @@ -212,6 +214,18 @@ void NetQueryDispatcher::update_session_count() { } } } +void NetQueryDispatcher::destroy_auth_keys(Promise<> promise) { + std::lock_guard guard(main_dc_id_mutex_); + LOG(INFO) << "Destory auth keys"; + need_destroy_auth_key_ = true; + for (size_t i = 1; i < MAX_DC_COUNT; i++) { + if (is_dc_inited(narrow_cast(i)) && dcs_[i - 1].id_.is_internal()) { + send_closure_later(dcs_[i - 1].main_session_, &SessionMultiProxy::update_destroy_auth_key, + need_destroy_auth_key_); + } + } + send_closure_later(dc_auth_manager_, &DcAuthManager::destroy, std::move(promise)); +} void NetQueryDispatcher::update_use_pfs() { std::lock_guard guard(main_dc_id_mutex_); diff --git a/td/telegram/net/NetQueryDispatcher.h b/td/telegram/net/NetQueryDispatcher.h index 39ae0de0c..f6930bc25 100644 --- a/td/telegram/net/NetQueryDispatcher.h +++ b/td/telegram/net/NetQueryDispatcher.h @@ -11,6 +11,7 @@ #include "td/telegram/net/NetQuery.h" #include "td/actor/actor.h" +#include "td/actor/PromiseFuture.h" #include "td/utils/common.h" #include "td/utils/ScopeGuard.h" @@ -46,6 +47,7 @@ class NetQueryDispatcher { void stop(); void update_session_count(); + void destroy_auth_keys(Promise<> promise); void update_use_pfs(); void update_mtproto_header(); @@ -57,9 +59,11 @@ class NetQueryDispatcher { private: std::atomic stop_flag_{false}; + bool need_destroy_auth_key_{false}; ActorOwn delayer_; ActorOwn dc_auth_manager_; struct Dc { + DcId id_; std::atomic is_valid_{false}; std::atomic is_inited_{false}; // TODO: cache in scheduler local storage :D diff --git a/td/telegram/net/Session.cpp b/td/telegram/net/Session.cpp index 9daa0ef98..bf1a15464 100644 --- a/td/telegram/net/Session.cpp +++ b/td/telegram/net/Session.cpp @@ -108,10 +108,15 @@ class GenAuthKeyActor : public Actor { } // namespace detail Session::Session(unique_ptr callback, std::shared_ptr shared_auth_data, int32 dc_id, - bool is_main, bool use_pfs, bool is_cdn, const mtproto::AuthKey &tmp_auth_key, + bool is_main, bool use_pfs, bool is_cdn, bool need_destroy, const mtproto::AuthKey &tmp_auth_key, std::vector server_salts) : dc_id_(dc_id), is_main_(is_main), is_cdn_(is_cdn) { VLOG(dc) << "Start connection"; + need_destroy_ = need_destroy; + if (need_destroy) { + use_pfs = false; + CHECK(!is_cdn); + } shared_auth_data_ = std::move(shared_auth_data); auth_data_.set_use_pfs(use_pfs); @@ -141,6 +146,10 @@ Session::Session(unique_ptr callback, std::shared_ptr last_activity_timestamp_ = Time::now(); } +bool Session::can_destroy_auth_key() { + return need_destroy_; +} + void Session::start_up() { class StateCallback : public StateManager::Callback { public: @@ -415,6 +424,9 @@ void Session::on_closed(Status status) { auth_data_.drop_main_auth_key(); on_auth_key_updated(); on_session_failed(std::move(status)); + } else if (need_destroy_) { + auth_data_.drop_main_auth_key(); + on_auth_key_updated(); } } @@ -774,6 +786,11 @@ void Session::on_message_info(uint64 id, int32 state, uint64 answer_id, int32 an current_info_->connection->resend_answer(answer_id); } } +Status Session::on_destroy_auth_key() { + auth_data_.drop_main_auth_key(); + on_auth_key_updated(); + return Status::Error("Close because of on_destroy_auth_key"); +} bool Session::has_queries() const { return !pending_invoke_after_queries_.empty() || !pending_queries_.empty() || !sent_queries_.empty(); @@ -993,7 +1010,8 @@ bool Session::need_send_bind_key() { return auth_data_.use_pfs() && !auth_data_.get_bind_flag() && auth_data_.get_tmp_auth_key().id() != tmp_auth_key_id_; } bool Session::need_send_query() { - return !close_flag_ && (!auth_data_.use_pfs() || auth_data_.get_bind_flag()) && !pending_queries_.empty(); + return !close_flag_ && (!auth_data_.use_pfs() || auth_data_.get_bind_flag()) && !pending_queries_.empty() && + !can_destroy_auth_key(); } bool Session::connection_send_bind_key(ConnectionInfo *info) { CHECK(info->state != ConnectionInfo::State::Empty); @@ -1116,6 +1134,9 @@ void Session::create_gen_auth_key_actor(HandshakeId handshake_id) { } void Session::auth_loop() { + if (can_destroy_auth_key()) { + return; + } if (auth_data_.need_main_auth_key()) { create_gen_auth_key_actor(MainAuthKeyHandshake); } @@ -1133,7 +1154,8 @@ void Session::loop() { if (cached_connection_timestamp_ < Time::now_cached() - 10) { cached_connection_.reset(); } - if (!is_main_ && !has_queries() && last_activity_timestamp_ < Time::now_cached() - ACTIVITY_TIMEOUT) { + if (!is_main_ && !has_queries() && !need_destroy_ && + last_activity_timestamp_ < Time::now_cached() - ACTIVITY_TIMEOUT) { on_session_failed(Status::OK()); } @@ -1179,6 +1201,11 @@ void Session::loop() { connection_send_bind_key(&main_connection_); need_flush = true; } + if (can_destroy_auth_key()) { + if (main_connection_.connection) { + main_connection_.connection->destroy_key(); + } + } } if (need_flush) { connection_flush(&main_connection_); diff --git a/td/telegram/net/Session.h b/td/telegram/net/Session.h index d7c7544d1..eb422f727 100644 --- a/td/telegram/net/Session.h +++ b/td/telegram/net/Session.h @@ -62,7 +62,7 @@ class Session final }; Session(unique_ptr callback, std::shared_ptr shared_auth_data, int32 dc_id, bool is_main, - bool use_pfs, bool is_cdn, const mtproto::AuthKey &tmp_auth_key, + bool use_pfs, bool is_cdn, bool need_destroy, const mtproto::AuthKey &tmp_auth_key, std::vector server_salts); void send(NetQueryPtr &&query); void on_network(bool network_flag, uint32 network_generation); @@ -101,6 +101,7 @@ class Session final enum class Mode : int8 { Tcp, Http } mode_ = Mode::Tcp; bool is_main_; bool is_cdn_; + bool need_destroy_; bool was_on_network_ = false; bool network_flag_ = false; uint32 network_generation_ = 0; @@ -193,6 +194,8 @@ class Session final void on_message_info(uint64 id, int32 state, uint64 answer_id, int32 answer_size) override; + Status on_destroy_auth_key() override; + void flush_pending_invoke_after_queries(); bool has_queries() const; @@ -221,6 +224,7 @@ class Session final void connection_send_query(ConnectionInfo *info, NetQueryPtr &&net_query, uint64 message_id = 0); bool need_send_bind_key(); bool need_send_query(); + bool can_destroy_auth_key(); bool connection_send_bind_key(ConnectionInfo *info); void on_result(NetQueryPtr query) override; diff --git a/td/telegram/net/SessionMultiProxy.cpp b/td/telegram/net/SessionMultiProxy.cpp index e0a557b46..4eb72e5d9 100644 --- a/td/telegram/net/SessionMultiProxy.cpp +++ b/td/telegram/net/SessionMultiProxy.cpp @@ -18,14 +18,16 @@ SessionMultiProxy::SessionMultiProxy() = default; SessionMultiProxy::~SessionMultiProxy() = default; SessionMultiProxy::SessionMultiProxy(int32 session_count, std::shared_ptr shared_auth_data, - bool is_main, bool use_pfs, bool allow_media_only, bool is_media, bool is_cdn) + bool is_main, bool use_pfs, bool allow_media_only, bool is_media, bool is_cdn, + bool need_destroy_auth_key) : session_count_(session_count) , auth_data_(std::move(shared_auth_data)) , is_main_(is_main) , use_pfs_(use_pfs) , allow_media_only_(allow_media_only) , is_media_(is_media) - , is_cdn_(is_cdn) { + , is_cdn_(is_cdn) + , need_destroy_auth_key_(need_destroy_auth_key) { if (allow_media_only_) { CHECK(is_media_); } @@ -52,6 +54,13 @@ void SessionMultiProxy::update_main_flag(bool is_main) { send_closure(session, &SessionProxy::update_main_flag, is_main); } } + +void SessionMultiProxy::update_destroy_auth_key(bool need_destroy_auth_key) { + need_destroy_auth_key_ = need_destroy_auth_key; + for (auto &session : sessions_) { + send_closure(session, &SessionProxy::update_destroy, need_destroy_auth_key_); + } +} void SessionMultiProxy::update_session_count(int32 session_count) { update_options(session_count, use_pfs_); } @@ -110,7 +119,8 @@ void SessionMultiProxy::init() { string name = PSTRING() << "Session" << get_name().substr(Slice("SessionMulti").size()) << format::cond(session_count_ > 1, format::concat("#", i)); sessions_.push_back(create_actor(name, auth_data_, is_main_, allow_media_only_, is_media_, - get_pfs_flag(), is_main_ && i != 0, is_cdn_)); + get_pfs_flag(), is_main_ && i != 0, is_cdn_, + need_destroy_auth_key_)); } } diff --git a/td/telegram/net/SessionMultiProxy.h b/td/telegram/net/SessionMultiProxy.h index c318a22ea..d0423f8f9 100644 --- a/td/telegram/net/SessionMultiProxy.h +++ b/td/telegram/net/SessionMultiProxy.h @@ -24,7 +24,7 @@ class SessionMultiProxy : public Actor { SessionMultiProxy &operator=(const SessionMultiProxy &other) = delete; ~SessionMultiProxy() override; SessionMultiProxy(int32 session_count, std::shared_ptr shared_auth_data, bool is_main, bool use_pfs, - bool allow_media_only, bool is_media, bool is_cdn); + bool allow_media_only, bool is_media, bool is_cdn, bool need_destroy_auth_key); void send(NetQueryPtr query); void update_main_flag(bool is_main); @@ -34,6 +34,8 @@ class SessionMultiProxy : public Actor { void update_options(int32 session_count, bool use_pfs); void update_mtproto_header(); + void update_destroy_auth_key(bool need_destroy_auth_key); + private: size_t pos_ = 0; int32 session_count_ = 0; @@ -43,6 +45,7 @@ class SessionMultiProxy : public Actor { bool allow_media_only_ = false; bool is_media_ = false; bool is_cdn_ = false; + bool need_destroy_auth_key_ = false; std::vector> sessions_; void start_up() override; diff --git a/td/telegram/net/SessionProxy.cpp b/td/telegram/net/SessionProxy.cpp index 8f5582159..86b9f07d3 100644 --- a/td/telegram/net/SessionProxy.cpp +++ b/td/telegram/net/SessionProxy.cpp @@ -63,14 +63,15 @@ class SessionCallback : public Session::Callback { }; SessionProxy::SessionProxy(std::shared_ptr shared_auth_data, bool is_main, bool allow_media_only, - bool is_media, bool use_pfs, bool need_wait_for_key, bool is_cdn) + bool is_media, bool use_pfs, bool need_wait_for_key, bool is_cdn, bool need_destroy) : auth_data_(std::move(shared_auth_data)) , is_main_(is_main) , allow_media_only_(allow_media_only) , is_media_(is_media) , use_pfs_(use_pfs) , need_wait_for_key_(need_wait_for_key) - , is_cdn_(is_cdn) { + , is_cdn_(is_cdn) + , need_destroy_(need_destroy) { } void SessionProxy::start_up() { @@ -91,9 +92,7 @@ void SessionProxy::start_up() { }; auth_state_ = auth_data_->get_auth_state().first; auth_data_->add_auth_key_listener(make_unique(actor_shared(this))); - if (is_main_ && !need_wait_for_key_) { - open_session(); - } + open_session(); } void SessionProxy::tear_down() { @@ -110,9 +109,7 @@ void SessionProxy::send(NetQueryPtr query) { pending_queries_.emplace_back(std::move(query)); return; } - if (session_.empty()) { - open_session(true); - } + open_session(true); query->debug(PSTRING() << get_name() << ": sent to session"); send_closure(session_, &Session::send, std::move(query)); } @@ -127,6 +124,12 @@ void SessionProxy::update_main_flag(bool is_main) { open_session(); } +void SessionProxy::update_destroy(bool need_destroy) { + need_destroy_ = need_destroy; + close_session(); + open_session(); +} + void SessionProxy::on_failed() { if (session_generation_ != get_link_token()) { return; @@ -148,9 +151,19 @@ void SessionProxy::close_session() { session_generation_++; } void SessionProxy::open_session(bool force) { - if (!force && !is_main_) { + if (!session_.empty()) { return; } + if (auth_state_ == AuthState::Empty && need_destroy_) { + return; + } + if (auth_state_ != AuthState::OK && need_wait_for_key_) { + return; + } + if (!is_main_ && pending_queries_.empty() && !need_destroy_) { + return; + } + CHECK(session_.empty()); auto dc_id = auth_data_->dc_id(); string name = PSTRING() << "Session" << get_name().substr(Slice("SessionProxy").size()); @@ -166,20 +179,12 @@ void SessionProxy::open_session(bool force) { session_ = create_actor( name, make_unique(actor_shared(this, session_generation_), dc_id, allow_media_only_, is_media_, hash), - auth_data_, int_dc_id, is_main_, use_pfs_, is_cdn_, tmp_auth_key_, server_salts_); + auth_data_, int_dc_id, is_main_, use_pfs_, is_cdn_, need_destroy_, tmp_auth_key_, server_salts_); } void SessionProxy::update_auth_state() { auth_state_ = auth_data_->get_auth_state().first; - if (pending_queries_.empty() && !need_wait_for_key_) { - return; - } - if (auth_state_ != AuthState::OK) { - return; - } - if (session_.empty()) { - open_session(true); - } + open_session(true); for (auto &query : pending_queries_) { query->debug(PSTRING() << get_name() << ": sent to session"); send_closure(session_, &Session::send, std::move(query)); diff --git a/td/telegram/net/SessionProxy.h b/td/telegram/net/SessionProxy.h index d61c487d9..40e97e553 100644 --- a/td/telegram/net/SessionProxy.h +++ b/td/telegram/net/SessionProxy.h @@ -22,11 +22,12 @@ class SessionProxy : public Actor { friend class SessionCallback; SessionProxy(std::shared_ptr shared_auth_data, bool is_main, bool allow_media_only, bool is_media, - bool use_pfs, bool need_wait_for_key, bool is_cdn); + bool use_pfs, bool need_wait_for_key, bool is_cdn, bool need_destroy); void send(NetQueryPtr query); void update_main_flag(bool is_main); void update_mtproto_header(); + void update_destroy(bool need_destroy); private: std::shared_ptr auth_data_; @@ -39,6 +40,7 @@ class SessionProxy : public Actor { std::vector server_salts_; bool need_wait_for_key_; bool is_cdn_; + bool need_destroy_; ActorOwn session_; std::vector pending_queries_; uint64 session_generation_ = 1;