Add basic support for nested entities.

GitOrigin-RevId: 127b89671b4551947552e94bcfdb9cab70ef37c0
This commit is contained in:
levlam 2019-10-07 19:45:36 +03:00
parent b355e0c5c0
commit 1b2e4c79f3

View File

@ -1026,18 +1026,59 @@ vector<std::pair<Slice, bool>> find_urls(Slice str) {
return result;
}
// sorts entities, removes intersecting and empty entities
static void fix_entities(vector<MessageEntity> &entities) {
if (entities.empty()) {
return;
// keeps nested, but removes mutually intersecting and empty entities
// entities must be pre-sorted
static void remove_unallowed_entities(vector<MessageEntity> &entities) {
vector<const MessageEntity *> nested_entities_stack;
size_t left_entities = 0;
for (size_t i = 0; i < entities.size(); i++) {
if (entities[i].offset < 0 || entities[i].length <= 0 || entities[i].offset > 1000000 ||
entities[i].length > 1000000) {
continue;
}
while (!nested_entities_stack.empty() &&
entities[i].offset >= nested_entities_stack.back()->offset + nested_entities_stack.back()->length) {
// remove non-intersecting entities from the stack
nested_entities_stack.pop_back();
}
if (!nested_entities_stack.empty()) {
// entity intersects some previous entity
if (entities[i].offset + entities[i].length >
nested_entities_stack.back()->offset + nested_entities_stack.back()->length) {
// it must be nested
continue;
}
auto parent_type = nested_entities_stack.back()->type;
if (entities[i].type == parent_type) {
// the type must be different
continue;
}
if (parent_type == MessageEntity::Type::Code || parent_type == MessageEntity::Type::Pre ||
parent_type == MessageEntity::Type::PreCode) {
// Pre and Code can't contain nested entities
continue;
}
}
if (i != left_entities) {
entities[left_entities] = std::move(entities[i]);
}
nested_entities_stack.push_back(&entities[left_entities++]);
}
std::sort(entities.begin(), entities.end());
entities.erase(entities.begin() + left_entities, entities.end());
}
// removes all intersecting entities, including nested
// entities must be pre-sorted and pre-validated
static void remove_intersecting_entities(vector<MessageEntity> &entities) {
int32 last_entity_end = 0;
size_t left_entities = 0;
for (size_t i = 0; i < entities.size(); i++) {
if (entities[i].length > 0 && entities[i].offset >= last_entity_end) {
CHECK(entities[i].length > 0);
if (entities[i].offset >= last_entity_end) {
last_entity_end = entities[i].offset + entities[i].length;
if (i != left_entities) {
entities[left_entities] = std::move(entities[i]);
@ -1048,6 +1089,17 @@ static void fix_entities(vector<MessageEntity> &entities) {
entities.erase(entities.begin() + left_entities, entities.end());
}
static void fix_entities(vector<MessageEntity> &entities) {
if (entities.empty()) {
// fast path
return;
}
std::sort(entities.begin(), entities.end());
remove_unallowed_entities(entities);
}
vector<MessageEntity> find_entities(Slice text, bool skip_bot_commands, bool only_urls) {
vector<MessageEntity> entities;
@ -1083,7 +1135,6 @@ vector<MessageEntity> find_entities(Slice text, bool skip_bot_commands, bool onl
auto urls = find_urls(text);
for (auto &url : urls) {
// TODO better find messageEntityUrl
auto type = url.second ? MessageEntity::Type::EmailAddress : MessageEntity::Type::Url;
if (only_urls && type != MessageEntity::Type::Url) {
continue;
@ -1097,9 +1148,11 @@ vector<MessageEntity> find_entities(Slice text, bool skip_bot_commands, bool onl
return entities;
}
fix_entities(entities);
std::sort(entities.begin(), entities.end());
// fix offsets to utf16 offsets
remove_intersecting_entities(entities);
// fix offsets to UTF-16 offsets
const unsigned char *begin = text.ubegin();
const unsigned char *ptr = begin;
const unsigned char *end = text.uend();
@ -1119,7 +1172,7 @@ vector<MessageEntity> find_entities(Slice text, bool skip_bot_commands, bool onl
while (ptr != end && cnt > 0) {
unsigned char c = ptr[0];
utf16_pos += 1 + (c >= 0xf0);
ptr = next_utf8_unsafe(ptr, nullptr, "match_urls 8");
ptr = next_utf8_unsafe(ptr, nullptr, "find_entities");
pos = static_cast<int32>(ptr - begin);
if (entity_begin == pos) {
@ -2317,10 +2370,18 @@ vector<MessageEntity> get_message_entities(vector<tl_object_ptr<secret_api::Mess
}
// like clean_input_string but also fixes entities
// entities must be sorted, can be nested, but must not intersect each other
static Result<string> clean_input_string_with_entities(const string &text, vector<MessageEntity> &entities) {
bool in_entity = false;
struct EntityInfo {
MessageEntity *entity;
int32 utf16_skipped_before;
EntityInfo(MessageEntity *entity, int32 utf16_skipped_before)
: entity(entity), utf16_skipped_before(utf16_skipped_before) {
}
};
vector<EntityInfo> nested_entities_stack;
size_t current_entity = 0;
int32 skipped_before_current_entity = 0;
int32 utf16_offset = 0;
int32 utf16_skipped = 0;
@ -2333,27 +2394,30 @@ static Result<string> clean_input_string_with_entities(const string &text, vecto
auto c = static_cast<unsigned char>(text[pos]);
bool is_utf8_character_begin = is_utf8_character_first_code_unit(c);
if (is_utf8_character_begin) {
if (in_entity) {
CHECK(current_entity < entities.size());
if (utf16_offset >= entities[current_entity].offset + entities[current_entity].length) {
if (utf16_offset != entities[current_entity].offset + entities[current_entity].length) {
CHECK(utf16_offset == entities[current_entity].offset + entities[current_entity].length + 1);
return Status::Error(16, PSLICE() << "Entity beginning at UTF-16 offset " << entities[current_entity].offset
<< " ends in a middle of a UTF-16 symbol at byte offset " << pos);
}
entities[current_entity].offset -= skipped_before_current_entity;
entities[current_entity].length -= utf16_skipped - skipped_before_current_entity;
in_entity = false;
current_entity++;
while (!nested_entities_stack.empty()) {
auto *entity = nested_entities_stack.back().entity;
auto entity_end = entity->offset + entity->length;
if (utf16_offset < entity_end) {
break;
}
if (utf16_offset != entity_end) {
CHECK(utf16_offset == entity_end + 1);
return Status::Error(400, PSLICE() << "Entity beginning at UTF-16 offset " << entity->offset
<< " ends in a middle of a UTF-16 symbol at byte offset " << pos);
}
auto skipped_before_current_entity = nested_entities_stack.back().utf16_skipped_before;
entity->offset -= skipped_before_current_entity;
entity->length -= utf16_skipped - skipped_before_current_entity;
nested_entities_stack.pop_back();
}
if (!in_entity && current_entity < entities.size() && utf16_offset >= entities[current_entity].offset) {
while (current_entity < entities.size() && utf16_offset >= entities[current_entity].offset) {
if (utf16_offset != entities[current_entity].offset) {
CHECK(utf16_offset == entities[current_entity].offset + 1);
return Status::Error(16, PSLICE() << "Entity begins in a middle of a UTF-16 symbol at byte offset " << pos);
return Status::Error(400, PSLICE() << "Entity begins in a middle of a UTF-16 symbol at byte offset " << pos);
}
in_entity = true;
skipped_before_current_entity = utf16_skipped;
nested_entities_stack.emplace_back(&entities[current_entity++], utf16_skipped);
}
}
if (pos == text_size) {
@ -2433,42 +2497,56 @@ static Result<string> clean_input_string_with_entities(const string &text, vecto
}
}
entities.resize(current_entity);
if (current_entity != entities.size()) {
return Status::Error(400, PSLICE() << "Entity begins after the end of the text at UTF-16 offset "
<< entities[current_entity].offset);
}
if (!nested_entities_stack.empty()) {
auto *entity = nested_entities_stack.back().entity;
return Status::Error(400, PSLICE() << "Entity beginning at UTF-16 offset " << entity->offset
<< " ends after the end of the text at UTF-16 offset "
<< entity->offset + entity->length);
}
return result;
}
// removes entities containing whitespaces only
static std::pair<size_t, int32> remove_invalid_entities(const string &text, vector<MessageEntity> &entities) {
size_t left_entities = 0;
vector<MessageEntity *> nested_entities_stack;
size_t current_entity = 0;
size_t text_size = text.size();
size_t last_non_whitespace_pos = text_size;
size_t last_non_whitespace_pos = text.size();
int32 utf16_offset = 0;
int32 last_space_utf16_offset = -1;
int32 last_non_whitespace_utf16_offset = -1;
for (size_t pos = 0; pos <= text.size(); pos++) {
if (current_entity < entities.size() &&
utf16_offset == entities[current_entity].offset + entities[current_entity].length) {
auto entity_offset = entities[current_entity].offset;
auto entity_type = entities[current_entity].type;
auto have_hidden_data =
entity_type == MessageEntity::Type::TextUrl || entity_type == MessageEntity::Type::MentionName;
if (last_non_whitespace_utf16_offset >= entity_offset ||
(last_space_utf16_offset >= entity_offset && have_hidden_data)) {
// TODO check entities for validness, for example, that mentions, hashtags, cashtags and URLs are valid
if (current_entity != left_entities) {
entities[left_entities] = std::move(entities[current_entity]);
}
left_entities++;
while (!nested_entities_stack.empty()) {
auto *entity = nested_entities_stack.back();
auto entity_end = entity->offset + entity->length;
if (utf16_offset < entity_end) {
break;
}
current_entity++;
auto have_hidden_data =
entity->type == MessageEntity::Type::TextUrl || entity->type == MessageEntity::Type::MentionName;
if (last_non_whitespace_utf16_offset >= entity->offset ||
(last_space_utf16_offset >= entity->offset && have_hidden_data)) {
// TODO check entity for validness, for example, that mentions, hashtags, cashtags and URLs are valid
// keep entity
} else {
entity->length = 0;
}
nested_entities_stack.pop_back();
}
while (current_entity < entities.size() && utf16_offset >= entities[current_entity].offset) {
nested_entities_stack.push_back(&entities[current_entity++]);
}
if (pos == text_size) {
if (pos == text.size()) {
break;
}
@ -2480,7 +2558,7 @@ static std::pair<size_t, int32> remove_invalid_entities(const string &text, vect
last_space_utf16_offset = utf16_offset;
break;
default:
while (pos + 1 < text_size && !is_utf8_character_first_code_unit(static_cast<unsigned char>(text[pos + 1]))) {
while (!is_utf8_character_first_code_unit(static_cast<unsigned char>(text[pos + 1]))) {
pos++;
}
utf16_offset += (c >= 0xf0); // >= 4 bytes in symbol => surrogaite pair
@ -2491,7 +2569,12 @@ static std::pair<size_t, int32> remove_invalid_entities(const string &text, vect
utf16_offset++;
}
entities.erase(entities.begin() + left_entities, entities.end());
CHECK(nested_entities_stack.empty());
CHECK(current_entity == entities.size());
entities.erase(
std::remove_if(entities.begin(), entities.end(), [](const auto &entity) { return entity.length == 0; }),
entities.end());
return {last_non_whitespace_pos, last_non_whitespace_utf16_offset};
}
@ -2567,9 +2650,13 @@ Status fix_formatted_text(string &text, vector<MessageEntity> &entities, bool al
new_size--;
}
text.resize(new_size);
while (!entities.empty() && entities.back().offset + entities.back().length > 8192) {
entities.pop_back();
}
entities.erase(
std::remove_if(entities.begin(), entities.end(),
[text_utf16_length = narrow_cast<int32>(utf8_utf16_length(text))](const auto &entity) {
return entity.offset + entity.length > text_utf16_length;
}),
entities.end());
}
if (!skip_new_entities) {