diff --git a/td/telegram/Client.cpp b/td/telegram/Client.cpp index a85b14d36..55dcb67b0 100644 --- a/td/telegram/Client.cpp +++ b/td/telegram/Client.cpp @@ -7,6 +7,7 @@ #include "td/telegram/Client.h" #include "td/telegram/Td.h" +#include "td/telegram/TdCallback.h" #include "td/actor/actor.h" @@ -20,20 +21,13 @@ #include #include -#include #include #include +#include #include namespace td { -class TdReceiver { - public: - virtual ~TdReceiver() = default; - virtual MultiClient::Response receive(double timeout) = 0; - virtual unique_ptr create_callback(MultiClient::ClientId client_id) = 0; -}; - class MultiTd : public Actor { public: void create(int32 td_id, unique_ptr callback) { @@ -43,12 +37,12 @@ class MultiTd : public Actor { string name = "Td"; class TdActorContext : public ActorContext { public: - explicit TdActorContext(std::string tag) : tag_(std::move(tag)) { + explicit TdActorContext(string tag) : tag_(std::move(tag)) { } int32 get_id() const override { return 0x172ae58d; } - std::string tag_; + string tag_; }; auto context = std::make_shared(to_string(td_id)); auto old_context = set_context(context); @@ -57,26 +51,26 @@ class MultiTd : public Actor { set_context(old_context); set_tag(old_tag); } + void send(MultiClient::ClientId client_id, MultiClient::RequestId request_id, MultiClient::Function function) { auto &td = tds_[client_id]; CHECK(!td.empty()); send_closure(td, &Td::request, request_id, std::move(function)); } + void destroy(int32 td_id) { auto size = tds_.erase(td_id); CHECK(size == 1); } private: - std::unordered_map > tds_; + std::unordered_map> tds_; }; #if TD_THREAD_UNSUPPORTED || TD_EVENTFD_UNSUPPORTED -class TdReceiverSimple : public TdReceiver { +class TdReceiver { public: - TdReceiverSimple() { - } - MultiClient::Response receive(double timeout) final { + MultiClient::Response receive(double timeout) { if (!responses_.empty()) { auto result = std::move(responses_.front()); responses_.pop_front(); @@ -84,15 +78,16 @@ class TdReceiverSimple : public TdReceiver { } return {0, 0, nullptr}; } - unique_ptr create_callback(MultiClient::ClientId client_id) final { + + unique_ptr create_callback(MultiClient::ClientId client_id) { class Callback : public TdCallback { public: - explicit Callback(MultiClient::ClientId client_id, TdReceiverSimple *impl) : client_id_(client_id), impl_(impl) { + Callback(MultiClient::ClientId client_id, TdReceiver *impl) : client_id_(client_id), impl_(impl) { } - void on_result(std::uint64_t id, td_api::object_ptr result) override { + void on_result(uint64 id, td_api::object_ptr result) override { impl_->responses_.push_back({client_id_, id, std::move(result)}); } - void on_error(std::uint64_t id, td_api::object_ptr error) override { + void on_error(uint64 id, td_api::object_ptr error) override { impl_->responses_.push_back({client_id_, id, std::move(error)}); } Callback(const Callback &) = delete; @@ -105,13 +100,13 @@ class TdReceiverSimple : public TdReceiver { private: MultiClient::ClientId client_id_; - TdReceiverSimple *impl_; + TdReceiver *impl_; }; return td::make_unique(client_id, this); } private: - std::deque responses_; + std::queue responses_; }; class MultiClient::Impl final { @@ -119,7 +114,7 @@ class MultiClient::Impl final { Impl() { concurrent_scheduler_ = make_unique(); concurrent_scheduler_->init(0); - receiver_ = make_unique(); + receiver_ = make_unique(); concurrent_scheduler_->start(); } @@ -162,10 +157,11 @@ class MultiClient::Impl final { return response; } - static Object execute(Function &&function) { - return Td::static_request(std::move(function)); - } - + Impl() = default; + Impl(const Impl &) = delete; + Impl &operator=(const Impl &) = delete; + Impl(Impl &&) = delete; + Impl &operator=(Impl &&) = delete; ~Impl() { { auto guard = concurrent_scheduler_->get_main_guard(); @@ -180,8 +176,7 @@ class MultiClient::Impl final { } private: - friend class Client::Impl; - td::unique_ptr receiver_; + unique_ptr receiver_; struct Request { ClientId client_id; RequestId id; @@ -190,7 +185,7 @@ class MultiClient::Impl final { std::vector requests_; unique_ptr concurrent_scheduler_; ClientId client_id_{0}; - std::unordered_map > tds_; + std::unordered_map> tds_; }; class Client::Impl final { @@ -211,11 +206,6 @@ class Client::Impl final { return old_response; } - Impl(const Impl &) = delete; - Impl &operator=(const Impl &) = delete; - Impl(Impl &&) = delete; - Impl &operator=(Impl &&) = delete; - private: MultiClient::Impl impl_; MultiClient::ClientId client_id_; @@ -223,13 +213,14 @@ class Client::Impl final { #else -class TdReceiverTs : public TdReceiver { +class TdReceiver { public: - TdReceiverTs() { + TdReceiver() { output_queue_ = std::make_shared(); output_queue_->init(); } - MultiClient::Response receive(double timeout) final { + + MultiClient::Response receive(double timeout) { if (output_queue_ready_cnt_ == 0) { output_queue_ready_cnt_ = output_queue_->reader_wait_nonblock(); } @@ -243,16 +234,17 @@ class TdReceiverTs : public TdReceiver { } return {0, 0, nullptr}; } - unique_ptr create_callback(MultiClient::ClientId client_id) final { + + unique_ptr create_callback(MultiClient::ClientId client_id) { class Callback : public TdCallback { public: explicit Callback(MultiClient::ClientId client_id, std::shared_ptr output_queue) : client_id_(client_id), output_queue_(std::move(output_queue)) { } - void on_result(std::uint64_t id, td_api::object_ptr result) override { + void on_result(uint64 id, td_api::object_ptr result) override { output_queue_->writer_put({client_id_, id, std::move(result)}); } - void on_error(std::uint64_t id, td_api::object_ptr error) override { + void on_error(uint64 id, td_api::object_ptr error) override { output_queue_->writer_put({client_id_, id, std::move(error)}); } Callback(const Callback &) = delete; @@ -304,10 +296,6 @@ class MultiImpl { return id; } - [[deprecated]] void send(int32 td_id, Client::Request request) { - auto guard = concurrent_scheduler_->get_send_guard(); - send_closure(multi_td_, &MultiTd::send, td_id, request.id, std::move(request.function)); - } void send(MultiClient::ClientId client_id, MultiClient::RequestId request_id, MultiClient::Function function) { auto guard = concurrent_scheduler_->get_send_guard(); send_closure(multi_td_, &MultiTd::send, client_id, request_id, std::move(function)); @@ -333,9 +321,9 @@ class MultiImpl { thread scheduler_thread_; ActorOwn multi_td_; - int32 create_id() { - static std::atomic id_{0}; - return id_.fetch_add(1) + 1; + static int32 create_id() { + static std::atomic current_id{1}; + return current_id.fetch_add(1); } void create(int32 td_id, unique_ptr callback) { @@ -363,7 +351,7 @@ class MultiImplPool { private: std::mutex mutex_; - std::vector > impls_; + std::vector> impls_; }; class MultiClient::Impl final { @@ -377,12 +365,14 @@ class MultiClient::Impl final { } return client_id; } + void send(ClientId client_id, RequestId request_id, Function function) { auto lock = impls_mutex_.lock_read().move_as_ok(); auto it = impls_.find(client_id); CHECK(it != impls_.end()); it->second->send(client_id, request_id, std::move(function)); } + Response receive(double timeout) { auto res = receiver_->receive(timeout); if (res.client_id != 0 && !res.object) { @@ -391,10 +381,12 @@ class MultiClient::Impl final { } return res; } - static Object execute(Function &&function) { - return Td::static_request(std::move(function)); - } + Impl() = default; + Impl(const Impl &) = delete; + Impl &operator=(const Impl &) = delete; + Impl(Impl &&) = delete; + Impl &operator=(Impl &&) = delete; ~Impl() { for (auto &it : impls_) { it.second->destroy(it.first); @@ -406,9 +398,9 @@ class MultiClient::Impl final { private: MultiImplPool pool_; - td::RwMutex impls_mutex_; - std::unordered_map > impls_; - td::unique_ptr receiver_{td::make_unique()}; + RwMutex impls_mutex_; + std::unordered_map> impls_; + unique_ptr receiver_{make_unique()}; }; class Client::Impl final { @@ -416,7 +408,7 @@ class Client::Impl final { Impl() { static MultiImplPool pool; multi_impl_ = pool.get(); - receiver_ = make_unique(); + receiver_ = make_unique(); td_id_ = multi_impl_->create(*receiver_); } @@ -455,7 +447,7 @@ class Client::Impl final { private: std::shared_ptr multi_impl_; - td::unique_ptr receiver_; + unique_ptr receiver_; bool is_closed_{false}; int32 td_id_; @@ -492,14 +484,17 @@ MultiClient::MultiClient() : impl_(std::make_unique()) { MultiClient::ClientId MultiClient::create_client() { return impl_->create_client(); } -void MultiClient::send(ClientId client_id, RequestId request_id, Function function) { + +void MultiClient::send(ClientId client_id, RequestId request_id, Function &&function) { impl_->send(client_id, request_id, std::move(function)); } + MultiClient::Response MultiClient::receive(double timeout) { return impl_->receive(timeout); } + MultiClient::Object MultiClient::execute(Function &&function) { - return Impl::execute(std::move(function)); + return Td::static_request(std::move(function)); } MultiClient::~MultiClient() = default; diff --git a/td/telegram/Client.h b/td/telegram/Client.h index 6eeea6018..40f50613a 100644 --- a/td/telegram/Client.h +++ b/td/telegram/Client.h @@ -127,7 +127,6 @@ class Client final { Client &operator=(Client &&other); private: - friend class MultiClient; class Impl; std::unique_ptr impl_; }; @@ -148,16 +147,20 @@ class MultiClient final { }; ClientId create_client(); - void send(ClientId client_id, RequestId request_id, Function function); + + void send(ClientId client_id, RequestId request_id, Function &&function); + Response receive(double timeout); + static Object execute(Function &&function); ~MultiClient(); + MultiClient(MultiClient &&other); + MultiClient &operator=(MultiClient &&other); private: - friend class Client; class Impl; std::unique_ptr impl_; }; diff --git a/test/tdclient.cpp b/test/tdclient.cpp index 27c22237e..b3edcbfac 100644 --- a/test/tdclient.cpp +++ b/test/tdclient.cpp @@ -34,6 +34,7 @@ #include #include #include +#include #include REGISTER_TESTS(tdclient); @@ -875,7 +876,7 @@ TEST(Client, SimpleMulti) { #if !TD_THREAD_UNSUPPORTED TEST(Client, Multi) { - std::vector threads; + td::vector threads; for (int i = 0; i < 4; i++) { threads.emplace_back([] { for (int i = 0; i < 1000; i++) { @@ -895,8 +896,9 @@ TEST(Client, Multi) { thread.join(); } } + TEST(Client, MultiNew) { - std::vector threads; + td::vector threads; td::MultiClient client; int threads_n = 4; int clients_n = 1000; @@ -912,7 +914,7 @@ TEST(Client, MultiNew) { thread.join(); } - std::set ids; + std::set ids; while (ids.size() * threads_n * clients_n) { auto event = client.receive(10); if (event.client_id != 0 && event.id == 3) {