From 5fe3a7ca949a18ac716b0b853877a9a859a2b18c Mon Sep 17 00:00:00 2001 From: levlam Date: Wed, 19 Oct 2022 20:43:30 +0300 Subject: [PATCH] Add class TranscriptionInfo. --- CMakeLists.txt | 3 + td/telegram/TranscriptionInfo.cpp | 108 ++++++++++++++++++++++++++++++ td/telegram/TranscriptionInfo.h | 56 ++++++++++++++++ td/telegram/TranscriptionInfo.hpp | 31 +++++++++ td/telegram/VoiceNotesManager.cpp | 83 +++++------------------ td/telegram/VoiceNotesManager.h | 7 +- td/telegram/VoiceNotesManager.hpp | 19 +++--- 7 files changed, 228 insertions(+), 79 deletions(-) create mode 100644 td/telegram/TranscriptionInfo.cpp create mode 100644 td/telegram/TranscriptionInfo.h create mode 100644 td/telegram/TranscriptionInfo.hpp diff --git a/CMakeLists.txt b/CMakeLists.txt index fd69f9229..7c318c193 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -454,6 +454,7 @@ set(TDLIB_SOURCE td/telegram/ThemeManager.cpp td/telegram/TopDialogCategory.cpp td/telegram/TopDialogManager.cpp + td/telegram/TranscriptionInfo.cpp td/telegram/UpdatesManager.cpp td/telegram/Usernames.cpp td/telegram/Venue.cpp @@ -718,6 +719,7 @@ set(TDLIB_SOURCE td/telegram/ThemeManager.h td/telegram/TopDialogCategory.h td/telegram/TopDialogManager.h + td/telegram/TranscriptionInfo.h td/telegram/UniqueId.h td/telegram/UpdatesManager.h td/telegram/UserId.h @@ -769,6 +771,7 @@ set(TDLIB_SOURCE td/telegram/SendCodeHelper.hpp td/telegram/StickerSetId.hpp td/telegram/StickersManager.hpp + td/telegram/TranscriptionInfo.hpp td/telegram/VideoNotesManager.hpp td/telegram/VideosManager.hpp td/telegram/VoiceNotesManager.hpp diff --git a/td/telegram/TranscriptionInfo.cpp b/td/telegram/TranscriptionInfo.cpp new file mode 100644 index 000000000..abce5f265 --- /dev/null +++ b/td/telegram/TranscriptionInfo.cpp @@ -0,0 +1,108 @@ +// +// Copyright Aliaksei Levin (levlam@telegram.org), Arseny Smirnov (arseny30@gmail.com) 2014-2022 +// +// Distributed under the Boost Software License, Version 1.0. (See accompanying +// file LICENSE_1_0.txt or copy at http://www.boost.org/LICENSE_1_0.txt) +// +#include "td/telegram/TranscriptionInfo.h" + +namespace td { + +bool TranscriptionInfo::start_recognize_speech(Promise &&promise) { + if (is_transcribed_) { + promise.set_value(Unit()); + return false; + } + speech_recognition_queries_.push_back(std::move(promise)); + if (speech_recognition_queries_.size() == 1) { + last_transcription_error_ = Status::OK(); + return true; + } + return false; +} + +vector> TranscriptionInfo::on_final_transcription(string &&text, int64 transcription_id) { + CHECK(!is_transcribed_); + CHECK(transcription_id_ == 0 || transcription_id_ == transcription_id); + CHECK(transcription_id != 0); + transcription_id_ = transcription_id; + is_transcribed_ = true; + text_ = std::move(text); + last_transcription_error_ = Status::OK(); + + CHECK(!speech_recognition_queries_.empty()); + auto promises = std::move(speech_recognition_queries_); + speech_recognition_queries_.clear(); + + return std::move(promises); +} + +bool TranscriptionInfo::on_partial_transcription(string &&text, int64 transcription_id) { + CHECK(!is_transcribed_); + CHECK(transcription_id_ == 0 || transcription_id_ == transcription_id); + CHECK(transcription_id != 0); + bool is_changed = text_ != text; + transcription_id_ = transcription_id; + text_ = std::move(text); + last_transcription_error_ = Status::OK(); + + return is_changed; +} + +vector> TranscriptionInfo::on_failed_transcription(Status &&error) { + CHECK(!is_transcribed_); + transcription_id_ = 0; + text_.clear(); + last_transcription_error_ = std::move(error); + + CHECK(!speech_recognition_queries_.empty()); + auto promises = std::move(speech_recognition_queries_); + speech_recognition_queries_.clear(); + return promises; +} + +unique_ptr TranscriptionInfo::copy_if_transcribed(const unique_ptr &info) { + if (info == nullptr || !info->is_transcribed_) { + return nullptr; + } + auto result = make_unique(); + result->is_transcribed_ = true; + result->transcription_id_ = info->transcription_id_; + result->text_ = info->text_; + return result; +} + +bool TranscriptionInfo::update_from(unique_ptr &old_info, unique_ptr &&new_info) { + if (new_info == nullptr || !new_info->is_transcribed_) { + return false; + } + CHECK(new_info->transcription_id_ != 0); + CHECK(new_info->last_transcription_error_.is_ok()); + CHECK(new_info->speech_recognition_queries_.empty()); + if (old_info == nullptr) { + old_info = std::move(new_info); + return true; + } + if (old_info->transcription_id_ != 0 || !old_info->speech_recognition_queries_.empty()) { + return false; + } + CHECK(!old_info->is_transcribed_); + old_info = std::move(new_info); + return true; +} + +td_api::object_ptr TranscriptionInfo::get_speech_recognition_result_object() const { + if (is_transcribed_) { + return td_api::make_object(text_); + } + if (!speech_recognition_queries_.empty()) { + return td_api::make_object(text_); + } + if (last_transcription_error_.is_error()) { + return td_api::make_object(td_api::make_object( + last_transcription_error_.code(), last_transcription_error_.message().str())); + } + return nullptr; +} + +} // namespace td diff --git a/td/telegram/TranscriptionInfo.h b/td/telegram/TranscriptionInfo.h new file mode 100644 index 000000000..747012700 --- /dev/null +++ b/td/telegram/TranscriptionInfo.h @@ -0,0 +1,56 @@ +// +// Copyright Aliaksei Levin (levlam@telegram.org), Arseny Smirnov (arseny30@gmail.com) 2014-2022 +// +// Distributed under the Boost Software License, Version 1.0. (See accompanying +// file LICENSE_1_0.txt or copy at http://www.boost.org/LICENSE_1_0.txt) +// +#pragma once + +#include "td/telegram/td_api.h" + +#include "td/utils/common.h" +#include "td/utils/Promise.h" +#include "td/utils/Status.h" + +namespace td { + +class TranscriptionInfo { + bool is_transcribed_ = false; + int64 transcription_id_ = 0; + string text_; + + // temporary state + Status last_transcription_error_; + vector> speech_recognition_queries_; + + public: + bool is_transcribed() const { + return is_transcribed_; + } + + int64 get_transcription_id() const { + return transcription_id_; + } + + bool start_recognize_speech(Promise &&promise); + + vector> on_final_transcription(string &&text, int64 transcription_id); + + bool on_partial_transcription(string &&text, int64 transcription_id); + + vector> on_failed_transcription(Status &&error); + + static unique_ptr copy_if_transcribed(const unique_ptr &info); + + static bool update_from(unique_ptr &old_info, unique_ptr &&new_info); + + td_api::object_ptr get_speech_recognition_result_object() const; + + template + void store(StorerT &storer) const; + + template + void parse(ParserT &parser); +}; + +} // namespace td diff --git a/td/telegram/TranscriptionInfo.hpp b/td/telegram/TranscriptionInfo.hpp new file mode 100644 index 000000000..224d637dc --- /dev/null +++ b/td/telegram/TranscriptionInfo.hpp @@ -0,0 +1,31 @@ +// +// Copyright Aliaksei Levin (levlam@telegram.org), Arseny Smirnov (arseny30@gmail.com) 2014-2022 +// +// Distributed under the Boost Software License, Version 1.0. (See accompanying +// file LICENSE_1_0.txt or copy at http://www.boost.org/LICENSE_1_0.txt) +// +#pragma once + +#include "td/telegram/TranscriptionInfo.h" + +#include "td/utils/common.h" +#include "td/utils/tl_helpers.h" + +namespace td { + +template +void TranscriptionInfo::store(StorerT &storer) const { + CHECK(is_transcribed()); + td::store(transcription_id_, storer); + td::store(text_, storer); +} + +template +void TranscriptionInfo::parse(ParserT &parser) { + is_transcribed_ = true; + td::parse(transcription_id_, parser); + td::parse(text_, parser); + CHECK(transcription_id_ != 0); +} + +} // namespace td diff --git a/td/telegram/VoiceNotesManager.cpp b/td/telegram/VoiceNotesManager.cpp index af2de051d..d1fc9f2f7 100644 --- a/td/telegram/VoiceNotesManager.cpp +++ b/td/telegram/VoiceNotesManager.cpp @@ -136,22 +136,9 @@ tl_object_ptr VoiceNotesManager::get_voice_note_object(FileId auto voice_note = get_voice_note(file_id); CHECK(voice_note != nullptr); - - auto speech_recognition_result = [this, voice_note]() -> td_api::object_ptr { - if (voice_note->is_transcribed) { - return td_api::make_object(voice_note->text); - } - if (speech_recognition_queries_.count(voice_note->file_id) != 0) { - return td_api::make_object(voice_note->text); - } - if (voice_note->last_transcription_error.is_error()) { - return td_api::make_object( - td_api::make_object(voice_note->last_transcription_error.error().code(), - voice_note->last_transcription_error.error().message().str())); - } - return nullptr; - }(); - + auto speech_recognition_result = voice_note->transcription_info == nullptr + ? nullptr + : voice_note->transcription_info->get_speech_recognition_result_object(); return make_tl_object(voice_note->duration, voice_note->waveform, voice_note->mime_type, std::move(speech_recognition_result), td_->file_manager_->get_file_object(file_id)); @@ -175,13 +162,7 @@ FileId VoiceNotesManager::on_get_voice_note(unique_ptr new_voice_note v->duration = new_voice_note->duration; v->waveform = std::move(new_voice_note->waveform); } - if (new_voice_note->is_transcribed && v->transcription_id == 0) { - CHECK(!v->is_transcribed); - CHECK(new_voice_note->transcription_id != 0); - v->is_transcribed = true; - v->transcription_id = new_voice_note->transcription_id; - v->text = std::move(new_voice_note->text); - v->last_transcription_error = Status::OK(); + if (TranscriptionInfo::update_from(v->transcription_info, std::move(new_voice_note->transcription_info))) { on_voice_note_transcription_completed(file_id); } } @@ -207,10 +188,7 @@ FileId VoiceNotesManager::dup_voice_note(FileId new_id, FileId old_id) { new_voice_note->mime_type = old_voice_note->mime_type; new_voice_note->duration = old_voice_note->duration; new_voice_note->waveform = old_voice_note->waveform; - if (old_voice_note->is_transcribed) { - new_voice_note->is_transcribed = old_voice_note->is_transcribed; - new_voice_note->text = old_voice_note->text; - } + new_voice_note->transcription_info = TranscriptionInfo::copy_if_transcribed(old_voice_note->transcription_info); return new_id; } @@ -278,14 +256,11 @@ void VoiceNotesManager::recognize_speech(FullMessageId full_message_id, Promise< auto file_id = it->second; auto voice_note = get_voice_note(file_id); CHECK(voice_note != nullptr); - if (voice_note->is_transcribed) { - return promise.set_value(Unit()); + if (voice_note->transcription_info == nullptr) { + voice_note->transcription_info = make_unique(); } - auto &queries = speech_recognition_queries_[file_id]; - queries.push_back(std::move(promise)); - if (queries.size() == 1) { + if (voice_note->transcription_info->start_recognize_speech(std::move(promise))) { td_->create_handler()->send(file_id, full_message_id); - voice_note->last_transcription_error = Status::OK(); on_voice_note_transcription_updated(file_id); } } @@ -294,25 +269,13 @@ void VoiceNotesManager::on_voice_note_transcribed(FileId file_id, string &&text, bool is_final) { auto voice_note = get_voice_note(file_id); CHECK(voice_note != nullptr); - CHECK(!voice_note->is_transcribed); - CHECK(voice_note->transcription_id == 0 || voice_note->transcription_id == transcription_id); - CHECK(transcription_id != 0); - bool is_changed = voice_note->is_transcribed != is_final || voice_note->text != text; - voice_note->transcription_id = transcription_id; - voice_note->is_transcribed = is_final; - voice_note->text = std::move(text); - voice_note->last_transcription_error = Status::OK(); - + CHECK(voice_note->transcription_info != nullptr); if (is_final) { - auto it = speech_recognition_queries_.find(file_id); - CHECK(it != speech_recognition_queries_.end()); - CHECK(!it->second.empty()); - auto promises = std::move(it->second); - speech_recognition_queries_.erase(it); - + auto promises = voice_note->transcription_info->on_final_transcription(std::move(text), transcription_id); on_voice_note_transcription_completed(file_id); set_promises(promises); } else { + auto is_changed = voice_note->transcription_info->on_partial_transcription(std::move(text), transcription_id); if (is_changed) { on_voice_note_transcription_updated(file_id); } @@ -330,19 +293,9 @@ void VoiceNotesManager::on_voice_note_transcribed(FileId file_id, string &&text, void VoiceNotesManager::on_voice_note_transcription_failed(FileId file_id, Status &&error) { auto voice_note = get_voice_note(file_id); CHECK(voice_note != nullptr); - CHECK(!voice_note->is_transcribed); - CHECK(pending_voice_note_transcription_queries_.count(voice_note->transcription_id) == 0); - - voice_note->transcription_id = 0; - voice_note->text.clear(); - voice_note->last_transcription_error = error.clone(); - - auto it = speech_recognition_queries_.find(file_id); - CHECK(it != speech_recognition_queries_.end()); - CHECK(!it->second.empty()); - auto promises = std::move(it->second); - speech_recognition_queries_.erase(it); - + CHECK(voice_note->transcription_info != nullptr); + CHECK(pending_voice_note_transcription_queries_.count(voice_note->transcription_info->get_transcription_id()) == 0); + auto promises = voice_note->transcription_info->on_failed_transcription(error.clone()); on_voice_note_transcription_updated(file_id); fail_promises(promises, std::move(error)); } @@ -396,12 +349,12 @@ void VoiceNotesManager::rate_speech_recognition(FullMessageId full_message_id, b auto file_id = it->second; auto voice_note = get_voice_note(file_id); CHECK(voice_note != nullptr); - if (!voice_note->is_transcribed) { + if (voice_note->transcription_info == nullptr || !voice_note->transcription_info->is_transcribed()) { return promise.set_value(Unit()); } - CHECK(voice_note->transcription_id != 0); - td_->create_handler(std::move(promise)) - ->send(full_message_id, voice_note->transcription_id, is_good); + auto transcription_id = voice_note->transcription_info->get_transcription_id(); + CHECK(transcription_id != 0); + td_->create_handler(std::move(promise))->send(full_message_id, transcription_id, is_good); } SecretInputMedia VoiceNotesManager::get_secret_input_media(FileId voice_note_file_id, diff --git a/td/telegram/VoiceNotesManager.h b/td/telegram/VoiceNotesManager.h index e77b94f57..3e95578d0 100644 --- a/td/telegram/VoiceNotesManager.h +++ b/td/telegram/VoiceNotesManager.h @@ -11,6 +11,7 @@ #include "td/telegram/SecretInputMedia.h" #include "td/telegram/td_api.h" #include "td/telegram/telegram_api.h" +#include "td/telegram/TranscriptionInfo.h" #include "td/actor/actor.h" #include "td/actor/MultiTimeout.h" @@ -79,11 +80,8 @@ class VoiceNotesManager final : public Actor { public: string mime_type; int32 duration = 0; - bool is_transcribed = false; string waveform; - int64 transcription_id = 0; - string text; - Status last_transcription_error; + unique_ptr transcription_info; FileId file_id; }; @@ -109,7 +107,6 @@ class VoiceNotesManager final : public Actor { WaitFreeHashMap, FileIdHash> voice_notes_; - FlatHashMap>, FileIdHash> speech_recognition_queries_; FlatHashMap pending_voice_note_transcription_queries_; MultiTimeout voice_note_transcription_timeout_{"VoiceNoteTranscriptionTimeout"}; diff --git a/td/telegram/VoiceNotesManager.hpp b/td/telegram/VoiceNotesManager.hpp index a17279bd2..49cf69b52 100644 --- a/td/telegram/VoiceNotesManager.hpp +++ b/td/telegram/VoiceNotesManager.hpp @@ -9,6 +9,7 @@ #include "td/telegram/VoiceNotesManager.h" #include "td/telegram/files/FileId.hpp" +#include "td/telegram/TranscriptionInfo.hpp" #include "td/telegram/Version.h" #include "td/utils/common.h" @@ -23,11 +24,12 @@ void VoiceNotesManager::store_voice_note(FileId file_id, StorerT &storer) const bool has_mime_type = !voice_note->mime_type.empty(); bool has_duration = voice_note->duration != 0; bool has_waveform = !voice_note->waveform.empty(); + bool is_transcribed = voice_note->transcription_info != nullptr && voice_note->transcription_info->is_transcribed(); BEGIN_STORE_FLAGS(); STORE_FLAG(has_mime_type); STORE_FLAG(has_duration); STORE_FLAG(has_waveform); - STORE_FLAG(voice_note->is_transcribed); + STORE_FLAG(is_transcribed); END_STORE_FLAGS(); if (has_mime_type) { store(voice_note->mime_type, storer); @@ -38,9 +40,8 @@ void VoiceNotesManager::store_voice_note(FileId file_id, StorerT &storer) const if (has_waveform) { store(voice_note->waveform, storer); } - if (voice_note->is_transcribed) { - store(voice_note->transcription_id, storer); - store(voice_note->text, storer); + if (is_transcribed) { + store(voice_note->transcription_info, storer); } store(file_id, storer); } @@ -51,18 +52,19 @@ FileId VoiceNotesManager::parse_voice_note(ParserT &parser) { bool has_mime_type; bool has_duration; bool has_waveform; + bool is_transcribed; if (parser.version() >= static_cast(Version::AddVoiceNoteFlags)) { BEGIN_PARSE_FLAGS(); PARSE_FLAG(has_mime_type); PARSE_FLAG(has_duration); PARSE_FLAG(has_waveform); - PARSE_FLAG(voice_note->is_transcribed); + PARSE_FLAG(is_transcribed); END_PARSE_FLAGS(); } else { has_mime_type = true; has_duration = true; has_waveform = true; - voice_note->is_transcribed = false; + is_transcribed = false; } if (has_mime_type) { parse(voice_note->mime_type, parser); @@ -73,9 +75,8 @@ FileId VoiceNotesManager::parse_voice_note(ParserT &parser) { if (has_waveform) { parse(voice_note->waveform, parser); } - if (voice_note->is_transcribed) { - parse(voice_note->transcription_id, parser); - parse(voice_note->text, parser); + if (is_transcribed) { + parse(voice_note->transcription_info, parser); } parse(voice_note->file_id, parser); if (parser.get_error() != nullptr || !voice_note->file_id.is_valid()) {