From b44e2ea3fcab3451acd64f86303b64546a0d71d4 Mon Sep 17 00:00:00 2001 From: levlam Date: Thu, 21 Sep 2023 17:52:33 +0300 Subject: [PATCH] Add strictly-typed class mtproto::MessageId. --- CMakeLists.txt | 1 + benchmark/bench_misc.cpp | 20 ++-- td/mtproto/AuthData.cpp | 42 ++++---- td/mtproto/AuthData.h | 22 +++-- td/mtproto/CryptoStorer.h | 21 ++-- td/mtproto/HandshakeConnection.h | 3 +- td/mtproto/MessageId.h | 70 +++++++++++++ td/mtproto/MtprotoQuery.h | 6 +- td/mtproto/NoCryptoStorer.h | 8 +- td/mtproto/PacketInfo.h | 4 +- td/mtproto/PingConnection.cpp | 19 ++-- td/mtproto/RawConnection.cpp | 4 +- td/mtproto/RawConnection.h | 5 +- td/mtproto/SessionConnection.cpp | 164 ++++++++++++++++--------------- td/mtproto/SessionConnection.h | 64 ++++++------ td/mtproto/Transport.cpp | 11 ++- td/telegram/net/Session.cpp | 110 ++++++++++----------- td/telegram/net/Session.h | 46 ++++----- 18 files changed, 355 insertions(+), 265 deletions(-) create mode 100644 td/mtproto/MessageId.h diff --git a/CMakeLists.txt b/CMakeLists.txt index 49c0b674c..4d2d624a9 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -518,6 +518,7 @@ set(TDLIB_SOURCE td/mtproto/HttpTransport.h td/mtproto/IStreamTransport.h td/mtproto/KDF.h + td/mtproto/MessageId.h td/mtproto/MtprotoQuery.h td/mtproto/NoCryptoStorer.h td/mtproto/PacketInfo.h diff --git a/benchmark/bench_misc.cpp b/benchmark/bench_misc.cpp index 7cb175699..ed400fb48 100644 --- a/benchmark/bench_misc.cpp +++ b/benchmark/bench_misc.cpp @@ -416,7 +416,7 @@ class IdDuplicateCheckerOld { static td::string get_description() { return "Old"; } - td::Status check(td::int64 message_id) { + td::Status check(td::uint64 message_id) { if (saved_message_ids_.size() == MAX_SAVED_MESSAGE_IDS) { auto oldest_message_id = *saved_message_ids_.begin(); if (message_id < oldest_message_id) { @@ -437,7 +437,7 @@ class IdDuplicateCheckerOld { private: static constexpr size_t MAX_SAVED_MESSAGE_IDS = 1000; - std::set saved_message_ids_; + std::set saved_message_ids_; }; template @@ -446,7 +446,7 @@ class IdDuplicateCheckerNew { static td::string get_description() { return PSTRING() << "New" << MAX_SAVED_MESSAGE_IDS; } - td::Status check(td::int64 message_id) { + td::Status check(td::uint64 message_id) { auto insert_result = saved_message_ids_.insert(message_id); if (!insert_result.second) { return td::Status::Error(1, PSLICE() << "Ignore already processed message " << message_id); @@ -464,7 +464,7 @@ class IdDuplicateCheckerNew { } private: - std::set saved_message_ids_; + std::set saved_message_ids_; }; class IdDuplicateCheckerNewOther { @@ -472,7 +472,7 @@ class IdDuplicateCheckerNewOther { static td::string get_description() { return "NewOther"; } - td::Status check(td::int64 message_id) { + td::Status check(td::uint64 message_id) { if (!saved_message_ids_.insert(message_id).second) { return td::Status::Error(1, PSLICE() << "Ignore already processed message " << message_id); } @@ -490,7 +490,7 @@ class IdDuplicateCheckerNewOther { private: static constexpr size_t MAX_SAVED_MESSAGE_IDS = 1000; - std::set saved_message_ids_; + std::set saved_message_ids_; }; class IdDuplicateCheckerNewSimple { @@ -498,7 +498,7 @@ class IdDuplicateCheckerNewSimple { static td::string get_description() { return "NewSimple"; } - td::Status check(td::int64 message_id) { + td::Status check(td::uint64 message_id) { auto insert_result = saved_message_ids_.insert(message_id); if (!insert_result.second) { return td::Status::Error(1, "Ignore already processed message"); @@ -516,7 +516,7 @@ class IdDuplicateCheckerNewSimple { private: static constexpr size_t MAX_SAVED_MESSAGE_IDS = 1000; - std::set saved_message_ids_; + std::set saved_message_ids_; }; template @@ -525,7 +525,7 @@ class IdDuplicateCheckerArray { static td::string get_description() { return PSTRING() << "Array" << max_size; } - td::Status check(td::int64 message_id) { + td::Status check(td::uint64 message_id) { if (end_pos_ == 2 * max_size) { std::copy_n(&saved_message_ids_[max_size], max_size, &saved_message_ids_[0]); end_pos_ = max_size; @@ -550,7 +550,7 @@ class IdDuplicateCheckerArray { } private: - std::array saved_message_ids_; + std::array saved_message_ids_; std::size_t end_pos_ = 0; }; diff --git a/td/mtproto/AuthData.cpp b/td/mtproto/AuthData.cpp index 023023495..59fca3a1b 100644 --- a/td/mtproto/AuthData.cpp +++ b/td/mtproto/AuthData.cpp @@ -6,7 +6,6 @@ // #include "td/mtproto/AuthData.h" -#include "td/utils/format.h" #include "td/utils/logging.h" #include "td/utils/Random.h" #include "td/utils/SliceBuilder.h" @@ -17,7 +16,8 @@ namespace td { namespace mtproto { -Status check_message_id_duplicates(uint64 *saved_message_ids, size_t max_size, size_t &end_pos, uint64 message_id) { +Status check_message_id_duplicates(MessageId *saved_message_ids, size_t max_size, size_t &end_pos, + MessageId message_id) { // In addition, the identifiers (msg_id) of the last N messages received from the other side must be stored, and if // a message comes in with msg_id lower than all or equal to any of the stored values, that message is to be // ignored. Otherwise, the new message msg_id is added to the set, and, if the number of stored msg_id values is @@ -32,13 +32,12 @@ Status check_message_id_duplicates(uint64 *saved_message_ids, size_t max_size, s return Status::OK(); } if (end_pos >= max_size && message_id < saved_message_ids[0]) { - return Status::Error(2, PSLICE() << "Ignore very old message " << format::as_hex(message_id) - << " older than the oldest known message " - << format::as_hex(saved_message_ids[0])); + return Status::Error( + 2, PSLICE() << "Ignore very old " << message_id << " older than the oldest known " << saved_message_ids[0]); } auto it = std::lower_bound(&saved_message_ids[0], &saved_message_ids[end_pos], message_id); if (*it == message_id) { - return Status::Error(1, PSLICE() << "Ignore already processed message " << format::as_hex(message_id)); + return Status::Error(1, PSLICE() << "Ignore already processed " << message_id); } std::copy_backward(it, &saved_message_ids[end_pos], &saved_message_ids[end_pos + 1]); *it = message_id; @@ -105,7 +104,7 @@ std::vector AuthData::get_future_salts() const { return res; } -uint64 AuthData::next_message_id(double now) { +MessageId AuthData::next_message_id(double now) { double server_time = get_server_time(now); auto t = static_cast(server_time * (static_cast(1) << 32)); @@ -113,31 +112,31 @@ uint64 AuthData::next_message_id(double now) { // TODO(perf) do not do this for systems with good precision?.. auto rx = Random::secure_int32(); auto to_xor = rx & ((1 << 22) - 1); - auto to_mul = ((rx >> 22) & 1023) + 1; t ^= to_xor; - auto result = t & static_cast(-4); + auto result = MessageId(t & static_cast(-4)); if (last_message_id_ >= result) { - result = last_message_id_ + 8 * to_mul; + auto to_mul = ((rx >> 22) & 1023) + 1; + result = MessageId(last_message_id_.get() + 8 * to_mul); } - LOG(DEBUG) << "Create message identifier " << format::as_hex(result) << " at " << now; + LOG(DEBUG) << "Create identifier for " << result << " at " << now; last_message_id_ = result; return result; } -bool AuthData::is_valid_outbound_msg_id(uint64 message_id, double now) const { +bool AuthData::is_valid_outbound_msg_id(MessageId message_id, double now) const { double server_time = get_server_time(now); - auto id_time = static_cast(message_id) / static_cast(static_cast(1) << 32); + auto id_time = static_cast(message_id.get()) / static_cast(static_cast(1) << 32); return server_time - 150 < id_time && id_time < server_time + 30; } -bool AuthData::is_valid_inbound_msg_id(uint64 message_id, double now) const { +bool AuthData::is_valid_inbound_msg_id(MessageId message_id, double now) const { double server_time = get_server_time(now); - auto id_time = static_cast(message_id) / static_cast(static_cast(1) << 32); + auto id_time = static_cast(message_id.get()) / static_cast(static_cast(1) << 32); return server_time - 300 < id_time && id_time < server_time + 30; } -Status AuthData::check_packet(uint64 session_id, uint64 message_id, double now, bool &time_difference_was_updated) { +Status AuthData::check_packet(uint64 session_id, MessageId message_id, double now, bool &time_difference_was_updated) { // Client is to check that the session_id field in the decrypted message indeed equals to that of an active session // created by the client. if (get_session_id() != session_id) { @@ -147,22 +146,21 @@ Status AuthData::check_packet(uint64 session_id, uint64 message_id, double now, // Client must check that msg_id has even parity for messages from client to server, and odd parity for messages // from server to client. - if ((message_id & 1) == 0) { - return Status::Error(PSLICE() << "Receive invalid message identifier " << format::as_hex(message_id)); + if ((message_id.get() & 1) == 0) { + return Status::Error(PSLICE() << "Receive invalid " << message_id); } TRY_STATUS(duplicate_checker_.check(message_id)); - LOG(DEBUG) << "Receive packet " << format::as_hex(message_id) << " from session " << format::as_hex(session_id) - << " at " << now; - time_difference_was_updated = update_server_time_difference(static_cast(message_id >> 32) - now); + LOG(DEBUG) << "Receive packet in " << message_id << " from session " << session_id << " at " << now; + time_difference_was_updated = update_server_time_difference(static_cast(message_id.get() >> 32) - now); // In addition, msg_id values that belong over 30 seconds in the future or over 300 seconds in the past are to be // ignored (recall that msg_id approximately equals unixtime * 2^32). This is especially important for the server. // The client would also find this useful (to protect from a replay attack), but only if it is certain of its time // (for example, if its time has been synchronized with that of the server). if (server_time_difference_was_updated_ && !is_valid_inbound_msg_id(message_id, now)) { - return Status::Error(PSLICE() << "Ignore too old or too new message " << format::as_hex(message_id)); + return Status::Error(PSLICE() << "Ignore too old or too new " << message_id); } return Status::OK(); diff --git a/td/mtproto/AuthData.h b/td/mtproto/AuthData.h index bc2403508..f242607be 100644 --- a/td/mtproto/AuthData.h +++ b/td/mtproto/AuthData.h @@ -7,6 +7,7 @@ #pragma once #include "td/mtproto/AuthKey.h" +#include "td/mtproto/MessageId.h" #include "td/utils/common.h" #include "td/utils/Slice.h" @@ -37,17 +38,18 @@ void parse(ServerSalt &salt, ParserT &parser) { salt.valid_until = parser.fetch_double(); } -Status check_message_id_duplicates(uint64 *saved_message_ids, size_t max_size, size_t &end_pos, uint64 message_id); +Status check_message_id_duplicates(MessageId *saved_message_ids, size_t max_size, size_t &end_pos, + MessageId message_id); template class MessageIdDuplicateChecker { public: - Status check(uint64 message_id) { + Status check(MessageId message_id) { return check_message_id_duplicates(&saved_message_ids_[0], max_size, end_pos_, message_id); } private: - std::array saved_message_ids_; + std::array saved_message_ids_; size_t end_pos_ = 0; }; @@ -232,19 +234,19 @@ class AuthData { std::vector get_future_salts() const; - uint64 next_message_id(double now); + MessageId next_message_id(double now); - bool is_valid_outbound_msg_id(uint64 message_id, double now) const; + bool is_valid_outbound_msg_id(MessageId message_id, double now) const; - bool is_valid_inbound_msg_id(uint64 message_id, double now) const; + bool is_valid_inbound_msg_id(MessageId message_id, double now) const; - Status check_packet(uint64 session_id, uint64 message_id, double now, bool &time_difference_was_updated); + Status check_packet(uint64 session_id, MessageId message_id, double now, bool &time_difference_was_updated); - Status check_update(uint64 message_id) { + Status check_update(MessageId message_id) { return updates_duplicate_checker_.check(message_id); } - Status recheck_update(uint64 message_id) { + Status recheck_update(MessageId message_id) { return updates_duplicate_rechecker_.check(message_id); } @@ -275,7 +277,7 @@ class AuthData { bool server_time_difference_was_updated_ = false; double server_time_difference_ = 0; ServerSalt server_salt_; - uint64 last_message_id_ = 0; + MessageId last_message_id_; int32 seq_no_ = 0; string header_; uint64 session_id_ = 0; diff --git a/td/mtproto/CryptoStorer.h b/td/mtproto/CryptoStorer.h index eaadd523f..ce9c5efd0 100644 --- a/td/mtproto/CryptoStorer.h +++ b/td/mtproto/CryptoStorer.h @@ -7,6 +7,7 @@ #pragma once #include "td/mtproto/AuthData.h" +#include "td/mtproto/MessageId.h" #include "td/mtproto/MtprotoQuery.h" #include "td/mtproto/PacketStorer.h" #include "td/mtproto/utils.h" @@ -57,7 +58,7 @@ class ObjectImpl { bool empty() const { return !not_empty_; } - uint64 get_message_id() const { + MessageId get_message_id() const { return message_id_; } @@ -65,7 +66,7 @@ class ObjectImpl { bool not_empty_; Object object_; ObjectStorer object_storer_; - uint64 message_id_; + MessageId message_id_; int32 seq_no_; }; @@ -96,7 +97,7 @@ class CancelVectorImpl { bool not_empty() const { return !storers_.empty(); } - uint64 get_message_id() const { + MessageId get_message_id() const { CHECK(storers_.size() == 1); return storers_[0].get_message_id(); } @@ -107,7 +108,7 @@ class CancelVectorImpl { class InvokeAfter { public: - explicit InvokeAfter(Span message_ids) : message_ids_(message_ids) { + explicit InvokeAfter(Span message_ids) : message_ids_(message_ids) { } template void store(StorerT &storer) const { @@ -116,7 +117,7 @@ class InvokeAfter { } if (message_ids_.size() == 1) { storer.store_int(static_cast(0xcb9f372d)); - storer.store_binary(message_ids_[0]); + storer.store_binary(message_ids_[0].get()); return; } // invokeAfterMsgs#3dc4b4f0 {X:Type} msg_ids:Vector query:!X = X; @@ -124,12 +125,12 @@ class InvokeAfter { storer.store_int(static_cast(0x1cb5c415)); storer.store_int(narrow_cast(message_ids_.size())); for (auto message_id : message_ids_) { - storer.store_binary(message_id); + storer.store_binary(message_id.get()); } } private: - Span message_ids_; + Span message_ids_; }; class QueryImpl { @@ -206,8 +207,8 @@ class CryptoImpl { CryptoImpl(const vector &to_send, Slice header, vector &&to_ack, int64 ping_id, int ping_timeout, int max_delay, int max_after, int max_wait, int future_salt_n, vector get_info, vector resend, const vector &cancel, bool destroy_key, AuthData *auth_data, - uint64 *container_message_id, uint64 *get_info_message_id, uint64 *resend_message_id, - uint64 *ping_message_id, uint64 *parent_message_id) + MessageId *container_message_id, MessageId *get_info_message_id, MessageId *resend_message_id, + MessageId *ping_message_id, MessageId *parent_message_id) : query_storer_(to_send, header) , ack_empty_(to_ack.empty()) , ack_storer_(!ack_empty_, mtproto_api::msgs_ack(std::move(to_ack)), auth_data) @@ -362,7 +363,7 @@ class CryptoImpl { Mixed }; Type type_; - uint64 message_id_; + MessageId message_id_; int32 seq_no_; }; diff --git a/td/mtproto/HandshakeConnection.h b/td/mtproto/HandshakeConnection.h index 8d7883ff7..5fda32d4a 100644 --- a/td/mtproto/HandshakeConnection.h +++ b/td/mtproto/HandshakeConnection.h @@ -8,6 +8,7 @@ #include "td/mtproto/AuthKey.h" #include "td/mtproto/Handshake.h" +#include "td/mtproto/MessageId.h" #include "td/mtproto/NoCryptoStorer.h" #include "td/mtproto/PacketInfo.h" #include "td/mtproto/PacketStorer.h" @@ -61,7 +62,7 @@ class HandshakeConnection final unique_ptr context_; void send_no_crypto(const Storer &storer) final { - raw_connection_->send_no_crypto(PacketStorer(0, storer)); + raw_connection_->send_no_crypto(PacketStorer(MessageId(), storer)); } Status on_raw_packet(const PacketInfo &packet_info, BufferSlice packet) final { diff --git a/td/mtproto/MessageId.h b/td/mtproto/MessageId.h new file mode 100644 index 000000000..4b7615f69 --- /dev/null +++ b/td/mtproto/MessageId.h @@ -0,0 +1,70 @@ +// +// Copyright Aliaksei Levin (levlam@telegram.org), Arseny Smirnov (arseny30@gmail.com) 2014-2023 +// +// Distributed under the Boost Software License, Version 1.0. (See accompanying +// file LICENSE_1_0.txt or copy at http://www.boost.org/LICENSE_1_0.txt) +// +#pragma once + +#include "td/utils/common.h" +#include "td/utils/format.h" +#include "td/utils/HashTableUtils.h" +#include "td/utils/StringBuilder.h" + +#include + +namespace td { +namespace mtproto { + +class MessageId { + uint64 message_id_ = 0; + + public: + MessageId() = default; + + explicit constexpr MessageId(uint64 message_id) : message_id_(message_id) { + } + template ::value>> + MessageId(T message_id) = delete; + + uint64 get() const { + return message_id_; + } + + bool operator==(const MessageId &other) const { + return message_id_ == other.message_id_; + } + + bool operator!=(const MessageId &other) const { + return message_id_ != other.message_id_; + } + + friend bool operator<(const MessageId &lhs, const MessageId &rhs) { + return lhs.get() < rhs.get(); + } + + friend bool operator>(const MessageId &lhs, const MessageId &rhs) { + return lhs.get() > rhs.get(); + } + + friend bool operator<=(const MessageId &lhs, const MessageId &rhs) { + return lhs.get() <= rhs.get(); + } + + friend bool operator>=(const MessageId &lhs, const MessageId &rhs) { + return lhs.get() >= rhs.get(); + } +}; + +struct MessageIdHash { + uint32 operator()(MessageId message_id) const { + return Hash()(message_id.get()); + } +}; + +inline StringBuilder &operator<<(StringBuilder &string_builder, MessageId message_id) { + return string_builder << "message " << format::as_hex(message_id.get()); +} + +} // namespace mtproto +} // namespace td diff --git a/td/mtproto/MtprotoQuery.h b/td/mtproto/MtprotoQuery.h index d0fad78f9..40d86a79b 100644 --- a/td/mtproto/MtprotoQuery.h +++ b/td/mtproto/MtprotoQuery.h @@ -6,6 +6,8 @@ // #pragma once +#include "td/mtproto/MessageId.h" + #include "td/utils/buffer.h" #include "td/utils/common.h" @@ -13,11 +15,11 @@ namespace td { namespace mtproto { struct MtprotoQuery { - uint64 message_id; + MessageId message_id; int32 seq_no; BufferSlice packet; bool gzip_flag; - vector invoke_after_message_ids; + vector invoke_after_message_ids; bool use_quick_ack; }; diff --git a/td/mtproto/NoCryptoStorer.h b/td/mtproto/NoCryptoStorer.h index 78da7b6ef..1edf8c653 100644 --- a/td/mtproto/NoCryptoStorer.h +++ b/td/mtproto/NoCryptoStorer.h @@ -6,6 +6,8 @@ // #pragma once +#include "td/mtproto/MessageId.h" + #include "td/utils/Random.h" #include "td/utils/StorerBase.h" @@ -14,7 +16,7 @@ namespace mtproto { class NoCryptoImpl { public: - NoCryptoImpl(uint64 message_id, const Storer &data, bool need_pad = true) : message_id_(message_id), data_(data) { + NoCryptoImpl(MessageId message_id, const Storer &data, bool need_pad = true) : message_id_(message_id), data_(data) { if (need_pad) { size_t pad_size = -static_cast(data_.size()) & 15; pad_size += 16 * (static_cast(Random::secure_int32()) % 16); @@ -25,14 +27,14 @@ class NoCryptoImpl { template void do_store(StorerT &storer) const { - storer.store_binary(message_id_); + storer.store_binary(message_id_.get()); storer.store_binary(static_cast(data_.size() + pad_.size())); storer.store_storer(data_); storer.store_slice(pad_); } private: - uint64 message_id_; + MessageId message_id_; const Storer &data_; std::string pad_; }; diff --git a/td/mtproto/PacketInfo.h b/td/mtproto/PacketInfo.h index 68bb41420..9a33662b1 100644 --- a/td/mtproto/PacketInfo.h +++ b/td/mtproto/PacketInfo.h @@ -6,6 +6,8 @@ // #pragma once +#include "td/mtproto/MessageId.h" + #include "td/utils/common.h" namespace td { @@ -18,7 +20,7 @@ struct PacketInfo { uint64 salt{0}; uint64 session_id{0}; - uint64 message_id{0}; + MessageId message_id; int32 seq_no{0}; int32 version{1}; bool no_crypto_flag{false}; diff --git a/td/mtproto/PingConnection.cpp b/td/mtproto/PingConnection.cpp index 45ff1bf85..28a1878bf 100644 --- a/td/mtproto/PingConnection.cpp +++ b/td/mtproto/PingConnection.cpp @@ -8,6 +8,7 @@ #include "td/mtproto/AuthData.h" #include "td/mtproto/AuthKey.h" +#include "td/mtproto/MessageId.h" #include "td/mtproto/mtproto_api.h" #include "td/mtproto/NoCryptoStorer.h" #include "td/mtproto/PacketInfo.h" @@ -47,7 +48,8 @@ class PingConnectionReqPQ final if (!was_ping_) { UInt128 nonce; Random::secure_bytes(nonce.raw, sizeof(nonce)); - raw_connection_->send_no_crypto(PacketStorer(1, create_storer(mtproto_api::req_pq_multi(nonce)))); + raw_connection_->send_no_crypto(PacketStorer(MessageId(static_cast(1)), + create_storer(mtproto_api::req_pq_multi(nonce)))); was_ping_ = true; if (ping_count_ == 1) { start_time_ = Time::now(); @@ -129,13 +131,13 @@ class PingConnectionPingPong final void on_server_time_difference_updated(bool force) final { } - void on_new_session_created(uint64 unique_id, uint64 first_message_id) final { + void on_new_session_created(uint64 unique_id, MessageId first_message_id) final { } void on_session_failed(Status status) final { } - void on_container_sent(uint64 container_message_id, vector message_ids) final { + void on_container_sent(MessageId container_message_id, vector message_ids) final { } Status on_pong() final { @@ -153,21 +155,22 @@ class PingConnectionPingPong final return Status::OK(); } - void on_message_ack(uint64 message_id) final { + void on_message_ack(MessageId message_id) final { } - Status on_message_result_ok(uint64 message_id, BufferSlice packet, size_t original_size) final { + Status on_message_result_ok(MessageId message_id, BufferSlice packet, size_t original_size) final { LOG(ERROR) << "Unexpected message"; return Status::OK(); } - void on_message_result_error(uint64 message_id, int code, string message) final { + void on_message_result_error(MessageId message_id, int code, string message) final { } - void on_message_failed(uint64 message_id, Status status) final { + void on_message_failed(MessageId message_id, Status status) final { } - void on_message_info(uint64 message_id, int32 state, uint64 answer_id, int32 answer_size, int32 source) final { + void on_message_info(MessageId message_id, int32 state, MessageId answer_message_id, int32 answer_size, + int32 source) final { } Status on_destroy_auth_key() final { diff --git a/td/mtproto/RawConnection.cpp b/td/mtproto/RawConnection.cpp index 14e62295e..08cacc742 100644 --- a/td/mtproto/RawConnection.cpp +++ b/td/mtproto/RawConnection.cpp @@ -86,7 +86,7 @@ class RawConnectionDefault final : public RawConnection { return packet_size; } - uint64 send_no_crypto(const Storer &storer) final { + MessageId send_no_crypto(const Storer &storer) final { PacketInfo packet_info; packet_info.no_crypto_flag = true; auto packet = Transport::write(storer, AuthKey(), &packet_info, transport_->max_prepend_size(), @@ -315,7 +315,7 @@ class RawConnectionHttp final : public RawConnection { return packet_size; } - uint64 send_no_crypto(const Storer &storer) final { + MessageId send_no_crypto(const Storer &storer) final { PacketInfo packet_info; packet_info.no_crypto_flag = true; auto packet = Transport::write(storer, AuthKey(), &packet_info); diff --git a/td/mtproto/RawConnection.h b/td/mtproto/RawConnection.h index 5ac928ee2..25db09261 100644 --- a/td/mtproto/RawConnection.h +++ b/td/mtproto/RawConnection.h @@ -7,6 +7,7 @@ #pragma once #include "td/mtproto/ConnectionManager.h" +#include "td/mtproto/MessageId.h" #include "td/mtproto/PacketInfo.h" #include "td/mtproto/TransportType.h" @@ -50,7 +51,7 @@ class RawConnection { virtual TransportType get_transport_type() const = 0; virtual size_t send_crypto(const Storer &storer, uint64 session_id, int64 salt, const AuthKey &auth_key, uint64 quick_ack_token) = 0; - virtual uint64 send_no_crypto(const Storer &storer) = 0; + virtual MessageId send_no_crypto(const Storer &storer) = 0; virtual PollableFdInfo &get_poll_info() = 0; virtual StatsCallback *stats_callback() = 0; @@ -63,7 +64,7 @@ class RawConnection { virtual ~Callback() = default; virtual Status on_raw_packet(const PacketInfo &packet_info, BufferSlice packet) = 0; virtual Status on_quick_ack(uint64 quick_ack_token) { - return Status::Error("Quick acknowledgements are unsupported by the callback"); + return Status::Error("Quick acknowledgements aren't supported by the callback"); } virtual Status before_write() { return Status::OK(); diff --git a/td/mtproto/SessionConnection.cpp b/td/mtproto/SessionConnection.cpp index efb9dda80..15b80e7b2 100644 --- a/td/mtproto/SessionConnection.cpp +++ b/td/mtproto/SessionConnection.cpp @@ -172,7 +172,7 @@ namespace mtproto { */ inline StringBuilder &operator<<(StringBuilder &string_builder, const SessionConnection::MsgInfo &info) { - return string_builder << "[msg_id:" << format::as_hex(info.message_id) << "][seq_no:" << info.seq_no << ']'; + return string_builder << "with " << info.message_id << " and seq_no " << info.seq_no; } unique_ptr SessionConnection::move_as_raw_connection() { @@ -190,7 +190,7 @@ Status SessionConnection::parse_message(TlParser &parser, MsgInfo *info, Slice * if (parser.get_error() != nullptr) { return Status::Error(PSLICE() << "Failed to parse mtproto_api::message: " << parser.get_error()); } - info->message_id = parser.fetch_long_unsafe(); + info->message_id = MessageId(static_cast(parser.fetch_long_unsafe())); if (crypto_flag) { info->seq_no = parser.fetch_int_unsafe(); } @@ -223,22 +223,22 @@ Status SessionConnection::on_packet_container(const MsgInfo &info, Slice packet) if (parser.get_error()) { return Status::Error(PSLICE() << "Failed to parse mtproto_api::rpc_container: " << parser.get_error()); } - VLOG(mtproto) << "Receive container " << format::as_hex(container_message_id_) << " of size " << size; + VLOG(mtproto) << "Receive container " << container_message_id_ << " of size " << size; for (int i = 0; i < size; i++) { TRY_STATUS(parse_packet(parser)); } return Status::OK(); } -void SessionConnection::reset_server_time_difference(uint64 message_id) { +void SessionConnection::reset_server_time_difference(MessageId message_id) { VLOG(mtproto) << "Reset server time difference"; - auth_data_->reset_server_time_difference(static_cast(message_id >> 32) - Time::now()); + auth_data_->reset_server_time_difference(static_cast(message_id.get() >> 32) - Time::now()); callback_->on_server_time_difference_updated(true); } Status SessionConnection::on_packet_rpc_result(const MsgInfo &info, Slice packet) { TlParser parser(packet); - uint64 req_msg_id = parser.fetch_long(); + uint64 req_msg_id = static_cast(parser.fetch_long()); if (parser.get_error()) { return Status::Error(PSLICE() << "Failed to parse mtproto_api::rpc_result: " << parser.get_error()); } @@ -246,9 +246,9 @@ Status SessionConnection::on_packet_rpc_result(const MsgInfo &info, Slice packet LOG(ERROR) << "Receive an update in rpc_result " << info; return Status::Error("Receive an update in rpc_result"); } - VLOG(mtproto) << "Receive result for request " << format::as_hex(req_msg_id) << " with " << info; + VLOG(mtproto) << "Receive result for request with " << MessageId(req_msg_id) << ' ' << info; - if (info.message_id < req_msg_id - (static_cast(15) << 32)) { + if (info.message_id.get() < req_msg_id - (static_cast(15) << 32)) { reset_server_time_difference(info.message_id); } @@ -258,7 +258,7 @@ Status SessionConnection::on_packet_rpc_result(const MsgInfo &info, Slice packet if (parser.get_error()) { return Status::Error(PSLICE() << "Failed to parse mtproto_api::rpc_error: " << parser.get_error()); } - callback_->on_message_result_error(req_msg_id, rpc_error.error_code_, rpc_error.error_message_.str()); + callback_->on_message_result_error(MessageId(req_msg_id), rpc_error.error_code_, rpc_error.error_message_.str()); return Status::OK(); } case mtproto_api::gzip_packed::ID: { @@ -269,11 +269,11 @@ Status SessionConnection::on_packet_rpc_result(const MsgInfo &info, Slice packet // yep, gzip in rpc_result BufferSlice object = gzdecode(gzip.packed_data_); // send header no more optimization - return callback_->on_message_result_ok(req_msg_id, std::move(object), info.size); + return callback_->on_message_result_ok(MessageId(req_msg_id), std::move(object), info.size); } default: packet.remove_prefix(sizeof(req_msg_id)); - return callback_->on_message_result_ok(req_msg_id, as_buffer_slice(packet), info.size); + return callback_->on_message_result_ok(MessageId(req_msg_id), as_buffer_slice(packet), info.size); } } @@ -284,17 +284,17 @@ Status SessionConnection::on_packet(const MsgInfo &info, const T &packet) { } Status SessionConnection::on_packet(const MsgInfo &info, const mtproto_api::destroy_auth_key_ok &destroy_auth_key) { - VLOG(mtproto) << "Receive destroy_auth_key_ok with " << info; + VLOG(mtproto) << "Receive destroy_auth_key_ok " << info; return on_destroy_auth_key(destroy_auth_key); } Status SessionConnection::on_packet(const MsgInfo &info, const mtproto_api::destroy_auth_key_none &destroy_auth_key) { - VLOG(mtproto) << "Receive destroy_auth_key_none with " << info; + VLOG(mtproto) << "Receive destroy_auth_key_none " << info; return on_destroy_auth_key(destroy_auth_key); } Status SessionConnection::on_packet(const MsgInfo &info, const mtproto_api::destroy_auth_key_fail &destroy_auth_key) { - VLOG(mtproto) << "Receive destroy_auth_key_fail with " << info; + VLOG(mtproto) << "Receive destroy_auth_key_fail " << info; return on_destroy_auth_key(destroy_auth_key); } @@ -304,14 +304,14 @@ Status SessionConnection::on_destroy_auth_key(const mtproto_api::DestroyAuthKeyR } Status SessionConnection::on_packet(const MsgInfo &info, const mtproto_api::new_session_created &new_session_created) { - auto first_message_id = static_cast(new_session_created.first_msg_id_); - VLOG(mtproto) << "Receive new_session_created with " << info << ": [first_msg_id:" << format::as_hex(first_message_id) - << "] [unique_id:" << format::as_hex(new_session_created.unique_id_) << ']'; + auto first_message_id = MessageId(static_cast(new_session_created.first_msg_id_)); + VLOG(mtproto) << "Receive new_session_created " << info << ": [first " << first_message_id + << "] [unique_id:" << new_session_created.unique_id_ << ']'; auto it = service_queries_.find(first_message_id); if (it != service_queries_.end()) { first_message_id = it->second.container_message_id_; - LOG(INFO) << "Update first_message_id to container's " << format::as_hex(first_message_id); + LOG(INFO) << "Update first_message_id to container's " << first_message_id; } callback_->on_new_session_created(new_session_created.unique_id_, first_message_id); @@ -320,7 +320,8 @@ Status SessionConnection::on_packet(const MsgInfo &info, const mtproto_api::new_ Status SessionConnection::on_packet(const MsgInfo &info, const mtproto_api::bad_msg_notification &bad_msg_notification) { - MsgInfo bad_info{static_cast(bad_msg_notification.bad_msg_id_), bad_msg_notification.bad_msg_seqno_, 0}; + MsgInfo bad_info{MessageId(static_cast(bad_msg_notification.bad_msg_id_)), + bad_msg_notification.bad_msg_seqno_, 0}; enum Code { MsgIdTooLow = 16, MsgIdTooHigh = 17, @@ -383,8 +384,8 @@ Status SessionConnection::on_packet(const MsgInfo &info, } Status SessionConnection::on_packet(const MsgInfo &info, const mtproto_api::bad_server_salt &bad_server_salt) { - MsgInfo bad_info{static_cast(bad_server_salt.bad_msg_id_), bad_server_salt.bad_msg_seqno_, 0}; - VLOG(mtproto) << "Receive bad_server_salt with " << info << ": " << bad_info; + MsgInfo bad_info{MessageId(static_cast(bad_server_salt.bad_msg_id_)), bad_server_salt.bad_msg_seqno_, 0}; + VLOG(mtproto) << "Receive bad_server_salt " << info << ": " << bad_info; auth_data_->set_server_salt(bad_server_salt.new_server_salt_, Time::now_cached()); callback_->on_server_salt_updated(); @@ -393,8 +394,9 @@ Status SessionConnection::on_packet(const MsgInfo &info, const mtproto_api::bad_ } Status SessionConnection::on_packet(const MsgInfo &info, const mtproto_api::msgs_ack &msgs_ack) { - VLOG(mtproto) << "Receive msgs_ack with " << info << ": " << msgs_ack.msg_ids_; - for (auto message_id : msgs_ack.msg_ids_) { + auto message_ids = transform(msgs_ack.msg_ids_, [](int64 msg_id) { return MessageId(static_cast(msg_id)); }); + VLOG(mtproto) << "Receive msgs_ack " << info << ": " << message_ids; + for (auto message_id : message_ids) { callback_->on_message_ack(message_id); } return Status::OK(); @@ -407,8 +409,8 @@ Status SessionConnection::on_packet(const MsgInfo &info, const mtproto_api::gzip } Status SessionConnection::on_packet(const MsgInfo &info, const mtproto_api::pong &pong) { - VLOG(mtproto) << "Receive pong with " << info; - if (info.message_id < static_cast(pong.msg_id_) - (static_cast(15) << 32)) { + VLOG(mtproto) << "Receive pong " << info; + if (info.message_id.get() < static_cast(pong.msg_id_) - (static_cast(15) << 32)) { reset_server_time_difference(info.message_id); } last_pong_at_ = Time::now_cached(); @@ -424,7 +426,7 @@ Status SessionConnection::on_packet(const MsgInfo &info, const mtproto_api::futu } auto now = Time::now_cached(); auth_data_->set_future_salts(new_salts, now); - VLOG(mtproto) << "Receive future_salts with " << info << ": is_valid = " << auth_data_->is_server_salt_valid(now) + VLOG(mtproto) << "Receive future_salts " << info << ": is_valid = " << auth_data_->is_server_salt_valid(now) << ", has_salt = " << auth_data_->has_salt(now) << ", need_future_salts = " << auth_data_->need_future_salts(now); callback_->on_server_salt_updated(); @@ -438,14 +440,14 @@ Status SessionConnection::on_msgs_state_info(const vector &msg_ids, Slice } size_t i = 0; for (auto msg_id : msg_ids) { - callback_->on_message_info(static_cast(msg_id), info[i], 0, 0, 1); + callback_->on_message_info(MessageId(static_cast(msg_id)), info[i], MessageId(), 0, 1); i++; } return Status::OK(); } Status SessionConnection::on_packet(const MsgInfo &info, const mtproto_api::msgs_state_info &msgs_state_info) { - auto message_id = static_cast(msgs_state_info.req_msg_id_); + auto message_id = MessageId(static_cast(msgs_state_info.req_msg_id_)); auto it = service_queries_.find(message_id); if (it == service_queries_.end()) { return Status::Error("Unknown msgs_state_info"); @@ -456,26 +458,28 @@ Status SessionConnection::on_packet(const MsgInfo &info, const mtproto_api::msgs if (query.type_ != ServiceQuery::GetStateInfo) { return Status::Error("Receive msgs_state_info in response not to GetStateInfo"); } - VLOG(mtproto) << "Receive msgs_state_info with " << info; + VLOG(mtproto) << "Receive msgs_state_info " << info; return on_msgs_state_info(query.msg_ids_, msgs_state_info.info_); } Status SessionConnection::on_packet(const MsgInfo &info, const mtproto_api::msgs_all_info &msgs_all_info) { - VLOG(mtproto) << "Receive msgs_all_info with " << info; + VLOG(mtproto) << "Receive msgs_all_info " << info; return on_msgs_state_info(msgs_all_info.msg_ids_, msgs_all_info.info_); } Status SessionConnection::on_packet(const MsgInfo &info, const mtproto_api::msg_detailed_info &msg_detailed_info) { - VLOG(mtproto) << "Receive msg_detailed_info with " << info; - callback_->on_message_info(msg_detailed_info.msg_id_, msg_detailed_info.status_, msg_detailed_info.answer_msg_id_, - msg_detailed_info.bytes_, 2); + VLOG(mtproto) << "Receive msg_detailed_info " << info; + callback_->on_message_info(MessageId(static_cast(msg_detailed_info.msg_id_)), msg_detailed_info.status_, + MessageId(static_cast(msg_detailed_info.answer_msg_id_)), msg_detailed_info.bytes_, + 2); return Status::OK(); } Status SessionConnection::on_packet(const MsgInfo &info, const mtproto_api::msg_new_detailed_info &msg_new_detailed_info) { - VLOG(mtproto) << "Receive msg_new_detailed_info with " << info; - callback_->on_message_info(0, 0, msg_new_detailed_info.answer_msg_id_, msg_new_detailed_info.bytes_, 0); + VLOG(mtproto) << "Receive msg_new_detailed_info " << info; + callback_->on_message_info(MessageId(), 0, MessageId(static_cast(msg_new_detailed_info.answer_msg_id_)), + msg_new_detailed_info.bytes_, 0); return Status::OK(); } @@ -517,9 +521,8 @@ Status SessionConnection::on_slice_packet(const MsgInfo &info, Slice packet) { auto get_update_description = [&] { return PSTRING() << "update from " << get_name() << " with auth key " << auth_data_->get_auth_key().id() << " active for " << (Time::now() - created_at_) << " seconds in container " - << container_message_id_ << " from session " << auth_data_->get_session_id() << " with " << info - << ", main_message_id = " << format::as_hex(main_message_id_) - << " and original size = " << info.size; + << container_message_id_ << " from session " << auth_data_->get_session_id() << ' ' << info + << ", main " << main_message_id_ << " and original size = " << info.size; }; // It is an update... I hope. @@ -560,8 +563,8 @@ Status SessionConnection::on_main_packet(const PacketInfo &packet_info, Slice pa } VLOG(raw_mtproto) << "Receive packet of size " << packet.size() << ':' << format::as_hex_dump<4>(packet); - VLOG(mtproto) << "Receive packet with seq_no " << packet_info.seq_no << " and msg_id " - << format::as_hex(packet_info.message_id) << " of size " << packet.size(); + VLOG(mtproto) << "Receive packet with " << packet_info.message_id << " and seq_no " << packet_info.seq_no + << " of size " << packet.size(); if (packet_info.no_crypto_flag) { return Status::Error("Unencrypted packet"); @@ -576,7 +579,7 @@ Status SessionConnection::on_main_packet(const PacketInfo &packet_info, Slice pa return Status::OK(); } -void SessionConnection::on_message_failed(uint64 message_id, Status status) { +void SessionConnection::on_message_failed(MessageId message_id, Status status) { callback_->on_message_failed(message_id, std::move(status)); sent_destroy_auth_key_ = false; @@ -584,8 +587,8 @@ void SessionConnection::on_message_failed(uint64 message_id, Status status) { if (message_id == last_ping_message_id_ || message_id == last_ping_container_message_id_) { // restart ping immediately last_ping_at_ = 0; - last_ping_message_id_ = 0; - last_ping_container_message_id_ = 0; + last_ping_message_id_ = {}; + last_ping_container_message_id_ = {}; } auto cit = container_to_service_message_id_.find(message_id); @@ -599,7 +602,7 @@ void SessionConnection::on_message_failed(uint64 message_id, Status status) { } } -void SessionConnection::on_message_failed_inner(uint64 message_id) { +void SessionConnection::on_message_failed_inner(MessageId message_id) { auto it = service_queries_.find(message_id); if (it == service_queries_.end()) { return; @@ -610,12 +613,12 @@ void SessionConnection::on_message_failed_inner(uint64 message_id) { switch (query.type_) { case ServiceQuery::ResendAnswer: for (auto msg_id : query.msg_ids_) { - resend_answer(static_cast(msg_id)); + resend_answer(MessageId(static_cast(msg_id))); } break; case ServiceQuery::GetStateInfo: for (auto msg_id : query.msg_ids_) { - get_state_info(static_cast(msg_id)); + get_state_info(MessageId(static_cast(msg_id))); } break; default: @@ -726,7 +729,7 @@ Status SessionConnection::on_raw_packet(const PacketInfo &packet_info, BufferSli } Status SessionConnection::on_quick_ack(uint64 quick_ack_token) { - callback_->on_message_ack(quick_ack_token); + callback_->on_message_ack(MessageId(quick_ack_token)); return Status::OK(); } @@ -773,8 +776,8 @@ void SessionConnection::set_online(bool online_flag, bool is_main) { last_read_at_ = now; } last_ping_at_ = 0; - last_ping_message_id_ = 0; - last_ping_container_message_id_ = 0; + last_ping_message_id_ = {}; + last_ping_container_message_id_ = {}; } void SessionConnection::do_close(Status status) { @@ -790,10 +793,10 @@ void SessionConnection::send_crypto(const Storer &storer, uint64 quick_ack_token auth_data_->get_auth_key(), quick_ack_token); } -Result SessionConnection::send_query(BufferSlice buffer, bool gzip_flag, uint64 message_id, - vector invoke_after_message_ids, bool use_quick_ack) { +Result SessionConnection::send_query(BufferSlice buffer, bool gzip_flag, MessageId message_id, + vector invoke_after_message_ids, bool use_quick_ack) { CHECK(mode_ != Mode::HttpLongPoll); // "LongPoll connection is only for http_wait" - if (message_id == 0) { + if (message_id == MessageId()) { message_id = auth_data_->next_message_id(Time::now_cached()); } auto seq_no = auth_data_->next_seq_no(true); @@ -802,28 +805,28 @@ Result SessionConnection::send_query(BufferSlice buffer, bool gzip_flag, } to_send_.push_back(MtprotoQuery{message_id, seq_no, std::move(buffer), gzip_flag, std::move(invoke_after_message_ids), use_quick_ack}); - VLOG(mtproto) << "Invoke query with msg_id " << format::as_hex(message_id) << " and seq_no " << seq_no << " of size " + VLOG(mtproto) << "Invoke query with " << message_id << " and seq_no " << seq_no << " of size " << to_send_.back().packet.size() << " after " << invoke_after_message_ids << (use_quick_ack ? " with quick ack" : ""); return message_id; } -void SessionConnection::get_state_info(uint64 message_id) { +void SessionConnection::get_state_info(MessageId message_id) { if (to_get_state_info_message_ids_.empty()) { send_before(Time::now_cached()); } to_get_state_info_message_ids_.push_back(message_id); } -void SessionConnection::resend_answer(uint64 message_id) { +void SessionConnection::resend_answer(MessageId message_id) { if (to_resend_answer_message_ids_.empty()) { send_before(Time::now_cached() + RESEND_ANSWER_DELAY); } to_resend_answer_message_ids_.push_back(message_id); } -void SessionConnection::cancel_answer(uint64 message_id) { +void SessionConnection::cancel_answer(MessageId message_id) { if (to_cancel_answer_message_ids_.empty()) { send_before(Time::now_cached() + RESEND_ANSWER_DELAY); } @@ -835,7 +838,7 @@ void SessionConnection::destroy_key() { need_destroy_auth_key_ = true; } -std::pair SessionConnection::encrypted_bind(int64 perm_key, int64 nonce, int32 expires_at) { +std::pair SessionConnection::encrypted_bind(int64 perm_key, int64 nonce, int32 expires_at) { int64 temp_key = auth_data_->get_tmp_auth_key().id(); mtproto_api::bind_auth_key_inner object(nonce, temp_key, perm_key, auth_data_->get_session_id(), expires_at); @@ -865,8 +868,8 @@ void SessionConnection::force_ack() { } } -void SessionConnection::send_ack(uint64 message_id) { - VLOG(mtproto) << "Send ack: [msg_id:" << format::as_hex(message_id) << "]"; +void SessionConnection::send_ack(MessageId message_id) { + VLOG(mtproto) << "Send ack for " << message_id; if (to_ack_message_ids_.empty()) { send_before(Time::now_cached() + ACK_DELAY); } @@ -881,7 +884,7 @@ void SessionConnection::send_ack(uint64 message_id) { } } -// don't send ping in poll mode. +// don't send ping in poll mode bool SessionConnection::may_ping() const { return last_ping_at_ == 0 || (mode_ != Mode::HttpLongPoll && last_ping_at_ + ping_may_delay() < Time::now_cached()); } @@ -893,7 +896,7 @@ bool SessionConnection::must_ping() const { void SessionConnection::flush_packet() { bool has_salt = auth_data_->has_salt(Time::now_cached()); // ping - uint64 container_message_id = 0; + MessageId container_message_id; int64 ping_id = 0; if (has_salt && may_ping()) { ping_id = ++cur_ping_id_; @@ -963,53 +966,54 @@ void SessionConnection::flush_packet() { << tag("cancel", to_cancel_answer_message_ids_.size()) << tag("destroy_key", destroy_auth_key) << tag("auth_key_id", auth_data_->get_auth_key().id()); - auto cut_tail = [](vector &v, size_t size, Slice name) { - if (size >= v.size()) { - auto result = transform(v, [](uint64 x) { return static_cast(x); }); - v.clear(); + auto cut_tail = [](vector &message_ids, size_t size, Slice name) { + if (size >= message_ids.size()) { + auto result = transform(message_ids, [](MessageId message_id) { return static_cast(message_id.get()); }); + message_ids.clear(); return result; } - LOG(WARNING) << "Too many message identifiers in container " << name << ": " << v.size() << " instead of " << size; - auto new_size = v.size() - size; + LOG(WARNING) << "Too many message identifiers in container " << name << ": " << message_ids.size() << " instead of " + << size; + auto new_size = message_ids.size() - size; vector result(size); for (size_t i = 0; i < size; i++) { - result[i] = static_cast(v[i + new_size]); + result[i] = static_cast(message_ids[i + new_size].get()); } - v.resize(new_size); + message_ids.resize(new_size); return result; }; // no more than 8192 message identifiers per container.. auto to_resend_answer = cut_tail(to_resend_answer_message_ids_, 8192, "resend_answer"); - uint64 resend_answer_message_id = 0; + MessageId resend_answer_message_id; CHECK(queries.size() <= 1020); auto to_cancel_answer = cut_tail(to_cancel_answer_message_ids_, 1020 - queries.size(), "cancel_answer"); auto to_get_state_info = cut_tail(to_get_state_info_message_ids_, 8192, "get_state_info"); - uint64 get_state_info_message_id = 0; + MessageId get_state_info_message_id; auto to_ack = cut_tail(to_ack_message_ids_, 8192, "ack"); - uint64 ping_message_id = 0; + MessageId ping_message_id; bool use_quick_ack = std::any_of(queries.begin(), queries.end(), [](const auto &query) { return query.use_quick_ack; }); { // LOG(ERROR) << (auth_data_->get_header().empty() ? '-' : '+'); - uint64 parent_message_id = 0; + MessageId parent_message_id; auto storer = PacketStorer( queries, auth_data_->get_header(), std::move(to_ack), ping_id, static_cast(ping_disconnect_delay() + 2.0), max_delay, max_after, max_wait, future_salt_n, to_get_state_info, to_resend_answer, to_cancel_answer, destroy_auth_key, auth_data_, &container_message_id, &get_state_info_message_id, &resend_answer_message_id, &ping_message_id, &parent_message_id); - auto quick_ack_token = use_quick_ack ? parent_message_id : 0; + auto quick_ack_token = use_quick_ack ? parent_message_id.get() : 0; send_crypto(storer, quick_ack_token); } - if (resend_answer_message_id) { + if (resend_answer_message_id != MessageId()) { service_queries_.emplace(resend_answer_message_id, ServiceQuery{ServiceQuery::ResendAnswer, container_message_id, std::move(to_resend_answer)}); } - if (get_state_info_message_id) { + if (get_state_info_message_id != MessageId()) { service_queries_.emplace(get_state_info_message_id, ServiceQuery{ServiceQuery::GetStateInfo, container_message_id, std::move(to_get_state_info)}); } @@ -1018,8 +1022,8 @@ void SessionConnection::flush_packet() { last_ping_message_id_ = ping_message_id; } - if (container_message_id != 0) { - auto message_ids = transform(queries, [](const MtprotoQuery &x) { return static_cast(x.message_id); }); + if (container_message_id != MessageId()) { + auto message_ids = transform(queries, [](const MtprotoQuery &x) { return x.message_id; }); // some acks may be lost here. Nobody will resend them if something goes wrong with query. // It is mostly problem for server. We will just drop this answers in next connection @@ -1028,10 +1032,10 @@ void SessionConnection::flush_packet() { // So I will re-ask salt if have no answer in 60 second. callback_->on_container_sent(container_message_id, std::move(message_ids)); - if (resend_answer_message_id) { + if (resend_answer_message_id != MessageId()) { container_to_service_message_id_[container_message_id].push_back(resend_answer_message_id); } - if (get_state_info_message_id) { + if (get_state_info_message_id != MessageId()) { container_to_service_message_id_[container_message_id].push_back(get_state_info_message_id); } } diff --git a/td/mtproto/SessionConnection.h b/td/mtproto/SessionConnection.h index 0e3575929..669e38cc7 100644 --- a/td/mtproto/SessionConnection.h +++ b/td/mtproto/SessionConnection.h @@ -6,6 +6,7 @@ // #pragma once +#include "td/mtproto/MessageId.h" #include "td/mtproto/MtprotoQuery.h" #include "td/mtproto/PacketInfo.h" #include "td/mtproto/RawConnection.h" @@ -67,14 +68,14 @@ class SessionConnection final unique_ptr move_as_raw_connection(); // Interface - Result TD_WARN_UNUSED_RESULT send_query(BufferSlice buffer, bool gzip_flag, uint64 message_id = 0, - vector invoke_after_message_ids = {}, - bool use_quick_ack = false); - std::pair encrypted_bind(int64 perm_key, int64 nonce, int32 expires_at); + Result TD_WARN_UNUSED_RESULT send_query(BufferSlice buffer, bool gzip_flag, MessageId message_id = {}, + vector invoke_after_message_ids = {}, + bool use_quick_ack = false); + std::pair encrypted_bind(int64 perm_key, int64 nonce, int32 expires_at); - void get_state_info(uint64 message_id); - void resend_answer(uint64 message_id); - void cancel_answer(uint64 message_id); + void get_state_info(MessageId message_id); + void resend_answer(MessageId message_id); + void cancel_answer(MessageId message_id); void destroy_key(); void set_online(bool online_flag, bool is_main); @@ -95,19 +96,20 @@ class SessionConnection final virtual void on_server_salt_updated() = 0; virtual void on_server_time_difference_updated(bool force) = 0; - virtual void on_new_session_created(uint64 unique_id, uint64 first_message_id) = 0; + virtual void on_new_session_created(uint64 unique_id, MessageId first_message_id) = 0; virtual void on_session_failed(Status status) = 0; - virtual void on_container_sent(uint64 container_message_id, vector message_ids) = 0; + virtual void on_container_sent(MessageId container_message_id, vector message_ids) = 0; virtual Status on_pong() = 0; virtual Status on_update(BufferSlice packet) = 0; - virtual void on_message_ack(uint64 message_id) = 0; - virtual Status on_message_result_ok(uint64 message_id, BufferSlice packet, size_t original_size) = 0; - virtual void on_message_result_error(uint64 message_id, int code, string message) = 0; - virtual void on_message_failed(uint64 message_id, Status status) = 0; - virtual void on_message_info(uint64 message_id, int32 state, uint64 answer_id, int32 answer_size, int32 source) = 0; + virtual void on_message_ack(MessageId message_id) = 0; + virtual Status on_message_result_ok(MessageId message_id, BufferSlice packet, size_t original_size) = 0; + virtual void on_message_result_error(MessageId message_id, int code, string message) = 0; + virtual void on_message_failed(MessageId message_id, Status status) = 0; + virtual void on_message_info(MessageId message_id, int32 state, MessageId answer_message_id, int32 answer_size, + int32 source) = 0; virtual Status on_destroy_auth_key() = 0; }; @@ -123,7 +125,7 @@ class SessionConnection final static constexpr double RESEND_ANSWER_DELAY = 0.001; // 0.001s struct MsgInfo { - uint64 message_id; + MessageId message_id; int32 seq_no; size_t size; }; @@ -161,21 +163,21 @@ class SessionConnection final static constexpr int HTTP_MAX_DELAY = 30; // 0.03s vector to_send_; - vector to_ack_message_ids_; + vector to_ack_message_ids_; double force_send_at_ = 0; struct ServiceQuery { enum Type { GetStateInfo, ResendAnswer } type_; - uint64 container_message_id_; + MessageId container_message_id_; vector msg_ids_; }; - vector to_resend_answer_message_ids_; - vector to_cancel_answer_message_ids_; - vector to_get_state_info_message_ids_; - FlatHashMap service_queries_; + vector to_resend_answer_message_ids_; + vector to_cancel_answer_message_ids_; + vector to_get_state_info_message_ids_; + FlatHashMap service_queries_; // nobody cleans up this map. But it should be really small. - FlatHashMap> container_to_service_message_id_; + FlatHashMap, MessageIdHash> container_to_service_message_id_; double random_delay_ = 0; double last_read_at_ = 0; @@ -183,9 +185,9 @@ class SessionConnection final double last_pong_at_ = 0; double real_last_read_at_ = 0; double real_last_pong_at_ = 0; - uint64 cur_ping_id_ = 0; - uint64 last_ping_message_id_ = 0; - uint64 last_ping_container_message_id_ = 0; + int64 cur_ping_id_ = 0; + MessageId last_ping_message_id_; + MessageId last_ping_container_message_id_; uint64 last_read_size_ = 0; uint64 last_write_size_ = 0; @@ -200,8 +202,8 @@ class SessionConnection final Mode mode_; bool connected_flag_ = false; - uint64 container_message_id_ = 0; - uint64 main_message_id_ = 0; + MessageId container_message_id_; + MessageId main_message_id_; double created_at_ = 0; unique_ptr raw_connection_; @@ -218,7 +220,7 @@ class SessionConnection final }; } - void reset_server_time_difference(uint64 message_id); + void reset_server_time_difference(MessageId message_id); static Status parse_message(TlParser &parser, MsgInfo *info, Slice *packet, bool crypto_flag = true) TD_WARN_UNUSED_RESULT; @@ -254,12 +256,12 @@ class SessionConnection final Status on_slice_packet(const MsgInfo &info, Slice packet) TD_WARN_UNUSED_RESULT; Status on_main_packet(const PacketInfo &packet_info, Slice packet) TD_WARN_UNUSED_RESULT; - void on_message_failed(uint64 message_id, Status status); - void on_message_failed_inner(uint64 message_id); + void on_message_failed(MessageId message_id, Status status); + void on_message_failed_inner(MessageId message_id); void do_close(Status status); - void send_ack(uint64 message_id); + void send_ack(MessageId message_id); void send_crypto(const Storer &storer, uint64 quick_ack_token); void send_before(double tm); bool may_ping() const; diff --git a/td/mtproto/Transport.cpp b/td/mtproto/Transport.cpp index c3ef707e9..2e47a42d9 100644 --- a/td/mtproto/Transport.cpp +++ b/td/mtproto/Transport.cpp @@ -8,6 +8,7 @@ #include "td/mtproto/AuthKey.h" #include "td/mtproto/KDF.h" +#include "td/mtproto/MessageId.h" #include "td/utils/as.h" #include "td/utils/crypto.h" @@ -42,7 +43,7 @@ struct CryptoHeader { // It is weird to generate message_id and seq_no while writing a packet. // - // uint64 message_id; + // uint64 msg_id; // uint32 seq_no; // uint32 message_data_length; uint8 data[0]; // use compiler extension @@ -68,7 +69,7 @@ struct CryptoHeader { }; struct CryptoPrefix { - uint64 message_id; + uint64 msg_id; uint32 seq_no; uint32 message_data_length; }; @@ -108,9 +109,9 @@ struct EndToEndPrefix { struct NoCryptoHeader { uint64 auth_key_id; - // message_id is removed from CryptoHeader. Should be removed from here too. + // msg_id is removed from CryptoHeader. Should be removed from here too. // - // uint64 message_id; + // uint64 msg_id; // uint32 message_data_length; uint8 data[0]; // use compiler extension @@ -309,7 +310,7 @@ Status Transport::read_crypto(MutableSlice message, const AuthKey &auth_key, Pac packet_info->type = PacketInfo::Common; packet_info->salt = header->salt; packet_info->session_id = header->session_id; - packet_info->message_id = prefix->message_id; + packet_info->message_id = MessageId(prefix->msg_id); packet_info->seq_no = prefix->seq_no; return Status::OK(); } diff --git a/td/telegram/net/Session.cpp b/td/telegram/net/Session.cpp index 73344f1f9..dc483cb8a 100644 --- a/td/telegram/net/Session.cpp +++ b/td/telegram/net/Session.cpp @@ -534,7 +534,7 @@ void Session::hangup() { } void Session::raw_event(const Event::Raw &event) { - auto message_id = event.u64; + auto message_id = mtproto::MessageId(event.u64); auto it = sent_queries_.find(message_id); if (it == sent_queries_.end()) { return; @@ -552,7 +552,7 @@ void Session::raw_event(const Event::Raw &event) { if (main_connection_.state_ == ConnectionInfo::State::Ready) { main_connection_.connection_->cancel_answer(message_id); } else { - to_cancel_.push_back(message_id); + to_cancel_message_ids_.push_back(message_id); } loop(); } @@ -697,8 +697,8 @@ void Session::on_closed(Status status) { current_info_->state_ = ConnectionInfo::State::Empty; } -void Session::on_new_session_created(uint64 unique_id, uint64 first_message_id) { - LOG(INFO) << "New session " << unique_id << " created with first message_id " << format::as_hex(first_message_id); +void Session::on_new_session_created(uint64 unique_id, mtproto::MessageId first_message_id) { + LOG(INFO) << "New session " << unique_id << " created with first " << first_message_id; if (!use_pfs_ && !auth_data_.use_pfs()) { last_success_timestamp_ = Time::now(); } @@ -712,9 +712,9 @@ void Session::on_new_session_created(uint64 unique_id, uint64 first_message_id) auto first_query_it = sent_queries_.find(first_message_id); if (first_query_it != sent_queries_.end()) { first_message_id = first_query_it->second.container_message_id_; - LOG(INFO) << "Update first_message_id to container's " << format::as_hex(first_message_id); + LOG(INFO) << "Update first message to container's " << first_message_id; } else { - LOG(INFO) << "Failed to find query " << format::as_hex(first_message_id) << " from the new session"; + LOG(INFO) << "Failed to find sent " << first_message_id << " from the new session"; } for (auto it = sent_queries_.begin(); it != sent_queries_.end();) { Query *query_ptr = &it->second; @@ -741,10 +741,10 @@ void Session::on_session_failed(Status status) { callback_->on_failed(); } -void Session::on_container_sent(uint64 container_message_id, vector message_ids) { - CHECK(container_message_id != 0); +void Session::on_container_sent(mtproto::MessageId container_message_id, vector message_ids) { + CHECK(container_message_id != mtproto::MessageId()); - td::remove_if(message_ids, [&](uint64 message_id) { + td::remove_if(message_ids, [&](mtproto::MessageId message_id) { auto it = sent_queries_.find(message_id); if (it == sent_queries_.end()) { return true; // remove @@ -759,11 +759,11 @@ void Session::on_container_sent(uint64 container_message_id, vector mess sent_containers_.emplace(container_message_id, ContainerInfo{size, std::move(message_ids)}); } -void Session::on_message_ack(uint64 message_id) { +void Session::on_message_ack(mtproto::MessageId message_id) { on_message_ack_impl(message_id, 1); } -void Session::on_message_ack_impl(uint64 container_message_id, int32 type) { +void Session::on_message_ack_impl(mtproto::MessageId container_message_id, int32 type) { auto cit = sent_containers_.find(container_message_id); if (cit != sent_containers_.end()) { auto container_info = std::move(cit->second); @@ -778,7 +778,7 @@ void Session::on_message_ack_impl(uint64 container_message_id, int32 type) { on_message_ack_impl_inner(container_message_id, type, false); } -void Session::on_message_ack_impl_inner(uint64 message_id, int32 type, bool in_container) { +void Session::on_message_ack_impl_inner(mtproto::MessageId message_id, int32 type, bool in_container) { auto it = sent_queries_.find(message_id); if (it == sent_queries_.end()) { return; @@ -796,7 +796,7 @@ void Session::on_message_ack_impl_inner(uint64 message_id, int32 type, bool in_c mark_as_known(it->first, &it->second); } -void Session::dec_container(uint64 container_message_id, Query *query) { +void Session::dec_container(mtproto::MessageId container_message_id, Query *query) { if (query->container_message_id_ == container_message_id) { // message was sent without any container return; @@ -812,7 +812,7 @@ void Session::dec_container(uint64 container_message_id, Query *query) { } } -void Session::cleanup_container(uint64 container_message_id, Query *query) { +void Session::cleanup_container(mtproto::MessageId container_message_id, Query *query) { if (query->container_message_id_ == container_message_id) { // message was sent without any container return; @@ -823,7 +823,7 @@ void Session::cleanup_container(uint64 container_message_id, Query *query) { sent_containers_.erase(query->container_message_id_); } -void Session::mark_as_known(uint64 message_id, Query *query) { +void Session::mark_as_known(mtproto::MessageId message_id, Query *query) { { auto lock = query->net_query_->lock(); query->net_query_->get_data_unsafe().unknown_state_ = false; @@ -839,7 +839,7 @@ void Session::mark_as_known(uint64 message_id, Query *query) { } } -void Session::mark_as_unknown(uint64 message_id, Query *query) { +void Session::mark_as_unknown(mtproto::MessageId message_id, Query *query) { { auto lock = query->net_query_->lock(); query->net_query_->get_data_unsafe().unknown_state_ = true; @@ -849,7 +849,7 @@ void Session::mark_as_unknown(uint64 message_id, Query *query) { } VLOG(net_query) << "Mark as unknown " << query->net_query_; query->is_unknown_ = true; - CHECK(message_id != 0); + CHECK(message_id != mtproto::MessageId()); unknown_queries_.insert(message_id); } @@ -866,16 +866,16 @@ Status Session::on_update(BufferSlice packet) { return Status::OK(); } -Status Session::on_message_result_ok(uint64 message_id, BufferSlice packet, size_t original_size) { +Status Session::on_message_result_ok(mtproto::MessageId message_id, BufferSlice packet, size_t original_size) { last_success_timestamp_ = Time::now(); TlParser parser(packet.as_slice()); - int32 response_id = parser.fetch_int(); + int32 response_tl_id = parser.fetch_int(); auto it = sent_queries_.find(message_id); if (it == sent_queries_.end()) { - LOG(DEBUG) << "Drop result to " << tag("message_id", format::as_hex(message_id)) - << tag("original_size", original_size) << tag("response_id", format::as_hex(response_id)); + LOG(DEBUG) << "Drop result to " << message_id << tag("original_size", original_size) + << tag("response_tl", format::as_hex(response_tl_id)); if (original_size > 16 * 1024) { dropped_size_ += original_size; @@ -896,9 +896,9 @@ Status Session::on_message_result_ok(uint64 message_id, BufferSlice packet, size if (!parser.get_error()) { // Steal authorization information. // It is a dirty hack, yep. - if (response_id == telegram_api::auth_authorization::ID || - response_id == telegram_api::auth_loginTokenSuccess::ID || - response_id == telegram_api::auth_sentCodeSuccess::ID) { + if (response_tl_id == telegram_api::auth_authorization::ID || + response_tl_id == telegram_api::auth_loginTokenSuccess::ID || + response_tl_id == telegram_api::auth_sentCodeSuccess::ID) { if (query_ptr->net_query_->tl_constructor() != telegram_api::auth_importAuthorization::ID) { G()->net_query_dispatcher().set_main_dc_id(raw_dc_id_); } @@ -918,7 +918,7 @@ Status Session::on_message_result_ok(uint64 message_id, BufferSlice packet, size return Status::OK(); } -void Session::on_message_result_error(uint64 message_id, int error_code, string message) { +void Session::on_message_result_error(mtproto::MessageId message_id, int error_code, string message) { if (!check_utf8(message)) { LOG(ERROR) << "Receive invalid error message \"" << message << '"'; message = "INVALID_UTF8_ERROR_MESSAGE"; @@ -970,7 +970,7 @@ void Session::on_message_result_error(uint64 message_id, int error_code, string error_code = 500; } - if (message_id == 0) { + if (message_id == mtproto::MessageId()) { LOG(ERROR) << "Receive an error without message_id"; return; } @@ -998,8 +998,8 @@ void Session::on_message_result_error(uint64 message_id, int error_code, string sent_queries_.erase(it); } -void Session::on_message_failed_inner(uint64 message_id, bool in_container) { - LOG(INFO) << "Message inner failed " << message_id; +void Session::on_message_failed_inner(mtproto::MessageId message_id, bool in_container) { + LOG(INFO) << "Message inner failed for " << message_id; auto it = sent_queries_.find(message_id); if (it == sent_queries_.end()) { return; @@ -1016,8 +1016,8 @@ void Session::on_message_failed_inner(uint64 message_id, bool in_container) { sent_queries_.erase(it); } -void Session::on_message_failed(uint64 message_id, Status status) { - LOG(INFO) << "Message failed: " << tag("message_id", message_id) << tag("status", status); +void Session::on_message_failed(mtproto::MessageId message_id, Status status) { + LOG(INFO) << "Failed to send " << message_id << ": " << status; status.ignore(); auto cit = sent_containers_.find(message_id); @@ -1034,8 +1034,8 @@ void Session::on_message_failed(uint64 message_id, Status status) { on_message_failed_inner(message_id, false); } -void Session::on_message_info(uint64 message_id, int32 state, uint64 answer_message_id, int32 answer_size, - int32 source) { +void Session::on_message_info(mtproto::MessageId message_id, int32 state, mtproto::MessageId answer_message_id, + int32 answer_size, int32 source) { auto it = sent_queries_.find(message_id); if (it != sent_queries_.end()) { if (it->second.net_query_->update_is_ready()) { @@ -1049,7 +1049,7 @@ void Session::on_message_info(uint64 message_id, int32 state, uint64 answer_mess return; } } - if (message_id != 0) { + if (message_id != mtproto::MessageId()) { if (it == sent_queries_.end()) { return; } @@ -1060,15 +1060,16 @@ void Session::on_message_info(uint64 message_id, int32 state, uint64 answer_mess return on_message_failed(message_id, Status::Error("Message wasn't received by the server and must be re-sent")); case 0: - if (answer_message_id == 0) { - LOG(ERROR) << "Unexpected message_info.state == 0 " << tag("message_id", message_id) << tag("state", state) - << tag("answer_message_id", answer_message_id); + if (answer_message_id == mtproto::MessageId()) { + LOG(ERROR) << "Unexpected message_info.state == 0 for " << message_id << ": " << tag("state", state) + << tag("answer", answer_message_id); return on_message_failed(message_id, Status::Error("Unexpected message_info.state == 0")); } // fallthrough case 4: CHECK(0 <= source && source <= 3); - on_message_ack_impl(message_id, (answer_message_id ? 2 : 0) | (((state | source) & ((1 << 28) - 1)) << 2)); + on_message_ack_impl(message_id, (answer_message_id != mtproto::MessageId() ? 2 : 0) | + (((state | source) & ((1 << 28) - 1)) << 2)); break; default: LOG(ERROR) << "Invalid message info " << tag("state", state); @@ -1076,10 +1077,10 @@ void Session::on_message_info(uint64 message_id, int32 state, uint64 answer_mess } // ok, we are waiting for result of message_id. let's ask to resend it - if (answer_message_id != 0) { + if (answer_message_id != mtproto::MessageId()) { if (it != sent_queries_.end()) { - VLOG_IF(net_query, message_id != 0) << "Resend answer " << tag("answer_message_id", answer_message_id) - << tag("answer_size", answer_size) << it->second.net_query_; + VLOG_IF(net_query, message_id != mtproto::MessageId()) + << "Resend answer " << answer_message_id << ": " << tag("answer_size", answer_size) << it->second.net_query_; it->second.net_query_->debug(PSTRING() << get_name() << ": resend answer"); } current_info_->connection_->resend_answer(answer_message_id); @@ -1114,7 +1115,7 @@ void Session::add_query(NetQueryPtr &&net_query) { pending_queries_.push(std::move(net_query)); } -void Session::connection_send_query(ConnectionInfo *info, NetQueryPtr &&net_query, uint64 message_id) { +void Session::connection_send_query(ConnectionInfo *info, NetQueryPtr &&net_query, mtproto::MessageId message_id) { CHECK(info->state_ == ConnectionInfo::State::Ready); current_info_ = info; @@ -1123,10 +1124,10 @@ void Session::connection_send_query(ConnectionInfo *info, NetQueryPtr &&net_quer } Span invoke_after = net_query->invoke_after(); - vector invoke_after_message_ids; + vector invoke_after_message_ids; for (auto &ref : invoke_after) { - auto invoke_after_message_id = ref->message_id(); - if (ref->session_id() != auth_data_.get_session_id() || invoke_after_message_id == 0) { + auto invoke_after_message_id = mtproto::MessageId(ref->message_id()); + if (ref->session_id() != auth_data_.get_session_id() || invoke_after_message_id == mtproto::MessageId()) { net_query->set_error_resend_invoke_after(); return return_query(std::move(net_query)); } @@ -1155,30 +1156,27 @@ void Session::connection_send_query(ConnectionInfo *info, NetQueryPtr &&net_quer } message_id = r_message_id.ok(); } else { - if (message_id == 0) { + if (message_id == mtproto::MessageId()) { message_id = auth_data_.next_message_id(now); } } - net_query->set_message_id(message_id); - VLOG(net_query) << "Send query to connection " << net_query - << tag("invoke_after", transform(invoke_after_message_ids, [](auto message_id) { - return PSTRING() << format::as_hex(message_id); - })); + net_query->set_message_id(message_id.get()); + VLOG(net_query) << "Send query to connection " << net_query << tag("invoke_after", invoke_after_message_ids); { auto lock = net_query->lock(); net_query->get_data_unsafe().unknown_state_ = false; net_query->get_data_unsafe().ack_state_ = 0; } if (!net_query->cancel_slot_.empty()) { - LOG(DEBUG) << "Set event for net_query cancellation " << tag("message_id", format::as_hex(message_id)); - net_query->cancel_slot_.set_event(EventCreator::raw(actor_id(), message_id)); + LOG(DEBUG) << "Set event for net_query cancellation for " << message_id; + net_query->cancel_slot_.set_event(EventCreator::raw(actor_id(), message_id.get())); } auto status = sent_queries_.emplace(message_id, Query{message_id, std::move(net_query), main_connection_.connection_id_, now}); LOG_CHECK(status.second) << message_id; sent_queries_list_.put(status.first->second.get_list_node()); if (!status.second) { - LOG(FATAL) << "Duplicate message_id [message_id = " << message_id << "]"; + LOG(FATAL) << "Duplicate " << message_id; } if (immediately_fail_query) { on_message_result_error(message_id, 401, "TEST_ERROR"); @@ -1308,10 +1306,10 @@ void Session::connection_open_finish(ConnectionInfo *info, for (auto &message_id : unknown_queries_) { info->connection_->get_state_info(message_id); } - for (auto &message_id : to_cancel_) { + for (auto &message_id : to_cancel_message_ids_) { info->connection_->cancel_answer(message_id); } - to_cancel_.clear(); + to_cancel_message_ids_.clear(); } yield(); } @@ -1378,7 +1376,7 @@ bool Session::connection_send_bind_key(ConnectionInfo *info) { int64 perm_auth_key_id = auth_data_.get_main_auth_key().id(); int64 nonce = Random::secure_int64(); auto expires_at = static_cast(auth_data_.get_server_time(auth_data_.get_tmp_auth_key().expires_at())); - uint64 message_id; + mtproto::MessageId message_id; BufferSlice encrypted; std::tie(message_id, encrypted) = info->connection_->encrypted_bind(perm_auth_key_id, nonce, expires_at); diff --git a/td/telegram/net/Session.h b/td/telegram/net/Session.h index 4bf381e68..12d477891 100644 --- a/td/telegram/net/Session.h +++ b/td/telegram/net/Session.h @@ -14,6 +14,7 @@ #include "td/mtproto/AuthKey.h" #include "td/mtproto/ConnectionManager.h" #include "td/mtproto/Handshake.h" +#include "td/mtproto/MessageId.h" #include "td/mtproto/SessionConnection.h" #include "td/actor/actor.h" @@ -78,7 +79,7 @@ class Session final private: struct Query final : private ListNode { - uint64 container_message_id_; + mtproto::MessageId container_message_id_; NetQueryPtr net_query_; bool is_acknowledged_ = false; @@ -87,7 +88,7 @@ class Session final const int8 connection_id_; const double sent_at_; - Query(uint64 message_id, NetQueryPtr &&net_query, int8 connection_id, double sent_at) + Query(mtproto::MessageId message_id, NetQueryPtr &&net_query, int8 connection_id, double sent_at) : container_message_id_(message_id) , net_query_(std::move(net_query)) , connection_id_(connection_id) @@ -131,8 +132,8 @@ class Session final double last_bind_success_timestamp_ = 0; // time when auth_key and Session definitely was valid and authorized size_t dropped_size_ = 0; - FlatHashSet unknown_queries_; - vector to_cancel_; + FlatHashSet unknown_queries_; + vector to_cancel_message_ids_; // Do not invalidate iterators of these two containers! // TODO: better data structures @@ -145,7 +146,7 @@ class Session final std::map, std::greater<>> queries_; }; PriorityQueue pending_queries_; - std::map sent_queries_; + std::map sent_queries_; std::deque pending_invoke_after_queries_; ListNode sent_queries_list_; @@ -181,9 +182,9 @@ class Session final struct ContainerInfo { size_t ref_cnt; - vector message_ids; + vector message_ids; }; - FlatHashMap sent_containers_; + FlatHashMap sent_containers_; friend class GenAuthKeyActor; struct HandshakeInfo { @@ -214,33 +215,34 @@ class Session final void on_server_salt_updated() final; void on_server_time_difference_updated(bool force) final; - void on_new_session_created(uint64 unique_id, uint64 first_message_id) final; + void on_new_session_created(uint64 unique_id, mtproto::MessageId first_message_id) final; void on_session_failed(Status status) final; - void on_container_sent(uint64 container_message_id, vector message_ids) final; + void on_container_sent(mtproto::MessageId container_message_id, vector message_ids) final; Status on_update(BufferSlice packet) final; - void on_message_ack(uint64 message_id) final; - Status on_message_result_ok(uint64 message_id, BufferSlice packet, size_t original_size) final; - void on_message_result_error(uint64 message_id, int error_code, string message) final; - void on_message_failed(uint64 message_id, Status status) final; + void on_message_ack(mtproto::MessageId message_id) final; + Status on_message_result_ok(mtproto::MessageId message_id, BufferSlice packet, size_t original_size) final; + void on_message_result_error(mtproto::MessageId message_id, int error_code, string message) final; + void on_message_failed(mtproto::MessageId message_id, Status status) final; - void on_message_info(uint64 message_id, int32 state, uint64 answer_message_id, int32 answer_size, int32 source) final; + void on_message_info(mtproto::MessageId message_id, int32 state, mtproto::MessageId answer_message_id, + int32 answer_size, int32 source) final; Status on_destroy_auth_key() final; void flush_pending_invoke_after_queries(); bool has_queries() const; - void dec_container(uint64 container_message_id, Query *query); - void cleanup_container(uint64 container_message_id, Query *query); - void mark_as_known(uint64 message_id, Query *query); - void mark_as_unknown(uint64 message_id, Query *query); + void dec_container(mtproto::MessageId container_message_id, Query *query); + void cleanup_container(mtproto::MessageId container_message_id, Query *query); + void mark_as_known(mtproto::MessageId message_id, Query *query); + void mark_as_unknown(mtproto::MessageId message_id, Query *query); - void on_message_ack_impl(uint64 container_message_id, int32 type); - void on_message_ack_impl_inner(uint64 message_id, int32 type, bool in_container); - void on_message_failed_inner(uint64 message_id, bool in_container); + void on_message_ack_impl(mtproto::MessageId container_message_id, int32 type); + void on_message_ack_impl_inner(mtproto::MessageId message_id, int32 type, bool in_container); + void on_message_failed_inner(mtproto::MessageId message_id, bool in_container); // send NetQueryPtr to parent void return_query(NetQueryPtr &&query); @@ -255,7 +257,7 @@ class Session final void connection_online_update(double now, bool force); void connection_close(ConnectionInfo *info); void connection_flush(ConnectionInfo *info); - void connection_send_query(ConnectionInfo *info, NetQueryPtr &&net_query, uint64 message_id = 0); + void connection_send_query(ConnectionInfo *info, NetQueryPtr &&net_query, mtproto::MessageId message_id = {}); bool need_send_bind_key() const; bool need_send_query() const; bool can_destroy_auth_key() const;