diff --git a/td/telegram/MessagesManager.cpp b/td/telegram/MessagesManager.cpp index 5f5dc75a3..35b7d7afd 100644 --- a/td/telegram/MessagesManager.cpp +++ b/td/telegram/MessagesManager.cpp @@ -5998,7 +5998,7 @@ MessagesManager::MessagesManager(Td *td, ActorShared<> parent) preload_folder_dialog_list_timeout_.set_callback(on_preload_folder_dialog_list_timeout_callback); preload_folder_dialog_list_timeout_.set_callback_data(static_cast(this)); - sequence_dispatcher_ = create_actor("multi sequence dispatcher"); + sequence_dispatcher_ = MultiSequenceDispatcher::create("multi sequence dispatcher"); } MessagesManager::~MessagesManager() = default; diff --git a/td/telegram/MessagesManager.h b/td/telegram/MessagesManager.h index 8a38a66f7..75077eeb1 100644 --- a/td/telegram/MessagesManager.h +++ b/td/telegram/MessagesManager.h @@ -53,6 +53,7 @@ #include "td/telegram/SecretChatId.h" #include "td/telegram/SecretInputMedia.h" #include "td/telegram/ServerMessageId.h" +#include "td/telegram/SequenceDispatcher.h" #include "td/telegram/td_api.h" #include "td/telegram/telegram_api.h" #include "td/telegram/UserId.h" @@ -92,7 +93,6 @@ class DraftMessage; struct InputMessageContent; class MessageContent; struct MessageReactions; -class MultiSequenceDispatcher; class Td; class MessagesManager final : public Actor { diff --git a/td/telegram/SequenceDispatcher.cpp b/td/telegram/SequenceDispatcher.cpp index 6e1cf04ca..4a129cfd1 100644 --- a/td/telegram/SequenceDispatcher.cpp +++ b/td/telegram/SequenceDispatcher.cpp @@ -11,6 +11,7 @@ #include "td/actor/PromiseFuture.h" +#include "td/utils/ChainScheduler.h" #include "td/utils/format.h" #include "td/utils/logging.h" #include "td/utils/misc.h" @@ -241,7 +242,7 @@ void SequenceDispatcher::close_silent() { } /*** MultiSequenceDispatcher ***/ -void MultiSequenceDispatcher::send_with_callback(NetQueryPtr query, ActorShared callback, +void MultiSequenceDispatcherOld::send_with_callback(NetQueryPtr query, ActorShared callback, uint64 sequence_id) { CHECK(sequence_id != 0); auto it_ok = dispatchers_.emplace(sequence_id, Data{0, ActorOwn()}); @@ -255,13 +256,13 @@ void MultiSequenceDispatcher::send_with_callback(NetQueryPtr query, ActorShared< send_closure(data.dispatcher_, &SequenceDispatcher::send_with_callback, std::move(query), std::move(callback)); } -void MultiSequenceDispatcher::on_result() { +void MultiSequenceDispatcherOld::on_result() { auto it = dispatchers_.find(get_link_token()); CHECK(it != dispatchers_.end()); it->second.cnt_--; } -void MultiSequenceDispatcher::ready_to_close() { +void MultiSequenceDispatcherOld::ready_to_close() { auto it = dispatchers_.find(get_link_token()); CHECK(it != dispatchers_.end()); if (it->second.cnt_ == 0) { @@ -270,4 +271,104 @@ void MultiSequenceDispatcher::ready_to_close() { } } +class MultiSequenceDispatcherNewImpl final : public MultiSequenceDispatcherNew { + public: + void send_with_callback(NetQueryPtr query, ActorShared callback, uint64 sequence_id) final { + LOG(ERROR) << "send " << query; + Node node; + node.net_query = std::move(query); + node.net_query->debug("Waiting at SequenceDispatcher"); + node.net_query_ref = node.net_query.get_weak(); + node.callback = std::move(callback); + scheduler_.create_task({ChainId{sequence_id}}, std::move(node)); + loop(); + } + + private: + struct Node { + NetQueryRef net_query_ref; + NetQueryPtr net_query; + ActorShared callback; + friend StringBuilder &operator << (StringBuilder &sb, const Node &node) { + return sb << node.net_query; + } + }; + ChainScheduler scheduler_; + using TaskId = ChainScheduler::TaskId; + using ChainId = ChainScheduler::ChainId; + + void on_result(NetQueryPtr query) override { + auto task_id = TaskId(get_link_token()); + auto &node = *scheduler_.get_task_extra(task_id); + + if (query->is_error() && (query->error().code() == NetQuery::ResendInvokeAfter || + (query->error().code() == 400 && (query->error().message() == "MSG_WAIT_FAILED" || + query->error().message() == "MSG_WAIT_TIMEOUT")))) { + VLOG(net_query) << "Resend " << query; + query->resend(); + return on_resend(std::move(query)); + } + auto promise = promise_send_closure(actor_shared(this, task_id), &MultiSequenceDispatcherNewImpl::on_resend); + send_closure(node.callback, &NetQueryCallback::on_result_resendable, std::move(query), std::move(promise)); + } + + // TODO: without td::Result? + void on_resend(td::Result query) { + auto task_id = TaskId(get_link_token()); + auto &node = *scheduler_.get_task_extra(task_id); + if (query.is_error()) { + scheduler_.finish_task(task_id); + } else { + node.net_query = query.move_as_ok(); + node.net_query->debug("Waiting at SequenceDispatcher"); + node.net_query_ref = node.net_query.get_weak(); + scheduler_.reset_task(task_id); + } + loop(); + } + + void loop() override { + flush_pending_queries(); + } + + void flush_pending_queries() { + while (true) { + auto o_task = scheduler_.start_next_task(); + if (!o_task) { + LOG(ERROR) << " no more tasks " << scheduler_; + break; + } + auto task = o_task.unwrap(); + LOG(ERROR) << " next task = " << task.task_id; + auto &node = *scheduler_.get_task_extra(task.task_id); + CHECK(!node.net_query.empty()); + + auto query = std::move(node.net_query); + std::vector parents; + for (auto parent_id : task.parents) { + auto &parent_node = *scheduler_.get_task_extra(parent_id); + parents.push_back(parent_node.net_query_ref); + } + + if (parents.empty()) { + query->set_invoke_after({}); + } else if (parents.size() == 1) { + query->set_invoke_after(parents[0]); + } else if (parents.size() > 1){ + LOG(FATAL) << "TODO: support invokeAfterMsgs"; + } + query->last_timeout_ = 0; // TODO: flood + VLOG(net_query) << "Send " << query; + query->debug("send to Td::send_with_callback"); + query->set_session_rand(123); // TODO: chain_rand + G()->net_query_dispatcher().dispatch_with_callback(std::move(query), + actor_shared(this, task.task_id)); + } + } +}; + +ActorOwn MultiSequenceDispatcherNew::create(Slice name) { + return ActorOwn(create_actor(name)); +} + } // namespace td diff --git a/td/telegram/SequenceDispatcher.h b/td/telegram/SequenceDispatcher.h index 1018f3397..4b8e9d996 100644 --- a/td/telegram/SequenceDispatcher.h +++ b/td/telegram/SequenceDispatcher.h @@ -73,9 +73,12 @@ class SequenceDispatcher final : public NetQueryCallback { void tear_down() final; }; -class MultiSequenceDispatcher final : public SequenceDispatcher::Parent { +class MultiSequenceDispatcherOld final : public SequenceDispatcher::Parent { public: void send_with_callback(NetQueryPtr query, ActorShared callback, uint64 sequence_id); + static ActorOwn create(td::Slice name) { + return create_actor(name); + } private: struct Data { @@ -87,4 +90,12 @@ class MultiSequenceDispatcher final : public SequenceDispatcher::Parent { void ready_to_close() final; }; +class MultiSequenceDispatcherNew : public NetQueryCallback { + public: + virtual void send_with_callback(NetQueryPtr query, ActorShared callback, uint64 sequence_id) = 0; + static ActorOwn create(Slice name); +}; + +using MultiSequenceDispatcher = MultiSequenceDispatcherOld; + } // namespace td diff --git a/td/telegram/net/NetQuery.h b/td/telegram/net/NetQuery.h index ac34b1857..90c66190a 100644 --- a/td/telegram/net/NetQuery.h +++ b/td/telegram/net/NetQuery.h @@ -383,6 +383,9 @@ inline StringBuilder &operator<<(StringBuilder &stream, const NetQuery &net_quer } inline StringBuilder &operator<<(StringBuilder &stream, const NetQueryPtr &net_query_ptr) { + if (net_query_ptr.empty()) { + return stream << "[Query: null]"; + } return stream << *net_query_ptr; } diff --git a/tdutils/CMakeLists.txt b/tdutils/CMakeLists.txt index 039d5802d..feabbcf23 100644 --- a/tdutils/CMakeLists.txt +++ b/tdutils/CMakeLists.txt @@ -289,6 +289,7 @@ endif() set(TDUTILS_TEST_SOURCE ${CMAKE_CURRENT_SOURCE_DIR}/test/bitmask.cpp ${CMAKE_CURRENT_SOURCE_DIR}/test/buffer.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/test/ChainScheduler.cpp ${CMAKE_CURRENT_SOURCE_DIR}/test/ConcurrentHashMap.cpp ${CMAKE_CURRENT_SOURCE_DIR}/test/crypto.cpp ${CMAKE_CURRENT_SOURCE_DIR}/test/Enumerator.cpp diff --git a/tdutils/td/utils/ChainScheduler.h b/tdutils/td/utils/ChainScheduler.h new file mode 100644 index 000000000..44dd2bb65 --- /dev/null +++ b/tdutils/td/utils/ChainScheduler.h @@ -0,0 +1,291 @@ +#pragma once + +#include "td/utils/algorithm.h" +#include "td/utils/Container.h" +#include "td/utils/List.h" +#include "td/utils/optional.h" +#include "td/utils/Span.h" +#include "td/utils/tests.h" +#include "td/utils/VectorQueue.h" +#include "td/utils/Random.h" + +#include +#include +#include +#include + +namespace td { + +template +class ChainScheduler { + public: + using TaskId = uint64; + using ChainId = uint64; + TaskId create_task(td::Span chains, ExtraT extra = {}); + ExtraT *get_task_extra(TaskId task_id); + + struct TaskWithParents { + TaskId task_id{}; + std::vector parents; + }; + + optional start_next_task(); + void finish_task(TaskId task_id); + void reset_task(TaskId task_id); + template + friend td::StringBuilder &operator<<(StringBuilder &sb, ChainScheduler &scheduler); + + private: + struct ChainNode : ListNode { + TaskId task_id{}; + }; + + class Chain { + public: + void add_task(ChainNode *node) { + head_.put_back(node); + } + + optional get_first() { + if (head_.empty()) { + return {}; + } + return static_cast(*head_.get_next()).task_id; + } + + optional get_child(ChainNode *chain_node) { + if (chain_node->get_next() == head_.end()) { + return {}; + } + return static_cast(*chain_node->get_next()).task_id; + } + optional get_parent(ChainNode *chain_node) { + if (chain_node->get_prev() == head_.end()) { + return {}; + } + return static_cast(*chain_node->get_prev()).task_id; + } + + void finish_task(ChainNode *node) { + node->remove(); + } + + bool empty() const { + return head_.empty(); + } + + void foreach(std::function f) const { + for (auto it = head_.begin(); it != head_.end(); it = it->get_next()) { + f(static_cast(*it).task_id); + } + } + + private: + ListNode head_; + }; + struct ChainInfo { + Chain chain; + uint32 active_tasks{}; + }; + struct TaskChainInfo { + ChainNode chain_node; + ChainId chain_id{}; + ChainInfo *chain_info{}; + bool waiting_for_parent{}; + }; + struct Task { + enum class State { Pending, Active } state{State::Pending}; + std::vector chains; + ExtraT extra; + }; + std::map chains_; + std::map limited_tasks_; + Container tasks_; + VectorQueue pending_tasks_; + + void on_parent_is_ready(TaskId task_id, ChainId chain_id) { + auto *task = tasks_.get(task_id); + CHECK(task); + for (TaskChainInfo &task_chain_info : task->chains) { + if (task_chain_info.chain_id == chain_id) { + task_chain_info.waiting_for_parent = false; + } + } + + try_start_task(task_id, task); + } + + void try_start_task(TaskId task_id, Task *task) { + if (task->state != Task::State::Pending) { + return; + } + for (TaskChainInfo &task_chain_info : task->chains) { + if (task_chain_info.waiting_for_parent) { + return; + } + ChainInfo &chain_info = chains_[task_chain_info.chain_id]; + if (chain_info.active_tasks >= 10) { + limited_tasks_[task_chain_info.chain_id] = task_id; + return; + } + } + + do_start_task(task_id, task); + } + + void do_start_task(TaskId task_id, Task *task) { + for (TaskChainInfo &task_chain_info : task->chains) { + ChainInfo &chain_info = chains_[task_chain_info.chain_id]; + chain_info.active_tasks++; + } + task->state = Task::State::Active; + + pending_tasks_.push(task_id); + notify_children(task); + } + + void notify_children(Task *task) { + for (TaskChainInfo &task_chain_info : task->chains) { + ChainInfo &chain_info = chains_[task_chain_info.chain_id]; + auto o_child = chain_info.chain.get_child(&task_chain_info.chain_node); + if (o_child) { + on_parent_is_ready(o_child.value(), task_chain_info.chain_id); + } + } + } + + void inactivate_task(TaskId task_id, Task *task) { + CHECK(task->state == Task::State::Active); + task->state = Task::State::Pending; + for (TaskChainInfo &task_chain_info : task->chains) { + ChainInfo &chain_info = chains_[task_chain_info.chain_id]; + chain_info.active_tasks--; + + auto it = limited_tasks_.find(task_chain_info.chain_id); + if (it != limited_tasks_.end()) { + auto limited_task_id = it->second; + limited_tasks_.erase(it); + if (limited_task_id != task_id) { + try_start_task(limited_task_id, tasks_.get(limited_task_id)); + } + } + auto o_first = chain_info.chain.get_first(); + if (o_first) { + auto first_task_id = o_first.unwrap(); + if (first_task_id != task_id) { + try_start_task(first_task_id, tasks_.get(first_task_id)); + } + } + } + } + + void finish_chain_task(TaskChainInfo &task_chain_info) { + auto &chain = task_chain_info.chain_info->chain; + chain.finish_task(&task_chain_info.chain_node); + if (chain.empty()) { + chains_.erase(task_chain_info.chain_id); + } + } +}; +template +typename ChainScheduler::TaskId ChainScheduler::create_task(Span chains, ExtraT extra) { + auto task_id = tasks_.create(); + Task &task = *tasks_.get(task_id); + task.extra = std::move(extra); + task.chains = transform(chains, [&](auto chain_id) { + TaskChainInfo task_chain_info; + ChainInfo &chain_info = chains_[chain_id]; + task_chain_info.chain_id = chain_id; + task_chain_info.chain_info = &chain_info; + task_chain_info.chain_node.task_id = task_id; + return task_chain_info; + }); + + for (TaskChainInfo &task_chain_info : task.chains) { + auto &chain = task_chain_info.chain_info->chain; + chain.add_task(&task_chain_info.chain_node); + task_chain_info.waiting_for_parent = bool(chain.get_parent(&task_chain_info.chain_node)); + } + + try_start_task(task_id, &task); + return task_id; +} +template +ExtraT *ChainScheduler::get_task_extra(ChainScheduler::TaskId task_id) { // may return nullptr + auto *task = tasks_.get(task_id); + if (!task) { + return nullptr; + } + return &task->extra; +} +template +optional::TaskWithParents> ChainScheduler::start_next_task() { + if (pending_tasks_.empty()) { + return {}; + } + auto task_id = pending_tasks_.pop(); + TaskWithParents res; + res.task_id = task_id; + auto *task = tasks_.get(task_id); + CHECK(task); + for (TaskChainInfo &task_chain_info : task->chains) { + Chain &chain = task_chain_info.chain_info->chain; + auto o_parent = chain.get_parent(&task_chain_info.chain_node); + if (o_parent) { + res.parents.push_back(o_parent.value()); + } + } + return res; +} +template +void ChainScheduler::finish_task(ChainScheduler::TaskId task_id) { + auto *task = tasks_.get(task_id); + CHECK(task); + + inactivate_task(task_id, task); + notify_children(task); + + for (TaskChainInfo &task_chain_info : task->chains) { + finish_chain_task(task_chain_info); + } + tasks_.erase(task_id); +} +template +void ChainScheduler::reset_task(ChainScheduler::TaskId task_id) { + auto *task = tasks_.get(task_id); + CHECK(task); + inactivate_task(task_id, task); + + for (TaskChainInfo &task_chain_info : task->chains) { + ChainInfo &chain_info = chains_[task_chain_info.chain_id]; + task_chain_info.waiting_for_parent = bool(chain_info.chain.get_parent(&task_chain_info.chain_node)); + } + + try_start_task(task_id, task); +} +template +td::StringBuilder &operator<<(StringBuilder &sb, ChainScheduler &scheduler) { + // 1 print chains + sb << "\n"; + for (auto &it : scheduler.chains_) { + sb << "ChainId{" << it.first << "} "; + sb << " active_cnt=" << it.second.active_tasks; + sb << " : "; + it.second.chain.foreach([&](auto task_id) { + sb << *scheduler.get_task_extra(task_id); + }); + sb << "\n"; + } + scheduler.tasks_.for_each([&](auto id, auto &task) { + sb << "Task: " << task.extra; + sb << " state =" << static_cast(task.state); + for (auto& task_chain_info : task.chains) { + if (task_chain_info.waiting_for_parent) { + sb << " wait " << *scheduler.get_task_extra(task_chain_info.chain_info->chain.get_parent(&task_chain_info.chain_node).value()); + } + } + sb << "\n"; + }); + return sb; +} +} // namespace td diff --git a/tdutils/td/utils/List.h b/tdutils/td/utils/List.h index 345353538..4f9bb9877 100644 --- a/tdutils/td/utils/List.h +++ b/tdutils/td/utils/List.h @@ -89,12 +89,24 @@ struct ListNode { ListNode *end() { return this; } + const ListNode *begin() const { + return next; + } + const ListNode *end() const { + return this; + } ListNode *get_next() { return next; } ListNode *get_prev() { return prev; } + const ListNode *get_next() const { + return next; + } + const ListNode *get_prev() const { + return prev; + } protected: void clear() { diff --git a/tdutils/td/utils/algorithm.h b/tdutils/td/utils/algorithm.h index f1764e313..a2b1515ea 100644 --- a/tdutils/td/utils/algorithm.h +++ b/tdutils/td/utils/algorithm.h @@ -117,6 +117,15 @@ bool contains(const V &v, const T &value) { } return false; } +template +bool all_of(const V &v, F &&f) { + for (auto &x : v) { + if (!f(x)) { + return false; + } + } + return true; +} template void reset_to_empty(T &value) { diff --git a/tdutils/test/ChainScheduler.cpp b/tdutils/test/ChainScheduler.cpp new file mode 100644 index 000000000..bdd5fc75c --- /dev/null +++ b/tdutils/test/ChainScheduler.cpp @@ -0,0 +1,164 @@ +#include "td/utils/algorithm.h" +#include "td/utils/optional.h" +#include "td/utils/Span.h" +#include "td/utils/tests.h" +#include "td/utils/Random.h" + +#include "td/utils/ChainScheduler.h" + +#include +#include + +TEST(ChainScheduler, Basic) { + td::ChainScheduler scheduler; + using ChainId = td::ChainScheduler::ChainId; + using TaskId = td::ChainScheduler::TaskId; + for (int i = 0; i < 100; i++) { + scheduler.create_task({ChainId{1}}, i); + } + int j = 0; + while (j != 100) { + std::vector tasks; + while (true) { + auto o_task_id = scheduler.start_next_task(); + if (!o_task_id) { + break; + } + auto task_id = o_task_id.value().task_id; + auto extra = *scheduler.get_task_extra(task_id); + auto parents = td::transform(o_task_id.value().parents, + [&](auto parent) { return *scheduler.get_task_extra(parent); }); + LOG(ERROR) << "start " << extra << parents; + CHECK(extra == j); + j++; + tasks.push_back(task_id); + } + for (auto &task_id : tasks) { + auto extra = *scheduler.get_task_extra(task_id); + LOG(ERROR) << "finish " << extra; + scheduler.finish_task(task_id); + } + } +} + +struct Query; +using QueryPtr = std::shared_ptr; +using ChainId = td::ChainScheduler::ChainId; +using TaskId = td::ChainScheduler::TaskId; +struct Query { + int id{}; + TaskId task_id{}; + bool is_ok{}; + friend td::StringBuilder &operator << (td::StringBuilder &sb, const Query &q) { + return sb << "Q{" << q.id << "}"; + } +}; +TEST(ChainScheduler, Stress) { + td::Random::Xorshift128plus rnd(123); + int max_query_id = 1000; + int MAX_INFLIGHT_QUERIES = 20; + int ChainsN = 4; + + struct QueryWithParents { + QueryPtr id; + std::vector parents; + }; + std::vector active_queries; + + td::ChainScheduler scheduler; + std::vector> chains(ChainsN); + int inflight_queries{}; + int current_query_id{}; + bool done = false; + + auto schedule_new_query = [&] { + if (current_query_id > max_query_id) { + if (inflight_queries == 0) { + done = true; + } + return; + } + if (inflight_queries >= MAX_INFLIGHT_QUERIES) { + return; + } + auto query_id = current_query_id++; + auto query = std::make_shared(); + query->id = query_id; + int chain_n = rnd.fast(1, ChainsN); + std::vector chain_ids(ChainsN); + std::iota(chain_ids.begin(), chain_ids.end(), 0); + td::random_shuffle(td::as_mutable_span(chain_ids), rnd); + chain_ids.resize(chain_n); + for (auto chain_id : chain_ids) { + chains[chain_id].push_back(query); + } + auto task_id = scheduler.create_task(chain_ids, query); + query->task_id = task_id; + inflight_queries++; + }; + + auto check_parents_ok = [&] (const QueryWithParents &query_with_parents) -> bool { + return td::all_of(query_with_parents.parents, [](auto &parent) { return parent->is_ok; }); + }; + + auto to_query_ptr = [&](TaskId task_id) { + return *scheduler.get_task_extra(task_id); + }; + auto flush_pending_queries = [&]{ + while (true) { + auto o_task_with_parents = scheduler.start_next_task(); + if (!o_task_with_parents) { + break; + } + auto task_with_parents = o_task_with_parents.unwrap(); + QueryWithParents query_with_parents; + query_with_parents.id = to_query_ptr(task_with_parents.task_id); + query_with_parents.parents = td::transform(task_with_parents.parents, to_query_ptr); + active_queries.push_back(query_with_parents); + } + }; + auto execute_one_query = [&]() { + if (active_queries.empty()) { + return; + } + auto it = active_queries.begin() + rnd.fast(0, (int)active_queries.size() - 1); + auto query_with_parents = *it; + active_queries.erase(it); + + auto query = query_with_parents.id; + if (rnd.fast(0, 20) == 0) { + scheduler.finish_task(query->task_id); + inflight_queries--; + LOG(ERROR) << "Fail " << query->id; + } else if (check_parents_ok(query_with_parents)) { + query->is_ok = true; + scheduler.finish_task(query->task_id); + inflight_queries--; + LOG(ERROR) << "OK " << query->id; + } else { + scheduler.reset_task(query->task_id); + } + }; + + td::RandomSteps steps({{schedule_new_query, 100}, {execute_one_query, 100}}); + while (!done) { + steps.step(rnd); + flush_pending_queries(); + // LOG(ERROR) << scheduler; + } + for (auto &chain : chains) { + int prev_ok = -1; + int failed_cnt = 0; + int ok_cnt = 0; + for (auto &q : chain) { + if (q->is_ok) { + CHECK(prev_ok < q->id) ; + prev_ok = q->id; + ok_cnt++; + } else { + failed_cnt++; + } + } + LOG(ERROR) << "Chain ok " << ok_cnt << " failed " << failed_cnt; + } +}