From b1222a9bb705332aefa33066fead202873789409 Mon Sep 17 00:00:00 2001 From: Arseny Smirnov Date: Wed, 29 Jul 2020 16:49:35 +0300 Subject: [PATCH] MultiClient: draft GitOrigin-RevId: 4d1bdd6ad99909ce7ad94cfd32a43262051a6d18 --- td/telegram/Client.cpp | 500 +++++++++++++++++++++++++++-------------- td/telegram/Client.h | 31 +++ test/tdclient.cpp | 25 +++ 3 files changed, 388 insertions(+), 168 deletions(-) diff --git a/td/telegram/Client.cpp b/td/telegram/Client.cpp index b0fb29be7..a85b14d36 100644 --- a/td/telegram/Client.cpp +++ b/td/telegram/Client.cpp @@ -15,6 +15,7 @@ #include "td/utils/logging.h" #include "td/utils/misc.h" #include "td/utils/MpscPollableQueue.h" +#include "td/utils/port/RwMutex.h" #include "td/utils/port/thread.h" #include @@ -26,91 +27,13 @@ namespace td { -#if TD_THREAD_UNSUPPORTED || TD_EVENTFD_UNSUPPORTED - -class Client::Impl final { +class TdReceiver { public: - Impl() { - concurrent_scheduler_ = make_unique(); - concurrent_scheduler_->init(0); - class Callback : public TdCallback { - public: - explicit Callback(Impl *client) : client_(client) { - } - void on_result(std::uint64_t id, td_api::object_ptr result) override { - client_->responses_.push_back({id, std::move(result)}); - } - void on_error(std::uint64_t id, td_api::object_ptr error) override { - client_->responses_.push_back({id, std::move(error)}); - } - - Callback(const Callback &) = delete; - Callback &operator=(const Callback &) = delete; - Callback(Callback &&) = delete; - Callback &operator=(Callback &&) = delete; - ~Callback() override { - client_->closed_ = true; - Scheduler::instance()->yield(); - } - - private: - Impl *client_; - }; - td_ = concurrent_scheduler_->create_actor_unsafe(0, "Td", make_unique(this)); - concurrent_scheduler_->start(); - } - - void send(Request request) { - requests_.push_back(std::move(request)); - } - - Response receive(double timeout) { - if (!requests_.empty()) { - auto guard = concurrent_scheduler_->get_main_guard(); - for (auto &request : requests_) { - send_closure_later(td_, &Td::request, request.id, std::move(request.function)); - } - requests_.clear(); - } - - if (responses_.empty()) { - concurrent_scheduler_->run_main(0); - } else { - ConcurrentScheduler::emscripten_clear_main_timeout(); - } - if (!responses_.empty()) { - auto result = std::move(responses_.front()); - responses_.pop_front(); - return result; - } - return {0, nullptr}; - } - - Impl(const Impl &) = delete; - Impl &operator=(const Impl &) = delete; - Impl(Impl &&) = delete; - Impl &operator=(Impl &&) = delete; - ~Impl() { - { - auto guard = concurrent_scheduler_->get_main_guard(); - td_.reset(); - } - while (!closed_) { - concurrent_scheduler_->run_main(0); - } - concurrent_scheduler_.reset(); - } - - private: - std::deque responses_; - std::vector requests_; - unique_ptr concurrent_scheduler_; - ActorOwn td_; - bool closed_ = false; + virtual ~TdReceiver() = default; + virtual MultiClient::Response receive(double timeout) = 0; + virtual unique_ptr create_callback(MultiClient::ClientId client_id) = 0; }; -#else - class MultiTd : public Actor { public: void create(int32 td_id, unique_ptr callback) { @@ -134,10 +57,10 @@ class MultiTd : public Actor { set_context(old_context); set_tag(old_tag); } - void send(int32 td_id, Client::Request request) { - auto &td = tds_[td_id]; + void send(MultiClient::ClientId client_id, MultiClient::RequestId request_id, MultiClient::Function function) { + auto &td = tds_[client_id]; CHECK(!td.empty()); - send_closure(td, &Td::request, request.id, std::move(request.function)); + send_closure(td, &Td::request, request_id, std::move(function)); } void destroy(int32 td_id) { auto size = tds_.erase(td_id); @@ -148,25 +71,213 @@ class MultiTd : public Actor { std::unordered_map > tds_; }; -class MultiImpl { +#if TD_THREAD_UNSUPPORTED || TD_EVENTFD_UNSUPPORTED +class TdReceiverSimple : public TdReceiver { public: - static std::shared_ptr get() { - static std::mutex mutex; - static std::vector > impls; - std::unique_lock lock(mutex); - if (impls.size() == 0) { - impls.resize(clamp(thread::hardware_concurrency(), 8u, 1000u) * 5 / 4); + TdReceiverSimple() { + } + MultiClient::Response receive(double timeout) final { + if (!responses_.empty()) { + auto result = std::move(responses_.front()); + responses_.pop_front(); + return result; } - auto &impl = *std::min_element(impls.begin(), impls.end(), - [](auto &a, auto &b) { return a.lock().use_count() < b.lock().use_count(); }); - auto res = impl.lock(); - if (!res) { - res = std::make_shared(); - impl = res; - } - return res; + return {0, 0, nullptr}; + } + unique_ptr create_callback(MultiClient::ClientId client_id) final { + class Callback : public TdCallback { + public: + explicit Callback(MultiClient::ClientId client_id, TdReceiverSimple *impl) : client_id_(client_id), impl_(impl) { + } + void on_result(std::uint64_t id, td_api::object_ptr result) override { + impl_->responses_.push_back({client_id_, id, std::move(result)}); + } + void on_error(std::uint64_t id, td_api::object_ptr error) override { + impl_->responses_.push_back({client_id_, id, std::move(error)}); + } + Callback(const Callback &) = delete; + Callback &operator=(const Callback &) = delete; + Callback(Callback &&) = delete; + Callback &operator=(Callback &&) = delete; + ~Callback() override { + impl_->responses_.push_back({client_id_, 0, nullptr}); + } + + private: + MultiClient::ClientId client_id_; + TdReceiverSimple *impl_; + }; + return td::make_unique(client_id, this); } + private: + std::deque responses_; +}; + +class MultiClient::Impl final { + public: + Impl() { + concurrent_scheduler_ = make_unique(); + concurrent_scheduler_->init(0); + receiver_ = make_unique(); + concurrent_scheduler_->start(); + } + + ClientId create_client() { + auto client_id = ++client_id_; + tds_[client_id] = concurrent_scheduler_->create_actor_unsafe(0, "Td", receiver_->create_callback(client_id)); + return client_id; + } + + void send(ClientId client_id, RequestId request_id, Function function) { + Request request; + request.client_id = client_id; + request.id = request_id; + request.function = std::move(function); + requests_.push_back(std::move(request)); + } + + Response receive(double timeout) { + if (!requests_.empty()) { + auto guard = concurrent_scheduler_->get_main_guard(); + for (auto &request : requests_) { + auto &td = tds_[request.client_id]; + CHECK(!td.empty()); + send_closure_later(td, &Td::request, request.id, std::move(request.function)); + } + requests_.clear(); + } + + auto response = receiver_->receive(0); + if (response.client_id == 0) { + concurrent_scheduler_->run_main(0); + response = receiver_->receive(0); + } else { + ConcurrentScheduler::emscripten_clear_main_timeout(); + } + if (response.client_id != 0 && !response.object) { + auto guard = concurrent_scheduler_->get_main_guard(); + tds_.erase(response.client_id); + } + return response; + } + + static Object execute(Function &&function) { + return Td::static_request(std::move(function)); + } + + ~Impl() { + { + auto guard = concurrent_scheduler_->get_main_guard(); + for (auto &td : tds_) { + td.second = {}; + } + } + while (!tds_.empty()) { + receive(10); + } + concurrent_scheduler_->finish(); + } + + private: + friend class Client::Impl; + td::unique_ptr receiver_; + struct Request { + ClientId client_id; + RequestId id; + Function function; + }; + std::vector requests_; + unique_ptr concurrent_scheduler_; + ClientId client_id_{0}; + std::unordered_map > tds_; +}; + +class Client::Impl final { + public: + Impl() { + client_id_ = impl_.create_client(); + } + + void send(Request request) { + impl_.send(client_id_, request.id, std::move(request.function)); + } + + Response receive(double timeout) { + auto response = impl_.receive(timeout); + Response old_response; + old_response.id = response.id; + old_response.object = std::move(response.object); + return old_response; + } + + Impl(const Impl &) = delete; + Impl &operator=(const Impl &) = delete; + Impl(Impl &&) = delete; + Impl &operator=(Impl &&) = delete; + + private: + MultiClient::Impl impl_; + MultiClient::ClientId client_id_; +}; + +#else + +class TdReceiverTs : public TdReceiver { + public: + TdReceiverTs() { + output_queue_ = std::make_shared(); + output_queue_->init(); + } + MultiClient::Response receive(double timeout) final { + if (output_queue_ready_cnt_ == 0) { + output_queue_ready_cnt_ = output_queue_->reader_wait_nonblock(); + } + if (output_queue_ready_cnt_ > 0) { + output_queue_ready_cnt_--; + return output_queue_->reader_get_unsafe(); + } + if (timeout != 0) { + output_queue_->reader_get_event_fd().wait(static_cast(timeout * 1000)); + return receive(0); + } + return {0, 0, nullptr}; + } + unique_ptr create_callback(MultiClient::ClientId client_id) final { + class Callback : public TdCallback { + public: + explicit Callback(MultiClient::ClientId client_id, std::shared_ptr output_queue) + : client_id_(client_id), output_queue_(std::move(output_queue)) { + } + void on_result(std::uint64_t id, td_api::object_ptr result) override { + output_queue_->writer_put({client_id_, id, std::move(result)}); + } + void on_error(std::uint64_t id, td_api::object_ptr error) override { + output_queue_->writer_put({client_id_, id, std::move(error)}); + } + Callback(const Callback &) = delete; + Callback &operator=(const Callback &) = delete; + Callback(Callback &&) = delete; + Callback &operator=(Callback &&) = delete; + ~Callback() override { + output_queue_->writer_put({client_id_, 0, nullptr}); + } + + private: + MultiClient::ClientId client_id_; + std::shared_ptr output_queue_; + }; + return td::make_unique(client_id, output_queue_); + } + + private: + using OutputQueue = MpscPollableQueue; + std::shared_ptr output_queue_; + int output_queue_ready_cnt_{0}; +}; + +class MultiImpl { + public: MultiImpl() { concurrent_scheduler_ = std::make_shared(); concurrent_scheduler_->init(3); @@ -187,19 +298,19 @@ class MultiImpl { MultiImpl(MultiImpl &&) = delete; MultiImpl &operator=(MultiImpl &&) = delete; - int32 create_id() { - static std::atomic id_{0}; - return id_.fetch_add(1) + 1; + int32 create(TdReceiver &receiver) { + auto id = create_id(); + create(id, receiver.create_callback(id)); + return id; } - void create(int32 td_id, unique_ptr callback) { + [[deprecated]] void send(int32 td_id, Client::Request request) { auto guard = concurrent_scheduler_->get_send_guard(); - send_closure(multi_td_, &MultiTd::create, td_id, std::move(callback)); + send_closure(multi_td_, &MultiTd::send, td_id, request.id, std::move(request.function)); } - - void send(int32 td_id, Client::Request request) { + void send(MultiClient::ClientId client_id, MultiClient::RequestId request_id, MultiClient::Function function) { auto guard = concurrent_scheduler_->get_send_guard(); - send_closure(multi_td_, &MultiTd::send, td_id, std::move(request)); + send_closure(multi_td_, &MultiTd::send, client_id, request_id, std::move(function)); } void destroy(int32 td_id) { @@ -221,40 +332,92 @@ class MultiImpl { std::shared_ptr concurrent_scheduler_; thread scheduler_thread_; ActorOwn multi_td_; + + int32 create_id() { + static std::atomic id_{0}; + return id_.fetch_add(1) + 1; + } + + void create(int32 td_id, unique_ptr callback) { + auto guard = concurrent_scheduler_->get_send_guard(); + send_closure(multi_td_, &MultiTd::create, td_id, std::move(callback)); + } +}; + +class MultiImplPool { + public: + std::shared_ptr get() { + std::unique_lock lock(mutex_); + if (impls_.size() == 0) { + impls_.resize(clamp(thread::hardware_concurrency(), 8u, 1000u) * 5 / 4); + } + auto &impl = *std::min_element(impls_.begin(), impls_.end(), + [](auto &a, auto &b) { return a.lock().use_count() < b.lock().use_count(); }); + auto res = impl.lock(); + if (!res) { + res = std::make_shared(); + impl = res; + } + return res; + } + + private: + std::mutex mutex_; + std::vector > impls_; +}; + +class MultiClient::Impl final { + public: + ClientId create_client() { + auto impl = pool_.get(); + auto client_id = impl->create(*receiver_); + { + auto lock = impls_mutex_.lock_write().move_as_ok(); + impls_[client_id] = std::move(impl); + } + return client_id; + } + void send(ClientId client_id, RequestId request_id, Function function) { + auto lock = impls_mutex_.lock_read().move_as_ok(); + auto it = impls_.find(client_id); + CHECK(it != impls_.end()); + it->second->send(client_id, request_id, std::move(function)); + } + Response receive(double timeout) { + auto res = receiver_->receive(timeout); + if (res.client_id != 0 && !res.object) { + auto lock = impls_mutex_.lock_write().move_as_ok(); + impls_.erase(res.client_id); + } + return res; + } + static Object execute(Function &&function) { + return Td::static_request(std::move(function)); + } + + ~Impl() { + for (auto &it : impls_) { + it.second->destroy(it.first); + } + while (!impls_.empty()) { + receive(10); + } + } + + private: + MultiImplPool pool_; + td::RwMutex impls_mutex_; + std::unordered_map > impls_; + td::unique_ptr receiver_{td::make_unique()}; }; class Client::Impl final { public: - using OutputQueue = MpscPollableQueue; Impl() { - multi_impl_ = MultiImpl::get(); - td_id_ = multi_impl_->create_id(); - output_queue_ = std::make_shared(); - output_queue_->init(); - - class Callback : public TdCallback { - public: - explicit Callback(std::shared_ptr output_queue) : output_queue_(std::move(output_queue)) { - } - void on_result(std::uint64_t id, td_api::object_ptr result) override { - output_queue_->writer_put({id, std::move(result)}); - } - void on_error(std::uint64_t id, td_api::object_ptr error) override { - output_queue_->writer_put({id, std::move(error)}); - } - Callback(const Callback &) = delete; - Callback &operator=(const Callback &) = delete; - Callback(Callback &&) = delete; - Callback &operator=(Callback &&) = delete; - ~Callback() override { - output_queue_->writer_put({0, nullptr}); - } - - private: - std::shared_ptr output_queue_; - }; - - multi_impl_->create(td_id_, td::make_unique(output_queue_)); + static MultiImplPool pool; + multi_impl_ = pool.get(); + receiver_ = make_unique(); + td_id_ = multi_impl_->create(*receiver_); } void send(Client::Request request) { @@ -263,18 +426,20 @@ class Client::Impl final { return; } - multi_impl_->send(td_id_, std::move(request)); + multi_impl_->send(td_id_, request.id, std::move(request.function)); } Client::Response receive(double timeout) { - VLOG(td_requests) << "Begin to wait for updates with timeout " << timeout; - auto is_locked = receive_lock_.exchange(true); - CHECK(!is_locked); - auto response = receive_unlocked(timeout); - is_locked = receive_lock_.exchange(false); - CHECK(is_locked); - VLOG(td_requests) << "End to wait for updates, returning object " << response.id << ' ' << response.object.get(); - return response; + auto res = receiver_->receive(0); + + if (res.client_id != 0 && !res.object) { + is_closed_ = true; + } + + Client::Response old_res; + old_res.id = res.id; + old_res.object = std::move(res.object); + return old_res; } Impl(const Impl &) = delete; @@ -290,31 +455,10 @@ class Client::Impl final { private: std::shared_ptr multi_impl_; + td::unique_ptr receiver_; - std::shared_ptr output_queue_; - int output_queue_ready_cnt_{0}; - std::atomic receive_lock_{false}; bool is_closed_{false}; int32 td_id_; - - Client::Response receive_unlocked(double timeout) { - if (output_queue_ready_cnt_ == 0) { - output_queue_ready_cnt_ = output_queue_->reader_wait_nonblock(); - } - if (output_queue_ready_cnt_ > 0) { - output_queue_ready_cnt_--; - auto res = output_queue_->reader_get_unsafe(); - if (res.object == nullptr && res.id == 0) { - is_closed_ = true; - } - return res; - } - if (timeout != 0) { - output_queue_->reader_get_event_fd().wait(static_cast(timeout * 1000)); - return receive_unlocked(0); - } - return {0, nullptr}; - } }; #endif @@ -342,4 +486,24 @@ Client::~Client() = default; Client::Client(Client &&other) = default; Client &Client::operator=(Client &&other) = default; +MultiClient::MultiClient() : impl_(std::make_unique()) { +} + +MultiClient::ClientId MultiClient::create_client() { + return impl_->create_client(); +} +void MultiClient::send(ClientId client_id, RequestId request_id, Function function) { + impl_->send(client_id, request_id, std::move(function)); +} +MultiClient::Response MultiClient::receive(double timeout) { + return impl_->receive(timeout); +} +MultiClient::Object MultiClient::execute(Function &&function) { + return Impl::execute(std::move(function)); +} + +MultiClient::~MultiClient() = default; +MultiClient::MultiClient(MultiClient &&other) = default; +MultiClient &MultiClient::operator=(MultiClient &&other) = default; + } // namespace td diff --git a/td/telegram/Client.h b/td/telegram/Client.h index ac49f1550..6eeea6018 100644 --- a/td/telegram/Client.h +++ b/td/telegram/Client.h @@ -127,6 +127,37 @@ class Client final { Client &operator=(Client &&other); private: + friend class MultiClient; + class Impl; + std::unique_ptr impl_; +}; + +// --- EXPERIMENTAL --- +class MultiClient final { + public: + MultiClient(); + + using ClientId = std::int32_t; + using RequestId = std::uint64_t; + using Function = td_api::object_ptr; + using Object = td_api::object_ptr; + struct Response { + ClientId client_id; + RequestId id; + Object object; + }; + + ClientId create_client(); + void send(ClientId client_id, RequestId request_id, Function function); + Response receive(double timeout); + static Object execute(Function &&function); + + ~MultiClient(); + MultiClient(MultiClient &&other); + MultiClient &operator=(MultiClient &&other); + + private: + friend class Client; class Impl; std::unique_ptr impl_; }; diff --git a/test/tdclient.cpp b/test/tdclient.cpp index ed9e64823..27c22237e 100644 --- a/test/tdclient.cpp +++ b/test/tdclient.cpp @@ -895,6 +895,31 @@ TEST(Client, Multi) { thread.join(); } } +TEST(Client, MultiNew) { + std::vector threads; + td::MultiClient client; + int threads_n = 4; + int clients_n = 1000; + for (int i = 0; i < threads_n; i++) { + threads.emplace_back([&] { + for (int i = 0; i < clients_n; i++) { + auto id = client.create_client(); + client.send(id, 3, td::make_tl_object(3)); + } + }); + } + for (auto &thread : threads) { + thread.join(); + } + + std::set ids; + while (ids.size() * threads_n * clients_n) { + auto event = client.receive(10); + if (event.client_id != 0 && event.id == 3) { + ids.insert(event.client_id); + } + } +} #endif TEST(PartsManager, hands) {