// // Copyright Aliaksei Levin (levlam@telegram.org), Arseny Smirnov (arseny30@gmail.com) 2014-2022 // // Distributed under the Boost Software License, Version 1.0. (See accompanying // file LICENSE_1_0.txt or copy at http://www.boost.org/LICENSE_1_0.txt) // #include "td/telegram/TranscriptionInfo.h" #include "td/telegram/AccessRights.h" #include "td/telegram/DialogId.h" #include "td/telegram/Global.h" #include "td/telegram/MessagesManager.h" #include "td/telegram/Td.h" #include "td/telegram/telegram_api.h" #include "td/utils/buffer.h" #include "td/utils/logging.h" namespace td { class TranscribeAudioQuery final : public Td::ResultHandler { DialogId dialog_id_; std::function>)> handler_; public: void send(FullMessageId full_message_id, std::function>)> &&handler) { dialog_id_ = full_message_id.get_dialog_id(); handler_ = std::move(handler); auto input_peer = td_->messages_manager_->get_input_peer(dialog_id_, AccessRights::Read); if (input_peer == nullptr) { return on_error(Status::Error(400, "Can't access the chat")); } send_query(G()->net_query_creator().create(telegram_api::messages_transcribeAudio( std::move(input_peer), full_message_id.get_message_id().get_server_message_id().get()))); } void on_result(BufferSlice packet) final { auto result_ptr = fetch_result(packet); if (result_ptr.is_error()) { return on_error(result_ptr.move_as_error()); } auto result = result_ptr.move_as_ok(); LOG(INFO) << "Receive result for TranscribeAudioQuery: " << to_string(result); if (result->transcription_id_ == 0) { return on_error(Status::Error(500, "Receive no recognition identifier")); } auto update = telegram_api::make_object(); update->text_ = std::move(result->text_); update->transcription_id_ = result->transcription_id_; update->pending_ = result->pending_; handler_(std::move(update)); } void on_error(Status status) final { td_->messages_manager_->on_get_dialog_error(dialog_id_, status, "TranscribeAudioQuery"); handler_(std::move(status)); } }; class RateTranscribedAudioQuery final : public Td::ResultHandler { Promise promise_; DialogId dialog_id_; public: explicit RateTranscribedAudioQuery(Promise &&promise) : promise_(std::move(promise)) { } void send(FullMessageId full_message_id, int64 transcription_id, bool is_good) { dialog_id_ = full_message_id.get_dialog_id(); auto input_peer = td_->messages_manager_->get_input_peer(dialog_id_, AccessRights::Read); if (input_peer == nullptr) { return on_error(Status::Error(400, "Can't access the chat")); } send_query(G()->net_query_creator().create(telegram_api::messages_rateTranscribedAudio( std::move(input_peer), full_message_id.get_message_id().get_server_message_id().get(), transcription_id, is_good))); } void on_result(BufferSlice packet) final { auto result_ptr = fetch_result(packet); if (result_ptr.is_error()) { return on_error(result_ptr.move_as_error()); } bool result = result_ptr.ok(); LOG(INFO) << "Receive result for RateTranscribedAudioQuery: " << result; promise_.set_value(Unit()); } void on_error(Status status) final { td_->messages_manager_->on_get_dialog_error(dialog_id_, status, "RateTranscribedAudioQuery"); promise_.set_error(std::move(status)); } }; bool TranscriptionInfo::recognize_speech( Td *td, FullMessageId full_message_id, Promise &&promise, std::function>)> &&handler) { if (is_transcribed_) { promise.set_value(Unit()); return false; } speech_recognition_queries_.push_back(std::move(promise)); if (speech_recognition_queries_.size() == 1) { last_transcription_error_ = Status::OK(); td->create_handler()->send(full_message_id, std::move(handler)); return true; } return false; } vector> TranscriptionInfo::on_final_transcription(string &&text, int64 transcription_id) { CHECK(!is_transcribed_); CHECK(transcription_id_ == 0 || transcription_id_ == transcription_id); CHECK(transcription_id != 0); transcription_id_ = transcription_id; is_transcribed_ = true; text_ = std::move(text); last_transcription_error_ = Status::OK(); CHECK(!speech_recognition_queries_.empty()); auto promises = std::move(speech_recognition_queries_); speech_recognition_queries_.clear(); return promises; } bool TranscriptionInfo::on_partial_transcription(string &&text, int64 transcription_id) { CHECK(!is_transcribed_); CHECK(transcription_id_ == 0 || transcription_id_ == transcription_id); CHECK(transcription_id != 0); bool is_changed = text_ != text; transcription_id_ = transcription_id; text_ = std::move(text); last_transcription_error_ = Status::OK(); return is_changed; } vector> TranscriptionInfo::on_failed_transcription(Status &&error) { CHECK(!is_transcribed_); transcription_id_ = 0; text_.clear(); last_transcription_error_ = std::move(error); CHECK(!speech_recognition_queries_.empty()); auto promises = std::move(speech_recognition_queries_); speech_recognition_queries_.clear(); return promises; } void TranscriptionInfo::rate_speech_recognition(Td *td, FullMessageId full_message_id, bool is_good, Promise &&promise) const { if (!is_transcribed_) { return promise.set_value(Unit()); } CHECK(transcription_id_ != 0); td->create_handler(std::move(promise))->send(full_message_id, transcription_id_, is_good); } unique_ptr TranscriptionInfo::copy_if_transcribed(const unique_ptr &info) { if (info == nullptr || !info->is_transcribed_) { return nullptr; } auto result = make_unique(); result->is_transcribed_ = true; result->transcription_id_ = info->transcription_id_; result->text_ = info->text_; return result; } bool TranscriptionInfo::update_from(unique_ptr &old_info, unique_ptr &&new_info) { if (new_info == nullptr || !new_info->is_transcribed_) { return false; } CHECK(new_info->transcription_id_ != 0); CHECK(new_info->last_transcription_error_.is_ok()); CHECK(new_info->speech_recognition_queries_.empty()); if (old_info == nullptr) { old_info = std::move(new_info); return true; } if (old_info->transcription_id_ != 0 || !old_info->speech_recognition_queries_.empty()) { return false; } CHECK(!old_info->is_transcribed_); old_info = std::move(new_info); return true; } td_api::object_ptr TranscriptionInfo::get_speech_recognition_result_object() const { if (is_transcribed_) { return td_api::make_object(text_); } if (!speech_recognition_queries_.empty()) { return td_api::make_object(text_); } if (last_transcription_error_.is_error()) { return td_api::make_object(td_api::make_object( last_transcription_error_.code(), last_transcription_error_.message().str())); } return nullptr; } } // namespace td