diff --git a/td/generate/scheme/td_api.tl b/td/generate/scheme/td_api.tl index b5695a26a..094f323e5 100644 --- a/td/generate/scheme/td_api.tl +++ b/td/generate/scheme/td_api.tl @@ -4818,6 +4818,11 @@ getMessageLinkInfo url:string = MessageLinkInfo; //@to_language_code A two-letter ISO 639-1 language code of the language to which the message is translated translateText text:string from_language_code:string to_language_code:string = Text; +//@description Recognizes speech in a voice note message. The message must be successfully sent and must not be scheduled +//@chat_id Identifier of the chat to which the message belongs +//@message_id Identifier of the message +recognizeSpeech chat_id:int53 message_id:int53 = Ok; + //@description Returns list of message sender identifiers, which can be used to send messages in a chat @chat_id Chat identifier getChatAvailableMessageSenders chat_id:int53 = MessageSenders; diff --git a/td/telegram/Td.cpp b/td/telegram/Td.cpp index 0a55576ff..cd1bd1b25 100644 --- a/td/telegram/Td.cpp +++ b/td/telegram/Td.cpp @@ -4797,6 +4797,13 @@ void Td::on_request(uint64 id, td_api::translateText &request) { std::move(promise)); } +void Td::on_request(uint64 id, const td_api::recognizeSpeech &request) { + CHECK_IS_USER(); + CREATE_OK_REQUEST_PROMISE(); + voice_notes_manager_->recognize_speech({DialogId(request.chat_id_), MessageId(request.message_id_)}, + std::move(promise)); +} + void Td::on_request(uint64 id, const td_api::getFile &request) { send_closure(actor_id(this), &Td::send_result, id, file_manager_->get_file_object(FileId(request.file_id_, 0))); } diff --git a/td/telegram/Td.h b/td/telegram/Td.h index 409e071e6..8896928b9 100644 --- a/td/telegram/Td.h +++ b/td/telegram/Td.h @@ -532,6 +532,8 @@ class Td final : public Actor { void on_request(uint64 id, td_api::translateText &request); + void on_request(uint64 id, const td_api::recognizeSpeech &request); + void on_request(uint64 id, const td_api::getFile &request); void on_request(uint64 id, td_api::getRemoteFile &request); diff --git a/td/telegram/VoiceNotesManager.cpp b/td/telegram/VoiceNotesManager.cpp index bdaeb17c8..861fc40bd 100644 --- a/td/telegram/VoiceNotesManager.cpp +++ b/td/telegram/VoiceNotesManager.cpp @@ -8,6 +8,8 @@ #include "td/telegram/Dimensions.h" #include "td/telegram/files/FileManager.h" +#include "td/telegram/Global.h" +#include "td/telegram/MessagesManager.h" #include "td/telegram/secret_api.h" #include "td/telegram/Td.h" #include "td/telegram/td_api.h" @@ -16,10 +18,46 @@ #include "td/utils/buffer.h" #include "td/utils/logging.h" #include "td/utils/misc.h" -#include "td/utils/Status.h" namespace td { +class TranscribeAudioQuery final : public Td::ResultHandler { + DialogId dialog_id_; + FileId file_id_; + + public: + void send(FileId file_id, FullMessageId full_message_id) { + dialog_id_ = full_message_id.get_dialog_id(); + file_id_ = file_id; + auto input_peer = td_->messages_manager_->get_input_peer(dialog_id_, AccessRights::Read); + if (input_peer == nullptr) { + return on_error(Status::Error(400, "Can't access the chat")); + } + send_query(G()->net_query_creator().create(telegram_api::messages_transcribeAudio( + std::move(input_peer), full_message_id.get_message_id().get_server_message_id().get()))); + } + + void on_result(BufferSlice packet) final { + auto result_ptr = fetch_result(packet); + if (result_ptr.is_error()) { + return on_error(result_ptr.move_as_error()); + } + + auto result = result_ptr.move_as_ok(); + LOG(INFO) << "Receive result for TranscribeAudioQuery: " << to_string(result); + if (result->transcription_id_ == 0) { + return on_error(Status::Error(500, "Receive no recognition identifier")); + } + td_->voice_notes_manager_->on_voice_note_transcribed(file_id_, std::move(result->text_), result->transcription_id_, + !result->pending_); + } + + void on_error(Status status) final { + td_->messages_manager_->on_get_dialog_error(dialog_id_, status, "TranscribeAudioQuery"); + td_->voice_notes_manager_->on_voice_note_transcription_failed(file_id_, std::move(status)); + } +}; + VoiceNotesManager::VoiceNotesManager(Td *td) : td_(td) { } @@ -63,11 +101,29 @@ FileId VoiceNotesManager::on_get_voice_note(unique_ptr new_voice_note v->duration = new_voice_note->duration; v->waveform = new_voice_note->waveform; } + if (new_voice_note->is_transcribed && v->transcription_id == 0) { + CHECK(!v->is_transcribed); + CHECK(new_voice_note->transcription_id != 0); + v->is_transcribed = true; + v->transcription_id = new_voice_note->transcription_id; + v->text = std::move(new_voice_note->text); + on_voice_note_transcription_updated(file_id); + } } return file_id; } +VoiceNotesManager::VoiceNote *VoiceNotesManager::get_voice_note(FileId file_id) { + auto voice_note = voice_notes_.find(file_id); + if (voice_note == voice_notes_.end()) { + return nullptr; + } + + CHECK(voice_note->second->file_id == file_id); + return voice_note->second.get(); +} + const VoiceNotesManager::VoiceNote *VoiceNotesManager::get_voice_note(FileId file_id) const { auto voice_note = voice_notes_.find(file_id); if (voice_note == voice_notes_.end()) { @@ -137,8 +193,8 @@ void VoiceNotesManager::register_voice_note(FileId voice_note_file_id, FullMessa LOG(INFO) << "Register voice note " << voice_note_file_id << " from " << full_message_id << " from " << source; bool is_inserted = voice_note_messages_[voice_note_file_id].insert(full_message_id).second; LOG_CHECK(is_inserted) << source << ' ' << voice_note_file_id << ' ' << full_message_id; - auto voice_note = get_voice_note(voice_note_file_id); - CHECK(voice_note != nullptr); + is_inserted = message_voice_notes_.emplace(full_message_id, voice_note_file_id).second; + CHECK(is_inserted); } void VoiceNotesManager::unregister_voice_note(FileId voice_note_file_id, FullMessageId full_message_id, @@ -153,6 +209,80 @@ void VoiceNotesManager::unregister_voice_note(FileId voice_note_file_id, FullMes if (message_ids.empty()) { voice_note_messages_.erase(voice_note_file_id); } + is_deleted = message_voice_notes_.erase(full_message_id); + CHECK(is_deleted); +} + +void VoiceNotesManager::recognize_speech(FullMessageId full_message_id, Promise &&promise) { + if (!td_->messages_manager_->have_message_force(full_message_id, "recognize_speech")) { + return promise.set_error(Status::Error(400, "Message not found")); + } + + auto it = message_voice_notes_.find(full_message_id); + if (it == message_voice_notes_.end()) { + return promise.set_error(Status::Error(400, "Invalid message specified")); + } + + auto file_id = it->second; + auto voice_note = get_voice_note(file_id); + CHECK(voice_note != nullptr); + if (voice_note->is_transcribed) { + return promise.set_value(Unit()); + } + auto &queries = speech_recognition_queries_[file_id]; + queries.push_back(std::move(promise)); + if (queries.size() == 1) { + td_->create_handler()->send(file_id, full_message_id); + } +} + +void VoiceNotesManager::on_voice_note_transcribed(FileId file_id, string &&text, int64 transcription_id, + bool is_final) { + auto voice_note = get_voice_note(file_id); + CHECK(voice_note != nullptr); + CHECK(!voice_note->is_transcribed); + CHECK(voice_note->transcription_id == 0); + voice_note->transcription_id = transcription_id; + voice_note->is_transcribed = is_final; + voice_note->text = std::move(text); + + if (!voice_note->text.empty() || is_final) { + on_voice_note_transcription_updated(file_id); + } + + if (is_final) { + auto it = speech_recognition_queries_.find(file_id); + CHECK(it != speech_recognition_queries_.end()); + CHECK(!it->second.empty()); + auto promises = std::move(it->second); + speech_recognition_queries_.erase(it); + + set_promises(promises); + } +} + +void VoiceNotesManager::on_voice_note_transcription_failed(FileId file_id, Status &&error) { + auto voice_note = get_voice_note(file_id); + CHECK(voice_note != nullptr); + CHECK(!voice_note->is_transcribed); + CHECK(voice_note->transcription_id == 0); + + auto it = speech_recognition_queries_.find(file_id); + CHECK(it != speech_recognition_queries_.end()); + CHECK(!it->second.empty()); + auto promises = std::move(it->second); + speech_recognition_queries_.erase(it); + + fail_promises(promises, std::move(error)); +} + +void VoiceNotesManager::on_voice_note_transcription_updated(FileId file_id) { + auto it = voice_note_messages_.find(file_id); + if (it != voice_note_messages_.end()) { + for (const auto &full_message_id : it->second) { + td_->messages_manager_->on_external_update_message_content(full_message_id); + } + } } SecretInputMedia VoiceNotesManager::get_secret_input_media(FileId voice_note_file_id, diff --git a/td/telegram/VoiceNotesManager.h b/td/telegram/VoiceNotesManager.h index 9f610480b..baf148de3 100644 --- a/td/telegram/VoiceNotesManager.h +++ b/td/telegram/VoiceNotesManager.h @@ -13,9 +13,12 @@ #include "td/telegram/telegram_api.h" #include "td/telegram/Version.h" +#include "td/actor/PromiseFuture.h" + #include "td/utils/common.h" #include "td/utils/FlatHashMap.h" #include "td/utils/FlatHashSet.h" +#include "td/utils/Status.h" namespace td { @@ -35,6 +38,12 @@ class VoiceNotesManager { void unregister_voice_note(FileId voice_note_file_id, FullMessageId full_message_id, const char *source); + void recognize_speech(FullMessageId full_message_id, Promise &&promise); + + void on_voice_note_transcribed(FileId file_id, string &&text, int64 transcription_id, bool is_final); + + void on_voice_note_transcription_failed(FileId file_id, Status &&error); + tl_object_ptr get_input_media(FileId file_id, tl_object_ptr input_file) const; @@ -65,14 +74,21 @@ class VoiceNotesManager { FileId file_id; }; + VoiceNote *get_voice_note(FileId file_id); + const VoiceNote *get_voice_note(FileId file_id) const; FileId on_get_voice_note(unique_ptr new_voice_note, bool replace); + void on_voice_note_transcription_updated(FileId file_id); + Td *td_; FlatHashMap, FileIdHash> voice_notes_; + FlatHashMap>, FileIdHash> speech_recognition_queries_; + FlatHashMap, FileIdHash> voice_note_messages_; + FlatHashMap message_voice_notes_; }; } // namespace td diff --git a/td/telegram/cli.cpp b/td/telegram/cli.cpp index f9d382b7c..65e364269 100644 --- a/td/telegram/cli.cpp +++ b/td/telegram/cli.cpp @@ -2889,6 +2889,11 @@ class CliClient final : public Actor { string to_language_code; get_args(args, text, from_language_code, to_language_code); send_request(td_api::make_object(text, from_language_code, to_language_code)); + } else if (op == "rs") { + ChatId chat_id; + MessageId message_id; + get_args(args, chat_id, message_id); + send_request(td_api::make_object(chat_id, message_id)); } else if (op == "gf" || op == "GetFile") { FileId file_id; get_args(args, file_id);