tdlight/td/telegram/TranscriptionInfo.cpp
2023-01-01 00:28:08 +03:00

208 lines
7.5 KiB
C++

//
// Copyright Aliaksei Levin (levlam@telegram.org), Arseny Smirnov (arseny30@gmail.com) 2014-2023
//
// Distributed under the Boost Software License, Version 1.0. (See accompanying
// file LICENSE_1_0.txt or copy at http://www.boost.org/LICENSE_1_0.txt)
//
#include "td/telegram/TranscriptionInfo.h"
#include "td/telegram/AccessRights.h"
#include "td/telegram/DialogId.h"
#include "td/telegram/Global.h"
#include "td/telegram/MessagesManager.h"
#include "td/telegram/Td.h"
#include "td/telegram/telegram_api.h"
#include "td/utils/buffer.h"
#include "td/utils/logging.h"
namespace td {
class TranscribeAudioQuery final : public Td::ResultHandler {
DialogId dialog_id_;
std::function<void(Result<telegram_api::object_ptr<telegram_api::updateTranscribedAudio>>)> handler_;
public:
void send(FullMessageId full_message_id,
std::function<void(Result<telegram_api::object_ptr<telegram_api::updateTranscribedAudio>>)> &&handler) {
dialog_id_ = full_message_id.get_dialog_id();
handler_ = std::move(handler);
auto input_peer = td_->messages_manager_->get_input_peer(dialog_id_, AccessRights::Read);
if (input_peer == nullptr) {
return on_error(Status::Error(400, "Can't access the chat"));
}
send_query(G()->net_query_creator().create(telegram_api::messages_transcribeAudio(
std::move(input_peer), full_message_id.get_message_id().get_server_message_id().get())));
}
void on_result(BufferSlice packet) final {
auto result_ptr = fetch_result<telegram_api::messages_transcribeAudio>(packet);
if (result_ptr.is_error()) {
return on_error(result_ptr.move_as_error());
}
auto result = result_ptr.move_as_ok();
LOG(INFO) << "Receive result for TranscribeAudioQuery: " << to_string(result);
if (result->transcription_id_ == 0) {
return on_error(Status::Error(500, "Receive no recognition identifier"));
}
auto update = telegram_api::make_object<telegram_api::updateTranscribedAudio>();
update->text_ = std::move(result->text_);
update->transcription_id_ = result->transcription_id_;
update->pending_ = result->pending_;
handler_(std::move(update));
}
void on_error(Status status) final {
td_->messages_manager_->on_get_dialog_error(dialog_id_, status, "TranscribeAudioQuery");
handler_(std::move(status));
}
};
class RateTranscribedAudioQuery final : public Td::ResultHandler {
Promise<Unit> promise_;
DialogId dialog_id_;
public:
explicit RateTranscribedAudioQuery(Promise<Unit> &&promise) : promise_(std::move(promise)) {
}
void send(FullMessageId full_message_id, int64 transcription_id, bool is_good) {
dialog_id_ = full_message_id.get_dialog_id();
auto input_peer = td_->messages_manager_->get_input_peer(dialog_id_, AccessRights::Read);
if (input_peer == nullptr) {
return on_error(Status::Error(400, "Can't access the chat"));
}
send_query(G()->net_query_creator().create(telegram_api::messages_rateTranscribedAudio(
std::move(input_peer), full_message_id.get_message_id().get_server_message_id().get(), transcription_id,
is_good)));
}
void on_result(BufferSlice packet) final {
auto result_ptr = fetch_result<telegram_api::messages_rateTranscribedAudio>(packet);
if (result_ptr.is_error()) {
return on_error(result_ptr.move_as_error());
}
bool result = result_ptr.ok();
LOG(INFO) << "Receive result for RateTranscribedAudioQuery: " << result;
promise_.set_value(Unit());
}
void on_error(Status status) final {
td_->messages_manager_->on_get_dialog_error(dialog_id_, status, "RateTranscribedAudioQuery");
promise_.set_error(std::move(status));
}
};
bool TranscriptionInfo::recognize_speech(
Td *td, FullMessageId full_message_id, Promise<Unit> &&promise,
std::function<void(Result<telegram_api::object_ptr<telegram_api::updateTranscribedAudio>>)> &&handler) {
if (is_transcribed_) {
promise.set_value(Unit());
return false;
}
speech_recognition_queries_.push_back(std::move(promise));
if (speech_recognition_queries_.size() == 1) {
last_transcription_error_ = Status::OK();
td->create_handler<TranscribeAudioQuery>()->send(full_message_id, std::move(handler));
return true;
}
return false;
}
vector<Promise<Unit>> TranscriptionInfo::on_final_transcription(string &&text, int64 transcription_id) {
CHECK(!is_transcribed_);
CHECK(transcription_id_ == 0 || transcription_id_ == transcription_id);
CHECK(transcription_id != 0);
transcription_id_ = transcription_id;
is_transcribed_ = true;
text_ = std::move(text);
last_transcription_error_ = Status::OK();
CHECK(!speech_recognition_queries_.empty());
auto promises = std::move(speech_recognition_queries_);
speech_recognition_queries_.clear();
return promises;
}
bool TranscriptionInfo::on_partial_transcription(string &&text, int64 transcription_id) {
CHECK(!is_transcribed_);
CHECK(transcription_id_ == 0 || transcription_id_ == transcription_id);
CHECK(transcription_id != 0);
bool is_changed = text_ != text;
transcription_id_ = transcription_id;
text_ = std::move(text);
last_transcription_error_ = Status::OK();
return is_changed;
}
vector<Promise<Unit>> TranscriptionInfo::on_failed_transcription(Status &&error) {
CHECK(!is_transcribed_);
transcription_id_ = 0;
text_.clear();
last_transcription_error_ = std::move(error);
CHECK(!speech_recognition_queries_.empty());
auto promises = std::move(speech_recognition_queries_);
speech_recognition_queries_.clear();
return promises;
}
void TranscriptionInfo::rate_speech_recognition(Td *td, FullMessageId full_message_id, bool is_good,
Promise<Unit> &&promise) const {
if (!is_transcribed_) {
return promise.set_value(Unit());
}
CHECK(transcription_id_ != 0);
td->create_handler<RateTranscribedAudioQuery>(std::move(promise))->send(full_message_id, transcription_id_, is_good);
}
unique_ptr<TranscriptionInfo> TranscriptionInfo::copy_if_transcribed(const unique_ptr<TranscriptionInfo> &info) {
if (info == nullptr || !info->is_transcribed_) {
return nullptr;
}
auto result = make_unique<TranscriptionInfo>();
result->is_transcribed_ = true;
result->transcription_id_ = info->transcription_id_;
result->text_ = info->text_;
return result;
}
bool TranscriptionInfo::update_from(unique_ptr<TranscriptionInfo> &old_info, unique_ptr<TranscriptionInfo> &&new_info) {
if (new_info == nullptr || !new_info->is_transcribed_) {
return false;
}
CHECK(new_info->transcription_id_ != 0);
CHECK(new_info->last_transcription_error_.is_ok());
CHECK(new_info->speech_recognition_queries_.empty());
if (old_info == nullptr) {
old_info = std::move(new_info);
return true;
}
if (old_info->transcription_id_ != 0 || !old_info->speech_recognition_queries_.empty()) {
return false;
}
CHECK(!old_info->is_transcribed_);
old_info = std::move(new_info);
return true;
}
td_api::object_ptr<td_api::SpeechRecognitionResult> TranscriptionInfo::get_speech_recognition_result_object() const {
if (is_transcribed_) {
return td_api::make_object<td_api::speechRecognitionResultText>(text_);
}
if (!speech_recognition_queries_.empty()) {
return td_api::make_object<td_api::speechRecognitionResultPending>(text_);
}
if (last_transcription_error_.is_error()) {
return td_api::make_object<td_api::speechRecognitionResultError>(td_api::make_object<td_api::error>(
last_transcription_error_.code(), last_transcription_error_.message().str()));
}
return nullptr;
}
} // namespace td