From 46562f56d0758797d13e2c1e7c2929605b566a70 Mon Sep 17 00:00:00 2001 From: levlam Date: Thu, 20 Oct 2022 17:52:32 +0300 Subject: [PATCH] Explicitly subscribe to updateTranscribedAudio updates. --- td/telegram/TranscriptionInfo.h | 4 -- td/telegram/UpdatesManager.cpp | 56 +++++++++++++++++++++++++- td/telegram/UpdatesManager.h | 14 +++++++ td/telegram/VoiceNotesManager.cpp | 65 +++++++++---------------------- td/telegram/VoiceNotesManager.h | 17 ++------ 5 files changed, 90 insertions(+), 66 deletions(-) diff --git a/td/telegram/TranscriptionInfo.h b/td/telegram/TranscriptionInfo.h index 7f89f032e..433cab5b5 100644 --- a/td/telegram/TranscriptionInfo.h +++ b/td/telegram/TranscriptionInfo.h @@ -31,10 +31,6 @@ class TranscriptionInfo { return is_transcribed_; } - int64 get_transcription_id() const { - return transcription_id_; - } - bool start_recognize_speech(Promise &&promise); vector> on_final_transcription(string &&text, int64 transcription_id); diff --git a/td/telegram/UpdatesManager.cpp b/td/telegram/UpdatesManager.cpp index f72768471..d25555590 100644 --- a/td/telegram/UpdatesManager.cpp +++ b/td/telegram/UpdatesManager.cpp @@ -183,6 +183,9 @@ const double UpdatesManager::MAX_PTS_SAVE_DELAY = 0.05; UpdatesManager::UpdatesManager(Td *td, ActorShared<> parent) : td_(td), parent_(std::move(parent)) { last_pts_save_time_ = last_qts_save_time_ = Time::now() - 2 * MAX_PTS_SAVE_DELAY; + + pending_audio_transcription_timeout_.set_callback(on_pending_audio_transcription_timeout_callback); + pending_audio_transcription_timeout_.set_callback_data(static_cast(td_)); } void UpdatesManager::tear_down() { @@ -309,6 +312,20 @@ void UpdatesManager::fill_gap(void *td, const char *source) { updates_manager->get_difference("fill_gap"); } +void UpdatesManager::on_pending_audio_transcription_timeout_callback(void *td, int64 transcription_id) { + if (G()->close_flag()) { + return; + } + CHECK(td != nullptr); + if (!static_cast(td)->auth_manager_->is_authorized()) { + return; + } + + auto updates_manager = static_cast(td)->updates_manager_.get(); + send_closure_later(updates_manager->actor_id(updates_manager), &UpdatesManager::on_pending_audio_transcription_failed, + transcription_id, Status::Error(500, "Timeout expired")); +} + void UpdatesManager::get_difference(const char *source) { if (G()->close_flag() || !td_->auth_manager_->is_authorized()) { return; @@ -1739,6 +1756,31 @@ void UpdatesManager::try_reload_data() { schedule_data_reload(); } +void UpdatesManager::subscribe_to_transcribed_audio_updates(int64 transcription_id, TranscribedAudioHandler on_update) { + if (pending_audio_transcriptions_.count(transcription_id) != 0) { + on_pending_audio_transcription_failed(transcription_id, + Status::Error(500, "Receive duplicate speech recognition identifier")); + } + bool is_inserted = pending_audio_transcriptions_.emplace(transcription_id, std::move(on_update)).second; + CHECK(is_inserted); + pending_audio_transcription_timeout_.set_timeout_in(transcription_id, AUDIO_TRANSCRIPTION_TIMEOUT); +} + +void UpdatesManager::on_pending_audio_transcription_failed(int64 transcription_id, Status &&error) { + if (G()->close_flag()) { + return; + } + auto it = pending_audio_transcriptions_.find(transcription_id); + if (it == pending_audio_transcriptions_.end()) { + return; + } + auto on_update = std::move(it->second); + pending_audio_transcriptions_.erase(it); + pending_audio_transcription_timeout_.cancel_timeout(transcription_id); + + on_update(std::move(error)); +} + void UpdatesManager::on_pending_updates(vector> &&updates, int32 seq_begin, int32 seq_end, int32 date, double receive_time, Promise &&promise, const char *source) { @@ -3614,8 +3656,18 @@ void UpdatesManager::on_update(tl_object_ptr } void UpdatesManager::on_update(tl_object_ptr update, Promise &&promise) { - td_->voice_notes_manager_->on_update_transcribed_audio(std::move(update->text_), update->transcription_id_, - !update->pending_); + auto it = pending_audio_transcriptions_.find(update->transcription_id_); + if (it == pending_audio_transcriptions_.end()) { + return promise.set_value(Unit()); + } + if (!update->pending_) { + auto on_update = std::move(it->second); + pending_audio_transcriptions_.erase(it); + pending_audio_transcription_timeout_.cancel_timeout(update->transcription_id_); + on_update(std::move(update)); + } else { + it->second(std::move(update)); + } promise.set_value(Unit()); } diff --git a/td/telegram/UpdatesManager.h b/td/telegram/UpdatesManager.h index 40645785f..f8f8fc021 100644 --- a/td/telegram/UpdatesManager.h +++ b/td/telegram/UpdatesManager.h @@ -17,9 +17,11 @@ #include "td/telegram/UserId.h" #include "td/actor/actor.h" +#include "td/actor/MultiTimeout.h" #include "td/actor/Timeout.h" #include "td/utils/common.h" +#include "td/utils/FlatHashMap.h" #include "td/utils/FlatHashSet.h" #include "td/utils/logging.h" #include "td/utils/Promise.h" @@ -114,6 +116,10 @@ class UpdatesManager final : public Actor { static int32 get_update_edit_message_pts(const telegram_api::Updates *updates_ptr, FullMessageId full_message_id); + using TranscribedAudioHandler = + std::function>)>; + void subscribe_to_transcribed_audio_updates(int64 transcription_id, TranscribedAudioHandler on_update); + void get_difference(const char *source); void schedule_get_difference(const char *source); @@ -133,6 +139,7 @@ class UpdatesManager final : public Actor { static const double MAX_PTS_SAVE_DELAY; static constexpr bool DROP_PTS_UPDATES = false; static constexpr const char *AFTER_GET_DIFFERENCE_SOURCE = "after get difference"; + static constexpr int32 AUDIO_TRANSCRIPTION_TIMEOUT = 60; friend class OnUpdate; @@ -232,6 +239,9 @@ class UpdatesManager final : public Actor { int32 min_postponed_update_qts_ = 0; double get_difference_start_time_ = 0; // time from which we started to get difference without success + FlatHashMap pending_audio_transcriptions_; + MultiTimeout pending_audio_transcription_timeout_{"PendingAudioTranscriptionTimeout"}; + void start_up() final; void tear_down() final; @@ -322,6 +332,8 @@ class UpdatesManager final : public Actor { static void fill_gap(void *td, const char *source); + static void on_pending_audio_transcription_timeout_callback(void *td, int64 transcription_id); + void set_pts_gap_timeout(double timeout); void set_seq_gap_timeout(double timeout); @@ -366,6 +378,8 @@ class UpdatesManager final : public Actor { static vector> *get_updates(telegram_api::Updates *updates_ptr); + void on_pending_audio_transcription_failed(int64 transcription_id, Status &&error); + bool is_acceptable_user(UserId user_id) const; bool is_acceptable_chat(ChatId chat_id) const; diff --git a/td/telegram/VoiceNotesManager.cpp b/td/telegram/VoiceNotesManager.cpp index 3df284d9e..2225fccea 100644 --- a/td/telegram/VoiceNotesManager.cpp +++ b/td/telegram/VoiceNotesManager.cpp @@ -16,6 +16,7 @@ #include "td/telegram/Td.h" #include "td/telegram/td_api.h" #include "td/telegram/telegram_api.h" +#include "td/telegram/UpdatesManager.h" #include "td/utils/buffer.h" #include "td/utils/logging.h" @@ -50,7 +51,7 @@ class TranscribeAudioQuery final : public Td::ResultHandler { return on_error(Status::Error(500, "Receive no recognition identifier")); } td_->voice_notes_manager_->on_voice_note_transcribed(file_id_, std::move(result->text_), result->transcription_id_, - !result->pending_); + true, !result->pending_); } void on_error(Status status) final { @@ -60,8 +61,6 @@ class TranscribeAudioQuery final : public Td::ResultHandler { }; VoiceNotesManager::VoiceNotesManager(Td *td, ActorShared<> parent) : td_(td), parent_(std::move(parent)) { - voice_note_transcription_timeout_.set_callback(on_voice_note_transcription_timeout_callback); - voice_note_transcription_timeout_.set_callback_data(static_cast(this)); } VoiceNotesManager::~VoiceNotesManager() { @@ -73,18 +72,6 @@ void VoiceNotesManager::tear_down() { parent_.reset(); } -void VoiceNotesManager::on_voice_note_transcription_timeout_callback(void *voice_notes_manager_ptr, - int64 transcription_id) { - if (G()->close_flag()) { - return; - } - - auto voice_notes_manager = static_cast(voice_notes_manager_ptr); - send_closure_later(voice_notes_manager->actor_id(voice_notes_manager), - &VoiceNotesManager::on_pending_voice_note_transcription_failed, transcription_id, - Status::Error(500, "Timeout expired")); -} - int32 VoiceNotesManager::get_voice_note_duration(FileId file_id) const { auto voice_note = get_voice_note(file_id); if (voice_note == nullptr) { @@ -230,7 +217,7 @@ void VoiceNotesManager::recognize_speech(FullMessageId full_message_id, Promise< } void VoiceNotesManager::on_voice_note_transcribed(FileId file_id, string &&text, int64 transcription_id, - bool is_final) { + bool is_initial, bool is_final) { auto voice_note = get_voice_note(file_id); CHECK(voice_note != nullptr); CHECK(voice_note->transcription_info != nullptr); @@ -244,50 +231,34 @@ void VoiceNotesManager::on_voice_note_transcribed(FileId file_id, string &&text, on_voice_note_transcription_updated(file_id); } - if (pending_voice_note_transcription_queries_.count(transcription_id) != 0) { - on_pending_voice_note_transcription_failed(transcription_id, - Status::Error(500, "Receive duplicate recognition identifier")); + if (is_initial) { + td_->updates_manager_->subscribe_to_transcribed_audio_updates( + transcription_id, [actor_id = actor_id(this), + file_id](Result> r_update) { + send_closure(actor_id, &VoiceNotesManager::on_transcribed_audio_update, file_id, std::move(r_update)); + }); } - bool is_inserted = pending_voice_note_transcription_queries_.emplace(transcription_id, file_id).second; - CHECK(is_inserted); - voice_note_transcription_timeout_.set_timeout_in(transcription_id, TRANSCRIPTION_TIMEOUT); } } +void VoiceNotesManager::on_transcribed_audio_update( + FileId file_id, Result> r_update) { + if (r_update.is_error()) { + return on_voice_note_transcription_failed(file_id, r_update.move_as_error()); + } + auto update = r_update.move_as_ok(); + on_voice_note_transcribed(file_id, std::move(update->text_), update->transcription_id_, false, !update->pending_); +} + void VoiceNotesManager::on_voice_note_transcription_failed(FileId file_id, Status &&error) { auto voice_note = get_voice_note(file_id); CHECK(voice_note != nullptr); CHECK(voice_note->transcription_info != nullptr); - CHECK(pending_voice_note_transcription_queries_.count(voice_note->transcription_info->get_transcription_id()) == 0); auto promises = voice_note->transcription_info->on_failed_transcription(error.clone()); on_voice_note_transcription_updated(file_id); fail_promises(promises, std::move(error)); } -void VoiceNotesManager::on_update_transcribed_audio(string &&text, int64 transcription_id, bool is_final) { - auto it = pending_voice_note_transcription_queries_.find(transcription_id); - if (it == pending_voice_note_transcription_queries_.end()) { - return; - } - auto file_id = it->second; - pending_voice_note_transcription_queries_.erase(it); - voice_note_transcription_timeout_.cancel_timeout(transcription_id); - - on_voice_note_transcribed(file_id, std::move(text), transcription_id, is_final); -} - -void VoiceNotesManager::on_pending_voice_note_transcription_failed(int64 transcription_id, Status &&error) { - auto it = pending_voice_note_transcription_queries_.find(transcription_id); - if (it == pending_voice_note_transcription_queries_.end()) { - return; - } - auto file_id = it->second; - pending_voice_note_transcription_queries_.erase(it); - voice_note_transcription_timeout_.cancel_timeout(transcription_id); - - on_voice_note_transcription_failed(file_id, std::move(error)); -} - void VoiceNotesManager::on_voice_note_transcription_updated(FileId file_id) { auto it = voice_note_messages_.find(file_id); if (it != voice_note_messages_.end()) { diff --git a/td/telegram/VoiceNotesManager.h b/td/telegram/VoiceNotesManager.h index 3e95578d0..52b6c2253 100644 --- a/td/telegram/VoiceNotesManager.h +++ b/td/telegram/VoiceNotesManager.h @@ -14,7 +14,6 @@ #include "td/telegram/TranscriptionInfo.h" #include "td/actor/actor.h" -#include "td/actor/MultiTimeout.h" #include "td/utils/common.h" #include "td/utils/FlatHashMap.h" @@ -50,9 +49,7 @@ class VoiceNotesManager final : public Actor { void rate_speech_recognition(FullMessageId full_message_id, bool is_good, Promise &&promise); - void on_update_transcribed_audio(string &&text, int64 transcription_id, bool is_final); - - void on_voice_note_transcribed(FileId file_id, string &&text, int64 transcription_id, bool is_final); + void on_voice_note_transcribed(FileId file_id, string &&text, int64 transcription_id, bool is_initial, bool is_final); void on_voice_note_transcription_failed(FileId file_id, Status &&error); @@ -74,8 +71,6 @@ class VoiceNotesManager final : public Actor { FileId parse_voice_note(ParserT &parser); private: - static constexpr int32 TRANSCRIPTION_TIMEOUT = 60; - class VoiceNote { public: string mime_type; @@ -86,20 +81,19 @@ class VoiceNotesManager final : public Actor { FileId file_id; }; - static void on_voice_note_transcription_timeout_callback(void *voice_notes_manager_ptr, int64 transcription_id); - VoiceNote *get_voice_note(FileId file_id); const VoiceNote *get_voice_note(FileId file_id) const; FileId on_get_voice_note(unique_ptr new_voice_note, bool replace); - void on_pending_voice_note_transcription_failed(int64 transcription_id, Status &&error); - void on_voice_note_transcription_updated(FileId file_id); void on_voice_note_transcription_completed(FileId file_id); + void on_transcribed_audio_update(FileId file_id, + Result> r_update); + void tear_down() final; Td *td_; @@ -107,9 +101,6 @@ class VoiceNotesManager final : public Actor { WaitFreeHashMap, FileIdHash> voice_notes_; - FlatHashMap pending_voice_note_transcription_queries_; - MultiTimeout voice_note_transcription_timeout_{"VoiceNoteTranscriptionTimeout"}; - FlatHashMap, FileIdHash> voice_note_messages_; FlatHashMap message_voice_notes_; };