Use FlatHashSet to store unallowed boundaries.

This commit is contained in:
levlam 2023-05-19 14:03:22 +03:00
parent d510bc8435
commit 45dfc6e51b

View File

@ -32,7 +32,6 @@
#include <cstring>
#include <limits>
#include <tuple>
#include <unordered_set>
namespace td {
@ -2447,15 +2446,15 @@ static FormattedText parse_text_url_entities_v3(Slice text, const vector<Message
}
static vector<MessageEntity> find_splittable_entities_v3(Slice text, const vector<MessageEntity> &entities) {
std::unordered_set<int32, Hash<int32>> unallowed_boundaries;
FlatHashSet<int32, Hash<int32>> unallowed_boundaries;
for (auto &entity : entities) {
unallowed_boundaries.insert(entity.offset);
unallowed_boundaries.insert(entity.offset + entity.length);
unallowed_boundaries.insert(entity.offset + 1);
unallowed_boundaries.insert(entity.offset + entity.length + 1);
if (entity.type == MessageEntity::Type::Mention || entity.type == MessageEntity::Type::Hashtag ||
entity.type == MessageEntity::Type::BotCommand || entity.type == MessageEntity::Type::Cashtag ||
entity.type == MessageEntity::Type::PhoneNumber || entity.type == MessageEntity::Type::BankCardNumber) {
for (int32 i = 1; i < entity.length; i++) {
unallowed_boundaries.insert(entity.offset + i);
unallowed_boundaries.insert(entity.offset + i + 1);
}
}
}
@ -2466,7 +2465,7 @@ static vector<MessageEntity> find_splittable_entities_v3(Slice text, const vecto
});
for (auto &entity : found_entities) {
for (int32 i = 0; i <= entity.length; i++) {
unallowed_boundaries.insert(entity.offset + i);
unallowed_boundaries.insert(entity.offset + i + 1);
}
}
@ -2479,10 +2478,10 @@ static vector<MessageEntity> find_splittable_entities_v3(Slice text, const vecto
utf16_offset += 1 + (c >= 0xf0); // >= 4 bytes in symbol => surrogate pair
}
if ((c == '_' || c == '*' || c == '~' || c == '|') && text[i] == text[i + 1] &&
unallowed_boundaries.count(utf16_offset) == 0) {
unallowed_boundaries.count(utf16_offset + 1) == 0) {
auto j = i + 2;
while (j != text.size() && text[j] == text[i] &&
unallowed_boundaries.count(utf16_offset + static_cast<int32>(j - i - 1)) == 0) {
unallowed_boundaries.count(utf16_offset + static_cast<int32>(j - i)) == 0) {
j++;
}
if (j == i + 2) {