Add basic support for nested entities.

GitOrigin-RevId: 127b89671b4551947552e94bcfdb9cab70ef37c0
This commit is contained in:
levlam 2019-10-07 19:45:36 +03:00
parent b355e0c5c0
commit 1b2e4c79f3

View File

@ -1026,18 +1026,59 @@ vector<std::pair<Slice, bool>> find_urls(Slice str) {
return result; return result;
} }
// sorts entities, removes intersecting and empty entities // keeps nested, but removes mutually intersecting and empty entities
static void fix_entities(vector<MessageEntity> &entities) { // entities must be pre-sorted
if (entities.empty()) { static void remove_unallowed_entities(vector<MessageEntity> &entities) {
return; vector<const MessageEntity *> nested_entities_stack;
size_t left_entities = 0;
for (size_t i = 0; i < entities.size(); i++) {
if (entities[i].offset < 0 || entities[i].length <= 0 || entities[i].offset > 1000000 ||
entities[i].length > 1000000) {
continue;
} }
std::sort(entities.begin(), entities.end()); while (!nested_entities_stack.empty() &&
entities[i].offset >= nested_entities_stack.back()->offset + nested_entities_stack.back()->length) {
// remove non-intersecting entities from the stack
nested_entities_stack.pop_back();
}
if (!nested_entities_stack.empty()) {
// entity intersects some previous entity
if (entities[i].offset + entities[i].length >
nested_entities_stack.back()->offset + nested_entities_stack.back()->length) {
// it must be nested
continue;
}
auto parent_type = nested_entities_stack.back()->type;
if (entities[i].type == parent_type) {
// the type must be different
continue;
}
if (parent_type == MessageEntity::Type::Code || parent_type == MessageEntity::Type::Pre ||
parent_type == MessageEntity::Type::PreCode) {
// Pre and Code can't contain nested entities
continue;
}
}
if (i != left_entities) {
entities[left_entities] = std::move(entities[i]);
}
nested_entities_stack.push_back(&entities[left_entities++]);
}
entities.erase(entities.begin() + left_entities, entities.end());
}
// removes all intersecting entities, including nested
// entities must be pre-sorted and pre-validated
static void remove_intersecting_entities(vector<MessageEntity> &entities) {
int32 last_entity_end = 0; int32 last_entity_end = 0;
size_t left_entities = 0; size_t left_entities = 0;
for (size_t i = 0; i < entities.size(); i++) { for (size_t i = 0; i < entities.size(); i++) {
if (entities[i].length > 0 && entities[i].offset >= last_entity_end) { CHECK(entities[i].length > 0);
if (entities[i].offset >= last_entity_end) {
last_entity_end = entities[i].offset + entities[i].length; last_entity_end = entities[i].offset + entities[i].length;
if (i != left_entities) { if (i != left_entities) {
entities[left_entities] = std::move(entities[i]); entities[left_entities] = std::move(entities[i]);
@ -1048,6 +1089,17 @@ static void fix_entities(vector<MessageEntity> &entities) {
entities.erase(entities.begin() + left_entities, entities.end()); entities.erase(entities.begin() + left_entities, entities.end());
} }
static void fix_entities(vector<MessageEntity> &entities) {
if (entities.empty()) {
// fast path
return;
}
std::sort(entities.begin(), entities.end());
remove_unallowed_entities(entities);
}
vector<MessageEntity> find_entities(Slice text, bool skip_bot_commands, bool only_urls) { vector<MessageEntity> find_entities(Slice text, bool skip_bot_commands, bool only_urls) {
vector<MessageEntity> entities; vector<MessageEntity> entities;
@ -1083,7 +1135,6 @@ vector<MessageEntity> find_entities(Slice text, bool skip_bot_commands, bool onl
auto urls = find_urls(text); auto urls = find_urls(text);
for (auto &url : urls) { for (auto &url : urls) {
// TODO better find messageEntityUrl
auto type = url.second ? MessageEntity::Type::EmailAddress : MessageEntity::Type::Url; auto type = url.second ? MessageEntity::Type::EmailAddress : MessageEntity::Type::Url;
if (only_urls && type != MessageEntity::Type::Url) { if (only_urls && type != MessageEntity::Type::Url) {
continue; continue;
@ -1097,9 +1148,11 @@ vector<MessageEntity> find_entities(Slice text, bool skip_bot_commands, bool onl
return entities; return entities;
} }
fix_entities(entities); std::sort(entities.begin(), entities.end());
// fix offsets to utf16 offsets remove_intersecting_entities(entities);
// fix offsets to UTF-16 offsets
const unsigned char *begin = text.ubegin(); const unsigned char *begin = text.ubegin();
const unsigned char *ptr = begin; const unsigned char *ptr = begin;
const unsigned char *end = text.uend(); const unsigned char *end = text.uend();
@ -1119,7 +1172,7 @@ vector<MessageEntity> find_entities(Slice text, bool skip_bot_commands, bool onl
while (ptr != end && cnt > 0) { while (ptr != end && cnt > 0) {
unsigned char c = ptr[0]; unsigned char c = ptr[0];
utf16_pos += 1 + (c >= 0xf0); utf16_pos += 1 + (c >= 0xf0);
ptr = next_utf8_unsafe(ptr, nullptr, "match_urls 8"); ptr = next_utf8_unsafe(ptr, nullptr, "find_entities");
pos = static_cast<int32>(ptr - begin); pos = static_cast<int32>(ptr - begin);
if (entity_begin == pos) { if (entity_begin == pos) {
@ -2317,10 +2370,18 @@ vector<MessageEntity> get_message_entities(vector<tl_object_ptr<secret_api::Mess
} }
// like clean_input_string but also fixes entities // like clean_input_string but also fixes entities
// entities must be sorted, can be nested, but must not intersect each other
static Result<string> clean_input_string_with_entities(const string &text, vector<MessageEntity> &entities) { static Result<string> clean_input_string_with_entities(const string &text, vector<MessageEntity> &entities) {
bool in_entity = false; struct EntityInfo {
MessageEntity *entity;
int32 utf16_skipped_before;
EntityInfo(MessageEntity *entity, int32 utf16_skipped_before)
: entity(entity), utf16_skipped_before(utf16_skipped_before) {
}
};
vector<EntityInfo> nested_entities_stack;
size_t current_entity = 0; size_t current_entity = 0;
int32 skipped_before_current_entity = 0;
int32 utf16_offset = 0; int32 utf16_offset = 0;
int32 utf16_skipped = 0; int32 utf16_skipped = 0;
@ -2333,27 +2394,30 @@ static Result<string> clean_input_string_with_entities(const string &text, vecto
auto c = static_cast<unsigned char>(text[pos]); auto c = static_cast<unsigned char>(text[pos]);
bool is_utf8_character_begin = is_utf8_character_first_code_unit(c); bool is_utf8_character_begin = is_utf8_character_first_code_unit(c);
if (is_utf8_character_begin) { if (is_utf8_character_begin) {
if (in_entity) { while (!nested_entities_stack.empty()) {
CHECK(current_entity < entities.size()); auto *entity = nested_entities_stack.back().entity;
if (utf16_offset >= entities[current_entity].offset + entities[current_entity].length) { auto entity_end = entity->offset + entity->length;
if (utf16_offset != entities[current_entity].offset + entities[current_entity].length) { if (utf16_offset < entity_end) {
CHECK(utf16_offset == entities[current_entity].offset + entities[current_entity].length + 1); break;
return Status::Error(16, PSLICE() << "Entity beginning at UTF-16 offset " << entities[current_entity].offset }
if (utf16_offset != entity_end) {
CHECK(utf16_offset == entity_end + 1);
return Status::Error(400, PSLICE() << "Entity beginning at UTF-16 offset " << entity->offset
<< " ends in a middle of a UTF-16 symbol at byte offset " << pos); << " ends in a middle of a UTF-16 symbol at byte offset " << pos);
} }
entities[current_entity].offset -= skipped_before_current_entity;
entities[current_entity].length -= utf16_skipped - skipped_before_current_entity; auto skipped_before_current_entity = nested_entities_stack.back().utf16_skipped_before;
in_entity = false; entity->offset -= skipped_before_current_entity;
current_entity++; entity->length -= utf16_skipped - skipped_before_current_entity;
nested_entities_stack.pop_back();
} }
} while (current_entity < entities.size() && utf16_offset >= entities[current_entity].offset) {
if (!in_entity && current_entity < entities.size() && utf16_offset >= entities[current_entity].offset) {
if (utf16_offset != entities[current_entity].offset) { if (utf16_offset != entities[current_entity].offset) {
CHECK(utf16_offset == entities[current_entity].offset + 1); CHECK(utf16_offset == entities[current_entity].offset + 1);
return Status::Error(16, PSLICE() << "Entity begins in a middle of a UTF-16 symbol at byte offset " << pos); return Status::Error(400, PSLICE() << "Entity begins in a middle of a UTF-16 symbol at byte offset " << pos);
} }
in_entity = true; nested_entities_stack.emplace_back(&entities[current_entity++], utf16_skipped);
skipped_before_current_entity = utf16_skipped;
} }
} }
if (pos == text_size) { if (pos == text_size) {
@ -2433,42 +2497,56 @@ static Result<string> clean_input_string_with_entities(const string &text, vecto
} }
} }
entities.resize(current_entity); if (current_entity != entities.size()) {
return Status::Error(400, PSLICE() << "Entity begins after the end of the text at UTF-16 offset "
<< entities[current_entity].offset);
}
if (!nested_entities_stack.empty()) {
auto *entity = nested_entities_stack.back().entity;
return Status::Error(400, PSLICE() << "Entity beginning at UTF-16 offset " << entity->offset
<< " ends after the end of the text at UTF-16 offset "
<< entity->offset + entity->length);
}
return result; return result;
} }
// removes entities containing whitespaces only // removes entities containing whitespaces only
static std::pair<size_t, int32> remove_invalid_entities(const string &text, vector<MessageEntity> &entities) { static std::pair<size_t, int32> remove_invalid_entities(const string &text, vector<MessageEntity> &entities) {
size_t left_entities = 0; vector<MessageEntity *> nested_entities_stack;
size_t current_entity = 0; size_t current_entity = 0;
size_t text_size = text.size(); size_t last_non_whitespace_pos = text.size();
size_t last_non_whitespace_pos = text_size;
int32 utf16_offset = 0; int32 utf16_offset = 0;
int32 last_space_utf16_offset = -1; int32 last_space_utf16_offset = -1;
int32 last_non_whitespace_utf16_offset = -1; int32 last_non_whitespace_utf16_offset = -1;
for (size_t pos = 0; pos <= text.size(); pos++) { for (size_t pos = 0; pos <= text.size(); pos++) {
if (current_entity < entities.size() && while (!nested_entities_stack.empty()) {
utf16_offset == entities[current_entity].offset + entities[current_entity].length) { auto *entity = nested_entities_stack.back();
auto entity_offset = entities[current_entity].offset; auto entity_end = entity->offset + entity->length;
auto entity_type = entities[current_entity].type; if (utf16_offset < entity_end) {
auto have_hidden_data = break;
entity_type == MessageEntity::Type::TextUrl || entity_type == MessageEntity::Type::MentionName;
if (last_non_whitespace_utf16_offset >= entity_offset ||
(last_space_utf16_offset >= entity_offset && have_hidden_data)) {
// TODO check entities for validness, for example, that mentions, hashtags, cashtags and URLs are valid
if (current_entity != left_entities) {
entities[left_entities] = std::move(entities[current_entity]);
}
left_entities++;
}
current_entity++;
} }
if (pos == text_size) { auto have_hidden_data =
entity->type == MessageEntity::Type::TextUrl || entity->type == MessageEntity::Type::MentionName;
if (last_non_whitespace_utf16_offset >= entity->offset ||
(last_space_utf16_offset >= entity->offset && have_hidden_data)) {
// TODO check entity for validness, for example, that mentions, hashtags, cashtags and URLs are valid
// keep entity
} else {
entity->length = 0;
}
nested_entities_stack.pop_back();
}
while (current_entity < entities.size() && utf16_offset >= entities[current_entity].offset) {
nested_entities_stack.push_back(&entities[current_entity++]);
}
if (pos == text.size()) {
break; break;
} }
@ -2480,7 +2558,7 @@ static std::pair<size_t, int32> remove_invalid_entities(const string &text, vect
last_space_utf16_offset = utf16_offset; last_space_utf16_offset = utf16_offset;
break; break;
default: default:
while (pos + 1 < text_size && !is_utf8_character_first_code_unit(static_cast<unsigned char>(text[pos + 1]))) { while (!is_utf8_character_first_code_unit(static_cast<unsigned char>(text[pos + 1]))) {
pos++; pos++;
} }
utf16_offset += (c >= 0xf0); // >= 4 bytes in symbol => surrogaite pair utf16_offset += (c >= 0xf0); // >= 4 bytes in symbol => surrogaite pair
@ -2491,7 +2569,12 @@ static std::pair<size_t, int32> remove_invalid_entities(const string &text, vect
utf16_offset++; utf16_offset++;
} }
entities.erase(entities.begin() + left_entities, entities.end()); CHECK(nested_entities_stack.empty());
CHECK(current_entity == entities.size());
entities.erase(
std::remove_if(entities.begin(), entities.end(), [](const auto &entity) { return entity.length == 0; }),
entities.end());
return {last_non_whitespace_pos, last_non_whitespace_utf16_offset}; return {last_non_whitespace_pos, last_non_whitespace_utf16_offset};
} }
@ -2567,9 +2650,13 @@ Status fix_formatted_text(string &text, vector<MessageEntity> &entities, bool al
new_size--; new_size--;
} }
text.resize(new_size); text.resize(new_size);
while (!entities.empty() && entities.back().offset + entities.back().length > 8192) {
entities.pop_back(); entities.erase(
} std::remove_if(entities.begin(), entities.end(),
[text_utf16_length = narrow_cast<int32>(utf8_utf16_length(text))](const auto &entity) {
return entity.offset + entity.length > text_utf16_length;
}),
entities.end());
} }
if (!skip_new_entities) { if (!skip_new_entities) {