Move transcription update subscription to TranscriptionManager.

This commit is contained in:
levlam 2023-11-21 20:32:23 +03:00
parent 1cdb210ed1
commit ab39c96b2c
6 changed files with 97 additions and 80 deletions

View File

@ -14,14 +14,6 @@
namespace td {
TranscriptionManager::TranscriptionManager(Td *td, ActorShared<> parent) : td_(td), parent_(std::move(parent)) {
load_trial_parameters();
}
void TranscriptionManager::tear_down() {
parent_.reset();
}
void TranscriptionManager::TrialParameters::update_left_tries() {
if (cooldown_until_ <= G()->unix_time()) {
cooldown_until_ = 0;
@ -88,6 +80,32 @@ bool operator==(const TranscriptionManager::TrialParameters &lhs, const Transcri
lhs.left_tries_ == rhs.left_tries_ && lhs.cooldown_until_ == rhs.cooldown_until_;
}
TranscriptionManager::TranscriptionManager(Td *td, ActorShared<> parent) : td_(td), parent_(std::move(parent)) {
load_trial_parameters();
pending_audio_transcription_timeout_.set_callback(on_pending_audio_transcription_timeout_callback);
pending_audio_transcription_timeout_.set_callback_data(static_cast<void *>(td_));
}
void TranscriptionManager::tear_down() {
parent_.reset();
}
void TranscriptionManager::on_pending_audio_transcription_timeout_callback(void *td, int64 transcription_id) {
if (G()->close_flag()) {
return;
}
CHECK(td != nullptr);
if (!static_cast<Td *>(td)->auth_manager_->is_authorized()) {
return;
}
auto transcription_manager = static_cast<Td *>(td)->transcription_manager_.get();
send_closure_later(transcription_manager->actor_id(transcription_manager),
&TranscriptionManager::on_pending_audio_transcription_failed, transcription_id,
Status::Error(500, "Timeout expired"));
}
string TranscriptionManager::get_trial_parameters_database_key() {
return "speech_recognition_trial";
}
@ -153,6 +171,50 @@ TranscriptionManager::TrialParameters::get_update_speech_recognition_trial_objec
cooldown_until_);
}
void TranscriptionManager::subscribe_to_transcribed_audio_updates(int64 transcription_id,
TranscribedAudioHandler on_update) {
CHECK(transcription_id != 0);
if (pending_audio_transcriptions_.count(transcription_id) != 0) {
on_pending_audio_transcription_failed(transcription_id,
Status::Error(500, "Receive duplicate speech recognition identifier"));
}
bool is_inserted = pending_audio_transcriptions_.emplace(transcription_id, std::move(on_update)).second;
CHECK(is_inserted);
pending_audio_transcription_timeout_.set_timeout_in(transcription_id, AUDIO_TRANSCRIPTION_TIMEOUT);
}
void TranscriptionManager::on_update_transcribed_audio(
telegram_api::object_ptr<telegram_api::updateTranscribedAudio> &&update) {
auto it = pending_audio_transcriptions_.find(update->transcription_id_);
if (it == pending_audio_transcriptions_.end()) {
return;
}
// flags_, dialog_id_ and message_id_ must not be used
if (!update->pending_) {
auto on_update = std::move(it->second);
pending_audio_transcriptions_.erase(it);
pending_audio_transcription_timeout_.cancel_timeout(update->transcription_id_);
on_update(std::move(update));
} else {
it->second(std::move(update));
}
}
void TranscriptionManager::on_pending_audio_transcription_failed(int64 transcription_id, Status &&error) {
if (G()->close_flag()) {
return;
}
auto it = pending_audio_transcriptions_.find(transcription_id);
if (it == pending_audio_transcriptions_.end()) {
return;
}
auto on_update = std::move(it->second);
pending_audio_transcriptions_.erase(it);
pending_audio_transcription_timeout_.cancel_timeout(transcription_id);
on_update(std::move(error));
}
void TranscriptionManager::get_current_state(vector<td_api::object_ptr<td_api::Update>> &updates) const {
if (!td_->auth_manager_->is_authorized() || td_->auth_manager_->is_bot()) {
return;

View File

@ -7,10 +7,16 @@
#pragma once
#include "td/telegram/td_api.h"
#include "td/telegram/telegram_api.h"
#include "td/actor/actor.h"
#include "td/actor/MultiTimeout.h"
#include "td/utils/common.h"
#include "td/utils/FlatHashMap.h"
#include "td/utils/Status.h"
#include <functional>
namespace td {
@ -22,11 +28,21 @@ class TranscriptionManager final : public Actor {
void on_update_trial_parameters(int32 weekly_number, int32 duration_max, int32 cooldown_until);
using TranscribedAudioHandler =
std::function<void(Result<telegram_api::object_ptr<telegram_api::updateTranscribedAudio>>)>;
void subscribe_to_transcribed_audio_updates(int64 transcription_id, TranscribedAudioHandler on_update);
void on_update_transcribed_audio(telegram_api::object_ptr<telegram_api::updateTranscribedAudio> &&update);
void get_current_state(vector<td_api::object_ptr<td_api::Update>> &updates) const;
private:
static constexpr int32 AUDIO_TRANSCRIPTION_TIMEOUT = 60;
void tear_down() final;
static void on_pending_audio_transcription_timeout_callback(void *td, int64 transcription_id);
static string get_trial_parameters_database_key();
void load_trial_parameters();
@ -37,6 +53,8 @@ class TranscriptionManager final : public Actor {
td_api::object_ptr<td_api::updateSpeechRecognitionTrial> get_update_speech_recognition_trial_object() const;
void on_pending_audio_transcription_failed(int64 transcription_id, Status &&error);
struct TrialParameters {
int32 weekly_number_ = 0;
int32 duration_max_ = 0;
@ -60,6 +78,9 @@ class TranscriptionManager final : public Actor {
ActorShared<> parent_;
TrialParameters trial_parameters_;
FlatHashMap<int64, TranscribedAudioHandler> pending_audio_transcriptions_;
MultiTimeout pending_audio_transcription_timeout_{"PendingAudioTranscriptionTimeout"};
};
} // namespace td

View File

@ -68,6 +68,7 @@
#include "td/telegram/telegram_api.h"
#include "td/telegram/telegram_api.hpp"
#include "td/telegram/ThemeManager.h"
#include "td/telegram/TranscriptionManager.h"
#include "td/telegram/Usernames.h"
#include "td/telegram/WebPagesManager.h"
@ -257,9 +258,6 @@ class GetPtsUpdateQuery final : public Td::ResultHandler {
UpdatesManager::UpdatesManager(Td *td, ActorShared<> parent) : td_(td), parent_(std::move(parent)) {
last_pts_save_time_ = last_qts_save_time_ = Time::now() - 2 * MAX_PTS_SAVE_DELAY;
pending_audio_transcription_timeout_.set_callback(on_pending_audio_transcription_timeout_callback);
pending_audio_transcription_timeout_.set_callback_data(static_cast<void *>(td_));
if (!td_->auth_manager_->is_authorized() || !td_->auth_manager_->is_bot()) {
skipped_postponed_updates_after_start_ = 0;
}
@ -452,20 +450,6 @@ void UpdatesManager::fill_gap(void *td, const char *source) {
updates_manager->get_difference("fill_gap");
}
void UpdatesManager::on_pending_audio_transcription_timeout_callback(void *td, int64 transcription_id) {
if (G()->close_flag()) {
return;
}
CHECK(td != nullptr);
if (!static_cast<Td *>(td)->auth_manager_->is_authorized()) {
return;
}
auto updates_manager = static_cast<Td *>(td)->updates_manager_.get();
send_closure_later(updates_manager->actor_id(updates_manager), &UpdatesManager::on_pending_audio_transcription_failed,
transcription_id, Status::Error(500, "Timeout expired"));
}
void UpdatesManager::get_difference(const char *source) {
if (G()->close_flag() || !td_->auth_manager_->is_authorized()) {
return;
@ -2291,32 +2275,6 @@ void UpdatesManager::on_data_reloaded() {
schedule_data_reload();
}
void UpdatesManager::subscribe_to_transcribed_audio_updates(int64 transcription_id, TranscribedAudioHandler on_update) {
CHECK(transcription_id != 0);
if (pending_audio_transcriptions_.count(transcription_id) != 0) {
on_pending_audio_transcription_failed(transcription_id,
Status::Error(500, "Receive duplicate speech recognition identifier"));
}
bool is_inserted = pending_audio_transcriptions_.emplace(transcription_id, std::move(on_update)).second;
CHECK(is_inserted);
pending_audio_transcription_timeout_.set_timeout_in(transcription_id, AUDIO_TRANSCRIPTION_TIMEOUT);
}
void UpdatesManager::on_pending_audio_transcription_failed(int64 transcription_id, Status &&error) {
if (G()->close_flag()) {
return;
}
auto it = pending_audio_transcriptions_.find(transcription_id);
if (it == pending_audio_transcriptions_.end()) {
return;
}
auto on_update = std::move(it->second);
pending_audio_transcriptions_.erase(it);
pending_audio_transcription_timeout_.cancel_timeout(transcription_id);
on_update(std::move(error));
}
void UpdatesManager::on_pending_updates(vector<tl_object_ptr<telegram_api::Update>> &&updates, int32 seq_begin,
int32 seq_end, int32 date, double receive_time, Promise<Unit> &&promise,
const char *source) {
@ -4327,19 +4285,7 @@ void UpdatesManager::on_update(tl_object_ptr<telegram_api::updateSavedRingtones>
}
void UpdatesManager::on_update(tl_object_ptr<telegram_api::updateTranscribedAudio> update, Promise<Unit> &&promise) {
auto it = pending_audio_transcriptions_.find(update->transcription_id_);
if (it == pending_audio_transcriptions_.end()) {
return promise.set_value(Unit());
}
// flags_, dialog_id_ and message_id_ must not be used
if (!update->pending_) {
auto on_update = std::move(it->second);
pending_audio_transcriptions_.erase(it);
pending_audio_transcription_timeout_.cancel_timeout(update->transcription_id_);
on_update(std::move(update));
} else {
it->second(std::move(update));
}
td_->transcription_manager_->on_update_transcribed_audio(std::move(update));
promise.set_value(Unit());
}

View File

@ -127,10 +127,6 @@ class UpdatesManager final : public Actor {
static int32 get_update_edit_message_pts(const telegram_api::Updates *updates_ptr, MessageFullId message_full_id);
using TranscribedAudioHandler =
std::function<void(Result<telegram_api::object_ptr<telegram_api::updateTranscribedAudio>>)>;
void subscribe_to_transcribed_audio_updates(int64 transcription_id, TranscribedAudioHandler on_update);
void get_difference(const char *source);
void schedule_get_difference(const char *source);
@ -156,7 +152,6 @@ class UpdatesManager final : public Actor {
static constexpr double UPDATE_APPLY_WARNING_TIME = 0.25;
static constexpr bool DROP_PTS_UPDATES = false;
static constexpr const char *AFTER_GET_DIFFERENCE_SOURCE = "after get difference";
static constexpr int32 AUDIO_TRANSCRIPTION_TIMEOUT = 60;
friend class OnUpdate;
@ -271,9 +266,6 @@ class UpdatesManager final : public Actor {
double get_difference_start_time_ = 0; // time from which we started to get difference without success
int32 get_difference_retry_count_ = 0;
FlatHashMap<int64, TranscribedAudioHandler> pending_audio_transcriptions_;
MultiTimeout pending_audio_transcription_timeout_{"PendingAudioTranscriptionTimeout"};
struct SessionInfo {
uint64 update_count = 0;
double first_update_time = 0.0;
@ -383,8 +375,6 @@ class UpdatesManager final : public Actor {
static void fill_gap(void *td, const char *source);
static void on_pending_audio_transcription_timeout_callback(void *td, int64 transcription_id);
void repair_pts_gap();
void on_get_pts_update(int32 pts, telegram_api::object_ptr<telegram_api::updates_Difference> difference_ptr);
@ -441,8 +431,6 @@ class UpdatesManager final : public Actor {
static vector<tl_object_ptr<telegram_api::Update>> *get_updates(telegram_api::Updates *updates_ptr);
void on_pending_audio_transcription_failed(int64 transcription_id, Status &&error);
bool is_acceptable_user(UserId user_id) const;
bool is_acceptable_chat(ChatId chat_id) const;

View File

@ -16,7 +16,7 @@
#include "td/telegram/Td.h"
#include "td/telegram/td_api.h"
#include "td/telegram/telegram_api.h"
#include "td/telegram/UpdatesManager.h"
#include "td/telegram/TranscriptionManager.h"
#include "td/actor/actor.h"
@ -252,7 +252,7 @@ void VideoNotesManager::on_transcribed_audio_update(
}
if (is_initial) {
td_->updates_manager_->subscribe_to_transcribed_audio_updates(
td_->transcription_manager_->subscribe_to_transcribed_audio_updates(
transcription_id, [actor_id = actor_id(this),
file_id](Result<telegram_api::object_ptr<telegram_api::updateTranscribedAudio>> r_update) {
send_closure(actor_id, &VideoNotesManager::on_transcribed_audio_update, file_id, false,

View File

@ -15,7 +15,7 @@
#include "td/telegram/Td.h"
#include "td/telegram/td_api.h"
#include "td/telegram/telegram_api.h"
#include "td/telegram/UpdatesManager.h"
#include "td/telegram/TranscriptionManager.h"
#include "td/utils/buffer.h"
#include "td/utils/logging.h"
@ -216,7 +216,7 @@ void VoiceNotesManager::on_transcribed_audio_update(
}
if (is_initial) {
td_->updates_manager_->subscribe_to_transcribed_audio_updates(
td_->transcription_manager_->subscribe_to_transcribed_audio_updates(
transcription_id, [actor_id = actor_id(this),
file_id](Result<telegram_api::object_ptr<telegram_api::updateTranscribedAudio>> r_update) {
send_closure(actor_id, &VoiceNotesManager::on_transcribed_audio_update, file_id, false,