// // Copyright Aliaksei Levin (levlam@telegram.org), Arseny Smirnov (arseny30@gmail.com) 2014-2023 // // Distributed under the Boost Software License, Version 1.0. (See accompanying // file LICENSE_1_0.txt or copy at http://www.boost.org/LICENSE_1_0.txt) // #include "td/telegram/SequenceDispatcher.h" #include "td/telegram/Global.h" #include "td/telegram/net/NetQueryDispatcher.h" #include "td/telegram/Td.h" #include "td/actor/PromiseFuture.h" #include "td/utils/algorithm.h" #include "td/utils/ChainScheduler.h" #include "td/utils/logging.h" #include "td/utils/misc.h" #include "td/utils/Promise.h" #include "td/utils/SliceBuilder.h" #include "td/utils/Status.h" #include "td/utils/StringBuilder.h" #include <limits> namespace td { /*** Sequence Dispatcher ***/ // Sends queries with invokeAfter. // // Each query has three states Start/Wait/Finish // // finish_i points to the first not Finish query. // next_i points to the next query to be sent. // // Each query has generation of InvokeAfter chain. // // When query is send, its generation is set to current chain generation. // // When query is failed and its generation is equals to current generation we must start new chain: // increment the generation and set next_i to finish_i. // // last_sent_i points to the last sent query in current chain. // void SequenceDispatcher::send_with_callback(NetQueryPtr query, ActorShared<NetQueryCallback> callback) { cancel_timeout(); query->debug("Waiting at SequenceDispatcher"); auto query_weak_ref = query.get_weak(); data_.push_back(Data{State::Start, std::move(query_weak_ref), std::move(query), std::move(callback), 0, 0, 0}); loop(); } void SequenceDispatcher::check_timeout(Data &data) { if (data.state_ != State::Start) { return; } data.query_->total_timeout_ += data.total_timeout_; data.total_timeout_ = 0; if (data.query_->total_timeout_ > data.query_->total_timeout_limit_) { LOG(WARNING) << "Fail " << data.query_ << " to " << data.query_->source_ << " because total_timeout " << data.query_->total_timeout_ << " is greater than total_timeout_limit " << data.query_->total_timeout_limit_; data.query_->set_error(Status::Error(429, PSLICE() << "Too Many Requests: retry after " << data.last_timeout_)); data.state_ = State::Dummy; try_resend_query(data, std::move(data.query_)); } } void SequenceDispatcher::try_resend_query(Data &data, NetQueryPtr query) { size_t pos = &data - &data_[0]; CHECK(pos < data_.size()); CHECK(data.state_ == State::Dummy); data.state_ = State::Wait; wait_cnt_++; auto token = pos + id_offset_; // TODO: if query is ok, use NetQueryCallback::on_result if (data.callback_.empty()) { do_finish(data); send_closure_later(G()->td(), &Td::on_result, std::move(query)); loop(); return; } auto promise = PromiseCreator::lambda([&, self = actor_shared(this, token)](NetQueryPtr query) mutable { if (!query.empty()) { send_closure(std::move(self), &SequenceDispatcher::on_resend_ok, std::move(query)); } else { send_closure(std::move(self), &SequenceDispatcher::on_resend_error); } }); send_closure(data.callback_, &NetQueryCallback::on_result_resendable, std::move(query), std::move(promise)); } SequenceDispatcher::Data &SequenceDispatcher::data_from_token() { auto token = narrow_cast<size_t>(get_link_token()); auto pos = token - id_offset_; CHECK(pos < data_.size()); auto &data = data_[pos]; CHECK(data.state_ == State::Wait); CHECK(wait_cnt_ > 0); wait_cnt_--; data.state_ = State::Dummy; return data; } void SequenceDispatcher::on_resend_ok(NetQueryPtr query) { auto &data = data_from_token(); data.query_ = std::move(query); do_resend(data); loop(); } void SequenceDispatcher::on_resend_error() { auto &data = data_from_token(); do_finish(data); loop(); } void SequenceDispatcher::do_resend(Data &data) { CHECK(data.state_ == State::Dummy); data.state_ = State::Start; if (data.generation_ == generation_) { next_i_ = finish_i_; generation_++; last_sent_i_ = std::numeric_limits<size_t>::max(); } check_timeout(data); } void SequenceDispatcher::do_finish(Data &data) { CHECK(data.state_ == State::Dummy); data.state_ = State::Finish; if (!parent_.empty()) { send_closure(parent_, &Parent::on_result); } } void SequenceDispatcher::on_result(NetQueryPtr query) { auto &data = data_from_token(); size_t pos = &data - &data_[0]; CHECK(pos < data_.size()); if (query->last_timeout_ != 0) { for (auto i = pos + 1; i < data_.size(); i++) { data_[i].total_timeout_ += query->last_timeout_; data_[i].last_timeout_ = query->last_timeout_; check_timeout(data_[i]); } query->last_timeout_ = 0; } if (query->is_error() && (query->error().code() == NetQuery::ResendInvokeAfter || (query->error().code() == 400 && (query->error().message() == "MSG_WAIT_FAILED" || query->error().message() == "MSG_WAIT_TIMEOUT")))) { VLOG(net_query) << "Resend " << query; query->resend(); query->debug("Waiting at SequenceDispatcher"); data.query_ = std::move(query); do_resend(data); } else { try_resend_query(data, std::move(query)); } loop(); } void SequenceDispatcher::loop() { for (; finish_i_ < data_.size() && data_[finish_i_].state_ == State::Finish; finish_i_++) { } if (next_i_ < finish_i_) { next_i_ = finish_i_; } for (; next_i_ < data_.size() && data_[next_i_].state_ != State::Wait && wait_cnt_ < MAX_SIMULTANEOUS_WAIT; next_i_++) { if (data_[next_i_].state_ == State::Finish) { continue; } NetQueryRef invoke_after; if (last_sent_i_ != std::numeric_limits<size_t>::max() && data_[last_sent_i_].state_ == State::Wait) { invoke_after = data_[last_sent_i_].net_query_ref_; } if (!invoke_after.empty()) { data_[next_i_].query_->set_invoke_after({invoke_after}); } else { data_[next_i_].query_->set_invoke_after({}); } data_[next_i_].query_->last_timeout_ = 0; VLOG(net_query) << "Send " << data_[next_i_].query_; data_[next_i_].query_->debug("send to Td::send_with_callback"); G()->net_query_dispatcher().dispatch_with_callback(std::move(data_[next_i_].query_), actor_shared(this, next_i_ + id_offset_)); data_[next_i_].state_ = State::Wait; wait_cnt_++; data_[next_i_].generation_ = generation_; last_sent_i_ = next_i_; } try_shrink(); if (finish_i_ == data_.size() && !parent_.empty()) { set_timeout_in(5); } } void SequenceDispatcher::try_shrink() { if (finish_i_ * 2 > data_.size() && data_.size() > 5) { CHECK(finish_i_ <= next_i_); data_.erase(data_.begin(), data_.begin() + finish_i_); next_i_ -= finish_i_; if (last_sent_i_ != std::numeric_limits<size_t>::max()) { if (last_sent_i_ >= finish_i_) { last_sent_i_ -= finish_i_; } else { last_sent_i_ = std::numeric_limits<size_t>::max(); } } id_offset_ += finish_i_; finish_i_ = 0; } } void SequenceDispatcher::timeout_expired() { if (finish_i_ != data_.size()) { return; } CHECK(!parent_.empty()); set_timeout_in(1); LOG(DEBUG) << "SequenceDispatcher ready to close"; send_closure(parent_, &Parent::ready_to_close); } void SequenceDispatcher::hangup() { stop(); } void SequenceDispatcher::tear_down() { for (auto &data : data_) { if (data.query_.empty()) { continue; } data.state_ = State::Dummy; data.query_->set_error(Global::request_aborted_error()); do_finish(data); } } void SequenceDispatcher::close_silent() { for (auto &data : data_) { if (!data.query_.empty()) { data.query_->clear(); } } stop(); } void MultiSequenceDispatcherOld::send(NetQueryPtr query) { auto callback = query->move_callback(); auto chain_ids = query->get_chain_ids(); query->set_in_sequence_dispatcher(true); CHECK(all_of(chain_ids, [](auto chain_id) { return chain_id != 0; })); CHECK(!chain_ids.empty()); auto sequence_id = chain_ids[0]; auto it_ok = dispatchers_.emplace(sequence_id, Data{0, ActorOwn<SequenceDispatcher>()}); auto &data = it_ok.first->second; if (it_ok.second) { LOG(DEBUG) << "Create SequenceDispatcher " << sequence_id; data.dispatcher_ = create_actor<SequenceDispatcher>("SequenceDispatcher", actor_shared(this, sequence_id)); } data.cnt_++; query->debug(PSTRING() << "send to SequenceDispatcher " << sequence_id); send_closure(data.dispatcher_, &SequenceDispatcher::send_with_callback, std::move(query), std::move(callback)); } void MultiSequenceDispatcherOld::on_result() { auto it = dispatchers_.find(get_link_token()); CHECK(it != dispatchers_.end()); it->second.cnt_--; } void MultiSequenceDispatcherOld::ready_to_close() { auto it = dispatchers_.find(get_link_token()); CHECK(it != dispatchers_.end()); if (it->second.cnt_ == 0) { LOG(DEBUG) << "Close SequenceDispatcher " << get_link_token(); dispatchers_.erase(it); } } class MultiSequenceDispatcherImpl final : public MultiSequenceDispatcher { public: void send(NetQueryPtr query) final { auto callback = query->move_callback(); auto chain_ids = query->get_chain_ids(); query->set_in_sequence_dispatcher(true); CHECK(all_of(chain_ids, [](auto chain_id) { return chain_id != 0; })); Node node; node.net_query = std::move(query); node.net_query->debug("Waiting at SequenceDispatcher"); node.net_query_ref = node.net_query.get_weak(); node.callback = std::move(callback); scheduler_.create_task(chain_ids, std::move(node)); loop(); } private: struct Node { NetQueryRef net_query_ref; NetQueryPtr net_query; int32 total_timeout{0}; int32 last_timeout{0}; ActorShared<NetQueryCallback> callback; friend StringBuilder &operator<<(StringBuilder &sb, const Node &node) { return sb << node.net_query; } }; ChainScheduler<Node> scheduler_; using TaskId = ChainScheduler<Node>::TaskId; bool check_timeout(Node &node) { auto &net_query = node.net_query; if (net_query.empty() || net_query->is_ready()) { return false; } if (node.total_timeout > 0) { net_query->total_timeout_ += node.total_timeout; LOG(INFO) << "Set total_timeout to " << net_query->total_timeout_ << " for " << net_query->id(); node.total_timeout = 0; if (net_query->total_timeout_ > net_query->total_timeout_limit_) { LOG(WARNING) << "Fail " << net_query << " to " << net_query->source_ << " because total_timeout " << net_query->total_timeout_ << " is greater than total_timeout_limit " << net_query->total_timeout_limit_; net_query->set_error(Status::Error(429, PSLICE() << "Too Many Requests: retry after " << node.last_timeout)); return true; } } return false; } void on_result(NetQueryPtr query) final { auto task_id = TaskId(get_link_token()); auto &node = *scheduler_.get_task_extra(task_id); if (query->last_timeout_ != 0) { vector<TaskId> to_check_timeout; auto tl_constructor = query->tl_constructor(); scheduler_.for_each_dependent(task_id, [&](TaskId child_task_id) { auto &child_node = *scheduler_.get_task_extra(child_task_id); if (child_node.net_query_ref->tl_constructor() == tl_constructor && child_task_id != task_id) { child_node.total_timeout += query->last_timeout_; child_node.last_timeout = query->last_timeout_; to_check_timeout.push_back(child_task_id); } }); query->last_timeout_ = 0; for (auto dependent_task_id : to_check_timeout) { auto &child_node = *scheduler_.get_task_extra(dependent_task_id); if (check_timeout(child_node)) { scheduler_.pause_task(dependent_task_id); try_resend(dependent_task_id); } } } if (query->is_error() && (query->error().code() == NetQuery::ResendInvokeAfter || (query->error().code() == 400 && (query->error().message() == "MSG_WAIT_FAILED" || query->error().message() == "MSG_WAIT_TIMEOUT")))) { VLOG(net_query) << "Resend " << query; query->resend(); do_resend(task_id, node, std::move(query)); loop(); return; } node.net_query = std::move(query); try_resend(task_id); } void try_resend(TaskId task_id) { auto &node = *scheduler_.get_task_extra(task_id); if (node.callback.empty()) { auto query = std::move(node.net_query); scheduler_.finish_task(task_id); send_closure_later(G()->td(), &Td::on_result, std::move(query)); loop(); return; } auto promise = promise_send_closure(actor_shared(this, task_id), &MultiSequenceDispatcherImpl::on_resend); send_closure(node.callback, &NetQueryCallback::on_result_resendable, std::move(node.net_query), std::move(promise)); } void on_resend(Result<NetQueryPtr> r_query) { auto task_id = TaskId(get_link_token()); auto &node = *scheduler_.get_task_extra(task_id); if (r_query.is_error()) { scheduler_.finish_task(task_id); } else { do_resend(task_id, node, r_query.move_as_ok()); } loop(); } void do_resend(TaskId task_id, Node &node, NetQueryPtr &&query) { node.net_query = std::move(query); node.net_query->debug("Waiting at SequenceDispatcher"); node.net_query_ref = node.net_query.get_weak(); if (check_timeout(node)) { scheduler_.pause_task(task_id); try_resend(task_id); } else { scheduler_.reset_task(task_id); } } void loop() final { flush_pending_queries(); } void tear_down() final { // Leaves scheduler_ in an invalid state, but we are closing anyway scheduler_.for_each([](Node &node) { if (node.net_query.empty()) { return; } node.net_query->set_error(Global::request_aborted_error()); }); } void flush_pending_queries() { while (true) { auto o_task = scheduler_.start_next_task(); if (!o_task) { break; } auto task = o_task.unwrap(); auto &node = *scheduler_.get_task_extra(task.task_id); CHECK(!node.net_query.empty()); auto query = std::move(node.net_query); vector<NetQueryRef> parents; for (auto parent_id : task.parents) { auto &parent_node = *scheduler_.get_task_extra(parent_id); parents.push_back(parent_node.net_query_ref); CHECK(!parent_node.net_query_ref.empty()); } query->set_invoke_after(std::move(parents)); query->last_timeout_ = 0; query->debug("dispatch_with_callback"); G()->net_query_dispatcher().dispatch_with_callback(std::move(query), actor_shared(this, task.task_id)); } } }; ActorOwn<MultiSequenceDispatcher> MultiSequenceDispatcher::create(Slice name) { return ActorOwn<MultiSequenceDispatcher>(create_actor<MultiSequenceDispatcherImpl>(name)); } } // namespace td