Add strictly-typed class mtproto::MessageId.

This commit is contained in:
levlam 2023-09-21 17:52:33 +03:00
parent e47cea5904
commit b44e2ea3fc
18 changed files with 355 additions and 265 deletions

View File

@ -518,6 +518,7 @@ set(TDLIB_SOURCE
td/mtproto/HttpTransport.h
td/mtproto/IStreamTransport.h
td/mtproto/KDF.h
td/mtproto/MessageId.h
td/mtproto/MtprotoQuery.h
td/mtproto/NoCryptoStorer.h
td/mtproto/PacketInfo.h

View File

@ -416,7 +416,7 @@ class IdDuplicateCheckerOld {
static td::string get_description() {
return "Old";
}
td::Status check(td::int64 message_id) {
td::Status check(td::uint64 message_id) {
if (saved_message_ids_.size() == MAX_SAVED_MESSAGE_IDS) {
auto oldest_message_id = *saved_message_ids_.begin();
if (message_id < oldest_message_id) {
@ -437,7 +437,7 @@ class IdDuplicateCheckerOld {
private:
static constexpr size_t MAX_SAVED_MESSAGE_IDS = 1000;
std::set<td::int64> saved_message_ids_;
std::set<td::uint64> saved_message_ids_;
};
template <size_t MAX_SAVED_MESSAGE_IDS>
@ -446,7 +446,7 @@ class IdDuplicateCheckerNew {
static td::string get_description() {
return PSTRING() << "New" << MAX_SAVED_MESSAGE_IDS;
}
td::Status check(td::int64 message_id) {
td::Status check(td::uint64 message_id) {
auto insert_result = saved_message_ids_.insert(message_id);
if (!insert_result.second) {
return td::Status::Error(1, PSLICE() << "Ignore already processed message " << message_id);
@ -464,7 +464,7 @@ class IdDuplicateCheckerNew {
}
private:
std::set<td::int64> saved_message_ids_;
std::set<td::uint64> saved_message_ids_;
};
class IdDuplicateCheckerNewOther {
@ -472,7 +472,7 @@ class IdDuplicateCheckerNewOther {
static td::string get_description() {
return "NewOther";
}
td::Status check(td::int64 message_id) {
td::Status check(td::uint64 message_id) {
if (!saved_message_ids_.insert(message_id).second) {
return td::Status::Error(1, PSLICE() << "Ignore already processed message " << message_id);
}
@ -490,7 +490,7 @@ class IdDuplicateCheckerNewOther {
private:
static constexpr size_t MAX_SAVED_MESSAGE_IDS = 1000;
std::set<td::int64> saved_message_ids_;
std::set<td::uint64> saved_message_ids_;
};
class IdDuplicateCheckerNewSimple {
@ -498,7 +498,7 @@ class IdDuplicateCheckerNewSimple {
static td::string get_description() {
return "NewSimple";
}
td::Status check(td::int64 message_id) {
td::Status check(td::uint64 message_id) {
auto insert_result = saved_message_ids_.insert(message_id);
if (!insert_result.second) {
return td::Status::Error(1, "Ignore already processed message");
@ -516,7 +516,7 @@ class IdDuplicateCheckerNewSimple {
private:
static constexpr size_t MAX_SAVED_MESSAGE_IDS = 1000;
std::set<td::int64> saved_message_ids_;
std::set<td::uint64> saved_message_ids_;
};
template <size_t max_size>
@ -525,7 +525,7 @@ class IdDuplicateCheckerArray {
static td::string get_description() {
return PSTRING() << "Array" << max_size;
}
td::Status check(td::int64 message_id) {
td::Status check(td::uint64 message_id) {
if (end_pos_ == 2 * max_size) {
std::copy_n(&saved_message_ids_[max_size], max_size, &saved_message_ids_[0]);
end_pos_ = max_size;
@ -550,7 +550,7 @@ class IdDuplicateCheckerArray {
}
private:
std::array<td::int64, 2 * max_size> saved_message_ids_;
std::array<td::uint64, 2 * max_size> saved_message_ids_;
std::size_t end_pos_ = 0;
};

View File

@ -6,7 +6,6 @@
//
#include "td/mtproto/AuthData.h"
#include "td/utils/format.h"
#include "td/utils/logging.h"
#include "td/utils/Random.h"
#include "td/utils/SliceBuilder.h"
@ -17,7 +16,8 @@
namespace td {
namespace mtproto {
Status check_message_id_duplicates(uint64 *saved_message_ids, size_t max_size, size_t &end_pos, uint64 message_id) {
Status check_message_id_duplicates(MessageId *saved_message_ids, size_t max_size, size_t &end_pos,
MessageId message_id) {
// In addition, the identifiers (msg_id) of the last N messages received from the other side must be stored, and if
// a message comes in with msg_id lower than all or equal to any of the stored values, that message is to be
// ignored. Otherwise, the new message msg_id is added to the set, and, if the number of stored msg_id values is
@ -32,13 +32,12 @@ Status check_message_id_duplicates(uint64 *saved_message_ids, size_t max_size, s
return Status::OK();
}
if (end_pos >= max_size && message_id < saved_message_ids[0]) {
return Status::Error(2, PSLICE() << "Ignore very old message " << format::as_hex(message_id)
<< " older than the oldest known message "
<< format::as_hex(saved_message_ids[0]));
return Status::Error(
2, PSLICE() << "Ignore very old " << message_id << " older than the oldest known " << saved_message_ids[0]);
}
auto it = std::lower_bound(&saved_message_ids[0], &saved_message_ids[end_pos], message_id);
if (*it == message_id) {
return Status::Error(1, PSLICE() << "Ignore already processed message " << format::as_hex(message_id));
return Status::Error(1, PSLICE() << "Ignore already processed " << message_id);
}
std::copy_backward(it, &saved_message_ids[end_pos], &saved_message_ids[end_pos + 1]);
*it = message_id;
@ -105,7 +104,7 @@ std::vector<ServerSalt> AuthData::get_future_salts() const {
return res;
}
uint64 AuthData::next_message_id(double now) {
MessageId AuthData::next_message_id(double now) {
double server_time = get_server_time(now);
auto t = static_cast<uint64>(server_time * (static_cast<uint64>(1) << 32));
@ -113,31 +112,31 @@ uint64 AuthData::next_message_id(double now) {
// TODO(perf) do not do this for systems with good precision?..
auto rx = Random::secure_int32();
auto to_xor = rx & ((1 << 22) - 1);
auto to_mul = ((rx >> 22) & 1023) + 1;
t ^= to_xor;
auto result = t & static_cast<uint64>(-4);
auto result = MessageId(t & static_cast<uint64>(-4));
if (last_message_id_ >= result) {
result = last_message_id_ + 8 * to_mul;
auto to_mul = ((rx >> 22) & 1023) + 1;
result = MessageId(last_message_id_.get() + 8 * to_mul);
}
LOG(DEBUG) << "Create message identifier " << format::as_hex(result) << " at " << now;
LOG(DEBUG) << "Create identifier for " << result << " at " << now;
last_message_id_ = result;
return result;
}
bool AuthData::is_valid_outbound_msg_id(uint64 message_id, double now) const {
bool AuthData::is_valid_outbound_msg_id(MessageId message_id, double now) const {
double server_time = get_server_time(now);
auto id_time = static_cast<double>(message_id) / static_cast<double>(static_cast<uint64>(1) << 32);
auto id_time = static_cast<double>(message_id.get()) / static_cast<double>(static_cast<uint64>(1) << 32);
return server_time - 150 < id_time && id_time < server_time + 30;
}
bool AuthData::is_valid_inbound_msg_id(uint64 message_id, double now) const {
bool AuthData::is_valid_inbound_msg_id(MessageId message_id, double now) const {
double server_time = get_server_time(now);
auto id_time = static_cast<double>(message_id) / static_cast<double>(static_cast<uint64>(1) << 32);
auto id_time = static_cast<double>(message_id.get()) / static_cast<double>(static_cast<uint64>(1) << 32);
return server_time - 300 < id_time && id_time < server_time + 30;
}
Status AuthData::check_packet(uint64 session_id, uint64 message_id, double now, bool &time_difference_was_updated) {
Status AuthData::check_packet(uint64 session_id, MessageId message_id, double now, bool &time_difference_was_updated) {
// Client is to check that the session_id field in the decrypted message indeed equals to that of an active session
// created by the client.
if (get_session_id() != session_id) {
@ -147,22 +146,21 @@ Status AuthData::check_packet(uint64 session_id, uint64 message_id, double now,
// Client must check that msg_id has even parity for messages from client to server, and odd parity for messages
// from server to client.
if ((message_id & 1) == 0) {
return Status::Error(PSLICE() << "Receive invalid message identifier " << format::as_hex(message_id));
if ((message_id.get() & 1) == 0) {
return Status::Error(PSLICE() << "Receive invalid " << message_id);
}
TRY_STATUS(duplicate_checker_.check(message_id));
LOG(DEBUG) << "Receive packet " << format::as_hex(message_id) << " from session " << format::as_hex(session_id)
<< " at " << now;
time_difference_was_updated = update_server_time_difference(static_cast<uint32>(message_id >> 32) - now);
LOG(DEBUG) << "Receive packet in " << message_id << " from session " << session_id << " at " << now;
time_difference_was_updated = update_server_time_difference(static_cast<uint32>(message_id.get() >> 32) - now);
// In addition, msg_id values that belong over 30 seconds in the future or over 300 seconds in the past are to be
// ignored (recall that msg_id approximately equals unixtime * 2^32). This is especially important for the server.
// The client would also find this useful (to protect from a replay attack), but only if it is certain of its time
// (for example, if its time has been synchronized with that of the server).
if (server_time_difference_was_updated_ && !is_valid_inbound_msg_id(message_id, now)) {
return Status::Error(PSLICE() << "Ignore too old or too new message " << format::as_hex(message_id));
return Status::Error(PSLICE() << "Ignore too old or too new " << message_id);
}
return Status::OK();

View File

@ -7,6 +7,7 @@
#pragma once
#include "td/mtproto/AuthKey.h"
#include "td/mtproto/MessageId.h"
#include "td/utils/common.h"
#include "td/utils/Slice.h"
@ -37,17 +38,18 @@ void parse(ServerSalt &salt, ParserT &parser) {
salt.valid_until = parser.fetch_double();
}
Status check_message_id_duplicates(uint64 *saved_message_ids, size_t max_size, size_t &end_pos, uint64 message_id);
Status check_message_id_duplicates(MessageId *saved_message_ids, size_t max_size, size_t &end_pos,
MessageId message_id);
template <size_t max_size>
class MessageIdDuplicateChecker {
public:
Status check(uint64 message_id) {
Status check(MessageId message_id) {
return check_message_id_duplicates(&saved_message_ids_[0], max_size, end_pos_, message_id);
}
private:
std::array<uint64, 2 * max_size> saved_message_ids_;
std::array<MessageId, 2 * max_size> saved_message_ids_;
size_t end_pos_ = 0;
};
@ -232,19 +234,19 @@ class AuthData {
std::vector<ServerSalt> get_future_salts() const;
uint64 next_message_id(double now);
MessageId next_message_id(double now);
bool is_valid_outbound_msg_id(uint64 message_id, double now) const;
bool is_valid_outbound_msg_id(MessageId message_id, double now) const;
bool is_valid_inbound_msg_id(uint64 message_id, double now) const;
bool is_valid_inbound_msg_id(MessageId message_id, double now) const;
Status check_packet(uint64 session_id, uint64 message_id, double now, bool &time_difference_was_updated);
Status check_packet(uint64 session_id, MessageId message_id, double now, bool &time_difference_was_updated);
Status check_update(uint64 message_id) {
Status check_update(MessageId message_id) {
return updates_duplicate_checker_.check(message_id);
}
Status recheck_update(uint64 message_id) {
Status recheck_update(MessageId message_id) {
return updates_duplicate_rechecker_.check(message_id);
}
@ -275,7 +277,7 @@ class AuthData {
bool server_time_difference_was_updated_ = false;
double server_time_difference_ = 0;
ServerSalt server_salt_;
uint64 last_message_id_ = 0;
MessageId last_message_id_;
int32 seq_no_ = 0;
string header_;
uint64 session_id_ = 0;

View File

@ -7,6 +7,7 @@
#pragma once
#include "td/mtproto/AuthData.h"
#include "td/mtproto/MessageId.h"
#include "td/mtproto/MtprotoQuery.h"
#include "td/mtproto/PacketStorer.h"
#include "td/mtproto/utils.h"
@ -57,7 +58,7 @@ class ObjectImpl {
bool empty() const {
return !not_empty_;
}
uint64 get_message_id() const {
MessageId get_message_id() const {
return message_id_;
}
@ -65,7 +66,7 @@ class ObjectImpl {
bool not_empty_;
Object object_;
ObjectStorer object_storer_;
uint64 message_id_;
MessageId message_id_;
int32 seq_no_;
};
@ -96,7 +97,7 @@ class CancelVectorImpl {
bool not_empty() const {
return !storers_.empty();
}
uint64 get_message_id() const {
MessageId get_message_id() const {
CHECK(storers_.size() == 1);
return storers_[0].get_message_id();
}
@ -107,7 +108,7 @@ class CancelVectorImpl {
class InvokeAfter {
public:
explicit InvokeAfter(Span<uint64> message_ids) : message_ids_(message_ids) {
explicit InvokeAfter(Span<MessageId> message_ids) : message_ids_(message_ids) {
}
template <class StorerT>
void store(StorerT &storer) const {
@ -116,7 +117,7 @@ class InvokeAfter {
}
if (message_ids_.size() == 1) {
storer.store_int(static_cast<int32>(0xcb9f372d));
storer.store_binary(message_ids_[0]);
storer.store_binary(message_ids_[0].get());
return;
}
// invokeAfterMsgs#3dc4b4f0 {X:Type} msg_ids:Vector<long> query:!X = X;
@ -124,12 +125,12 @@ class InvokeAfter {
storer.store_int(static_cast<int32>(0x1cb5c415));
storer.store_int(narrow_cast<int32>(message_ids_.size()));
for (auto message_id : message_ids_) {
storer.store_binary(message_id);
storer.store_binary(message_id.get());
}
}
private:
Span<uint64> message_ids_;
Span<MessageId> message_ids_;
};
class QueryImpl {
@ -206,8 +207,8 @@ class CryptoImpl {
CryptoImpl(const vector<MtprotoQuery> &to_send, Slice header, vector<int64> &&to_ack, int64 ping_id, int ping_timeout,
int max_delay, int max_after, int max_wait, int future_salt_n, vector<int64> get_info,
vector<int64> resend, const vector<int64> &cancel, bool destroy_key, AuthData *auth_data,
uint64 *container_message_id, uint64 *get_info_message_id, uint64 *resend_message_id,
uint64 *ping_message_id, uint64 *parent_message_id)
MessageId *container_message_id, MessageId *get_info_message_id, MessageId *resend_message_id,
MessageId *ping_message_id, MessageId *parent_message_id)
: query_storer_(to_send, header)
, ack_empty_(to_ack.empty())
, ack_storer_(!ack_empty_, mtproto_api::msgs_ack(std::move(to_ack)), auth_data)
@ -362,7 +363,7 @@ class CryptoImpl {
Mixed
};
Type type_;
uint64 message_id_;
MessageId message_id_;
int32 seq_no_;
};

View File

@ -8,6 +8,7 @@
#include "td/mtproto/AuthKey.h"
#include "td/mtproto/Handshake.h"
#include "td/mtproto/MessageId.h"
#include "td/mtproto/NoCryptoStorer.h"
#include "td/mtproto/PacketInfo.h"
#include "td/mtproto/PacketStorer.h"
@ -61,7 +62,7 @@ class HandshakeConnection final
unique_ptr<AuthKeyHandshakeContext> context_;
void send_no_crypto(const Storer &storer) final {
raw_connection_->send_no_crypto(PacketStorer<NoCryptoImpl>(0, storer));
raw_connection_->send_no_crypto(PacketStorer<NoCryptoImpl>(MessageId(), storer));
}
Status on_raw_packet(const PacketInfo &packet_info, BufferSlice packet) final {

70
td/mtproto/MessageId.h Normal file
View File

@ -0,0 +1,70 @@
//
// Copyright Aliaksei Levin (levlam@telegram.org), Arseny Smirnov (arseny30@gmail.com) 2014-2023
//
// Distributed under the Boost Software License, Version 1.0. (See accompanying
// file LICENSE_1_0.txt or copy at http://www.boost.org/LICENSE_1_0.txt)
//
#pragma once
#include "td/utils/common.h"
#include "td/utils/format.h"
#include "td/utils/HashTableUtils.h"
#include "td/utils/StringBuilder.h"
#include <type_traits>
namespace td {
namespace mtproto {
class MessageId {
uint64 message_id_ = 0;
public:
MessageId() = default;
explicit constexpr MessageId(uint64 message_id) : message_id_(message_id) {
}
template <class T, typename = std::enable_if_t<std::is_convertible<T, int64>::value>>
MessageId(T message_id) = delete;
uint64 get() const {
return message_id_;
}
bool operator==(const MessageId &other) const {
return message_id_ == other.message_id_;
}
bool operator!=(const MessageId &other) const {
return message_id_ != other.message_id_;
}
friend bool operator<(const MessageId &lhs, const MessageId &rhs) {
return lhs.get() < rhs.get();
}
friend bool operator>(const MessageId &lhs, const MessageId &rhs) {
return lhs.get() > rhs.get();
}
friend bool operator<=(const MessageId &lhs, const MessageId &rhs) {
return lhs.get() <= rhs.get();
}
friend bool operator>=(const MessageId &lhs, const MessageId &rhs) {
return lhs.get() >= rhs.get();
}
};
struct MessageIdHash {
uint32 operator()(MessageId message_id) const {
return Hash<uint64>()(message_id.get());
}
};
inline StringBuilder &operator<<(StringBuilder &string_builder, MessageId message_id) {
return string_builder << "message " << format::as_hex(message_id.get());
}
} // namespace mtproto
} // namespace td

View File

@ -6,6 +6,8 @@
//
#pragma once
#include "td/mtproto/MessageId.h"
#include "td/utils/buffer.h"
#include "td/utils/common.h"
@ -13,11 +15,11 @@ namespace td {
namespace mtproto {
struct MtprotoQuery {
uint64 message_id;
MessageId message_id;
int32 seq_no;
BufferSlice packet;
bool gzip_flag;
vector<uint64> invoke_after_message_ids;
vector<MessageId> invoke_after_message_ids;
bool use_quick_ack;
};

View File

@ -6,6 +6,8 @@
//
#pragma once
#include "td/mtproto/MessageId.h"
#include "td/utils/Random.h"
#include "td/utils/StorerBase.h"
@ -14,7 +16,7 @@ namespace mtproto {
class NoCryptoImpl {
public:
NoCryptoImpl(uint64 message_id, const Storer &data, bool need_pad = true) : message_id_(message_id), data_(data) {
NoCryptoImpl(MessageId message_id, const Storer &data, bool need_pad = true) : message_id_(message_id), data_(data) {
if (need_pad) {
size_t pad_size = -static_cast<int>(data_.size()) & 15;
pad_size += 16 * (static_cast<size_t>(Random::secure_int32()) % 16);
@ -25,14 +27,14 @@ class NoCryptoImpl {
template <class StorerT>
void do_store(StorerT &storer) const {
storer.store_binary(message_id_);
storer.store_binary(message_id_.get());
storer.store_binary(static_cast<int32>(data_.size() + pad_.size()));
storer.store_storer(data_);
storer.store_slice(pad_);
}
private:
uint64 message_id_;
MessageId message_id_;
const Storer &data_;
std::string pad_;
};

View File

@ -6,6 +6,8 @@
//
#pragma once
#include "td/mtproto/MessageId.h"
#include "td/utils/common.h"
namespace td {
@ -18,7 +20,7 @@ struct PacketInfo {
uint64 salt{0};
uint64 session_id{0};
uint64 message_id{0};
MessageId message_id;
int32 seq_no{0};
int32 version{1};
bool no_crypto_flag{false};

View File

@ -8,6 +8,7 @@
#include "td/mtproto/AuthData.h"
#include "td/mtproto/AuthKey.h"
#include "td/mtproto/MessageId.h"
#include "td/mtproto/mtproto_api.h"
#include "td/mtproto/NoCryptoStorer.h"
#include "td/mtproto/PacketInfo.h"
@ -47,7 +48,8 @@ class PingConnectionReqPQ final
if (!was_ping_) {
UInt128 nonce;
Random::secure_bytes(nonce.raw, sizeof(nonce));
raw_connection_->send_no_crypto(PacketStorer<NoCryptoImpl>(1, create_storer(mtproto_api::req_pq_multi(nonce))));
raw_connection_->send_no_crypto(PacketStorer<NoCryptoImpl>(MessageId(static_cast<uint64>(1)),
create_storer(mtproto_api::req_pq_multi(nonce))));
was_ping_ = true;
if (ping_count_ == 1) {
start_time_ = Time::now();
@ -129,13 +131,13 @@ class PingConnectionPingPong final
void on_server_time_difference_updated(bool force) final {
}
void on_new_session_created(uint64 unique_id, uint64 first_message_id) final {
void on_new_session_created(uint64 unique_id, MessageId first_message_id) final {
}
void on_session_failed(Status status) final {
}
void on_container_sent(uint64 container_message_id, vector<uint64> message_ids) final {
void on_container_sent(MessageId container_message_id, vector<MessageId> message_ids) final {
}
Status on_pong() final {
@ -153,21 +155,22 @@ class PingConnectionPingPong final
return Status::OK();
}
void on_message_ack(uint64 message_id) final {
void on_message_ack(MessageId message_id) final {
}
Status on_message_result_ok(uint64 message_id, BufferSlice packet, size_t original_size) final {
Status on_message_result_ok(MessageId message_id, BufferSlice packet, size_t original_size) final {
LOG(ERROR) << "Unexpected message";
return Status::OK();
}
void on_message_result_error(uint64 message_id, int code, string message) final {
void on_message_result_error(MessageId message_id, int code, string message) final {
}
void on_message_failed(uint64 message_id, Status status) final {
void on_message_failed(MessageId message_id, Status status) final {
}
void on_message_info(uint64 message_id, int32 state, uint64 answer_id, int32 answer_size, int32 source) final {
void on_message_info(MessageId message_id, int32 state, MessageId answer_message_id, int32 answer_size,
int32 source) final {
}
Status on_destroy_auth_key() final {

View File

@ -86,7 +86,7 @@ class RawConnectionDefault final : public RawConnection {
return packet_size;
}
uint64 send_no_crypto(const Storer &storer) final {
MessageId send_no_crypto(const Storer &storer) final {
PacketInfo packet_info;
packet_info.no_crypto_flag = true;
auto packet = Transport::write(storer, AuthKey(), &packet_info, transport_->max_prepend_size(),
@ -315,7 +315,7 @@ class RawConnectionHttp final : public RawConnection {
return packet_size;
}
uint64 send_no_crypto(const Storer &storer) final {
MessageId send_no_crypto(const Storer &storer) final {
PacketInfo packet_info;
packet_info.no_crypto_flag = true;
auto packet = Transport::write(storer, AuthKey(), &packet_info);

View File

@ -7,6 +7,7 @@
#pragma once
#include "td/mtproto/ConnectionManager.h"
#include "td/mtproto/MessageId.h"
#include "td/mtproto/PacketInfo.h"
#include "td/mtproto/TransportType.h"
@ -50,7 +51,7 @@ class RawConnection {
virtual TransportType get_transport_type() const = 0;
virtual size_t send_crypto(const Storer &storer, uint64 session_id, int64 salt, const AuthKey &auth_key,
uint64 quick_ack_token) = 0;
virtual uint64 send_no_crypto(const Storer &storer) = 0;
virtual MessageId send_no_crypto(const Storer &storer) = 0;
virtual PollableFdInfo &get_poll_info() = 0;
virtual StatsCallback *stats_callback() = 0;
@ -63,7 +64,7 @@ class RawConnection {
virtual ~Callback() = default;
virtual Status on_raw_packet(const PacketInfo &packet_info, BufferSlice packet) = 0;
virtual Status on_quick_ack(uint64 quick_ack_token) {
return Status::Error("Quick acknowledgements are unsupported by the callback");
return Status::Error("Quick acknowledgements aren't supported by the callback");
}
virtual Status before_write() {
return Status::OK();

View File

@ -172,7 +172,7 @@ namespace mtproto {
*/
inline StringBuilder &operator<<(StringBuilder &string_builder, const SessionConnection::MsgInfo &info) {
return string_builder << "[msg_id:" << format::as_hex(info.message_id) << "][seq_no:" << info.seq_no << ']';
return string_builder << "with " << info.message_id << " and seq_no " << info.seq_no;
}
unique_ptr<RawConnection> SessionConnection::move_as_raw_connection() {
@ -190,7 +190,7 @@ Status SessionConnection::parse_message(TlParser &parser, MsgInfo *info, Slice *
if (parser.get_error() != nullptr) {
return Status::Error(PSLICE() << "Failed to parse mtproto_api::message: " << parser.get_error());
}
info->message_id = parser.fetch_long_unsafe();
info->message_id = MessageId(static_cast<uint64>(parser.fetch_long_unsafe()));
if (crypto_flag) {
info->seq_no = parser.fetch_int_unsafe();
}
@ -223,22 +223,22 @@ Status SessionConnection::on_packet_container(const MsgInfo &info, Slice packet)
if (parser.get_error()) {
return Status::Error(PSLICE() << "Failed to parse mtproto_api::rpc_container: " << parser.get_error());
}
VLOG(mtproto) << "Receive container " << format::as_hex(container_message_id_) << " of size " << size;
VLOG(mtproto) << "Receive container " << container_message_id_ << " of size " << size;
for (int i = 0; i < size; i++) {
TRY_STATUS(parse_packet(parser));
}
return Status::OK();
}
void SessionConnection::reset_server_time_difference(uint64 message_id) {
void SessionConnection::reset_server_time_difference(MessageId message_id) {
VLOG(mtproto) << "Reset server time difference";
auth_data_->reset_server_time_difference(static_cast<uint32>(message_id >> 32) - Time::now());
auth_data_->reset_server_time_difference(static_cast<uint32>(message_id.get() >> 32) - Time::now());
callback_->on_server_time_difference_updated(true);
}
Status SessionConnection::on_packet_rpc_result(const MsgInfo &info, Slice packet) {
TlParser parser(packet);
uint64 req_msg_id = parser.fetch_long();
uint64 req_msg_id = static_cast<uint64>(parser.fetch_long());
if (parser.get_error()) {
return Status::Error(PSLICE() << "Failed to parse mtproto_api::rpc_result: " << parser.get_error());
}
@ -246,9 +246,9 @@ Status SessionConnection::on_packet_rpc_result(const MsgInfo &info, Slice packet
LOG(ERROR) << "Receive an update in rpc_result " << info;
return Status::Error("Receive an update in rpc_result");
}
VLOG(mtproto) << "Receive result for request " << format::as_hex(req_msg_id) << " with " << info;
VLOG(mtproto) << "Receive result for request with " << MessageId(req_msg_id) << ' ' << info;
if (info.message_id < req_msg_id - (static_cast<uint64>(15) << 32)) {
if (info.message_id.get() < req_msg_id - (static_cast<uint64>(15) << 32)) {
reset_server_time_difference(info.message_id);
}
@ -258,7 +258,7 @@ Status SessionConnection::on_packet_rpc_result(const MsgInfo &info, Slice packet
if (parser.get_error()) {
return Status::Error(PSLICE() << "Failed to parse mtproto_api::rpc_error: " << parser.get_error());
}
callback_->on_message_result_error(req_msg_id, rpc_error.error_code_, rpc_error.error_message_.str());
callback_->on_message_result_error(MessageId(req_msg_id), rpc_error.error_code_, rpc_error.error_message_.str());
return Status::OK();
}
case mtproto_api::gzip_packed::ID: {
@ -269,11 +269,11 @@ Status SessionConnection::on_packet_rpc_result(const MsgInfo &info, Slice packet
// yep, gzip in rpc_result
BufferSlice object = gzdecode(gzip.packed_data_);
// send header no more optimization
return callback_->on_message_result_ok(req_msg_id, std::move(object), info.size);
return callback_->on_message_result_ok(MessageId(req_msg_id), std::move(object), info.size);
}
default:
packet.remove_prefix(sizeof(req_msg_id));
return callback_->on_message_result_ok(req_msg_id, as_buffer_slice(packet), info.size);
return callback_->on_message_result_ok(MessageId(req_msg_id), as_buffer_slice(packet), info.size);
}
}
@ -284,17 +284,17 @@ Status SessionConnection::on_packet(const MsgInfo &info, const T &packet) {
}
Status SessionConnection::on_packet(const MsgInfo &info, const mtproto_api::destroy_auth_key_ok &destroy_auth_key) {
VLOG(mtproto) << "Receive destroy_auth_key_ok with " << info;
VLOG(mtproto) << "Receive destroy_auth_key_ok " << info;
return on_destroy_auth_key(destroy_auth_key);
}
Status SessionConnection::on_packet(const MsgInfo &info, const mtproto_api::destroy_auth_key_none &destroy_auth_key) {
VLOG(mtproto) << "Receive destroy_auth_key_none with " << info;
VLOG(mtproto) << "Receive destroy_auth_key_none " << info;
return on_destroy_auth_key(destroy_auth_key);
}
Status SessionConnection::on_packet(const MsgInfo &info, const mtproto_api::destroy_auth_key_fail &destroy_auth_key) {
VLOG(mtproto) << "Receive destroy_auth_key_fail with " << info;
VLOG(mtproto) << "Receive destroy_auth_key_fail " << info;
return on_destroy_auth_key(destroy_auth_key);
}
@ -304,14 +304,14 @@ Status SessionConnection::on_destroy_auth_key(const mtproto_api::DestroyAuthKeyR
}
Status SessionConnection::on_packet(const MsgInfo &info, const mtproto_api::new_session_created &new_session_created) {
auto first_message_id = static_cast<uint64>(new_session_created.first_msg_id_);
VLOG(mtproto) << "Receive new_session_created with " << info << ": [first_msg_id:" << format::as_hex(first_message_id)
<< "] [unique_id:" << format::as_hex(new_session_created.unique_id_) << ']';
auto first_message_id = MessageId(static_cast<uint64>(new_session_created.first_msg_id_));
VLOG(mtproto) << "Receive new_session_created " << info << ": [first " << first_message_id
<< "] [unique_id:" << new_session_created.unique_id_ << ']';
auto it = service_queries_.find(first_message_id);
if (it != service_queries_.end()) {
first_message_id = it->second.container_message_id_;
LOG(INFO) << "Update first_message_id to container's " << format::as_hex(first_message_id);
LOG(INFO) << "Update first_message_id to container's " << first_message_id;
}
callback_->on_new_session_created(new_session_created.unique_id_, first_message_id);
@ -320,7 +320,8 @@ Status SessionConnection::on_packet(const MsgInfo &info, const mtproto_api::new_
Status SessionConnection::on_packet(const MsgInfo &info,
const mtproto_api::bad_msg_notification &bad_msg_notification) {
MsgInfo bad_info{static_cast<uint64>(bad_msg_notification.bad_msg_id_), bad_msg_notification.bad_msg_seqno_, 0};
MsgInfo bad_info{MessageId(static_cast<uint64>(bad_msg_notification.bad_msg_id_)),
bad_msg_notification.bad_msg_seqno_, 0};
enum Code {
MsgIdTooLow = 16,
MsgIdTooHigh = 17,
@ -383,8 +384,8 @@ Status SessionConnection::on_packet(const MsgInfo &info,
}
Status SessionConnection::on_packet(const MsgInfo &info, const mtproto_api::bad_server_salt &bad_server_salt) {
MsgInfo bad_info{static_cast<uint64>(bad_server_salt.bad_msg_id_), bad_server_salt.bad_msg_seqno_, 0};
VLOG(mtproto) << "Receive bad_server_salt with " << info << ": " << bad_info;
MsgInfo bad_info{MessageId(static_cast<uint64>(bad_server_salt.bad_msg_id_)), bad_server_salt.bad_msg_seqno_, 0};
VLOG(mtproto) << "Receive bad_server_salt " << info << ": " << bad_info;
auth_data_->set_server_salt(bad_server_salt.new_server_salt_, Time::now_cached());
callback_->on_server_salt_updated();
@ -393,8 +394,9 @@ Status SessionConnection::on_packet(const MsgInfo &info, const mtproto_api::bad_
}
Status SessionConnection::on_packet(const MsgInfo &info, const mtproto_api::msgs_ack &msgs_ack) {
VLOG(mtproto) << "Receive msgs_ack with " << info << ": " << msgs_ack.msg_ids_;
for (auto message_id : msgs_ack.msg_ids_) {
auto message_ids = transform(msgs_ack.msg_ids_, [](int64 msg_id) { return MessageId(static_cast<uint64>(msg_id)); });
VLOG(mtproto) << "Receive msgs_ack " << info << ": " << message_ids;
for (auto message_id : message_ids) {
callback_->on_message_ack(message_id);
}
return Status::OK();
@ -407,8 +409,8 @@ Status SessionConnection::on_packet(const MsgInfo &info, const mtproto_api::gzip
}
Status SessionConnection::on_packet(const MsgInfo &info, const mtproto_api::pong &pong) {
VLOG(mtproto) << "Receive pong with " << info;
if (info.message_id < static_cast<uint64>(pong.msg_id_) - (static_cast<uint64>(15) << 32)) {
VLOG(mtproto) << "Receive pong " << info;
if (info.message_id.get() < static_cast<uint64>(pong.msg_id_) - (static_cast<uint64>(15) << 32)) {
reset_server_time_difference(info.message_id);
}
last_pong_at_ = Time::now_cached();
@ -424,7 +426,7 @@ Status SessionConnection::on_packet(const MsgInfo &info, const mtproto_api::futu
}
auto now = Time::now_cached();
auth_data_->set_future_salts(new_salts, now);
VLOG(mtproto) << "Receive future_salts with " << info << ": is_valid = " << auth_data_->is_server_salt_valid(now)
VLOG(mtproto) << "Receive future_salts " << info << ": is_valid = " << auth_data_->is_server_salt_valid(now)
<< ", has_salt = " << auth_data_->has_salt(now)
<< ", need_future_salts = " << auth_data_->need_future_salts(now);
callback_->on_server_salt_updated();
@ -438,14 +440,14 @@ Status SessionConnection::on_msgs_state_info(const vector<int64> &msg_ids, Slice
}
size_t i = 0;
for (auto msg_id : msg_ids) {
callback_->on_message_info(static_cast<uint64>(msg_id), info[i], 0, 0, 1);
callback_->on_message_info(MessageId(static_cast<uint64>(msg_id)), info[i], MessageId(), 0, 1);
i++;
}
return Status::OK();
}
Status SessionConnection::on_packet(const MsgInfo &info, const mtproto_api::msgs_state_info &msgs_state_info) {
auto message_id = static_cast<uint64>(msgs_state_info.req_msg_id_);
auto message_id = MessageId(static_cast<uint64>(msgs_state_info.req_msg_id_));
auto it = service_queries_.find(message_id);
if (it == service_queries_.end()) {
return Status::Error("Unknown msgs_state_info");
@ -456,26 +458,28 @@ Status SessionConnection::on_packet(const MsgInfo &info, const mtproto_api::msgs
if (query.type_ != ServiceQuery::GetStateInfo) {
return Status::Error("Receive msgs_state_info in response not to GetStateInfo");
}
VLOG(mtproto) << "Receive msgs_state_info with " << info;
VLOG(mtproto) << "Receive msgs_state_info " << info;
return on_msgs_state_info(query.msg_ids_, msgs_state_info.info_);
}
Status SessionConnection::on_packet(const MsgInfo &info, const mtproto_api::msgs_all_info &msgs_all_info) {
VLOG(mtproto) << "Receive msgs_all_info with " << info;
VLOG(mtproto) << "Receive msgs_all_info " << info;
return on_msgs_state_info(msgs_all_info.msg_ids_, msgs_all_info.info_);
}
Status SessionConnection::on_packet(const MsgInfo &info, const mtproto_api::msg_detailed_info &msg_detailed_info) {
VLOG(mtproto) << "Receive msg_detailed_info with " << info;
callback_->on_message_info(msg_detailed_info.msg_id_, msg_detailed_info.status_, msg_detailed_info.answer_msg_id_,
msg_detailed_info.bytes_, 2);
VLOG(mtproto) << "Receive msg_detailed_info " << info;
callback_->on_message_info(MessageId(static_cast<uint64>(msg_detailed_info.msg_id_)), msg_detailed_info.status_,
MessageId(static_cast<uint64>(msg_detailed_info.answer_msg_id_)), msg_detailed_info.bytes_,
2);
return Status::OK();
}
Status SessionConnection::on_packet(const MsgInfo &info,
const mtproto_api::msg_new_detailed_info &msg_new_detailed_info) {
VLOG(mtproto) << "Receive msg_new_detailed_info with " << info;
callback_->on_message_info(0, 0, msg_new_detailed_info.answer_msg_id_, msg_new_detailed_info.bytes_, 0);
VLOG(mtproto) << "Receive msg_new_detailed_info " << info;
callback_->on_message_info(MessageId(), 0, MessageId(static_cast<uint64>(msg_new_detailed_info.answer_msg_id_)),
msg_new_detailed_info.bytes_, 0);
return Status::OK();
}
@ -517,9 +521,8 @@ Status SessionConnection::on_slice_packet(const MsgInfo &info, Slice packet) {
auto get_update_description = [&] {
return PSTRING() << "update from " << get_name() << " with auth key " << auth_data_->get_auth_key().id()
<< " active for " << (Time::now() - created_at_) << " seconds in container "
<< container_message_id_ << " from session " << auth_data_->get_session_id() << " with " << info
<< ", main_message_id = " << format::as_hex(main_message_id_)
<< " and original size = " << info.size;
<< container_message_id_ << " from session " << auth_data_->get_session_id() << ' ' << info
<< ", main " << main_message_id_ << " and original size = " << info.size;
};
// It is an update... I hope.
@ -560,8 +563,8 @@ Status SessionConnection::on_main_packet(const PacketInfo &packet_info, Slice pa
}
VLOG(raw_mtproto) << "Receive packet of size " << packet.size() << ':' << format::as_hex_dump<4>(packet);
VLOG(mtproto) << "Receive packet with seq_no " << packet_info.seq_no << " and msg_id "
<< format::as_hex(packet_info.message_id) << " of size " << packet.size();
VLOG(mtproto) << "Receive packet with " << packet_info.message_id << " and seq_no " << packet_info.seq_no
<< " of size " << packet.size();
if (packet_info.no_crypto_flag) {
return Status::Error("Unencrypted packet");
@ -576,7 +579,7 @@ Status SessionConnection::on_main_packet(const PacketInfo &packet_info, Slice pa
return Status::OK();
}
void SessionConnection::on_message_failed(uint64 message_id, Status status) {
void SessionConnection::on_message_failed(MessageId message_id, Status status) {
callback_->on_message_failed(message_id, std::move(status));
sent_destroy_auth_key_ = false;
@ -584,8 +587,8 @@ void SessionConnection::on_message_failed(uint64 message_id, Status status) {
if (message_id == last_ping_message_id_ || message_id == last_ping_container_message_id_) {
// restart ping immediately
last_ping_at_ = 0;
last_ping_message_id_ = 0;
last_ping_container_message_id_ = 0;
last_ping_message_id_ = {};
last_ping_container_message_id_ = {};
}
auto cit = container_to_service_message_id_.find(message_id);
@ -599,7 +602,7 @@ void SessionConnection::on_message_failed(uint64 message_id, Status status) {
}
}
void SessionConnection::on_message_failed_inner(uint64 message_id) {
void SessionConnection::on_message_failed_inner(MessageId message_id) {
auto it = service_queries_.find(message_id);
if (it == service_queries_.end()) {
return;
@ -610,12 +613,12 @@ void SessionConnection::on_message_failed_inner(uint64 message_id) {
switch (query.type_) {
case ServiceQuery::ResendAnswer:
for (auto msg_id : query.msg_ids_) {
resend_answer(static_cast<uint64>(msg_id));
resend_answer(MessageId(static_cast<uint64>(msg_id)));
}
break;
case ServiceQuery::GetStateInfo:
for (auto msg_id : query.msg_ids_) {
get_state_info(static_cast<uint64>(msg_id));
get_state_info(MessageId(static_cast<uint64>(msg_id)));
}
break;
default:
@ -726,7 +729,7 @@ Status SessionConnection::on_raw_packet(const PacketInfo &packet_info, BufferSli
}
Status SessionConnection::on_quick_ack(uint64 quick_ack_token) {
callback_->on_message_ack(quick_ack_token);
callback_->on_message_ack(MessageId(quick_ack_token));
return Status::OK();
}
@ -773,8 +776,8 @@ void SessionConnection::set_online(bool online_flag, bool is_main) {
last_read_at_ = now;
}
last_ping_at_ = 0;
last_ping_message_id_ = 0;
last_ping_container_message_id_ = 0;
last_ping_message_id_ = {};
last_ping_container_message_id_ = {};
}
void SessionConnection::do_close(Status status) {
@ -790,10 +793,10 @@ void SessionConnection::send_crypto(const Storer &storer, uint64 quick_ack_token
auth_data_->get_auth_key(), quick_ack_token);
}
Result<uint64> SessionConnection::send_query(BufferSlice buffer, bool gzip_flag, uint64 message_id,
vector<uint64> invoke_after_message_ids, bool use_quick_ack) {
Result<MessageId> SessionConnection::send_query(BufferSlice buffer, bool gzip_flag, MessageId message_id,
vector<MessageId> invoke_after_message_ids, bool use_quick_ack) {
CHECK(mode_ != Mode::HttpLongPoll); // "LongPoll connection is only for http_wait"
if (message_id == 0) {
if (message_id == MessageId()) {
message_id = auth_data_->next_message_id(Time::now_cached());
}
auto seq_no = auth_data_->next_seq_no(true);
@ -802,28 +805,28 @@ Result<uint64> SessionConnection::send_query(BufferSlice buffer, bool gzip_flag,
}
to_send_.push_back(MtprotoQuery{message_id, seq_no, std::move(buffer), gzip_flag, std::move(invoke_after_message_ids),
use_quick_ack});
VLOG(mtproto) << "Invoke query with msg_id " << format::as_hex(message_id) << " and seq_no " << seq_no << " of size "
VLOG(mtproto) << "Invoke query with " << message_id << " and seq_no " << seq_no << " of size "
<< to_send_.back().packet.size() << " after " << invoke_after_message_ids
<< (use_quick_ack ? " with quick ack" : "");
return message_id;
}
void SessionConnection::get_state_info(uint64 message_id) {
void SessionConnection::get_state_info(MessageId message_id) {
if (to_get_state_info_message_ids_.empty()) {
send_before(Time::now_cached());
}
to_get_state_info_message_ids_.push_back(message_id);
}
void SessionConnection::resend_answer(uint64 message_id) {
void SessionConnection::resend_answer(MessageId message_id) {
if (to_resend_answer_message_ids_.empty()) {
send_before(Time::now_cached() + RESEND_ANSWER_DELAY);
}
to_resend_answer_message_ids_.push_back(message_id);
}
void SessionConnection::cancel_answer(uint64 message_id) {
void SessionConnection::cancel_answer(MessageId message_id) {
if (to_cancel_answer_message_ids_.empty()) {
send_before(Time::now_cached() + RESEND_ANSWER_DELAY);
}
@ -835,7 +838,7 @@ void SessionConnection::destroy_key() {
need_destroy_auth_key_ = true;
}
std::pair<uint64, BufferSlice> SessionConnection::encrypted_bind(int64 perm_key, int64 nonce, int32 expires_at) {
std::pair<MessageId, BufferSlice> SessionConnection::encrypted_bind(int64 perm_key, int64 nonce, int32 expires_at) {
int64 temp_key = auth_data_->get_tmp_auth_key().id();
mtproto_api::bind_auth_key_inner object(nonce, temp_key, perm_key, auth_data_->get_session_id(), expires_at);
@ -865,8 +868,8 @@ void SessionConnection::force_ack() {
}
}
void SessionConnection::send_ack(uint64 message_id) {
VLOG(mtproto) << "Send ack: [msg_id:" << format::as_hex(message_id) << "]";
void SessionConnection::send_ack(MessageId message_id) {
VLOG(mtproto) << "Send ack for " << message_id;
if (to_ack_message_ids_.empty()) {
send_before(Time::now_cached() + ACK_DELAY);
}
@ -881,7 +884,7 @@ void SessionConnection::send_ack(uint64 message_id) {
}
}
// don't send ping in poll mode.
// don't send ping in poll mode
bool SessionConnection::may_ping() const {
return last_ping_at_ == 0 || (mode_ != Mode::HttpLongPoll && last_ping_at_ + ping_may_delay() < Time::now_cached());
}
@ -893,7 +896,7 @@ bool SessionConnection::must_ping() const {
void SessionConnection::flush_packet() {
bool has_salt = auth_data_->has_salt(Time::now_cached());
// ping
uint64 container_message_id = 0;
MessageId container_message_id;
int64 ping_id = 0;
if (has_salt && may_ping()) {
ping_id = ++cur_ping_id_;
@ -963,53 +966,54 @@ void SessionConnection::flush_packet() {
<< tag("cancel", to_cancel_answer_message_ids_.size()) << tag("destroy_key", destroy_auth_key)
<< tag("auth_key_id", auth_data_->get_auth_key().id());
auto cut_tail = [](vector<uint64> &v, size_t size, Slice name) {
if (size >= v.size()) {
auto result = transform(v, [](uint64 x) { return static_cast<int64>(x); });
v.clear();
auto cut_tail = [](vector<MessageId> &message_ids, size_t size, Slice name) {
if (size >= message_ids.size()) {
auto result = transform(message_ids, [](MessageId message_id) { return static_cast<int64>(message_id.get()); });
message_ids.clear();
return result;
}
LOG(WARNING) << "Too many message identifiers in container " << name << ": " << v.size() << " instead of " << size;
auto new_size = v.size() - size;
LOG(WARNING) << "Too many message identifiers in container " << name << ": " << message_ids.size() << " instead of "
<< size;
auto new_size = message_ids.size() - size;
vector<int64> result(size);
for (size_t i = 0; i < size; i++) {
result[i] = static_cast<int64>(v[i + new_size]);
result[i] = static_cast<int64>(message_ids[i + new_size].get());
}
v.resize(new_size);
message_ids.resize(new_size);
return result;
};
// no more than 8192 message identifiers per container..
auto to_resend_answer = cut_tail(to_resend_answer_message_ids_, 8192, "resend_answer");
uint64 resend_answer_message_id = 0;
MessageId resend_answer_message_id;
CHECK(queries.size() <= 1020);
auto to_cancel_answer = cut_tail(to_cancel_answer_message_ids_, 1020 - queries.size(), "cancel_answer");
auto to_get_state_info = cut_tail(to_get_state_info_message_ids_, 8192, "get_state_info");
uint64 get_state_info_message_id = 0;
MessageId get_state_info_message_id;
auto to_ack = cut_tail(to_ack_message_ids_, 8192, "ack");
uint64 ping_message_id = 0;
MessageId ping_message_id;
bool use_quick_ack =
std::any_of(queries.begin(), queries.end(), [](const auto &query) { return query.use_quick_ack; });
{
// LOG(ERROR) << (auth_data_->get_header().empty() ? '-' : '+');
uint64 parent_message_id = 0;
MessageId parent_message_id;
auto storer = PacketStorer<CryptoImpl>(
queries, auth_data_->get_header(), std::move(to_ack), ping_id, static_cast<int>(ping_disconnect_delay() + 2.0),
max_delay, max_after, max_wait, future_salt_n, to_get_state_info, to_resend_answer, to_cancel_answer,
destroy_auth_key, auth_data_, &container_message_id, &get_state_info_message_id, &resend_answer_message_id,
&ping_message_id, &parent_message_id);
auto quick_ack_token = use_quick_ack ? parent_message_id : 0;
auto quick_ack_token = use_quick_ack ? parent_message_id.get() : 0;
send_crypto(storer, quick_ack_token);
}
if (resend_answer_message_id) {
if (resend_answer_message_id != MessageId()) {
service_queries_.emplace(resend_answer_message_id, ServiceQuery{ServiceQuery::ResendAnswer, container_message_id,
std::move(to_resend_answer)});
}
if (get_state_info_message_id) {
if (get_state_info_message_id != MessageId()) {
service_queries_.emplace(get_state_info_message_id, ServiceQuery{ServiceQuery::GetStateInfo, container_message_id,
std::move(to_get_state_info)});
}
@ -1018,8 +1022,8 @@ void SessionConnection::flush_packet() {
last_ping_message_id_ = ping_message_id;
}
if (container_message_id != 0) {
auto message_ids = transform(queries, [](const MtprotoQuery &x) { return static_cast<uint64>(x.message_id); });
if (container_message_id != MessageId()) {
auto message_ids = transform(queries, [](const MtprotoQuery &x) { return x.message_id; });
// some acks may be lost here. Nobody will resend them if something goes wrong with query.
// It is mostly problem for server. We will just drop this answers in next connection
@ -1028,10 +1032,10 @@ void SessionConnection::flush_packet() {
// So I will re-ask salt if have no answer in 60 second.
callback_->on_container_sent(container_message_id, std::move(message_ids));
if (resend_answer_message_id) {
if (resend_answer_message_id != MessageId()) {
container_to_service_message_id_[container_message_id].push_back(resend_answer_message_id);
}
if (get_state_info_message_id) {
if (get_state_info_message_id != MessageId()) {
container_to_service_message_id_[container_message_id].push_back(get_state_info_message_id);
}
}

View File

@ -6,6 +6,7 @@
//
#pragma once
#include "td/mtproto/MessageId.h"
#include "td/mtproto/MtprotoQuery.h"
#include "td/mtproto/PacketInfo.h"
#include "td/mtproto/RawConnection.h"
@ -67,14 +68,14 @@ class SessionConnection final
unique_ptr<RawConnection> move_as_raw_connection();
// Interface
Result<uint64> TD_WARN_UNUSED_RESULT send_query(BufferSlice buffer, bool gzip_flag, uint64 message_id = 0,
vector<uint64> invoke_after_message_ids = {},
bool use_quick_ack = false);
std::pair<uint64, BufferSlice> encrypted_bind(int64 perm_key, int64 nonce, int32 expires_at);
Result<MessageId> TD_WARN_UNUSED_RESULT send_query(BufferSlice buffer, bool gzip_flag, MessageId message_id = {},
vector<MessageId> invoke_after_message_ids = {},
bool use_quick_ack = false);
std::pair<MessageId, BufferSlice> encrypted_bind(int64 perm_key, int64 nonce, int32 expires_at);
void get_state_info(uint64 message_id);
void resend_answer(uint64 message_id);
void cancel_answer(uint64 message_id);
void get_state_info(MessageId message_id);
void resend_answer(MessageId message_id);
void cancel_answer(MessageId message_id);
void destroy_key();
void set_online(bool online_flag, bool is_main);
@ -95,19 +96,20 @@ class SessionConnection final
virtual void on_server_salt_updated() = 0;
virtual void on_server_time_difference_updated(bool force) = 0;
virtual void on_new_session_created(uint64 unique_id, uint64 first_message_id) = 0;
virtual void on_new_session_created(uint64 unique_id, MessageId first_message_id) = 0;
virtual void on_session_failed(Status status) = 0;
virtual void on_container_sent(uint64 container_message_id, vector<uint64> message_ids) = 0;
virtual void on_container_sent(MessageId container_message_id, vector<MessageId> message_ids) = 0;
virtual Status on_pong() = 0;
virtual Status on_update(BufferSlice packet) = 0;
virtual void on_message_ack(uint64 message_id) = 0;
virtual Status on_message_result_ok(uint64 message_id, BufferSlice packet, size_t original_size) = 0;
virtual void on_message_result_error(uint64 message_id, int code, string message) = 0;
virtual void on_message_failed(uint64 message_id, Status status) = 0;
virtual void on_message_info(uint64 message_id, int32 state, uint64 answer_id, int32 answer_size, int32 source) = 0;
virtual void on_message_ack(MessageId message_id) = 0;
virtual Status on_message_result_ok(MessageId message_id, BufferSlice packet, size_t original_size) = 0;
virtual void on_message_result_error(MessageId message_id, int code, string message) = 0;
virtual void on_message_failed(MessageId message_id, Status status) = 0;
virtual void on_message_info(MessageId message_id, int32 state, MessageId answer_message_id, int32 answer_size,
int32 source) = 0;
virtual Status on_destroy_auth_key() = 0;
};
@ -123,7 +125,7 @@ class SessionConnection final
static constexpr double RESEND_ANSWER_DELAY = 0.001; // 0.001s
struct MsgInfo {
uint64 message_id;
MessageId message_id;
int32 seq_no;
size_t size;
};
@ -161,21 +163,21 @@ class SessionConnection final
static constexpr int HTTP_MAX_DELAY = 30; // 0.03s
vector<MtprotoQuery> to_send_;
vector<uint64> to_ack_message_ids_;
vector<MessageId> to_ack_message_ids_;
double force_send_at_ = 0;
struct ServiceQuery {
enum Type { GetStateInfo, ResendAnswer } type_;
uint64 container_message_id_;
MessageId container_message_id_;
vector<int64> msg_ids_;
};
vector<uint64> to_resend_answer_message_ids_;
vector<uint64> to_cancel_answer_message_ids_;
vector<uint64> to_get_state_info_message_ids_;
FlatHashMap<uint64, ServiceQuery> service_queries_;
vector<MessageId> to_resend_answer_message_ids_;
vector<MessageId> to_cancel_answer_message_ids_;
vector<MessageId> to_get_state_info_message_ids_;
FlatHashMap<MessageId, ServiceQuery, MessageIdHash> service_queries_;
// nobody cleans up this map. But it should be really small.
FlatHashMap<uint64, vector<uint64>> container_to_service_message_id_;
FlatHashMap<MessageId, vector<MessageId>, MessageIdHash> container_to_service_message_id_;
double random_delay_ = 0;
double last_read_at_ = 0;
@ -183,9 +185,9 @@ class SessionConnection final
double last_pong_at_ = 0;
double real_last_read_at_ = 0;
double real_last_pong_at_ = 0;
uint64 cur_ping_id_ = 0;
uint64 last_ping_message_id_ = 0;
uint64 last_ping_container_message_id_ = 0;
int64 cur_ping_id_ = 0;
MessageId last_ping_message_id_;
MessageId last_ping_container_message_id_;
uint64 last_read_size_ = 0;
uint64 last_write_size_ = 0;
@ -200,8 +202,8 @@ class SessionConnection final
Mode mode_;
bool connected_flag_ = false;
uint64 container_message_id_ = 0;
uint64 main_message_id_ = 0;
MessageId container_message_id_;
MessageId main_message_id_;
double created_at_ = 0;
unique_ptr<RawConnection> raw_connection_;
@ -218,7 +220,7 @@ class SessionConnection final
};
}
void reset_server_time_difference(uint64 message_id);
void reset_server_time_difference(MessageId message_id);
static Status parse_message(TlParser &parser, MsgInfo *info, Slice *packet,
bool crypto_flag = true) TD_WARN_UNUSED_RESULT;
@ -254,12 +256,12 @@ class SessionConnection final
Status on_slice_packet(const MsgInfo &info, Slice packet) TD_WARN_UNUSED_RESULT;
Status on_main_packet(const PacketInfo &packet_info, Slice packet) TD_WARN_UNUSED_RESULT;
void on_message_failed(uint64 message_id, Status status);
void on_message_failed_inner(uint64 message_id);
void on_message_failed(MessageId message_id, Status status);
void on_message_failed_inner(MessageId message_id);
void do_close(Status status);
void send_ack(uint64 message_id);
void send_ack(MessageId message_id);
void send_crypto(const Storer &storer, uint64 quick_ack_token);
void send_before(double tm);
bool may_ping() const;

View File

@ -8,6 +8,7 @@
#include "td/mtproto/AuthKey.h"
#include "td/mtproto/KDF.h"
#include "td/mtproto/MessageId.h"
#include "td/utils/as.h"
#include "td/utils/crypto.h"
@ -42,7 +43,7 @@ struct CryptoHeader {
// It is weird to generate message_id and seq_no while writing a packet.
//
// uint64 message_id;
// uint64 msg_id;
// uint32 seq_no;
// uint32 message_data_length;
uint8 data[0]; // use compiler extension
@ -68,7 +69,7 @@ struct CryptoHeader {
};
struct CryptoPrefix {
uint64 message_id;
uint64 msg_id;
uint32 seq_no;
uint32 message_data_length;
};
@ -108,9 +109,9 @@ struct EndToEndPrefix {
struct NoCryptoHeader {
uint64 auth_key_id;
// message_id is removed from CryptoHeader. Should be removed from here too.
// msg_id is removed from CryptoHeader. Should be removed from here too.
//
// uint64 message_id;
// uint64 msg_id;
// uint32 message_data_length;
uint8 data[0]; // use compiler extension
@ -309,7 +310,7 @@ Status Transport::read_crypto(MutableSlice message, const AuthKey &auth_key, Pac
packet_info->type = PacketInfo::Common;
packet_info->salt = header->salt;
packet_info->session_id = header->session_id;
packet_info->message_id = prefix->message_id;
packet_info->message_id = MessageId(prefix->msg_id);
packet_info->seq_no = prefix->seq_no;
return Status::OK();
}

View File

@ -534,7 +534,7 @@ void Session::hangup() {
}
void Session::raw_event(const Event::Raw &event) {
auto message_id = event.u64;
auto message_id = mtproto::MessageId(event.u64);
auto it = sent_queries_.find(message_id);
if (it == sent_queries_.end()) {
return;
@ -552,7 +552,7 @@ void Session::raw_event(const Event::Raw &event) {
if (main_connection_.state_ == ConnectionInfo::State::Ready) {
main_connection_.connection_->cancel_answer(message_id);
} else {
to_cancel_.push_back(message_id);
to_cancel_message_ids_.push_back(message_id);
}
loop();
}
@ -697,8 +697,8 @@ void Session::on_closed(Status status) {
current_info_->state_ = ConnectionInfo::State::Empty;
}
void Session::on_new_session_created(uint64 unique_id, uint64 first_message_id) {
LOG(INFO) << "New session " << unique_id << " created with first message_id " << format::as_hex(first_message_id);
void Session::on_new_session_created(uint64 unique_id, mtproto::MessageId first_message_id) {
LOG(INFO) << "New session " << unique_id << " created with first " << first_message_id;
if (!use_pfs_ && !auth_data_.use_pfs()) {
last_success_timestamp_ = Time::now();
}
@ -712,9 +712,9 @@ void Session::on_new_session_created(uint64 unique_id, uint64 first_message_id)
auto first_query_it = sent_queries_.find(first_message_id);
if (first_query_it != sent_queries_.end()) {
first_message_id = first_query_it->second.container_message_id_;
LOG(INFO) << "Update first_message_id to container's " << format::as_hex(first_message_id);
LOG(INFO) << "Update first message to container's " << first_message_id;
} else {
LOG(INFO) << "Failed to find query " << format::as_hex(first_message_id) << " from the new session";
LOG(INFO) << "Failed to find sent " << first_message_id << " from the new session";
}
for (auto it = sent_queries_.begin(); it != sent_queries_.end();) {
Query *query_ptr = &it->second;
@ -741,10 +741,10 @@ void Session::on_session_failed(Status status) {
callback_->on_failed();
}
void Session::on_container_sent(uint64 container_message_id, vector<uint64> message_ids) {
CHECK(container_message_id != 0);
void Session::on_container_sent(mtproto::MessageId container_message_id, vector<mtproto::MessageId> message_ids) {
CHECK(container_message_id != mtproto::MessageId());
td::remove_if(message_ids, [&](uint64 message_id) {
td::remove_if(message_ids, [&](mtproto::MessageId message_id) {
auto it = sent_queries_.find(message_id);
if (it == sent_queries_.end()) {
return true; // remove
@ -759,11 +759,11 @@ void Session::on_container_sent(uint64 container_message_id, vector<uint64> mess
sent_containers_.emplace(container_message_id, ContainerInfo{size, std::move(message_ids)});
}
void Session::on_message_ack(uint64 message_id) {
void Session::on_message_ack(mtproto::MessageId message_id) {
on_message_ack_impl(message_id, 1);
}
void Session::on_message_ack_impl(uint64 container_message_id, int32 type) {
void Session::on_message_ack_impl(mtproto::MessageId container_message_id, int32 type) {
auto cit = sent_containers_.find(container_message_id);
if (cit != sent_containers_.end()) {
auto container_info = std::move(cit->second);
@ -778,7 +778,7 @@ void Session::on_message_ack_impl(uint64 container_message_id, int32 type) {
on_message_ack_impl_inner(container_message_id, type, false);
}
void Session::on_message_ack_impl_inner(uint64 message_id, int32 type, bool in_container) {
void Session::on_message_ack_impl_inner(mtproto::MessageId message_id, int32 type, bool in_container) {
auto it = sent_queries_.find(message_id);
if (it == sent_queries_.end()) {
return;
@ -796,7 +796,7 @@ void Session::on_message_ack_impl_inner(uint64 message_id, int32 type, bool in_c
mark_as_known(it->first, &it->second);
}
void Session::dec_container(uint64 container_message_id, Query *query) {
void Session::dec_container(mtproto::MessageId container_message_id, Query *query) {
if (query->container_message_id_ == container_message_id) {
// message was sent without any container
return;
@ -812,7 +812,7 @@ void Session::dec_container(uint64 container_message_id, Query *query) {
}
}
void Session::cleanup_container(uint64 container_message_id, Query *query) {
void Session::cleanup_container(mtproto::MessageId container_message_id, Query *query) {
if (query->container_message_id_ == container_message_id) {
// message was sent without any container
return;
@ -823,7 +823,7 @@ void Session::cleanup_container(uint64 container_message_id, Query *query) {
sent_containers_.erase(query->container_message_id_);
}
void Session::mark_as_known(uint64 message_id, Query *query) {
void Session::mark_as_known(mtproto::MessageId message_id, Query *query) {
{
auto lock = query->net_query_->lock();
query->net_query_->get_data_unsafe().unknown_state_ = false;
@ -839,7 +839,7 @@ void Session::mark_as_known(uint64 message_id, Query *query) {
}
}
void Session::mark_as_unknown(uint64 message_id, Query *query) {
void Session::mark_as_unknown(mtproto::MessageId message_id, Query *query) {
{
auto lock = query->net_query_->lock();
query->net_query_->get_data_unsafe().unknown_state_ = true;
@ -849,7 +849,7 @@ void Session::mark_as_unknown(uint64 message_id, Query *query) {
}
VLOG(net_query) << "Mark as unknown " << query->net_query_;
query->is_unknown_ = true;
CHECK(message_id != 0);
CHECK(message_id != mtproto::MessageId());
unknown_queries_.insert(message_id);
}
@ -866,16 +866,16 @@ Status Session::on_update(BufferSlice packet) {
return Status::OK();
}
Status Session::on_message_result_ok(uint64 message_id, BufferSlice packet, size_t original_size) {
Status Session::on_message_result_ok(mtproto::MessageId message_id, BufferSlice packet, size_t original_size) {
last_success_timestamp_ = Time::now();
TlParser parser(packet.as_slice());
int32 response_id = parser.fetch_int();
int32 response_tl_id = parser.fetch_int();
auto it = sent_queries_.find(message_id);
if (it == sent_queries_.end()) {
LOG(DEBUG) << "Drop result to " << tag("message_id", format::as_hex(message_id))
<< tag("original_size", original_size) << tag("response_id", format::as_hex(response_id));
LOG(DEBUG) << "Drop result to " << message_id << tag("original_size", original_size)
<< tag("response_tl", format::as_hex(response_tl_id));
if (original_size > 16 * 1024) {
dropped_size_ += original_size;
@ -896,9 +896,9 @@ Status Session::on_message_result_ok(uint64 message_id, BufferSlice packet, size
if (!parser.get_error()) {
// Steal authorization information.
// It is a dirty hack, yep.
if (response_id == telegram_api::auth_authorization::ID ||
response_id == telegram_api::auth_loginTokenSuccess::ID ||
response_id == telegram_api::auth_sentCodeSuccess::ID) {
if (response_tl_id == telegram_api::auth_authorization::ID ||
response_tl_id == telegram_api::auth_loginTokenSuccess::ID ||
response_tl_id == telegram_api::auth_sentCodeSuccess::ID) {
if (query_ptr->net_query_->tl_constructor() != telegram_api::auth_importAuthorization::ID) {
G()->net_query_dispatcher().set_main_dc_id(raw_dc_id_);
}
@ -918,7 +918,7 @@ Status Session::on_message_result_ok(uint64 message_id, BufferSlice packet, size
return Status::OK();
}
void Session::on_message_result_error(uint64 message_id, int error_code, string message) {
void Session::on_message_result_error(mtproto::MessageId message_id, int error_code, string message) {
if (!check_utf8(message)) {
LOG(ERROR) << "Receive invalid error message \"" << message << '"';
message = "INVALID_UTF8_ERROR_MESSAGE";
@ -970,7 +970,7 @@ void Session::on_message_result_error(uint64 message_id, int error_code, string
error_code = 500;
}
if (message_id == 0) {
if (message_id == mtproto::MessageId()) {
LOG(ERROR) << "Receive an error without message_id";
return;
}
@ -998,8 +998,8 @@ void Session::on_message_result_error(uint64 message_id, int error_code, string
sent_queries_.erase(it);
}
void Session::on_message_failed_inner(uint64 message_id, bool in_container) {
LOG(INFO) << "Message inner failed " << message_id;
void Session::on_message_failed_inner(mtproto::MessageId message_id, bool in_container) {
LOG(INFO) << "Message inner failed for " << message_id;
auto it = sent_queries_.find(message_id);
if (it == sent_queries_.end()) {
return;
@ -1016,8 +1016,8 @@ void Session::on_message_failed_inner(uint64 message_id, bool in_container) {
sent_queries_.erase(it);
}
void Session::on_message_failed(uint64 message_id, Status status) {
LOG(INFO) << "Message failed: " << tag("message_id", message_id) << tag("status", status);
void Session::on_message_failed(mtproto::MessageId message_id, Status status) {
LOG(INFO) << "Failed to send " << message_id << ": " << status;
status.ignore();
auto cit = sent_containers_.find(message_id);
@ -1034,8 +1034,8 @@ void Session::on_message_failed(uint64 message_id, Status status) {
on_message_failed_inner(message_id, false);
}
void Session::on_message_info(uint64 message_id, int32 state, uint64 answer_message_id, int32 answer_size,
int32 source) {
void Session::on_message_info(mtproto::MessageId message_id, int32 state, mtproto::MessageId answer_message_id,
int32 answer_size, int32 source) {
auto it = sent_queries_.find(message_id);
if (it != sent_queries_.end()) {
if (it->second.net_query_->update_is_ready()) {
@ -1049,7 +1049,7 @@ void Session::on_message_info(uint64 message_id, int32 state, uint64 answer_mess
return;
}
}
if (message_id != 0) {
if (message_id != mtproto::MessageId()) {
if (it == sent_queries_.end()) {
return;
}
@ -1060,15 +1060,16 @@ void Session::on_message_info(uint64 message_id, int32 state, uint64 answer_mess
return on_message_failed(message_id,
Status::Error("Message wasn't received by the server and must be re-sent"));
case 0:
if (answer_message_id == 0) {
LOG(ERROR) << "Unexpected message_info.state == 0 " << tag("message_id", message_id) << tag("state", state)
<< tag("answer_message_id", answer_message_id);
if (answer_message_id == mtproto::MessageId()) {
LOG(ERROR) << "Unexpected message_info.state == 0 for " << message_id << ": " << tag("state", state)
<< tag("answer", answer_message_id);
return on_message_failed(message_id, Status::Error("Unexpected message_info.state == 0"));
}
// fallthrough
case 4:
CHECK(0 <= source && source <= 3);
on_message_ack_impl(message_id, (answer_message_id ? 2 : 0) | (((state | source) & ((1 << 28) - 1)) << 2));
on_message_ack_impl(message_id, (answer_message_id != mtproto::MessageId() ? 2 : 0) |
(((state | source) & ((1 << 28) - 1)) << 2));
break;
default:
LOG(ERROR) << "Invalid message info " << tag("state", state);
@ -1076,10 +1077,10 @@ void Session::on_message_info(uint64 message_id, int32 state, uint64 answer_mess
}
// ok, we are waiting for result of message_id. let's ask to resend it
if (answer_message_id != 0) {
if (answer_message_id != mtproto::MessageId()) {
if (it != sent_queries_.end()) {
VLOG_IF(net_query, message_id != 0) << "Resend answer " << tag("answer_message_id", answer_message_id)
<< tag("answer_size", answer_size) << it->second.net_query_;
VLOG_IF(net_query, message_id != mtproto::MessageId())
<< "Resend answer " << answer_message_id << ": " << tag("answer_size", answer_size) << it->second.net_query_;
it->second.net_query_->debug(PSTRING() << get_name() << ": resend answer");
}
current_info_->connection_->resend_answer(answer_message_id);
@ -1114,7 +1115,7 @@ void Session::add_query(NetQueryPtr &&net_query) {
pending_queries_.push(std::move(net_query));
}
void Session::connection_send_query(ConnectionInfo *info, NetQueryPtr &&net_query, uint64 message_id) {
void Session::connection_send_query(ConnectionInfo *info, NetQueryPtr &&net_query, mtproto::MessageId message_id) {
CHECK(info->state_ == ConnectionInfo::State::Ready);
current_info_ = info;
@ -1123,10 +1124,10 @@ void Session::connection_send_query(ConnectionInfo *info, NetQueryPtr &&net_quer
}
Span<NetQueryRef> invoke_after = net_query->invoke_after();
vector<uint64> invoke_after_message_ids;
vector<mtproto::MessageId> invoke_after_message_ids;
for (auto &ref : invoke_after) {
auto invoke_after_message_id = ref->message_id();
if (ref->session_id() != auth_data_.get_session_id() || invoke_after_message_id == 0) {
auto invoke_after_message_id = mtproto::MessageId(ref->message_id());
if (ref->session_id() != auth_data_.get_session_id() || invoke_after_message_id == mtproto::MessageId()) {
net_query->set_error_resend_invoke_after();
return return_query(std::move(net_query));
}
@ -1155,30 +1156,27 @@ void Session::connection_send_query(ConnectionInfo *info, NetQueryPtr &&net_quer
}
message_id = r_message_id.ok();
} else {
if (message_id == 0) {
if (message_id == mtproto::MessageId()) {
message_id = auth_data_.next_message_id(now);
}
}
net_query->set_message_id(message_id);
VLOG(net_query) << "Send query to connection " << net_query
<< tag("invoke_after", transform(invoke_after_message_ids, [](auto message_id) {
return PSTRING() << format::as_hex(message_id);
}));
net_query->set_message_id(message_id.get());
VLOG(net_query) << "Send query to connection " << net_query << tag("invoke_after", invoke_after_message_ids);
{
auto lock = net_query->lock();
net_query->get_data_unsafe().unknown_state_ = false;
net_query->get_data_unsafe().ack_state_ = 0;
}
if (!net_query->cancel_slot_.empty()) {
LOG(DEBUG) << "Set event for net_query cancellation " << tag("message_id", format::as_hex(message_id));
net_query->cancel_slot_.set_event(EventCreator::raw(actor_id(), message_id));
LOG(DEBUG) << "Set event for net_query cancellation for " << message_id;
net_query->cancel_slot_.set_event(EventCreator::raw(actor_id(), message_id.get()));
}
auto status =
sent_queries_.emplace(message_id, Query{message_id, std::move(net_query), main_connection_.connection_id_, now});
LOG_CHECK(status.second) << message_id;
sent_queries_list_.put(status.first->second.get_list_node());
if (!status.second) {
LOG(FATAL) << "Duplicate message_id [message_id = " << message_id << "]";
LOG(FATAL) << "Duplicate " << message_id;
}
if (immediately_fail_query) {
on_message_result_error(message_id, 401, "TEST_ERROR");
@ -1308,10 +1306,10 @@ void Session::connection_open_finish(ConnectionInfo *info,
for (auto &message_id : unknown_queries_) {
info->connection_->get_state_info(message_id);
}
for (auto &message_id : to_cancel_) {
for (auto &message_id : to_cancel_message_ids_) {
info->connection_->cancel_answer(message_id);
}
to_cancel_.clear();
to_cancel_message_ids_.clear();
}
yield();
}
@ -1378,7 +1376,7 @@ bool Session::connection_send_bind_key(ConnectionInfo *info) {
int64 perm_auth_key_id = auth_data_.get_main_auth_key().id();
int64 nonce = Random::secure_int64();
auto expires_at = static_cast<int32>(auth_data_.get_server_time(auth_data_.get_tmp_auth_key().expires_at()));
uint64 message_id;
mtproto::MessageId message_id;
BufferSlice encrypted;
std::tie(message_id, encrypted) = info->connection_->encrypted_bind(perm_auth_key_id, nonce, expires_at);

View File

@ -14,6 +14,7 @@
#include "td/mtproto/AuthKey.h"
#include "td/mtproto/ConnectionManager.h"
#include "td/mtproto/Handshake.h"
#include "td/mtproto/MessageId.h"
#include "td/mtproto/SessionConnection.h"
#include "td/actor/actor.h"
@ -78,7 +79,7 @@ class Session final
private:
struct Query final : private ListNode {
uint64 container_message_id_;
mtproto::MessageId container_message_id_;
NetQueryPtr net_query_;
bool is_acknowledged_ = false;
@ -87,7 +88,7 @@ class Session final
const int8 connection_id_;
const double sent_at_;
Query(uint64 message_id, NetQueryPtr &&net_query, int8 connection_id, double sent_at)
Query(mtproto::MessageId message_id, NetQueryPtr &&net_query, int8 connection_id, double sent_at)
: container_message_id_(message_id)
, net_query_(std::move(net_query))
, connection_id_(connection_id)
@ -131,8 +132,8 @@ class Session final
double last_bind_success_timestamp_ = 0; // time when auth_key and Session definitely was valid and authorized
size_t dropped_size_ = 0;
FlatHashSet<uint64> unknown_queries_;
vector<int64> to_cancel_;
FlatHashSet<mtproto::MessageId, mtproto::MessageIdHash> unknown_queries_;
vector<mtproto::MessageId> to_cancel_message_ids_;
// Do not invalidate iterators of these two containers!
// TODO: better data structures
@ -145,7 +146,7 @@ class Session final
std::map<int8, VectorQueue<NetQueryPtr>, std::greater<>> queries_;
};
PriorityQueue pending_queries_;
std::map<uint64, Query> sent_queries_;
std::map<mtproto::MessageId, Query> sent_queries_;
std::deque<NetQueryPtr> pending_invoke_after_queries_;
ListNode sent_queries_list_;
@ -181,9 +182,9 @@ class Session final
struct ContainerInfo {
size_t ref_cnt;
vector<uint64> message_ids;
vector<mtproto::MessageId> message_ids;
};
FlatHashMap<uint64, ContainerInfo> sent_containers_;
FlatHashMap<mtproto::MessageId, ContainerInfo, mtproto::MessageIdHash> sent_containers_;
friend class GenAuthKeyActor;
struct HandshakeInfo {
@ -214,33 +215,34 @@ class Session final
void on_server_salt_updated() final;
void on_server_time_difference_updated(bool force) final;
void on_new_session_created(uint64 unique_id, uint64 first_message_id) final;
void on_new_session_created(uint64 unique_id, mtproto::MessageId first_message_id) final;
void on_session_failed(Status status) final;
void on_container_sent(uint64 container_message_id, vector<uint64> message_ids) final;
void on_container_sent(mtproto::MessageId container_message_id, vector<mtproto::MessageId> message_ids) final;
Status on_update(BufferSlice packet) final;
void on_message_ack(uint64 message_id) final;
Status on_message_result_ok(uint64 message_id, BufferSlice packet, size_t original_size) final;
void on_message_result_error(uint64 message_id, int error_code, string message) final;
void on_message_failed(uint64 message_id, Status status) final;
void on_message_ack(mtproto::MessageId message_id) final;
Status on_message_result_ok(mtproto::MessageId message_id, BufferSlice packet, size_t original_size) final;
void on_message_result_error(mtproto::MessageId message_id, int error_code, string message) final;
void on_message_failed(mtproto::MessageId message_id, Status status) final;
void on_message_info(uint64 message_id, int32 state, uint64 answer_message_id, int32 answer_size, int32 source) final;
void on_message_info(mtproto::MessageId message_id, int32 state, mtproto::MessageId answer_message_id,
int32 answer_size, int32 source) final;
Status on_destroy_auth_key() final;
void flush_pending_invoke_after_queries();
bool has_queries() const;
void dec_container(uint64 container_message_id, Query *query);
void cleanup_container(uint64 container_message_id, Query *query);
void mark_as_known(uint64 message_id, Query *query);
void mark_as_unknown(uint64 message_id, Query *query);
void dec_container(mtproto::MessageId container_message_id, Query *query);
void cleanup_container(mtproto::MessageId container_message_id, Query *query);
void mark_as_known(mtproto::MessageId message_id, Query *query);
void mark_as_unknown(mtproto::MessageId message_id, Query *query);
void on_message_ack_impl(uint64 container_message_id, int32 type);
void on_message_ack_impl_inner(uint64 message_id, int32 type, bool in_container);
void on_message_failed_inner(uint64 message_id, bool in_container);
void on_message_ack_impl(mtproto::MessageId container_message_id, int32 type);
void on_message_ack_impl_inner(mtproto::MessageId message_id, int32 type, bool in_container);
void on_message_failed_inner(mtproto::MessageId message_id, bool in_container);
// send NetQueryPtr to parent
void return_query(NetQueryPtr &&query);
@ -255,7 +257,7 @@ class Session final
void connection_online_update(double now, bool force);
void connection_close(ConnectionInfo *info);
void connection_flush(ConnectionInfo *info);
void connection_send_query(ConnectionInfo *info, NetQueryPtr &&net_query, uint64 message_id = 0);
void connection_send_query(ConnectionInfo *info, NetQueryPtr &&net_query, mtproto::MessageId message_id = {});
bool need_send_bind_key() const;
bool need_send_query() const;
bool can_destroy_auth_key() const;