From 1b2e4c79f370368ed920971f14d1e501a95b3740 Mon Sep 17 00:00:00 2001 From: levlam Date: Mon, 7 Oct 2019 19:45:36 +0300 Subject: [PATCH] Add basic support for nested entities. GitOrigin-RevId: 127b89671b4551947552e94bcfdb9cab70ef37c0 --- td/telegram/MessageEntity.cpp | 191 +++++++++++++++++++++++++--------- 1 file changed, 139 insertions(+), 52 deletions(-) diff --git a/td/telegram/MessageEntity.cpp b/td/telegram/MessageEntity.cpp index fcf5c767..49601cd8 100644 --- a/td/telegram/MessageEntity.cpp +++ b/td/telegram/MessageEntity.cpp @@ -1026,18 +1026,59 @@ vector> find_urls(Slice str) { return result; } -// sorts entities, removes intersecting and empty entities -static void fix_entities(vector &entities) { - if (entities.empty()) { - return; +// keeps nested, but removes mutually intersecting and empty entities +// entities must be pre-sorted +static void remove_unallowed_entities(vector &entities) { + vector nested_entities_stack; + size_t left_entities = 0; + for (size_t i = 0; i < entities.size(); i++) { + if (entities[i].offset < 0 || entities[i].length <= 0 || entities[i].offset > 1000000 || + entities[i].length > 1000000) { + continue; + } + + while (!nested_entities_stack.empty() && + entities[i].offset >= nested_entities_stack.back()->offset + nested_entities_stack.back()->length) { + // remove non-intersecting entities from the stack + nested_entities_stack.pop_back(); + } + + if (!nested_entities_stack.empty()) { + // entity intersects some previous entity + if (entities[i].offset + entities[i].length > + nested_entities_stack.back()->offset + nested_entities_stack.back()->length) { + // it must be nested + continue; + } + auto parent_type = nested_entities_stack.back()->type; + if (entities[i].type == parent_type) { + // the type must be different + continue; + } + if (parent_type == MessageEntity::Type::Code || parent_type == MessageEntity::Type::Pre || + parent_type == MessageEntity::Type::PreCode) { + // Pre and Code can't contain nested entities + continue; + } + } + + if (i != left_entities) { + entities[left_entities] = std::move(entities[i]); + } + nested_entities_stack.push_back(&entities[left_entities++]); } - std::sort(entities.begin(), entities.end()); + entities.erase(entities.begin() + left_entities, entities.end()); +} +// removes all intersecting entities, including nested +// entities must be pre-sorted and pre-validated +static void remove_intersecting_entities(vector &entities) { int32 last_entity_end = 0; size_t left_entities = 0; for (size_t i = 0; i < entities.size(); i++) { - if (entities[i].length > 0 && entities[i].offset >= last_entity_end) { + CHECK(entities[i].length > 0); + if (entities[i].offset >= last_entity_end) { last_entity_end = entities[i].offset + entities[i].length; if (i != left_entities) { entities[left_entities] = std::move(entities[i]); @@ -1048,6 +1089,17 @@ static void fix_entities(vector &entities) { entities.erase(entities.begin() + left_entities, entities.end()); } +static void fix_entities(vector &entities) { + if (entities.empty()) { + // fast path + return; + } + + std::sort(entities.begin(), entities.end()); + + remove_unallowed_entities(entities); +} + vector find_entities(Slice text, bool skip_bot_commands, bool only_urls) { vector entities; @@ -1083,7 +1135,6 @@ vector find_entities(Slice text, bool skip_bot_commands, bool onl auto urls = find_urls(text); for (auto &url : urls) { - // TODO better find messageEntityUrl auto type = url.second ? MessageEntity::Type::EmailAddress : MessageEntity::Type::Url; if (only_urls && type != MessageEntity::Type::Url) { continue; @@ -1097,9 +1148,11 @@ vector find_entities(Slice text, bool skip_bot_commands, bool onl return entities; } - fix_entities(entities); + std::sort(entities.begin(), entities.end()); - // fix offsets to utf16 offsets + remove_intersecting_entities(entities); + + // fix offsets to UTF-16 offsets const unsigned char *begin = text.ubegin(); const unsigned char *ptr = begin; const unsigned char *end = text.uend(); @@ -1119,7 +1172,7 @@ vector find_entities(Slice text, bool skip_bot_commands, bool onl while (ptr != end && cnt > 0) { unsigned char c = ptr[0]; utf16_pos += 1 + (c >= 0xf0); - ptr = next_utf8_unsafe(ptr, nullptr, "match_urls 8"); + ptr = next_utf8_unsafe(ptr, nullptr, "find_entities"); pos = static_cast(ptr - begin); if (entity_begin == pos) { @@ -2317,10 +2370,18 @@ vector get_message_entities(vector clean_input_string_with_entities(const string &text, vector &entities) { - bool in_entity = false; + struct EntityInfo { + MessageEntity *entity; + int32 utf16_skipped_before; + + EntityInfo(MessageEntity *entity, int32 utf16_skipped_before) + : entity(entity), utf16_skipped_before(utf16_skipped_before) { + } + }; + vector nested_entities_stack; size_t current_entity = 0; - int32 skipped_before_current_entity = 0; int32 utf16_offset = 0; int32 utf16_skipped = 0; @@ -2333,27 +2394,30 @@ static Result clean_input_string_with_entities(const string &text, vecto auto c = static_cast(text[pos]); bool is_utf8_character_begin = is_utf8_character_first_code_unit(c); if (is_utf8_character_begin) { - if (in_entity) { - CHECK(current_entity < entities.size()); - if (utf16_offset >= entities[current_entity].offset + entities[current_entity].length) { - if (utf16_offset != entities[current_entity].offset + entities[current_entity].length) { - CHECK(utf16_offset == entities[current_entity].offset + entities[current_entity].length + 1); - return Status::Error(16, PSLICE() << "Entity beginning at UTF-16 offset " << entities[current_entity].offset - << " ends in a middle of a UTF-16 symbol at byte offset " << pos); - } - entities[current_entity].offset -= skipped_before_current_entity; - entities[current_entity].length -= utf16_skipped - skipped_before_current_entity; - in_entity = false; - current_entity++; + while (!nested_entities_stack.empty()) { + auto *entity = nested_entities_stack.back().entity; + auto entity_end = entity->offset + entity->length; + if (utf16_offset < entity_end) { + break; } + + if (utf16_offset != entity_end) { + CHECK(utf16_offset == entity_end + 1); + return Status::Error(400, PSLICE() << "Entity beginning at UTF-16 offset " << entity->offset + << " ends in a middle of a UTF-16 symbol at byte offset " << pos); + } + + auto skipped_before_current_entity = nested_entities_stack.back().utf16_skipped_before; + entity->offset -= skipped_before_current_entity; + entity->length -= utf16_skipped - skipped_before_current_entity; + nested_entities_stack.pop_back(); } - if (!in_entity && current_entity < entities.size() && utf16_offset >= entities[current_entity].offset) { + while (current_entity < entities.size() && utf16_offset >= entities[current_entity].offset) { if (utf16_offset != entities[current_entity].offset) { CHECK(utf16_offset == entities[current_entity].offset + 1); - return Status::Error(16, PSLICE() << "Entity begins in a middle of a UTF-16 symbol at byte offset " << pos); + return Status::Error(400, PSLICE() << "Entity begins in a middle of a UTF-16 symbol at byte offset " << pos); } - in_entity = true; - skipped_before_current_entity = utf16_skipped; + nested_entities_stack.emplace_back(&entities[current_entity++], utf16_skipped); } } if (pos == text_size) { @@ -2433,42 +2497,56 @@ static Result clean_input_string_with_entities(const string &text, vecto } } - entities.resize(current_entity); + if (current_entity != entities.size()) { + return Status::Error(400, PSLICE() << "Entity begins after the end of the text at UTF-16 offset " + << entities[current_entity].offset); + } + if (!nested_entities_stack.empty()) { + auto *entity = nested_entities_stack.back().entity; + return Status::Error(400, PSLICE() << "Entity beginning at UTF-16 offset " << entity->offset + << " ends after the end of the text at UTF-16 offset " + << entity->offset + entity->length); + } return result; } // removes entities containing whitespaces only static std::pair remove_invalid_entities(const string &text, vector &entities) { - size_t left_entities = 0; + vector nested_entities_stack; size_t current_entity = 0; - size_t text_size = text.size(); - size_t last_non_whitespace_pos = text_size; + size_t last_non_whitespace_pos = text.size(); int32 utf16_offset = 0; int32 last_space_utf16_offset = -1; int32 last_non_whitespace_utf16_offset = -1; for (size_t pos = 0; pos <= text.size(); pos++) { - if (current_entity < entities.size() && - utf16_offset == entities[current_entity].offset + entities[current_entity].length) { - auto entity_offset = entities[current_entity].offset; - auto entity_type = entities[current_entity].type; - auto have_hidden_data = - entity_type == MessageEntity::Type::TextUrl || entity_type == MessageEntity::Type::MentionName; - if (last_non_whitespace_utf16_offset >= entity_offset || - (last_space_utf16_offset >= entity_offset && have_hidden_data)) { - // TODO check entities for validness, for example, that mentions, hashtags, cashtags and URLs are valid - if (current_entity != left_entities) { - entities[left_entities] = std::move(entities[current_entity]); - } - left_entities++; + while (!nested_entities_stack.empty()) { + auto *entity = nested_entities_stack.back(); + auto entity_end = entity->offset + entity->length; + if (utf16_offset < entity_end) { + break; } - current_entity++; + + auto have_hidden_data = + entity->type == MessageEntity::Type::TextUrl || entity->type == MessageEntity::Type::MentionName; + if (last_non_whitespace_utf16_offset >= entity->offset || + (last_space_utf16_offset >= entity->offset && have_hidden_data)) { + // TODO check entity for validness, for example, that mentions, hashtags, cashtags and URLs are valid + // keep entity + } else { + entity->length = 0; + } + + nested_entities_stack.pop_back(); + } + while (current_entity < entities.size() && utf16_offset >= entities[current_entity].offset) { + nested_entities_stack.push_back(&entities[current_entity++]); } - if (pos == text_size) { + if (pos == text.size()) { break; } @@ -2480,7 +2558,7 @@ static std::pair remove_invalid_entities(const string &text, vect last_space_utf16_offset = utf16_offset; break; default: - while (pos + 1 < text_size && !is_utf8_character_first_code_unit(static_cast(text[pos + 1]))) { + while (!is_utf8_character_first_code_unit(static_cast(text[pos + 1]))) { pos++; } utf16_offset += (c >= 0xf0); // >= 4 bytes in symbol => surrogaite pair @@ -2491,7 +2569,12 @@ static std::pair remove_invalid_entities(const string &text, vect utf16_offset++; } - entities.erase(entities.begin() + left_entities, entities.end()); + CHECK(nested_entities_stack.empty()); + CHECK(current_entity == entities.size()); + + entities.erase( + std::remove_if(entities.begin(), entities.end(), [](const auto &entity) { return entity.length == 0; }), + entities.end()); return {last_non_whitespace_pos, last_non_whitespace_utf16_offset}; } @@ -2567,9 +2650,13 @@ Status fix_formatted_text(string &text, vector &entities, bool al new_size--; } text.resize(new_size); - while (!entities.empty() && entities.back().offset + entities.back().length > 8192) { - entities.pop_back(); - } + + entities.erase( + std::remove_if(entities.begin(), entities.end(), + [text_utf16_length = narrow_cast(utf8_utf16_length(text))](const auto &entity) { + return entity.offset + entity.length > text_utf16_length; + }), + entities.end()); } if (!skip_new_entities) {