diff --git a/CMakeLists.txt b/CMakeLists.txt index 9c45963bd..1040f8a33 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -271,6 +271,7 @@ set(TL_C_SCHEME_SOURCE ) set(TDLIB_SOURCE + td/mtproto/AuthData.cpp td/mtproto/crypto.cpp td/mtproto/Handshake.cpp td/mtproto/HandshakeActor.cpp diff --git a/td/mtproto/AuthData.cpp b/td/mtproto/AuthData.cpp new file mode 100644 index 000000000..6fff40535 --- /dev/null +++ b/td/mtproto/AuthData.cpp @@ -0,0 +1,164 @@ +// +// Copyright Aliaksei Levin (levlam@telegram.org), Arseny Smirnov (arseny30@gmail.com) 2014-2018 +// +// Distributed under the Boost Software License, Version 1.0. (See accompanying +// file LICENSE_1_0.txt or copy at http://www.boost.org/LICENSE_1_0.txt) +// +#include "td/mtproto/AuthData.h" + +#include "td/utils/format.h" +#include "td/utils/Random.h" +#include "td/utils/Time.h" + +#include + +namespace td { +namespace mtproto { + +Status MessageIdDuplicateChecker::check(int64 message_id) { + // In addition, the identifiers (msg_id) of the last N messages received from the other side must be stored, and if + // a message comes in with msg_id lower than all or equal to any of the stored values, that message is to be + // ignored. Otherwise, the new message msg_id is added to the set, and, if the number of stored msg_id values is + // greater than N, the oldest (i. e. the lowest) is forgotten. + if (saved_message_ids_.size() == MAX_SAVED_MESSAGE_IDS) { + auto oldest_message_id = *saved_message_ids_.begin(); + if (message_id < oldest_message_id) { + return Status::Error(1, PSLICE() << "Ignore very old message_id " << tag("oldest message_id", oldest_message_id) + << tag("got message_id", message_id)); + } + } + if (saved_message_ids_.count(message_id) != 0) { + return Status::Error(1, PSLICE() << "Ignore duplicated_message id " << tag("message_id", message_id)); + } + + saved_message_ids_.insert(message_id); + if (saved_message_ids_.size() > MAX_SAVED_MESSAGE_IDS) { + saved_message_ids_.erase(saved_message_ids_.begin()); + } + return Status::OK(); +} + +AuthData::AuthData() { + server_salt_.salt = Random::secure_int64(); + server_salt_.valid_since = -1e10; + server_salt_.valid_until = -1e10; +} + +bool AuthData::is_ready(double now) { + if (!has_main_auth_key()) { + LOG(INFO) << "Need main auth key"; + return false; + } + if (use_pfs() && !has_tmp_auth_key(now)) { + LOG(INFO) << "Need tmp auth key"; + return false; + } + if (!has_salt(now)) { + LOG(INFO) << "no salt"; + return false; + } + return true; +} + +bool AuthData::update_server_time_difference(double diff) { + if (!server_time_difference_was_updated_) { + server_time_difference_was_updated_ = true; + LOG(DEBUG) << "UPDATE_SERVER_TIME_DIFFERENCE: " << server_time_difference_ << " -> " << diff; + server_time_difference_ = diff; + } else if (server_time_difference_ + 1e-4 < diff) { + LOG(DEBUG) << "UPDATE_SERVER_TIME_DIFFERENCE: " << server_time_difference_ << " -> " << diff; + server_time_difference_ = diff; + } else { + return false; + } + LOG(DEBUG) << "SERVER_TIME: " << format::as_hex(static_cast(get_server_time(Time::now_cached()))); + return true; +} + +void AuthData::set_future_salts(const std::vector &salts, double now) { + if (salts.empty()) { + return; + } + future_salts_ = salts; + std::sort(future_salts_.begin(), future_salts_.end(), + [](const ServerSalt &a, const ServerSalt &b) { return a.valid_since > b.valid_since; }); + update_salt(now); +} + +std::vector AuthData::get_future_salts() const { + auto res = future_salts_; + res.push_back(server_salt_); + return res; +} + +int64 AuthData::next_message_id(double now) { + double server_time = get_server_time(now); + int64 t = static_cast(server_time * (1ll << 32)); + + // randomize lower bits for clocks with low precision + // TODO(perf) do not do this for systems with good precision?.. + auto rx = Random::secure_int32(); + auto to_xor = rx & ((1 << 22) - 1); + auto to_mul = ((rx >> 22) & 1023) + 1; + + t ^= to_xor; + auto result = t & -4; + if (last_message_id_ >= result) { + result = last_message_id_ + 8 * to_mul; + } + last_message_id_ = result; + return result; +} + +bool AuthData::is_valid_outbound_msg_id(int64 id, double now) { + double server_time = get_server_time(now); + auto id_time = static_cast(id / (1ll << 32)); + return server_time - 300 / 2 < id_time && id_time < server_time + 60 / 2; +} +bool AuthData::is_valid_inbound_msg_id(int64 id, double now) { + double server_time = get_server_time(now); + auto id_time = static_cast(id / (1ll << 32)); + return server_time - 300 < id_time && id_time < server_time + 30; +} + +Status AuthData::check_packet(int64 session_id, int64 message_id, double now, bool &time_difference_was_updated) { + // Client is to check that the session_id field in the decrypted message indeed equals to that of an active session + // created by the client. + if (get_session_id() != static_cast(session_id)) { + return Status::Error(PSLICE() << "Got packet from different session " + << tag("current session_id", get_session_id()) + << tag("got session_id", session_id)); + } + + // Client must check that msg_id has even parity for messages from client to server, and odd parity for messages + // from server to client. + if ((message_id & 1) == 0) { + return Status::Error(PSLICE() << "Got invalid message_id " << tag("message_id", message_id)); + } + + TRY_STATUS(duplicate_checker_.check(message_id)); + + time_difference_was_updated = update_server_time_difference(static_cast(message_id >> 32) - now); + + // In addition, msg_id values that belong over 30 seconds in the future or over 300 seconds in the past are to be + // ignored (recall that msg_id approximately equals unixtime * 2^32). This is especially important for the server. + // The client would also find this useful (to protect from a replay attack), but only if it is certain of its time + // (for example, if its time has been synchronized with that of the server). + if (server_time_difference_was_updated_ && !is_valid_inbound_msg_id(message_id, now)) { + return Status::Error(PSLICE() << "Ignore message with too old or too new message_id " + << tag("message_id", message_id)); + } + + return Status::OK(); +} + +void AuthData::update_salt(double now) { + double server_time = get_server_time(now); + while (!future_salts_.empty() && (future_salts_.back().valid_since < server_time)) { + server_salt_ = future_salts_.back(); + future_salts_.pop_back(); + } +} + +} // namespace mtproto +} // namespace td diff --git a/td/mtproto/AuthData.h b/td/mtproto/AuthData.h index 47e1a5cc4..7293ffd1c 100644 --- a/td/mtproto/AuthData.h +++ b/td/mtproto/AuthData.h @@ -8,14 +8,10 @@ #include "td/mtproto/AuthKey.h" -#include "td/utils/format.h" #include "td/utils/logging.h" -#include "td/utils/Random.h" #include "td/utils/Slice.h" #include "td/utils/Status.h" -#include "td/utils/Time.h" -#include #include namespace td { @@ -43,28 +39,7 @@ void parse(ServerSalt &salt, ParserT &parser) { class MessageIdDuplicateChecker { public: - Status check(int64 message_id) { - // In addition, the identifiers (msg_id) of the last N messages received from the other side must be stored, and if - // a message comes in with msg_id lower than all or equal to any of the stored values, that message is to be - // ignored. Otherwise, the new message msg_id is added to the set, and, if the number of stored msg_id values is - // greater than N, the oldest (i. e. the lowest) is forgotten. - if (saved_message_ids_.size() == MAX_SAVED_MESSAGE_IDS) { - auto oldest_message_id = *saved_message_ids_.begin(); - if (message_id < oldest_message_id) { - return Status::Error(1, PSLICE() << "Ignore very old message_id " << tag("oldest message_id", oldest_message_id) - << tag("got message_id", message_id)); - } - } - if (saved_message_ids_.count(message_id) != 0) { - return Status::Error(1, PSLICE() << "Ignore duplicated_message id " << tag("message_id", message_id)); - } - - saved_message_ids_.insert(message_id); - if (saved_message_ids_.size() > MAX_SAVED_MESSAGE_IDS) { - saved_message_ids_.erase(saved_message_ids_.begin()); - } - return Status::OK(); - } + Status check(int64 message_id); private: static constexpr size_t MAX_SAVED_MESSAGE_IDS = 1000; @@ -73,30 +48,14 @@ class MessageIdDuplicateChecker { class AuthData { public: - AuthData() { - server_salt_.salt = Random::secure_int64(); - server_salt_.valid_since = -1e10; - server_salt_.valid_until = -1e10; - } + AuthData(); AuthData(const AuthData &) = delete; AuthData &operator=(const AuthData &) = delete; - virtual ~AuthData() = default; + AuthData(AuthData &&) = delete; + AuthData &operator=(AuthData &&) = delete; + ~AuthData() = default; - bool is_ready(double now) { - if (!has_main_auth_key()) { - LOG(INFO) << "Need main auth key"; - return false; - } - if (use_pfs() && !has_tmp_auth_key(now)) { - LOG(INFO) << "Need tmp auth key"; - return false; - } - if (!has_salt(now)) { - LOG(INFO) << "no salt"; - return false; - } - return true; - } + bool is_ready(double now); uint64 session_id_; void set_main_auth_key(AuthKey auth_key) { @@ -225,20 +184,7 @@ class AuthData { // diff == msg_id / 2^32 - now == old_server_now - now <= server_now - now // server_time_difference >= max{diff} - bool update_server_time_difference(double diff) { - if (!server_time_difference_was_updated_) { - server_time_difference_was_updated_ = true; - LOG(DEBUG) << "UPDATE_SERVER_TIME_DIFFERENCE: " << server_time_difference_ << " -> " << diff; - server_time_difference_ = diff; - } else if (server_time_difference_ + 1e-4 < diff) { - LOG(DEBUG) << "UPDATE_SERVER_TIME_DIFFERENCE: " << server_time_difference_ << " -> " << diff; - server_time_difference_ = diff; - } else { - return false; - } - LOG(DEBUG) << "SERVER_TIME: " << format::as_hex(static_cast(get_server_time(Time::now_cached()))); - return true; - } + bool update_server_time_difference(double diff); void set_server_time_difference(double diff) { server_time_difference_was_updated_ = false; @@ -272,82 +218,17 @@ class AuthData { return future_salts_.empty() || !is_server_salt_valid(now); } - virtual void set_future_salts(const std::vector &salts, double now) { - if (salts.empty()) { - return; - } - future_salts_ = salts; - std::sort(future_salts_.begin(), future_salts_.end(), - [](const ServerSalt &a, const ServerSalt &b) { return a.valid_since > b.valid_since; }); - update_salt(now); - } + void set_future_salts(const std::vector &salts, double now); - std::vector get_future_salts() const { - auto res = future_salts_; - res.push_back(server_salt_); - return res; - } + std::vector get_future_salts() const; - int64 next_message_id(double now) { - double server_time = get_server_time(now); - int64 t = static_cast(server_time * (1ll << 32)); + int64 next_message_id(double now); - // randomize lower bits for clocks with low precision - // TODO(perf) do not do this for systems with good precision?.. - auto rx = Random::secure_int32(); - auto to_xor = rx & ((1 << 22) - 1); - auto to_mul = ((rx >> 22) & 1023) + 1; + bool is_valid_outbound_msg_id(int64 id, double now); - t ^= to_xor; - auto result = t & -4; - if (last_message_id_ >= result) { - result = last_message_id_ + 8 * to_mul; - } - last_message_id_ = result; - return result; - } + bool is_valid_inbound_msg_id(int64 id, double now); - bool is_valid_outbound_msg_id(int64 id, double now) { - double server_time = get_server_time(now); - auto id_time = static_cast(id / (1ll << 32)); - return server_time - 300 / 2 < id_time && id_time < server_time + 60 / 2; - } - bool is_valid_inbound_msg_id(int64 id, double now) { - double server_time = get_server_time(now); - auto id_time = static_cast(id / (1ll << 32)); - return server_time - 300 < id_time && id_time < server_time + 30; - } - - Status check_packet(int64 session_id, int64 message_id, double now, bool &time_difference_was_updated) { - // Client is to check that the session_id field in the decrypted message indeed equals to that of an active session - // created by the client. - if (get_session_id() != static_cast(session_id)) { - return Status::Error(PSLICE() << "Got packet from different session " - << tag("current session_id", get_session_id()) - << tag("got session_id", session_id)); - } - - // Client must check that msg_id has even parity for messages from client to server, and odd parity for messages - // from server to client. - if ((message_id & 1) == 0) { - return Status::Error(PSLICE() << "Got invalid message_id " << tag("message_id", message_id)); - } - - TRY_STATUS(duplicate_checker_.check(message_id)); - - time_difference_was_updated = update_server_time_difference(static_cast(message_id >> 32) - now); - - // In addition, msg_id values that belong over 30 seconds in the future or over 300 seconds in the past are to be - // ignored (recall that msg_id approximately equals unixtime * 2^32). This is especially important for the server. - // The client would also find this useful (to protect from a replay attack), but only if it is certain of its time - // (for example, if its time has been synchronized with that of the server). - if (server_time_difference_was_updated_ && !is_valid_inbound_msg_id(message_id, now)) { - return Status::Error(PSLICE() << "Ignore message with too old or too new message_id " - << tag("message_id", message_id)); - } - - return Status::OK(); - } + Status check_packet(int64 session_id, int64 message_id, double now, bool &time_difference_was_updated); Status check_update(int64 message_id) { return updates_duplicate_checker_.check(message_id); @@ -389,13 +270,7 @@ class AuthData { MessageIdDuplicateChecker duplicate_checker_; MessageIdDuplicateChecker updates_duplicate_checker_; - void update_salt(double now) { - double server_time = get_server_time(now); - while (!future_salts_.empty() && (future_salts_.back().valid_since < server_time)) { - server_salt_ = future_salts_.back(); - future_salts_.pop_back(); - } - } + void update_salt(double now); }; } // namespace mtproto diff --git a/td/mtproto/AuthKey.h b/td/mtproto/AuthKey.h index 249e144f5..d71afc88e 100644 --- a/td/mtproto/AuthKey.h +++ b/td/mtproto/AuthKey.h @@ -9,8 +9,6 @@ #include "td/utils/common.h" #include "td/utils/Time.h" -#include - namespace td { namespace mtproto { class AuthKey { @@ -81,7 +79,6 @@ class AuthKey { private: uint64 auth_key_id_ = 0; - // TODO(perf): std::shared_ptr string auth_key_; bool auth_flag_ = false; bool was_auth_flag_ = false; diff --git a/td/telegram/net/AuthDataShared.h b/td/telegram/net/AuthDataShared.h index dc6ae831c..1f36db1b3 100644 --- a/td/telegram/net/AuthDataShared.h +++ b/td/telegram/net/AuthDataShared.h @@ -7,6 +7,7 @@ #pragma once #include "td/mtproto/AuthData.h" +#include "td/mtproto/Authkey.h" #include "td/telegram/net/DcId.h" #include "td/telegram/net/PublicRsaKeyShared.h"