ChainScheduler - new implementation of SequenceDispatcher
This commit is contained in:
parent
4c98811b03
commit
355c2950ad
@ -5998,7 +5998,7 @@ MessagesManager::MessagesManager(Td *td, ActorShared<> parent)
|
||||
preload_folder_dialog_list_timeout_.set_callback(on_preload_folder_dialog_list_timeout_callback);
|
||||
preload_folder_dialog_list_timeout_.set_callback_data(static_cast<void *>(this));
|
||||
|
||||
sequence_dispatcher_ = create_actor<MultiSequenceDispatcher>("multi sequence dispatcher");
|
||||
sequence_dispatcher_ = MultiSequenceDispatcher::create("multi sequence dispatcher");
|
||||
}
|
||||
|
||||
MessagesManager::~MessagesManager() = default;
|
||||
|
@ -53,6 +53,7 @@
|
||||
#include "td/telegram/SecretChatId.h"
|
||||
#include "td/telegram/SecretInputMedia.h"
|
||||
#include "td/telegram/ServerMessageId.h"
|
||||
#include "td/telegram/SequenceDispatcher.h"
|
||||
#include "td/telegram/td_api.h"
|
||||
#include "td/telegram/telegram_api.h"
|
||||
#include "td/telegram/UserId.h"
|
||||
@ -92,7 +93,6 @@ class DraftMessage;
|
||||
struct InputMessageContent;
|
||||
class MessageContent;
|
||||
struct MessageReactions;
|
||||
class MultiSequenceDispatcher;
|
||||
class Td;
|
||||
|
||||
class MessagesManager final : public Actor {
|
||||
|
@ -11,6 +11,7 @@
|
||||
|
||||
#include "td/actor/PromiseFuture.h"
|
||||
|
||||
#include "td/utils/ChainScheduler.h"
|
||||
#include "td/utils/format.h"
|
||||
#include "td/utils/logging.h"
|
||||
#include "td/utils/misc.h"
|
||||
@ -241,7 +242,7 @@ void SequenceDispatcher::close_silent() {
|
||||
}
|
||||
|
||||
/*** MultiSequenceDispatcher ***/
|
||||
void MultiSequenceDispatcher::send_with_callback(NetQueryPtr query, ActorShared<NetQueryCallback> callback,
|
||||
void MultiSequenceDispatcherOld::send_with_callback(NetQueryPtr query, ActorShared<NetQueryCallback> callback,
|
||||
uint64 sequence_id) {
|
||||
CHECK(sequence_id != 0);
|
||||
auto it_ok = dispatchers_.emplace(sequence_id, Data{0, ActorOwn<SequenceDispatcher>()});
|
||||
@ -255,13 +256,13 @@ void MultiSequenceDispatcher::send_with_callback(NetQueryPtr query, ActorShared<
|
||||
send_closure(data.dispatcher_, &SequenceDispatcher::send_with_callback, std::move(query), std::move(callback));
|
||||
}
|
||||
|
||||
void MultiSequenceDispatcher::on_result() {
|
||||
void MultiSequenceDispatcherOld::on_result() {
|
||||
auto it = dispatchers_.find(get_link_token());
|
||||
CHECK(it != dispatchers_.end());
|
||||
it->second.cnt_--;
|
||||
}
|
||||
|
||||
void MultiSequenceDispatcher::ready_to_close() {
|
||||
void MultiSequenceDispatcherOld::ready_to_close() {
|
||||
auto it = dispatchers_.find(get_link_token());
|
||||
CHECK(it != dispatchers_.end());
|
||||
if (it->second.cnt_ == 0) {
|
||||
@ -270,4 +271,104 @@ void MultiSequenceDispatcher::ready_to_close() {
|
||||
}
|
||||
}
|
||||
|
||||
class MultiSequenceDispatcherNewImpl final : public MultiSequenceDispatcherNew {
|
||||
public:
|
||||
void send_with_callback(NetQueryPtr query, ActorShared<NetQueryCallback> callback, uint64 sequence_id) final {
|
||||
LOG(ERROR) << "send " << query;
|
||||
Node node;
|
||||
node.net_query = std::move(query);
|
||||
node.net_query->debug("Waiting at SequenceDispatcher");
|
||||
node.net_query_ref = node.net_query.get_weak();
|
||||
node.callback = std::move(callback);
|
||||
scheduler_.create_task({ChainId{sequence_id}}, std::move(node));
|
||||
loop();
|
||||
}
|
||||
|
||||
private:
|
||||
struct Node {
|
||||
NetQueryRef net_query_ref;
|
||||
NetQueryPtr net_query;
|
||||
ActorShared<NetQueryCallback> callback;
|
||||
friend StringBuilder &operator << (StringBuilder &sb, const Node &node) {
|
||||
return sb << node.net_query;
|
||||
}
|
||||
};
|
||||
ChainScheduler<Node> scheduler_;
|
||||
using TaskId = ChainScheduler<NetQueryPtr>::TaskId;
|
||||
using ChainId = ChainScheduler<NetQueryPtr>::ChainId;
|
||||
|
||||
void on_result(NetQueryPtr query) override {
|
||||
auto task_id = TaskId(get_link_token());
|
||||
auto &node = *scheduler_.get_task_extra(task_id);
|
||||
|
||||
if (query->is_error() && (query->error().code() == NetQuery::ResendInvokeAfter ||
|
||||
(query->error().code() == 400 && (query->error().message() == "MSG_WAIT_FAILED" ||
|
||||
query->error().message() == "MSG_WAIT_TIMEOUT")))) {
|
||||
VLOG(net_query) << "Resend " << query;
|
||||
query->resend();
|
||||
return on_resend(std::move(query));
|
||||
}
|
||||
auto promise = promise_send_closure(actor_shared(this, task_id), &MultiSequenceDispatcherNewImpl::on_resend);
|
||||
send_closure(node.callback, &NetQueryCallback::on_result_resendable, std::move(query), std::move(promise));
|
||||
}
|
||||
|
||||
// TODO: without td::Result?
|
||||
void on_resend(td::Result<NetQueryPtr> query) {
|
||||
auto task_id = TaskId(get_link_token());
|
||||
auto &node = *scheduler_.get_task_extra(task_id);
|
||||
if (query.is_error()) {
|
||||
scheduler_.finish_task(task_id);
|
||||
} else {
|
||||
node.net_query = query.move_as_ok();
|
||||
node.net_query->debug("Waiting at SequenceDispatcher");
|
||||
node.net_query_ref = node.net_query.get_weak();
|
||||
scheduler_.reset_task(task_id);
|
||||
}
|
||||
loop();
|
||||
}
|
||||
|
||||
void loop() override {
|
||||
flush_pending_queries();
|
||||
}
|
||||
|
||||
void flush_pending_queries() {
|
||||
while (true) {
|
||||
auto o_task = scheduler_.start_next_task();
|
||||
if (!o_task) {
|
||||
LOG(ERROR) << " no more tasks " << scheduler_;
|
||||
break;
|
||||
}
|
||||
auto task = o_task.unwrap();
|
||||
LOG(ERROR) << " next task = " << task.task_id;
|
||||
auto &node = *scheduler_.get_task_extra(task.task_id);
|
||||
CHECK(!node.net_query.empty());
|
||||
|
||||
auto query = std::move(node.net_query);
|
||||
std::vector<NetQueryRef> parents;
|
||||
for (auto parent_id : task.parents) {
|
||||
auto &parent_node = *scheduler_.get_task_extra(parent_id);
|
||||
parents.push_back(parent_node.net_query_ref);
|
||||
}
|
||||
|
||||
if (parents.empty()) {
|
||||
query->set_invoke_after({});
|
||||
} else if (parents.size() == 1) {
|
||||
query->set_invoke_after(parents[0]);
|
||||
} else if (parents.size() > 1){
|
||||
LOG(FATAL) << "TODO: support invokeAfterMsgs";
|
||||
}
|
||||
query->last_timeout_ = 0; // TODO: flood
|
||||
VLOG(net_query) << "Send " << query;
|
||||
query->debug("send to Td::send_with_callback");
|
||||
query->set_session_rand(123); // TODO: chain_rand
|
||||
G()->net_query_dispatcher().dispatch_with_callback(std::move(query),
|
||||
actor_shared(this, task.task_id));
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
ActorOwn<MultiSequenceDispatcherNew> MultiSequenceDispatcherNew::create(Slice name) {
|
||||
return ActorOwn<MultiSequenceDispatcherNew>(create_actor<MultiSequenceDispatcherNewImpl>(name));
|
||||
}
|
||||
|
||||
} // namespace td
|
||||
|
@ -73,9 +73,12 @@ class SequenceDispatcher final : public NetQueryCallback {
|
||||
void tear_down() final;
|
||||
};
|
||||
|
||||
class MultiSequenceDispatcher final : public SequenceDispatcher::Parent {
|
||||
class MultiSequenceDispatcherOld final : public SequenceDispatcher::Parent {
|
||||
public:
|
||||
void send_with_callback(NetQueryPtr query, ActorShared<NetQueryCallback> callback, uint64 sequence_id);
|
||||
static ActorOwn<MultiSequenceDispatcherOld> create(td::Slice name) {
|
||||
return create_actor<MultiSequenceDispatcherOld>(name);
|
||||
}
|
||||
|
||||
private:
|
||||
struct Data {
|
||||
@ -87,4 +90,12 @@ class MultiSequenceDispatcher final : public SequenceDispatcher::Parent {
|
||||
void ready_to_close() final;
|
||||
};
|
||||
|
||||
class MultiSequenceDispatcherNew : public NetQueryCallback {
|
||||
public:
|
||||
virtual void send_with_callback(NetQueryPtr query, ActorShared<NetQueryCallback> callback, uint64 sequence_id) = 0;
|
||||
static ActorOwn<MultiSequenceDispatcherNew> create(Slice name);
|
||||
};
|
||||
|
||||
using MultiSequenceDispatcher = MultiSequenceDispatcherOld;
|
||||
|
||||
} // namespace td
|
||||
|
@ -383,6 +383,9 @@ inline StringBuilder &operator<<(StringBuilder &stream, const NetQuery &net_quer
|
||||
}
|
||||
|
||||
inline StringBuilder &operator<<(StringBuilder &stream, const NetQueryPtr &net_query_ptr) {
|
||||
if (net_query_ptr.empty()) {
|
||||
return stream << "[Query: null]";
|
||||
}
|
||||
return stream << *net_query_ptr;
|
||||
}
|
||||
|
||||
|
@ -289,6 +289,7 @@ endif()
|
||||
set(TDUTILS_TEST_SOURCE
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/test/bitmask.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/test/buffer.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/test/ChainScheduler.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/test/ConcurrentHashMap.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/test/crypto.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/test/Enumerator.cpp
|
||||
|
291
tdutils/td/utils/ChainScheduler.h
Normal file
291
tdutils/td/utils/ChainScheduler.h
Normal file
@ -0,0 +1,291 @@
|
||||
#pragma once
|
||||
|
||||
#include "td/utils/algorithm.h"
|
||||
#include "td/utils/Container.h"
|
||||
#include "td/utils/List.h"
|
||||
#include "td/utils/optional.h"
|
||||
#include "td/utils/Span.h"
|
||||
#include "td/utils/tests.h"
|
||||
#include "td/utils/VectorQueue.h"
|
||||
#include "td/utils/Random.h"
|
||||
|
||||
#include <map>
|
||||
#include <vector>
|
||||
#include <set>
|
||||
#include <numeric>
|
||||
|
||||
namespace td {
|
||||
|
||||
template <class ExtraT = Unit>
|
||||
class ChainScheduler {
|
||||
public:
|
||||
using TaskId = uint64;
|
||||
using ChainId = uint64;
|
||||
TaskId create_task(td::Span<ChainId> chains, ExtraT extra = {});
|
||||
ExtraT *get_task_extra(TaskId task_id);
|
||||
|
||||
struct TaskWithParents {
|
||||
TaskId task_id{};
|
||||
std::vector<TaskId> parents;
|
||||
};
|
||||
|
||||
optional<TaskWithParents> start_next_task();
|
||||
void finish_task(TaskId task_id);
|
||||
void reset_task(TaskId task_id);
|
||||
template <class ExtraTT>
|
||||
friend td::StringBuilder &operator<<(StringBuilder &sb, ChainScheduler<ExtraTT> &scheduler);
|
||||
|
||||
private:
|
||||
struct ChainNode : ListNode {
|
||||
TaskId task_id{};
|
||||
};
|
||||
|
||||
class Chain {
|
||||
public:
|
||||
void add_task(ChainNode *node) {
|
||||
head_.put_back(node);
|
||||
}
|
||||
|
||||
optional<TaskId> get_first() {
|
||||
if (head_.empty()) {
|
||||
return {};
|
||||
}
|
||||
return static_cast<ChainNode &>(*head_.get_next()).task_id;
|
||||
}
|
||||
|
||||
optional<TaskId> get_child(ChainNode *chain_node) {
|
||||
if (chain_node->get_next() == head_.end()) {
|
||||
return {};
|
||||
}
|
||||
return static_cast<ChainNode &>(*chain_node->get_next()).task_id;
|
||||
}
|
||||
optional<TaskId> get_parent(ChainNode *chain_node) {
|
||||
if (chain_node->get_prev() == head_.end()) {
|
||||
return {};
|
||||
}
|
||||
return static_cast<ChainNode &>(*chain_node->get_prev()).task_id;
|
||||
}
|
||||
|
||||
void finish_task(ChainNode *node) {
|
||||
node->remove();
|
||||
}
|
||||
|
||||
bool empty() const {
|
||||
return head_.empty();
|
||||
}
|
||||
|
||||
void foreach(std::function<void(TaskId)> f) const {
|
||||
for (auto it = head_.begin(); it != head_.end(); it = it->get_next()) {
|
||||
f(static_cast<const ChainNode &>(*it).task_id);
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
ListNode head_;
|
||||
};
|
||||
struct ChainInfo {
|
||||
Chain chain;
|
||||
uint32 active_tasks{};
|
||||
};
|
||||
struct TaskChainInfo {
|
||||
ChainNode chain_node;
|
||||
ChainId chain_id{};
|
||||
ChainInfo *chain_info{};
|
||||
bool waiting_for_parent{};
|
||||
};
|
||||
struct Task {
|
||||
enum class State { Pending, Active } state{State::Pending};
|
||||
std::vector<TaskChainInfo> chains;
|
||||
ExtraT extra;
|
||||
};
|
||||
std::map<ChainId, ChainInfo> chains_;
|
||||
std::map<ChainId, TaskId> limited_tasks_;
|
||||
Container<Task> tasks_;
|
||||
VectorQueue<TaskId> pending_tasks_;
|
||||
|
||||
void on_parent_is_ready(TaskId task_id, ChainId chain_id) {
|
||||
auto *task = tasks_.get(task_id);
|
||||
CHECK(task);
|
||||
for (TaskChainInfo &task_chain_info : task->chains) {
|
||||
if (task_chain_info.chain_id == chain_id) {
|
||||
task_chain_info.waiting_for_parent = false;
|
||||
}
|
||||
}
|
||||
|
||||
try_start_task(task_id, task);
|
||||
}
|
||||
|
||||
void try_start_task(TaskId task_id, Task *task) {
|
||||
if (task->state != Task::State::Pending) {
|
||||
return;
|
||||
}
|
||||
for (TaskChainInfo &task_chain_info : task->chains) {
|
||||
if (task_chain_info.waiting_for_parent) {
|
||||
return;
|
||||
}
|
||||
ChainInfo &chain_info = chains_[task_chain_info.chain_id];
|
||||
if (chain_info.active_tasks >= 10) {
|
||||
limited_tasks_[task_chain_info.chain_id] = task_id;
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
do_start_task(task_id, task);
|
||||
}
|
||||
|
||||
void do_start_task(TaskId task_id, Task *task) {
|
||||
for (TaskChainInfo &task_chain_info : task->chains) {
|
||||
ChainInfo &chain_info = chains_[task_chain_info.chain_id];
|
||||
chain_info.active_tasks++;
|
||||
}
|
||||
task->state = Task::State::Active;
|
||||
|
||||
pending_tasks_.push(task_id);
|
||||
notify_children(task);
|
||||
}
|
||||
|
||||
void notify_children(Task *task) {
|
||||
for (TaskChainInfo &task_chain_info : task->chains) {
|
||||
ChainInfo &chain_info = chains_[task_chain_info.chain_id];
|
||||
auto o_child = chain_info.chain.get_child(&task_chain_info.chain_node);
|
||||
if (o_child) {
|
||||
on_parent_is_ready(o_child.value(), task_chain_info.chain_id);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void inactivate_task(TaskId task_id, Task *task) {
|
||||
CHECK(task->state == Task::State::Active);
|
||||
task->state = Task::State::Pending;
|
||||
for (TaskChainInfo &task_chain_info : task->chains) {
|
||||
ChainInfo &chain_info = chains_[task_chain_info.chain_id];
|
||||
chain_info.active_tasks--;
|
||||
|
||||
auto it = limited_tasks_.find(task_chain_info.chain_id);
|
||||
if (it != limited_tasks_.end()) {
|
||||
auto limited_task_id = it->second;
|
||||
limited_tasks_.erase(it);
|
||||
if (limited_task_id != task_id) {
|
||||
try_start_task(limited_task_id, tasks_.get(limited_task_id));
|
||||
}
|
||||
}
|
||||
auto o_first = chain_info.chain.get_first();
|
||||
if (o_first) {
|
||||
auto first_task_id = o_first.unwrap();
|
||||
if (first_task_id != task_id) {
|
||||
try_start_task(first_task_id, tasks_.get(first_task_id));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void finish_chain_task(TaskChainInfo &task_chain_info) {
|
||||
auto &chain = task_chain_info.chain_info->chain;
|
||||
chain.finish_task(&task_chain_info.chain_node);
|
||||
if (chain.empty()) {
|
||||
chains_.erase(task_chain_info.chain_id);
|
||||
}
|
||||
}
|
||||
};
|
||||
template <class ExtraT>
|
||||
typename ChainScheduler<ExtraT>::TaskId ChainScheduler<ExtraT>::create_task(Span<ChainScheduler::ChainId> chains, ExtraT extra) {
|
||||
auto task_id = tasks_.create();
|
||||
Task &task = *tasks_.get(task_id);
|
||||
task.extra = std::move(extra);
|
||||
task.chains = transform(chains, [&](auto chain_id) {
|
||||
TaskChainInfo task_chain_info;
|
||||
ChainInfo &chain_info = chains_[chain_id];
|
||||
task_chain_info.chain_id = chain_id;
|
||||
task_chain_info.chain_info = &chain_info;
|
||||
task_chain_info.chain_node.task_id = task_id;
|
||||
return task_chain_info;
|
||||
});
|
||||
|
||||
for (TaskChainInfo &task_chain_info : task.chains) {
|
||||
auto &chain = task_chain_info.chain_info->chain;
|
||||
chain.add_task(&task_chain_info.chain_node);
|
||||
task_chain_info.waiting_for_parent = bool(chain.get_parent(&task_chain_info.chain_node));
|
||||
}
|
||||
|
||||
try_start_task(task_id, &task);
|
||||
return task_id;
|
||||
}
|
||||
template <class ExtraT>
|
||||
ExtraT *ChainScheduler<ExtraT>::get_task_extra(ChainScheduler::TaskId task_id) { // may return nullptr
|
||||
auto *task = tasks_.get(task_id);
|
||||
if (!task) {
|
||||
return nullptr;
|
||||
}
|
||||
return &task->extra;
|
||||
}
|
||||
template <class ExtraT>
|
||||
optional<typename ChainScheduler<ExtraT>::TaskWithParents> ChainScheduler<ExtraT>::start_next_task() {
|
||||
if (pending_tasks_.empty()) {
|
||||
return {};
|
||||
}
|
||||
auto task_id = pending_tasks_.pop();
|
||||
TaskWithParents res;
|
||||
res.task_id = task_id;
|
||||
auto *task = tasks_.get(task_id);
|
||||
CHECK(task);
|
||||
for (TaskChainInfo &task_chain_info : task->chains) {
|
||||
Chain &chain = task_chain_info.chain_info->chain;
|
||||
auto o_parent = chain.get_parent(&task_chain_info.chain_node);
|
||||
if (o_parent) {
|
||||
res.parents.push_back(o_parent.value());
|
||||
}
|
||||
}
|
||||
return res;
|
||||
}
|
||||
template <class ExtraT>
|
||||
void ChainScheduler<ExtraT>::finish_task(ChainScheduler::TaskId task_id) {
|
||||
auto *task = tasks_.get(task_id);
|
||||
CHECK(task);
|
||||
|
||||
inactivate_task(task_id, task);
|
||||
notify_children(task);
|
||||
|
||||
for (TaskChainInfo &task_chain_info : task->chains) {
|
||||
finish_chain_task(task_chain_info);
|
||||
}
|
||||
tasks_.erase(task_id);
|
||||
}
|
||||
template <class ExtraT>
|
||||
void ChainScheduler<ExtraT>::reset_task(ChainScheduler::TaskId task_id) {
|
||||
auto *task = tasks_.get(task_id);
|
||||
CHECK(task);
|
||||
inactivate_task(task_id, task);
|
||||
|
||||
for (TaskChainInfo &task_chain_info : task->chains) {
|
||||
ChainInfo &chain_info = chains_[task_chain_info.chain_id];
|
||||
task_chain_info.waiting_for_parent = bool(chain_info.chain.get_parent(&task_chain_info.chain_node));
|
||||
}
|
||||
|
||||
try_start_task(task_id, task);
|
||||
}
|
||||
template <class ExtraT>
|
||||
td::StringBuilder &operator<<(StringBuilder &sb, ChainScheduler<ExtraT> &scheduler) {
|
||||
// 1 print chains
|
||||
sb << "\n";
|
||||
for (auto &it : scheduler.chains_) {
|
||||
sb << "ChainId{" << it.first << "} ";
|
||||
sb << " active_cnt=" << it.second.active_tasks;
|
||||
sb << " : ";
|
||||
it.second.chain.foreach([&](auto task_id) {
|
||||
sb << *scheduler.get_task_extra(task_id);
|
||||
});
|
||||
sb << "\n";
|
||||
}
|
||||
scheduler.tasks_.for_each([&](auto id, auto &task) {
|
||||
sb << "Task: " << task.extra;
|
||||
sb << " state =" << static_cast<int>(task.state);
|
||||
for (auto& task_chain_info : task.chains) {
|
||||
if (task_chain_info.waiting_for_parent) {
|
||||
sb << " wait " << *scheduler.get_task_extra(task_chain_info.chain_info->chain.get_parent(&task_chain_info.chain_node).value());
|
||||
}
|
||||
}
|
||||
sb << "\n";
|
||||
});
|
||||
return sb;
|
||||
}
|
||||
} // namespace td
|
@ -89,12 +89,24 @@ struct ListNode {
|
||||
ListNode *end() {
|
||||
return this;
|
||||
}
|
||||
const ListNode *begin() const {
|
||||
return next;
|
||||
}
|
||||
const ListNode *end() const {
|
||||
return this;
|
||||
}
|
||||
ListNode *get_next() {
|
||||
return next;
|
||||
}
|
||||
ListNode *get_prev() {
|
||||
return prev;
|
||||
}
|
||||
const ListNode *get_next() const {
|
||||
return next;
|
||||
}
|
||||
const ListNode *get_prev() const {
|
||||
return prev;
|
||||
}
|
||||
|
||||
protected:
|
||||
void clear() {
|
||||
|
@ -117,6 +117,15 @@ bool contains(const V &v, const T &value) {
|
||||
}
|
||||
return false;
|
||||
}
|
||||
template <class V, class F>
|
||||
bool all_of(const V &v, F &&f) {
|
||||
for (auto &x : v) {
|
||||
if (!f(x)) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
template <class T>
|
||||
void reset_to_empty(T &value) {
|
||||
|
164
tdutils/test/ChainScheduler.cpp
Normal file
164
tdutils/test/ChainScheduler.cpp
Normal file
@ -0,0 +1,164 @@
|
||||
#include "td/utils/algorithm.h"
|
||||
#include "td/utils/optional.h"
|
||||
#include "td/utils/Span.h"
|
||||
#include "td/utils/tests.h"
|
||||
#include "td/utils/Random.h"
|
||||
|
||||
#include "td/utils/ChainScheduler.h"
|
||||
|
||||
#include <vector>
|
||||
#include <numeric>
|
||||
|
||||
TEST(ChainScheduler, Basic) {
|
||||
td::ChainScheduler<int> scheduler;
|
||||
using ChainId = td::ChainScheduler<int>::ChainId;
|
||||
using TaskId = td::ChainScheduler<int>::TaskId;
|
||||
for (int i = 0; i < 100; i++) {
|
||||
scheduler.create_task({ChainId{1}}, i);
|
||||
}
|
||||
int j = 0;
|
||||
while (j != 100) {
|
||||
std::vector<TaskId> tasks;
|
||||
while (true) {
|
||||
auto o_task_id = scheduler.start_next_task();
|
||||
if (!o_task_id) {
|
||||
break;
|
||||
}
|
||||
auto task_id = o_task_id.value().task_id;
|
||||
auto extra = *scheduler.get_task_extra(task_id);
|
||||
auto parents = td::transform(o_task_id.value().parents,
|
||||
[&](auto parent) { return *scheduler.get_task_extra(parent); });
|
||||
LOG(ERROR) << "start " << extra << parents;
|
||||
CHECK(extra == j);
|
||||
j++;
|
||||
tasks.push_back(task_id);
|
||||
}
|
||||
for (auto &task_id : tasks) {
|
||||
auto extra = *scheduler.get_task_extra(task_id);
|
||||
LOG(ERROR) << "finish " << extra;
|
||||
scheduler.finish_task(task_id);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
struct Query;
|
||||
using QueryPtr = std::shared_ptr<Query>;
|
||||
using ChainId = td::ChainScheduler<QueryPtr>::ChainId;
|
||||
using TaskId = td::ChainScheduler<QueryPtr>::TaskId;
|
||||
struct Query {
|
||||
int id{};
|
||||
TaskId task_id{};
|
||||
bool is_ok{};
|
||||
friend td::StringBuilder &operator << (td::StringBuilder &sb, const Query &q) {
|
||||
return sb << "Q{" << q.id << "}";
|
||||
}
|
||||
};
|
||||
TEST(ChainScheduler, Stress) {
|
||||
td::Random::Xorshift128plus rnd(123);
|
||||
int max_query_id = 1000;
|
||||
int MAX_INFLIGHT_QUERIES = 20;
|
||||
int ChainsN = 4;
|
||||
|
||||
struct QueryWithParents {
|
||||
QueryPtr id;
|
||||
std::vector<QueryPtr> parents;
|
||||
};
|
||||
std::vector<QueryWithParents> active_queries;
|
||||
|
||||
td::ChainScheduler<QueryPtr> scheduler;
|
||||
std::vector<std::vector<QueryPtr>> chains(ChainsN);
|
||||
int inflight_queries{};
|
||||
int current_query_id{};
|
||||
bool done = false;
|
||||
|
||||
auto schedule_new_query = [&] {
|
||||
if (current_query_id > max_query_id) {
|
||||
if (inflight_queries == 0) {
|
||||
done = true;
|
||||
}
|
||||
return;
|
||||
}
|
||||
if (inflight_queries >= MAX_INFLIGHT_QUERIES) {
|
||||
return;
|
||||
}
|
||||
auto query_id = current_query_id++;
|
||||
auto query = std::make_shared<Query>();
|
||||
query->id = query_id;
|
||||
int chain_n = rnd.fast(1, ChainsN);
|
||||
std::vector<ChainId> chain_ids(ChainsN);
|
||||
std::iota(chain_ids.begin(), chain_ids.end(), 0);
|
||||
td::random_shuffle(td::as_mutable_span(chain_ids), rnd);
|
||||
chain_ids.resize(chain_n);
|
||||
for (auto chain_id : chain_ids) {
|
||||
chains[chain_id].push_back(query);
|
||||
}
|
||||
auto task_id = scheduler.create_task(chain_ids, query);
|
||||
query->task_id = task_id;
|
||||
inflight_queries++;
|
||||
};
|
||||
|
||||
auto check_parents_ok = [&] (const QueryWithParents &query_with_parents) -> bool {
|
||||
return td::all_of(query_with_parents.parents, [](auto &parent) { return parent->is_ok; });
|
||||
};
|
||||
|
||||
auto to_query_ptr = [&](TaskId task_id) {
|
||||
return *scheduler.get_task_extra(task_id);
|
||||
};
|
||||
auto flush_pending_queries = [&]{
|
||||
while (true) {
|
||||
auto o_task_with_parents = scheduler.start_next_task();
|
||||
if (!o_task_with_parents) {
|
||||
break;
|
||||
}
|
||||
auto task_with_parents = o_task_with_parents.unwrap();
|
||||
QueryWithParents query_with_parents;
|
||||
query_with_parents.id = to_query_ptr(task_with_parents.task_id);
|
||||
query_with_parents.parents = td::transform(task_with_parents.parents, to_query_ptr);
|
||||
active_queries.push_back(query_with_parents);
|
||||
}
|
||||
};
|
||||
auto execute_one_query = [&]() {
|
||||
if (active_queries.empty()) {
|
||||
return;
|
||||
}
|
||||
auto it = active_queries.begin() + rnd.fast(0, (int)active_queries.size() - 1);
|
||||
auto query_with_parents = *it;
|
||||
active_queries.erase(it);
|
||||
|
||||
auto query = query_with_parents.id;
|
||||
if (rnd.fast(0, 20) == 0) {
|
||||
scheduler.finish_task(query->task_id);
|
||||
inflight_queries--;
|
||||
LOG(ERROR) << "Fail " << query->id;
|
||||
} else if (check_parents_ok(query_with_parents)) {
|
||||
query->is_ok = true;
|
||||
scheduler.finish_task(query->task_id);
|
||||
inflight_queries--;
|
||||
LOG(ERROR) << "OK " << query->id;
|
||||
} else {
|
||||
scheduler.reset_task(query->task_id);
|
||||
}
|
||||
};
|
||||
|
||||
td::RandomSteps steps({{schedule_new_query, 100}, {execute_one_query, 100}});
|
||||
while (!done) {
|
||||
steps.step(rnd);
|
||||
flush_pending_queries();
|
||||
// LOG(ERROR) << scheduler;
|
||||
}
|
||||
for (auto &chain : chains) {
|
||||
int prev_ok = -1;
|
||||
int failed_cnt = 0;
|
||||
int ok_cnt = 0;
|
||||
for (auto &q : chain) {
|
||||
if (q->is_ok) {
|
||||
CHECK(prev_ok < q->id) ;
|
||||
prev_ok = q->id;
|
||||
ok_cnt++;
|
||||
} else {
|
||||
failed_cnt++;
|
||||
}
|
||||
}
|
||||
LOG(ERROR) << "Chain ok " << ok_cnt << " failed " << failed_cnt;
|
||||
}
|
||||
}
|
Loading…
Reference in New Issue
Block a user