// // Copyright Aliaksei Levin (levlam@telegram.org), Arseny Smirnov (arseny30@gmail.com) 2014-2020 // // Distributed under the Boost Software License, Version 1.0. (See accompanying // file LICENSE_1_0.txt or copy at http://www.boost.org/LICENSE_1_0.txt) // #include "td/telegram/Client.h" #include "td/telegram/Td.h" #include "td/telegram/TdCallback.h" #include "td/actor/actor.h" #include "td/utils/common.h" #include "td/utils/crypto.h" #include "td/utils/logging.h" #include "td/utils/misc.h" #include "td/utils/MpscPollableQueue.h" #include "td/utils/port/RwMutex.h" #include "td/utils/port/thread.h" #include "td/utils/death_handler.h" #include #include #include #include #include #include namespace td { #if TD_THREAD_UNSUPPORTED || TD_EVENTFD_UNSUPPORTED class TdReceiver { public: ClientManager::Response receive(double timeout) { return receive(timeout, true, true); } ClientManager::Response receive(double timeout, bool include_responses, bool include_updates) { if (include_responses && !responses_.empty()) { auto result = std::move(responses_.front()); responses_.pop(); return result; } if (include_updates && !updates_.empty()) { auto result = std::move(updates_.front()); updates_.pop(); return result; } return {0, 0, nullptr}; } unique_ptr create_callback(ClientManager::ClientId client_id) { class Callback : public TdCallback { public: Callback(ClientManager::ClientId client_id, TdReceiver *impl) : client_id_(client_id), impl_(impl) { } void on_result(uint64 id, td_api::object_ptr result) override { if (id == 0) { impl_->responses_.push({client_id_, id, nullptr}); impl_->updates_.push({client_id_, 0, std::move(result)}); } else { impl_->responses_.push({client_id_, id, std::move(result)}); impl_->updates_.push({client_id_, id, nullptr}); } } void on_error(uint64 id, td_api::object_ptr error) override { if (id == 0) { impl_->responses_.push({client_id_, 0, nullptr}); impl_->updates_.push({client_id_, 0, std::move(error)}); } else { impl_->responses_.push({client_id_, id, std::move(error)}); impl_->updates_.push({client_id_, id, nullptr}); } } Callback(const Callback &) = delete; Callback &operator=(const Callback &) = delete; Callback(Callback &&) = delete; Callback &operator=(Callback &&) = delete; ~Callback() override { //impl_->responses_.push({0, 0, nullptr}); impl_->updates_.push({client_id_, 0, nullptr}); } private: ClientManager::ClientId client_id_; TdReceiver *impl_; }; return td::make_unique(client_id, this); } void add_response(ClientManager::ClientId client_id, uint64 id, td_api::object_ptr result) { responses_.push({client_id, id, std::move(result)}); } private: std::queue updates_; std::queue responses_; }; class ClientManager::Impl final { public: Impl() { options_.net_query_stats = std::make_shared(); concurrent_scheduler_ = make_unique(); concurrent_scheduler_->init(0); receiver_ = make_unique(); concurrent_scheduler_->start(); } ClientId create_client() { auto client_id = ++client_id_; tds_[client_id] = concurrent_scheduler_->create_actor_unsafe(0, "Td", receiver_->create_callback(client_id), options_); return client_id; } void send(ClientId client_id, RequestId request_id, td_api::object_ptr &&request) { requests_.push_back({client_id, request_id, std::move(request)}); } Response receive(double timeout) { return receive(timeout, true, true); } Response receive(double timeout, bool include_responses, bool include_updates) { if (!requests_.empty()) { auto guard = concurrent_scheduler_->get_main_guard(); for (size_t i = 0; i < requests_.size(); i++) { auto &request = requests_[i]; if (request.client_id <= 0 || request.client_id > client_id_) { receiver_->add_response(request.client_id, request.id, td_api::make_object(400, "Invalid TDLib instance specified")); continue; } auto it = tds_.find(request.client_id); if (it == tds_.end() || it->second.empty()) { receiver_->add_response(request.client_id, request.id, td_api::make_object(500, "Request aborted")); continue; } send_closure_later(it->second, &Td::request, request.id, std::move(request.request)); } requests_.clear(); } auto response = receiver_->receive(0, include_responses, include_updates); if (response.client_id == 0) { concurrent_scheduler_->run_main(0); response = receiver_->receive(0, include_responses, include_updates); } else { ConcurrentScheduler::emscripten_clear_main_timeout(); } if (response.request_id == 0 && response.object != nullptr && response.object->get_id() == td::td_api::updateAuthorizationState::ID && static_cast(response.object.get()) ->authorization_state_->get_id() == td::td_api::authorizationStateClosed::ID) { auto it = tds_.find(response.client_id); CHECK(it != tds_.end()); it->second.reset(); } if (response.object == nullptr && response.client_id != 0 && response.request_id == 0) { auto guard = concurrent_scheduler_->get_main_guard(); auto it = tds_.find(response.client_id); CHECK(it != tds_.end()); CHECK(it->second.empty()); tds_.erase(it); response.client_id = 0; } return response; } Impl(const Impl &) = delete; Impl &operator=(const Impl &) = delete; Impl(Impl &&) = delete; Impl &operator=(Impl &&) = delete; ~Impl() { { auto guard = concurrent_scheduler_->get_main_guard(); for (auto &td : tds_) { td.second.reset(); } } while (!tds_.empty()) { receive(10, false, true); } concurrent_scheduler_->finish(); } private: unique_ptr receiver_; struct Request { ClientId client_id; RequestId id; td_api::object_ptr request; }; vector requests_; unique_ptr concurrent_scheduler_; ClientId client_id_{0}; Td::Options options_; std::unordered_map> tds_; }; class Client::Impl final { public: Impl() : client_id_(impl_.create_client()) { } void send(Request request) { impl_.send(client_id_, request.id, std::move(request.function)); } Response receive(double timeout) { return receive(timeout, true, true); } Response receive(double timeout, bool include_responses, bool include_updates) { auto response = impl_.receive(timeout, include_responses, include_updates); Response old_response; old_response.id = response.request_id; old_response.object = std::move(response.object); return old_response; } private: ClientManager::Impl impl_; ClientManager::ClientId client_id_; }; #else class MultiTd : public Actor { public: explicit MultiTd(Td::Options options) : options_(std::move(options)) { } void create(int32 td_id, unique_ptr callback) { auto &td = tds_[td_id]; CHECK(td.empty()); string name = "Td"; auto context = std::make_shared(); auto old_context = set_context(context); auto old_tag = set_tag(to_string(td_id)); td = create_actor("Td", std::move(callback), options_); set_context(old_context); set_tag(old_tag); } void send(ClientManager::ClientId client_id, ClientManager::RequestId request_id, td_api::object_ptr &&request) { auto &td = tds_[client_id]; CHECK(!td.empty()); send_closure(td, &Td::request, request_id, std::move(request)); } void close(int32 td_id) { size_t erased = tds_.erase(td_id); CHECK(erased > 0); } private: Td::Options options_; std::unordered_map> tds_; }; class TdReceiver { public: TdReceiver() { output_responses_queue_ = std::make_shared(); output_responses_queue_->init(); output_updates_queue_ = std::make_shared(); output_updates_queue_->init(); } ClientManager::Response receive(double timeout) { return receive(timeout, true, true); } ClientManager::Response receive(double timeout, bool include_responses, bool include_updates) { VLOG(td_requests) << "Begin to wait for updates with timeout " << timeout; bool is_responses_locked = false; bool is_updates_locked = false; if (include_responses) { is_responses_locked = receive_responses_lock_.exchange(true); CHECK(!is_responses_locked); } if (include_updates) { is_updates_locked = receive_updates_lock_.exchange(true); CHECK(!is_updates_locked); } auto response = receive_unlocked(timeout, include_responses, include_updates); if (include_updates) { is_updates_locked = receive_updates_lock_.exchange(false); CHECK(is_updates_locked); } if (include_responses) { is_responses_locked = receive_responses_lock_.exchange(false); CHECK(is_responses_locked); } VLOG(td_requests) << "End to wait for updates, returning object " << response.request_id << ' ' << response.object.get(); return response; } unique_ptr create_callback(ClientManager::ClientId client_id) { class Callback : public TdCallback { public: explicit Callback(ClientManager::ClientId client_id, std::shared_ptr output_responses_queue, std::shared_ptr output_updates_queue) : client_id_(client_id), output_responses_queue_(std::move(output_responses_queue)), output_updates_queue_(std::move(output_updates_queue)) { } void on_result(uint64 id, td_api::object_ptr result) override { if (id == 0) { output_responses_queue_->writer_put({0, 0, nullptr}); output_updates_queue_->writer_put({client_id_, 0, std::move(result)}); } else { output_responses_queue_->writer_put({client_id_, id, std::move(result)}); output_updates_queue_->writer_put({0, 0, nullptr}); } } void on_error(uint64 id, td_api::object_ptr error) override { if (id == 0) { output_responses_queue_->writer_put({0, 0, nullptr}); output_updates_queue_->writer_put({client_id_, 0, std::move(error)}); } else { output_responses_queue_->writer_put({client_id_, id, std::move(error)}); output_updates_queue_->writer_put({0, 0, nullptr}); } } Callback(const Callback &) = delete; Callback &operator=(const Callback &) = delete; Callback(Callback &&) = delete; Callback &operator=(Callback &&) = delete; ~Callback() override { //output_responses_queue_->writer_put({0, 0, nullptr}); output_updates_queue_->writer_put({client_id_, 0, nullptr}); } private: ClientManager::ClientId client_id_; std::shared_ptr output_responses_queue_; std::shared_ptr output_updates_queue_; }; return td::make_unique(client_id, output_responses_queue_, output_updates_queue_); } void add_response(ClientManager::ClientId client_id, uint64 id, td_api::object_ptr result) { if (id == 0) { output_responses_queue_->writer_put({0, 0, nullptr}); output_updates_queue_->writer_put({client_id, id, std::move(result)}); } else { output_responses_queue_->writer_put({client_id, id, std::move(result)}); output_updates_queue_->writer_put({0, 0, nullptr}); } } private: using OutputQueue = MpscPollableQueue; std::shared_ptr output_responses_queue_; std::shared_ptr output_updates_queue_; int output_responses_queue_ready_cnt_{0}; int output_updates_queue_ready_cnt_{0}; std::atomic receive_responses_lock_{false}; std::atomic receive_updates_lock_{false}; ClientManager::Response receive_unlocked(double timeout, bool include_responses, bool include_updates) { if (include_responses) { if (output_responses_queue_ready_cnt_ == 0) { output_responses_queue_ready_cnt_ = output_responses_queue_->reader_wait_nonblock(); } if (output_responses_queue_ready_cnt_ > 0) { output_responses_queue_ready_cnt_--; return output_responses_queue_->reader_get_unsafe(); } } if (include_updates) { if (output_updates_queue_ready_cnt_ == 0) { output_updates_queue_ready_cnt_ = output_updates_queue_->reader_wait_nonblock(); } if (output_updates_queue_ready_cnt_ > 0) { output_updates_queue_ready_cnt_--; return output_updates_queue_->reader_get_unsafe(); } } if (timeout != 0) { if (include_responses && !include_updates) { output_responses_queue_->reader_get_event_fd().wait(static_cast(timeout * 1000)); } else if (!include_responses && include_updates) { output_updates_queue_->reader_get_event_fd().wait(static_cast(timeout * 1000)); } else if (include_responses && include_updates) { output_updates_queue_->reader_get_event_fd().wait(static_cast(timeout * 1000)); } else { // This configuration (include_responses = false and include_updates = false) shouldn't be used. } return receive_unlocked(0, include_responses, include_updates); } return {0, 0, nullptr}; } }; class MultiImpl { public: explicit MultiImpl(std::shared_ptr net_query_stats) { concurrent_scheduler_ = std::make_shared(); concurrent_scheduler_->init(3); concurrent_scheduler_->start(); { auto guard = concurrent_scheduler_->get_main_guard(); Td::Options options; options.net_query_stats = std::move(net_query_stats); multi_td_ = create_actor("MultiTd", std::move(options)); } scheduler_thread_ = thread([concurrent_scheduler = concurrent_scheduler_] { while (concurrent_scheduler->run_main(10)) { } }); } MultiImpl(const MultiImpl &) = delete; MultiImpl &operator=(const MultiImpl &) = delete; MultiImpl(MultiImpl &&) = delete; MultiImpl &operator=(MultiImpl &&) = delete; int32 create(TdReceiver &receiver) { auto id = create_id(); create(id, receiver.create_callback(id)); return id; } static bool is_valid_client_id(int32 client_id) { return client_id > 0 && client_id < current_id_.load(); } void send(ClientManager::ClientId client_id, ClientManager::RequestId request_id, td_api::object_ptr &&request) { auto guard = concurrent_scheduler_->get_send_guard(); send_closure(multi_td_, &MultiTd::send, client_id, request_id, std::move(request)); } void close(ClientManager::ClientId client_id) { auto guard = concurrent_scheduler_->get_send_guard(); send_closure(multi_td_, &MultiTd::close, client_id); } ~MultiImpl() { { auto guard = concurrent_scheduler_->get_send_guard(); multi_td_.reset(); Scheduler::instance()->finish(); } scheduler_thread_.join(); concurrent_scheduler_->finish(); } private: std::shared_ptr concurrent_scheduler_; thread scheduler_thread_; ActorOwn multi_td_; static std::atomic current_id_; static int32 create_id() { return current_id_.fetch_add(1); } void create(int32 td_id, unique_ptr callback) { auto guard = concurrent_scheduler_->get_send_guard(); send_closure(multi_td_, &MultiTd::create, td_id, std::move(callback)); } }; std::atomic MultiImpl::current_id_{1}; class MultiImplPool { public: std::shared_ptr get() { std::unique_lock lock(mutex_); if (impls_.empty()) { init_openssl_threads(); impls_.resize(clamp(thread::hardware_concurrency(), 8u, 1000u) * 5 / 4); } auto &impl = *std::min_element(impls_.begin(), impls_.end(), [](auto &a, auto &b) { return a.lock().use_count() < b.lock().use_count(); }); auto result = impl.lock(); if (!result) { result = std::make_shared(net_query_stats_); impl = result; } return result; } private: std::mutex mutex_; std::vector> impls_; std::shared_ptr net_query_stats_ = std::make_shared(); }; class ClientManager::Impl final { public: ClientId create_client() { auto impl = pool_.get(); auto client_id = impl->create(*receiver_); { auto lock = impls_mutex_.lock_write().move_as_ok(); impls_[client_id].impl = std::move(impl); } return client_id; } void send(ClientId client_id, RequestId request_id, td_api::object_ptr &&request) { auto lock = impls_mutex_.lock_read().move_as_ok(); if (!MultiImpl::is_valid_client_id(client_id)) { receiver_->add_response(client_id, request_id, td_api::make_object(400, "Invalid TDLib instance specified")); return; } auto it = impls_.find(client_id); if (it == impls_.end() || it->second.is_closed) { receiver_->add_response(client_id, request_id, td_api::make_object(500, "Request aborted")); return; } it->second.impl->send(client_id, request_id, std::move(request)); } Response receive(double timeout) { return receive(timeout, true, true); } Response receive(double timeout, bool include_responses, bool include_updates) { auto response = receiver_->receive(timeout, include_responses, include_updates); if (response.request_id == 0 && response.object != nullptr && response.object->get_id() == td::td_api::updateAuthorizationState::ID && static_cast(response.object.get()) ->authorization_state_->get_id() == td::td_api::authorizationStateClosed::ID) { auto lock = impls_mutex_.lock_write().move_as_ok(); close_impl(response.client_id); } if (response.object == nullptr && response.client_id != 0 && response.request_id == 0) { auto lock = impls_mutex_.lock_write().move_as_ok(); auto it = impls_.find(response.client_id); CHECK(it != impls_.end()); CHECK(it->second.is_closed); impls_.erase(it); response.client_id = 0; } return response; } void close_impl(ClientId client_id) { auto it = impls_.find(client_id); CHECK(it != impls_.end()); if (!it->second.is_closed) { it->second.is_closed = true; it->second.impl->close(client_id); } } Impl() = default; Impl(const Impl &) = delete; Impl &operator=(const Impl &) = delete; Impl(Impl &&) = delete; Impl &operator=(Impl &&) = delete; ~Impl() { for (auto &it : impls_) { close_impl(it.first); } while (!impls_.empty()) { receive(10, false, true); } } private: MultiImplPool pool_; RwMutex impls_mutex_; struct MultiImplInfo { std::shared_ptr impl; bool is_closed = false; }; std::unordered_map impls_; unique_ptr receiver_{make_unique()}; }; class Client::Impl final { public: Impl() { static MultiImplPool pool; multi_impl_ = pool.get(); receiver_ = make_unique(); td_id_ = multi_impl_->create(*receiver_); } void send(Request request) { if (request.id == 0 || request.function == nullptr) { LOG(ERROR) << "Drop wrong request " << request.id; return; } multi_impl_->send(td_id_, request.id, std::move(request.function)); } Response receive(double timeout) { return receive(timeout, true, true); } Response receive(double timeout, bool include_responses, bool include_updates) { auto response = receiver_->receive(timeout, include_responses, include_updates); Response old_response; old_response.id = response.request_id; old_response.object = std::move(response.object); return old_response; } Impl(const Impl &) = delete; Impl &operator=(const Impl &) = delete; Impl(Impl &&) = delete; Impl &operator=(Impl &&) = delete; ~Impl() { multi_impl_->close(td_id_); while (true) { auto response = receiver_->receive(10.0, false, true); if (response.object == nullptr && response.client_id != 0 && response.request_id == 0) { break; } } } private: std::shared_ptr multi_impl_; unique_ptr receiver_; int32 td_id_; }; #endif Client::Client() : impl_(std::make_unique()) { } void Client::send(Request &&request) { impl_->send(std::move(request)); } Client::Response Client::receive(double timeout) { return impl_->receive(timeout); } Client::Response Client::receive(double timeout, bool include_responses, bool include_updates) { return impl_->receive(timeout, include_responses, include_updates); } Client::Response Client::execute(Request &&request) { Response response; response.id = request.id; response.object = Td::static_request(std::move(request.function)); return response; } Client::~Client() = default; Client::Client(Client &&other) = default; Client &Client::operator=(Client &&other) = default; ClientManager::ClientManager() : impl_(std::make_unique()) { } ClientManager::ClientId ClientManager::create_client() { return impl_->create_client(); } void ClientManager::send(ClientId client_id, RequestId request_id, td_api::object_ptr &&request) { impl_->send(client_id, request_id, std::move(request)); } ClientManager::Response ClientManager::receive(double timeout) { return impl_->receive(timeout); } ClientManager::Response ClientManager::receive(double timeout, bool include_responses, bool include_updates) { return impl_->receive(timeout, include_responses, include_updates); } td_api::object_ptr ClientManager::execute(td_api::object_ptr &&request) { return Td::static_request(std::move(request)); } ClientManager::~ClientManager() = default; ClientManager::ClientManager(ClientManager &&other) = default; ClientManager &ClientManager::operator=(ClientManager &&other) = default; } // namespace td