Use FlatHashMap instead of unordered_map in ChainScheduler.

This commit is contained in:
levlam 2023-07-27 14:18:39 +03:00
parent d8116aa796
commit b914b28bf0
2 changed files with 18 additions and 9 deletions

View File

@ -20,7 +20,6 @@
#include "td/utils/VectorQueue.h" #include "td/utils/VectorQueue.h"
#include <functional> #include <functional>
#include <unordered_map>
namespace td { namespace td {
@ -142,11 +141,19 @@ class ChainScheduler final : public ChainSchedulerBase {
vector<TaskChainInfo> chains; vector<TaskChainInfo> chains;
ExtraT extra; ExtraT extra;
}; };
std::unordered_map<ChainId, ChainInfo, Hash<ChainId>> chains_; FlatHashMap<ChainId, unique_ptr<ChainInfo>> chains_;
FlatHashMap<ChainId, TaskId> limited_tasks_; FlatHashMap<ChainId, TaskId> limited_tasks_;
Container<Task> tasks_; Container<Task> tasks_;
VectorQueue<TaskId> pending_tasks_; VectorQueue<TaskId> pending_tasks_;
ChainInfo &get_chain_info(ChainId chain_id) {
auto &chain = chains_[chain_id];
if (chain == nullptr) {
chain = make_unique<ChainInfo>();
}
return *chain;
}
void try_start_task(TaskId task_id) { void try_start_task(TaskId task_id) {
auto *task = tasks_.get(task_id); auto *task = tasks_.get(task_id);
CHECK(task != nullptr); CHECK(task != nullptr);
@ -173,7 +180,7 @@ class ChainScheduler final : public ChainSchedulerBase {
void do_start_task(TaskId task_id, Task *task) { void do_start_task(TaskId task_id, Task *task) {
for (TaskChainInfo &task_chain_info : task->chains) { for (TaskChainInfo &task_chain_info : task->chains) {
ChainInfo &chain_info = chains_[task_chain_info.chain_id]; ChainInfo &chain_info = get_chain_info(task_chain_info.chain_id);
chain_info.active_tasks++; chain_info.active_tasks++;
task_chain_info.chain_node.generation = chain_info.generation; task_chain_info.chain_node.generation = chain_info.generation;
} }
@ -262,8 +269,9 @@ typename ChainScheduler<ExtraT>::TaskId ChainScheduler<ExtraT>::create_task(Span
Task &task = *tasks_.get(task_id); Task &task = *tasks_.get(task_id);
task.extra = std::move(extra); task.extra = std::move(extra);
task.chains = transform(chains, [&](auto chain_id) { task.chains = transform(chains, [&](auto chain_id) {
CHECK(chain_id != 0);
TaskChainInfo task_chain_info; TaskChainInfo task_chain_info;
ChainInfo &chain_info = chains_[chain_id]; ChainInfo &chain_info = get_chain_info(chain_id);
task_chain_info.chain_id = chain_id; task_chain_info.chain_id = chain_id;
task_chain_info.chain_info = &chain_info; task_chain_info.chain_info = &chain_info;
task_chain_info.chain_node.task_id = task_id; task_chain_info.chain_node.task_id = task_id;
@ -352,11 +360,12 @@ StringBuilder &operator<<(StringBuilder &sb, ChainScheduler<ExtraT> &scheduler)
// 1 print chains // 1 print chains
sb << '\n'; sb << '\n';
for (auto &it : scheduler.chains_) { for (auto &it : scheduler.chains_) {
CHECK(it.second != nullptr);
sb << "ChainId{" << it.first << "}"; sb << "ChainId{" << it.first << "}";
sb << " active_cnt = " << it.second.active_tasks; sb << " active_cnt = " << it.second->active_tasks;
sb << " g = " << it.second.generation; sb << " g = " << it.second->generation;
sb << ':'; sb << ':';
it.second.chain.foreach( it.second->chain.foreach(
[&](auto task_id, auto generation) { sb << ' ' << *scheduler.get_task_extra(task_id) << ':' << generation; }); [&](auto task_id, auto generation) { sb << ' ' << *scheduler.get_task_extra(task_id) << ':' << generation; });
sb << '\n'; sb << '\n';
} }

View File

@ -116,7 +116,7 @@ TEST(ChainScheduler, Stress) {
td::vector<QueryWithParents> active_queries; td::vector<QueryWithParents> active_queries;
td::ChainScheduler<QueryPtr> scheduler; td::ChainScheduler<QueryPtr> scheduler;
td::vector<td::vector<QueryPtr>> chains(ChainsN); td::vector<td::vector<QueryPtr>> chains(ChainsN + 1);
int inflight_queries{}; int inflight_queries{};
int current_query_id{}; int current_query_id{};
int sent_cnt{}; int sent_cnt{};
@ -138,7 +138,7 @@ TEST(ChainScheduler, Stress) {
query->id = query_id; query->id = query_id;
int chain_n = rnd.fast(1, ChainsN); int chain_n = rnd.fast(1, ChainsN);
td::vector<ChainId> chain_ids(ChainsN); td::vector<ChainId> chain_ids(ChainsN);
std::iota(chain_ids.begin(), chain_ids.end(), 0); std::iota(chain_ids.begin(), chain_ids.end(), 1);
td::rand_shuffle(td::as_mutable_span(chain_ids), rnd); td::rand_shuffle(td::as_mutable_span(chain_ids), rnd);
chain_ids.resize(chain_n); chain_ids.resize(chain_n);
for (auto chain_id : chain_ids) { for (auto chain_id : chain_ids) {