ChainScheduler: pass new test

This commit is contained in:
Arseny Smirnov 2022-02-01 11:13:59 +03:00
parent b4396f18c6
commit 24766fdad8
3 changed files with 153 additions and 65 deletions

View File

@ -300,6 +300,8 @@ class MultiSequenceDispatcherNewImpl final : public MultiSequenceDispatcherNew {
struct Node {
NetQueryRef net_query_ref;
NetQueryPtr net_query;
double total_timeout{0};
double last_timeout{0};
ActorShared<NetQueryCallback> callback;
friend StringBuilder &operator<<(StringBuilder &sb, const Node &node) {
return sb << node.net_query;
@ -313,6 +315,23 @@ class MultiSequenceDispatcherNewImpl final : public MultiSequenceDispatcherNew {
auto task_id = TaskId(get_link_token());
auto &node = *scheduler_.get_task_extra(task_id);
// if (query->last_timeout_ != 0) {
// for (auto i = pos + 1; i < data_.size(); i++) {
// data_[i].total_timeout_ += query->last_timeout_;
// data_[i].last_timeout_ = query->last_timeout_;
// check_timeout(data_[i]);
// if (data.query_->total_timeout_ > data.query_->total_timeout_limit_) {
// LOG(WARNING) << "Fail " << data.query_ << " to " << data.query_->source_ << " because total_timeout "
// << data.query_->total_timeout_ << " is greater than total_timeout_limit "
// << data.query_->total_timeout_limit_;
// data.query_->set_error(Status::Error(
// 429, PSLICE() << "Too Many Requests: retry after " << static_cast<int32>(data.last_timeout_ + 0.999)));
// data.state_ = State::Dummy;
// try_resend_query(data, std::move(data.query_));
// }
// }
// }
if (query->is_error() && (query->error().code() == NetQuery::ResendInvokeAfter ||
(query->error().code() == 400 && (query->error().message() == "MSG_WAIT_FAILED" ||
query->error().message() == "MSG_WAIT_TIMEOUT")))) {

View File

@ -20,20 +20,22 @@
namespace td {
struct ChainSchedulerTaskWithParents {
uint64 task_id{};
vector<uint64> parents;
struct ChainSchedulerBase {
struct TaskWithParents {
uint64 task_id{};
vector<uint64> parents;
};
};
template <class ExtraT = Unit>
class ChainScheduler {
class ChainScheduler : public ChainSchedulerBase {
public:
using TaskId = uint64;
using ChainId = uint64;
TaskId create_task(Span<ChainId> chains, ExtraT extra = {});
ExtraT *get_task_extra(TaskId task_id);
optional<ChainSchedulerTaskWithParents> start_next_task();
optional<TaskWithParents> start_next_task();
void finish_task(TaskId task_id);
void reset_task(TaskId task_id);
template <class ExtraTT>
@ -47,6 +49,7 @@ class ChainScheduler {
private:
struct ChainNode : ListNode {
TaskId task_id{};
uint64 generation{};
};
class Chain {
@ -68,11 +71,11 @@ class ChainScheduler {
}
return static_cast<ChainNode &>(*chain_node->get_next()).task_id;
}
optional<TaskId> get_parent(ChainNode *chain_node) {
optional<ChainNode*> get_parent(ChainNode *chain_node) {
if (chain_node->get_prev() == head_.end()) {
return {};
}
return static_cast<ChainNode &>(*chain_node->get_prev()).task_id;
return static_cast<ChainNode *>(chain_node->get_prev());
}
void finish_task(ChainNode *node) {
@ -83,9 +86,10 @@ class ChainScheduler {
return head_.empty();
}
void foreach(std::function<void(TaskId)> f) const {
void foreach(std::function<void(TaskId, uint64)> f) const {
for (auto it = head_.begin(); it != head_.end(); it = it->get_next()) {
f(static_cast<const ChainNode &>(*it).task_id);
auto &node = static_cast<const ChainNode &>(*it);
f(node.task_id, node.generation);
}
}
@ -95,12 +99,12 @@ class ChainScheduler {
struct ChainInfo {
Chain chain;
uint32 active_tasks{};
uint64 generation{1};
};
struct TaskChainInfo {
ChainNode chain_node;
ChainId chain_id{};
ChainInfo *chain_info{};
bool waiting_for_parent{};
};
struct Task {
enum class State { Pending, Active } state{State::Pending};
@ -112,28 +116,20 @@ class ChainScheduler {
Container<Task> tasks_;
VectorQueue<TaskId> pending_tasks_;
void on_parent_is_ready(TaskId task_id, ChainId chain_id) {
auto *task = tasks_.get(task_id);
CHECK(task);
for (TaskChainInfo &task_chain_info : task->chains) {
if (task_chain_info.chain_id == chain_id) {
task_chain_info.waiting_for_parent = false;
}
}
try_start_task(task_id, task);
}
void try_start_task(TaskId task_id, Task *task) {
if (task->state != Task::State::Pending) {
return;
}
for (TaskChainInfo &task_chain_info : task->chains) {
if (task_chain_info.waiting_for_parent) {
return;
auto o_parent = task_chain_info.chain_info->chain.get_parent(&task_chain_info.chain_node);
if (o_parent) {
if (o_parent.value()->generation != task_chain_info.chain_info->generation) {
return;
}
}
ChainInfo &chain_info = chains_[task_chain_info.chain_id];
if (chain_info.active_tasks >= 10) {
if (task_chain_info.chain_info->active_tasks >= 10) {
limited_tasks_[task_chain_info.chain_id] = task_id;
return;
}
@ -146,44 +142,50 @@ class ChainScheduler {
for (TaskChainInfo &task_chain_info : task->chains) {
ChainInfo &chain_info = chains_[task_chain_info.chain_id];
chain_info.active_tasks++;
task_chain_info.waiting_for_parent = true;
task_chain_info.chain_node.generation = chain_info.generation;
}
task->state = Task::State::Active;
pending_tasks_.push(task_id);
notify_children(task);
for_each_child(task, [&](auto task_id) {
try_start_task(task_id, tasks_.get(task_id));
});
}
void notify_children(Task *task) {
template <class F>
void for_each_child(Task *task, F &&f) {
for (TaskChainInfo &task_chain_info : task->chains) {
ChainInfo &chain_info = chains_[task_chain_info.chain_id];
ChainInfo &chain_info = *task_chain_info.chain_info;
auto o_child = chain_info.chain.get_child(&task_chain_info.chain_node);
if (o_child) {
on_parent_is_ready(o_child.value(), task_chain_info.chain_id);
f(o_child.value());
}
}
}
void inactivate_task(TaskId task_id, Task *task) {
CHECK(task->state == Task::State::Active);
LOG(ERROR) << "inactivate " << task_id;
bool was_active = task->state == Task::State::Active;
task->state = Task::State::Pending;
for (TaskChainInfo &task_chain_info : task->chains) {
ChainInfo &chain_info = chains_[task_chain_info.chain_id];
chain_info.active_tasks--;
ChainInfo &chain_info = *task_chain_info.chain_info;
if (was_active) {
chain_info.active_tasks--;
}
auto it = limited_tasks_.find(task_chain_info.chain_id);
if (it != limited_tasks_.end()) {
auto limited_task_id = it->second;
limited_tasks_.erase(it);
if (limited_task_id != task_id) {
try_start_task(limited_task_id, tasks_.get(limited_task_id));
try_start_task_later(limited_task_id);
}
}
auto o_first = chain_info.chain.get_first();
if (o_first) {
auto first_task_id = o_first.unwrap();
if (first_task_id != task_id) {
try_start_task(first_task_id, tasks_.get(first_task_id));
try_start_task_later(first_task_id);
}
}
}
@ -196,6 +198,18 @@ class ChainScheduler {
chains_.erase(task_chain_info.chain_id);
}
}
std::vector<TaskId> to_start;
void try_start_task_later(TaskId task_id) {
to_start.push_back(task_id);
}
void flush_try_start_task() {
auto moved_to_start = std::move(to_start);
for (auto task_id : moved_to_start) {
try_start_task(task_id, tasks_.get(task_id));
}
CHECK(to_start.empty());
}
};
template <class ExtraT>
@ -214,15 +228,8 @@ typename ChainScheduler<ExtraT>::TaskId ChainScheduler<ExtraT>::create_task(Span
});
for (TaskChainInfo &task_chain_info : task.chains) {
auto &chain = task_chain_info.chain_info->chain;
chain.add_task(&task_chain_info.chain_node);
auto o_parent = chain.get_parent(&task_chain_info.chain_node);
if (o_parent) {
auto parent = o_parent.unwrap();
if (tasks_.get(parent)->state == Task::State::Pending) {
task_chain_info.waiting_for_parent = true;
}
}
ChainInfo &chain_info = *task_chain_info.chain_info;
chain_info.chain.add_task(&task_chain_info.chain_node);
}
try_start_task(task_id, &task);
@ -239,12 +246,12 @@ ExtraT *ChainScheduler<ExtraT>::get_task_extra(ChainScheduler::TaskId task_id) {
}
template <class ExtraT>
optional<ChainSchedulerTaskWithParents> ChainScheduler<ExtraT>::start_next_task() {
optional<ChainSchedulerBase::TaskWithParents> ChainScheduler<ExtraT>::start_next_task() {
if (pending_tasks_.empty()) {
return {};
}
auto task_id = pending_tasks_.pop();
ChainSchedulerTaskWithParents res;
TaskWithParents res;
res.task_id = task_id;
auto *task = tasks_.get(task_id);
CHECK(task);
@ -252,7 +259,7 @@ optional<ChainSchedulerTaskWithParents> ChainScheduler<ExtraT>::start_next_task(
Chain &chain = task_chain_info.chain_info->chain;
auto o_parent = chain.get_parent(&task_chain_info.chain_node);
if (o_parent) {
res.parents.push_back(o_parent.value());
res.parents.push_back(o_parent.value()->task_id);
}
}
return res;
@ -263,13 +270,19 @@ void ChainScheduler<ExtraT>::finish_task(ChainScheduler::TaskId task_id) {
auto *task = tasks_.get(task_id);
CHECK(task);
CHECK(to_start.empty());
inactivate_task(task_id, task);
notify_children(task);
for_each_child(task, [&](auto task_id) {
try_start_task_later(task_id);
});
for (TaskChainInfo &task_chain_info : task->chains) {
finish_chain_task(task_chain_info);
}
auto task_copy = std::move(*task);
tasks_.erase(task_id);
flush_try_start_task();
}
template <class ExtraT>
@ -279,11 +292,12 @@ void ChainScheduler<ExtraT>::reset_task(ChainScheduler::TaskId task_id) {
inactivate_task(task_id, task);
for (TaskChainInfo &task_chain_info : task->chains) {
ChainInfo &chain_info = chains_[task_chain_info.chain_id];
task_chain_info.waiting_for_parent &= bool(chain_info.chain.get_parent(&task_chain_info.chain_node));
ChainInfo &chain_info = *task_chain_info.chain_info;
chain_info.generation = td::max(chain_info.generation, task_chain_info.chain_node.generation + 1);
}
try_start_task(task_id, task);
try_start_task_later(task_id);
flush_try_start_task();
}
template <class ExtraT>
@ -293,18 +307,20 @@ StringBuilder &operator<<(StringBuilder &sb, ChainScheduler<ExtraT> &scheduler)
for (auto &it : scheduler.chains_) {
sb << "ChainId{" << it.first << "} ";
sb << " active_cnt=" << it.second.active_tasks;
sb << " : ";
it.second.chain.foreach([&](auto task_id) { sb << *scheduler.get_task_extra(task_id); });
sb << " g=" << it.second.generation;
sb << " :";
it.second.chain.foreach([&](auto task_id, auto generation) {
sb << " " << *scheduler.get_task_extra(task_id) << ":" << generation;
});
sb << "\n";
}
scheduler.tasks_.for_each([&](auto id, auto &task) {
sb << "Task: " << task.extra;
sb << " state =" << static_cast<int>(task.state);
for (auto &task_chain_info : task.chains) {
if (task_chain_info.waiting_for_parent) {
sb << " wait "
<< *scheduler.get_task_extra(
task_chain_info.chain_info->chain.get_parent(&task_chain_info.chain_node).value());
for (auto& task_chain_info : task.chains) {
sb << " g=" << task_chain_info.chain_node.generation;
if (task_chain_info.chain_info->generation != task_chain_info.chain_node.generation) {
sb << " chain_g=" << task_chain_info.chain_info->generation;
}
}
sb << "\n";

View File

@ -47,6 +47,25 @@ TEST(ChainScheduler, RestartAfterActive) {
ASSERT_EQ(second_task_id, scheduler.start_next_task().unwrap().task_id);
}
TEST(ChainScheduler, SendAfterRestart) {
td::ChainScheduler<int> scheduler;
using ChainId = td::ChainScheduler<int>::ChainId;
using TaskId = td::ChainScheduler<int>::TaskId;
std::vector<ChainId> chains{1};
auto first_task_id = scheduler.create_task( chains, 1);
auto second_task_id = scheduler.create_task( chains, 2);
ASSERT_EQ(first_task_id, scheduler.start_next_task().unwrap().task_id);
ASSERT_EQ(second_task_id, scheduler.start_next_task().unwrap().task_id);
scheduler.reset_task(first_task_id);
auto third_task_id = scheduler.create_task( chains, 3);
ASSERT_EQ(first_task_id, scheduler.start_next_task().unwrap().task_id);
ASSERT_TRUE(!scheduler.start_next_task());
}
TEST(ChainScheduler, Basic) {
td::ChainScheduler<int> scheduler;
using ChainId = td::ChainScheduler<int>::ChainId;
@ -87,18 +106,22 @@ struct Query {
int id{};
TaskId task_id{};
bool is_ok{};
friend td::StringBuilder &operator<<(td::StringBuilder &sb, const Query &q) {
bool skipped{};
friend td::StringBuilder &operator << (td::StringBuilder &sb, const Query &q) {
return sb << "Q{" << q.id << "}";
}
};
td::StringBuilder &operator << (td::StringBuilder &sb, const QueryPtr &query_ptr) {
return sb << *query_ptr;
}
TEST(ChainScheduler, Stress) {
td::Random::Xorshift128plus rnd(123);
int max_query_id = 1000;
int max_query_id = 100000;
int MAX_INFLIGHT_QUERIES = 20;
int ChainsN = 4;
struct QueryWithParents {
TaskId task_id;
QueryPtr id;
td::vector<QueryPtr> parents;
};
@ -108,7 +131,9 @@ TEST(ChainScheduler, Stress) {
td::vector<td::vector<QueryPtr>> chains(ChainsN);
int inflight_queries{};
int current_query_id{};
int sent_cnt{};
bool done = false;
std::vector<TaskId> pending_queries;
auto schedule_new_query = [&] {
if (current_query_id > max_query_id) {
@ -133,6 +158,7 @@ TEST(ChainScheduler, Stress) {
}
auto task_id = scheduler.create_task(chain_ids, query);
query->task_id = task_id;
pending_queries.push_back(task_id);
inflight_queries++;
};
@ -151,11 +177,28 @@ TEST(ChainScheduler, Stress) {
}
auto task_with_parents = o_task_with_parents.unwrap();
QueryWithParents query_with_parents;
query_with_parents.task_id = task_with_parents.task_id;
query_with_parents.id = to_query_ptr(task_with_parents.task_id);
query_with_parents.parents = td::transform(task_with_parents.parents, to_query_ptr);
active_queries.push_back(query_with_parents);
sent_cnt++;
}
};
auto skip_one_query = [&]() {
if (pending_queries.empty()) {
return;
}
auto it = pending_queries.begin() + rnd.fast(0, (int)pending_queries.size() - 1);
auto task_id = *it;
pending_queries.erase(it);
td::remove_if(active_queries, [&](auto &q) {return q.task_id == task_id;});
auto query = *scheduler.get_task_extra(task_id);
query->skipped = true;
scheduler.finish_task(task_id);
inflight_queries--;
LOG(ERROR) << "Skip " << query->id;
};
auto execute_one_query = [&] {
if (active_queries.empty()) {
return;
@ -167,37 +210,47 @@ TEST(ChainScheduler, Stress) {
auto query = query_with_parents.id;
if (rnd.fast(0, 20) == 0) {
scheduler.finish_task(query->task_id);
td::remove(pending_queries, query->task_id);
inflight_queries--;
LOG(INFO) << "Fail " << query->id;
} else if (check_parents_ok(query_with_parents)) {
query->is_ok = true;
scheduler.finish_task(query->task_id);
td::remove(pending_queries, query->task_id);
inflight_queries--;
LOG(INFO) << "OK " << query->id;
} else {
scheduler.reset_task(query->task_id);
LOG(ERROR) << "Reset " << query->id;
}
};
td::RandomSteps steps({{schedule_new_query, 100}, {execute_one_query, 100}});
td::RandomSteps steps({{schedule_new_query, 100}, {execute_one_query, 100}, {skip_one_query, 10}});
while (!done) {
steps.step(rnd);
flush_pending_queries();
// LOG(INFO) << scheduler;
}
LOG(ERROR) << "Sent queries count " << sent_cnt;
LOG(ERROR) << "Total queries " << current_query_id;
for (auto &chain : chains) {
int prev_ok = -1;
int failed_cnt = 0;
int ok_cnt = 0;
int skipped_cnt = 0;
for (auto &q : chain) {
if (q->is_ok) {
CHECK(prev_ok < q->id);
prev_ok = q->id;
ok_cnt++;
} else {
failed_cnt++;
if (q->skipped) {
skipped_cnt++;
} else {
failed_cnt++;
}
}
}
LOG(INFO) << "Chain ok " << ok_cnt << " failed " << failed_cnt;
LOG(INFO) << "Chain ok " << ok_cnt << " failed " << failed_cnt << " skipped " << skipped_cnt;
}
}