Use array-based MessageIdDuplicateChecker.

This commit is contained in:
levlam 2021-08-22 22:08:46 +03:00
parent b3aa31d398
commit 6507fb7602
3 changed files with 43 additions and 32 deletions

View File

@ -533,40 +533,40 @@ class MessageIdDuplicateCheckerNewSimple {
std::set<td::int64> saved_message_ids_;
};
template <size_t MAX_SAVED_MESSAGE_IDS>
template <size_t max_size>
class MessageIdDuplicateCheckerArray {
public:
static td::string get_description() {
return PSTRING() << "Array" << MAX_SAVED_MESSAGE_IDS;
return PSTRING() << "Array" << max_size;
}
td::Status check(td::int64 message_id) {
if (end_pos == 2 * MAX_SAVED_MESSAGE_IDS) {
std::copy_n(&saved_message_ids_[MAX_SAVED_MESSAGE_IDS], MAX_SAVED_MESSAGE_IDS, &saved_message_ids_[0]);
end_pos = MAX_SAVED_MESSAGE_IDS;
if (end_pos_ == 2 * max_size) {
std::copy_n(&saved_message_ids_[max_size], max_size, &saved_message_ids_[0]);
end_pos_ = max_size;
}
if (end_pos == 0 || message_id > saved_message_ids_[end_pos - 1]) {
if (end_pos_ == 0 || message_id > saved_message_ids_[end_pos_ - 1]) {
// fast path
saved_message_ids_[end_pos++] = message_id;
saved_message_ids_[end_pos_++] = message_id;
return td::Status::OK();
}
if (end_pos >= MAX_SAVED_MESSAGE_IDS && message_id < saved_message_ids_[0]) {
if (end_pos_ >= max_size && message_id < saved_message_ids_[0]) {
return td::Status::Error(2, PSLICE() << "Ignore very old message_id "
<< td::tag("oldest message_id", saved_message_ids_[0])
<< td::tag("got message_id", message_id));
}
auto it = std::lower_bound(&saved_message_ids_[0], &saved_message_ids_[end_pos], message_id);
auto it = std::lower_bound(&saved_message_ids_[0], &saved_message_ids_[end_pos_], message_id);
if (*it == message_id) {
return td::Status::Error(1, PSLICE() << "Ignore duplicated message_id " << td::tag("message_id", message_id));
}
std::copy_backward(it, &saved_message_ids_[end_pos], &saved_message_ids_[end_pos + 1]);
std::copy_backward(it, &saved_message_ids_[end_pos_], &saved_message_ids_[end_pos_ + 1]);
*it = message_id;
++end_pos;
++end_pos_;
return td::Status::OK();
}
private:
std::array<td::int64, 2 * MAX_SAVED_MESSAGE_IDS> saved_message_ids_;
std::size_t end_pos = 0;
std::array<td::int64, 2 * max_size> saved_message_ids_;
std::size_t end_pos_ = 0;
};
template <class T>

View File

@ -17,25 +17,31 @@
namespace td {
namespace mtproto {
Status MessageIdDuplicateChecker::check(int64 message_id) {
Status check_message_id_duplicates(int64 *saved_message_ids, size_t max_size, size_t &end_pos, int64 message_id) {
// In addition, the identifiers (msg_id) of the last N messages received from the other side must be stored, and if
// a message comes in with msg_id lower than all or equal to any of the stored values, that message is to be
// ignored. Otherwise, the new message msg_id is added to the set, and, if the number of stored msg_id values is
// greater than N, the oldest (i. e. the lowest) is forgotten.
auto insert_result = saved_message_ids_.insert(message_id);
if (!insert_result.second) {
if (end_pos == 2 * max_size) {
std::copy_n(&saved_message_ids[max_size], max_size, &saved_message_ids[0]);
end_pos = max_size;
}
if (end_pos == 0 || message_id > saved_message_ids[end_pos - 1]) {
// fast path
saved_message_ids[end_pos++] = message_id;
return Status::OK();
}
if (end_pos >= max_size && message_id < saved_message_ids[0]) {
return Status::Error(2, PSLICE() << "Ignore very old message_id " << tag("oldest message_id", saved_message_ids[0])
<< tag("got message_id", message_id));
}
auto it = std::lower_bound(&saved_message_ids[0], &saved_message_ids[end_pos], message_id);
if (*it == message_id) {
return Status::Error(1, PSLICE() << "Ignore duplicated message_id " << tag("message_id", message_id));
}
if (saved_message_ids_.size() == MAX_SAVED_MESSAGE_IDS + 1) {
auto begin_it = saved_message_ids_.begin();
bool is_very_old = begin_it == insert_result.first;
saved_message_ids_.erase(begin_it);
if (is_very_old) {
return Status::Error(2, PSLICE() << "Ignore very old message_id "
<< tag("oldest message_id", *saved_message_ids_.begin())
<< tag("got message_id", message_id));
}
}
std::copy_backward(it, &saved_message_ids[end_pos], &saved_message_ids[end_pos + 1]);
*it = message_id;
++end_pos;
return Status::OK();
}

View File

@ -12,7 +12,7 @@
#include "td/utils/Slice.h"
#include "td/utils/Status.h"
#include <set>
#include <array>
namespace td {
namespace mtproto {
@ -37,13 +37,18 @@ void parse(ServerSalt &salt, ParserT &parser) {
salt.valid_until = parser.fetch_double();
}
Status check_message_id_duplicates(int64 *saved_message_ids, size_t max_size, size_t &end_pos, int64 message_id);
template <size_t max_size>
class MessageIdDuplicateChecker {
public:
Status check(int64 message_id);
Status check(int64 message_id) {
return check_message_id_duplicates(&saved_message_ids_[0], max_size, end_pos_, message_id);
}
private:
static constexpr size_t MAX_SAVED_MESSAGE_IDS = 1000;
std::set<int64> saved_message_ids_;
std::array<int64, 2 * max_size> saved_message_ids_;
size_t end_pos_ = 0;
};
class AuthData {
@ -274,8 +279,8 @@ class AuthData {
std::vector<ServerSalt> future_salts_;
MessageIdDuplicateChecker duplicate_checker_;
MessageIdDuplicateChecker updates_duplicate_checker_;
MessageIdDuplicateChecker<1000> duplicate_checker_;
MessageIdDuplicateChecker<1000> updates_duplicate_checker_;
void update_salt(double now);
};