diff --git a/td/telegram/ConfigManager.cpp b/td/telegram/ConfigManager.cpp index d359b3b21..ef1235761 100644 --- a/td/telegram/ConfigManager.cpp +++ b/td/telegram/ConfigManager.cpp @@ -221,6 +221,9 @@ ActorOwn<> get_full_config(DcId dc_id, IPAddress ip_address, Promise void on_tmp_auth_key_updated(mtproto::AuthKey auth_key) final { // nop } + void on_result(NetQueryPtr net_query) final { + G()->net_query_dispatcher().dispatch(std::move(net_query)); + } private: ActorShared<> parent_; diff --git a/td/telegram/net/Session.cpp b/td/telegram/net/Session.cpp index 5bcb000b8..8a563602a 100644 --- a/td/telegram/net/Session.cpp +++ b/td/telegram/net/Session.cpp @@ -275,7 +275,7 @@ void Session::return_query(NetQueryPtr &&query) { last_activity_timestamp_ = Time::now(); query->set_session_id(0); - G()->net_query_dispatcher().dispatch(std::move(query)); + callback_->on_result(std::move(query)); } void Session::flush_pending_invoke_after_queries() { diff --git a/td/telegram/net/Session.h b/td/telegram/net/Session.h index 97bb8dbd6..3b843b9bc 100644 --- a/td/telegram/net/Session.h +++ b/td/telegram/net/Session.h @@ -60,6 +60,7 @@ class Session final virtual void on_server_salt_updated(std::vector server_salts) { } // one still have to call close after on_closed + virtual void on_result(NetQueryPtr net_query) = 0; }; Session(unique_ptr callback, std::shared_ptr shared_auth_data, int32 dc_id, bool is_main, diff --git a/td/telegram/net/SessionMultiProxy.cpp b/td/telegram/net/SessionMultiProxy.cpp index f82505fbc..007afe878 100644 --- a/td/telegram/net/SessionMultiProxy.cpp +++ b/td/telegram/net/SessionMultiProxy.cpp @@ -45,20 +45,21 @@ void SessionMultiProxy::send(NetQueryPtr query) { } } query->debug(PSTRING() << get_name() << ": send to proxy #" << pos); - send_closure(sessions_[pos], &SessionProxy::send, std::move(query)); + sessions_[pos].queries_count++; + send_closure(sessions_[pos].proxy, &SessionProxy::send, std::move(query)); } void SessionMultiProxy::update_main_flag(bool is_main) { LOG(INFO) << "Update " << get_name() << " is_main to " << is_main; is_main_ = is_main; for (auto &session : sessions_) { - send_closure(session, &SessionProxy::update_main_flag, is_main); + send_closure(session.proxy, &SessionProxy::update_main_flag, is_main); } } void SessionMultiProxy::update_destroy_auth_key(bool need_destroy_auth_key) { need_destroy_auth_key_ = need_destroy_auth_key; - send_closure(sessions_[0], &SessionProxy::update_destroy, need_destroy_auth_key_); + send_closure(sessions_[0].proxy, &SessionProxy::update_destroy, need_destroy_auth_key_); } void SessionMultiProxy::update_session_count(int32 session_count) { update_options(session_count, use_pfs_); @@ -97,7 +98,7 @@ void SessionMultiProxy::update_options(int32 session_count, bool use_pfs) { void SessionMultiProxy::update_mtproto_header() { for (auto &session : sessions_) { - send_closure_later(session, &SessionProxy::update_mtproto_header); + send_closure_later(session.proxy, &SessionProxy::update_mtproto_header); } } @@ -110,6 +111,7 @@ bool SessionMultiProxy::get_pfs_flag() const { } void SessionMultiProxy::init() { + sessions_generation_++; sessions_.clear(); if (is_main_) { LOG(WARNING) << tag("session_count", session_count_); @@ -117,9 +119,35 @@ void SessionMultiProxy::init() { for (int32 i = 0; i < session_count_; i++) { string name = PSTRING() << "Session" << get_name().substr(Slice("SessionMulti").size()) << format::cond(session_count_ > 1, format::concat("#", i)); - sessions_.push_back(create_actor(name, auth_data_, is_main_, allow_media_only_, is_media_, - get_pfs_flag(), is_cdn_, need_destroy_auth_key_ && i == 0)); + + SessionInfo info; + class Callback : public SessionProxy::Callback { + public: + Callback(ActorId parent, uint32 generation, int32 session_id) + : parent_(parent), generation_(generation), session_id_(session_id) { + } + void on_query_finished() override { + send_closure(parent_, &SessionMultiProxy::on_query_finished, generation_, session_id_); + } + + private: + ActorId parent_; + uint32 generation_; + int32 session_id_; + }; + info.proxy = create_actor(name, make_unique(actor_id(this), sessions_generation_, i), + auth_data_, is_main_, allow_media_only_, is_media_, get_pfs_flag(), is_cdn_, + need_destroy_auth_key_ && i == 0); + sessions_.push_back(std::move(info)); } } +void SessionMultiProxy::on_query_finished(uint32 generation, int session_id) { + if (generation != sessions_generation_) { + return; + } + sessions_.at(session_id).queries_count--; + CHECK(sessions_.at(session_id).queries_count >= 0); +} + } // namespace td diff --git a/td/telegram/net/SessionMultiProxy.h b/td/telegram/net/SessionMultiProxy.h index 45fba1830..2883c8af4 100644 --- a/td/telegram/net/SessionMultiProxy.h +++ b/td/telegram/net/SessionMultiProxy.h @@ -46,7 +46,12 @@ class SessionMultiProxy : public Actor { bool is_media_ = false; bool is_cdn_ = false; bool need_destroy_auth_key_ = false; - std::vector> sessions_; + struct SessionInfo { + ActorOwn proxy; + int queries_count{0}; + }; + uint32 sessions_generation_{0}; + std::vector sessions_; void start_up() override; void init(); @@ -54,6 +59,8 @@ class SessionMultiProxy : public Actor { bool get_pfs_flag() const; void update_auth_state(); + + void on_query_finished(uint32 generation, int session_id); }; } // namespace td diff --git a/td/telegram/net/SessionProxy.cpp b/td/telegram/net/SessionProxy.cpp index 7b2eabbb7..af0446cab 100644 --- a/td/telegram/net/SessionProxy.cpp +++ b/td/telegram/net/SessionProxy.cpp @@ -55,6 +55,11 @@ class SessionCallback : public Session::Callback { send_closure(parent_, &SessionProxy::on_server_salt_updated, std::move(server_salts)); } + void on_result(NetQueryPtr query) override { + G()->net_query_dispatcher().dispatch(std::move(query)); + send_closure(parent_, &SessionProxy::on_query_finished); + } + private: ActorShared parent_; DcId dc_id_; @@ -63,9 +68,11 @@ class SessionCallback : public Session::Callback { size_t hash_ = 0; }; -SessionProxy::SessionProxy(std::shared_ptr shared_auth_data, bool is_main, bool allow_media_only, - bool is_media, bool use_pfs, bool is_cdn, bool need_destroy) - : auth_data_(std::move(shared_auth_data)) +SessionProxy::SessionProxy(unique_ptr callback, std::shared_ptr shared_auth_data, + bool is_main, bool allow_media_only, bool is_media, bool use_pfs, bool is_cdn, + bool need_destroy) + : callback_(std::move(callback)) + , auth_data_(std::move(shared_auth_data)) , is_main_(is_main) , allow_media_only_(allow_media_only) , is_media_(is_media) @@ -225,4 +232,8 @@ void SessionProxy::on_server_salt_updated(std::vector serve server_salts_ = std::move(server_salts); } +void SessionProxy::on_query_finished() { + callback_->on_query_finished(); +} + } // namespace td diff --git a/td/telegram/net/SessionProxy.h b/td/telegram/net/SessionProxy.h index d9512cdca..e7348ff7b 100644 --- a/td/telegram/net/SessionProxy.h +++ b/td/telegram/net/SessionProxy.h @@ -23,9 +23,14 @@ class Session; class SessionProxy : public Actor { public: friend class SessionCallback; + class Callback { + public: + virtual ~Callback() = default; + virtual void on_query_finished() = 0; + }; - SessionProxy(std::shared_ptr shared_auth_data, bool is_main, bool allow_media_only, bool is_media, - bool use_pfs, bool is_cdn, bool need_destroy); + SessionProxy(unique_ptr callback, std::shared_ptr shared_auth_data, bool is_main, + bool allow_media_only, bool is_media, bool use_pfs, bool is_cdn, bool need_destroy); void send(NetQueryPtr query); void update_main_flag(bool is_main); @@ -33,6 +38,7 @@ class SessionProxy : public Actor { void update_destroy(bool need_destroy); private: + unique_ptr callback_; std::shared_ptr auth_data_; AuthState auth_state_; bool is_main_; @@ -56,6 +62,8 @@ class SessionProxy : public Actor { void on_tmp_auth_key_updated(mtproto::AuthKey auth_key); void on_server_salt_updated(std::vector server_salts); + void on_query_finished(); + void start_up() override; void tear_down() override; };