// // Copyright Aliaksei Levin (levlam@telegram.org), Arseny Smirnov (arseny30@gmail.com) 2014-2019 // // Distributed under the Boost Software License, Version 1.0. (See accompanying // file LICENSE_1_0.txt or copy at http://www.boost.org/LICENSE_1_0.txt) // #include "td/telegram/net/SessionMultiProxy.h" #include "td/telegram/net/SessionProxy.h" #include "td/utils/common.h" #include "td/utils/format.h" #include "td/utils/logging.h" #include "td/utils/Slice.h" #include <algorithm> namespace td { SessionMultiProxy::SessionMultiProxy() = default; SessionMultiProxy::~SessionMultiProxy() = default; SessionMultiProxy::SessionMultiProxy(int32 session_count, std::shared_ptr<AuthDataShared> shared_auth_data, bool is_main, bool use_pfs, bool allow_media_only, bool is_media, bool is_cdn, bool need_destroy_auth_key) : session_count_(session_count) , auth_data_(std::move(shared_auth_data)) , is_main_(is_main) , use_pfs_(use_pfs) , allow_media_only_(allow_media_only) , is_media_(is_media) , is_cdn_(is_cdn) , need_destroy_auth_key_(need_destroy_auth_key) { if (allow_media_only_) { CHECK(is_media_); } } void SessionMultiProxy::send(NetQueryPtr query) { size_t pos = 0; // TODO temporary hack with total_timeout_limit if (query->auth_flag() == NetQuery::AuthFlag::On && query->total_timeout_limit > 50) { if (query->session_rand()) { pos = query->session_rand() % sessions_.size(); } else { pos = std::min_element(sessions_.begin(), sessions_.end(), [](const auto &a, const auto &b) { return a.queries_count < b.queries_count; }) - sessions_.begin(); } } query->debug(PSTRING() << get_name() << ": send to proxy #" << pos); sessions_[pos].queries_count++; send_closure(sessions_[pos].proxy, &SessionProxy::send, std::move(query)); } void SessionMultiProxy::update_main_flag(bool is_main) { LOG(INFO) << "Update " << get_name() << " is_main to " << is_main; is_main_ = is_main; for (auto &session : sessions_) { send_closure(session.proxy, &SessionProxy::update_main_flag, is_main); } } void SessionMultiProxy::update_destroy_auth_key(bool need_destroy_auth_key) { need_destroy_auth_key_ = need_destroy_auth_key; send_closure(sessions_[0].proxy, &SessionProxy::update_destroy, need_destroy_auth_key_); } void SessionMultiProxy::update_session_count(int32 session_count) { update_options(session_count, use_pfs_); } void SessionMultiProxy::update_use_pfs(bool use_pfs) { update_options(session_count_, use_pfs); } void SessionMultiProxy::update_options(int32 session_count, bool use_pfs) { bool changed = false; if (session_count != session_count_) { session_count_ = session_count; if (session_count_ <= 0) { session_count_ = 1; } if (session_count_ > 100) { session_count_ = 100; } LOG(INFO) << "Update " << get_name() << " session_count to " << session_count_; changed = true; } if (use_pfs != use_pfs_) { bool old_pfs_flag = get_pfs_flag(); use_pfs_ = use_pfs; if (old_pfs_flag != get_pfs_flag()) { LOG(INFO) << "Update " << get_name() << " use_pfs to " << use_pfs_; changed = true; } } if (changed) { init(); } } void SessionMultiProxy::update_mtproto_header() { for (auto &session : sessions_) { send_closure_later(session.proxy, &SessionProxy::update_mtproto_header); } } void SessionMultiProxy::start_up() { init(); } bool SessionMultiProxy::get_pfs_flag() const { return use_pfs_; } void SessionMultiProxy::init() { sessions_generation_++; sessions_.clear(); if (is_main_) { LOG(WARNING) << tag("session_count", session_count_); } for (int32 i = 0; i < session_count_; i++) { string name = PSTRING() << "Session" << get_name().substr(Slice("SessionMulti").size()) << format::cond(session_count_ > 1, format::concat("#", i)); SessionInfo info; class Callback : public SessionProxy::Callback { public: Callback(ActorId<SessionMultiProxy> parent, uint32 generation, int32 session_id) : parent_(parent), generation_(generation), session_id_(session_id) { } void on_query_finished() override { send_closure(parent_, &SessionMultiProxy::on_query_finished, generation_, session_id_); } private: ActorId<SessionMultiProxy> parent_; uint32 generation_; int32 session_id_; }; info.proxy = create_actor<SessionProxy>(name, make_unique<Callback>(actor_id(this), sessions_generation_, i), auth_data_, is_main_, allow_media_only_, is_media_, get_pfs_flag(), is_cdn_, need_destroy_auth_key_ && i == 0); sessions_.push_back(std::move(info)); } } void SessionMultiProxy::on_query_finished(uint32 generation, int session_id) { if (generation != sessions_generation_) { return; } sessions_.at(session_id).queries_count--; CHECK(sessions_.at(session_id).queries_count >= 0); } } // namespace td