#pragma once #include "td/utils/algorithm.h" #include "td/utils/Container.h" #include "td/utils/List.h" #include "td/utils/optional.h" #include "td/utils/Span.h" #include "td/utils/tests.h" #include "td/utils/VectorQueue.h" #include "td/utils/Random.h" #include #include #include #include namespace td { template class ChainScheduler { public: using TaskId = uint64; using ChainId = uint64; TaskId create_task(td::Span chains, ExtraT extra = {}); ExtraT *get_task_extra(TaskId task_id); struct TaskWithParents { TaskId task_id{}; std::vector parents; }; optional start_next_task(); void finish_task(TaskId task_id); void reset_task(TaskId task_id); template friend td::StringBuilder &operator<<(StringBuilder &sb, ChainScheduler &scheduler); template void for_each(F &&f) { tasks_.for_each([&f](auto, Task &task) { f(task.extra) ; }); } private: struct ChainNode : ListNode { TaskId task_id{}; }; class Chain { public: void add_task(ChainNode *node) { head_.put_back(node); } optional get_first() { if (head_.empty()) { return {}; } return static_cast(*head_.get_next()).task_id; } optional get_child(ChainNode *chain_node) { if (chain_node->get_next() == head_.end()) { return {}; } return static_cast(*chain_node->get_next()).task_id; } optional get_parent(ChainNode *chain_node) { if (chain_node->get_prev() == head_.end()) { return {}; } return static_cast(*chain_node->get_prev()).task_id; } void finish_task(ChainNode *node) { node->remove(); } bool empty() const { return head_.empty(); } void foreach(std::function f) const { for (auto it = head_.begin(); it != head_.end(); it = it->get_next()) { f(static_cast(*it).task_id); } } private: ListNode head_; }; struct ChainInfo { Chain chain; uint32 active_tasks{}; }; struct TaskChainInfo { ChainNode chain_node; ChainId chain_id{}; ChainInfo *chain_info{}; bool waiting_for_parent{}; }; struct Task { enum class State { Pending, Active } state{State::Pending}; std::vector chains; ExtraT extra; }; std::map chains_; std::map limited_tasks_; Container tasks_; VectorQueue pending_tasks_; void on_parent_is_ready(TaskId task_id, ChainId chain_id) { auto *task = tasks_.get(task_id); CHECK(task); for (TaskChainInfo &task_chain_info : task->chains) { if (task_chain_info.chain_id == chain_id) { task_chain_info.waiting_for_parent = false; } } try_start_task(task_id, task); } void try_start_task(TaskId task_id, Task *task) { if (task->state != Task::State::Pending) { return; } for (TaskChainInfo &task_chain_info : task->chains) { if (task_chain_info.waiting_for_parent) { return; } ChainInfo &chain_info = chains_[task_chain_info.chain_id]; if (chain_info.active_tasks >= 10) { limited_tasks_[task_chain_info.chain_id] = task_id; return; } } do_start_task(task_id, task); } void do_start_task(TaskId task_id, Task *task) { for (TaskChainInfo &task_chain_info : task->chains) { ChainInfo &chain_info = chains_[task_chain_info.chain_id]; chain_info.active_tasks++; } task->state = Task::State::Active; pending_tasks_.push(task_id); notify_children(task); } void notify_children(Task *task) { for (TaskChainInfo &task_chain_info : task->chains) { ChainInfo &chain_info = chains_[task_chain_info.chain_id]; auto o_child = chain_info.chain.get_child(&task_chain_info.chain_node); if (o_child) { on_parent_is_ready(o_child.value(), task_chain_info.chain_id); } } } void inactivate_task(TaskId task_id, Task *task) { CHECK(task->state == Task::State::Active); task->state = Task::State::Pending; for (TaskChainInfo &task_chain_info : task->chains) { ChainInfo &chain_info = chains_[task_chain_info.chain_id]; chain_info.active_tasks--; auto it = limited_tasks_.find(task_chain_info.chain_id); if (it != limited_tasks_.end()) { auto limited_task_id = it->second; limited_tasks_.erase(it); if (limited_task_id != task_id) { try_start_task(limited_task_id, tasks_.get(limited_task_id)); } } auto o_first = chain_info.chain.get_first(); if (o_first) { auto first_task_id = o_first.unwrap(); if (first_task_id != task_id) { try_start_task(first_task_id, tasks_.get(first_task_id)); } } } } void finish_chain_task(TaskChainInfo &task_chain_info) { auto &chain = task_chain_info.chain_info->chain; chain.finish_task(&task_chain_info.chain_node); if (chain.empty()) { chains_.erase(task_chain_info.chain_id); } } }; template typename ChainScheduler::TaskId ChainScheduler::create_task(Span chains, ExtraT extra) { auto task_id = tasks_.create(); Task &task = *tasks_.get(task_id); task.extra = std::move(extra); task.chains = transform(chains, [&](auto chain_id) { TaskChainInfo task_chain_info; ChainInfo &chain_info = chains_[chain_id]; task_chain_info.chain_id = chain_id; task_chain_info.chain_info = &chain_info; task_chain_info.chain_node.task_id = task_id; return task_chain_info; }); for (TaskChainInfo &task_chain_info : task.chains) { auto &chain = task_chain_info.chain_info->chain; chain.add_task(&task_chain_info.chain_node); task_chain_info.waiting_for_parent = bool(chain.get_parent(&task_chain_info.chain_node)); } try_start_task(task_id, &task); return task_id; } template ExtraT *ChainScheduler::get_task_extra(ChainScheduler::TaskId task_id) { // may return nullptr auto *task = tasks_.get(task_id); if (!task) { return nullptr; } return &task->extra; } template optional::TaskWithParents> ChainScheduler::start_next_task() { if (pending_tasks_.empty()) { return {}; } auto task_id = pending_tasks_.pop(); TaskWithParents res; res.task_id = task_id; auto *task = tasks_.get(task_id); CHECK(task); for (TaskChainInfo &task_chain_info : task->chains) { Chain &chain = task_chain_info.chain_info->chain; auto o_parent = chain.get_parent(&task_chain_info.chain_node); if (o_parent) { res.parents.push_back(o_parent.value()); } } return res; } template void ChainScheduler::finish_task(ChainScheduler::TaskId task_id) { auto *task = tasks_.get(task_id); CHECK(task); inactivate_task(task_id, task); notify_children(task); for (TaskChainInfo &task_chain_info : task->chains) { finish_chain_task(task_chain_info); } tasks_.erase(task_id); } template void ChainScheduler::reset_task(ChainScheduler::TaskId task_id) { auto *task = tasks_.get(task_id); CHECK(task); inactivate_task(task_id, task); for (TaskChainInfo &task_chain_info : task->chains) { ChainInfo &chain_info = chains_[task_chain_info.chain_id]; task_chain_info.waiting_for_parent = bool(chain_info.chain.get_parent(&task_chain_info.chain_node)); } try_start_task(task_id, task); } template td::StringBuilder &operator<<(StringBuilder &sb, ChainScheduler &scheduler) { // 1 print chains sb << "\n"; for (auto &it : scheduler.chains_) { sb << "ChainId{" << it.first << "} "; sb << " active_cnt=" << it.second.active_tasks; sb << " : "; it.second.chain.foreach([&](auto task_id) { sb << *scheduler.get_task_extra(task_id); }); sb << "\n"; } scheduler.tasks_.for_each([&](auto id, auto &task) { sb << "Task: " << task.extra; sb << " state =" << static_cast(task.state); for (auto& task_chain_info : task.chains) { if (task_chain_info.waiting_for_parent) { sb << " wait " << *scheduler.get_task_extra(task_chain_info.chain_info->chain.get_parent(&task_chain_info.chain_node).value()); } } sb << "\n"; }); return sb; } } // namespace td