Explicitly subscribe to updateTranscribedAudio updates.

This commit is contained in:
levlam 2022-10-20 17:52:32 +03:00
parent e079b684f0
commit 46562f56d0
5 changed files with 90 additions and 66 deletions

View File

@ -31,10 +31,6 @@ class TranscriptionInfo {
return is_transcribed_;
}
int64 get_transcription_id() const {
return transcription_id_;
}
bool start_recognize_speech(Promise<Unit> &&promise);
vector<Promise<Unit>> on_final_transcription(string &&text, int64 transcription_id);

View File

@ -183,6 +183,9 @@ const double UpdatesManager::MAX_PTS_SAVE_DELAY = 0.05;
UpdatesManager::UpdatesManager(Td *td, ActorShared<> parent) : td_(td), parent_(std::move(parent)) {
last_pts_save_time_ = last_qts_save_time_ = Time::now() - 2 * MAX_PTS_SAVE_DELAY;
pending_audio_transcription_timeout_.set_callback(on_pending_audio_transcription_timeout_callback);
pending_audio_transcription_timeout_.set_callback_data(static_cast<void *>(td_));
}
void UpdatesManager::tear_down() {
@ -309,6 +312,20 @@ void UpdatesManager::fill_gap(void *td, const char *source) {
updates_manager->get_difference("fill_gap");
}
void UpdatesManager::on_pending_audio_transcription_timeout_callback(void *td, int64 transcription_id) {
if (G()->close_flag()) {
return;
}
CHECK(td != nullptr);
if (!static_cast<Td *>(td)->auth_manager_->is_authorized()) {
return;
}
auto updates_manager = static_cast<Td *>(td)->updates_manager_.get();
send_closure_later(updates_manager->actor_id(updates_manager), &UpdatesManager::on_pending_audio_transcription_failed,
transcription_id, Status::Error(500, "Timeout expired"));
}
void UpdatesManager::get_difference(const char *source) {
if (G()->close_flag() || !td_->auth_manager_->is_authorized()) {
return;
@ -1739,6 +1756,31 @@ void UpdatesManager::try_reload_data() {
schedule_data_reload();
}
void UpdatesManager::subscribe_to_transcribed_audio_updates(int64 transcription_id, TranscribedAudioHandler on_update) {
if (pending_audio_transcriptions_.count(transcription_id) != 0) {
on_pending_audio_transcription_failed(transcription_id,
Status::Error(500, "Receive duplicate speech recognition identifier"));
}
bool is_inserted = pending_audio_transcriptions_.emplace(transcription_id, std::move(on_update)).second;
CHECK(is_inserted);
pending_audio_transcription_timeout_.set_timeout_in(transcription_id, AUDIO_TRANSCRIPTION_TIMEOUT);
}
void UpdatesManager::on_pending_audio_transcription_failed(int64 transcription_id, Status &&error) {
if (G()->close_flag()) {
return;
}
auto it = pending_audio_transcriptions_.find(transcription_id);
if (it == pending_audio_transcriptions_.end()) {
return;
}
auto on_update = std::move(it->second);
pending_audio_transcriptions_.erase(it);
pending_audio_transcription_timeout_.cancel_timeout(transcription_id);
on_update(std::move(error));
}
void UpdatesManager::on_pending_updates(vector<tl_object_ptr<telegram_api::Update>> &&updates, int32 seq_begin,
int32 seq_end, int32 date, double receive_time, Promise<Unit> &&promise,
const char *source) {
@ -3614,8 +3656,18 @@ void UpdatesManager::on_update(tl_object_ptr<telegram_api::updateSavedRingtones>
}
void UpdatesManager::on_update(tl_object_ptr<telegram_api::updateTranscribedAudio> update, Promise<Unit> &&promise) {
td_->voice_notes_manager_->on_update_transcribed_audio(std::move(update->text_), update->transcription_id_,
!update->pending_);
auto it = pending_audio_transcriptions_.find(update->transcription_id_);
if (it == pending_audio_transcriptions_.end()) {
return promise.set_value(Unit());
}
if (!update->pending_) {
auto on_update = std::move(it->second);
pending_audio_transcriptions_.erase(it);
pending_audio_transcription_timeout_.cancel_timeout(update->transcription_id_);
on_update(std::move(update));
} else {
it->second(std::move(update));
}
promise.set_value(Unit());
}

View File

@ -17,9 +17,11 @@
#include "td/telegram/UserId.h"
#include "td/actor/actor.h"
#include "td/actor/MultiTimeout.h"
#include "td/actor/Timeout.h"
#include "td/utils/common.h"
#include "td/utils/FlatHashMap.h"
#include "td/utils/FlatHashSet.h"
#include "td/utils/logging.h"
#include "td/utils/Promise.h"
@ -114,6 +116,10 @@ class UpdatesManager final : public Actor {
static int32 get_update_edit_message_pts(const telegram_api::Updates *updates_ptr, FullMessageId full_message_id);
using TranscribedAudioHandler =
std::function<void(Result<telegram_api::object_ptr<telegram_api::updateTranscribedAudio>>)>;
void subscribe_to_transcribed_audio_updates(int64 transcription_id, TranscribedAudioHandler on_update);
void get_difference(const char *source);
void schedule_get_difference(const char *source);
@ -133,6 +139,7 @@ class UpdatesManager final : public Actor {
static const double MAX_PTS_SAVE_DELAY;
static constexpr bool DROP_PTS_UPDATES = false;
static constexpr const char *AFTER_GET_DIFFERENCE_SOURCE = "after get difference";
static constexpr int32 AUDIO_TRANSCRIPTION_TIMEOUT = 60;
friend class OnUpdate;
@ -232,6 +239,9 @@ class UpdatesManager final : public Actor {
int32 min_postponed_update_qts_ = 0;
double get_difference_start_time_ = 0; // time from which we started to get difference without success
FlatHashMap<int64, TranscribedAudioHandler> pending_audio_transcriptions_;
MultiTimeout pending_audio_transcription_timeout_{"PendingAudioTranscriptionTimeout"};
void start_up() final;
void tear_down() final;
@ -322,6 +332,8 @@ class UpdatesManager final : public Actor {
static void fill_gap(void *td, const char *source);
static void on_pending_audio_transcription_timeout_callback(void *td, int64 transcription_id);
void set_pts_gap_timeout(double timeout);
void set_seq_gap_timeout(double timeout);
@ -366,6 +378,8 @@ class UpdatesManager final : public Actor {
static vector<tl_object_ptr<telegram_api::Update>> *get_updates(telegram_api::Updates *updates_ptr);
void on_pending_audio_transcription_failed(int64 transcription_id, Status &&error);
bool is_acceptable_user(UserId user_id) const;
bool is_acceptable_chat(ChatId chat_id) const;

View File

@ -16,6 +16,7 @@
#include "td/telegram/Td.h"
#include "td/telegram/td_api.h"
#include "td/telegram/telegram_api.h"
#include "td/telegram/UpdatesManager.h"
#include "td/utils/buffer.h"
#include "td/utils/logging.h"
@ -50,7 +51,7 @@ class TranscribeAudioQuery final : public Td::ResultHandler {
return on_error(Status::Error(500, "Receive no recognition identifier"));
}
td_->voice_notes_manager_->on_voice_note_transcribed(file_id_, std::move(result->text_), result->transcription_id_,
!result->pending_);
true, !result->pending_);
}
void on_error(Status status) final {
@ -60,8 +61,6 @@ class TranscribeAudioQuery final : public Td::ResultHandler {
};
VoiceNotesManager::VoiceNotesManager(Td *td, ActorShared<> parent) : td_(td), parent_(std::move(parent)) {
voice_note_transcription_timeout_.set_callback(on_voice_note_transcription_timeout_callback);
voice_note_transcription_timeout_.set_callback_data(static_cast<void *>(this));
}
VoiceNotesManager::~VoiceNotesManager() {
@ -73,18 +72,6 @@ void VoiceNotesManager::tear_down() {
parent_.reset();
}
void VoiceNotesManager::on_voice_note_transcription_timeout_callback(void *voice_notes_manager_ptr,
int64 transcription_id) {
if (G()->close_flag()) {
return;
}
auto voice_notes_manager = static_cast<VoiceNotesManager *>(voice_notes_manager_ptr);
send_closure_later(voice_notes_manager->actor_id(voice_notes_manager),
&VoiceNotesManager::on_pending_voice_note_transcription_failed, transcription_id,
Status::Error(500, "Timeout expired"));
}
int32 VoiceNotesManager::get_voice_note_duration(FileId file_id) const {
auto voice_note = get_voice_note(file_id);
if (voice_note == nullptr) {
@ -230,7 +217,7 @@ void VoiceNotesManager::recognize_speech(FullMessageId full_message_id, Promise<
}
void VoiceNotesManager::on_voice_note_transcribed(FileId file_id, string &&text, int64 transcription_id,
bool is_final) {
bool is_initial, bool is_final) {
auto voice_note = get_voice_note(file_id);
CHECK(voice_note != nullptr);
CHECK(voice_note->transcription_info != nullptr);
@ -244,50 +231,34 @@ void VoiceNotesManager::on_voice_note_transcribed(FileId file_id, string &&text,
on_voice_note_transcription_updated(file_id);
}
if (pending_voice_note_transcription_queries_.count(transcription_id) != 0) {
on_pending_voice_note_transcription_failed(transcription_id,
Status::Error(500, "Receive duplicate recognition identifier"));
if (is_initial) {
td_->updates_manager_->subscribe_to_transcribed_audio_updates(
transcription_id, [actor_id = actor_id(this),
file_id](Result<telegram_api::object_ptr<telegram_api::updateTranscribedAudio>> r_update) {
send_closure(actor_id, &VoiceNotesManager::on_transcribed_audio_update, file_id, std::move(r_update));
});
}
bool is_inserted = pending_voice_note_transcription_queries_.emplace(transcription_id, file_id).second;
CHECK(is_inserted);
voice_note_transcription_timeout_.set_timeout_in(transcription_id, TRANSCRIPTION_TIMEOUT);
}
}
void VoiceNotesManager::on_transcribed_audio_update(
FileId file_id, Result<telegram_api::object_ptr<telegram_api::updateTranscribedAudio>> r_update) {
if (r_update.is_error()) {
return on_voice_note_transcription_failed(file_id, r_update.move_as_error());
}
auto update = r_update.move_as_ok();
on_voice_note_transcribed(file_id, std::move(update->text_), update->transcription_id_, false, !update->pending_);
}
void VoiceNotesManager::on_voice_note_transcription_failed(FileId file_id, Status &&error) {
auto voice_note = get_voice_note(file_id);
CHECK(voice_note != nullptr);
CHECK(voice_note->transcription_info != nullptr);
CHECK(pending_voice_note_transcription_queries_.count(voice_note->transcription_info->get_transcription_id()) == 0);
auto promises = voice_note->transcription_info->on_failed_transcription(error.clone());
on_voice_note_transcription_updated(file_id);
fail_promises(promises, std::move(error));
}
void VoiceNotesManager::on_update_transcribed_audio(string &&text, int64 transcription_id, bool is_final) {
auto it = pending_voice_note_transcription_queries_.find(transcription_id);
if (it == pending_voice_note_transcription_queries_.end()) {
return;
}
auto file_id = it->second;
pending_voice_note_transcription_queries_.erase(it);
voice_note_transcription_timeout_.cancel_timeout(transcription_id);
on_voice_note_transcribed(file_id, std::move(text), transcription_id, is_final);
}
void VoiceNotesManager::on_pending_voice_note_transcription_failed(int64 transcription_id, Status &&error) {
auto it = pending_voice_note_transcription_queries_.find(transcription_id);
if (it == pending_voice_note_transcription_queries_.end()) {
return;
}
auto file_id = it->second;
pending_voice_note_transcription_queries_.erase(it);
voice_note_transcription_timeout_.cancel_timeout(transcription_id);
on_voice_note_transcription_failed(file_id, std::move(error));
}
void VoiceNotesManager::on_voice_note_transcription_updated(FileId file_id) {
auto it = voice_note_messages_.find(file_id);
if (it != voice_note_messages_.end()) {

View File

@ -14,7 +14,6 @@
#include "td/telegram/TranscriptionInfo.h"
#include "td/actor/actor.h"
#include "td/actor/MultiTimeout.h"
#include "td/utils/common.h"
#include "td/utils/FlatHashMap.h"
@ -50,9 +49,7 @@ class VoiceNotesManager final : public Actor {
void rate_speech_recognition(FullMessageId full_message_id, bool is_good, Promise<Unit> &&promise);
void on_update_transcribed_audio(string &&text, int64 transcription_id, bool is_final);
void on_voice_note_transcribed(FileId file_id, string &&text, int64 transcription_id, bool is_final);
void on_voice_note_transcribed(FileId file_id, string &&text, int64 transcription_id, bool is_initial, bool is_final);
void on_voice_note_transcription_failed(FileId file_id, Status &&error);
@ -74,8 +71,6 @@ class VoiceNotesManager final : public Actor {
FileId parse_voice_note(ParserT &parser);
private:
static constexpr int32 TRANSCRIPTION_TIMEOUT = 60;
class VoiceNote {
public:
string mime_type;
@ -86,20 +81,19 @@ class VoiceNotesManager final : public Actor {
FileId file_id;
};
static void on_voice_note_transcription_timeout_callback(void *voice_notes_manager_ptr, int64 transcription_id);
VoiceNote *get_voice_note(FileId file_id);
const VoiceNote *get_voice_note(FileId file_id) const;
FileId on_get_voice_note(unique_ptr<VoiceNote> new_voice_note, bool replace);
void on_pending_voice_note_transcription_failed(int64 transcription_id, Status &&error);
void on_voice_note_transcription_updated(FileId file_id);
void on_voice_note_transcription_completed(FileId file_id);
void on_transcribed_audio_update(FileId file_id,
Result<telegram_api::object_ptr<telegram_api::updateTranscribedAudio>> r_update);
void tear_down() final;
Td *td_;
@ -107,9 +101,6 @@ class VoiceNotesManager final : public Actor {
WaitFreeHashMap<FileId, unique_ptr<VoiceNote>, FileIdHash> voice_notes_;
FlatHashMap<int64, FileId> pending_voice_note_transcription_queries_;
MultiTimeout voice_note_transcription_timeout_{"VoiceNoteTranscriptionTimeout"};
FlatHashMap<FileId, FlatHashSet<FullMessageId, FullMessageIdHash>, FileIdHash> voice_note_messages_;
FlatHashMap<FullMessageId, FileId, FullMessageIdHash> message_voice_notes_;
};