Fix message_id type.

This commit is contained in:
levlam 2023-05-14 22:42:58 +03:00
parent c9f83caf9f
commit f29774acc6
8 changed files with 45 additions and 44 deletions

View File

@ -17,7 +17,7 @@
namespace td {
namespace mtproto {
Status check_message_id_duplicates(int64 *saved_message_ids, size_t max_size, size_t &end_pos, int64 message_id) {
Status check_message_id_duplicates(uint64 *saved_message_ids, size_t max_size, size_t &end_pos, uint64 message_id) {
// In addition, the identifiers (msg_id) of the last N messages received from the other side must be stored, and if
// a message comes in with msg_id lower than all or equal to any of the stored values, that message is to be
// ignored. Otherwise, the new message msg_id is added to the set, and, if the number of stored msg_id values is
@ -98,9 +98,9 @@ std::vector<ServerSalt> AuthData::get_future_salts() const {
return res;
}
int64 AuthData::next_message_id(double now) {
uint64 AuthData::next_message_id(double now) {
double server_time = get_server_time(now);
auto t = static_cast<int64>(server_time * (static_cast<int64>(1) << 32));
auto t = static_cast<uint64>(server_time * (static_cast<uint64>(1) << 32));
// randomize lower bits for clocks with low precision
// TODO(perf) do not do this for systems with good precision?..
@ -109,7 +109,7 @@ int64 AuthData::next_message_id(double now) {
auto to_mul = ((rx >> 22) & 1023) + 1;
t ^= to_xor;
auto result = t & -4;
auto result = t & static_cast<uint64>(-4);
if (last_message_id_ >= result) {
result = last_message_id_ + 8 * to_mul;
}
@ -117,19 +117,19 @@ int64 AuthData::next_message_id(double now) {
return result;
}
bool AuthData::is_valid_outbound_msg_id(int64 id, double now) const {
bool AuthData::is_valid_outbound_msg_id(uint64 message_id, double now) const {
double server_time = get_server_time(now);
auto id_time = static_cast<double>(id) / static_cast<double>(static_cast<int64>(1) << 32);
auto id_time = static_cast<double>(message_id) / static_cast<double>(static_cast<uint64>(1) << 32);
return server_time - 150 < id_time && id_time < server_time + 30;
}
bool AuthData::is_valid_inbound_msg_id(int64 id, double now) const {
bool AuthData::is_valid_inbound_msg_id(uint64 message_id, double now) const {
double server_time = get_server_time(now);
auto id_time = static_cast<double>(id) / static_cast<double>(static_cast<int64>(1) << 32);
auto id_time = static_cast<double>(message_id) / static_cast<double>(static_cast<uint64>(1) << 32);
return server_time - 300 < id_time && id_time < server_time + 30;
}
Status AuthData::check_packet(int64 session_id, int64 message_id, double now, bool &time_difference_was_updated) {
Status AuthData::check_packet(int64 session_id, uint64 message_id, double now, bool &time_difference_was_updated) {
// Client is to check that the session_id field in the decrypted message indeed equals to that of an active session
// created by the client.
if (get_session_id() != static_cast<uint64>(session_id)) {

View File

@ -37,17 +37,17 @@ void parse(ServerSalt &salt, ParserT &parser) {
salt.valid_until = parser.fetch_double();
}
Status check_message_id_duplicates(int64 *saved_message_ids, size_t max_size, size_t &end_pos, int64 message_id);
Status check_message_id_duplicates(uint64 *saved_message_ids, size_t max_size, size_t &end_pos, uint64 message_id);
template <size_t max_size>
class MessageIdDuplicateChecker {
public:
Status check(int64 message_id) {
Status check(uint64 message_id) {
return check_message_id_duplicates(&saved_message_ids_[0], max_size, end_pos_, message_id);
}
private:
std::array<int64, 2 * max_size> saved_message_ids_;
std::array<uint64, 2 * max_size> saved_message_ids_;
size_t end_pos_ = 0;
};
@ -239,19 +239,19 @@ class AuthData {
std::vector<ServerSalt> get_future_salts() const;
int64 next_message_id(double now);
uint64 next_message_id(double now);
bool is_valid_outbound_msg_id(int64 id, double now) const;
bool is_valid_outbound_msg_id(uint64 message_id, double now) const;
bool is_valid_inbound_msg_id(int64 id, double now) const;
bool is_valid_inbound_msg_id(uint64 message_id, double now) const;
Status check_packet(int64 session_id, int64 message_id, double now, bool &time_difference_was_updated);
Status check_packet(int64 session_id, uint64 message_id, double now, bool &time_difference_was_updated);
Status check_update(int64 message_id) {
Status check_update(uint64 message_id) {
return updates_duplicate_checker_.check(message_id);
}
Status recheck_update(int64 message_id) {
Status recheck_update(uint64 message_id) {
return updates_duplicate_rechecker_.check(message_id);
}
@ -282,9 +282,9 @@ class AuthData {
bool server_time_difference_was_updated_ = false;
double server_time_difference_ = 0;
ServerSalt server_salt_;
int64 last_message_id_ = 0;
uint64 last_message_id_ = 0;
int32 seq_no_ = 0;
std::string header_;
string header_;
uint64 session_id_ = 0;
std::vector<ServerSalt> future_salts_;

View File

@ -116,7 +116,7 @@ class InvokeAfter {
}
if (message_ids_.size() == 1) {
storer.store_int(static_cast<int32>(0xcb9f372d));
storer.store_long(static_cast<int64>(message_ids_[0]));
storer.store_binary(message_ids_[0]);
return;
}
// invokeAfterMsgs#3dc4b4f0 {X:Type} msg_ids:Vector<long> query:!X = X;
@ -124,7 +124,7 @@ class InvokeAfter {
storer.store_int(static_cast<int32>(0x1cb5c415));
storer.store_int(narrow_cast<int32>(message_ids_.size()));
for (auto message_id : message_ids_) {
storer.store_long(static_cast<int64>(message_id));
storer.store_binary(message_id);
}
}

View File

@ -13,11 +13,11 @@ namespace td {
namespace mtproto {
struct MtprotoQuery {
int64 message_id;
uint64 message_id;
int32 seq_no;
BufferSlice packet;
bool gzip_flag;
std::vector<uint64> invoke_after_ids;
vector<uint64> invoke_after_ids;
bool use_quick_ack;
};

View File

@ -302,7 +302,8 @@ Status SessionConnection::on_packet(const MsgInfo &info, const mtproto_api::new_
Status SessionConnection::on_packet(const MsgInfo &info,
const mtproto_api::bad_msg_notification &bad_msg_notification) {
MsgInfo bad_info{info.session_id, bad_msg_notification.bad_msg_id_, bad_msg_notification.bad_msg_seqno_, 0};
MsgInfo bad_info{info.session_id, static_cast<uint64>(bad_msg_notification.bad_msg_id_),
bad_msg_notification.bad_msg_seqno_, 0};
enum Code {
MsgIdTooLow = 16,
MsgIdTooHigh = 17,
@ -381,7 +382,8 @@ Status SessionConnection::on_packet(const MsgInfo &info,
}
Status SessionConnection::on_packet(const MsgInfo &info, const mtproto_api::bad_server_salt &bad_server_salt) {
MsgInfo bad_info{info.session_id, bad_server_salt.bad_msg_id_, bad_server_salt.bad_msg_seqno_, 0};
MsgInfo bad_info{info.session_id, static_cast<uint64>(bad_server_salt.bad_msg_id_), bad_server_salt.bad_msg_seqno_,
0};
VLOG(mtproto) << "BAD_SERVER_SALT: " << bad_info;
auth_data_->set_server_salt(bad_server_salt.new_server_salt_, Time::now_cached());
callback_->on_server_salt_updated();
@ -434,7 +436,7 @@ Status SessionConnection::on_msgs_state_info(const vector<int64> &message_ids, S
}
size_t i = 0;
for (auto message_id : message_ids) {
callback_->on_message_info(message_id, info[i], 0, 0);
callback_->on_message_info(static_cast<uint64>(message_id), info[i], 0, 0);
i++;
}
return Status::OK();
@ -780,7 +782,7 @@ void SessionConnection::send_crypto(const Storer &storer, uint64 quick_ack_token
auth_data_->get_auth_key(), quick_ack_token);
}
Result<uint64> SessionConnection::send_query(BufferSlice buffer, bool gzip_flag, int64 message_id,
Result<uint64> SessionConnection::send_query(BufferSlice buffer, bool gzip_flag, uint64 message_id,
vector<uint64> invoke_after_ids, bool use_quick_ack) {
CHECK(mode_ != Mode::HttpLongPoll); // "LongPoll connection is only for http_wait"
if (message_id == 0) {
@ -798,24 +800,24 @@ Result<uint64> SessionConnection::send_query(BufferSlice buffer, bool gzip_flag,
return message_id;
}
void SessionConnection::get_state_info(int64 message_id) {
void SessionConnection::get_state_info(uint64 message_id) {
if (to_get_state_info_.empty()) {
send_before(Time::now_cached());
}
to_get_state_info_.push_back(message_id);
to_get_state_info_.push_back(static_cast<int64>(message_id));
}
void SessionConnection::resend_answer(int64 message_id) {
void SessionConnection::resend_answer(uint64 message_id) {
if (to_resend_answer_.empty()) {
send_before(Time::now_cached() + RESEND_ANSWER_DELAY);
}
to_resend_answer_.push_back(message_id);
to_resend_answer_.push_back(static_cast<int64>(message_id));
}
void SessionConnection::cancel_answer(int64 message_id) {
void SessionConnection::cancel_answer(uint64 message_id) {
if (to_cancel_answer_.empty()) {
send_before(Time::now_cached() + RESEND_ANSWER_DELAY);
}
to_cancel_answer_.push_back(message_id);
to_cancel_answer_.push_back(static_cast<int64>(message_id));
}
void SessionConnection::destroy_key() {

View File

@ -30,7 +30,6 @@ namespace td {
extern int VERBOSITY_NAME(mtproto);
namespace mtproto_api {
class rpc_error;
class new_session_created;
class bad_msg_notification;
@ -55,7 +54,7 @@ class AuthData;
struct MsgInfo {
uint64 session_id;
int64 message_id;
uint64 message_id;
int32 seq_no;
size_t size;
};
@ -81,13 +80,13 @@ class SessionConnection final
unique_ptr<RawConnection> move_as_raw_connection();
// Interface
Result<uint64> TD_WARN_UNUSED_RESULT send_query(BufferSlice buffer, bool gzip_flag, int64 message_id = 0,
Result<uint64> TD_WARN_UNUSED_RESULT send_query(BufferSlice buffer, bool gzip_flag, uint64 message_id = 0,
std::vector<uint64> invoke_after_id = {}, bool use_quick_ack = false);
std::pair<uint64, BufferSlice> encrypted_bind(int64 perm_key, int64 nonce, int32 expires_at);
void get_state_info(int64 message_id);
void resend_answer(int64 message_id);
void cancel_answer(int64 message_id);
void get_state_info(uint64 message_id);
void resend_answer(uint64 message_id);
void cancel_answer(uint64 message_id);
void destroy_key();
void set_online(bool online_flag, bool is_main);
@ -187,7 +186,7 @@ class SessionConnection final
double last_pong_at_ = 0;
double real_last_read_at_ = 0;
double real_last_pong_at_ = 0;
int64 cur_ping_id_ = 0;
uint64 cur_ping_id_ = 0;
uint64 last_ping_message_id_ = 0;
uint64 last_ping_container_id_ = 0;
@ -205,7 +204,7 @@ class SessionConnection final
bool connected_flag_ = false;
uint64 container_id_ = 0;
int64 main_message_id_ = 0;
uint64 main_message_id_ = 0;
double created_at_ = 0;
unique_ptr<RawConnection> raw_connection_;

View File

@ -111,7 +111,7 @@ struct NoCryptoHeader {
// message_id is removed from CryptoHeader. Should be removed from here too.
//
// int64 message_id;
// uint64 message_id;
// uint32 message_data_length;
uint8 data[0]; // use compiler extension

View File

@ -1382,7 +1382,7 @@ bool Session::connection_send_bind_key(ConnectionInfo *info) {
int64 perm_auth_key_id = auth_data_.get_main_auth_key().id();
int64 nonce = Random::secure_int64();
auto expires_at = static_cast<int32>(auth_data_.get_server_time(auth_data_.get_tmp_auth_key().expires_at()));
int64 message_id;
uint64 message_id;
BufferSlice encrypted;
std::tie(message_id, encrypted) = info->connection_->encrypted_bind(perm_auth_key_id, nonce, expires_at);