Move RateTranscribedAudioQuery to TranscriptionInfo.cpp.

This commit is contained in:
levlam 2022-10-19 21:22:57 +03:00
parent 5fe3a7ca94
commit bd78d57e56
3 changed files with 59 additions and 40 deletions

View File

@ -6,8 +6,51 @@
//
#include "td/telegram/TranscriptionInfo.h"
#include "td/telegram/DialogId.h"
#include "td/telegram/MessagesManager.h"
#include "td/telegram/Td.h"
#include "td/utils/buffer.h"
#include "td/utils/logging.h"
namespace td {
class RateTranscribedAudioQuery final : public Td::ResultHandler {
Promise<Unit> promise_;
DialogId dialog_id_;
public:
explicit RateTranscribedAudioQuery(Promise<Unit> &&promise) : promise_(std::move(promise)) {
}
void send(FullMessageId full_message_id, int64 transcription_id, bool is_good) {
dialog_id_ = full_message_id.get_dialog_id();
auto input_peer = td_->messages_manager_->get_input_peer(dialog_id_, AccessRights::Read);
if (input_peer == nullptr) {
return on_error(Status::Error(400, "Can't access the chat"));
}
send_query(G()->net_query_creator().create(telegram_api::messages_rateTranscribedAudio(
std::move(input_peer), full_message_id.get_message_id().get_server_message_id().get(), transcription_id,
is_good)));
}
void on_result(BufferSlice packet) final {
auto result_ptr = fetch_result<telegram_api::messages_rateTranscribedAudio>(packet);
if (result_ptr.is_error()) {
return on_error(result_ptr.move_as_error());
}
bool result = result_ptr.ok();
LOG(INFO) << "Receive result for RateTranscribedAudioQuery: " << result;
promise_.set_value(Unit());
}
void on_error(Status status) final {
td_->messages_manager_->on_get_dialog_error(dialog_id_, status, "RateTranscribedAudioQuery");
promise_.set_error(std::move(status));
}
};
bool TranscriptionInfo::start_recognize_speech(Promise<Unit> &&promise) {
if (is_transcribed_) {
promise.set_value(Unit());
@ -61,6 +104,15 @@ vector<Promise<Unit>> TranscriptionInfo::on_failed_transcription(Status &&error)
return promises;
}
void TranscriptionInfo::rate_speech_recognition(Td *td, FullMessageId full_message_id, bool is_good,
Promise<Unit> &&promise) const {
if (!is_transcribed_) {
return promise.set_value(Unit());
}
CHECK(transcription_id_ != 0);
td->create_handler<RateTranscribedAudioQuery>(std::move(promise))->send(full_message_id, transcription_id_, is_good);
}
unique_ptr<TranscriptionInfo> TranscriptionInfo::copy_if_transcribed(const unique_ptr<TranscriptionInfo> &info) {
if (info == nullptr || !info->is_transcribed_) {
return nullptr;

View File

@ -6,6 +6,7 @@
//
#pragma once
#include "td/telegram/FullMessageId.h"
#include "td/telegram/td_api.h"
#include "td/utils/common.h"
@ -14,6 +15,8 @@
namespace td {
class Td;
class TranscriptionInfo {
bool is_transcribed_ = false;
int64 transcription_id_ = 0;
@ -40,6 +43,8 @@ class TranscriptionInfo {
vector<Promise<Unit>> on_failed_transcription(Status &&error);
void rate_speech_recognition(Td *td, FullMessageId full_message_id, bool is_good, Promise<Unit> &&promise) const;
static unique_ptr<TranscriptionInfo> copy_if_transcribed(const unique_ptr<TranscriptionInfo> &info);
static bool update_from(unique_ptr<TranscriptionInfo> &old_info, unique_ptr<TranscriptionInfo> &&new_info);

View File

@ -59,42 +59,6 @@ class TranscribeAudioQuery final : public Td::ResultHandler {
}
};
class RateTranscribedAudioQuery final : public Td::ResultHandler {
Promise<Unit> promise_;
DialogId dialog_id_;
public:
explicit RateTranscribedAudioQuery(Promise<Unit> &&promise) : promise_(std::move(promise)) {
}
void send(FullMessageId full_message_id, int64 transcription_id, bool is_good) {
dialog_id_ = full_message_id.get_dialog_id();
auto input_peer = td_->messages_manager_->get_input_peer(dialog_id_, AccessRights::Read);
if (input_peer == nullptr) {
return on_error(Status::Error(400, "Can't access the chat"));
}
send_query(G()->net_query_creator().create(telegram_api::messages_rateTranscribedAudio(
std::move(input_peer), full_message_id.get_message_id().get_server_message_id().get(), transcription_id,
is_good)));
}
void on_result(BufferSlice packet) final {
auto result_ptr = fetch_result<telegram_api::messages_rateTranscribedAudio>(packet);
if (result_ptr.is_error()) {
return on_error(result_ptr.move_as_error());
}
bool result = result_ptr.ok();
LOG(INFO) << "Receive result for RateTranscribedAudioQuery: " << result;
promise_.set_value(Unit());
}
void on_error(Status status) final {
td_->messages_manager_->on_get_dialog_error(dialog_id_, status, "RateTranscribedAudioQuery");
promise_.set_error(std::move(status));
}
};
VoiceNotesManager::VoiceNotesManager(Td *td, ActorShared<> parent) : td_(td), parent_(std::move(parent)) {
voice_note_transcription_timeout_.set_callback(on_voice_note_transcription_timeout_callback);
voice_note_transcription_timeout_.set_callback_data(static_cast<void *>(this));
@ -349,12 +313,10 @@ void VoiceNotesManager::rate_speech_recognition(FullMessageId full_message_id, b
auto file_id = it->second;
auto voice_note = get_voice_note(file_id);
CHECK(voice_note != nullptr);
if (voice_note->transcription_info == nullptr || !voice_note->transcription_info->is_transcribed()) {
if (voice_note->transcription_info == nullptr) {
return promise.set_value(Unit());
}
auto transcription_id = voice_note->transcription_info->get_transcription_id();
CHECK(transcription_id != 0);
td_->create_handler<RateTranscribedAudioQuery>(std::move(promise))->send(full_message_id, transcription_id, is_good);
voice_note->transcription_info->rate_speech_recognition(td_, full_message_id, is_good, std::move(promise));
}
SecretInputMedia VoiceNotesManager::get_secret_input_media(FileId voice_note_file_id,