Add td_api::recognizeSpeech

This commit is contained in:
levlam 2022-05-25 21:18:40 +03:00
parent de4d3e7620
commit 151654eeea
6 changed files with 168 additions and 3 deletions

View File

@ -4818,6 +4818,11 @@ getMessageLinkInfo url:string = MessageLinkInfo;
//@to_language_code A two-letter ISO 639-1 language code of the language to which the message is translated //@to_language_code A two-letter ISO 639-1 language code of the language to which the message is translated
translateText text:string from_language_code:string to_language_code:string = Text; translateText text:string from_language_code:string to_language_code:string = Text;
//@description Recognizes speech in a voice note message. The message must be successfully sent and must not be scheduled
//@chat_id Identifier of the chat to which the message belongs
//@message_id Identifier of the message
recognizeSpeech chat_id:int53 message_id:int53 = Ok;
//@description Returns list of message sender identifiers, which can be used to send messages in a chat @chat_id Chat identifier //@description Returns list of message sender identifiers, which can be used to send messages in a chat @chat_id Chat identifier
getChatAvailableMessageSenders chat_id:int53 = MessageSenders; getChatAvailableMessageSenders chat_id:int53 = MessageSenders;

View File

@ -4797,6 +4797,13 @@ void Td::on_request(uint64 id, td_api::translateText &request) {
std::move(promise)); std::move(promise));
} }
void Td::on_request(uint64 id, const td_api::recognizeSpeech &request) {
CHECK_IS_USER();
CREATE_OK_REQUEST_PROMISE();
voice_notes_manager_->recognize_speech({DialogId(request.chat_id_), MessageId(request.message_id_)},
std::move(promise));
}
void Td::on_request(uint64 id, const td_api::getFile &request) { void Td::on_request(uint64 id, const td_api::getFile &request) {
send_closure(actor_id(this), &Td::send_result, id, file_manager_->get_file_object(FileId(request.file_id_, 0))); send_closure(actor_id(this), &Td::send_result, id, file_manager_->get_file_object(FileId(request.file_id_, 0)));
} }

View File

@ -532,6 +532,8 @@ class Td final : public Actor {
void on_request(uint64 id, td_api::translateText &request); void on_request(uint64 id, td_api::translateText &request);
void on_request(uint64 id, const td_api::recognizeSpeech &request);
void on_request(uint64 id, const td_api::getFile &request); void on_request(uint64 id, const td_api::getFile &request);
void on_request(uint64 id, td_api::getRemoteFile &request); void on_request(uint64 id, td_api::getRemoteFile &request);

View File

@ -8,6 +8,8 @@
#include "td/telegram/Dimensions.h" #include "td/telegram/Dimensions.h"
#include "td/telegram/files/FileManager.h" #include "td/telegram/files/FileManager.h"
#include "td/telegram/Global.h"
#include "td/telegram/MessagesManager.h"
#include "td/telegram/secret_api.h" #include "td/telegram/secret_api.h"
#include "td/telegram/Td.h" #include "td/telegram/Td.h"
#include "td/telegram/td_api.h" #include "td/telegram/td_api.h"
@ -16,10 +18,46 @@
#include "td/utils/buffer.h" #include "td/utils/buffer.h"
#include "td/utils/logging.h" #include "td/utils/logging.h"
#include "td/utils/misc.h" #include "td/utils/misc.h"
#include "td/utils/Status.h"
namespace td { namespace td {
class TranscribeAudioQuery final : public Td::ResultHandler {
DialogId dialog_id_;
FileId file_id_;
public:
void send(FileId file_id, FullMessageId full_message_id) {
dialog_id_ = full_message_id.get_dialog_id();
file_id_ = file_id;
auto input_peer = td_->messages_manager_->get_input_peer(dialog_id_, AccessRights::Read);
if (input_peer == nullptr) {
return on_error(Status::Error(400, "Can't access the chat"));
}
send_query(G()->net_query_creator().create(telegram_api::messages_transcribeAudio(
std::move(input_peer), full_message_id.get_message_id().get_server_message_id().get())));
}
void on_result(BufferSlice packet) final {
auto result_ptr = fetch_result<telegram_api::messages_transcribeAudio>(packet);
if (result_ptr.is_error()) {
return on_error(result_ptr.move_as_error());
}
auto result = result_ptr.move_as_ok();
LOG(INFO) << "Receive result for TranscribeAudioQuery: " << to_string(result);
if (result->transcription_id_ == 0) {
return on_error(Status::Error(500, "Receive no recognition identifier"));
}
td_->voice_notes_manager_->on_voice_note_transcribed(file_id_, std::move(result->text_), result->transcription_id_,
!result->pending_);
}
void on_error(Status status) final {
td_->messages_manager_->on_get_dialog_error(dialog_id_, status, "TranscribeAudioQuery");
td_->voice_notes_manager_->on_voice_note_transcription_failed(file_id_, std::move(status));
}
};
VoiceNotesManager::VoiceNotesManager(Td *td) : td_(td) { VoiceNotesManager::VoiceNotesManager(Td *td) : td_(td) {
} }
@ -63,11 +101,29 @@ FileId VoiceNotesManager::on_get_voice_note(unique_ptr<VoiceNote> new_voice_note
v->duration = new_voice_note->duration; v->duration = new_voice_note->duration;
v->waveform = new_voice_note->waveform; v->waveform = new_voice_note->waveform;
} }
if (new_voice_note->is_transcribed && v->transcription_id == 0) {
CHECK(!v->is_transcribed);
CHECK(new_voice_note->transcription_id != 0);
v->is_transcribed = true;
v->transcription_id = new_voice_note->transcription_id;
v->text = std::move(new_voice_note->text);
on_voice_note_transcription_updated(file_id);
}
} }
return file_id; return file_id;
} }
VoiceNotesManager::VoiceNote *VoiceNotesManager::get_voice_note(FileId file_id) {
auto voice_note = voice_notes_.find(file_id);
if (voice_note == voice_notes_.end()) {
return nullptr;
}
CHECK(voice_note->second->file_id == file_id);
return voice_note->second.get();
}
const VoiceNotesManager::VoiceNote *VoiceNotesManager::get_voice_note(FileId file_id) const { const VoiceNotesManager::VoiceNote *VoiceNotesManager::get_voice_note(FileId file_id) const {
auto voice_note = voice_notes_.find(file_id); auto voice_note = voice_notes_.find(file_id);
if (voice_note == voice_notes_.end()) { if (voice_note == voice_notes_.end()) {
@ -137,8 +193,8 @@ void VoiceNotesManager::register_voice_note(FileId voice_note_file_id, FullMessa
LOG(INFO) << "Register voice note " << voice_note_file_id << " from " << full_message_id << " from " << source; LOG(INFO) << "Register voice note " << voice_note_file_id << " from " << full_message_id << " from " << source;
bool is_inserted = voice_note_messages_[voice_note_file_id].insert(full_message_id).second; bool is_inserted = voice_note_messages_[voice_note_file_id].insert(full_message_id).second;
LOG_CHECK(is_inserted) << source << ' ' << voice_note_file_id << ' ' << full_message_id; LOG_CHECK(is_inserted) << source << ' ' << voice_note_file_id << ' ' << full_message_id;
auto voice_note = get_voice_note(voice_note_file_id); is_inserted = message_voice_notes_.emplace(full_message_id, voice_note_file_id).second;
CHECK(voice_note != nullptr); CHECK(is_inserted);
} }
void VoiceNotesManager::unregister_voice_note(FileId voice_note_file_id, FullMessageId full_message_id, void VoiceNotesManager::unregister_voice_note(FileId voice_note_file_id, FullMessageId full_message_id,
@ -153,6 +209,80 @@ void VoiceNotesManager::unregister_voice_note(FileId voice_note_file_id, FullMes
if (message_ids.empty()) { if (message_ids.empty()) {
voice_note_messages_.erase(voice_note_file_id); voice_note_messages_.erase(voice_note_file_id);
} }
is_deleted = message_voice_notes_.erase(full_message_id);
CHECK(is_deleted);
}
void VoiceNotesManager::recognize_speech(FullMessageId full_message_id, Promise<Unit> &&promise) {
if (!td_->messages_manager_->have_message_force(full_message_id, "recognize_speech")) {
return promise.set_error(Status::Error(400, "Message not found"));
}
auto it = message_voice_notes_.find(full_message_id);
if (it == message_voice_notes_.end()) {
return promise.set_error(Status::Error(400, "Invalid message specified"));
}
auto file_id = it->second;
auto voice_note = get_voice_note(file_id);
CHECK(voice_note != nullptr);
if (voice_note->is_transcribed) {
return promise.set_value(Unit());
}
auto &queries = speech_recognition_queries_[file_id];
queries.push_back(std::move(promise));
if (queries.size() == 1) {
td_->create_handler<TranscribeAudioQuery>()->send(file_id, full_message_id);
}
}
void VoiceNotesManager::on_voice_note_transcribed(FileId file_id, string &&text, int64 transcription_id,
bool is_final) {
auto voice_note = get_voice_note(file_id);
CHECK(voice_note != nullptr);
CHECK(!voice_note->is_transcribed);
CHECK(voice_note->transcription_id == 0);
voice_note->transcription_id = transcription_id;
voice_note->is_transcribed = is_final;
voice_note->text = std::move(text);
if (!voice_note->text.empty() || is_final) {
on_voice_note_transcription_updated(file_id);
}
if (is_final) {
auto it = speech_recognition_queries_.find(file_id);
CHECK(it != speech_recognition_queries_.end());
CHECK(!it->second.empty());
auto promises = std::move(it->second);
speech_recognition_queries_.erase(it);
set_promises(promises);
}
}
void VoiceNotesManager::on_voice_note_transcription_failed(FileId file_id, Status &&error) {
auto voice_note = get_voice_note(file_id);
CHECK(voice_note != nullptr);
CHECK(!voice_note->is_transcribed);
CHECK(voice_note->transcription_id == 0);
auto it = speech_recognition_queries_.find(file_id);
CHECK(it != speech_recognition_queries_.end());
CHECK(!it->second.empty());
auto promises = std::move(it->second);
speech_recognition_queries_.erase(it);
fail_promises(promises, std::move(error));
}
void VoiceNotesManager::on_voice_note_transcription_updated(FileId file_id) {
auto it = voice_note_messages_.find(file_id);
if (it != voice_note_messages_.end()) {
for (const auto &full_message_id : it->second) {
td_->messages_manager_->on_external_update_message_content(full_message_id);
}
}
} }
SecretInputMedia VoiceNotesManager::get_secret_input_media(FileId voice_note_file_id, SecretInputMedia VoiceNotesManager::get_secret_input_media(FileId voice_note_file_id,

View File

@ -13,9 +13,12 @@
#include "td/telegram/telegram_api.h" #include "td/telegram/telegram_api.h"
#include "td/telegram/Version.h" #include "td/telegram/Version.h"
#include "td/actor/PromiseFuture.h"
#include "td/utils/common.h" #include "td/utils/common.h"
#include "td/utils/FlatHashMap.h" #include "td/utils/FlatHashMap.h"
#include "td/utils/FlatHashSet.h" #include "td/utils/FlatHashSet.h"
#include "td/utils/Status.h"
namespace td { namespace td {
@ -35,6 +38,12 @@ class VoiceNotesManager {
void unregister_voice_note(FileId voice_note_file_id, FullMessageId full_message_id, const char *source); void unregister_voice_note(FileId voice_note_file_id, FullMessageId full_message_id, const char *source);
void recognize_speech(FullMessageId full_message_id, Promise<Unit> &&promise);
void on_voice_note_transcribed(FileId file_id, string &&text, int64 transcription_id, bool is_final);
void on_voice_note_transcription_failed(FileId file_id, Status &&error);
tl_object_ptr<telegram_api::InputMedia> get_input_media(FileId file_id, tl_object_ptr<telegram_api::InputMedia> get_input_media(FileId file_id,
tl_object_ptr<telegram_api::InputFile> input_file) const; tl_object_ptr<telegram_api::InputFile> input_file) const;
@ -65,14 +74,21 @@ class VoiceNotesManager {
FileId file_id; FileId file_id;
}; };
VoiceNote *get_voice_note(FileId file_id);
const VoiceNote *get_voice_note(FileId file_id) const; const VoiceNote *get_voice_note(FileId file_id) const;
FileId on_get_voice_note(unique_ptr<VoiceNote> new_voice_note, bool replace); FileId on_get_voice_note(unique_ptr<VoiceNote> new_voice_note, bool replace);
void on_voice_note_transcription_updated(FileId file_id);
Td *td_; Td *td_;
FlatHashMap<FileId, unique_ptr<VoiceNote>, FileIdHash> voice_notes_; FlatHashMap<FileId, unique_ptr<VoiceNote>, FileIdHash> voice_notes_;
FlatHashMap<FileId, vector<Promise<Unit>>, FileIdHash> speech_recognition_queries_;
FlatHashMap<FileId, FlatHashSet<FullMessageId, FullMessageIdHash>, FileIdHash> voice_note_messages_; FlatHashMap<FileId, FlatHashSet<FullMessageId, FullMessageIdHash>, FileIdHash> voice_note_messages_;
FlatHashMap<FullMessageId, FileId, FullMessageIdHash> message_voice_notes_;
}; };
} // namespace td } // namespace td

View File

@ -2889,6 +2889,11 @@ class CliClient final : public Actor {
string to_language_code; string to_language_code;
get_args(args, text, from_language_code, to_language_code); get_args(args, text, from_language_code, to_language_code);
send_request(td_api::make_object<td_api::translateText>(text, from_language_code, to_language_code)); send_request(td_api::make_object<td_api::translateText>(text, from_language_code, to_language_code));
} else if (op == "rs") {
ChatId chat_id;
MessageId message_id;
get_args(args, chat_id, message_id);
send_request(td_api::make_object<td_api::recognizeSpeech>(chat_id, message_id));
} else if (op == "gf" || op == "GetFile") { } else if (op == "gf" || op == "GetFile") {
FileId file_id; FileId file_id;
get_args(args, file_id); get_args(args, file_id);