diff --git a/td/telegram/MessageEntity.cpp b/td/telegram/MessageEntity.cpp index d1642b223..eafd22941 100644 --- a/td/telegram/MessageEntity.cpp +++ b/td/telegram/MessageEntity.cpp @@ -18,6 +18,7 @@ #include #include +#include #include #include @@ -1135,49 +1136,109 @@ vector> find_urls(Slice str) { return result; } -// keeps nested, but removes mutually intersecting entities -// entities must be pre-sorted -static void remove_unallowed_entities(vector &entities) { +static void check_is_sorted(const vector &entities) { + CHECK(std::is_sorted(entities.begin(), entities.end())); +} + +static void check_non_intersecting(const vector &entities) { + for (size_t i = 0; i + 1 < entities.size(); i++) { + CHECK(entities[i].offset + entities[i].length <= entities[i + 1].offset); + } +} + +static int32 get_entity_type_mask(MessageEntity::Type type) { + return 1 << static_cast(type); +} + +static int32 is_splittable_entity(MessageEntity::Type type) { + return type == MessageEntity::Type::Bold || type == MessageEntity::Type::Italic || + type == MessageEntity::Type::Underline || type == MessageEntity::Type::Strikethrough; +} + +static int32 is_blockquote_entity(MessageEntity::Type type) { + return type == MessageEntity::Type::BlockQuote; +} + +static int32 is_continuous_entity(MessageEntity::Type type) { + return type == MessageEntity::Type::Mention || type == MessageEntity::Type::Hashtag || + type == MessageEntity::Type::BotCommand || type == MessageEntity::Type::Url || + type == MessageEntity::Type::EmailAddress || type == MessageEntity::Type::TextUrl || + type == MessageEntity::Type::MentionName || type == MessageEntity::Type::Cashtag || + type == MessageEntity::Type::PhoneNumber || type == MessageEntity::Type::BankCardNumber; +} + +static int32 is_pre_entity(MessageEntity::Type type) { + return type == MessageEntity::Type::Pre || type == MessageEntity::Type::Code || type == MessageEntity::Type::PreCode; +} + +static constexpr size_t SPLITTABLE_ENTITY_TYPE_COUNT = 4; + +static size_t get_splittable_entity_type_index(MessageEntity::Type type) { + if (static_cast(type) <= static_cast(MessageEntity::Type::Bold) + 1) { + // Bold or Italic + return static_cast(type) - static_cast(MessageEntity::Type::Bold); + } else { + // Underline or Strikethrough + return static_cast(type) - static_cast(MessageEntity::Type::Underline) + 2; + } +} + +static bool are_entities_valid(const vector &entities) { + if (entities.empty()) { + return true; + } + check_is_sorted(entities); + + int32 end_pos[SPLITTABLE_ENTITY_TYPE_COUNT]; + std::fill_n(end_pos, SPLITTABLE_ENTITY_TYPE_COUNT, -1); vector nested_entities_stack; - size_t left_entities = 0; - for (size_t i = 0; i < entities.size(); i++) { + int32 nested_entity_type_mask = 0; + for (auto &entity : entities) { while (!nested_entities_stack.empty() && - entities[i].offset >= nested_entities_stack.back()->offset + nested_entities_stack.back()->length) { + entity.offset >= nested_entities_stack.back()->offset + nested_entities_stack.back()->length) { // remove non-intersecting entities from the stack + nested_entity_type_mask -= get_entity_type_mask(nested_entities_stack.back()->type); nested_entities_stack.pop_back(); } if (!nested_entities_stack.empty()) { - // entity intersects some previous entity - if (entities[i].offset + entities[i].length > - nested_entities_stack.back()->offset + nested_entities_stack.back()->length) { - // it must be nested - continue; + if (entity.offset + entity.length > nested_entities_stack.back()->offset + nested_entities_stack.back()->length) { + // entity intersects some previous entity + return false; + } + if ((nested_entity_type_mask & get_entity_type_mask(entity.type)) != 0) { + // entity has the same type as one of the previous nested + return false; } auto parent_type = nested_entities_stack.back()->type; - if (entities[i].type == parent_type) { - // the type must be different - continue; - } - if (parent_type == MessageEntity::Type::Code || parent_type == MessageEntity::Type::Pre || - parent_type == MessageEntity::Type::PreCode) { + if (is_pre_entity(parent_type)) { // Pre and Code can't contain nested entities - continue; + return false; + } + if (is_continuous_entity(parent_type) && + (is_pre_entity(entity.type) || is_continuous_entity(entity.type) || is_blockquote_entity(entity.type))) { + // continuous can't contain other continuous and blockquote + return false; } } - if (i != left_entities) { - entities[left_entities] = std::move(entities[i]); + if (is_splittable_entity(entity.type)) { + auto index = get_splittable_entity_type_index(entity.type); + if (end_pos[index] >= entity.offset) { + // the entities may be need to merged + return false; + } + end_pos[index] = entity.offset + entity.length; } - nested_entities_stack.push_back(&entities[left_entities++]); + nested_entities_stack.push_back(&entity); + nested_entity_type_mask += get_entity_type_mask(entity.type); } - - entities.erase(entities.begin() + left_entities, entities.end()); + return true; } // removes all intersecting entities, including nested -// entities must be pre-sorted and pre-validated static void remove_intersecting_entities(vector &entities) { + check_is_sorted(entities); int32 last_entity_end = 0; size_t left_entities = 0; for (size_t i = 0; i < entities.size(); i++) { @@ -1193,6 +1254,35 @@ static void remove_intersecting_entities(vector &entities) { entities.erase(entities.begin() + left_entities, entities.end()); } +// continuous_entities and blockquote_entities must be pre-sorted and non-overlapping +static void remove_entities_intersecting_blockquote(vector &entities, + const vector &blockquote_entities) { + check_non_intersecting(entities); + check_non_intersecting(blockquote_entities); + if (blockquote_entities.empty()) { + // fast path + return; + } + + auto blockquote_it = blockquote_entities.begin(); + size_t left_entities = 0; + for (size_t i = 0; i < entities.size(); i++) { + while (blockquote_it != blockquote_entities.end() && + (blockquote_it->type != MessageEntity::Type::BlockQuote || + blockquote_it->offset + blockquote_it->length <= entities[i].offset)) { + blockquote_it++; + } + if (blockquote_it != blockquote_entities.end() && blockquote_it->offset < entities[i].offset + entities[i].length) { + continue; + } + if (i != left_entities) { + entities[left_entities] = std::move(entities[i]); + } + left_entities++; + } + entities.erase(entities.begin() + left_entities, entities.end()); +} + vector find_entities(Slice text, bool skip_bot_commands, bool only_urls) { vector entities; @@ -2611,7 +2701,9 @@ static Result clean_input_string_with_entities(const string &text, vecto } // removes entities containing whitespaces only +// returns {last_non_whitespace_pos, last_non_whitespace_utf16_offset} static std::pair remove_invalid_entities(const string &text, vector &entities) { + check_is_sorted(entities); vector nested_entities_stack; size_t current_entity = 0; @@ -2680,6 +2772,117 @@ static std::pair remove_invalid_entities(const string &text, vect return {last_non_whitespace_pos, last_non_whitespace_utf16_offset}; } +// enitities must contain only splittable entities +void split_entities(vector &entities, const vector &other_entities) { + check_is_sorted(entities); + check_non_intersecting(other_entities); + + int32 begin_pos[SPLITTABLE_ENTITY_TYPE_COUNT] = {}; + int32 end_pos[SPLITTABLE_ENTITY_TYPE_COUNT] = {}; + auto it = entities.begin(); + vector result; + auto add_entities = [&](int32 end_offset) { + auto flush_entities = [&](int32 offset) { + for (auto type : {MessageEntity::Type::Bold, MessageEntity::Type::Italic, MessageEntity::Type::Underline, + MessageEntity::Type::Strikethrough}) { + auto index = get_splittable_entity_type_index(type); + if (end_pos[index] != 0 && begin_pos[index] < offset) { + if (end_pos[index] <= offset) { + result.emplace_back(type, begin_pos[index], end_pos[index]); + begin_pos[index] = 0; + end_pos[index] = 0; + } else { + result.emplace_back(type, begin_pos[index], offset); + begin_pos[index] = offset; + } + } + } + }; + + while (it != entities.end()) { + if (it->offset >= end_offset) { + break; + } + CHECK(is_splittable_entity(it->type)); + auto index = get_splittable_entity_type_index(it->type); + if (it->offset <= end_pos[index] && end_pos[index] != 0) { + if (it->offset + it->length > end_pos[index]) { + end_pos[index] = it->offset + it->length; + } + } else { + flush_entities(it->offset); + begin_pos[index] = it->offset; + end_pos[index] = it->offset + it->length; + } + ++it; + } + flush_entities(end_offset); + }; + for (auto &other_entity : other_entities) { + add_entities(other_entity.offset); + auto old_size = result.size(); + add_entities(other_entity.offset + other_entity.length); + if (is_pre_entity(other_entity.type)) { + result.resize(old_size); + } + } + add_entities(std::numeric_limits::max()); + entities = std::move(result); + // entities are sorted only by offset now, re-sort if needed + if (!std::is_sorted(entities.begin(), entities.end())) { + std::sort(entities.begin(), entities.end()); + } +} + +static void fix_entities(vector &entities) { + if (!std::is_sorted(entities.begin(), entities.end())) { + std::sort(entities.begin(), entities.end()); + } + + if (are_entities_valid(entities)) { + // fast path + return; + } + + vector continuous_entities; + vector blockquote_entities; + vector splittable_entities; + for (auto &entity : entities) { + if (is_splittable_entity(entity.type)) { + splittable_entities.push_back(std::move(entity)); + } else if (is_blockquote_entity(entity.type)) { + blockquote_entities.push_back(std::move(entity)); + } else { + continuous_entities.push_back(std::move(entity)); + } + } + remove_intersecting_entities(continuous_entities); // continuous entities can't intersect each other + + if (!blockquote_entities.empty()) { + remove_intersecting_entities(blockquote_entities); // blockquote entities can't intersect each other + split_entities(splittable_entities, blockquote_entities); + + // blockquote entities can contain continuous entities, but can't intersect them in the other ways + remove_entities_intersecting_blockquote(continuous_entities, blockquote_entities); + } + + split_entities(splittable_entities, continuous_entities); // split by remaining continuous entities + + if (!blockquote_entities.empty()) { + combine(continuous_entities, std::move(blockquote_entities)); + std::sort(continuous_entities.begin(), continuous_entities.end()); + } + + if (splittable_entities.empty()) { + splittable_entities = std::move(continuous_entities); + } else if (!continuous_entities.empty()) { + combine(splittable_entities, std::move(continuous_entities)); + std::sort(splittable_entities.begin(), splittable_entities.end()); + } + entities = std::move(splittable_entities); + check_is_sorted(entities); +} + Status fix_formatted_text(string &text, vector &entities, bool allow_empty, bool skip_new_entities, bool skip_bot_commands, bool for_draft) { if (!check_utf8(text)) { @@ -2696,10 +2899,7 @@ Status fix_formatted_text(string &text, vector &entities, bool al } td::remove_if(entities, [](const MessageEntity &entity) { return entity.length == 0; }); - if (!entities.empty()) { - std::sort(entities.begin(), entities.end()); - remove_unallowed_entities(entities); - } + fix_entities(entities); TRY_RESULT(result, clean_input_string_with_entities(text, entities)); @@ -2718,9 +2918,10 @@ Status fix_formatted_text(string &text, vector &entities, bool al return Status::Error(3, "Message must be non-empty"); } - if (!std::is_sorted(entities.begin(), entities.end())) { - std::sort(entities.begin(), entities.end()); // re-sort entities if needed after removal of some characters - } + // re-fix entities if needed after removal of some characters + // the sort order can be incorrect by type + // some splittable entities may be needed to be concatenated + fix_entities(entities); if (for_draft) { text = std::move(result);