Move speech recognition to TranscriptionManager.
This commit is contained in:
parent
ab39c96b2c
commit
ec109dfd4a
@ -7059,7 +7059,7 @@ translateText text:formattedText to_language_code:string = FormattedText;
|
||||
//-"st", "sn", "sd", "si", "sk", "sl", "so", "es", "su", "sw", "sv", "tl", "tg", "ta", "tt", "te", "th", "tr", "tk", "uk", "ur", "ug", "uz", "vi", "cy", "xh", "yi", "ji", "yo", "zu"
|
||||
translateMessageText chat_id:int53 message_id:int53 to_language_code:string = FormattedText;
|
||||
|
||||
//@description Recognizes speech in a video note or a voice note message. The message must be successfully sent and must not be scheduled
|
||||
//@description Recognizes speech in a video note or a voice note message. The message must be successfully sent, must not be scheduled, and must be from a non-secret chat
|
||||
//@chat_id Identifier of the chat to which the message belongs
|
||||
//@message_id Identifier of the message
|
||||
recognizeSpeech chat_id:int53 message_id:int53 = Ok;
|
||||
|
@ -79,6 +79,7 @@
|
||||
#include "td/telegram/Td.h"
|
||||
#include "td/telegram/telegram_api.h"
|
||||
#include "td/telegram/TopDialogManager.h"
|
||||
#include "td/telegram/TranscriptionManager.h"
|
||||
#include "td/telegram/UserId.h"
|
||||
#include "td/telegram/Venue.h"
|
||||
#include "td/telegram/Version.h"
|
||||
@ -4885,7 +4886,8 @@ static CustomEmojiId get_custom_emoji_id(const FormattedText &text) {
|
||||
|
||||
void register_message_content(Td *td, const MessageContent *content, MessageFullId message_full_id,
|
||||
const char *source) {
|
||||
switch (content->get_type()) {
|
||||
auto content_type = content->get_type();
|
||||
switch (content_type) {
|
||||
case MessageContentType::Text: {
|
||||
auto text = static_cast<const MessageText *>(content);
|
||||
if (text->web_page_id.is_valid()) {
|
||||
@ -4897,11 +4899,11 @@ void register_message_content(Td *td, const MessageContent *content, MessageFull
|
||||
return;
|
||||
}
|
||||
case MessageContentType::VideoNote:
|
||||
return td->video_notes_manager_->register_video_note(static_cast<const MessageVideoNote *>(content)->file_id,
|
||||
message_full_id, source);
|
||||
return td->transcription_manager_->register_voice(static_cast<const MessageVideoNote *>(content)->file_id,
|
||||
content_type, message_full_id, source);
|
||||
case MessageContentType::VoiceNote:
|
||||
return td->voice_notes_manager_->register_voice_note(static_cast<const MessageVoiceNote *>(content)->file_id,
|
||||
message_full_id, source);
|
||||
return td->transcription_manager_->register_voice(static_cast<const MessageVoiceNote *>(content)->file_id,
|
||||
content_type, message_full_id, source);
|
||||
case MessageContentType::Poll:
|
||||
return td->poll_manager_->register_poll(static_cast<const MessagePoll *>(content)->poll_id, message_full_id,
|
||||
source);
|
||||
@ -5005,7 +5007,8 @@ void reregister_message_content(Td *td, const MessageContent *old_content, const
|
||||
|
||||
void unregister_message_content(Td *td, const MessageContent *content, MessageFullId message_full_id,
|
||||
const char *source) {
|
||||
switch (content->get_type()) {
|
||||
auto content_type = content->get_type();
|
||||
switch (content_type) {
|
||||
case MessageContentType::Text: {
|
||||
auto text = static_cast<const MessageText *>(content);
|
||||
if (text->web_page_id.is_valid()) {
|
||||
@ -5017,11 +5020,11 @@ void unregister_message_content(Td *td, const MessageContent *content, MessageFu
|
||||
return;
|
||||
}
|
||||
case MessageContentType::VideoNote:
|
||||
return td->video_notes_manager_->unregister_video_note(static_cast<const MessageVideoNote *>(content)->file_id,
|
||||
message_full_id, source);
|
||||
return td->transcription_manager_->unregister_voice(static_cast<const MessageVideoNote *>(content)->file_id,
|
||||
content_type, message_full_id, source);
|
||||
case MessageContentType::VoiceNote:
|
||||
return td->voice_notes_manager_->unregister_voice_note(static_cast<const MessageVoiceNote *>(content)->file_id,
|
||||
message_full_id, source);
|
||||
return td->transcription_manager_->unregister_voice(static_cast<const MessageVoiceNote *>(content)->file_id,
|
||||
content_type, message_full_id, source);
|
||||
case MessageContentType::Poll:
|
||||
return td->poll_manager_->unregister_poll(static_cast<const MessagePoll *>(content)->poll_id, message_full_id,
|
||||
source);
|
||||
@ -7588,28 +7591,4 @@ void update_used_hashtags(Td *td, const MessageContent *content) {
|
||||
}
|
||||
}
|
||||
|
||||
void recognize_message_content_speech(Td *td, const MessageContent *content, MessageFullId message_full_id,
|
||||
Promise<Unit> &&promise) {
|
||||
switch (content->get_type()) {
|
||||
case MessageContentType::VideoNote:
|
||||
return td->video_notes_manager_->recognize_speech(message_full_id, std::move(promise));
|
||||
case MessageContentType::VoiceNote:
|
||||
return td->voice_notes_manager_->recognize_speech(message_full_id, std::move(promise));
|
||||
default:
|
||||
return promise.set_error(Status::Error(400, "Invalid message specified"));
|
||||
}
|
||||
}
|
||||
|
||||
void rate_message_content_speech_recognition(Td *td, const MessageContent *content, MessageFullId message_full_id,
|
||||
bool is_good, Promise<Unit> &&promise) {
|
||||
switch (content->get_type()) {
|
||||
case MessageContentType::VideoNote:
|
||||
return td->video_notes_manager_->rate_speech_recognition(message_full_id, is_good, std::move(promise));
|
||||
case MessageContentType::VoiceNote:
|
||||
return td->voice_notes_manager_->rate_speech_recognition(message_full_id, is_good, std::move(promise));
|
||||
default:
|
||||
return promise.set_error(Status::Error(400, "Invalid message specified"));
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace td
|
||||
|
@ -308,10 +308,4 @@ void on_dialog_used(TopDialogCategory category, DialogId dialog_id, int32 date);
|
||||
|
||||
void update_used_hashtags(Td *td, const MessageContent *content);
|
||||
|
||||
void recognize_message_content_speech(Td *td, const MessageContent *content, MessageFullId message_full_id,
|
||||
Promise<Unit> &&promise);
|
||||
|
||||
void rate_message_content_speech_recognition(Td *td, const MessageContent *content, MessageFullId message_full_id,
|
||||
bool is_good, Promise<Unit> &&promise);
|
||||
|
||||
} // namespace td
|
||||
|
@ -18281,7 +18281,7 @@ void MessagesManager::on_get_message_viewers(DialogId dialog_id, MessageViewers
|
||||
|
||||
void MessagesManager::translate_message_text(MessageFullId message_full_id, const string &to_language_code,
|
||||
Promise<td_api::object_ptr<td_api::formattedText>> &&promise) {
|
||||
auto m = get_message_force(message_full_id, "recognize_speech");
|
||||
auto m = get_message_force(message_full_id, "translate_message_text");
|
||||
if (m == nullptr) {
|
||||
return promise.set_error(Status::Error(400, "Message not found"));
|
||||
}
|
||||
@ -18297,29 +18297,6 @@ void MessagesManager::translate_message_text(MessageFullId message_full_id, cons
|
||||
std::move(promise));
|
||||
}
|
||||
|
||||
void MessagesManager::recognize_speech(MessageFullId message_full_id, Promise<Unit> &&promise) {
|
||||
auto m = get_message_force(message_full_id, "recognize_speech");
|
||||
if (m == nullptr) {
|
||||
return promise.set_error(Status::Error(400, "Message not found"));
|
||||
}
|
||||
|
||||
auto message_id = message_full_id.get_message_id();
|
||||
if (message_id.is_scheduled() || !message_id.is_server()) {
|
||||
return promise.set_error(Status::Error(400, "Message must be sent already"));
|
||||
}
|
||||
|
||||
recognize_message_content_speech(td_, m->content.get(), message_full_id, std::move(promise));
|
||||
}
|
||||
|
||||
void MessagesManager::rate_speech_recognition(MessageFullId message_full_id, bool is_good, Promise<Unit> &&promise) {
|
||||
auto m = get_message_force(message_full_id, "rate_speech_recognition");
|
||||
if (m == nullptr) {
|
||||
return promise.set_error(Status::Error(400, "Message not found"));
|
||||
}
|
||||
|
||||
rate_message_content_speech_recognition(td_, m->content.get(), message_full_id, is_good, std::move(promise));
|
||||
}
|
||||
|
||||
void MessagesManager::get_dialog_info_full(DialogId dialog_id, Promise<Unit> &&promise, const char *source) {
|
||||
switch (dialog_id.get_type()) {
|
||||
case DialogType::User:
|
||||
|
@ -670,10 +670,6 @@ class MessagesManager final : public Actor {
|
||||
void translate_message_text(MessageFullId message_full_id, const string &to_language_code,
|
||||
Promise<td_api::object_ptr<td_api::formattedText>> &&promise);
|
||||
|
||||
void recognize_speech(MessageFullId message_full_id, Promise<Unit> &&promise);
|
||||
|
||||
void rate_speech_recognition(MessageFullId message_full_id, bool is_good, Promise<Unit> &&promise);
|
||||
|
||||
bool is_message_edited_recently(MessageFullId message_full_id, int32 seconds);
|
||||
|
||||
bool is_deleted_secret_chat(DialogId dialog_id) const;
|
||||
|
@ -4880,14 +4880,15 @@ void Td::on_request(uint64 id, td_api::translateMessageText &request) {
|
||||
void Td::on_request(uint64 id, const td_api::recognizeSpeech &request) {
|
||||
CHECK_IS_USER();
|
||||
CREATE_OK_REQUEST_PROMISE();
|
||||
messages_manager_->recognize_speech({DialogId(request.chat_id_), MessageId(request.message_id_)}, std::move(promise));
|
||||
transcription_manager_->recognize_speech({DialogId(request.chat_id_), MessageId(request.message_id_)},
|
||||
std::move(promise));
|
||||
}
|
||||
|
||||
void Td::on_request(uint64 id, const td_api::rateSpeechRecognition &request) {
|
||||
CHECK_IS_USER();
|
||||
CREATE_OK_REQUEST_PROMISE();
|
||||
messages_manager_->rate_speech_recognition({DialogId(request.chat_id_), MessageId(request.message_id_)},
|
||||
request.is_good_, std::move(promise));
|
||||
transcription_manager_->rate_speech_recognition({DialogId(request.chat_id_), MessageId(request.message_id_)},
|
||||
request.is_good_, std::move(promise));
|
||||
}
|
||||
|
||||
void Td::on_request(uint64 id, const td_api::getFile &request) {
|
||||
|
@ -9,8 +9,11 @@
|
||||
#include "td/telegram/AuthManager.h"
|
||||
#include "td/telegram/Global.h"
|
||||
#include "td/telegram/logevent/LogEvent.h"
|
||||
#include "td/telegram/MessagesManager.h"
|
||||
#include "td/telegram/Td.h"
|
||||
#include "td/telegram/TdDb.h"
|
||||
#include "td/telegram/VideoNotesManager.h"
|
||||
#include "td/telegram/VoiceNotesManager.h"
|
||||
|
||||
namespace td {
|
||||
|
||||
@ -87,6 +90,10 @@ TranscriptionManager::TranscriptionManager(Td *td, ActorShared<> parent) : td_(t
|
||||
pending_audio_transcription_timeout_.set_callback_data(static_cast<void *>(td_));
|
||||
}
|
||||
|
||||
TranscriptionManager::~TranscriptionManager() {
|
||||
Scheduler::instance()->destroy_on_scheduler(G()->get_gc_scheduler_id(), voice_messages_, message_file_ids_);
|
||||
}
|
||||
|
||||
void TranscriptionManager::tear_down() {
|
||||
parent_.reset();
|
||||
}
|
||||
@ -171,6 +178,146 @@ TranscriptionManager::TrialParameters::get_update_speech_recognition_trial_objec
|
||||
cooldown_until_);
|
||||
}
|
||||
|
||||
void TranscriptionManager::register_voice(FileId file_id, MessageContentType content_type,
|
||||
MessageFullId message_full_id, const char *source) {
|
||||
if (td_->auth_manager_->is_bot() || message_full_id.get_message_id().is_scheduled() ||
|
||||
!message_full_id.get_message_id().is_server() ||
|
||||
message_full_id.get_dialog_id().get_type() == DialogType::SecretChat) {
|
||||
return;
|
||||
}
|
||||
LOG(INFO) << "Register voice " << file_id << " from " << message_full_id << " from " << source;
|
||||
CHECK(file_id.is_valid());
|
||||
bool is_inserted = voice_messages_[file_id].emplace(message_full_id).second;
|
||||
LOG_CHECK(is_inserted) << source << ' ' << file_id << ' ' << message_full_id;
|
||||
is_inserted = message_file_ids_.emplace(message_full_id, FileInfo(content_type, file_id)).second;
|
||||
CHECK(is_inserted);
|
||||
}
|
||||
|
||||
void TranscriptionManager::unregister_voice(FileId file_id, MessageContentType content_type,
|
||||
MessageFullId message_full_id, const char *source) {
|
||||
if (td_->auth_manager_->is_bot() || message_full_id.get_message_id().is_scheduled() ||
|
||||
!message_full_id.get_message_id().is_server() ||
|
||||
message_full_id.get_dialog_id().get_type() == DialogType::SecretChat) {
|
||||
return;
|
||||
}
|
||||
LOG(INFO) << "Unregister voice " << file_id << " from " << message_full_id << " from " << source;
|
||||
CHECK(file_id.is_valid());
|
||||
auto &message_full_ids = voice_messages_[file_id];
|
||||
auto is_deleted = message_full_ids.erase(message_full_id) > 0;
|
||||
LOG_CHECK(is_deleted) << source << ' ' << file_id << ' ' << message_full_id;
|
||||
if (message_full_ids.empty()) {
|
||||
voice_messages_.erase(file_id);
|
||||
}
|
||||
is_deleted = message_file_ids_.erase(message_full_id) > 0;
|
||||
CHECK(is_deleted);
|
||||
}
|
||||
|
||||
TranscriptionInfo *TranscriptionManager::get_transcription_info(const FileInfo &file_info, bool allow_creation) {
|
||||
switch (file_info.first) {
|
||||
case MessageContentType::VideoNote:
|
||||
return td_->video_notes_manager_->get_video_note_transcription_info(file_info.second, allow_creation);
|
||||
case MessageContentType::VoiceNote:
|
||||
return td_->voice_notes_manager_->get_voice_note_transcription_info(file_info.second, allow_creation);
|
||||
default:
|
||||
UNREACHABLE();
|
||||
return nullptr;
|
||||
}
|
||||
}
|
||||
|
||||
void TranscriptionManager::recognize_speech(MessageFullId message_full_id, Promise<Unit> &&promise) {
|
||||
if (!td_->messages_manager_->have_message_force(message_full_id, "recognize_speech")) {
|
||||
return promise.set_error(Status::Error(400, "Message not found"));
|
||||
}
|
||||
|
||||
auto it = message_file_ids_.find(message_full_id);
|
||||
if (it == message_file_ids_.end()) {
|
||||
return promise.set_error(Status::Error(400, "Message can't be transcribed"));
|
||||
}
|
||||
|
||||
auto *transcription_info = get_transcription_info(it->second, true);
|
||||
auto handler = [actor_id = actor_id(this), file_info = it->second](
|
||||
Result<telegram_api::object_ptr<telegram_api::updateTranscribedAudio>> r_update) {
|
||||
send_closure(actor_id, &TranscriptionManager::on_transcribed_audio_update, file_info, true, std::move(r_update));
|
||||
};
|
||||
if (transcription_info->recognize_speech(td_, message_full_id, std::move(promise), std::move(handler))) {
|
||||
on_transcription_updated(it->second.second);
|
||||
}
|
||||
}
|
||||
|
||||
void TranscriptionManager::on_transcribed_audio_update(
|
||||
FileInfo file_info, bool is_initial,
|
||||
Result<telegram_api::object_ptr<telegram_api::updateTranscribedAudio>> r_update) {
|
||||
if (G()->close_flag()) {
|
||||
return;
|
||||
}
|
||||
|
||||
auto *transcription_info = get_transcription_info(file_info, false);
|
||||
CHECK(transcription_info != nullptr);
|
||||
if (r_update.is_error()) {
|
||||
auto promises = transcription_info->on_failed_transcription(r_update.move_as_error());
|
||||
on_transcription_updated(file_info.second);
|
||||
set_promises(promises);
|
||||
return;
|
||||
}
|
||||
auto update = r_update.move_as_ok();
|
||||
auto transcription_id = update->transcription_id_;
|
||||
if (!update->pending_) {
|
||||
auto promises = transcription_info->on_final_transcription(std::move(update->text_), transcription_id);
|
||||
on_transcription_completed(file_info.second);
|
||||
set_promises(promises);
|
||||
} else {
|
||||
auto is_changed = transcription_info->on_partial_transcription(std::move(update->text_), transcription_id);
|
||||
if (is_changed) {
|
||||
on_transcription_updated(file_info.second);
|
||||
}
|
||||
|
||||
if (is_initial) {
|
||||
subscribe_to_transcribed_audio_updates(
|
||||
transcription_id, [actor_id = actor_id(this), file_info](
|
||||
Result<telegram_api::object_ptr<telegram_api::updateTranscribedAudio>> r_update) {
|
||||
send_closure(actor_id, &TranscriptionManager::on_transcribed_audio_update, file_info, false,
|
||||
std::move(r_update));
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void TranscriptionManager::on_transcription_updated(FileId file_id) {
|
||||
auto it = voice_messages_.find(file_id);
|
||||
if (it != voice_messages_.end()) {
|
||||
for (const auto &message_full_id : it->second) {
|
||||
td_->messages_manager_->on_external_update_message_content(message_full_id);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void TranscriptionManager::on_transcription_completed(FileId file_id) {
|
||||
auto it = voice_messages_.find(file_id);
|
||||
if (it != voice_messages_.end()) {
|
||||
for (const auto &message_full_id : it->second) {
|
||||
td_->messages_manager_->on_update_message_content(message_full_id);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void TranscriptionManager::rate_speech_recognition(MessageFullId message_full_id, bool is_good,
|
||||
Promise<Unit> &&promise) {
|
||||
if (!td_->messages_manager_->have_message_force(message_full_id, "recognize_speech")) {
|
||||
return promise.set_error(Status::Error(400, "Message not found"));
|
||||
}
|
||||
|
||||
auto it = message_file_ids_.find(message_full_id);
|
||||
if (it == message_file_ids_.end()) {
|
||||
return promise.set_error(Status::Error(400, "Message can't be transcribed"));
|
||||
}
|
||||
|
||||
const auto *transcription_info = get_transcription_info(it->second, false);
|
||||
if (transcription_info == nullptr) {
|
||||
return promise.set_value(Unit());
|
||||
}
|
||||
transcription_info->rate_speech_recognition(td_, message_full_id, is_good, std::move(promise));
|
||||
}
|
||||
|
||||
void TranscriptionManager::subscribe_to_transcribed_audio_updates(int64 transcription_id,
|
||||
TranscribedAudioHandler on_update) {
|
||||
CHECK(transcription_id != 0);
|
||||
|
@ -6,17 +6,24 @@
|
||||
//
|
||||
#pragma once
|
||||
|
||||
#include "td/telegram/files/FileId.h"
|
||||
#include "td/telegram/MessageContentType.h"
|
||||
#include "td/telegram/MessageFullId.h"
|
||||
#include "td/telegram/td_api.h"
|
||||
#include "td/telegram/telegram_api.h"
|
||||
#include "td/telegram/TranscriptionInfo.h"
|
||||
|
||||
#include "td/actor/actor.h"
|
||||
#include "td/actor/MultiTimeout.h"
|
||||
|
||||
#include "td/utils/common.h"
|
||||
#include "td/utils/FlatHashMap.h"
|
||||
#include "td/utils/FlatHashSet.h"
|
||||
#include "td/utils/Promise.h"
|
||||
#include "td/utils/Status.h"
|
||||
|
||||
#include <functional>
|
||||
#include <utility>
|
||||
|
||||
namespace td {
|
||||
|
||||
@ -25,12 +32,25 @@ class Td;
|
||||
class TranscriptionManager final : public Actor {
|
||||
public:
|
||||
TranscriptionManager(Td *td, ActorShared<> parent);
|
||||
TranscriptionManager(const TranscriptionManager &) = delete;
|
||||
TranscriptionManager &operator=(const TranscriptionManager &) = delete;
|
||||
TranscriptionManager(TranscriptionManager &&) = delete;
|
||||
TranscriptionManager &operator=(TranscriptionManager &&) = delete;
|
||||
~TranscriptionManager() final;
|
||||
|
||||
void on_update_trial_parameters(int32 weekly_number, int32 duration_max, int32 cooldown_until);
|
||||
|
||||
using TranscribedAudioHandler =
|
||||
std::function<void(Result<telegram_api::object_ptr<telegram_api::updateTranscribedAudio>>)>;
|
||||
void subscribe_to_transcribed_audio_updates(int64 transcription_id, TranscribedAudioHandler on_update);
|
||||
void register_voice(FileId file_id, MessageContentType content_type, MessageFullId message_full_id,
|
||||
const char *source);
|
||||
|
||||
void unregister_voice(FileId file_id, MessageContentType content_type, MessageFullId message_full_id,
|
||||
const char *source);
|
||||
|
||||
void recognize_speech(MessageFullId message_full_id, Promise<Unit> &&promise);
|
||||
|
||||
void on_transcription_completed(FileId file_id);
|
||||
|
||||
void rate_speech_recognition(MessageFullId message_full_id, bool is_good, Promise<Unit> &&promise);
|
||||
|
||||
void on_update_transcribed_audio(telegram_api::object_ptr<telegram_api::updateTranscribedAudio> &&update);
|
||||
|
||||
@ -53,6 +73,19 @@ class TranscriptionManager final : public Actor {
|
||||
|
||||
td_api::object_ptr<td_api::updateSpeechRecognitionTrial> get_update_speech_recognition_trial_object() const;
|
||||
|
||||
using FileInfo = std::pair<MessageContentType, FileId>;
|
||||
|
||||
TranscriptionInfo *get_transcription_info(const FileInfo &file_info, bool allow_creation);
|
||||
|
||||
using TranscribedAudioHandler =
|
||||
std::function<void(Result<telegram_api::object_ptr<telegram_api::updateTranscribedAudio>>)>;
|
||||
void subscribe_to_transcribed_audio_updates(int64 transcription_id, TranscribedAudioHandler on_update);
|
||||
|
||||
void on_transcribed_audio_update(FileInfo file_info, bool is_initial,
|
||||
Result<telegram_api::object_ptr<telegram_api::updateTranscribedAudio>> r_update);
|
||||
|
||||
void on_transcription_updated(FileId file_id);
|
||||
|
||||
void on_pending_audio_transcription_failed(int64 transcription_id, Status &&error);
|
||||
|
||||
struct TrialParameters {
|
||||
@ -81,6 +114,9 @@ class TranscriptionManager final : public Actor {
|
||||
|
||||
FlatHashMap<int64, TranscribedAudioHandler> pending_audio_transcriptions_;
|
||||
MultiTimeout pending_audio_transcription_timeout_{"PendingAudioTranscriptionTimeout"};
|
||||
|
||||
FlatHashMap<FileId, FlatHashSet<MessageFullId, MessageFullIdHash>, FileIdHash> voice_messages_;
|
||||
FlatHashMap<MessageFullId, FileInfo, MessageFullIdHash> message_file_ids_;
|
||||
};
|
||||
|
||||
} // namespace td
|
||||
|
@ -18,8 +18,6 @@
|
||||
#include "td/telegram/telegram_api.h"
|
||||
#include "td/telegram/TranscriptionManager.h"
|
||||
|
||||
#include "td/actor/actor.h"
|
||||
|
||||
#include "td/utils/buffer.h"
|
||||
#include "td/utils/logging.h"
|
||||
#include "td/utils/misc.h"
|
||||
@ -44,6 +42,15 @@ int32 VideoNotesManager::get_video_note_duration(FileId file_id) const {
|
||||
return video_note->duration;
|
||||
}
|
||||
|
||||
TranscriptionInfo *VideoNotesManager::get_video_note_transcription_info(FileId file_id, bool allow_creation) {
|
||||
auto video_note = get_video_note(file_id);
|
||||
CHECK(video_note != nullptr);
|
||||
if (video_note->transcription_info == nullptr && allow_creation) {
|
||||
video_note->transcription_info = make_unique<TranscriptionInfo>();
|
||||
}
|
||||
return video_note->transcription_info.get();
|
||||
}
|
||||
|
||||
tl_object_ptr<td_api::videoNote> VideoNotesManager::get_video_note_object(FileId file_id) const {
|
||||
if (!file_id.is_valid()) {
|
||||
return nullptr;
|
||||
@ -89,7 +96,7 @@ FileId VideoNotesManager::on_get_video_note(unique_ptr<VideoNote> new_video_note
|
||||
v->thumbnail = std::move(new_video_note->thumbnail);
|
||||
}
|
||||
if (TranscriptionInfo::update_from(v->transcription_info, std::move(new_video_note->transcription_info))) {
|
||||
on_video_note_transcription_completed(file_id);
|
||||
td_->transcription_manager_->on_transcription_completed(file_id);
|
||||
}
|
||||
}
|
||||
return file_id;
|
||||
@ -170,129 +177,6 @@ void VideoNotesManager::create_video_note(FileId file_id, string minithumbnail,
|
||||
on_get_video_note(std::move(v), replace);
|
||||
}
|
||||
|
||||
void VideoNotesManager::register_video_note(FileId video_note_file_id, MessageFullId message_full_id,
|
||||
const char *source) {
|
||||
if (message_full_id.get_message_id().is_scheduled() || !message_full_id.get_message_id().is_server() ||
|
||||
td_->auth_manager_->is_bot()) {
|
||||
return;
|
||||
}
|
||||
LOG(INFO) << "Register video note " << video_note_file_id << " from " << message_full_id << " from " << source;
|
||||
CHECK(video_note_file_id.is_valid());
|
||||
bool is_inserted = video_note_messages_[video_note_file_id].insert(message_full_id).second;
|
||||
LOG_CHECK(is_inserted) << source << ' ' << video_note_file_id << ' ' << message_full_id;
|
||||
is_inserted = message_video_notes_.emplace(message_full_id, video_note_file_id).second;
|
||||
CHECK(is_inserted);
|
||||
}
|
||||
|
||||
void VideoNotesManager::unregister_video_note(FileId video_note_file_id, MessageFullId message_full_id,
|
||||
const char *source) {
|
||||
if (message_full_id.get_message_id().is_scheduled() || !message_full_id.get_message_id().is_server() ||
|
||||
td_->auth_manager_->is_bot()) {
|
||||
return;
|
||||
}
|
||||
LOG(INFO) << "Unregister video note " << video_note_file_id << " from " << message_full_id << " from " << source;
|
||||
CHECK(video_note_file_id.is_valid());
|
||||
auto &message_ids = video_note_messages_[video_note_file_id];
|
||||
auto is_deleted = message_ids.erase(message_full_id) > 0;
|
||||
LOG_CHECK(is_deleted) << source << ' ' << video_note_file_id << ' ' << message_full_id;
|
||||
if (message_ids.empty()) {
|
||||
video_note_messages_.erase(video_note_file_id);
|
||||
}
|
||||
is_deleted = message_video_notes_.erase(message_full_id) > 0;
|
||||
CHECK(is_deleted);
|
||||
}
|
||||
|
||||
void VideoNotesManager::recognize_speech(MessageFullId message_full_id, Promise<Unit> &&promise) {
|
||||
auto it = message_video_notes_.find(message_full_id);
|
||||
CHECK(it != message_video_notes_.end());
|
||||
|
||||
auto file_id = it->second;
|
||||
auto video_note = get_video_note(file_id);
|
||||
CHECK(video_note != nullptr);
|
||||
if (video_note->transcription_info == nullptr) {
|
||||
video_note->transcription_info = make_unique<TranscriptionInfo>();
|
||||
}
|
||||
|
||||
auto handler = [actor_id = actor_id(this),
|
||||
file_id](Result<telegram_api::object_ptr<telegram_api::updateTranscribedAudio>> r_update) {
|
||||
send_closure(actor_id, &VideoNotesManager::on_transcribed_audio_update, file_id, true, std::move(r_update));
|
||||
};
|
||||
if (video_note->transcription_info->recognize_speech(td_, message_full_id, std::move(promise), std::move(handler))) {
|
||||
on_video_note_transcription_updated(file_id);
|
||||
}
|
||||
}
|
||||
|
||||
void VideoNotesManager::on_transcribed_audio_update(
|
||||
FileId file_id, bool is_initial, Result<telegram_api::object_ptr<telegram_api::updateTranscribedAudio>> r_update) {
|
||||
if (G()->close_flag()) {
|
||||
return;
|
||||
}
|
||||
|
||||
auto video_note = get_video_note(file_id);
|
||||
CHECK(video_note != nullptr);
|
||||
CHECK(video_note->transcription_info != nullptr);
|
||||
|
||||
if (r_update.is_error()) {
|
||||
auto promises = video_note->transcription_info->on_failed_transcription(r_update.move_as_error());
|
||||
on_video_note_transcription_updated(file_id);
|
||||
set_promises(promises);
|
||||
return;
|
||||
}
|
||||
auto update = r_update.move_as_ok();
|
||||
auto transcription_id = update->transcription_id_;
|
||||
if (!update->pending_) {
|
||||
auto promises = video_note->transcription_info->on_final_transcription(std::move(update->text_), transcription_id);
|
||||
on_video_note_transcription_completed(file_id);
|
||||
set_promises(promises);
|
||||
} else {
|
||||
auto is_changed =
|
||||
video_note->transcription_info->on_partial_transcription(std::move(update->text_), transcription_id);
|
||||
if (is_changed) {
|
||||
on_video_note_transcription_updated(file_id);
|
||||
}
|
||||
|
||||
if (is_initial) {
|
||||
td_->transcription_manager_->subscribe_to_transcribed_audio_updates(
|
||||
transcription_id, [actor_id = actor_id(this),
|
||||
file_id](Result<telegram_api::object_ptr<telegram_api::updateTranscribedAudio>> r_update) {
|
||||
send_closure(actor_id, &VideoNotesManager::on_transcribed_audio_update, file_id, false,
|
||||
std::move(r_update));
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void VideoNotesManager::on_video_note_transcription_updated(FileId file_id) {
|
||||
auto it = video_note_messages_.find(file_id);
|
||||
if (it != video_note_messages_.end()) {
|
||||
for (const auto &message_full_id : it->second) {
|
||||
td_->messages_manager_->on_external_update_message_content(message_full_id);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void VideoNotesManager::on_video_note_transcription_completed(FileId file_id) {
|
||||
auto it = video_note_messages_.find(file_id);
|
||||
if (it != video_note_messages_.end()) {
|
||||
for (const auto &message_full_id : it->second) {
|
||||
td_->messages_manager_->on_update_message_content(message_full_id);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void VideoNotesManager::rate_speech_recognition(MessageFullId message_full_id, bool is_good, Promise<Unit> &&promise) {
|
||||
auto it = message_video_notes_.find(message_full_id);
|
||||
CHECK(it != message_video_notes_.end());
|
||||
|
||||
auto file_id = it->second;
|
||||
auto video_note = get_video_note(file_id);
|
||||
CHECK(video_note != nullptr);
|
||||
if (video_note->transcription_info == nullptr) {
|
||||
return promise.set_value(Unit());
|
||||
}
|
||||
video_note->transcription_info->rate_speech_recognition(td_, message_full_id, is_good, std::move(promise));
|
||||
}
|
||||
|
||||
SecretInputMedia VideoNotesManager::get_secret_input_media(FileId video_note_file_id,
|
||||
tl_object_ptr<telegram_api::InputEncryptedFile> input_file,
|
||||
BufferSlice thumbnail, int32 layer) const {
|
||||
|
@ -40,19 +40,13 @@ class VideoNotesManager final : public Actor {
|
||||
|
||||
int32 get_video_note_duration(FileId file_id) const;
|
||||
|
||||
TranscriptionInfo *get_video_note_transcription_info(FileId file_id, bool allow_creation);
|
||||
|
||||
tl_object_ptr<td_api::videoNote> get_video_note_object(FileId file_id) const;
|
||||
|
||||
void create_video_note(FileId file_id, string minithumbnail, PhotoSize thumbnail, int32 duration,
|
||||
Dimensions dimensions, string waveform, bool replace);
|
||||
|
||||
void register_video_note(FileId video_note_file_id, MessageFullId message_full_id, const char *source);
|
||||
|
||||
void unregister_video_note(FileId video_note_file_id, MessageFullId message_full_id, const char *source);
|
||||
|
||||
void recognize_speech(MessageFullId message_full_id, Promise<Unit> &&promise);
|
||||
|
||||
void rate_speech_recognition(MessageFullId message_full_id, bool is_good, Promise<Unit> &&promise);
|
||||
|
||||
tl_object_ptr<telegram_api::InputMedia> get_input_media(FileId file_id,
|
||||
tl_object_ptr<telegram_api::InputFile> input_file,
|
||||
tl_object_ptr<telegram_api::InputFile> input_thumbnail) const;
|
||||
@ -94,22 +88,12 @@ class VideoNotesManager final : public Actor {
|
||||
|
||||
FileId on_get_video_note(unique_ptr<VideoNote> new_video_note, bool replace);
|
||||
|
||||
void on_video_note_transcription_updated(FileId file_id);
|
||||
|
||||
void on_video_note_transcription_completed(FileId file_id);
|
||||
|
||||
void on_transcribed_audio_update(FileId file_id, bool is_initial,
|
||||
Result<telegram_api::object_ptr<telegram_api::updateTranscribedAudio>> r_update);
|
||||
|
||||
void tear_down() final;
|
||||
|
||||
Td *td_;
|
||||
ActorShared<> parent_;
|
||||
|
||||
WaitFreeHashMap<FileId, unique_ptr<VideoNote>, FileIdHash> video_notes_;
|
||||
|
||||
FlatHashMap<FileId, FlatHashSet<MessageFullId, MessageFullIdHash>, FileIdHash> video_note_messages_;
|
||||
FlatHashMap<MessageFullId, FileId, MessageFullIdHash> message_video_notes_;
|
||||
};
|
||||
|
||||
} // namespace td
|
||||
|
@ -26,8 +26,7 @@ VoiceNotesManager::VoiceNotesManager(Td *td, ActorShared<> parent) : td_(td), pa
|
||||
}
|
||||
|
||||
VoiceNotesManager::~VoiceNotesManager() {
|
||||
Scheduler::instance()->destroy_on_scheduler(G()->get_gc_scheduler_id(), voice_notes_, voice_note_messages_,
|
||||
message_voice_notes_);
|
||||
Scheduler::instance()->destroy_on_scheduler(G()->get_gc_scheduler_id(), voice_notes_);
|
||||
}
|
||||
|
||||
void VoiceNotesManager::tear_down() {
|
||||
@ -42,6 +41,15 @@ int32 VoiceNotesManager::get_voice_note_duration(FileId file_id) const {
|
||||
return voice_note->duration;
|
||||
}
|
||||
|
||||
TranscriptionInfo *VoiceNotesManager::get_voice_note_transcription_info(FileId file_id, bool allow_creation) {
|
||||
auto voice_note = get_voice_note(file_id);
|
||||
CHECK(voice_note != nullptr);
|
||||
if (voice_note->transcription_info == nullptr && allow_creation) {
|
||||
voice_note->transcription_info = make_unique<TranscriptionInfo>();
|
||||
}
|
||||
return voice_note->transcription_info.get();
|
||||
}
|
||||
|
||||
tl_object_ptr<td_api::voiceNote> VoiceNotesManager::get_voice_note_object(FileId file_id) const {
|
||||
if (!file_id.is_valid()) {
|
||||
return nullptr;
|
||||
@ -76,7 +84,7 @@ FileId VoiceNotesManager::on_get_voice_note(unique_ptr<VoiceNote> new_voice_note
|
||||
v->waveform = std::move(new_voice_note->waveform);
|
||||
}
|
||||
if (TranscriptionInfo::update_from(v->transcription_info, std::move(new_voice_note->transcription_info))) {
|
||||
on_voice_note_transcription_completed(file_id);
|
||||
td_->transcription_manager_->on_transcription_completed(file_id);
|
||||
}
|
||||
}
|
||||
|
||||
@ -134,129 +142,6 @@ void VoiceNotesManager::create_voice_note(FileId file_id, string mime_type, int3
|
||||
on_get_voice_note(std::move(v), replace);
|
||||
}
|
||||
|
||||
void VoiceNotesManager::register_voice_note(FileId voice_note_file_id, MessageFullId message_full_id,
|
||||
const char *source) {
|
||||
if (message_full_id.get_message_id().is_scheduled() || !message_full_id.get_message_id().is_server() ||
|
||||
td_->auth_manager_->is_bot()) {
|
||||
return;
|
||||
}
|
||||
LOG(INFO) << "Register voice note " << voice_note_file_id << " from " << message_full_id << " from " << source;
|
||||
CHECK(voice_note_file_id.is_valid());
|
||||
bool is_inserted = voice_note_messages_[voice_note_file_id].insert(message_full_id).second;
|
||||
LOG_CHECK(is_inserted) << source << ' ' << voice_note_file_id << ' ' << message_full_id;
|
||||
is_inserted = message_voice_notes_.emplace(message_full_id, voice_note_file_id).second;
|
||||
CHECK(is_inserted);
|
||||
}
|
||||
|
||||
void VoiceNotesManager::unregister_voice_note(FileId voice_note_file_id, MessageFullId message_full_id,
|
||||
const char *source) {
|
||||
if (message_full_id.get_message_id().is_scheduled() || !message_full_id.get_message_id().is_server() ||
|
||||
td_->auth_manager_->is_bot()) {
|
||||
return;
|
||||
}
|
||||
LOG(INFO) << "Unregister voice note " << voice_note_file_id << " from " << message_full_id << " from " << source;
|
||||
CHECK(voice_note_file_id.is_valid());
|
||||
auto &message_ids = voice_note_messages_[voice_note_file_id];
|
||||
auto is_deleted = message_ids.erase(message_full_id) > 0;
|
||||
LOG_CHECK(is_deleted) << source << ' ' << voice_note_file_id << ' ' << message_full_id;
|
||||
if (message_ids.empty()) {
|
||||
voice_note_messages_.erase(voice_note_file_id);
|
||||
}
|
||||
is_deleted = message_voice_notes_.erase(message_full_id) > 0;
|
||||
CHECK(is_deleted);
|
||||
}
|
||||
|
||||
void VoiceNotesManager::recognize_speech(MessageFullId message_full_id, Promise<Unit> &&promise) {
|
||||
auto it = message_voice_notes_.find(message_full_id);
|
||||
CHECK(it != message_voice_notes_.end());
|
||||
|
||||
auto file_id = it->second;
|
||||
auto voice_note = get_voice_note(file_id);
|
||||
CHECK(voice_note != nullptr);
|
||||
if (voice_note->transcription_info == nullptr) {
|
||||
voice_note->transcription_info = make_unique<TranscriptionInfo>();
|
||||
}
|
||||
|
||||
auto handler = [actor_id = actor_id(this),
|
||||
file_id](Result<telegram_api::object_ptr<telegram_api::updateTranscribedAudio>> r_update) {
|
||||
send_closure(actor_id, &VoiceNotesManager::on_transcribed_audio_update, file_id, true, std::move(r_update));
|
||||
};
|
||||
if (voice_note->transcription_info->recognize_speech(td_, message_full_id, std::move(promise), std::move(handler))) {
|
||||
on_voice_note_transcription_updated(file_id);
|
||||
}
|
||||
}
|
||||
|
||||
void VoiceNotesManager::on_transcribed_audio_update(
|
||||
FileId file_id, bool is_initial, Result<telegram_api::object_ptr<telegram_api::updateTranscribedAudio>> r_update) {
|
||||
if (G()->close_flag()) {
|
||||
return;
|
||||
}
|
||||
|
||||
auto voice_note = get_voice_note(file_id);
|
||||
CHECK(voice_note != nullptr);
|
||||
CHECK(voice_note->transcription_info != nullptr);
|
||||
|
||||
if (r_update.is_error()) {
|
||||
auto promises = voice_note->transcription_info->on_failed_transcription(r_update.move_as_error());
|
||||
on_voice_note_transcription_updated(file_id);
|
||||
set_promises(promises);
|
||||
return;
|
||||
}
|
||||
auto update = r_update.move_as_ok();
|
||||
auto transcription_id = update->transcription_id_;
|
||||
if (!update->pending_) {
|
||||
auto promises = voice_note->transcription_info->on_final_transcription(std::move(update->text_), transcription_id);
|
||||
on_voice_note_transcription_completed(file_id);
|
||||
set_promises(promises);
|
||||
} else {
|
||||
auto is_changed =
|
||||
voice_note->transcription_info->on_partial_transcription(std::move(update->text_), transcription_id);
|
||||
if (is_changed) {
|
||||
on_voice_note_transcription_updated(file_id);
|
||||
}
|
||||
|
||||
if (is_initial) {
|
||||
td_->transcription_manager_->subscribe_to_transcribed_audio_updates(
|
||||
transcription_id, [actor_id = actor_id(this),
|
||||
file_id](Result<telegram_api::object_ptr<telegram_api::updateTranscribedAudio>> r_update) {
|
||||
send_closure(actor_id, &VoiceNotesManager::on_transcribed_audio_update, file_id, false,
|
||||
std::move(r_update));
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void VoiceNotesManager::on_voice_note_transcription_updated(FileId file_id) {
|
||||
auto it = voice_note_messages_.find(file_id);
|
||||
if (it != voice_note_messages_.end()) {
|
||||
for (const auto &message_full_id : it->second) {
|
||||
td_->messages_manager_->on_external_update_message_content(message_full_id);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void VoiceNotesManager::on_voice_note_transcription_completed(FileId file_id) {
|
||||
auto it = voice_note_messages_.find(file_id);
|
||||
if (it != voice_note_messages_.end()) {
|
||||
for (const auto &message_full_id : it->second) {
|
||||
td_->messages_manager_->on_update_message_content(message_full_id);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void VoiceNotesManager::rate_speech_recognition(MessageFullId message_full_id, bool is_good, Promise<Unit> &&promise) {
|
||||
auto it = message_voice_notes_.find(message_full_id);
|
||||
CHECK(it != message_voice_notes_.end());
|
||||
|
||||
auto file_id = it->second;
|
||||
auto voice_note = get_voice_note(file_id);
|
||||
CHECK(voice_note != nullptr);
|
||||
if (voice_note->transcription_info == nullptr) {
|
||||
return promise.set_value(Unit());
|
||||
}
|
||||
voice_note->transcription_info->rate_speech_recognition(td_, message_full_id, is_good, std::move(promise));
|
||||
}
|
||||
|
||||
SecretInputMedia VoiceNotesManager::get_secret_input_media(FileId voice_note_file_id,
|
||||
tl_object_ptr<telegram_api::InputEncryptedFile> input_file,
|
||||
const string &caption, int32 layer) const {
|
||||
|
@ -37,18 +37,12 @@ class VoiceNotesManager final : public Actor {
|
||||
|
||||
int32 get_voice_note_duration(FileId file_id) const;
|
||||
|
||||
TranscriptionInfo *get_voice_note_transcription_info(FileId file_id, bool allow_creation);
|
||||
|
||||
tl_object_ptr<td_api::voiceNote> get_voice_note_object(FileId file_id) const;
|
||||
|
||||
void create_voice_note(FileId file_id, string mime_type, int32 duration, string waveform, bool replace);
|
||||
|
||||
void register_voice_note(FileId voice_note_file_id, MessageFullId message_full_id, const char *source);
|
||||
|
||||
void unregister_voice_note(FileId voice_note_file_id, MessageFullId message_full_id, const char *source);
|
||||
|
||||
void recognize_speech(MessageFullId message_full_id, Promise<Unit> &&promise);
|
||||
|
||||
void rate_speech_recognition(MessageFullId message_full_id, bool is_good, Promise<Unit> &&promise);
|
||||
|
||||
tl_object_ptr<telegram_api::InputMedia> get_input_media(FileId file_id,
|
||||
tl_object_ptr<telegram_api::InputFile> input_file) const;
|
||||
|
||||
@ -83,22 +77,12 @@ class VoiceNotesManager final : public Actor {
|
||||
|
||||
FileId on_get_voice_note(unique_ptr<VoiceNote> new_voice_note, bool replace);
|
||||
|
||||
void on_voice_note_transcription_updated(FileId file_id);
|
||||
|
||||
void on_voice_note_transcription_completed(FileId file_id);
|
||||
|
||||
void on_transcribed_audio_update(FileId file_id, bool is_initial,
|
||||
Result<telegram_api::object_ptr<telegram_api::updateTranscribedAudio>> r_update);
|
||||
|
||||
void tear_down() final;
|
||||
|
||||
Td *td_;
|
||||
ActorShared<> parent_;
|
||||
|
||||
WaitFreeHashMap<FileId, unique_ptr<VoiceNote>, FileIdHash> voice_notes_;
|
||||
|
||||
FlatHashMap<FileId, FlatHashSet<MessageFullId, MessageFullIdHash>, FileIdHash> voice_note_messages_;
|
||||
FlatHashMap<MessageFullId, FileId, MessageFullIdHash> message_voice_notes_;
|
||||
};
|
||||
|
||||
} // namespace td
|
||||
|
Loading…
Reference in New Issue
Block a user