Add strictly-typed class mtproto::MessageId.

This commit is contained in:
levlam 2023-09-21 17:52:33 +03:00
parent e47cea5904
commit b44e2ea3fc
18 changed files with 355 additions and 265 deletions

View File

@ -518,6 +518,7 @@ set(TDLIB_SOURCE
td/mtproto/HttpTransport.h td/mtproto/HttpTransport.h
td/mtproto/IStreamTransport.h td/mtproto/IStreamTransport.h
td/mtproto/KDF.h td/mtproto/KDF.h
td/mtproto/MessageId.h
td/mtproto/MtprotoQuery.h td/mtproto/MtprotoQuery.h
td/mtproto/NoCryptoStorer.h td/mtproto/NoCryptoStorer.h
td/mtproto/PacketInfo.h td/mtproto/PacketInfo.h

View File

@ -416,7 +416,7 @@ class IdDuplicateCheckerOld {
static td::string get_description() { static td::string get_description() {
return "Old"; return "Old";
} }
td::Status check(td::int64 message_id) { td::Status check(td::uint64 message_id) {
if (saved_message_ids_.size() == MAX_SAVED_MESSAGE_IDS) { if (saved_message_ids_.size() == MAX_SAVED_MESSAGE_IDS) {
auto oldest_message_id = *saved_message_ids_.begin(); auto oldest_message_id = *saved_message_ids_.begin();
if (message_id < oldest_message_id) { if (message_id < oldest_message_id) {
@ -437,7 +437,7 @@ class IdDuplicateCheckerOld {
private: private:
static constexpr size_t MAX_SAVED_MESSAGE_IDS = 1000; static constexpr size_t MAX_SAVED_MESSAGE_IDS = 1000;
std::set<td::int64> saved_message_ids_; std::set<td::uint64> saved_message_ids_;
}; };
template <size_t MAX_SAVED_MESSAGE_IDS> template <size_t MAX_SAVED_MESSAGE_IDS>
@ -446,7 +446,7 @@ class IdDuplicateCheckerNew {
static td::string get_description() { static td::string get_description() {
return PSTRING() << "New" << MAX_SAVED_MESSAGE_IDS; return PSTRING() << "New" << MAX_SAVED_MESSAGE_IDS;
} }
td::Status check(td::int64 message_id) { td::Status check(td::uint64 message_id) {
auto insert_result = saved_message_ids_.insert(message_id); auto insert_result = saved_message_ids_.insert(message_id);
if (!insert_result.second) { if (!insert_result.second) {
return td::Status::Error(1, PSLICE() << "Ignore already processed message " << message_id); return td::Status::Error(1, PSLICE() << "Ignore already processed message " << message_id);
@ -464,7 +464,7 @@ class IdDuplicateCheckerNew {
} }
private: private:
std::set<td::int64> saved_message_ids_; std::set<td::uint64> saved_message_ids_;
}; };
class IdDuplicateCheckerNewOther { class IdDuplicateCheckerNewOther {
@ -472,7 +472,7 @@ class IdDuplicateCheckerNewOther {
static td::string get_description() { static td::string get_description() {
return "NewOther"; return "NewOther";
} }
td::Status check(td::int64 message_id) { td::Status check(td::uint64 message_id) {
if (!saved_message_ids_.insert(message_id).second) { if (!saved_message_ids_.insert(message_id).second) {
return td::Status::Error(1, PSLICE() << "Ignore already processed message " << message_id); return td::Status::Error(1, PSLICE() << "Ignore already processed message " << message_id);
} }
@ -490,7 +490,7 @@ class IdDuplicateCheckerNewOther {
private: private:
static constexpr size_t MAX_SAVED_MESSAGE_IDS = 1000; static constexpr size_t MAX_SAVED_MESSAGE_IDS = 1000;
std::set<td::int64> saved_message_ids_; std::set<td::uint64> saved_message_ids_;
}; };
class IdDuplicateCheckerNewSimple { class IdDuplicateCheckerNewSimple {
@ -498,7 +498,7 @@ class IdDuplicateCheckerNewSimple {
static td::string get_description() { static td::string get_description() {
return "NewSimple"; return "NewSimple";
} }
td::Status check(td::int64 message_id) { td::Status check(td::uint64 message_id) {
auto insert_result = saved_message_ids_.insert(message_id); auto insert_result = saved_message_ids_.insert(message_id);
if (!insert_result.second) { if (!insert_result.second) {
return td::Status::Error(1, "Ignore already processed message"); return td::Status::Error(1, "Ignore already processed message");
@ -516,7 +516,7 @@ class IdDuplicateCheckerNewSimple {
private: private:
static constexpr size_t MAX_SAVED_MESSAGE_IDS = 1000; static constexpr size_t MAX_SAVED_MESSAGE_IDS = 1000;
std::set<td::int64> saved_message_ids_; std::set<td::uint64> saved_message_ids_;
}; };
template <size_t max_size> template <size_t max_size>
@ -525,7 +525,7 @@ class IdDuplicateCheckerArray {
static td::string get_description() { static td::string get_description() {
return PSTRING() << "Array" << max_size; return PSTRING() << "Array" << max_size;
} }
td::Status check(td::int64 message_id) { td::Status check(td::uint64 message_id) {
if (end_pos_ == 2 * max_size) { if (end_pos_ == 2 * max_size) {
std::copy_n(&saved_message_ids_[max_size], max_size, &saved_message_ids_[0]); std::copy_n(&saved_message_ids_[max_size], max_size, &saved_message_ids_[0]);
end_pos_ = max_size; end_pos_ = max_size;
@ -550,7 +550,7 @@ class IdDuplicateCheckerArray {
} }
private: private:
std::array<td::int64, 2 * max_size> saved_message_ids_; std::array<td::uint64, 2 * max_size> saved_message_ids_;
std::size_t end_pos_ = 0; std::size_t end_pos_ = 0;
}; };

View File

@ -6,7 +6,6 @@
// //
#include "td/mtproto/AuthData.h" #include "td/mtproto/AuthData.h"
#include "td/utils/format.h"
#include "td/utils/logging.h" #include "td/utils/logging.h"
#include "td/utils/Random.h" #include "td/utils/Random.h"
#include "td/utils/SliceBuilder.h" #include "td/utils/SliceBuilder.h"
@ -17,7 +16,8 @@
namespace td { namespace td {
namespace mtproto { namespace mtproto {
Status check_message_id_duplicates(uint64 *saved_message_ids, size_t max_size, size_t &end_pos, uint64 message_id) { Status check_message_id_duplicates(MessageId *saved_message_ids, size_t max_size, size_t &end_pos,
MessageId message_id) {
// In addition, the identifiers (msg_id) of the last N messages received from the other side must be stored, and if // In addition, the identifiers (msg_id) of the last N messages received from the other side must be stored, and if
// a message comes in with msg_id lower than all or equal to any of the stored values, that message is to be // a message comes in with msg_id lower than all or equal to any of the stored values, that message is to be
// ignored. Otherwise, the new message msg_id is added to the set, and, if the number of stored msg_id values is // ignored. Otherwise, the new message msg_id is added to the set, and, if the number of stored msg_id values is
@ -32,13 +32,12 @@ Status check_message_id_duplicates(uint64 *saved_message_ids, size_t max_size, s
return Status::OK(); return Status::OK();
} }
if (end_pos >= max_size && message_id < saved_message_ids[0]) { if (end_pos >= max_size && message_id < saved_message_ids[0]) {
return Status::Error(2, PSLICE() << "Ignore very old message " << format::as_hex(message_id) return Status::Error(
<< " older than the oldest known message " 2, PSLICE() << "Ignore very old " << message_id << " older than the oldest known " << saved_message_ids[0]);
<< format::as_hex(saved_message_ids[0]));
} }
auto it = std::lower_bound(&saved_message_ids[0], &saved_message_ids[end_pos], message_id); auto it = std::lower_bound(&saved_message_ids[0], &saved_message_ids[end_pos], message_id);
if (*it == message_id) { if (*it == message_id) {
return Status::Error(1, PSLICE() << "Ignore already processed message " << format::as_hex(message_id)); return Status::Error(1, PSLICE() << "Ignore already processed " << message_id);
} }
std::copy_backward(it, &saved_message_ids[end_pos], &saved_message_ids[end_pos + 1]); std::copy_backward(it, &saved_message_ids[end_pos], &saved_message_ids[end_pos + 1]);
*it = message_id; *it = message_id;
@ -105,7 +104,7 @@ std::vector<ServerSalt> AuthData::get_future_salts() const {
return res; return res;
} }
uint64 AuthData::next_message_id(double now) { MessageId AuthData::next_message_id(double now) {
double server_time = get_server_time(now); double server_time = get_server_time(now);
auto t = static_cast<uint64>(server_time * (static_cast<uint64>(1) << 32)); auto t = static_cast<uint64>(server_time * (static_cast<uint64>(1) << 32));
@ -113,31 +112,31 @@ uint64 AuthData::next_message_id(double now) {
// TODO(perf) do not do this for systems with good precision?.. // TODO(perf) do not do this for systems with good precision?..
auto rx = Random::secure_int32(); auto rx = Random::secure_int32();
auto to_xor = rx & ((1 << 22) - 1); auto to_xor = rx & ((1 << 22) - 1);
auto to_mul = ((rx >> 22) & 1023) + 1;
t ^= to_xor; t ^= to_xor;
auto result = t & static_cast<uint64>(-4); auto result = MessageId(t & static_cast<uint64>(-4));
if (last_message_id_ >= result) { if (last_message_id_ >= result) {
result = last_message_id_ + 8 * to_mul; auto to_mul = ((rx >> 22) & 1023) + 1;
result = MessageId(last_message_id_.get() + 8 * to_mul);
} }
LOG(DEBUG) << "Create message identifier " << format::as_hex(result) << " at " << now; LOG(DEBUG) << "Create identifier for " << result << " at " << now;
last_message_id_ = result; last_message_id_ = result;
return result; return result;
} }
bool AuthData::is_valid_outbound_msg_id(uint64 message_id, double now) const { bool AuthData::is_valid_outbound_msg_id(MessageId message_id, double now) const {
double server_time = get_server_time(now); double server_time = get_server_time(now);
auto id_time = static_cast<double>(message_id) / static_cast<double>(static_cast<uint64>(1) << 32); auto id_time = static_cast<double>(message_id.get()) / static_cast<double>(static_cast<uint64>(1) << 32);
return server_time - 150 < id_time && id_time < server_time + 30; return server_time - 150 < id_time && id_time < server_time + 30;
} }
bool AuthData::is_valid_inbound_msg_id(uint64 message_id, double now) const { bool AuthData::is_valid_inbound_msg_id(MessageId message_id, double now) const {
double server_time = get_server_time(now); double server_time = get_server_time(now);
auto id_time = static_cast<double>(message_id) / static_cast<double>(static_cast<uint64>(1) << 32); auto id_time = static_cast<double>(message_id.get()) / static_cast<double>(static_cast<uint64>(1) << 32);
return server_time - 300 < id_time && id_time < server_time + 30; return server_time - 300 < id_time && id_time < server_time + 30;
} }
Status AuthData::check_packet(uint64 session_id, uint64 message_id, double now, bool &time_difference_was_updated) { Status AuthData::check_packet(uint64 session_id, MessageId message_id, double now, bool &time_difference_was_updated) {
// Client is to check that the session_id field in the decrypted message indeed equals to that of an active session // Client is to check that the session_id field in the decrypted message indeed equals to that of an active session
// created by the client. // created by the client.
if (get_session_id() != session_id) { if (get_session_id() != session_id) {
@ -147,22 +146,21 @@ Status AuthData::check_packet(uint64 session_id, uint64 message_id, double now,
// Client must check that msg_id has even parity for messages from client to server, and odd parity for messages // Client must check that msg_id has even parity for messages from client to server, and odd parity for messages
// from server to client. // from server to client.
if ((message_id & 1) == 0) { if ((message_id.get() & 1) == 0) {
return Status::Error(PSLICE() << "Receive invalid message identifier " << format::as_hex(message_id)); return Status::Error(PSLICE() << "Receive invalid " << message_id);
} }
TRY_STATUS(duplicate_checker_.check(message_id)); TRY_STATUS(duplicate_checker_.check(message_id));
LOG(DEBUG) << "Receive packet " << format::as_hex(message_id) << " from session " << format::as_hex(session_id) LOG(DEBUG) << "Receive packet in " << message_id << " from session " << session_id << " at " << now;
<< " at " << now; time_difference_was_updated = update_server_time_difference(static_cast<uint32>(message_id.get() >> 32) - now);
time_difference_was_updated = update_server_time_difference(static_cast<uint32>(message_id >> 32) - now);
// In addition, msg_id values that belong over 30 seconds in the future or over 300 seconds in the past are to be // In addition, msg_id values that belong over 30 seconds in the future or over 300 seconds in the past are to be
// ignored (recall that msg_id approximately equals unixtime * 2^32). This is especially important for the server. // ignored (recall that msg_id approximately equals unixtime * 2^32). This is especially important for the server.
// The client would also find this useful (to protect from a replay attack), but only if it is certain of its time // The client would also find this useful (to protect from a replay attack), but only if it is certain of its time
// (for example, if its time has been synchronized with that of the server). // (for example, if its time has been synchronized with that of the server).
if (server_time_difference_was_updated_ && !is_valid_inbound_msg_id(message_id, now)) { if (server_time_difference_was_updated_ && !is_valid_inbound_msg_id(message_id, now)) {
return Status::Error(PSLICE() << "Ignore too old or too new message " << format::as_hex(message_id)); return Status::Error(PSLICE() << "Ignore too old or too new " << message_id);
} }
return Status::OK(); return Status::OK();

View File

@ -7,6 +7,7 @@
#pragma once #pragma once
#include "td/mtproto/AuthKey.h" #include "td/mtproto/AuthKey.h"
#include "td/mtproto/MessageId.h"
#include "td/utils/common.h" #include "td/utils/common.h"
#include "td/utils/Slice.h" #include "td/utils/Slice.h"
@ -37,17 +38,18 @@ void parse(ServerSalt &salt, ParserT &parser) {
salt.valid_until = parser.fetch_double(); salt.valid_until = parser.fetch_double();
} }
Status check_message_id_duplicates(uint64 *saved_message_ids, size_t max_size, size_t &end_pos, uint64 message_id); Status check_message_id_duplicates(MessageId *saved_message_ids, size_t max_size, size_t &end_pos,
MessageId message_id);
template <size_t max_size> template <size_t max_size>
class MessageIdDuplicateChecker { class MessageIdDuplicateChecker {
public: public:
Status check(uint64 message_id) { Status check(MessageId message_id) {
return check_message_id_duplicates(&saved_message_ids_[0], max_size, end_pos_, message_id); return check_message_id_duplicates(&saved_message_ids_[0], max_size, end_pos_, message_id);
} }
private: private:
std::array<uint64, 2 * max_size> saved_message_ids_; std::array<MessageId, 2 * max_size> saved_message_ids_;
size_t end_pos_ = 0; size_t end_pos_ = 0;
}; };
@ -232,19 +234,19 @@ class AuthData {
std::vector<ServerSalt> get_future_salts() const; std::vector<ServerSalt> get_future_salts() const;
uint64 next_message_id(double now); MessageId next_message_id(double now);
bool is_valid_outbound_msg_id(uint64 message_id, double now) const; bool is_valid_outbound_msg_id(MessageId message_id, double now) const;
bool is_valid_inbound_msg_id(uint64 message_id, double now) const; bool is_valid_inbound_msg_id(MessageId message_id, double now) const;
Status check_packet(uint64 session_id, uint64 message_id, double now, bool &time_difference_was_updated); Status check_packet(uint64 session_id, MessageId message_id, double now, bool &time_difference_was_updated);
Status check_update(uint64 message_id) { Status check_update(MessageId message_id) {
return updates_duplicate_checker_.check(message_id); return updates_duplicate_checker_.check(message_id);
} }
Status recheck_update(uint64 message_id) { Status recheck_update(MessageId message_id) {
return updates_duplicate_rechecker_.check(message_id); return updates_duplicate_rechecker_.check(message_id);
} }
@ -275,7 +277,7 @@ class AuthData {
bool server_time_difference_was_updated_ = false; bool server_time_difference_was_updated_ = false;
double server_time_difference_ = 0; double server_time_difference_ = 0;
ServerSalt server_salt_; ServerSalt server_salt_;
uint64 last_message_id_ = 0; MessageId last_message_id_;
int32 seq_no_ = 0; int32 seq_no_ = 0;
string header_; string header_;
uint64 session_id_ = 0; uint64 session_id_ = 0;

View File

@ -7,6 +7,7 @@
#pragma once #pragma once
#include "td/mtproto/AuthData.h" #include "td/mtproto/AuthData.h"
#include "td/mtproto/MessageId.h"
#include "td/mtproto/MtprotoQuery.h" #include "td/mtproto/MtprotoQuery.h"
#include "td/mtproto/PacketStorer.h" #include "td/mtproto/PacketStorer.h"
#include "td/mtproto/utils.h" #include "td/mtproto/utils.h"
@ -57,7 +58,7 @@ class ObjectImpl {
bool empty() const { bool empty() const {
return !not_empty_; return !not_empty_;
} }
uint64 get_message_id() const { MessageId get_message_id() const {
return message_id_; return message_id_;
} }
@ -65,7 +66,7 @@ class ObjectImpl {
bool not_empty_; bool not_empty_;
Object object_; Object object_;
ObjectStorer object_storer_; ObjectStorer object_storer_;
uint64 message_id_; MessageId message_id_;
int32 seq_no_; int32 seq_no_;
}; };
@ -96,7 +97,7 @@ class CancelVectorImpl {
bool not_empty() const { bool not_empty() const {
return !storers_.empty(); return !storers_.empty();
} }
uint64 get_message_id() const { MessageId get_message_id() const {
CHECK(storers_.size() == 1); CHECK(storers_.size() == 1);
return storers_[0].get_message_id(); return storers_[0].get_message_id();
} }
@ -107,7 +108,7 @@ class CancelVectorImpl {
class InvokeAfter { class InvokeAfter {
public: public:
explicit InvokeAfter(Span<uint64> message_ids) : message_ids_(message_ids) { explicit InvokeAfter(Span<MessageId> message_ids) : message_ids_(message_ids) {
} }
template <class StorerT> template <class StorerT>
void store(StorerT &storer) const { void store(StorerT &storer) const {
@ -116,7 +117,7 @@ class InvokeAfter {
} }
if (message_ids_.size() == 1) { if (message_ids_.size() == 1) {
storer.store_int(static_cast<int32>(0xcb9f372d)); storer.store_int(static_cast<int32>(0xcb9f372d));
storer.store_binary(message_ids_[0]); storer.store_binary(message_ids_[0].get());
return; return;
} }
// invokeAfterMsgs#3dc4b4f0 {X:Type} msg_ids:Vector<long> query:!X = X; // invokeAfterMsgs#3dc4b4f0 {X:Type} msg_ids:Vector<long> query:!X = X;
@ -124,12 +125,12 @@ class InvokeAfter {
storer.store_int(static_cast<int32>(0x1cb5c415)); storer.store_int(static_cast<int32>(0x1cb5c415));
storer.store_int(narrow_cast<int32>(message_ids_.size())); storer.store_int(narrow_cast<int32>(message_ids_.size()));
for (auto message_id : message_ids_) { for (auto message_id : message_ids_) {
storer.store_binary(message_id); storer.store_binary(message_id.get());
} }
} }
private: private:
Span<uint64> message_ids_; Span<MessageId> message_ids_;
}; };
class QueryImpl { class QueryImpl {
@ -206,8 +207,8 @@ class CryptoImpl {
CryptoImpl(const vector<MtprotoQuery> &to_send, Slice header, vector<int64> &&to_ack, int64 ping_id, int ping_timeout, CryptoImpl(const vector<MtprotoQuery> &to_send, Slice header, vector<int64> &&to_ack, int64 ping_id, int ping_timeout,
int max_delay, int max_after, int max_wait, int future_salt_n, vector<int64> get_info, int max_delay, int max_after, int max_wait, int future_salt_n, vector<int64> get_info,
vector<int64> resend, const vector<int64> &cancel, bool destroy_key, AuthData *auth_data, vector<int64> resend, const vector<int64> &cancel, bool destroy_key, AuthData *auth_data,
uint64 *container_message_id, uint64 *get_info_message_id, uint64 *resend_message_id, MessageId *container_message_id, MessageId *get_info_message_id, MessageId *resend_message_id,
uint64 *ping_message_id, uint64 *parent_message_id) MessageId *ping_message_id, MessageId *parent_message_id)
: query_storer_(to_send, header) : query_storer_(to_send, header)
, ack_empty_(to_ack.empty()) , ack_empty_(to_ack.empty())
, ack_storer_(!ack_empty_, mtproto_api::msgs_ack(std::move(to_ack)), auth_data) , ack_storer_(!ack_empty_, mtproto_api::msgs_ack(std::move(to_ack)), auth_data)
@ -362,7 +363,7 @@ class CryptoImpl {
Mixed Mixed
}; };
Type type_; Type type_;
uint64 message_id_; MessageId message_id_;
int32 seq_no_; int32 seq_no_;
}; };

View File

@ -8,6 +8,7 @@
#include "td/mtproto/AuthKey.h" #include "td/mtproto/AuthKey.h"
#include "td/mtproto/Handshake.h" #include "td/mtproto/Handshake.h"
#include "td/mtproto/MessageId.h"
#include "td/mtproto/NoCryptoStorer.h" #include "td/mtproto/NoCryptoStorer.h"
#include "td/mtproto/PacketInfo.h" #include "td/mtproto/PacketInfo.h"
#include "td/mtproto/PacketStorer.h" #include "td/mtproto/PacketStorer.h"
@ -61,7 +62,7 @@ class HandshakeConnection final
unique_ptr<AuthKeyHandshakeContext> context_; unique_ptr<AuthKeyHandshakeContext> context_;
void send_no_crypto(const Storer &storer) final { void send_no_crypto(const Storer &storer) final {
raw_connection_->send_no_crypto(PacketStorer<NoCryptoImpl>(0, storer)); raw_connection_->send_no_crypto(PacketStorer<NoCryptoImpl>(MessageId(), storer));
} }
Status on_raw_packet(const PacketInfo &packet_info, BufferSlice packet) final { Status on_raw_packet(const PacketInfo &packet_info, BufferSlice packet) final {

70
td/mtproto/MessageId.h Normal file
View File

@ -0,0 +1,70 @@
//
// Copyright Aliaksei Levin (levlam@telegram.org), Arseny Smirnov (arseny30@gmail.com) 2014-2023
//
// Distributed under the Boost Software License, Version 1.0. (See accompanying
// file LICENSE_1_0.txt or copy at http://www.boost.org/LICENSE_1_0.txt)
//
#pragma once
#include "td/utils/common.h"
#include "td/utils/format.h"
#include "td/utils/HashTableUtils.h"
#include "td/utils/StringBuilder.h"
#include <type_traits>
namespace td {
namespace mtproto {
class MessageId {
uint64 message_id_ = 0;
public:
MessageId() = default;
explicit constexpr MessageId(uint64 message_id) : message_id_(message_id) {
}
template <class T, typename = std::enable_if_t<std::is_convertible<T, int64>::value>>
MessageId(T message_id) = delete;
uint64 get() const {
return message_id_;
}
bool operator==(const MessageId &other) const {
return message_id_ == other.message_id_;
}
bool operator!=(const MessageId &other) const {
return message_id_ != other.message_id_;
}
friend bool operator<(const MessageId &lhs, const MessageId &rhs) {
return lhs.get() < rhs.get();
}
friend bool operator>(const MessageId &lhs, const MessageId &rhs) {
return lhs.get() > rhs.get();
}
friend bool operator<=(const MessageId &lhs, const MessageId &rhs) {
return lhs.get() <= rhs.get();
}
friend bool operator>=(const MessageId &lhs, const MessageId &rhs) {
return lhs.get() >= rhs.get();
}
};
struct MessageIdHash {
uint32 operator()(MessageId message_id) const {
return Hash<uint64>()(message_id.get());
}
};
inline StringBuilder &operator<<(StringBuilder &string_builder, MessageId message_id) {
return string_builder << "message " << format::as_hex(message_id.get());
}
} // namespace mtproto
} // namespace td

View File

@ -6,6 +6,8 @@
// //
#pragma once #pragma once
#include "td/mtproto/MessageId.h"
#include "td/utils/buffer.h" #include "td/utils/buffer.h"
#include "td/utils/common.h" #include "td/utils/common.h"
@ -13,11 +15,11 @@ namespace td {
namespace mtproto { namespace mtproto {
struct MtprotoQuery { struct MtprotoQuery {
uint64 message_id; MessageId message_id;
int32 seq_no; int32 seq_no;
BufferSlice packet; BufferSlice packet;
bool gzip_flag; bool gzip_flag;
vector<uint64> invoke_after_message_ids; vector<MessageId> invoke_after_message_ids;
bool use_quick_ack; bool use_quick_ack;
}; };

View File

@ -6,6 +6,8 @@
// //
#pragma once #pragma once
#include "td/mtproto/MessageId.h"
#include "td/utils/Random.h" #include "td/utils/Random.h"
#include "td/utils/StorerBase.h" #include "td/utils/StorerBase.h"
@ -14,7 +16,7 @@ namespace mtproto {
class NoCryptoImpl { class NoCryptoImpl {
public: public:
NoCryptoImpl(uint64 message_id, const Storer &data, bool need_pad = true) : message_id_(message_id), data_(data) { NoCryptoImpl(MessageId message_id, const Storer &data, bool need_pad = true) : message_id_(message_id), data_(data) {
if (need_pad) { if (need_pad) {
size_t pad_size = -static_cast<int>(data_.size()) & 15; size_t pad_size = -static_cast<int>(data_.size()) & 15;
pad_size += 16 * (static_cast<size_t>(Random::secure_int32()) % 16); pad_size += 16 * (static_cast<size_t>(Random::secure_int32()) % 16);
@ -25,14 +27,14 @@ class NoCryptoImpl {
template <class StorerT> template <class StorerT>
void do_store(StorerT &storer) const { void do_store(StorerT &storer) const {
storer.store_binary(message_id_); storer.store_binary(message_id_.get());
storer.store_binary(static_cast<int32>(data_.size() + pad_.size())); storer.store_binary(static_cast<int32>(data_.size() + pad_.size()));
storer.store_storer(data_); storer.store_storer(data_);
storer.store_slice(pad_); storer.store_slice(pad_);
} }
private: private:
uint64 message_id_; MessageId message_id_;
const Storer &data_; const Storer &data_;
std::string pad_; std::string pad_;
}; };

View File

@ -6,6 +6,8 @@
// //
#pragma once #pragma once
#include "td/mtproto/MessageId.h"
#include "td/utils/common.h" #include "td/utils/common.h"
namespace td { namespace td {
@ -18,7 +20,7 @@ struct PacketInfo {
uint64 salt{0}; uint64 salt{0};
uint64 session_id{0}; uint64 session_id{0};
uint64 message_id{0}; MessageId message_id;
int32 seq_no{0}; int32 seq_no{0};
int32 version{1}; int32 version{1};
bool no_crypto_flag{false}; bool no_crypto_flag{false};

View File

@ -8,6 +8,7 @@
#include "td/mtproto/AuthData.h" #include "td/mtproto/AuthData.h"
#include "td/mtproto/AuthKey.h" #include "td/mtproto/AuthKey.h"
#include "td/mtproto/MessageId.h"
#include "td/mtproto/mtproto_api.h" #include "td/mtproto/mtproto_api.h"
#include "td/mtproto/NoCryptoStorer.h" #include "td/mtproto/NoCryptoStorer.h"
#include "td/mtproto/PacketInfo.h" #include "td/mtproto/PacketInfo.h"
@ -47,7 +48,8 @@ class PingConnectionReqPQ final
if (!was_ping_) { if (!was_ping_) {
UInt128 nonce; UInt128 nonce;
Random::secure_bytes(nonce.raw, sizeof(nonce)); Random::secure_bytes(nonce.raw, sizeof(nonce));
raw_connection_->send_no_crypto(PacketStorer<NoCryptoImpl>(1, create_storer(mtproto_api::req_pq_multi(nonce)))); raw_connection_->send_no_crypto(PacketStorer<NoCryptoImpl>(MessageId(static_cast<uint64>(1)),
create_storer(mtproto_api::req_pq_multi(nonce))));
was_ping_ = true; was_ping_ = true;
if (ping_count_ == 1) { if (ping_count_ == 1) {
start_time_ = Time::now(); start_time_ = Time::now();
@ -129,13 +131,13 @@ class PingConnectionPingPong final
void on_server_time_difference_updated(bool force) final { void on_server_time_difference_updated(bool force) final {
} }
void on_new_session_created(uint64 unique_id, uint64 first_message_id) final { void on_new_session_created(uint64 unique_id, MessageId first_message_id) final {
} }
void on_session_failed(Status status) final { void on_session_failed(Status status) final {
} }
void on_container_sent(uint64 container_message_id, vector<uint64> message_ids) final { void on_container_sent(MessageId container_message_id, vector<MessageId> message_ids) final {
} }
Status on_pong() final { Status on_pong() final {
@ -153,21 +155,22 @@ class PingConnectionPingPong final
return Status::OK(); return Status::OK();
} }
void on_message_ack(uint64 message_id) final { void on_message_ack(MessageId message_id) final {
} }
Status on_message_result_ok(uint64 message_id, BufferSlice packet, size_t original_size) final { Status on_message_result_ok(MessageId message_id, BufferSlice packet, size_t original_size) final {
LOG(ERROR) << "Unexpected message"; LOG(ERROR) << "Unexpected message";
return Status::OK(); return Status::OK();
} }
void on_message_result_error(uint64 message_id, int code, string message) final { void on_message_result_error(MessageId message_id, int code, string message) final {
} }
void on_message_failed(uint64 message_id, Status status) final { void on_message_failed(MessageId message_id, Status status) final {
} }
void on_message_info(uint64 message_id, int32 state, uint64 answer_id, int32 answer_size, int32 source) final { void on_message_info(MessageId message_id, int32 state, MessageId answer_message_id, int32 answer_size,
int32 source) final {
} }
Status on_destroy_auth_key() final { Status on_destroy_auth_key() final {

View File

@ -86,7 +86,7 @@ class RawConnectionDefault final : public RawConnection {
return packet_size; return packet_size;
} }
uint64 send_no_crypto(const Storer &storer) final { MessageId send_no_crypto(const Storer &storer) final {
PacketInfo packet_info; PacketInfo packet_info;
packet_info.no_crypto_flag = true; packet_info.no_crypto_flag = true;
auto packet = Transport::write(storer, AuthKey(), &packet_info, transport_->max_prepend_size(), auto packet = Transport::write(storer, AuthKey(), &packet_info, transport_->max_prepend_size(),
@ -315,7 +315,7 @@ class RawConnectionHttp final : public RawConnection {
return packet_size; return packet_size;
} }
uint64 send_no_crypto(const Storer &storer) final { MessageId send_no_crypto(const Storer &storer) final {
PacketInfo packet_info; PacketInfo packet_info;
packet_info.no_crypto_flag = true; packet_info.no_crypto_flag = true;
auto packet = Transport::write(storer, AuthKey(), &packet_info); auto packet = Transport::write(storer, AuthKey(), &packet_info);

View File

@ -7,6 +7,7 @@
#pragma once #pragma once
#include "td/mtproto/ConnectionManager.h" #include "td/mtproto/ConnectionManager.h"
#include "td/mtproto/MessageId.h"
#include "td/mtproto/PacketInfo.h" #include "td/mtproto/PacketInfo.h"
#include "td/mtproto/TransportType.h" #include "td/mtproto/TransportType.h"
@ -50,7 +51,7 @@ class RawConnection {
virtual TransportType get_transport_type() const = 0; virtual TransportType get_transport_type() const = 0;
virtual size_t send_crypto(const Storer &storer, uint64 session_id, int64 salt, const AuthKey &auth_key, virtual size_t send_crypto(const Storer &storer, uint64 session_id, int64 salt, const AuthKey &auth_key,
uint64 quick_ack_token) = 0; uint64 quick_ack_token) = 0;
virtual uint64 send_no_crypto(const Storer &storer) = 0; virtual MessageId send_no_crypto(const Storer &storer) = 0;
virtual PollableFdInfo &get_poll_info() = 0; virtual PollableFdInfo &get_poll_info() = 0;
virtual StatsCallback *stats_callback() = 0; virtual StatsCallback *stats_callback() = 0;
@ -63,7 +64,7 @@ class RawConnection {
virtual ~Callback() = default; virtual ~Callback() = default;
virtual Status on_raw_packet(const PacketInfo &packet_info, BufferSlice packet) = 0; virtual Status on_raw_packet(const PacketInfo &packet_info, BufferSlice packet) = 0;
virtual Status on_quick_ack(uint64 quick_ack_token) { virtual Status on_quick_ack(uint64 quick_ack_token) {
return Status::Error("Quick acknowledgements are unsupported by the callback"); return Status::Error("Quick acknowledgements aren't supported by the callback");
} }
virtual Status before_write() { virtual Status before_write() {
return Status::OK(); return Status::OK();

View File

@ -172,7 +172,7 @@ namespace mtproto {
*/ */
inline StringBuilder &operator<<(StringBuilder &string_builder, const SessionConnection::MsgInfo &info) { inline StringBuilder &operator<<(StringBuilder &string_builder, const SessionConnection::MsgInfo &info) {
return string_builder << "[msg_id:" << format::as_hex(info.message_id) << "][seq_no:" << info.seq_no << ']'; return string_builder << "with " << info.message_id << " and seq_no " << info.seq_no;
} }
unique_ptr<RawConnection> SessionConnection::move_as_raw_connection() { unique_ptr<RawConnection> SessionConnection::move_as_raw_connection() {
@ -190,7 +190,7 @@ Status SessionConnection::parse_message(TlParser &parser, MsgInfo *info, Slice *
if (parser.get_error() != nullptr) { if (parser.get_error() != nullptr) {
return Status::Error(PSLICE() << "Failed to parse mtproto_api::message: " << parser.get_error()); return Status::Error(PSLICE() << "Failed to parse mtproto_api::message: " << parser.get_error());
} }
info->message_id = parser.fetch_long_unsafe(); info->message_id = MessageId(static_cast<uint64>(parser.fetch_long_unsafe()));
if (crypto_flag) { if (crypto_flag) {
info->seq_no = parser.fetch_int_unsafe(); info->seq_no = parser.fetch_int_unsafe();
} }
@ -223,22 +223,22 @@ Status SessionConnection::on_packet_container(const MsgInfo &info, Slice packet)
if (parser.get_error()) { if (parser.get_error()) {
return Status::Error(PSLICE() << "Failed to parse mtproto_api::rpc_container: " << parser.get_error()); return Status::Error(PSLICE() << "Failed to parse mtproto_api::rpc_container: " << parser.get_error());
} }
VLOG(mtproto) << "Receive container " << format::as_hex(container_message_id_) << " of size " << size; VLOG(mtproto) << "Receive container " << container_message_id_ << " of size " << size;
for (int i = 0; i < size; i++) { for (int i = 0; i < size; i++) {
TRY_STATUS(parse_packet(parser)); TRY_STATUS(parse_packet(parser));
} }
return Status::OK(); return Status::OK();
} }
void SessionConnection::reset_server_time_difference(uint64 message_id) { void SessionConnection::reset_server_time_difference(MessageId message_id) {
VLOG(mtproto) << "Reset server time difference"; VLOG(mtproto) << "Reset server time difference";
auth_data_->reset_server_time_difference(static_cast<uint32>(message_id >> 32) - Time::now()); auth_data_->reset_server_time_difference(static_cast<uint32>(message_id.get() >> 32) - Time::now());
callback_->on_server_time_difference_updated(true); callback_->on_server_time_difference_updated(true);
} }
Status SessionConnection::on_packet_rpc_result(const MsgInfo &info, Slice packet) { Status SessionConnection::on_packet_rpc_result(const MsgInfo &info, Slice packet) {
TlParser parser(packet); TlParser parser(packet);
uint64 req_msg_id = parser.fetch_long(); uint64 req_msg_id = static_cast<uint64>(parser.fetch_long());
if (parser.get_error()) { if (parser.get_error()) {
return Status::Error(PSLICE() << "Failed to parse mtproto_api::rpc_result: " << parser.get_error()); return Status::Error(PSLICE() << "Failed to parse mtproto_api::rpc_result: " << parser.get_error());
} }
@ -246,9 +246,9 @@ Status SessionConnection::on_packet_rpc_result(const MsgInfo &info, Slice packet
LOG(ERROR) << "Receive an update in rpc_result " << info; LOG(ERROR) << "Receive an update in rpc_result " << info;
return Status::Error("Receive an update in rpc_result"); return Status::Error("Receive an update in rpc_result");
} }
VLOG(mtproto) << "Receive result for request " << format::as_hex(req_msg_id) << " with " << info; VLOG(mtproto) << "Receive result for request with " << MessageId(req_msg_id) << ' ' << info;
if (info.message_id < req_msg_id - (static_cast<uint64>(15) << 32)) { if (info.message_id.get() < req_msg_id - (static_cast<uint64>(15) << 32)) {
reset_server_time_difference(info.message_id); reset_server_time_difference(info.message_id);
} }
@ -258,7 +258,7 @@ Status SessionConnection::on_packet_rpc_result(const MsgInfo &info, Slice packet
if (parser.get_error()) { if (parser.get_error()) {
return Status::Error(PSLICE() << "Failed to parse mtproto_api::rpc_error: " << parser.get_error()); return Status::Error(PSLICE() << "Failed to parse mtproto_api::rpc_error: " << parser.get_error());
} }
callback_->on_message_result_error(req_msg_id, rpc_error.error_code_, rpc_error.error_message_.str()); callback_->on_message_result_error(MessageId(req_msg_id), rpc_error.error_code_, rpc_error.error_message_.str());
return Status::OK(); return Status::OK();
} }
case mtproto_api::gzip_packed::ID: { case mtproto_api::gzip_packed::ID: {
@ -269,11 +269,11 @@ Status SessionConnection::on_packet_rpc_result(const MsgInfo &info, Slice packet
// yep, gzip in rpc_result // yep, gzip in rpc_result
BufferSlice object = gzdecode(gzip.packed_data_); BufferSlice object = gzdecode(gzip.packed_data_);
// send header no more optimization // send header no more optimization
return callback_->on_message_result_ok(req_msg_id, std::move(object), info.size); return callback_->on_message_result_ok(MessageId(req_msg_id), std::move(object), info.size);
} }
default: default:
packet.remove_prefix(sizeof(req_msg_id)); packet.remove_prefix(sizeof(req_msg_id));
return callback_->on_message_result_ok(req_msg_id, as_buffer_slice(packet), info.size); return callback_->on_message_result_ok(MessageId(req_msg_id), as_buffer_slice(packet), info.size);
} }
} }
@ -284,17 +284,17 @@ Status SessionConnection::on_packet(const MsgInfo &info, const T &packet) {
} }
Status SessionConnection::on_packet(const MsgInfo &info, const mtproto_api::destroy_auth_key_ok &destroy_auth_key) { Status SessionConnection::on_packet(const MsgInfo &info, const mtproto_api::destroy_auth_key_ok &destroy_auth_key) {
VLOG(mtproto) << "Receive destroy_auth_key_ok with " << info; VLOG(mtproto) << "Receive destroy_auth_key_ok " << info;
return on_destroy_auth_key(destroy_auth_key); return on_destroy_auth_key(destroy_auth_key);
} }
Status SessionConnection::on_packet(const MsgInfo &info, const mtproto_api::destroy_auth_key_none &destroy_auth_key) { Status SessionConnection::on_packet(const MsgInfo &info, const mtproto_api::destroy_auth_key_none &destroy_auth_key) {
VLOG(mtproto) << "Receive destroy_auth_key_none with " << info; VLOG(mtproto) << "Receive destroy_auth_key_none " << info;
return on_destroy_auth_key(destroy_auth_key); return on_destroy_auth_key(destroy_auth_key);
} }
Status SessionConnection::on_packet(const MsgInfo &info, const mtproto_api::destroy_auth_key_fail &destroy_auth_key) { Status SessionConnection::on_packet(const MsgInfo &info, const mtproto_api::destroy_auth_key_fail &destroy_auth_key) {
VLOG(mtproto) << "Receive destroy_auth_key_fail with " << info; VLOG(mtproto) << "Receive destroy_auth_key_fail " << info;
return on_destroy_auth_key(destroy_auth_key); return on_destroy_auth_key(destroy_auth_key);
} }
@ -304,14 +304,14 @@ Status SessionConnection::on_destroy_auth_key(const mtproto_api::DestroyAuthKeyR
} }
Status SessionConnection::on_packet(const MsgInfo &info, const mtproto_api::new_session_created &new_session_created) { Status SessionConnection::on_packet(const MsgInfo &info, const mtproto_api::new_session_created &new_session_created) {
auto first_message_id = static_cast<uint64>(new_session_created.first_msg_id_); auto first_message_id = MessageId(static_cast<uint64>(new_session_created.first_msg_id_));
VLOG(mtproto) << "Receive new_session_created with " << info << ": [first_msg_id:" << format::as_hex(first_message_id) VLOG(mtproto) << "Receive new_session_created " << info << ": [first " << first_message_id
<< "] [unique_id:" << format::as_hex(new_session_created.unique_id_) << ']'; << "] [unique_id:" << new_session_created.unique_id_ << ']';
auto it = service_queries_.find(first_message_id); auto it = service_queries_.find(first_message_id);
if (it != service_queries_.end()) { if (it != service_queries_.end()) {
first_message_id = it->second.container_message_id_; first_message_id = it->second.container_message_id_;
LOG(INFO) << "Update first_message_id to container's " << format::as_hex(first_message_id); LOG(INFO) << "Update first_message_id to container's " << first_message_id;
} }
callback_->on_new_session_created(new_session_created.unique_id_, first_message_id); callback_->on_new_session_created(new_session_created.unique_id_, first_message_id);
@ -320,7 +320,8 @@ Status SessionConnection::on_packet(const MsgInfo &info, const mtproto_api::new_
Status SessionConnection::on_packet(const MsgInfo &info, Status SessionConnection::on_packet(const MsgInfo &info,
const mtproto_api::bad_msg_notification &bad_msg_notification) { const mtproto_api::bad_msg_notification &bad_msg_notification) {
MsgInfo bad_info{static_cast<uint64>(bad_msg_notification.bad_msg_id_), bad_msg_notification.bad_msg_seqno_, 0}; MsgInfo bad_info{MessageId(static_cast<uint64>(bad_msg_notification.bad_msg_id_)),
bad_msg_notification.bad_msg_seqno_, 0};
enum Code { enum Code {
MsgIdTooLow = 16, MsgIdTooLow = 16,
MsgIdTooHigh = 17, MsgIdTooHigh = 17,
@ -383,8 +384,8 @@ Status SessionConnection::on_packet(const MsgInfo &info,
} }
Status SessionConnection::on_packet(const MsgInfo &info, const mtproto_api::bad_server_salt &bad_server_salt) { Status SessionConnection::on_packet(const MsgInfo &info, const mtproto_api::bad_server_salt &bad_server_salt) {
MsgInfo bad_info{static_cast<uint64>(bad_server_salt.bad_msg_id_), bad_server_salt.bad_msg_seqno_, 0}; MsgInfo bad_info{MessageId(static_cast<uint64>(bad_server_salt.bad_msg_id_)), bad_server_salt.bad_msg_seqno_, 0};
VLOG(mtproto) << "Receive bad_server_salt with " << info << ": " << bad_info; VLOG(mtproto) << "Receive bad_server_salt " << info << ": " << bad_info;
auth_data_->set_server_salt(bad_server_salt.new_server_salt_, Time::now_cached()); auth_data_->set_server_salt(bad_server_salt.new_server_salt_, Time::now_cached());
callback_->on_server_salt_updated(); callback_->on_server_salt_updated();
@ -393,8 +394,9 @@ Status SessionConnection::on_packet(const MsgInfo &info, const mtproto_api::bad_
} }
Status SessionConnection::on_packet(const MsgInfo &info, const mtproto_api::msgs_ack &msgs_ack) { Status SessionConnection::on_packet(const MsgInfo &info, const mtproto_api::msgs_ack &msgs_ack) {
VLOG(mtproto) << "Receive msgs_ack with " << info << ": " << msgs_ack.msg_ids_; auto message_ids = transform(msgs_ack.msg_ids_, [](int64 msg_id) { return MessageId(static_cast<uint64>(msg_id)); });
for (auto message_id : msgs_ack.msg_ids_) { VLOG(mtproto) << "Receive msgs_ack " << info << ": " << message_ids;
for (auto message_id : message_ids) {
callback_->on_message_ack(message_id); callback_->on_message_ack(message_id);
} }
return Status::OK(); return Status::OK();
@ -407,8 +409,8 @@ Status SessionConnection::on_packet(const MsgInfo &info, const mtproto_api::gzip
} }
Status SessionConnection::on_packet(const MsgInfo &info, const mtproto_api::pong &pong) { Status SessionConnection::on_packet(const MsgInfo &info, const mtproto_api::pong &pong) {
VLOG(mtproto) << "Receive pong with " << info; VLOG(mtproto) << "Receive pong " << info;
if (info.message_id < static_cast<uint64>(pong.msg_id_) - (static_cast<uint64>(15) << 32)) { if (info.message_id.get() < static_cast<uint64>(pong.msg_id_) - (static_cast<uint64>(15) << 32)) {
reset_server_time_difference(info.message_id); reset_server_time_difference(info.message_id);
} }
last_pong_at_ = Time::now_cached(); last_pong_at_ = Time::now_cached();
@ -424,7 +426,7 @@ Status SessionConnection::on_packet(const MsgInfo &info, const mtproto_api::futu
} }
auto now = Time::now_cached(); auto now = Time::now_cached();
auth_data_->set_future_salts(new_salts, now); auth_data_->set_future_salts(new_salts, now);
VLOG(mtproto) << "Receive future_salts with " << info << ": is_valid = " << auth_data_->is_server_salt_valid(now) VLOG(mtproto) << "Receive future_salts " << info << ": is_valid = " << auth_data_->is_server_salt_valid(now)
<< ", has_salt = " << auth_data_->has_salt(now) << ", has_salt = " << auth_data_->has_salt(now)
<< ", need_future_salts = " << auth_data_->need_future_salts(now); << ", need_future_salts = " << auth_data_->need_future_salts(now);
callback_->on_server_salt_updated(); callback_->on_server_salt_updated();
@ -438,14 +440,14 @@ Status SessionConnection::on_msgs_state_info(const vector<int64> &msg_ids, Slice
} }
size_t i = 0; size_t i = 0;
for (auto msg_id : msg_ids) { for (auto msg_id : msg_ids) {
callback_->on_message_info(static_cast<uint64>(msg_id), info[i], 0, 0, 1); callback_->on_message_info(MessageId(static_cast<uint64>(msg_id)), info[i], MessageId(), 0, 1);
i++; i++;
} }
return Status::OK(); return Status::OK();
} }
Status SessionConnection::on_packet(const MsgInfo &info, const mtproto_api::msgs_state_info &msgs_state_info) { Status SessionConnection::on_packet(const MsgInfo &info, const mtproto_api::msgs_state_info &msgs_state_info) {
auto message_id = static_cast<uint64>(msgs_state_info.req_msg_id_); auto message_id = MessageId(static_cast<uint64>(msgs_state_info.req_msg_id_));
auto it = service_queries_.find(message_id); auto it = service_queries_.find(message_id);
if (it == service_queries_.end()) { if (it == service_queries_.end()) {
return Status::Error("Unknown msgs_state_info"); return Status::Error("Unknown msgs_state_info");
@ -456,26 +458,28 @@ Status SessionConnection::on_packet(const MsgInfo &info, const mtproto_api::msgs
if (query.type_ != ServiceQuery::GetStateInfo) { if (query.type_ != ServiceQuery::GetStateInfo) {
return Status::Error("Receive msgs_state_info in response not to GetStateInfo"); return Status::Error("Receive msgs_state_info in response not to GetStateInfo");
} }
VLOG(mtproto) << "Receive msgs_state_info with " << info; VLOG(mtproto) << "Receive msgs_state_info " << info;
return on_msgs_state_info(query.msg_ids_, msgs_state_info.info_); return on_msgs_state_info(query.msg_ids_, msgs_state_info.info_);
} }
Status SessionConnection::on_packet(const MsgInfo &info, const mtproto_api::msgs_all_info &msgs_all_info) { Status SessionConnection::on_packet(const MsgInfo &info, const mtproto_api::msgs_all_info &msgs_all_info) {
VLOG(mtproto) << "Receive msgs_all_info with " << info; VLOG(mtproto) << "Receive msgs_all_info " << info;
return on_msgs_state_info(msgs_all_info.msg_ids_, msgs_all_info.info_); return on_msgs_state_info(msgs_all_info.msg_ids_, msgs_all_info.info_);
} }
Status SessionConnection::on_packet(const MsgInfo &info, const mtproto_api::msg_detailed_info &msg_detailed_info) { Status SessionConnection::on_packet(const MsgInfo &info, const mtproto_api::msg_detailed_info &msg_detailed_info) {
VLOG(mtproto) << "Receive msg_detailed_info with " << info; VLOG(mtproto) << "Receive msg_detailed_info " << info;
callback_->on_message_info(msg_detailed_info.msg_id_, msg_detailed_info.status_, msg_detailed_info.answer_msg_id_, callback_->on_message_info(MessageId(static_cast<uint64>(msg_detailed_info.msg_id_)), msg_detailed_info.status_,
msg_detailed_info.bytes_, 2); MessageId(static_cast<uint64>(msg_detailed_info.answer_msg_id_)), msg_detailed_info.bytes_,
2);
return Status::OK(); return Status::OK();
} }
Status SessionConnection::on_packet(const MsgInfo &info, Status SessionConnection::on_packet(const MsgInfo &info,
const mtproto_api::msg_new_detailed_info &msg_new_detailed_info) { const mtproto_api::msg_new_detailed_info &msg_new_detailed_info) {
VLOG(mtproto) << "Receive msg_new_detailed_info with " << info; VLOG(mtproto) << "Receive msg_new_detailed_info " << info;
callback_->on_message_info(0, 0, msg_new_detailed_info.answer_msg_id_, msg_new_detailed_info.bytes_, 0); callback_->on_message_info(MessageId(), 0, MessageId(static_cast<uint64>(msg_new_detailed_info.answer_msg_id_)),
msg_new_detailed_info.bytes_, 0);
return Status::OK(); return Status::OK();
} }
@ -517,9 +521,8 @@ Status SessionConnection::on_slice_packet(const MsgInfo &info, Slice packet) {
auto get_update_description = [&] { auto get_update_description = [&] {
return PSTRING() << "update from " << get_name() << " with auth key " << auth_data_->get_auth_key().id() return PSTRING() << "update from " << get_name() << " with auth key " << auth_data_->get_auth_key().id()
<< " active for " << (Time::now() - created_at_) << " seconds in container " << " active for " << (Time::now() - created_at_) << " seconds in container "
<< container_message_id_ << " from session " << auth_data_->get_session_id() << " with " << info << container_message_id_ << " from session " << auth_data_->get_session_id() << ' ' << info
<< ", main_message_id = " << format::as_hex(main_message_id_) << ", main " << main_message_id_ << " and original size = " << info.size;
<< " and original size = " << info.size;
}; };
// It is an update... I hope. // It is an update... I hope.
@ -560,8 +563,8 @@ Status SessionConnection::on_main_packet(const PacketInfo &packet_info, Slice pa
} }
VLOG(raw_mtproto) << "Receive packet of size " << packet.size() << ':' << format::as_hex_dump<4>(packet); VLOG(raw_mtproto) << "Receive packet of size " << packet.size() << ':' << format::as_hex_dump<4>(packet);
VLOG(mtproto) << "Receive packet with seq_no " << packet_info.seq_no << " and msg_id " VLOG(mtproto) << "Receive packet with " << packet_info.message_id << " and seq_no " << packet_info.seq_no
<< format::as_hex(packet_info.message_id) << " of size " << packet.size(); << " of size " << packet.size();
if (packet_info.no_crypto_flag) { if (packet_info.no_crypto_flag) {
return Status::Error("Unencrypted packet"); return Status::Error("Unencrypted packet");
@ -576,7 +579,7 @@ Status SessionConnection::on_main_packet(const PacketInfo &packet_info, Slice pa
return Status::OK(); return Status::OK();
} }
void SessionConnection::on_message_failed(uint64 message_id, Status status) { void SessionConnection::on_message_failed(MessageId message_id, Status status) {
callback_->on_message_failed(message_id, std::move(status)); callback_->on_message_failed(message_id, std::move(status));
sent_destroy_auth_key_ = false; sent_destroy_auth_key_ = false;
@ -584,8 +587,8 @@ void SessionConnection::on_message_failed(uint64 message_id, Status status) {
if (message_id == last_ping_message_id_ || message_id == last_ping_container_message_id_) { if (message_id == last_ping_message_id_ || message_id == last_ping_container_message_id_) {
// restart ping immediately // restart ping immediately
last_ping_at_ = 0; last_ping_at_ = 0;
last_ping_message_id_ = 0; last_ping_message_id_ = {};
last_ping_container_message_id_ = 0; last_ping_container_message_id_ = {};
} }
auto cit = container_to_service_message_id_.find(message_id); auto cit = container_to_service_message_id_.find(message_id);
@ -599,7 +602,7 @@ void SessionConnection::on_message_failed(uint64 message_id, Status status) {
} }
} }
void SessionConnection::on_message_failed_inner(uint64 message_id) { void SessionConnection::on_message_failed_inner(MessageId message_id) {
auto it = service_queries_.find(message_id); auto it = service_queries_.find(message_id);
if (it == service_queries_.end()) { if (it == service_queries_.end()) {
return; return;
@ -610,12 +613,12 @@ void SessionConnection::on_message_failed_inner(uint64 message_id) {
switch (query.type_) { switch (query.type_) {
case ServiceQuery::ResendAnswer: case ServiceQuery::ResendAnswer:
for (auto msg_id : query.msg_ids_) { for (auto msg_id : query.msg_ids_) {
resend_answer(static_cast<uint64>(msg_id)); resend_answer(MessageId(static_cast<uint64>(msg_id)));
} }
break; break;
case ServiceQuery::GetStateInfo: case ServiceQuery::GetStateInfo:
for (auto msg_id : query.msg_ids_) { for (auto msg_id : query.msg_ids_) {
get_state_info(static_cast<uint64>(msg_id)); get_state_info(MessageId(static_cast<uint64>(msg_id)));
} }
break; break;
default: default:
@ -726,7 +729,7 @@ Status SessionConnection::on_raw_packet(const PacketInfo &packet_info, BufferSli
} }
Status SessionConnection::on_quick_ack(uint64 quick_ack_token) { Status SessionConnection::on_quick_ack(uint64 quick_ack_token) {
callback_->on_message_ack(quick_ack_token); callback_->on_message_ack(MessageId(quick_ack_token));
return Status::OK(); return Status::OK();
} }
@ -773,8 +776,8 @@ void SessionConnection::set_online(bool online_flag, bool is_main) {
last_read_at_ = now; last_read_at_ = now;
} }
last_ping_at_ = 0; last_ping_at_ = 0;
last_ping_message_id_ = 0; last_ping_message_id_ = {};
last_ping_container_message_id_ = 0; last_ping_container_message_id_ = {};
} }
void SessionConnection::do_close(Status status) { void SessionConnection::do_close(Status status) {
@ -790,10 +793,10 @@ void SessionConnection::send_crypto(const Storer &storer, uint64 quick_ack_token
auth_data_->get_auth_key(), quick_ack_token); auth_data_->get_auth_key(), quick_ack_token);
} }
Result<uint64> SessionConnection::send_query(BufferSlice buffer, bool gzip_flag, uint64 message_id, Result<MessageId> SessionConnection::send_query(BufferSlice buffer, bool gzip_flag, MessageId message_id,
vector<uint64> invoke_after_message_ids, bool use_quick_ack) { vector<MessageId> invoke_after_message_ids, bool use_quick_ack) {
CHECK(mode_ != Mode::HttpLongPoll); // "LongPoll connection is only for http_wait" CHECK(mode_ != Mode::HttpLongPoll); // "LongPoll connection is only for http_wait"
if (message_id == 0) { if (message_id == MessageId()) {
message_id = auth_data_->next_message_id(Time::now_cached()); message_id = auth_data_->next_message_id(Time::now_cached());
} }
auto seq_no = auth_data_->next_seq_no(true); auto seq_no = auth_data_->next_seq_no(true);
@ -802,28 +805,28 @@ Result<uint64> SessionConnection::send_query(BufferSlice buffer, bool gzip_flag,
} }
to_send_.push_back(MtprotoQuery{message_id, seq_no, std::move(buffer), gzip_flag, std::move(invoke_after_message_ids), to_send_.push_back(MtprotoQuery{message_id, seq_no, std::move(buffer), gzip_flag, std::move(invoke_after_message_ids),
use_quick_ack}); use_quick_ack});
VLOG(mtproto) << "Invoke query with msg_id " << format::as_hex(message_id) << " and seq_no " << seq_no << " of size " VLOG(mtproto) << "Invoke query with " << message_id << " and seq_no " << seq_no << " of size "
<< to_send_.back().packet.size() << " after " << invoke_after_message_ids << to_send_.back().packet.size() << " after " << invoke_after_message_ids
<< (use_quick_ack ? " with quick ack" : ""); << (use_quick_ack ? " with quick ack" : "");
return message_id; return message_id;
} }
void SessionConnection::get_state_info(uint64 message_id) { void SessionConnection::get_state_info(MessageId message_id) {
if (to_get_state_info_message_ids_.empty()) { if (to_get_state_info_message_ids_.empty()) {
send_before(Time::now_cached()); send_before(Time::now_cached());
} }
to_get_state_info_message_ids_.push_back(message_id); to_get_state_info_message_ids_.push_back(message_id);
} }
void SessionConnection::resend_answer(uint64 message_id) { void SessionConnection::resend_answer(MessageId message_id) {
if (to_resend_answer_message_ids_.empty()) { if (to_resend_answer_message_ids_.empty()) {
send_before(Time::now_cached() + RESEND_ANSWER_DELAY); send_before(Time::now_cached() + RESEND_ANSWER_DELAY);
} }
to_resend_answer_message_ids_.push_back(message_id); to_resend_answer_message_ids_.push_back(message_id);
} }
void SessionConnection::cancel_answer(uint64 message_id) { void SessionConnection::cancel_answer(MessageId message_id) {
if (to_cancel_answer_message_ids_.empty()) { if (to_cancel_answer_message_ids_.empty()) {
send_before(Time::now_cached() + RESEND_ANSWER_DELAY); send_before(Time::now_cached() + RESEND_ANSWER_DELAY);
} }
@ -835,7 +838,7 @@ void SessionConnection::destroy_key() {
need_destroy_auth_key_ = true; need_destroy_auth_key_ = true;
} }
std::pair<uint64, BufferSlice> SessionConnection::encrypted_bind(int64 perm_key, int64 nonce, int32 expires_at) { std::pair<MessageId, BufferSlice> SessionConnection::encrypted_bind(int64 perm_key, int64 nonce, int32 expires_at) {
int64 temp_key = auth_data_->get_tmp_auth_key().id(); int64 temp_key = auth_data_->get_tmp_auth_key().id();
mtproto_api::bind_auth_key_inner object(nonce, temp_key, perm_key, auth_data_->get_session_id(), expires_at); mtproto_api::bind_auth_key_inner object(nonce, temp_key, perm_key, auth_data_->get_session_id(), expires_at);
@ -865,8 +868,8 @@ void SessionConnection::force_ack() {
} }
} }
void SessionConnection::send_ack(uint64 message_id) { void SessionConnection::send_ack(MessageId message_id) {
VLOG(mtproto) << "Send ack: [msg_id:" << format::as_hex(message_id) << "]"; VLOG(mtproto) << "Send ack for " << message_id;
if (to_ack_message_ids_.empty()) { if (to_ack_message_ids_.empty()) {
send_before(Time::now_cached() + ACK_DELAY); send_before(Time::now_cached() + ACK_DELAY);
} }
@ -881,7 +884,7 @@ void SessionConnection::send_ack(uint64 message_id) {
} }
} }
// don't send ping in poll mode. // don't send ping in poll mode
bool SessionConnection::may_ping() const { bool SessionConnection::may_ping() const {
return last_ping_at_ == 0 || (mode_ != Mode::HttpLongPoll && last_ping_at_ + ping_may_delay() < Time::now_cached()); return last_ping_at_ == 0 || (mode_ != Mode::HttpLongPoll && last_ping_at_ + ping_may_delay() < Time::now_cached());
} }
@ -893,7 +896,7 @@ bool SessionConnection::must_ping() const {
void SessionConnection::flush_packet() { void SessionConnection::flush_packet() {
bool has_salt = auth_data_->has_salt(Time::now_cached()); bool has_salt = auth_data_->has_salt(Time::now_cached());
// ping // ping
uint64 container_message_id = 0; MessageId container_message_id;
int64 ping_id = 0; int64 ping_id = 0;
if (has_salt && may_ping()) { if (has_salt && may_ping()) {
ping_id = ++cur_ping_id_; ping_id = ++cur_ping_id_;
@ -963,53 +966,54 @@ void SessionConnection::flush_packet() {
<< tag("cancel", to_cancel_answer_message_ids_.size()) << tag("destroy_key", destroy_auth_key) << tag("cancel", to_cancel_answer_message_ids_.size()) << tag("destroy_key", destroy_auth_key)
<< tag("auth_key_id", auth_data_->get_auth_key().id()); << tag("auth_key_id", auth_data_->get_auth_key().id());
auto cut_tail = [](vector<uint64> &v, size_t size, Slice name) { auto cut_tail = [](vector<MessageId> &message_ids, size_t size, Slice name) {
if (size >= v.size()) { if (size >= message_ids.size()) {
auto result = transform(v, [](uint64 x) { return static_cast<int64>(x); }); auto result = transform(message_ids, [](MessageId message_id) { return static_cast<int64>(message_id.get()); });
v.clear(); message_ids.clear();
return result; return result;
} }
LOG(WARNING) << "Too many message identifiers in container " << name << ": " << v.size() << " instead of " << size; LOG(WARNING) << "Too many message identifiers in container " << name << ": " << message_ids.size() << " instead of "
auto new_size = v.size() - size; << size;
auto new_size = message_ids.size() - size;
vector<int64> result(size); vector<int64> result(size);
for (size_t i = 0; i < size; i++) { for (size_t i = 0; i < size; i++) {
result[i] = static_cast<int64>(v[i + new_size]); result[i] = static_cast<int64>(message_ids[i + new_size].get());
} }
v.resize(new_size); message_ids.resize(new_size);
return result; return result;
}; };
// no more than 8192 message identifiers per container.. // no more than 8192 message identifiers per container..
auto to_resend_answer = cut_tail(to_resend_answer_message_ids_, 8192, "resend_answer"); auto to_resend_answer = cut_tail(to_resend_answer_message_ids_, 8192, "resend_answer");
uint64 resend_answer_message_id = 0; MessageId resend_answer_message_id;
CHECK(queries.size() <= 1020); CHECK(queries.size() <= 1020);
auto to_cancel_answer = cut_tail(to_cancel_answer_message_ids_, 1020 - queries.size(), "cancel_answer"); auto to_cancel_answer = cut_tail(to_cancel_answer_message_ids_, 1020 - queries.size(), "cancel_answer");
auto to_get_state_info = cut_tail(to_get_state_info_message_ids_, 8192, "get_state_info"); auto to_get_state_info = cut_tail(to_get_state_info_message_ids_, 8192, "get_state_info");
uint64 get_state_info_message_id = 0; MessageId get_state_info_message_id;
auto to_ack = cut_tail(to_ack_message_ids_, 8192, "ack"); auto to_ack = cut_tail(to_ack_message_ids_, 8192, "ack");
uint64 ping_message_id = 0; MessageId ping_message_id;
bool use_quick_ack = bool use_quick_ack =
std::any_of(queries.begin(), queries.end(), [](const auto &query) { return query.use_quick_ack; }); std::any_of(queries.begin(), queries.end(), [](const auto &query) { return query.use_quick_ack; });
{ {
// LOG(ERROR) << (auth_data_->get_header().empty() ? '-' : '+'); // LOG(ERROR) << (auth_data_->get_header().empty() ? '-' : '+');
uint64 parent_message_id = 0; MessageId parent_message_id;
auto storer = PacketStorer<CryptoImpl>( auto storer = PacketStorer<CryptoImpl>(
queries, auth_data_->get_header(), std::move(to_ack), ping_id, static_cast<int>(ping_disconnect_delay() + 2.0), queries, auth_data_->get_header(), std::move(to_ack), ping_id, static_cast<int>(ping_disconnect_delay() + 2.0),
max_delay, max_after, max_wait, future_salt_n, to_get_state_info, to_resend_answer, to_cancel_answer, max_delay, max_after, max_wait, future_salt_n, to_get_state_info, to_resend_answer, to_cancel_answer,
destroy_auth_key, auth_data_, &container_message_id, &get_state_info_message_id, &resend_answer_message_id, destroy_auth_key, auth_data_, &container_message_id, &get_state_info_message_id, &resend_answer_message_id,
&ping_message_id, &parent_message_id); &ping_message_id, &parent_message_id);
auto quick_ack_token = use_quick_ack ? parent_message_id : 0; auto quick_ack_token = use_quick_ack ? parent_message_id.get() : 0;
send_crypto(storer, quick_ack_token); send_crypto(storer, quick_ack_token);
} }
if (resend_answer_message_id) { if (resend_answer_message_id != MessageId()) {
service_queries_.emplace(resend_answer_message_id, ServiceQuery{ServiceQuery::ResendAnswer, container_message_id, service_queries_.emplace(resend_answer_message_id, ServiceQuery{ServiceQuery::ResendAnswer, container_message_id,
std::move(to_resend_answer)}); std::move(to_resend_answer)});
} }
if (get_state_info_message_id) { if (get_state_info_message_id != MessageId()) {
service_queries_.emplace(get_state_info_message_id, ServiceQuery{ServiceQuery::GetStateInfo, container_message_id, service_queries_.emplace(get_state_info_message_id, ServiceQuery{ServiceQuery::GetStateInfo, container_message_id,
std::move(to_get_state_info)}); std::move(to_get_state_info)});
} }
@ -1018,8 +1022,8 @@ void SessionConnection::flush_packet() {
last_ping_message_id_ = ping_message_id; last_ping_message_id_ = ping_message_id;
} }
if (container_message_id != 0) { if (container_message_id != MessageId()) {
auto message_ids = transform(queries, [](const MtprotoQuery &x) { return static_cast<uint64>(x.message_id); }); auto message_ids = transform(queries, [](const MtprotoQuery &x) { return x.message_id; });
// some acks may be lost here. Nobody will resend them if something goes wrong with query. // some acks may be lost here. Nobody will resend them if something goes wrong with query.
// It is mostly problem for server. We will just drop this answers in next connection // It is mostly problem for server. We will just drop this answers in next connection
@ -1028,10 +1032,10 @@ void SessionConnection::flush_packet() {
// So I will re-ask salt if have no answer in 60 second. // So I will re-ask salt if have no answer in 60 second.
callback_->on_container_sent(container_message_id, std::move(message_ids)); callback_->on_container_sent(container_message_id, std::move(message_ids));
if (resend_answer_message_id) { if (resend_answer_message_id != MessageId()) {
container_to_service_message_id_[container_message_id].push_back(resend_answer_message_id); container_to_service_message_id_[container_message_id].push_back(resend_answer_message_id);
} }
if (get_state_info_message_id) { if (get_state_info_message_id != MessageId()) {
container_to_service_message_id_[container_message_id].push_back(get_state_info_message_id); container_to_service_message_id_[container_message_id].push_back(get_state_info_message_id);
} }
} }

View File

@ -6,6 +6,7 @@
// //
#pragma once #pragma once
#include "td/mtproto/MessageId.h"
#include "td/mtproto/MtprotoQuery.h" #include "td/mtproto/MtprotoQuery.h"
#include "td/mtproto/PacketInfo.h" #include "td/mtproto/PacketInfo.h"
#include "td/mtproto/RawConnection.h" #include "td/mtproto/RawConnection.h"
@ -67,14 +68,14 @@ class SessionConnection final
unique_ptr<RawConnection> move_as_raw_connection(); unique_ptr<RawConnection> move_as_raw_connection();
// Interface // Interface
Result<uint64> TD_WARN_UNUSED_RESULT send_query(BufferSlice buffer, bool gzip_flag, uint64 message_id = 0, Result<MessageId> TD_WARN_UNUSED_RESULT send_query(BufferSlice buffer, bool gzip_flag, MessageId message_id = {},
vector<uint64> invoke_after_message_ids = {}, vector<MessageId> invoke_after_message_ids = {},
bool use_quick_ack = false); bool use_quick_ack = false);
std::pair<uint64, BufferSlice> encrypted_bind(int64 perm_key, int64 nonce, int32 expires_at); std::pair<MessageId, BufferSlice> encrypted_bind(int64 perm_key, int64 nonce, int32 expires_at);
void get_state_info(uint64 message_id); void get_state_info(MessageId message_id);
void resend_answer(uint64 message_id); void resend_answer(MessageId message_id);
void cancel_answer(uint64 message_id); void cancel_answer(MessageId message_id);
void destroy_key(); void destroy_key();
void set_online(bool online_flag, bool is_main); void set_online(bool online_flag, bool is_main);
@ -95,19 +96,20 @@ class SessionConnection final
virtual void on_server_salt_updated() = 0; virtual void on_server_salt_updated() = 0;
virtual void on_server_time_difference_updated(bool force) = 0; virtual void on_server_time_difference_updated(bool force) = 0;
virtual void on_new_session_created(uint64 unique_id, uint64 first_message_id) = 0; virtual void on_new_session_created(uint64 unique_id, MessageId first_message_id) = 0;
virtual void on_session_failed(Status status) = 0; virtual void on_session_failed(Status status) = 0;
virtual void on_container_sent(uint64 container_message_id, vector<uint64> message_ids) = 0; virtual void on_container_sent(MessageId container_message_id, vector<MessageId> message_ids) = 0;
virtual Status on_pong() = 0; virtual Status on_pong() = 0;
virtual Status on_update(BufferSlice packet) = 0; virtual Status on_update(BufferSlice packet) = 0;
virtual void on_message_ack(uint64 message_id) = 0; virtual void on_message_ack(MessageId message_id) = 0;
virtual Status on_message_result_ok(uint64 message_id, BufferSlice packet, size_t original_size) = 0; virtual Status on_message_result_ok(MessageId message_id, BufferSlice packet, size_t original_size) = 0;
virtual void on_message_result_error(uint64 message_id, int code, string message) = 0; virtual void on_message_result_error(MessageId message_id, int code, string message) = 0;
virtual void on_message_failed(uint64 message_id, Status status) = 0; virtual void on_message_failed(MessageId message_id, Status status) = 0;
virtual void on_message_info(uint64 message_id, int32 state, uint64 answer_id, int32 answer_size, int32 source) = 0; virtual void on_message_info(MessageId message_id, int32 state, MessageId answer_message_id, int32 answer_size,
int32 source) = 0;
virtual Status on_destroy_auth_key() = 0; virtual Status on_destroy_auth_key() = 0;
}; };
@ -123,7 +125,7 @@ class SessionConnection final
static constexpr double RESEND_ANSWER_DELAY = 0.001; // 0.001s static constexpr double RESEND_ANSWER_DELAY = 0.001; // 0.001s
struct MsgInfo { struct MsgInfo {
uint64 message_id; MessageId message_id;
int32 seq_no; int32 seq_no;
size_t size; size_t size;
}; };
@ -161,21 +163,21 @@ class SessionConnection final
static constexpr int HTTP_MAX_DELAY = 30; // 0.03s static constexpr int HTTP_MAX_DELAY = 30; // 0.03s
vector<MtprotoQuery> to_send_; vector<MtprotoQuery> to_send_;
vector<uint64> to_ack_message_ids_; vector<MessageId> to_ack_message_ids_;
double force_send_at_ = 0; double force_send_at_ = 0;
struct ServiceQuery { struct ServiceQuery {
enum Type { GetStateInfo, ResendAnswer } type_; enum Type { GetStateInfo, ResendAnswer } type_;
uint64 container_message_id_; MessageId container_message_id_;
vector<int64> msg_ids_; vector<int64> msg_ids_;
}; };
vector<uint64> to_resend_answer_message_ids_; vector<MessageId> to_resend_answer_message_ids_;
vector<uint64> to_cancel_answer_message_ids_; vector<MessageId> to_cancel_answer_message_ids_;
vector<uint64> to_get_state_info_message_ids_; vector<MessageId> to_get_state_info_message_ids_;
FlatHashMap<uint64, ServiceQuery> service_queries_; FlatHashMap<MessageId, ServiceQuery, MessageIdHash> service_queries_;
// nobody cleans up this map. But it should be really small. // nobody cleans up this map. But it should be really small.
FlatHashMap<uint64, vector<uint64>> container_to_service_message_id_; FlatHashMap<MessageId, vector<MessageId>, MessageIdHash> container_to_service_message_id_;
double random_delay_ = 0; double random_delay_ = 0;
double last_read_at_ = 0; double last_read_at_ = 0;
@ -183,9 +185,9 @@ class SessionConnection final
double last_pong_at_ = 0; double last_pong_at_ = 0;
double real_last_read_at_ = 0; double real_last_read_at_ = 0;
double real_last_pong_at_ = 0; double real_last_pong_at_ = 0;
uint64 cur_ping_id_ = 0; int64 cur_ping_id_ = 0;
uint64 last_ping_message_id_ = 0; MessageId last_ping_message_id_;
uint64 last_ping_container_message_id_ = 0; MessageId last_ping_container_message_id_;
uint64 last_read_size_ = 0; uint64 last_read_size_ = 0;
uint64 last_write_size_ = 0; uint64 last_write_size_ = 0;
@ -200,8 +202,8 @@ class SessionConnection final
Mode mode_; Mode mode_;
bool connected_flag_ = false; bool connected_flag_ = false;
uint64 container_message_id_ = 0; MessageId container_message_id_;
uint64 main_message_id_ = 0; MessageId main_message_id_;
double created_at_ = 0; double created_at_ = 0;
unique_ptr<RawConnection> raw_connection_; unique_ptr<RawConnection> raw_connection_;
@ -218,7 +220,7 @@ class SessionConnection final
}; };
} }
void reset_server_time_difference(uint64 message_id); void reset_server_time_difference(MessageId message_id);
static Status parse_message(TlParser &parser, MsgInfo *info, Slice *packet, static Status parse_message(TlParser &parser, MsgInfo *info, Slice *packet,
bool crypto_flag = true) TD_WARN_UNUSED_RESULT; bool crypto_flag = true) TD_WARN_UNUSED_RESULT;
@ -254,12 +256,12 @@ class SessionConnection final
Status on_slice_packet(const MsgInfo &info, Slice packet) TD_WARN_UNUSED_RESULT; Status on_slice_packet(const MsgInfo &info, Slice packet) TD_WARN_UNUSED_RESULT;
Status on_main_packet(const PacketInfo &packet_info, Slice packet) TD_WARN_UNUSED_RESULT; Status on_main_packet(const PacketInfo &packet_info, Slice packet) TD_WARN_UNUSED_RESULT;
void on_message_failed(uint64 message_id, Status status); void on_message_failed(MessageId message_id, Status status);
void on_message_failed_inner(uint64 message_id); void on_message_failed_inner(MessageId message_id);
void do_close(Status status); void do_close(Status status);
void send_ack(uint64 message_id); void send_ack(MessageId message_id);
void send_crypto(const Storer &storer, uint64 quick_ack_token); void send_crypto(const Storer &storer, uint64 quick_ack_token);
void send_before(double tm); void send_before(double tm);
bool may_ping() const; bool may_ping() const;

View File

@ -8,6 +8,7 @@
#include "td/mtproto/AuthKey.h" #include "td/mtproto/AuthKey.h"
#include "td/mtproto/KDF.h" #include "td/mtproto/KDF.h"
#include "td/mtproto/MessageId.h"
#include "td/utils/as.h" #include "td/utils/as.h"
#include "td/utils/crypto.h" #include "td/utils/crypto.h"
@ -42,7 +43,7 @@ struct CryptoHeader {
// It is weird to generate message_id and seq_no while writing a packet. // It is weird to generate message_id and seq_no while writing a packet.
// //
// uint64 message_id; // uint64 msg_id;
// uint32 seq_no; // uint32 seq_no;
// uint32 message_data_length; // uint32 message_data_length;
uint8 data[0]; // use compiler extension uint8 data[0]; // use compiler extension
@ -68,7 +69,7 @@ struct CryptoHeader {
}; };
struct CryptoPrefix { struct CryptoPrefix {
uint64 message_id; uint64 msg_id;
uint32 seq_no; uint32 seq_no;
uint32 message_data_length; uint32 message_data_length;
}; };
@ -108,9 +109,9 @@ struct EndToEndPrefix {
struct NoCryptoHeader { struct NoCryptoHeader {
uint64 auth_key_id; uint64 auth_key_id;
// message_id is removed from CryptoHeader. Should be removed from here too. // msg_id is removed from CryptoHeader. Should be removed from here too.
// //
// uint64 message_id; // uint64 msg_id;
// uint32 message_data_length; // uint32 message_data_length;
uint8 data[0]; // use compiler extension uint8 data[0]; // use compiler extension
@ -309,7 +310,7 @@ Status Transport::read_crypto(MutableSlice message, const AuthKey &auth_key, Pac
packet_info->type = PacketInfo::Common; packet_info->type = PacketInfo::Common;
packet_info->salt = header->salt; packet_info->salt = header->salt;
packet_info->session_id = header->session_id; packet_info->session_id = header->session_id;
packet_info->message_id = prefix->message_id; packet_info->message_id = MessageId(prefix->msg_id);
packet_info->seq_no = prefix->seq_no; packet_info->seq_no = prefix->seq_no;
return Status::OK(); return Status::OK();
} }

View File

@ -534,7 +534,7 @@ void Session::hangup() {
} }
void Session::raw_event(const Event::Raw &event) { void Session::raw_event(const Event::Raw &event) {
auto message_id = event.u64; auto message_id = mtproto::MessageId(event.u64);
auto it = sent_queries_.find(message_id); auto it = sent_queries_.find(message_id);
if (it == sent_queries_.end()) { if (it == sent_queries_.end()) {
return; return;
@ -552,7 +552,7 @@ void Session::raw_event(const Event::Raw &event) {
if (main_connection_.state_ == ConnectionInfo::State::Ready) { if (main_connection_.state_ == ConnectionInfo::State::Ready) {
main_connection_.connection_->cancel_answer(message_id); main_connection_.connection_->cancel_answer(message_id);
} else { } else {
to_cancel_.push_back(message_id); to_cancel_message_ids_.push_back(message_id);
} }
loop(); loop();
} }
@ -697,8 +697,8 @@ void Session::on_closed(Status status) {
current_info_->state_ = ConnectionInfo::State::Empty; current_info_->state_ = ConnectionInfo::State::Empty;
} }
void Session::on_new_session_created(uint64 unique_id, uint64 first_message_id) { void Session::on_new_session_created(uint64 unique_id, mtproto::MessageId first_message_id) {
LOG(INFO) << "New session " << unique_id << " created with first message_id " << format::as_hex(first_message_id); LOG(INFO) << "New session " << unique_id << " created with first " << first_message_id;
if (!use_pfs_ && !auth_data_.use_pfs()) { if (!use_pfs_ && !auth_data_.use_pfs()) {
last_success_timestamp_ = Time::now(); last_success_timestamp_ = Time::now();
} }
@ -712,9 +712,9 @@ void Session::on_new_session_created(uint64 unique_id, uint64 first_message_id)
auto first_query_it = sent_queries_.find(first_message_id); auto first_query_it = sent_queries_.find(first_message_id);
if (first_query_it != sent_queries_.end()) { if (first_query_it != sent_queries_.end()) {
first_message_id = first_query_it->second.container_message_id_; first_message_id = first_query_it->second.container_message_id_;
LOG(INFO) << "Update first_message_id to container's " << format::as_hex(first_message_id); LOG(INFO) << "Update first message to container's " << first_message_id;
} else { } else {
LOG(INFO) << "Failed to find query " << format::as_hex(first_message_id) << " from the new session"; LOG(INFO) << "Failed to find sent " << first_message_id << " from the new session";
} }
for (auto it = sent_queries_.begin(); it != sent_queries_.end();) { for (auto it = sent_queries_.begin(); it != sent_queries_.end();) {
Query *query_ptr = &it->second; Query *query_ptr = &it->second;
@ -741,10 +741,10 @@ void Session::on_session_failed(Status status) {
callback_->on_failed(); callback_->on_failed();
} }
void Session::on_container_sent(uint64 container_message_id, vector<uint64> message_ids) { void Session::on_container_sent(mtproto::MessageId container_message_id, vector<mtproto::MessageId> message_ids) {
CHECK(container_message_id != 0); CHECK(container_message_id != mtproto::MessageId());
td::remove_if(message_ids, [&](uint64 message_id) { td::remove_if(message_ids, [&](mtproto::MessageId message_id) {
auto it = sent_queries_.find(message_id); auto it = sent_queries_.find(message_id);
if (it == sent_queries_.end()) { if (it == sent_queries_.end()) {
return true; // remove return true; // remove
@ -759,11 +759,11 @@ void Session::on_container_sent(uint64 container_message_id, vector<uint64> mess
sent_containers_.emplace(container_message_id, ContainerInfo{size, std::move(message_ids)}); sent_containers_.emplace(container_message_id, ContainerInfo{size, std::move(message_ids)});
} }
void Session::on_message_ack(uint64 message_id) { void Session::on_message_ack(mtproto::MessageId message_id) {
on_message_ack_impl(message_id, 1); on_message_ack_impl(message_id, 1);
} }
void Session::on_message_ack_impl(uint64 container_message_id, int32 type) { void Session::on_message_ack_impl(mtproto::MessageId container_message_id, int32 type) {
auto cit = sent_containers_.find(container_message_id); auto cit = sent_containers_.find(container_message_id);
if (cit != sent_containers_.end()) { if (cit != sent_containers_.end()) {
auto container_info = std::move(cit->second); auto container_info = std::move(cit->second);
@ -778,7 +778,7 @@ void Session::on_message_ack_impl(uint64 container_message_id, int32 type) {
on_message_ack_impl_inner(container_message_id, type, false); on_message_ack_impl_inner(container_message_id, type, false);
} }
void Session::on_message_ack_impl_inner(uint64 message_id, int32 type, bool in_container) { void Session::on_message_ack_impl_inner(mtproto::MessageId message_id, int32 type, bool in_container) {
auto it = sent_queries_.find(message_id); auto it = sent_queries_.find(message_id);
if (it == sent_queries_.end()) { if (it == sent_queries_.end()) {
return; return;
@ -796,7 +796,7 @@ void Session::on_message_ack_impl_inner(uint64 message_id, int32 type, bool in_c
mark_as_known(it->first, &it->second); mark_as_known(it->first, &it->second);
} }
void Session::dec_container(uint64 container_message_id, Query *query) { void Session::dec_container(mtproto::MessageId container_message_id, Query *query) {
if (query->container_message_id_ == container_message_id) { if (query->container_message_id_ == container_message_id) {
// message was sent without any container // message was sent without any container
return; return;
@ -812,7 +812,7 @@ void Session::dec_container(uint64 container_message_id, Query *query) {
} }
} }
void Session::cleanup_container(uint64 container_message_id, Query *query) { void Session::cleanup_container(mtproto::MessageId container_message_id, Query *query) {
if (query->container_message_id_ == container_message_id) { if (query->container_message_id_ == container_message_id) {
// message was sent without any container // message was sent without any container
return; return;
@ -823,7 +823,7 @@ void Session::cleanup_container(uint64 container_message_id, Query *query) {
sent_containers_.erase(query->container_message_id_); sent_containers_.erase(query->container_message_id_);
} }
void Session::mark_as_known(uint64 message_id, Query *query) { void Session::mark_as_known(mtproto::MessageId message_id, Query *query) {
{ {
auto lock = query->net_query_->lock(); auto lock = query->net_query_->lock();
query->net_query_->get_data_unsafe().unknown_state_ = false; query->net_query_->get_data_unsafe().unknown_state_ = false;
@ -839,7 +839,7 @@ void Session::mark_as_known(uint64 message_id, Query *query) {
} }
} }
void Session::mark_as_unknown(uint64 message_id, Query *query) { void Session::mark_as_unknown(mtproto::MessageId message_id, Query *query) {
{ {
auto lock = query->net_query_->lock(); auto lock = query->net_query_->lock();
query->net_query_->get_data_unsafe().unknown_state_ = true; query->net_query_->get_data_unsafe().unknown_state_ = true;
@ -849,7 +849,7 @@ void Session::mark_as_unknown(uint64 message_id, Query *query) {
} }
VLOG(net_query) << "Mark as unknown " << query->net_query_; VLOG(net_query) << "Mark as unknown " << query->net_query_;
query->is_unknown_ = true; query->is_unknown_ = true;
CHECK(message_id != 0); CHECK(message_id != mtproto::MessageId());
unknown_queries_.insert(message_id); unknown_queries_.insert(message_id);
} }
@ -866,16 +866,16 @@ Status Session::on_update(BufferSlice packet) {
return Status::OK(); return Status::OK();
} }
Status Session::on_message_result_ok(uint64 message_id, BufferSlice packet, size_t original_size) { Status Session::on_message_result_ok(mtproto::MessageId message_id, BufferSlice packet, size_t original_size) {
last_success_timestamp_ = Time::now(); last_success_timestamp_ = Time::now();
TlParser parser(packet.as_slice()); TlParser parser(packet.as_slice());
int32 response_id = parser.fetch_int(); int32 response_tl_id = parser.fetch_int();
auto it = sent_queries_.find(message_id); auto it = sent_queries_.find(message_id);
if (it == sent_queries_.end()) { if (it == sent_queries_.end()) {
LOG(DEBUG) << "Drop result to " << tag("message_id", format::as_hex(message_id)) LOG(DEBUG) << "Drop result to " << message_id << tag("original_size", original_size)
<< tag("original_size", original_size) << tag("response_id", format::as_hex(response_id)); << tag("response_tl", format::as_hex(response_tl_id));
if (original_size > 16 * 1024) { if (original_size > 16 * 1024) {
dropped_size_ += original_size; dropped_size_ += original_size;
@ -896,9 +896,9 @@ Status Session::on_message_result_ok(uint64 message_id, BufferSlice packet, size
if (!parser.get_error()) { if (!parser.get_error()) {
// Steal authorization information. // Steal authorization information.
// It is a dirty hack, yep. // It is a dirty hack, yep.
if (response_id == telegram_api::auth_authorization::ID || if (response_tl_id == telegram_api::auth_authorization::ID ||
response_id == telegram_api::auth_loginTokenSuccess::ID || response_tl_id == telegram_api::auth_loginTokenSuccess::ID ||
response_id == telegram_api::auth_sentCodeSuccess::ID) { response_tl_id == telegram_api::auth_sentCodeSuccess::ID) {
if (query_ptr->net_query_->tl_constructor() != telegram_api::auth_importAuthorization::ID) { if (query_ptr->net_query_->tl_constructor() != telegram_api::auth_importAuthorization::ID) {
G()->net_query_dispatcher().set_main_dc_id(raw_dc_id_); G()->net_query_dispatcher().set_main_dc_id(raw_dc_id_);
} }
@ -918,7 +918,7 @@ Status Session::on_message_result_ok(uint64 message_id, BufferSlice packet, size
return Status::OK(); return Status::OK();
} }
void Session::on_message_result_error(uint64 message_id, int error_code, string message) { void Session::on_message_result_error(mtproto::MessageId message_id, int error_code, string message) {
if (!check_utf8(message)) { if (!check_utf8(message)) {
LOG(ERROR) << "Receive invalid error message \"" << message << '"'; LOG(ERROR) << "Receive invalid error message \"" << message << '"';
message = "INVALID_UTF8_ERROR_MESSAGE"; message = "INVALID_UTF8_ERROR_MESSAGE";
@ -970,7 +970,7 @@ void Session::on_message_result_error(uint64 message_id, int error_code, string
error_code = 500; error_code = 500;
} }
if (message_id == 0) { if (message_id == mtproto::MessageId()) {
LOG(ERROR) << "Receive an error without message_id"; LOG(ERROR) << "Receive an error without message_id";
return; return;
} }
@ -998,8 +998,8 @@ void Session::on_message_result_error(uint64 message_id, int error_code, string
sent_queries_.erase(it); sent_queries_.erase(it);
} }
void Session::on_message_failed_inner(uint64 message_id, bool in_container) { void Session::on_message_failed_inner(mtproto::MessageId message_id, bool in_container) {
LOG(INFO) << "Message inner failed " << message_id; LOG(INFO) << "Message inner failed for " << message_id;
auto it = sent_queries_.find(message_id); auto it = sent_queries_.find(message_id);
if (it == sent_queries_.end()) { if (it == sent_queries_.end()) {
return; return;
@ -1016,8 +1016,8 @@ void Session::on_message_failed_inner(uint64 message_id, bool in_container) {
sent_queries_.erase(it); sent_queries_.erase(it);
} }
void Session::on_message_failed(uint64 message_id, Status status) { void Session::on_message_failed(mtproto::MessageId message_id, Status status) {
LOG(INFO) << "Message failed: " << tag("message_id", message_id) << tag("status", status); LOG(INFO) << "Failed to send " << message_id << ": " << status;
status.ignore(); status.ignore();
auto cit = sent_containers_.find(message_id); auto cit = sent_containers_.find(message_id);
@ -1034,8 +1034,8 @@ void Session::on_message_failed(uint64 message_id, Status status) {
on_message_failed_inner(message_id, false); on_message_failed_inner(message_id, false);
} }
void Session::on_message_info(uint64 message_id, int32 state, uint64 answer_message_id, int32 answer_size, void Session::on_message_info(mtproto::MessageId message_id, int32 state, mtproto::MessageId answer_message_id,
int32 source) { int32 answer_size, int32 source) {
auto it = sent_queries_.find(message_id); auto it = sent_queries_.find(message_id);
if (it != sent_queries_.end()) { if (it != sent_queries_.end()) {
if (it->second.net_query_->update_is_ready()) { if (it->second.net_query_->update_is_ready()) {
@ -1049,7 +1049,7 @@ void Session::on_message_info(uint64 message_id, int32 state, uint64 answer_mess
return; return;
} }
} }
if (message_id != 0) { if (message_id != mtproto::MessageId()) {
if (it == sent_queries_.end()) { if (it == sent_queries_.end()) {
return; return;
} }
@ -1060,15 +1060,16 @@ void Session::on_message_info(uint64 message_id, int32 state, uint64 answer_mess
return on_message_failed(message_id, return on_message_failed(message_id,
Status::Error("Message wasn't received by the server and must be re-sent")); Status::Error("Message wasn't received by the server and must be re-sent"));
case 0: case 0:
if (answer_message_id == 0) { if (answer_message_id == mtproto::MessageId()) {
LOG(ERROR) << "Unexpected message_info.state == 0 " << tag("message_id", message_id) << tag("state", state) LOG(ERROR) << "Unexpected message_info.state == 0 for " << message_id << ": " << tag("state", state)
<< tag("answer_message_id", answer_message_id); << tag("answer", answer_message_id);
return on_message_failed(message_id, Status::Error("Unexpected message_info.state == 0")); return on_message_failed(message_id, Status::Error("Unexpected message_info.state == 0"));
} }
// fallthrough // fallthrough
case 4: case 4:
CHECK(0 <= source && source <= 3); CHECK(0 <= source && source <= 3);
on_message_ack_impl(message_id, (answer_message_id ? 2 : 0) | (((state | source) & ((1 << 28) - 1)) << 2)); on_message_ack_impl(message_id, (answer_message_id != mtproto::MessageId() ? 2 : 0) |
(((state | source) & ((1 << 28) - 1)) << 2));
break; break;
default: default:
LOG(ERROR) << "Invalid message info " << tag("state", state); LOG(ERROR) << "Invalid message info " << tag("state", state);
@ -1076,10 +1077,10 @@ void Session::on_message_info(uint64 message_id, int32 state, uint64 answer_mess
} }
// ok, we are waiting for result of message_id. let's ask to resend it // ok, we are waiting for result of message_id. let's ask to resend it
if (answer_message_id != 0) { if (answer_message_id != mtproto::MessageId()) {
if (it != sent_queries_.end()) { if (it != sent_queries_.end()) {
VLOG_IF(net_query, message_id != 0) << "Resend answer " << tag("answer_message_id", answer_message_id) VLOG_IF(net_query, message_id != mtproto::MessageId())
<< tag("answer_size", answer_size) << it->second.net_query_; << "Resend answer " << answer_message_id << ": " << tag("answer_size", answer_size) << it->second.net_query_;
it->second.net_query_->debug(PSTRING() << get_name() << ": resend answer"); it->second.net_query_->debug(PSTRING() << get_name() << ": resend answer");
} }
current_info_->connection_->resend_answer(answer_message_id); current_info_->connection_->resend_answer(answer_message_id);
@ -1114,7 +1115,7 @@ void Session::add_query(NetQueryPtr &&net_query) {
pending_queries_.push(std::move(net_query)); pending_queries_.push(std::move(net_query));
} }
void Session::connection_send_query(ConnectionInfo *info, NetQueryPtr &&net_query, uint64 message_id) { void Session::connection_send_query(ConnectionInfo *info, NetQueryPtr &&net_query, mtproto::MessageId message_id) {
CHECK(info->state_ == ConnectionInfo::State::Ready); CHECK(info->state_ == ConnectionInfo::State::Ready);
current_info_ = info; current_info_ = info;
@ -1123,10 +1124,10 @@ void Session::connection_send_query(ConnectionInfo *info, NetQueryPtr &&net_quer
} }
Span<NetQueryRef> invoke_after = net_query->invoke_after(); Span<NetQueryRef> invoke_after = net_query->invoke_after();
vector<uint64> invoke_after_message_ids; vector<mtproto::MessageId> invoke_after_message_ids;
for (auto &ref : invoke_after) { for (auto &ref : invoke_after) {
auto invoke_after_message_id = ref->message_id(); auto invoke_after_message_id = mtproto::MessageId(ref->message_id());
if (ref->session_id() != auth_data_.get_session_id() || invoke_after_message_id == 0) { if (ref->session_id() != auth_data_.get_session_id() || invoke_after_message_id == mtproto::MessageId()) {
net_query->set_error_resend_invoke_after(); net_query->set_error_resend_invoke_after();
return return_query(std::move(net_query)); return return_query(std::move(net_query));
} }
@ -1155,30 +1156,27 @@ void Session::connection_send_query(ConnectionInfo *info, NetQueryPtr &&net_quer
} }
message_id = r_message_id.ok(); message_id = r_message_id.ok();
} else { } else {
if (message_id == 0) { if (message_id == mtproto::MessageId()) {
message_id = auth_data_.next_message_id(now); message_id = auth_data_.next_message_id(now);
} }
} }
net_query->set_message_id(message_id); net_query->set_message_id(message_id.get());
VLOG(net_query) << "Send query to connection " << net_query VLOG(net_query) << "Send query to connection " << net_query << tag("invoke_after", invoke_after_message_ids);
<< tag("invoke_after", transform(invoke_after_message_ids, [](auto message_id) {
return PSTRING() << format::as_hex(message_id);
}));
{ {
auto lock = net_query->lock(); auto lock = net_query->lock();
net_query->get_data_unsafe().unknown_state_ = false; net_query->get_data_unsafe().unknown_state_ = false;
net_query->get_data_unsafe().ack_state_ = 0; net_query->get_data_unsafe().ack_state_ = 0;
} }
if (!net_query->cancel_slot_.empty()) { if (!net_query->cancel_slot_.empty()) {
LOG(DEBUG) << "Set event for net_query cancellation " << tag("message_id", format::as_hex(message_id)); LOG(DEBUG) << "Set event for net_query cancellation for " << message_id;
net_query->cancel_slot_.set_event(EventCreator::raw(actor_id(), message_id)); net_query->cancel_slot_.set_event(EventCreator::raw(actor_id(), message_id.get()));
} }
auto status = auto status =
sent_queries_.emplace(message_id, Query{message_id, std::move(net_query), main_connection_.connection_id_, now}); sent_queries_.emplace(message_id, Query{message_id, std::move(net_query), main_connection_.connection_id_, now});
LOG_CHECK(status.second) << message_id; LOG_CHECK(status.second) << message_id;
sent_queries_list_.put(status.first->second.get_list_node()); sent_queries_list_.put(status.first->second.get_list_node());
if (!status.second) { if (!status.second) {
LOG(FATAL) << "Duplicate message_id [message_id = " << message_id << "]"; LOG(FATAL) << "Duplicate " << message_id;
} }
if (immediately_fail_query) { if (immediately_fail_query) {
on_message_result_error(message_id, 401, "TEST_ERROR"); on_message_result_error(message_id, 401, "TEST_ERROR");
@ -1308,10 +1306,10 @@ void Session::connection_open_finish(ConnectionInfo *info,
for (auto &message_id : unknown_queries_) { for (auto &message_id : unknown_queries_) {
info->connection_->get_state_info(message_id); info->connection_->get_state_info(message_id);
} }
for (auto &message_id : to_cancel_) { for (auto &message_id : to_cancel_message_ids_) {
info->connection_->cancel_answer(message_id); info->connection_->cancel_answer(message_id);
} }
to_cancel_.clear(); to_cancel_message_ids_.clear();
} }
yield(); yield();
} }
@ -1378,7 +1376,7 @@ bool Session::connection_send_bind_key(ConnectionInfo *info) {
int64 perm_auth_key_id = auth_data_.get_main_auth_key().id(); int64 perm_auth_key_id = auth_data_.get_main_auth_key().id();
int64 nonce = Random::secure_int64(); int64 nonce = Random::secure_int64();
auto expires_at = static_cast<int32>(auth_data_.get_server_time(auth_data_.get_tmp_auth_key().expires_at())); auto expires_at = static_cast<int32>(auth_data_.get_server_time(auth_data_.get_tmp_auth_key().expires_at()));
uint64 message_id; mtproto::MessageId message_id;
BufferSlice encrypted; BufferSlice encrypted;
std::tie(message_id, encrypted) = info->connection_->encrypted_bind(perm_auth_key_id, nonce, expires_at); std::tie(message_id, encrypted) = info->connection_->encrypted_bind(perm_auth_key_id, nonce, expires_at);

View File

@ -14,6 +14,7 @@
#include "td/mtproto/AuthKey.h" #include "td/mtproto/AuthKey.h"
#include "td/mtproto/ConnectionManager.h" #include "td/mtproto/ConnectionManager.h"
#include "td/mtproto/Handshake.h" #include "td/mtproto/Handshake.h"
#include "td/mtproto/MessageId.h"
#include "td/mtproto/SessionConnection.h" #include "td/mtproto/SessionConnection.h"
#include "td/actor/actor.h" #include "td/actor/actor.h"
@ -78,7 +79,7 @@ class Session final
private: private:
struct Query final : private ListNode { struct Query final : private ListNode {
uint64 container_message_id_; mtproto::MessageId container_message_id_;
NetQueryPtr net_query_; NetQueryPtr net_query_;
bool is_acknowledged_ = false; bool is_acknowledged_ = false;
@ -87,7 +88,7 @@ class Session final
const int8 connection_id_; const int8 connection_id_;
const double sent_at_; const double sent_at_;
Query(uint64 message_id, NetQueryPtr &&net_query, int8 connection_id, double sent_at) Query(mtproto::MessageId message_id, NetQueryPtr &&net_query, int8 connection_id, double sent_at)
: container_message_id_(message_id) : container_message_id_(message_id)
, net_query_(std::move(net_query)) , net_query_(std::move(net_query))
, connection_id_(connection_id) , connection_id_(connection_id)
@ -131,8 +132,8 @@ class Session final
double last_bind_success_timestamp_ = 0; // time when auth_key and Session definitely was valid and authorized double last_bind_success_timestamp_ = 0; // time when auth_key and Session definitely was valid and authorized
size_t dropped_size_ = 0; size_t dropped_size_ = 0;
FlatHashSet<uint64> unknown_queries_; FlatHashSet<mtproto::MessageId, mtproto::MessageIdHash> unknown_queries_;
vector<int64> to_cancel_; vector<mtproto::MessageId> to_cancel_message_ids_;
// Do not invalidate iterators of these two containers! // Do not invalidate iterators of these two containers!
// TODO: better data structures // TODO: better data structures
@ -145,7 +146,7 @@ class Session final
std::map<int8, VectorQueue<NetQueryPtr>, std::greater<>> queries_; std::map<int8, VectorQueue<NetQueryPtr>, std::greater<>> queries_;
}; };
PriorityQueue pending_queries_; PriorityQueue pending_queries_;
std::map<uint64, Query> sent_queries_; std::map<mtproto::MessageId, Query> sent_queries_;
std::deque<NetQueryPtr> pending_invoke_after_queries_; std::deque<NetQueryPtr> pending_invoke_after_queries_;
ListNode sent_queries_list_; ListNode sent_queries_list_;
@ -181,9 +182,9 @@ class Session final
struct ContainerInfo { struct ContainerInfo {
size_t ref_cnt; size_t ref_cnt;
vector<uint64> message_ids; vector<mtproto::MessageId> message_ids;
}; };
FlatHashMap<uint64, ContainerInfo> sent_containers_; FlatHashMap<mtproto::MessageId, ContainerInfo, mtproto::MessageIdHash> sent_containers_;
friend class GenAuthKeyActor; friend class GenAuthKeyActor;
struct HandshakeInfo { struct HandshakeInfo {
@ -214,33 +215,34 @@ class Session final
void on_server_salt_updated() final; void on_server_salt_updated() final;
void on_server_time_difference_updated(bool force) final; void on_server_time_difference_updated(bool force) final;
void on_new_session_created(uint64 unique_id, uint64 first_message_id) final; void on_new_session_created(uint64 unique_id, mtproto::MessageId first_message_id) final;
void on_session_failed(Status status) final; void on_session_failed(Status status) final;
void on_container_sent(uint64 container_message_id, vector<uint64> message_ids) final; void on_container_sent(mtproto::MessageId container_message_id, vector<mtproto::MessageId> message_ids) final;
Status on_update(BufferSlice packet) final; Status on_update(BufferSlice packet) final;
void on_message_ack(uint64 message_id) final; void on_message_ack(mtproto::MessageId message_id) final;
Status on_message_result_ok(uint64 message_id, BufferSlice packet, size_t original_size) final; Status on_message_result_ok(mtproto::MessageId message_id, BufferSlice packet, size_t original_size) final;
void on_message_result_error(uint64 message_id, int error_code, string message) final; void on_message_result_error(mtproto::MessageId message_id, int error_code, string message) final;
void on_message_failed(uint64 message_id, Status status) final; void on_message_failed(mtproto::MessageId message_id, Status status) final;
void on_message_info(uint64 message_id, int32 state, uint64 answer_message_id, int32 answer_size, int32 source) final; void on_message_info(mtproto::MessageId message_id, int32 state, mtproto::MessageId answer_message_id,
int32 answer_size, int32 source) final;
Status on_destroy_auth_key() final; Status on_destroy_auth_key() final;
void flush_pending_invoke_after_queries(); void flush_pending_invoke_after_queries();
bool has_queries() const; bool has_queries() const;
void dec_container(uint64 container_message_id, Query *query); void dec_container(mtproto::MessageId container_message_id, Query *query);
void cleanup_container(uint64 container_message_id, Query *query); void cleanup_container(mtproto::MessageId container_message_id, Query *query);
void mark_as_known(uint64 message_id, Query *query); void mark_as_known(mtproto::MessageId message_id, Query *query);
void mark_as_unknown(uint64 message_id, Query *query); void mark_as_unknown(mtproto::MessageId message_id, Query *query);
void on_message_ack_impl(uint64 container_message_id, int32 type); void on_message_ack_impl(mtproto::MessageId container_message_id, int32 type);
void on_message_ack_impl_inner(uint64 message_id, int32 type, bool in_container); void on_message_ack_impl_inner(mtproto::MessageId message_id, int32 type, bool in_container);
void on_message_failed_inner(uint64 message_id, bool in_container); void on_message_failed_inner(mtproto::MessageId message_id, bool in_container);
// send NetQueryPtr to parent // send NetQueryPtr to parent
void return_query(NetQueryPtr &&query); void return_query(NetQueryPtr &&query);
@ -255,7 +257,7 @@ class Session final
void connection_online_update(double now, bool force); void connection_online_update(double now, bool force);
void connection_close(ConnectionInfo *info); void connection_close(ConnectionInfo *info);
void connection_flush(ConnectionInfo *info); void connection_flush(ConnectionInfo *info);
void connection_send_query(ConnectionInfo *info, NetQueryPtr &&net_query, uint64 message_id = 0); void connection_send_query(ConnectionInfo *info, NetQueryPtr &&net_query, mtproto::MessageId message_id = {});
bool need_send_bind_key() const; bool need_send_bind_key() const;
bool need_send_query() const; bool need_send_query() const;
bool can_destroy_auth_key() const; bool can_destroy_auth_key() const;