// // Copyright Aliaksei Levin (levlam@telegram.org), Arseny Smirnov (arseny30@gmail.com) 2014-2023 // // Distributed under the Boost Software License, Version 1.0. (See accompanying // file LICENSE_1_0.txt or copy at http://www.boost.org/LICENSE_1_0.txt) // #pragma once #include "td/telegram/net/DcId.h" #include "td/telegram/net/NetQueryCounter.h" #include "td/telegram/net/NetQueryStats.h" #include "td/actor/actor.h" #include "td/actor/SignalSlot.h" #include "td/utils/buffer.h" #include "td/utils/common.h" #include "td/utils/format.h" #include "td/utils/logging.h" #include "td/utils/ObjectPool.h" #include "td/utils/Promise.h" #include "td/utils/Slice.h" #include "td/utils/Span.h" #include "td/utils/Status.h" #include "td/utils/StringBuilder.h" #include "td/utils/tl_parsers.h" #include "td/utils/TsList.h" #include <atomic> #include <utility> namespace td { extern int VERBOSITY_NAME(net_query); class ChainId; class NetQuery; using NetQueryPtr = ObjectPool<NetQuery>::OwnerPtr; using NetQueryRef = ObjectPool<NetQuery>::WeakPtr; class NetQueryCallback : public Actor { public: virtual void on_result(NetQueryPtr query); virtual void on_result_resendable(NetQueryPtr query, Promise<NetQueryPtr> promise); }; class NetQuery final : public TsListNode<NetQueryDebug> { enum class State : int8 { Empty, Query, OK, Error }; public: NetQuery() = default; enum class Type : int8 { Common, Upload, Download, DownloadSmall }; enum class AuthFlag : int8 { Off, On }; enum class GzipFlag : int8 { Off, On }; enum Error : int32 { Resend = 202, Canceled = 203, ResendInvokeAfter = 204 }; uint64 id() const { return id_; } DcId dc_id() const { return dc_id_; } Type type() const { return type_; } GzipFlag gzip_flag() const { return gzip_flag_; } AuthFlag auth_flag() const { return auth_flag_; } int32 tl_constructor() const { return tl_constructor_; } void resend(DcId new_dc_id) { VLOG(net_query) << "Resend " << *this; { auto guard = lock(); get_data_unsafe().resend_count_++; } dc_id_ = new_dc_id; status_ = Status::OK(); state_ = State::Query; } void resend() { resend(dc_id_); } const BufferSlice &query() const { return query_; } const BufferSlice &ok() const { CHECK(state_ == State::OK); return answer_; } const Status &error() const { CHECK(state_ == State::Error); return status_; } BufferSlice move_as_ok() { auto ok = std::move(answer_); clear(); return ok; } Status move_as_error() TD_WARN_UNUSED_RESULT { auto status = std::move(status_); clear(); return status; } void set_ok(BufferSlice slice) { VLOG(net_query) << "Receive answer " << *this; CHECK(state_ == State::Query); answer_ = std::move(slice); state_ = State::OK; } void on_net_write(size_t size); void on_net_read(size_t size); void set_error(Status status, string source = string()); void set_error_resend() { set_error_impl(Status::Error<Error::Resend>()); } void set_error_canceled() { set_error_impl(Status::Error<Error::Canceled>()); } void set_error_resend_invoke_after() { set_error_impl(Status::Error<Error::ResendInvokeAfter>()); } bool update_is_ready() { if (state_ == State::Query) { if (cancellation_token_.load(std::memory_order_relaxed) == 0 || cancel_slot_.was_signal()) { set_error_canceled(); return true; } return false; } return true; } bool is_ready() const { return state_ != State::Query; } bool is_error() const { return state_ == State::Error; } bool is_ok() const { return state_ == State::OK; } int32 ok_tl_constructor() const { return tl_magic(answer_); } uint64 session_id() const { return session_id_.load(std::memory_order_relaxed); } void set_session_id(uint64 session_id) { session_id_.store(session_id, std::memory_order_relaxed); } uint64 message_id() const { return message_id_; } void set_message_id(uint64 message_id) { message_id_ = message_id; cancel_slot_.clear_event(); } Span<NetQueryRef> invoke_after() const { return invoke_after_; } void set_invoke_after(std::vector<NetQueryRef> refs) { invoke_after_ = std::move(refs); } uint32 session_rand() const { if (in_sequence_dispacher_ && !chain_ids_.empty()) { return static_cast<uint32>(chain_ids_[0] >> 10); } return 0; } void cancel(int32 cancellation_token) { cancellation_token_.compare_exchange_strong(cancellation_token, 0, std::memory_order_relaxed); } void set_cancellation_token(int32 cancellation_token) { cancellation_token_.store(cancellation_token, std::memory_order_relaxed); } void clear() { if (!is_ready()) { auto guard = lock(); LOG(ERROR) << "Destroy not ready query " << *this << " " << tag("state", get_data_unsafe().state_); } // TODO: CHECK if net_query is lost here cancel_slot_.close(); *this = NetQuery(); } bool empty() const { return state_ == State::Empty || !nq_counter_ || may_be_lost_; } void stop_track() { nq_counter_ = NetQueryCounter(); remove(); } void debug_send_failed() { auto guard = lock(); get_data_unsafe().send_failed_count_++; } void debug(string state, bool may_be_lost = false); void set_callback(ActorShared<NetQueryCallback> callback) { callback_ = std::move(callback); } ActorShared<NetQueryCallback> move_callback() { return std::move(callback_); } void start_migrate(int32 sched_id) { using ::td::start_migrate; start_migrate(cancel_slot_, sched_id); } void finish_migrate() { using ::td::finish_migrate; finish_migrate(cancel_slot_); } int8 priority() const { return priority_; } void set_priority(int8 priority) { priority_ = priority; } Span<uint64> get_chain_ids() const { return chain_ids_; } void set_in_sequence_dispatcher(bool in_sequence_dispacher) { in_sequence_dispacher_ = in_sequence_dispacher; } bool in_sequence_dispatcher() const { return in_sequence_dispacher_; } private: State state_ = State::Empty; Type type_ = Type::Common; AuthFlag auth_flag_ = AuthFlag::Off; GzipFlag gzip_flag_ = GzipFlag::Off; DcId dc_id_; NetQueryCounter nq_counter_; Status status_; uint64 id_ = 0; BufferSlice query_; BufferSlice answer_; int32 tl_constructor_ = 0; vector<NetQueryRef> invoke_after_; vector<uint64> chain_ids_; bool in_sequence_dispacher_ = false; bool may_be_lost_ = false; int8 priority_{0}; template <class T> struct movable_atomic final : public std::atomic<T> { movable_atomic() = default; movable_atomic(T &&x) : std::atomic<T>(std::forward<T>(x)) { } movable_atomic(movable_atomic &&other) noexcept { this->store(other.load(std::memory_order_relaxed), std::memory_order_relaxed); } movable_atomic &operator=(movable_atomic &&other) noexcept { this->store(other.load(std::memory_order_relaxed), std::memory_order_relaxed); return *this; } movable_atomic(const movable_atomic &) = delete; movable_atomic &operator=(const movable_atomic &) = delete; ~movable_atomic() = default; }; movable_atomic<uint64> session_id_{0}; uint64 message_id_{0}; movable_atomic<int32> cancellation_token_{-1}; // == 0 if query is canceled ActorShared<NetQueryCallback> callback_; void set_error_impl(Status status, string source = string()) { VLOG(net_query) << "Receive error " << *this << " " << status; status_ = std::move(status); state_ = State::Error; source_ = std::move(source); } static int32 tl_magic(const BufferSlice &buffer_slice); public: int32 next_timeout_ = 1; // for NetQueryDelayer int32 total_timeout_ = 0; // for NetQueryDelayer/SequenceDispatcher int32 total_timeout_limit_ = 60; // for NetQueryDelayer/SequenceDispatcher and to be set by caller int32 last_timeout_ = 0; // for NetQueryDelayer/SequenceDispatcher string source_; // for NetQueryDelayer/SequenceDispatcher int32 dispatch_ttl_ = -1; // for NetQueryDispatcher and to be set by caller int32 file_type_ = -1; // to be set by caller Slot cancel_slot_; // for Session and to be set by caller Promise<> quick_ack_promise_; // for Session and to be set by caller bool need_resend_on_503_ = true; // for NetQueryDispatcher and to be set by caller NetQuery(uint64 id, BufferSlice &&query, DcId dc_id, Type type, AuthFlag auth_flag, GzipFlag gzip_flag, int32 tl_constructor, int32 total_timeout_limit, NetQueryStats *stats, vector<ChainId> chain_ids); }; inline StringBuilder &operator<<(StringBuilder &stream, const NetQuery &net_query) { stream << "[Query:"; stream << tag("id", net_query.id()); stream << tag("tl", format::as_hex(net_query.tl_constructor())); auto message_id = net_query.message_id(); if (message_id != 0) { stream << tag("msg_id", format::as_hex(message_id)); } if (net_query.is_error()) { stream << net_query.error(); } else if (net_query.is_ok()) { stream << tag("result_tl", format::as_hex(net_query.ok_tl_constructor())); } stream << ']'; return stream; } inline StringBuilder &operator<<(StringBuilder &stream, const NetQueryPtr &net_query_ptr) { if (net_query_ptr.empty()) { return stream << "[Query: null]"; } return stream << *net_query_ptr; } inline void cancel_query(NetQueryRef &ref) { if (ref.empty()) { return; } ref->cancel(ref.generation()); } template <class T> Result<typename T::ReturnType> fetch_result(const BufferSlice &message) { TlBufferParser parser(&message); auto result = T::fetch_result(parser); parser.fetch_end(); const char *error = parser.get_error(); if (error != nullptr) { LOG(ERROR) << "Can't parse: " << format::as_hex_dump<4>(message.as_slice()); return Status::Error(500, Slice(error)); } return std::move(result); } template <class T> Result<typename T::ReturnType> fetch_result(NetQueryPtr query) { CHECK(!query.empty()); if (query->is_error()) { return query->move_as_error(); } auto buffer = query->move_as_ok(); return fetch_result<T>(buffer); } template <class T> Result<typename T::ReturnType> fetch_result(Result<NetQueryPtr> r_query) { TRY_RESULT(query, std::move(r_query)); return fetch_result<T>(std::move(query)); } inline void NetQueryCallback::on_result(NetQueryPtr query) { on_result_resendable(std::move(query), Auto()); } inline void NetQueryCallback::on_result_resendable(NetQueryPtr query, Promise<NetQueryPtr> promise) { on_result(std::move(query)); } inline void start_migrate(NetQueryPtr &net_query, int32 sched_id) { net_query->start_migrate(sched_id); } inline void finish_migrate(NetQueryPtr &net_query) { net_query->finish_migrate(); } } // namespace td