From 24766fdad8b0372719f420b421802809ddd10153 Mon Sep 17 00:00:00 2001 From: Arseny Smirnov Date: Tue, 1 Feb 2022 11:13:59 +0300 Subject: [PATCH] ChainScheduler: pass new test --- td/telegram/SequenceDispatcher.cpp | 19 ++++ tdutils/td/utils/ChainScheduler.h | 134 ++++++++++++++++------------- tdutils/test/ChainScheduler.cpp | 65 ++++++++++++-- 3 files changed, 153 insertions(+), 65 deletions(-) diff --git a/td/telegram/SequenceDispatcher.cpp b/td/telegram/SequenceDispatcher.cpp index d5ecec832..30cc54065 100644 --- a/td/telegram/SequenceDispatcher.cpp +++ b/td/telegram/SequenceDispatcher.cpp @@ -300,6 +300,8 @@ class MultiSequenceDispatcherNewImpl final : public MultiSequenceDispatcherNew { struct Node { NetQueryRef net_query_ref; NetQueryPtr net_query; + double total_timeout{0}; + double last_timeout{0}; ActorShared callback; friend StringBuilder &operator<<(StringBuilder &sb, const Node &node) { return sb << node.net_query; @@ -313,6 +315,23 @@ class MultiSequenceDispatcherNewImpl final : public MultiSequenceDispatcherNew { auto task_id = TaskId(get_link_token()); auto &node = *scheduler_.get_task_extra(task_id); +// if (query->last_timeout_ != 0) { +// for (auto i = pos + 1; i < data_.size(); i++) { +// data_[i].total_timeout_ += query->last_timeout_; +// data_[i].last_timeout_ = query->last_timeout_; +// check_timeout(data_[i]); +// if (data.query_->total_timeout_ > data.query_->total_timeout_limit_) { +// LOG(WARNING) << "Fail " << data.query_ << " to " << data.query_->source_ << " because total_timeout " +// << data.query_->total_timeout_ << " is greater than total_timeout_limit " +// << data.query_->total_timeout_limit_; +// data.query_->set_error(Status::Error( +// 429, PSLICE() << "Too Many Requests: retry after " << static_cast(data.last_timeout_ + 0.999))); +// data.state_ = State::Dummy; +// try_resend_query(data, std::move(data.query_)); +// } +// } +// } + if (query->is_error() && (query->error().code() == NetQuery::ResendInvokeAfter || (query->error().code() == 400 && (query->error().message() == "MSG_WAIT_FAILED" || query->error().message() == "MSG_WAIT_TIMEOUT")))) { diff --git a/tdutils/td/utils/ChainScheduler.h b/tdutils/td/utils/ChainScheduler.h index 77e4a0787..8924c7bc0 100644 --- a/tdutils/td/utils/ChainScheduler.h +++ b/tdutils/td/utils/ChainScheduler.h @@ -20,20 +20,22 @@ namespace td { -struct ChainSchedulerTaskWithParents { - uint64 task_id{}; - vector parents; +struct ChainSchedulerBase { + struct TaskWithParents { + uint64 task_id{}; + vector parents; + }; }; template -class ChainScheduler { +class ChainScheduler : public ChainSchedulerBase { public: using TaskId = uint64; using ChainId = uint64; TaskId create_task(Span chains, ExtraT extra = {}); ExtraT *get_task_extra(TaskId task_id); - optional start_next_task(); + optional start_next_task(); void finish_task(TaskId task_id); void reset_task(TaskId task_id); template @@ -47,6 +49,7 @@ class ChainScheduler { private: struct ChainNode : ListNode { TaskId task_id{}; + uint64 generation{}; }; class Chain { @@ -68,11 +71,11 @@ class ChainScheduler { } return static_cast(*chain_node->get_next()).task_id; } - optional get_parent(ChainNode *chain_node) { + optional get_parent(ChainNode *chain_node) { if (chain_node->get_prev() == head_.end()) { return {}; } - return static_cast(*chain_node->get_prev()).task_id; + return static_cast(chain_node->get_prev()); } void finish_task(ChainNode *node) { @@ -83,9 +86,10 @@ class ChainScheduler { return head_.empty(); } - void foreach(std::function f) const { + void foreach(std::function f) const { for (auto it = head_.begin(); it != head_.end(); it = it->get_next()) { - f(static_cast(*it).task_id); + auto &node = static_cast(*it); + f(node.task_id, node.generation); } } @@ -95,12 +99,12 @@ class ChainScheduler { struct ChainInfo { Chain chain; uint32 active_tasks{}; + uint64 generation{1}; }; struct TaskChainInfo { ChainNode chain_node; ChainId chain_id{}; ChainInfo *chain_info{}; - bool waiting_for_parent{}; }; struct Task { enum class State { Pending, Active } state{State::Pending}; @@ -112,28 +116,20 @@ class ChainScheduler { Container tasks_; VectorQueue pending_tasks_; - void on_parent_is_ready(TaskId task_id, ChainId chain_id) { - auto *task = tasks_.get(task_id); - CHECK(task); - for (TaskChainInfo &task_chain_info : task->chains) { - if (task_chain_info.chain_id == chain_id) { - task_chain_info.waiting_for_parent = false; - } - } - - try_start_task(task_id, task); - } - void try_start_task(TaskId task_id, Task *task) { if (task->state != Task::State::Pending) { return; } for (TaskChainInfo &task_chain_info : task->chains) { - if (task_chain_info.waiting_for_parent) { - return; + auto o_parent = task_chain_info.chain_info->chain.get_parent(&task_chain_info.chain_node); + + if (o_parent) { + if (o_parent.value()->generation != task_chain_info.chain_info->generation) { + return; + } } - ChainInfo &chain_info = chains_[task_chain_info.chain_id]; - if (chain_info.active_tasks >= 10) { + + if (task_chain_info.chain_info->active_tasks >= 10) { limited_tasks_[task_chain_info.chain_id] = task_id; return; } @@ -146,44 +142,50 @@ class ChainScheduler { for (TaskChainInfo &task_chain_info : task->chains) { ChainInfo &chain_info = chains_[task_chain_info.chain_id]; chain_info.active_tasks++; - task_chain_info.waiting_for_parent = true; + task_chain_info.chain_node.generation = chain_info.generation; } task->state = Task::State::Active; pending_tasks_.push(task_id); - notify_children(task); + for_each_child(task, [&](auto task_id) { + try_start_task(task_id, tasks_.get(task_id)); + }); } - void notify_children(Task *task) { + template + void for_each_child(Task *task, F &&f) { for (TaskChainInfo &task_chain_info : task->chains) { - ChainInfo &chain_info = chains_[task_chain_info.chain_id]; + ChainInfo &chain_info = *task_chain_info.chain_info; auto o_child = chain_info.chain.get_child(&task_chain_info.chain_node); if (o_child) { - on_parent_is_ready(o_child.value(), task_chain_info.chain_id); + f(o_child.value()); } } } void inactivate_task(TaskId task_id, Task *task) { - CHECK(task->state == Task::State::Active); + LOG(ERROR) << "inactivate " << task_id; + bool was_active = task->state == Task::State::Active; task->state = Task::State::Pending; for (TaskChainInfo &task_chain_info : task->chains) { - ChainInfo &chain_info = chains_[task_chain_info.chain_id]; - chain_info.active_tasks--; + ChainInfo &chain_info = *task_chain_info.chain_info; + if (was_active) { + chain_info.active_tasks--; + } auto it = limited_tasks_.find(task_chain_info.chain_id); if (it != limited_tasks_.end()) { auto limited_task_id = it->second; limited_tasks_.erase(it); if (limited_task_id != task_id) { - try_start_task(limited_task_id, tasks_.get(limited_task_id)); + try_start_task_later(limited_task_id); } } auto o_first = chain_info.chain.get_first(); if (o_first) { auto first_task_id = o_first.unwrap(); if (first_task_id != task_id) { - try_start_task(first_task_id, tasks_.get(first_task_id)); + try_start_task_later(first_task_id); } } } @@ -196,6 +198,18 @@ class ChainScheduler { chains_.erase(task_chain_info.chain_id); } } + + std::vector to_start; + void try_start_task_later(TaskId task_id) { + to_start.push_back(task_id); + } + void flush_try_start_task() { + auto moved_to_start = std::move(to_start); + for (auto task_id : moved_to_start) { + try_start_task(task_id, tasks_.get(task_id)); + } + CHECK(to_start.empty()); + } }; template @@ -214,15 +228,8 @@ typename ChainScheduler::TaskId ChainScheduler::create_task(Span }); for (TaskChainInfo &task_chain_info : task.chains) { - auto &chain = task_chain_info.chain_info->chain; - chain.add_task(&task_chain_info.chain_node); - auto o_parent = chain.get_parent(&task_chain_info.chain_node); - if (o_parent) { - auto parent = o_parent.unwrap(); - if (tasks_.get(parent)->state == Task::State::Pending) { - task_chain_info.waiting_for_parent = true; - } - } + ChainInfo &chain_info = *task_chain_info.chain_info; + chain_info.chain.add_task(&task_chain_info.chain_node); } try_start_task(task_id, &task); @@ -239,12 +246,12 @@ ExtraT *ChainScheduler::get_task_extra(ChainScheduler::TaskId task_id) { } template -optional ChainScheduler::start_next_task() { +optional ChainScheduler::start_next_task() { if (pending_tasks_.empty()) { return {}; } auto task_id = pending_tasks_.pop(); - ChainSchedulerTaskWithParents res; + TaskWithParents res; res.task_id = task_id; auto *task = tasks_.get(task_id); CHECK(task); @@ -252,7 +259,7 @@ optional ChainScheduler::start_next_task( Chain &chain = task_chain_info.chain_info->chain; auto o_parent = chain.get_parent(&task_chain_info.chain_node); if (o_parent) { - res.parents.push_back(o_parent.value()); + res.parents.push_back(o_parent.value()->task_id); } } return res; @@ -263,13 +270,19 @@ void ChainScheduler::finish_task(ChainScheduler::TaskId task_id) { auto *task = tasks_.get(task_id); CHECK(task); + CHECK(to_start.empty()); inactivate_task(task_id, task); - notify_children(task); + for_each_child(task, [&](auto task_id) { + try_start_task_later(task_id); + }); for (TaskChainInfo &task_chain_info : task->chains) { finish_chain_task(task_chain_info); } + + auto task_copy = std::move(*task); tasks_.erase(task_id); + flush_try_start_task(); } template @@ -279,11 +292,12 @@ void ChainScheduler::reset_task(ChainScheduler::TaskId task_id) { inactivate_task(task_id, task); for (TaskChainInfo &task_chain_info : task->chains) { - ChainInfo &chain_info = chains_[task_chain_info.chain_id]; - task_chain_info.waiting_for_parent &= bool(chain_info.chain.get_parent(&task_chain_info.chain_node)); + ChainInfo &chain_info = *task_chain_info.chain_info; + chain_info.generation = td::max(chain_info.generation, task_chain_info.chain_node.generation + 1); } - try_start_task(task_id, task); + try_start_task_later(task_id); + flush_try_start_task(); } template @@ -293,18 +307,20 @@ StringBuilder &operator<<(StringBuilder &sb, ChainScheduler &scheduler) for (auto &it : scheduler.chains_) { sb << "ChainId{" << it.first << "} "; sb << " active_cnt=" << it.second.active_tasks; - sb << " : "; - it.second.chain.foreach([&](auto task_id) { sb << *scheduler.get_task_extra(task_id); }); + sb << " g=" << it.second.generation; + sb << " :"; + it.second.chain.foreach([&](auto task_id, auto generation) { + sb << " " << *scheduler.get_task_extra(task_id) << ":" << generation; + }); sb << "\n"; } scheduler.tasks_.for_each([&](auto id, auto &task) { sb << "Task: " << task.extra; sb << " state =" << static_cast(task.state); - for (auto &task_chain_info : task.chains) { - if (task_chain_info.waiting_for_parent) { - sb << " wait " - << *scheduler.get_task_extra( - task_chain_info.chain_info->chain.get_parent(&task_chain_info.chain_node).value()); + for (auto& task_chain_info : task.chains) { + sb << " g=" << task_chain_info.chain_node.generation; + if (task_chain_info.chain_info->generation != task_chain_info.chain_node.generation) { + sb << " chain_g=" << task_chain_info.chain_info->generation; } } sb << "\n"; diff --git a/tdutils/test/ChainScheduler.cpp b/tdutils/test/ChainScheduler.cpp index c8160369b..8216fb508 100644 --- a/tdutils/test/ChainScheduler.cpp +++ b/tdutils/test/ChainScheduler.cpp @@ -47,6 +47,25 @@ TEST(ChainScheduler, RestartAfterActive) { ASSERT_EQ(second_task_id, scheduler.start_next_task().unwrap().task_id); } +TEST(ChainScheduler, SendAfterRestart) { + td::ChainScheduler scheduler; + using ChainId = td::ChainScheduler::ChainId; + using TaskId = td::ChainScheduler::TaskId; + std::vector chains{1}; + + auto first_task_id = scheduler.create_task( chains, 1); + auto second_task_id = scheduler.create_task( chains, 2); + ASSERT_EQ(first_task_id, scheduler.start_next_task().unwrap().task_id); + ASSERT_EQ(second_task_id, scheduler.start_next_task().unwrap().task_id); + + scheduler.reset_task(first_task_id); + + auto third_task_id = scheduler.create_task( chains, 3); + + ASSERT_EQ(first_task_id, scheduler.start_next_task().unwrap().task_id); + ASSERT_TRUE(!scheduler.start_next_task()); +} + TEST(ChainScheduler, Basic) { td::ChainScheduler scheduler; using ChainId = td::ChainScheduler::ChainId; @@ -87,18 +106,22 @@ struct Query { int id{}; TaskId task_id{}; bool is_ok{}; - friend td::StringBuilder &operator<<(td::StringBuilder &sb, const Query &q) { + bool skipped{}; + friend td::StringBuilder &operator << (td::StringBuilder &sb, const Query &q) { return sb << "Q{" << q.id << "}"; } }; - +td::StringBuilder &operator << (td::StringBuilder &sb, const QueryPtr &query_ptr) { + return sb << *query_ptr; +} TEST(ChainScheduler, Stress) { td::Random::Xorshift128plus rnd(123); - int max_query_id = 1000; + int max_query_id = 100000; int MAX_INFLIGHT_QUERIES = 20; int ChainsN = 4; struct QueryWithParents { + TaskId task_id; QueryPtr id; td::vector parents; }; @@ -108,7 +131,9 @@ TEST(ChainScheduler, Stress) { td::vector> chains(ChainsN); int inflight_queries{}; int current_query_id{}; + int sent_cnt{}; bool done = false; + std::vector pending_queries; auto schedule_new_query = [&] { if (current_query_id > max_query_id) { @@ -133,6 +158,7 @@ TEST(ChainScheduler, Stress) { } auto task_id = scheduler.create_task(chain_ids, query); query->task_id = task_id; + pending_queries.push_back(task_id); inflight_queries++; }; @@ -151,11 +177,28 @@ TEST(ChainScheduler, Stress) { } auto task_with_parents = o_task_with_parents.unwrap(); QueryWithParents query_with_parents; + query_with_parents.task_id = task_with_parents.task_id; query_with_parents.id = to_query_ptr(task_with_parents.task_id); query_with_parents.parents = td::transform(task_with_parents.parents, to_query_ptr); active_queries.push_back(query_with_parents); + sent_cnt++; } }; + auto skip_one_query = [&]() { + if (pending_queries.empty()) { + return; + } + auto it = pending_queries.begin() + rnd.fast(0, (int)pending_queries.size() - 1); + auto task_id = *it; + pending_queries.erase(it); + td::remove_if(active_queries, [&](auto &q) {return q.task_id == task_id;}); + + auto query = *scheduler.get_task_extra(task_id); + query->skipped = true; + scheduler.finish_task(task_id); + inflight_queries--; + LOG(ERROR) << "Skip " << query->id; + }; auto execute_one_query = [&] { if (active_queries.empty()) { return; @@ -167,37 +210,47 @@ TEST(ChainScheduler, Stress) { auto query = query_with_parents.id; if (rnd.fast(0, 20) == 0) { scheduler.finish_task(query->task_id); + td::remove(pending_queries, query->task_id); inflight_queries--; LOG(INFO) << "Fail " << query->id; } else if (check_parents_ok(query_with_parents)) { query->is_ok = true; scheduler.finish_task(query->task_id); + td::remove(pending_queries, query->task_id); inflight_queries--; LOG(INFO) << "OK " << query->id; } else { scheduler.reset_task(query->task_id); + LOG(ERROR) << "Reset " << query->id; } }; - td::RandomSteps steps({{schedule_new_query, 100}, {execute_one_query, 100}}); + td::RandomSteps steps({{schedule_new_query, 100}, {execute_one_query, 100}, {skip_one_query, 10}}); while (!done) { steps.step(rnd); flush_pending_queries(); // LOG(INFO) << scheduler; } + LOG(ERROR) << "Sent queries count " << sent_cnt; + LOG(ERROR) << "Total queries " << current_query_id; for (auto &chain : chains) { int prev_ok = -1; int failed_cnt = 0; int ok_cnt = 0; + int skipped_cnt = 0; for (auto &q : chain) { if (q->is_ok) { CHECK(prev_ok < q->id); prev_ok = q->id; ok_cnt++; } else { - failed_cnt++; + if (q->skipped) { + skipped_cnt++; + } else { + failed_cnt++; + } } } - LOG(INFO) << "Chain ok " << ok_cnt << " failed " << failed_cnt; + LOG(INFO) << "Chain ok " << ok_cnt << " failed " << failed_cnt << " skipped " << skipped_cnt; } }