From b914b28bf07e2207faf0f60dd064551aff66f77d Mon Sep 17 00:00:00 2001 From: levlam Date: Thu, 27 Jul 2023 14:18:39 +0300 Subject: [PATCH] Use FlatHashMap instead of unordered_map in ChainScheduler. --- tdutils/td/utils/ChainScheduler.h | 23 ++++++++++++++++------- tdutils/test/ChainScheduler.cpp | 4 ++-- 2 files changed, 18 insertions(+), 9 deletions(-) diff --git a/tdutils/td/utils/ChainScheduler.h b/tdutils/td/utils/ChainScheduler.h index ebc97d295..a257adc5e 100644 --- a/tdutils/td/utils/ChainScheduler.h +++ b/tdutils/td/utils/ChainScheduler.h @@ -20,7 +20,6 @@ #include "td/utils/VectorQueue.h" #include -#include namespace td { @@ -142,11 +141,19 @@ class ChainScheduler final : public ChainSchedulerBase { vector chains; ExtraT extra; }; - std::unordered_map> chains_; + FlatHashMap> chains_; FlatHashMap limited_tasks_; Container tasks_; VectorQueue pending_tasks_; + ChainInfo &get_chain_info(ChainId chain_id) { + auto &chain = chains_[chain_id]; + if (chain == nullptr) { + chain = make_unique(); + } + return *chain; + } + void try_start_task(TaskId task_id) { auto *task = tasks_.get(task_id); CHECK(task != nullptr); @@ -173,7 +180,7 @@ class ChainScheduler final : public ChainSchedulerBase { void do_start_task(TaskId task_id, Task *task) { for (TaskChainInfo &task_chain_info : task->chains) { - ChainInfo &chain_info = chains_[task_chain_info.chain_id]; + ChainInfo &chain_info = get_chain_info(task_chain_info.chain_id); chain_info.active_tasks++; task_chain_info.chain_node.generation = chain_info.generation; } @@ -262,8 +269,9 @@ typename ChainScheduler::TaskId ChainScheduler::create_task(Span Task &task = *tasks_.get(task_id); task.extra = std::move(extra); task.chains = transform(chains, [&](auto chain_id) { + CHECK(chain_id != 0); TaskChainInfo task_chain_info; - ChainInfo &chain_info = chains_[chain_id]; + ChainInfo &chain_info = get_chain_info(chain_id); task_chain_info.chain_id = chain_id; task_chain_info.chain_info = &chain_info; task_chain_info.chain_node.task_id = task_id; @@ -352,11 +360,12 @@ StringBuilder &operator<<(StringBuilder &sb, ChainScheduler &scheduler) // 1 print chains sb << '\n'; for (auto &it : scheduler.chains_) { + CHECK(it.second != nullptr); sb << "ChainId{" << it.first << "}"; - sb << " active_cnt = " << it.second.active_tasks; - sb << " g = " << it.second.generation; + sb << " active_cnt = " << it.second->active_tasks; + sb << " g = " << it.second->generation; sb << ':'; - it.second.chain.foreach( + it.second->chain.foreach( [&](auto task_id, auto generation) { sb << ' ' << *scheduler.get_task_extra(task_id) << ':' << generation; }); sb << '\n'; } diff --git a/tdutils/test/ChainScheduler.cpp b/tdutils/test/ChainScheduler.cpp index 029a4ff36..ba4cdbd15 100644 --- a/tdutils/test/ChainScheduler.cpp +++ b/tdutils/test/ChainScheduler.cpp @@ -116,7 +116,7 @@ TEST(ChainScheduler, Stress) { td::vector active_queries; td::ChainScheduler scheduler; - td::vector> chains(ChainsN); + td::vector> chains(ChainsN + 1); int inflight_queries{}; int current_query_id{}; int sent_cnt{}; @@ -138,7 +138,7 @@ TEST(ChainScheduler, Stress) { query->id = query_id; int chain_n = rnd.fast(1, ChainsN); td::vector chain_ids(ChainsN); - std::iota(chain_ids.begin(), chain_ids.end(), 0); + std::iota(chain_ids.begin(), chain_ids.end(), 1); td::rand_shuffle(td::as_mutable_span(chain_ids), rnd); chain_ids.resize(chain_n); for (auto chain_id : chain_ids) {