Use array-based MessageIdDuplicateChecker.

This commit is contained in:
levlam 2021-08-22 22:08:46 +03:00
parent b3aa31d398
commit 6507fb7602
3 changed files with 43 additions and 32 deletions

View File

@ -533,40 +533,40 @@ class MessageIdDuplicateCheckerNewSimple {
std::set<td::int64> saved_message_ids_; std::set<td::int64> saved_message_ids_;
}; };
template <size_t MAX_SAVED_MESSAGE_IDS> template <size_t max_size>
class MessageIdDuplicateCheckerArray { class MessageIdDuplicateCheckerArray {
public: public:
static td::string get_description() { static td::string get_description() {
return PSTRING() << "Array" << MAX_SAVED_MESSAGE_IDS; return PSTRING() << "Array" << max_size;
} }
td::Status check(td::int64 message_id) { td::Status check(td::int64 message_id) {
if (end_pos == 2 * MAX_SAVED_MESSAGE_IDS) { if (end_pos_ == 2 * max_size) {
std::copy_n(&saved_message_ids_[MAX_SAVED_MESSAGE_IDS], MAX_SAVED_MESSAGE_IDS, &saved_message_ids_[0]); std::copy_n(&saved_message_ids_[max_size], max_size, &saved_message_ids_[0]);
end_pos = MAX_SAVED_MESSAGE_IDS; end_pos_ = max_size;
} }
if (end_pos == 0 || message_id > saved_message_ids_[end_pos - 1]) { if (end_pos_ == 0 || message_id > saved_message_ids_[end_pos_ - 1]) {
// fast path // fast path
saved_message_ids_[end_pos++] = message_id; saved_message_ids_[end_pos_++] = message_id;
return td::Status::OK(); return td::Status::OK();
} }
if (end_pos >= MAX_SAVED_MESSAGE_IDS && message_id < saved_message_ids_[0]) { if (end_pos_ >= max_size && message_id < saved_message_ids_[0]) {
return td::Status::Error(2, PSLICE() << "Ignore very old message_id " return td::Status::Error(2, PSLICE() << "Ignore very old message_id "
<< td::tag("oldest message_id", saved_message_ids_[0]) << td::tag("oldest message_id", saved_message_ids_[0])
<< td::tag("got message_id", message_id)); << td::tag("got message_id", message_id));
} }
auto it = std::lower_bound(&saved_message_ids_[0], &saved_message_ids_[end_pos], message_id); auto it = std::lower_bound(&saved_message_ids_[0], &saved_message_ids_[end_pos_], message_id);
if (*it == message_id) { if (*it == message_id) {
return td::Status::Error(1, PSLICE() << "Ignore duplicated message_id " << td::tag("message_id", message_id)); return td::Status::Error(1, PSLICE() << "Ignore duplicated message_id " << td::tag("message_id", message_id));
} }
std::copy_backward(it, &saved_message_ids_[end_pos], &saved_message_ids_[end_pos + 1]); std::copy_backward(it, &saved_message_ids_[end_pos_], &saved_message_ids_[end_pos_ + 1]);
*it = message_id; *it = message_id;
++end_pos; ++end_pos_;
return td::Status::OK(); return td::Status::OK();
} }
private: private:
std::array<td::int64, 2 * MAX_SAVED_MESSAGE_IDS> saved_message_ids_; std::array<td::int64, 2 * max_size> saved_message_ids_;
std::size_t end_pos = 0; std::size_t end_pos_ = 0;
}; };
template <class T> template <class T>

View File

@ -17,25 +17,31 @@
namespace td { namespace td {
namespace mtproto { namespace mtproto {
Status MessageIdDuplicateChecker::check(int64 message_id) { Status check_message_id_duplicates(int64 *saved_message_ids, size_t max_size, size_t &end_pos, int64 message_id) {
// In addition, the identifiers (msg_id) of the last N messages received from the other side must be stored, and if // In addition, the identifiers (msg_id) of the last N messages received from the other side must be stored, and if
// a message comes in with msg_id lower than all or equal to any of the stored values, that message is to be // a message comes in with msg_id lower than all or equal to any of the stored values, that message is to be
// ignored. Otherwise, the new message msg_id is added to the set, and, if the number of stored msg_id values is // ignored. Otherwise, the new message msg_id is added to the set, and, if the number of stored msg_id values is
// greater than N, the oldest (i. e. the lowest) is forgotten. // greater than N, the oldest (i. e. the lowest) is forgotten.
auto insert_result = saved_message_ids_.insert(message_id); if (end_pos == 2 * max_size) {
if (!insert_result.second) { std::copy_n(&saved_message_ids[max_size], max_size, &saved_message_ids[0]);
end_pos = max_size;
}
if (end_pos == 0 || message_id > saved_message_ids[end_pos - 1]) {
// fast path
saved_message_ids[end_pos++] = message_id;
return Status::OK();
}
if (end_pos >= max_size && message_id < saved_message_ids[0]) {
return Status::Error(2, PSLICE() << "Ignore very old message_id " << tag("oldest message_id", saved_message_ids[0])
<< tag("got message_id", message_id));
}
auto it = std::lower_bound(&saved_message_ids[0], &saved_message_ids[end_pos], message_id);
if (*it == message_id) {
return Status::Error(1, PSLICE() << "Ignore duplicated message_id " << tag("message_id", message_id)); return Status::Error(1, PSLICE() << "Ignore duplicated message_id " << tag("message_id", message_id));
} }
if (saved_message_ids_.size() == MAX_SAVED_MESSAGE_IDS + 1) { std::copy_backward(it, &saved_message_ids[end_pos], &saved_message_ids[end_pos + 1]);
auto begin_it = saved_message_ids_.begin(); *it = message_id;
bool is_very_old = begin_it == insert_result.first; ++end_pos;
saved_message_ids_.erase(begin_it);
if (is_very_old) {
return Status::Error(2, PSLICE() << "Ignore very old message_id "
<< tag("oldest message_id", *saved_message_ids_.begin())
<< tag("got message_id", message_id));
}
}
return Status::OK(); return Status::OK();
} }

View File

@ -12,7 +12,7 @@
#include "td/utils/Slice.h" #include "td/utils/Slice.h"
#include "td/utils/Status.h" #include "td/utils/Status.h"
#include <set> #include <array>
namespace td { namespace td {
namespace mtproto { namespace mtproto {
@ -37,13 +37,18 @@ void parse(ServerSalt &salt, ParserT &parser) {
salt.valid_until = parser.fetch_double(); salt.valid_until = parser.fetch_double();
} }
Status check_message_id_duplicates(int64 *saved_message_ids, size_t max_size, size_t &end_pos, int64 message_id);
template <size_t max_size>
class MessageIdDuplicateChecker { class MessageIdDuplicateChecker {
public: public:
Status check(int64 message_id); Status check(int64 message_id) {
return check_message_id_duplicates(&saved_message_ids_[0], max_size, end_pos_, message_id);
}
private: private:
static constexpr size_t MAX_SAVED_MESSAGE_IDS = 1000; std::array<int64, 2 * max_size> saved_message_ids_;
std::set<int64> saved_message_ids_; size_t end_pos_ = 0;
}; };
class AuthData { class AuthData {
@ -274,8 +279,8 @@ class AuthData {
std::vector<ServerSalt> future_salts_; std::vector<ServerSalt> future_salts_;
MessageIdDuplicateChecker duplicate_checker_; MessageIdDuplicateChecker<1000> duplicate_checker_;
MessageIdDuplicateChecker updates_duplicate_checker_; MessageIdDuplicateChecker<1000> updates_duplicate_checker_;
void update_salt(double now); void update_salt(double now);
}; };