diff --git a/benchmark/bench_misc.cpp b/benchmark/bench_misc.cpp index e3671a3e7..aeec2232a 100644 --- a/benchmark/bench_misc.cpp +++ b/benchmark/bench_misc.cpp @@ -533,40 +533,40 @@ class MessageIdDuplicateCheckerNewSimple { std::set saved_message_ids_; }; -template +template class MessageIdDuplicateCheckerArray { public: static td::string get_description() { - return PSTRING() << "Array" << MAX_SAVED_MESSAGE_IDS; + return PSTRING() << "Array" << max_size; } td::Status check(td::int64 message_id) { - if (end_pos == 2 * MAX_SAVED_MESSAGE_IDS) { - std::copy_n(&saved_message_ids_[MAX_SAVED_MESSAGE_IDS], MAX_SAVED_MESSAGE_IDS, &saved_message_ids_[0]); - end_pos = MAX_SAVED_MESSAGE_IDS; + if (end_pos_ == 2 * max_size) { + std::copy_n(&saved_message_ids_[max_size], max_size, &saved_message_ids_[0]); + end_pos_ = max_size; } - if (end_pos == 0 || message_id > saved_message_ids_[end_pos - 1]) { + if (end_pos_ == 0 || message_id > saved_message_ids_[end_pos_ - 1]) { // fast path - saved_message_ids_[end_pos++] = message_id; + saved_message_ids_[end_pos_++] = message_id; return td::Status::OK(); } - if (end_pos >= MAX_SAVED_MESSAGE_IDS && message_id < saved_message_ids_[0]) { + if (end_pos_ >= max_size && message_id < saved_message_ids_[0]) { return td::Status::Error(2, PSLICE() << "Ignore very old message_id " << td::tag("oldest message_id", saved_message_ids_[0]) << td::tag("got message_id", message_id)); } - auto it = std::lower_bound(&saved_message_ids_[0], &saved_message_ids_[end_pos], message_id); + auto it = std::lower_bound(&saved_message_ids_[0], &saved_message_ids_[end_pos_], message_id); if (*it == message_id) { return td::Status::Error(1, PSLICE() << "Ignore duplicated message_id " << td::tag("message_id", message_id)); } - std::copy_backward(it, &saved_message_ids_[end_pos], &saved_message_ids_[end_pos + 1]); + std::copy_backward(it, &saved_message_ids_[end_pos_], &saved_message_ids_[end_pos_ + 1]); *it = message_id; - ++end_pos; + ++end_pos_; return td::Status::OK(); } private: - std::array saved_message_ids_; - std::size_t end_pos = 0; + std::array saved_message_ids_; + std::size_t end_pos_ = 0; }; template diff --git a/td/mtproto/AuthData.cpp b/td/mtproto/AuthData.cpp index 189c74cd1..2dfff4bcd 100644 --- a/td/mtproto/AuthData.cpp +++ b/td/mtproto/AuthData.cpp @@ -17,25 +17,31 @@ namespace td { namespace mtproto { -Status MessageIdDuplicateChecker::check(int64 message_id) { +Status check_message_id_duplicates(int64 *saved_message_ids, size_t max_size, size_t &end_pos, int64 message_id) { // In addition, the identifiers (msg_id) of the last N messages received from the other side must be stored, and if // a message comes in with msg_id lower than all or equal to any of the stored values, that message is to be // ignored. Otherwise, the new message msg_id is added to the set, and, if the number of stored msg_id values is // greater than N, the oldest (i. e. the lowest) is forgotten. - auto insert_result = saved_message_ids_.insert(message_id); - if (!insert_result.second) { + if (end_pos == 2 * max_size) { + std::copy_n(&saved_message_ids[max_size], max_size, &saved_message_ids[0]); + end_pos = max_size; + } + if (end_pos == 0 || message_id > saved_message_ids[end_pos - 1]) { + // fast path + saved_message_ids[end_pos++] = message_id; + return Status::OK(); + } + if (end_pos >= max_size && message_id < saved_message_ids[0]) { + return Status::Error(2, PSLICE() << "Ignore very old message_id " << tag("oldest message_id", saved_message_ids[0]) + << tag("got message_id", message_id)); + } + auto it = std::lower_bound(&saved_message_ids[0], &saved_message_ids[end_pos], message_id); + if (*it == message_id) { return Status::Error(1, PSLICE() << "Ignore duplicated message_id " << tag("message_id", message_id)); } - if (saved_message_ids_.size() == MAX_SAVED_MESSAGE_IDS + 1) { - auto begin_it = saved_message_ids_.begin(); - bool is_very_old = begin_it == insert_result.first; - saved_message_ids_.erase(begin_it); - if (is_very_old) { - return Status::Error(2, PSLICE() << "Ignore very old message_id " - << tag("oldest message_id", *saved_message_ids_.begin()) - << tag("got message_id", message_id)); - } - } + std::copy_backward(it, &saved_message_ids[end_pos], &saved_message_ids[end_pos + 1]); + *it = message_id; + ++end_pos; return Status::OK(); } diff --git a/td/mtproto/AuthData.h b/td/mtproto/AuthData.h index 4a9717048..89d34f02b 100644 --- a/td/mtproto/AuthData.h +++ b/td/mtproto/AuthData.h @@ -12,7 +12,7 @@ #include "td/utils/Slice.h" #include "td/utils/Status.h" -#include +#include namespace td { namespace mtproto { @@ -37,13 +37,18 @@ void parse(ServerSalt &salt, ParserT &parser) { salt.valid_until = parser.fetch_double(); } +Status check_message_id_duplicates(int64 *saved_message_ids, size_t max_size, size_t &end_pos, int64 message_id); + +template class MessageIdDuplicateChecker { public: - Status check(int64 message_id); + Status check(int64 message_id) { + return check_message_id_duplicates(&saved_message_ids_[0], max_size, end_pos_, message_id); + } private: - static constexpr size_t MAX_SAVED_MESSAGE_IDS = 1000; - std::set saved_message_ids_; + std::array saved_message_ids_; + size_t end_pos_ = 0; }; class AuthData { @@ -274,8 +279,8 @@ class AuthData { std::vector future_salts_; - MessageIdDuplicateChecker duplicate_checker_; - MessageIdDuplicateChecker updates_duplicate_checker_; + MessageIdDuplicateChecker<1000> duplicate_checker_; + MessageIdDuplicateChecker<1000> updates_duplicate_checker_; void update_salt(double now); };