Move speech recognition to TranscriptionManager.

This commit is contained in:
levlam 2023-11-23 13:25:56 +03:00
parent ab39c96b2c
commit ec109dfd4a
12 changed files with 230 additions and 363 deletions

View File

@ -7059,7 +7059,7 @@ translateText text:formattedText to_language_code:string = FormattedText;
//-"st", "sn", "sd", "si", "sk", "sl", "so", "es", "su", "sw", "sv", "tl", "tg", "ta", "tt", "te", "th", "tr", "tk", "uk", "ur", "ug", "uz", "vi", "cy", "xh", "yi", "ji", "yo", "zu"
translateMessageText chat_id:int53 message_id:int53 to_language_code:string = FormattedText;
//@description Recognizes speech in a video note or a voice note message. The message must be successfully sent and must not be scheduled
//@description Recognizes speech in a video note or a voice note message. The message must be successfully sent, must not be scheduled, and must be from a non-secret chat
//@chat_id Identifier of the chat to which the message belongs
//@message_id Identifier of the message
recognizeSpeech chat_id:int53 message_id:int53 = Ok;

View File

@ -79,6 +79,7 @@
#include "td/telegram/Td.h"
#include "td/telegram/telegram_api.h"
#include "td/telegram/TopDialogManager.h"
#include "td/telegram/TranscriptionManager.h"
#include "td/telegram/UserId.h"
#include "td/telegram/Venue.h"
#include "td/telegram/Version.h"
@ -4885,7 +4886,8 @@ static CustomEmojiId get_custom_emoji_id(const FormattedText &text) {
void register_message_content(Td *td, const MessageContent *content, MessageFullId message_full_id,
const char *source) {
switch (content->get_type()) {
auto content_type = content->get_type();
switch (content_type) {
case MessageContentType::Text: {
auto text = static_cast<const MessageText *>(content);
if (text->web_page_id.is_valid()) {
@ -4897,11 +4899,11 @@ void register_message_content(Td *td, const MessageContent *content, MessageFull
return;
}
case MessageContentType::VideoNote:
return td->video_notes_manager_->register_video_note(static_cast<const MessageVideoNote *>(content)->file_id,
message_full_id, source);
return td->transcription_manager_->register_voice(static_cast<const MessageVideoNote *>(content)->file_id,
content_type, message_full_id, source);
case MessageContentType::VoiceNote:
return td->voice_notes_manager_->register_voice_note(static_cast<const MessageVoiceNote *>(content)->file_id,
message_full_id, source);
return td->transcription_manager_->register_voice(static_cast<const MessageVoiceNote *>(content)->file_id,
content_type, message_full_id, source);
case MessageContentType::Poll:
return td->poll_manager_->register_poll(static_cast<const MessagePoll *>(content)->poll_id, message_full_id,
source);
@ -5005,7 +5007,8 @@ void reregister_message_content(Td *td, const MessageContent *old_content, const
void unregister_message_content(Td *td, const MessageContent *content, MessageFullId message_full_id,
const char *source) {
switch (content->get_type()) {
auto content_type = content->get_type();
switch (content_type) {
case MessageContentType::Text: {
auto text = static_cast<const MessageText *>(content);
if (text->web_page_id.is_valid()) {
@ -5017,11 +5020,11 @@ void unregister_message_content(Td *td, const MessageContent *content, MessageFu
return;
}
case MessageContentType::VideoNote:
return td->video_notes_manager_->unregister_video_note(static_cast<const MessageVideoNote *>(content)->file_id,
message_full_id, source);
return td->transcription_manager_->unregister_voice(static_cast<const MessageVideoNote *>(content)->file_id,
content_type, message_full_id, source);
case MessageContentType::VoiceNote:
return td->voice_notes_manager_->unregister_voice_note(static_cast<const MessageVoiceNote *>(content)->file_id,
message_full_id, source);
return td->transcription_manager_->unregister_voice(static_cast<const MessageVoiceNote *>(content)->file_id,
content_type, message_full_id, source);
case MessageContentType::Poll:
return td->poll_manager_->unregister_poll(static_cast<const MessagePoll *>(content)->poll_id, message_full_id,
source);
@ -7588,28 +7591,4 @@ void update_used_hashtags(Td *td, const MessageContent *content) {
}
}
void recognize_message_content_speech(Td *td, const MessageContent *content, MessageFullId message_full_id,
Promise<Unit> &&promise) {
switch (content->get_type()) {
case MessageContentType::VideoNote:
return td->video_notes_manager_->recognize_speech(message_full_id, std::move(promise));
case MessageContentType::VoiceNote:
return td->voice_notes_manager_->recognize_speech(message_full_id, std::move(promise));
default:
return promise.set_error(Status::Error(400, "Invalid message specified"));
}
}
void rate_message_content_speech_recognition(Td *td, const MessageContent *content, MessageFullId message_full_id,
bool is_good, Promise<Unit> &&promise) {
switch (content->get_type()) {
case MessageContentType::VideoNote:
return td->video_notes_manager_->rate_speech_recognition(message_full_id, is_good, std::move(promise));
case MessageContentType::VoiceNote:
return td->voice_notes_manager_->rate_speech_recognition(message_full_id, is_good, std::move(promise));
default:
return promise.set_error(Status::Error(400, "Invalid message specified"));
}
}
} // namespace td

View File

@ -308,10 +308,4 @@ void on_dialog_used(TopDialogCategory category, DialogId dialog_id, int32 date);
void update_used_hashtags(Td *td, const MessageContent *content);
void recognize_message_content_speech(Td *td, const MessageContent *content, MessageFullId message_full_id,
Promise<Unit> &&promise);
void rate_message_content_speech_recognition(Td *td, const MessageContent *content, MessageFullId message_full_id,
bool is_good, Promise<Unit> &&promise);
} // namespace td

View File

@ -18281,7 +18281,7 @@ void MessagesManager::on_get_message_viewers(DialogId dialog_id, MessageViewers
void MessagesManager::translate_message_text(MessageFullId message_full_id, const string &to_language_code,
Promise<td_api::object_ptr<td_api::formattedText>> &&promise) {
auto m = get_message_force(message_full_id, "recognize_speech");
auto m = get_message_force(message_full_id, "translate_message_text");
if (m == nullptr) {
return promise.set_error(Status::Error(400, "Message not found"));
}
@ -18297,29 +18297,6 @@ void MessagesManager::translate_message_text(MessageFullId message_full_id, cons
std::move(promise));
}
void MessagesManager::recognize_speech(MessageFullId message_full_id, Promise<Unit> &&promise) {
auto m = get_message_force(message_full_id, "recognize_speech");
if (m == nullptr) {
return promise.set_error(Status::Error(400, "Message not found"));
}
auto message_id = message_full_id.get_message_id();
if (message_id.is_scheduled() || !message_id.is_server()) {
return promise.set_error(Status::Error(400, "Message must be sent already"));
}
recognize_message_content_speech(td_, m->content.get(), message_full_id, std::move(promise));
}
void MessagesManager::rate_speech_recognition(MessageFullId message_full_id, bool is_good, Promise<Unit> &&promise) {
auto m = get_message_force(message_full_id, "rate_speech_recognition");
if (m == nullptr) {
return promise.set_error(Status::Error(400, "Message not found"));
}
rate_message_content_speech_recognition(td_, m->content.get(), message_full_id, is_good, std::move(promise));
}
void MessagesManager::get_dialog_info_full(DialogId dialog_id, Promise<Unit> &&promise, const char *source) {
switch (dialog_id.get_type()) {
case DialogType::User:

View File

@ -670,10 +670,6 @@ class MessagesManager final : public Actor {
void translate_message_text(MessageFullId message_full_id, const string &to_language_code,
Promise<td_api::object_ptr<td_api::formattedText>> &&promise);
void recognize_speech(MessageFullId message_full_id, Promise<Unit> &&promise);
void rate_speech_recognition(MessageFullId message_full_id, bool is_good, Promise<Unit> &&promise);
bool is_message_edited_recently(MessageFullId message_full_id, int32 seconds);
bool is_deleted_secret_chat(DialogId dialog_id) const;

View File

@ -4880,14 +4880,15 @@ void Td::on_request(uint64 id, td_api::translateMessageText &request) {
void Td::on_request(uint64 id, const td_api::recognizeSpeech &request) {
CHECK_IS_USER();
CREATE_OK_REQUEST_PROMISE();
messages_manager_->recognize_speech({DialogId(request.chat_id_), MessageId(request.message_id_)}, std::move(promise));
transcription_manager_->recognize_speech({DialogId(request.chat_id_), MessageId(request.message_id_)},
std::move(promise));
}
void Td::on_request(uint64 id, const td_api::rateSpeechRecognition &request) {
CHECK_IS_USER();
CREATE_OK_REQUEST_PROMISE();
messages_manager_->rate_speech_recognition({DialogId(request.chat_id_), MessageId(request.message_id_)},
request.is_good_, std::move(promise));
transcription_manager_->rate_speech_recognition({DialogId(request.chat_id_), MessageId(request.message_id_)},
request.is_good_, std::move(promise));
}
void Td::on_request(uint64 id, const td_api::getFile &request) {

View File

@ -9,8 +9,11 @@
#include "td/telegram/AuthManager.h"
#include "td/telegram/Global.h"
#include "td/telegram/logevent/LogEvent.h"
#include "td/telegram/MessagesManager.h"
#include "td/telegram/Td.h"
#include "td/telegram/TdDb.h"
#include "td/telegram/VideoNotesManager.h"
#include "td/telegram/VoiceNotesManager.h"
namespace td {
@ -87,6 +90,10 @@ TranscriptionManager::TranscriptionManager(Td *td, ActorShared<> parent) : td_(t
pending_audio_transcription_timeout_.set_callback_data(static_cast<void *>(td_));
}
TranscriptionManager::~TranscriptionManager() {
Scheduler::instance()->destroy_on_scheduler(G()->get_gc_scheduler_id(), voice_messages_, message_file_ids_);
}
void TranscriptionManager::tear_down() {
parent_.reset();
}
@ -171,6 +178,146 @@ TranscriptionManager::TrialParameters::get_update_speech_recognition_trial_objec
cooldown_until_);
}
void TranscriptionManager::register_voice(FileId file_id, MessageContentType content_type,
MessageFullId message_full_id, const char *source) {
if (td_->auth_manager_->is_bot() || message_full_id.get_message_id().is_scheduled() ||
!message_full_id.get_message_id().is_server() ||
message_full_id.get_dialog_id().get_type() == DialogType::SecretChat) {
return;
}
LOG(INFO) << "Register voice " << file_id << " from " << message_full_id << " from " << source;
CHECK(file_id.is_valid());
bool is_inserted = voice_messages_[file_id].emplace(message_full_id).second;
LOG_CHECK(is_inserted) << source << ' ' << file_id << ' ' << message_full_id;
is_inserted = message_file_ids_.emplace(message_full_id, FileInfo(content_type, file_id)).second;
CHECK(is_inserted);
}
void TranscriptionManager::unregister_voice(FileId file_id, MessageContentType content_type,
MessageFullId message_full_id, const char *source) {
if (td_->auth_manager_->is_bot() || message_full_id.get_message_id().is_scheduled() ||
!message_full_id.get_message_id().is_server() ||
message_full_id.get_dialog_id().get_type() == DialogType::SecretChat) {
return;
}
LOG(INFO) << "Unregister voice " << file_id << " from " << message_full_id << " from " << source;
CHECK(file_id.is_valid());
auto &message_full_ids = voice_messages_[file_id];
auto is_deleted = message_full_ids.erase(message_full_id) > 0;
LOG_CHECK(is_deleted) << source << ' ' << file_id << ' ' << message_full_id;
if (message_full_ids.empty()) {
voice_messages_.erase(file_id);
}
is_deleted = message_file_ids_.erase(message_full_id) > 0;
CHECK(is_deleted);
}
TranscriptionInfo *TranscriptionManager::get_transcription_info(const FileInfo &file_info, bool allow_creation) {
switch (file_info.first) {
case MessageContentType::VideoNote:
return td_->video_notes_manager_->get_video_note_transcription_info(file_info.second, allow_creation);
case MessageContentType::VoiceNote:
return td_->voice_notes_manager_->get_voice_note_transcription_info(file_info.second, allow_creation);
default:
UNREACHABLE();
return nullptr;
}
}
void TranscriptionManager::recognize_speech(MessageFullId message_full_id, Promise<Unit> &&promise) {
if (!td_->messages_manager_->have_message_force(message_full_id, "recognize_speech")) {
return promise.set_error(Status::Error(400, "Message not found"));
}
auto it = message_file_ids_.find(message_full_id);
if (it == message_file_ids_.end()) {
return promise.set_error(Status::Error(400, "Message can't be transcribed"));
}
auto *transcription_info = get_transcription_info(it->second, true);
auto handler = [actor_id = actor_id(this), file_info = it->second](
Result<telegram_api::object_ptr<telegram_api::updateTranscribedAudio>> r_update) {
send_closure(actor_id, &TranscriptionManager::on_transcribed_audio_update, file_info, true, std::move(r_update));
};
if (transcription_info->recognize_speech(td_, message_full_id, std::move(promise), std::move(handler))) {
on_transcription_updated(it->second.second);
}
}
void TranscriptionManager::on_transcribed_audio_update(
FileInfo file_info, bool is_initial,
Result<telegram_api::object_ptr<telegram_api::updateTranscribedAudio>> r_update) {
if (G()->close_flag()) {
return;
}
auto *transcription_info = get_transcription_info(file_info, false);
CHECK(transcription_info != nullptr);
if (r_update.is_error()) {
auto promises = transcription_info->on_failed_transcription(r_update.move_as_error());
on_transcription_updated(file_info.second);
set_promises(promises);
return;
}
auto update = r_update.move_as_ok();
auto transcription_id = update->transcription_id_;
if (!update->pending_) {
auto promises = transcription_info->on_final_transcription(std::move(update->text_), transcription_id);
on_transcription_completed(file_info.second);
set_promises(promises);
} else {
auto is_changed = transcription_info->on_partial_transcription(std::move(update->text_), transcription_id);
if (is_changed) {
on_transcription_updated(file_info.second);
}
if (is_initial) {
subscribe_to_transcribed_audio_updates(
transcription_id, [actor_id = actor_id(this), file_info](
Result<telegram_api::object_ptr<telegram_api::updateTranscribedAudio>> r_update) {
send_closure(actor_id, &TranscriptionManager::on_transcribed_audio_update, file_info, false,
std::move(r_update));
});
}
}
}
void TranscriptionManager::on_transcription_updated(FileId file_id) {
auto it = voice_messages_.find(file_id);
if (it != voice_messages_.end()) {
for (const auto &message_full_id : it->second) {
td_->messages_manager_->on_external_update_message_content(message_full_id);
}
}
}
void TranscriptionManager::on_transcription_completed(FileId file_id) {
auto it = voice_messages_.find(file_id);
if (it != voice_messages_.end()) {
for (const auto &message_full_id : it->second) {
td_->messages_manager_->on_update_message_content(message_full_id);
}
}
}
void TranscriptionManager::rate_speech_recognition(MessageFullId message_full_id, bool is_good,
Promise<Unit> &&promise) {
if (!td_->messages_manager_->have_message_force(message_full_id, "recognize_speech")) {
return promise.set_error(Status::Error(400, "Message not found"));
}
auto it = message_file_ids_.find(message_full_id);
if (it == message_file_ids_.end()) {
return promise.set_error(Status::Error(400, "Message can't be transcribed"));
}
const auto *transcription_info = get_transcription_info(it->second, false);
if (transcription_info == nullptr) {
return promise.set_value(Unit());
}
transcription_info->rate_speech_recognition(td_, message_full_id, is_good, std::move(promise));
}
void TranscriptionManager::subscribe_to_transcribed_audio_updates(int64 transcription_id,
TranscribedAudioHandler on_update) {
CHECK(transcription_id != 0);

View File

@ -6,17 +6,24 @@
//
#pragma once
#include "td/telegram/files/FileId.h"
#include "td/telegram/MessageContentType.h"
#include "td/telegram/MessageFullId.h"
#include "td/telegram/td_api.h"
#include "td/telegram/telegram_api.h"
#include "td/telegram/TranscriptionInfo.h"
#include "td/actor/actor.h"
#include "td/actor/MultiTimeout.h"
#include "td/utils/common.h"
#include "td/utils/FlatHashMap.h"
#include "td/utils/FlatHashSet.h"
#include "td/utils/Promise.h"
#include "td/utils/Status.h"
#include <functional>
#include <utility>
namespace td {
@ -25,12 +32,25 @@ class Td;
class TranscriptionManager final : public Actor {
public:
TranscriptionManager(Td *td, ActorShared<> parent);
TranscriptionManager(const TranscriptionManager &) = delete;
TranscriptionManager &operator=(const TranscriptionManager &) = delete;
TranscriptionManager(TranscriptionManager &&) = delete;
TranscriptionManager &operator=(TranscriptionManager &&) = delete;
~TranscriptionManager() final;
void on_update_trial_parameters(int32 weekly_number, int32 duration_max, int32 cooldown_until);
using TranscribedAudioHandler =
std::function<void(Result<telegram_api::object_ptr<telegram_api::updateTranscribedAudio>>)>;
void subscribe_to_transcribed_audio_updates(int64 transcription_id, TranscribedAudioHandler on_update);
void register_voice(FileId file_id, MessageContentType content_type, MessageFullId message_full_id,
const char *source);
void unregister_voice(FileId file_id, MessageContentType content_type, MessageFullId message_full_id,
const char *source);
void recognize_speech(MessageFullId message_full_id, Promise<Unit> &&promise);
void on_transcription_completed(FileId file_id);
void rate_speech_recognition(MessageFullId message_full_id, bool is_good, Promise<Unit> &&promise);
void on_update_transcribed_audio(telegram_api::object_ptr<telegram_api::updateTranscribedAudio> &&update);
@ -53,6 +73,19 @@ class TranscriptionManager final : public Actor {
td_api::object_ptr<td_api::updateSpeechRecognitionTrial> get_update_speech_recognition_trial_object() const;
using FileInfo = std::pair<MessageContentType, FileId>;
TranscriptionInfo *get_transcription_info(const FileInfo &file_info, bool allow_creation);
using TranscribedAudioHandler =
std::function<void(Result<telegram_api::object_ptr<telegram_api::updateTranscribedAudio>>)>;
void subscribe_to_transcribed_audio_updates(int64 transcription_id, TranscribedAudioHandler on_update);
void on_transcribed_audio_update(FileInfo file_info, bool is_initial,
Result<telegram_api::object_ptr<telegram_api::updateTranscribedAudio>> r_update);
void on_transcription_updated(FileId file_id);
void on_pending_audio_transcription_failed(int64 transcription_id, Status &&error);
struct TrialParameters {
@ -81,6 +114,9 @@ class TranscriptionManager final : public Actor {
FlatHashMap<int64, TranscribedAudioHandler> pending_audio_transcriptions_;
MultiTimeout pending_audio_transcription_timeout_{"PendingAudioTranscriptionTimeout"};
FlatHashMap<FileId, FlatHashSet<MessageFullId, MessageFullIdHash>, FileIdHash> voice_messages_;
FlatHashMap<MessageFullId, FileInfo, MessageFullIdHash> message_file_ids_;
};
} // namespace td

View File

@ -18,8 +18,6 @@
#include "td/telegram/telegram_api.h"
#include "td/telegram/TranscriptionManager.h"
#include "td/actor/actor.h"
#include "td/utils/buffer.h"
#include "td/utils/logging.h"
#include "td/utils/misc.h"
@ -44,6 +42,15 @@ int32 VideoNotesManager::get_video_note_duration(FileId file_id) const {
return video_note->duration;
}
TranscriptionInfo *VideoNotesManager::get_video_note_transcription_info(FileId file_id, bool allow_creation) {
auto video_note = get_video_note(file_id);
CHECK(video_note != nullptr);
if (video_note->transcription_info == nullptr && allow_creation) {
video_note->transcription_info = make_unique<TranscriptionInfo>();
}
return video_note->transcription_info.get();
}
tl_object_ptr<td_api::videoNote> VideoNotesManager::get_video_note_object(FileId file_id) const {
if (!file_id.is_valid()) {
return nullptr;
@ -89,7 +96,7 @@ FileId VideoNotesManager::on_get_video_note(unique_ptr<VideoNote> new_video_note
v->thumbnail = std::move(new_video_note->thumbnail);
}
if (TranscriptionInfo::update_from(v->transcription_info, std::move(new_video_note->transcription_info))) {
on_video_note_transcription_completed(file_id);
td_->transcription_manager_->on_transcription_completed(file_id);
}
}
return file_id;
@ -170,129 +177,6 @@ void VideoNotesManager::create_video_note(FileId file_id, string minithumbnail,
on_get_video_note(std::move(v), replace);
}
void VideoNotesManager::register_video_note(FileId video_note_file_id, MessageFullId message_full_id,
const char *source) {
if (message_full_id.get_message_id().is_scheduled() || !message_full_id.get_message_id().is_server() ||
td_->auth_manager_->is_bot()) {
return;
}
LOG(INFO) << "Register video note " << video_note_file_id << " from " << message_full_id << " from " << source;
CHECK(video_note_file_id.is_valid());
bool is_inserted = video_note_messages_[video_note_file_id].insert(message_full_id).second;
LOG_CHECK(is_inserted) << source << ' ' << video_note_file_id << ' ' << message_full_id;
is_inserted = message_video_notes_.emplace(message_full_id, video_note_file_id).second;
CHECK(is_inserted);
}
void VideoNotesManager::unregister_video_note(FileId video_note_file_id, MessageFullId message_full_id,
const char *source) {
if (message_full_id.get_message_id().is_scheduled() || !message_full_id.get_message_id().is_server() ||
td_->auth_manager_->is_bot()) {
return;
}
LOG(INFO) << "Unregister video note " << video_note_file_id << " from " << message_full_id << " from " << source;
CHECK(video_note_file_id.is_valid());
auto &message_ids = video_note_messages_[video_note_file_id];
auto is_deleted = message_ids.erase(message_full_id) > 0;
LOG_CHECK(is_deleted) << source << ' ' << video_note_file_id << ' ' << message_full_id;
if (message_ids.empty()) {
video_note_messages_.erase(video_note_file_id);
}
is_deleted = message_video_notes_.erase(message_full_id) > 0;
CHECK(is_deleted);
}
void VideoNotesManager::recognize_speech(MessageFullId message_full_id, Promise<Unit> &&promise) {
auto it = message_video_notes_.find(message_full_id);
CHECK(it != message_video_notes_.end());
auto file_id = it->second;
auto video_note = get_video_note(file_id);
CHECK(video_note != nullptr);
if (video_note->transcription_info == nullptr) {
video_note->transcription_info = make_unique<TranscriptionInfo>();
}
auto handler = [actor_id = actor_id(this),
file_id](Result<telegram_api::object_ptr<telegram_api::updateTranscribedAudio>> r_update) {
send_closure(actor_id, &VideoNotesManager::on_transcribed_audio_update, file_id, true, std::move(r_update));
};
if (video_note->transcription_info->recognize_speech(td_, message_full_id, std::move(promise), std::move(handler))) {
on_video_note_transcription_updated(file_id);
}
}
void VideoNotesManager::on_transcribed_audio_update(
FileId file_id, bool is_initial, Result<telegram_api::object_ptr<telegram_api::updateTranscribedAudio>> r_update) {
if (G()->close_flag()) {
return;
}
auto video_note = get_video_note(file_id);
CHECK(video_note != nullptr);
CHECK(video_note->transcription_info != nullptr);
if (r_update.is_error()) {
auto promises = video_note->transcription_info->on_failed_transcription(r_update.move_as_error());
on_video_note_transcription_updated(file_id);
set_promises(promises);
return;
}
auto update = r_update.move_as_ok();
auto transcription_id = update->transcription_id_;
if (!update->pending_) {
auto promises = video_note->transcription_info->on_final_transcription(std::move(update->text_), transcription_id);
on_video_note_transcription_completed(file_id);
set_promises(promises);
} else {
auto is_changed =
video_note->transcription_info->on_partial_transcription(std::move(update->text_), transcription_id);
if (is_changed) {
on_video_note_transcription_updated(file_id);
}
if (is_initial) {
td_->transcription_manager_->subscribe_to_transcribed_audio_updates(
transcription_id, [actor_id = actor_id(this),
file_id](Result<telegram_api::object_ptr<telegram_api::updateTranscribedAudio>> r_update) {
send_closure(actor_id, &VideoNotesManager::on_transcribed_audio_update, file_id, false,
std::move(r_update));
});
}
}
}
void VideoNotesManager::on_video_note_transcription_updated(FileId file_id) {
auto it = video_note_messages_.find(file_id);
if (it != video_note_messages_.end()) {
for (const auto &message_full_id : it->second) {
td_->messages_manager_->on_external_update_message_content(message_full_id);
}
}
}
void VideoNotesManager::on_video_note_transcription_completed(FileId file_id) {
auto it = video_note_messages_.find(file_id);
if (it != video_note_messages_.end()) {
for (const auto &message_full_id : it->second) {
td_->messages_manager_->on_update_message_content(message_full_id);
}
}
}
void VideoNotesManager::rate_speech_recognition(MessageFullId message_full_id, bool is_good, Promise<Unit> &&promise) {
auto it = message_video_notes_.find(message_full_id);
CHECK(it != message_video_notes_.end());
auto file_id = it->second;
auto video_note = get_video_note(file_id);
CHECK(video_note != nullptr);
if (video_note->transcription_info == nullptr) {
return promise.set_value(Unit());
}
video_note->transcription_info->rate_speech_recognition(td_, message_full_id, is_good, std::move(promise));
}
SecretInputMedia VideoNotesManager::get_secret_input_media(FileId video_note_file_id,
tl_object_ptr<telegram_api::InputEncryptedFile> input_file,
BufferSlice thumbnail, int32 layer) const {

View File

@ -40,19 +40,13 @@ class VideoNotesManager final : public Actor {
int32 get_video_note_duration(FileId file_id) const;
TranscriptionInfo *get_video_note_transcription_info(FileId file_id, bool allow_creation);
tl_object_ptr<td_api::videoNote> get_video_note_object(FileId file_id) const;
void create_video_note(FileId file_id, string minithumbnail, PhotoSize thumbnail, int32 duration,
Dimensions dimensions, string waveform, bool replace);
void register_video_note(FileId video_note_file_id, MessageFullId message_full_id, const char *source);
void unregister_video_note(FileId video_note_file_id, MessageFullId message_full_id, const char *source);
void recognize_speech(MessageFullId message_full_id, Promise<Unit> &&promise);
void rate_speech_recognition(MessageFullId message_full_id, bool is_good, Promise<Unit> &&promise);
tl_object_ptr<telegram_api::InputMedia> get_input_media(FileId file_id,
tl_object_ptr<telegram_api::InputFile> input_file,
tl_object_ptr<telegram_api::InputFile> input_thumbnail) const;
@ -94,22 +88,12 @@ class VideoNotesManager final : public Actor {
FileId on_get_video_note(unique_ptr<VideoNote> new_video_note, bool replace);
void on_video_note_transcription_updated(FileId file_id);
void on_video_note_transcription_completed(FileId file_id);
void on_transcribed_audio_update(FileId file_id, bool is_initial,
Result<telegram_api::object_ptr<telegram_api::updateTranscribedAudio>> r_update);
void tear_down() final;
Td *td_;
ActorShared<> parent_;
WaitFreeHashMap<FileId, unique_ptr<VideoNote>, FileIdHash> video_notes_;
FlatHashMap<FileId, FlatHashSet<MessageFullId, MessageFullIdHash>, FileIdHash> video_note_messages_;
FlatHashMap<MessageFullId, FileId, MessageFullIdHash> message_video_notes_;
};
} // namespace td

View File

@ -26,8 +26,7 @@ VoiceNotesManager::VoiceNotesManager(Td *td, ActorShared<> parent) : td_(td), pa
}
VoiceNotesManager::~VoiceNotesManager() {
Scheduler::instance()->destroy_on_scheduler(G()->get_gc_scheduler_id(), voice_notes_, voice_note_messages_,
message_voice_notes_);
Scheduler::instance()->destroy_on_scheduler(G()->get_gc_scheduler_id(), voice_notes_);
}
void VoiceNotesManager::tear_down() {
@ -42,6 +41,15 @@ int32 VoiceNotesManager::get_voice_note_duration(FileId file_id) const {
return voice_note->duration;
}
TranscriptionInfo *VoiceNotesManager::get_voice_note_transcription_info(FileId file_id, bool allow_creation) {
auto voice_note = get_voice_note(file_id);
CHECK(voice_note != nullptr);
if (voice_note->transcription_info == nullptr && allow_creation) {
voice_note->transcription_info = make_unique<TranscriptionInfo>();
}
return voice_note->transcription_info.get();
}
tl_object_ptr<td_api::voiceNote> VoiceNotesManager::get_voice_note_object(FileId file_id) const {
if (!file_id.is_valid()) {
return nullptr;
@ -76,7 +84,7 @@ FileId VoiceNotesManager::on_get_voice_note(unique_ptr<VoiceNote> new_voice_note
v->waveform = std::move(new_voice_note->waveform);
}
if (TranscriptionInfo::update_from(v->transcription_info, std::move(new_voice_note->transcription_info))) {
on_voice_note_transcription_completed(file_id);
td_->transcription_manager_->on_transcription_completed(file_id);
}
}
@ -134,129 +142,6 @@ void VoiceNotesManager::create_voice_note(FileId file_id, string mime_type, int3
on_get_voice_note(std::move(v), replace);
}
void VoiceNotesManager::register_voice_note(FileId voice_note_file_id, MessageFullId message_full_id,
const char *source) {
if (message_full_id.get_message_id().is_scheduled() || !message_full_id.get_message_id().is_server() ||
td_->auth_manager_->is_bot()) {
return;
}
LOG(INFO) << "Register voice note " << voice_note_file_id << " from " << message_full_id << " from " << source;
CHECK(voice_note_file_id.is_valid());
bool is_inserted = voice_note_messages_[voice_note_file_id].insert(message_full_id).second;
LOG_CHECK(is_inserted) << source << ' ' << voice_note_file_id << ' ' << message_full_id;
is_inserted = message_voice_notes_.emplace(message_full_id, voice_note_file_id).second;
CHECK(is_inserted);
}
void VoiceNotesManager::unregister_voice_note(FileId voice_note_file_id, MessageFullId message_full_id,
const char *source) {
if (message_full_id.get_message_id().is_scheduled() || !message_full_id.get_message_id().is_server() ||
td_->auth_manager_->is_bot()) {
return;
}
LOG(INFO) << "Unregister voice note " << voice_note_file_id << " from " << message_full_id << " from " << source;
CHECK(voice_note_file_id.is_valid());
auto &message_ids = voice_note_messages_[voice_note_file_id];
auto is_deleted = message_ids.erase(message_full_id) > 0;
LOG_CHECK(is_deleted) << source << ' ' << voice_note_file_id << ' ' << message_full_id;
if (message_ids.empty()) {
voice_note_messages_.erase(voice_note_file_id);
}
is_deleted = message_voice_notes_.erase(message_full_id) > 0;
CHECK(is_deleted);
}
void VoiceNotesManager::recognize_speech(MessageFullId message_full_id, Promise<Unit> &&promise) {
auto it = message_voice_notes_.find(message_full_id);
CHECK(it != message_voice_notes_.end());
auto file_id = it->second;
auto voice_note = get_voice_note(file_id);
CHECK(voice_note != nullptr);
if (voice_note->transcription_info == nullptr) {
voice_note->transcription_info = make_unique<TranscriptionInfo>();
}
auto handler = [actor_id = actor_id(this),
file_id](Result<telegram_api::object_ptr<telegram_api::updateTranscribedAudio>> r_update) {
send_closure(actor_id, &VoiceNotesManager::on_transcribed_audio_update, file_id, true, std::move(r_update));
};
if (voice_note->transcription_info->recognize_speech(td_, message_full_id, std::move(promise), std::move(handler))) {
on_voice_note_transcription_updated(file_id);
}
}
void VoiceNotesManager::on_transcribed_audio_update(
FileId file_id, bool is_initial, Result<telegram_api::object_ptr<telegram_api::updateTranscribedAudio>> r_update) {
if (G()->close_flag()) {
return;
}
auto voice_note = get_voice_note(file_id);
CHECK(voice_note != nullptr);
CHECK(voice_note->transcription_info != nullptr);
if (r_update.is_error()) {
auto promises = voice_note->transcription_info->on_failed_transcription(r_update.move_as_error());
on_voice_note_transcription_updated(file_id);
set_promises(promises);
return;
}
auto update = r_update.move_as_ok();
auto transcription_id = update->transcription_id_;
if (!update->pending_) {
auto promises = voice_note->transcription_info->on_final_transcription(std::move(update->text_), transcription_id);
on_voice_note_transcription_completed(file_id);
set_promises(promises);
} else {
auto is_changed =
voice_note->transcription_info->on_partial_transcription(std::move(update->text_), transcription_id);
if (is_changed) {
on_voice_note_transcription_updated(file_id);
}
if (is_initial) {
td_->transcription_manager_->subscribe_to_transcribed_audio_updates(
transcription_id, [actor_id = actor_id(this),
file_id](Result<telegram_api::object_ptr<telegram_api::updateTranscribedAudio>> r_update) {
send_closure(actor_id, &VoiceNotesManager::on_transcribed_audio_update, file_id, false,
std::move(r_update));
});
}
}
}
void VoiceNotesManager::on_voice_note_transcription_updated(FileId file_id) {
auto it = voice_note_messages_.find(file_id);
if (it != voice_note_messages_.end()) {
for (const auto &message_full_id : it->second) {
td_->messages_manager_->on_external_update_message_content(message_full_id);
}
}
}
void VoiceNotesManager::on_voice_note_transcription_completed(FileId file_id) {
auto it = voice_note_messages_.find(file_id);
if (it != voice_note_messages_.end()) {
for (const auto &message_full_id : it->second) {
td_->messages_manager_->on_update_message_content(message_full_id);
}
}
}
void VoiceNotesManager::rate_speech_recognition(MessageFullId message_full_id, bool is_good, Promise<Unit> &&promise) {
auto it = message_voice_notes_.find(message_full_id);
CHECK(it != message_voice_notes_.end());
auto file_id = it->second;
auto voice_note = get_voice_note(file_id);
CHECK(voice_note != nullptr);
if (voice_note->transcription_info == nullptr) {
return promise.set_value(Unit());
}
voice_note->transcription_info->rate_speech_recognition(td_, message_full_id, is_good, std::move(promise));
}
SecretInputMedia VoiceNotesManager::get_secret_input_media(FileId voice_note_file_id,
tl_object_ptr<telegram_api::InputEncryptedFile> input_file,
const string &caption, int32 layer) const {

View File

@ -37,18 +37,12 @@ class VoiceNotesManager final : public Actor {
int32 get_voice_note_duration(FileId file_id) const;
TranscriptionInfo *get_voice_note_transcription_info(FileId file_id, bool allow_creation);
tl_object_ptr<td_api::voiceNote> get_voice_note_object(FileId file_id) const;
void create_voice_note(FileId file_id, string mime_type, int32 duration, string waveform, bool replace);
void register_voice_note(FileId voice_note_file_id, MessageFullId message_full_id, const char *source);
void unregister_voice_note(FileId voice_note_file_id, MessageFullId message_full_id, const char *source);
void recognize_speech(MessageFullId message_full_id, Promise<Unit> &&promise);
void rate_speech_recognition(MessageFullId message_full_id, bool is_good, Promise<Unit> &&promise);
tl_object_ptr<telegram_api::InputMedia> get_input_media(FileId file_id,
tl_object_ptr<telegram_api::InputFile> input_file) const;
@ -83,22 +77,12 @@ class VoiceNotesManager final : public Actor {
FileId on_get_voice_note(unique_ptr<VoiceNote> new_voice_note, bool replace);
void on_voice_note_transcription_updated(FileId file_id);
void on_voice_note_transcription_completed(FileId file_id);
void on_transcribed_audio_update(FileId file_id, bool is_initial,
Result<telegram_api::object_ptr<telegram_api::updateTranscribedAudio>> r_update);
void tear_down() final;
Td *td_;
ActorShared<> parent_;
WaitFreeHashMap<FileId, unique_ptr<VoiceNote>, FileIdHash> voice_notes_;
FlatHashMap<FileId, FlatHashSet<MessageFullId, MessageFullIdHash>, FileIdHash> voice_note_messages_;
FlatHashMap<MessageFullId, FileId, MessageFullIdHash> message_voice_notes_;
};
} // namespace td