Add class TranscriptionInfo.

This commit is contained in:
levlam 2022-10-19 20:43:30 +03:00
parent 02ed5e4f1e
commit 5fe3a7ca94
7 changed files with 228 additions and 79 deletions

View File

@ -454,6 +454,7 @@ set(TDLIB_SOURCE
td/telegram/ThemeManager.cpp
td/telegram/TopDialogCategory.cpp
td/telegram/TopDialogManager.cpp
td/telegram/TranscriptionInfo.cpp
td/telegram/UpdatesManager.cpp
td/telegram/Usernames.cpp
td/telegram/Venue.cpp
@ -718,6 +719,7 @@ set(TDLIB_SOURCE
td/telegram/ThemeManager.h
td/telegram/TopDialogCategory.h
td/telegram/TopDialogManager.h
td/telegram/TranscriptionInfo.h
td/telegram/UniqueId.h
td/telegram/UpdatesManager.h
td/telegram/UserId.h
@ -769,6 +771,7 @@ set(TDLIB_SOURCE
td/telegram/SendCodeHelper.hpp
td/telegram/StickerSetId.hpp
td/telegram/StickersManager.hpp
td/telegram/TranscriptionInfo.hpp
td/telegram/VideoNotesManager.hpp
td/telegram/VideosManager.hpp
td/telegram/VoiceNotesManager.hpp

View File

@ -0,0 +1,108 @@
//
// Copyright Aliaksei Levin (levlam@telegram.org), Arseny Smirnov (arseny30@gmail.com) 2014-2022
//
// Distributed under the Boost Software License, Version 1.0. (See accompanying
// file LICENSE_1_0.txt or copy at http://www.boost.org/LICENSE_1_0.txt)
//
#include "td/telegram/TranscriptionInfo.h"
namespace td {
bool TranscriptionInfo::start_recognize_speech(Promise<Unit> &&promise) {
if (is_transcribed_) {
promise.set_value(Unit());
return false;
}
speech_recognition_queries_.push_back(std::move(promise));
if (speech_recognition_queries_.size() == 1) {
last_transcription_error_ = Status::OK();
return true;
}
return false;
}
vector<Promise<Unit>> TranscriptionInfo::on_final_transcription(string &&text, int64 transcription_id) {
CHECK(!is_transcribed_);
CHECK(transcription_id_ == 0 || transcription_id_ == transcription_id);
CHECK(transcription_id != 0);
transcription_id_ = transcription_id;
is_transcribed_ = true;
text_ = std::move(text);
last_transcription_error_ = Status::OK();
CHECK(!speech_recognition_queries_.empty());
auto promises = std::move(speech_recognition_queries_);
speech_recognition_queries_.clear();
return std::move(promises);
}
bool TranscriptionInfo::on_partial_transcription(string &&text, int64 transcription_id) {
CHECK(!is_transcribed_);
CHECK(transcription_id_ == 0 || transcription_id_ == transcription_id);
CHECK(transcription_id != 0);
bool is_changed = text_ != text;
transcription_id_ = transcription_id;
text_ = std::move(text);
last_transcription_error_ = Status::OK();
return is_changed;
}
vector<Promise<Unit>> TranscriptionInfo::on_failed_transcription(Status &&error) {
CHECK(!is_transcribed_);
transcription_id_ = 0;
text_.clear();
last_transcription_error_ = std::move(error);
CHECK(!speech_recognition_queries_.empty());
auto promises = std::move(speech_recognition_queries_);
speech_recognition_queries_.clear();
return promises;
}
unique_ptr<TranscriptionInfo> TranscriptionInfo::copy_if_transcribed(const unique_ptr<TranscriptionInfo> &info) {
if (info == nullptr || !info->is_transcribed_) {
return nullptr;
}
auto result = make_unique<TranscriptionInfo>();
result->is_transcribed_ = true;
result->transcription_id_ = info->transcription_id_;
result->text_ = info->text_;
return result;
}
bool TranscriptionInfo::update_from(unique_ptr<TranscriptionInfo> &old_info, unique_ptr<TranscriptionInfo> &&new_info) {
if (new_info == nullptr || !new_info->is_transcribed_) {
return false;
}
CHECK(new_info->transcription_id_ != 0);
CHECK(new_info->last_transcription_error_.is_ok());
CHECK(new_info->speech_recognition_queries_.empty());
if (old_info == nullptr) {
old_info = std::move(new_info);
return true;
}
if (old_info->transcription_id_ != 0 || !old_info->speech_recognition_queries_.empty()) {
return false;
}
CHECK(!old_info->is_transcribed_);
old_info = std::move(new_info);
return true;
}
td_api::object_ptr<td_api::SpeechRecognitionResult> TranscriptionInfo::get_speech_recognition_result_object() const {
if (is_transcribed_) {
return td_api::make_object<td_api::speechRecognitionResultText>(text_);
}
if (!speech_recognition_queries_.empty()) {
return td_api::make_object<td_api::speechRecognitionResultPending>(text_);
}
if (last_transcription_error_.is_error()) {
return td_api::make_object<td_api::speechRecognitionResultError>(td_api::make_object<td_api::error>(
last_transcription_error_.code(), last_transcription_error_.message().str()));
}
return nullptr;
}
} // namespace td

View File

@ -0,0 +1,56 @@
//
// Copyright Aliaksei Levin (levlam@telegram.org), Arseny Smirnov (arseny30@gmail.com) 2014-2022
//
// Distributed under the Boost Software License, Version 1.0. (See accompanying
// file LICENSE_1_0.txt or copy at http://www.boost.org/LICENSE_1_0.txt)
//
#pragma once
#include "td/telegram/td_api.h"
#include "td/utils/common.h"
#include "td/utils/Promise.h"
#include "td/utils/Status.h"
namespace td {
class TranscriptionInfo {
bool is_transcribed_ = false;
int64 transcription_id_ = 0;
string text_;
// temporary state
Status last_transcription_error_;
vector<Promise<Unit>> speech_recognition_queries_;
public:
bool is_transcribed() const {
return is_transcribed_;
}
int64 get_transcription_id() const {
return transcription_id_;
}
bool start_recognize_speech(Promise<Unit> &&promise);
vector<Promise<Unit>> on_final_transcription(string &&text, int64 transcription_id);
bool on_partial_transcription(string &&text, int64 transcription_id);
vector<Promise<Unit>> on_failed_transcription(Status &&error);
static unique_ptr<TranscriptionInfo> copy_if_transcribed(const unique_ptr<TranscriptionInfo> &info);
static bool update_from(unique_ptr<TranscriptionInfo> &old_info, unique_ptr<TranscriptionInfo> &&new_info);
td_api::object_ptr<td_api::SpeechRecognitionResult> get_speech_recognition_result_object() const;
template <class StorerT>
void store(StorerT &storer) const;
template <class ParserT>
void parse(ParserT &parser);
};
} // namespace td

View File

@ -0,0 +1,31 @@
//
// Copyright Aliaksei Levin (levlam@telegram.org), Arseny Smirnov (arseny30@gmail.com) 2014-2022
//
// Distributed under the Boost Software License, Version 1.0. (See accompanying
// file LICENSE_1_0.txt or copy at http://www.boost.org/LICENSE_1_0.txt)
//
#pragma once
#include "td/telegram/TranscriptionInfo.h"
#include "td/utils/common.h"
#include "td/utils/tl_helpers.h"
namespace td {
template <class StorerT>
void TranscriptionInfo::store(StorerT &storer) const {
CHECK(is_transcribed());
td::store(transcription_id_, storer);
td::store(text_, storer);
}
template <class ParserT>
void TranscriptionInfo::parse(ParserT &parser) {
is_transcribed_ = true;
td::parse(transcription_id_, parser);
td::parse(text_, parser);
CHECK(transcription_id_ != 0);
}
} // namespace td

View File

@ -136,22 +136,9 @@ tl_object_ptr<td_api::voiceNote> VoiceNotesManager::get_voice_note_object(FileId
auto voice_note = get_voice_note(file_id);
CHECK(voice_note != nullptr);
auto speech_recognition_result = [this, voice_note]() -> td_api::object_ptr<td_api::SpeechRecognitionResult> {
if (voice_note->is_transcribed) {
return td_api::make_object<td_api::speechRecognitionResultText>(voice_note->text);
}
if (speech_recognition_queries_.count(voice_note->file_id) != 0) {
return td_api::make_object<td_api::speechRecognitionResultPending>(voice_note->text);
}
if (voice_note->last_transcription_error.is_error()) {
return td_api::make_object<td_api::speechRecognitionResultError>(
td_api::make_object<td_api::error>(voice_note->last_transcription_error.error().code(),
voice_note->last_transcription_error.error().message().str()));
}
return nullptr;
}();
auto speech_recognition_result = voice_note->transcription_info == nullptr
? nullptr
: voice_note->transcription_info->get_speech_recognition_result_object();
return make_tl_object<td_api::voiceNote>(voice_note->duration, voice_note->waveform, voice_note->mime_type,
std::move(speech_recognition_result),
td_->file_manager_->get_file_object(file_id));
@ -175,13 +162,7 @@ FileId VoiceNotesManager::on_get_voice_note(unique_ptr<VoiceNote> new_voice_note
v->duration = new_voice_note->duration;
v->waveform = std::move(new_voice_note->waveform);
}
if (new_voice_note->is_transcribed && v->transcription_id == 0) {
CHECK(!v->is_transcribed);
CHECK(new_voice_note->transcription_id != 0);
v->is_transcribed = true;
v->transcription_id = new_voice_note->transcription_id;
v->text = std::move(new_voice_note->text);
v->last_transcription_error = Status::OK();
if (TranscriptionInfo::update_from(v->transcription_info, std::move(new_voice_note->transcription_info))) {
on_voice_note_transcription_completed(file_id);
}
}
@ -207,10 +188,7 @@ FileId VoiceNotesManager::dup_voice_note(FileId new_id, FileId old_id) {
new_voice_note->mime_type = old_voice_note->mime_type;
new_voice_note->duration = old_voice_note->duration;
new_voice_note->waveform = old_voice_note->waveform;
if (old_voice_note->is_transcribed) {
new_voice_note->is_transcribed = old_voice_note->is_transcribed;
new_voice_note->text = old_voice_note->text;
}
new_voice_note->transcription_info = TranscriptionInfo::copy_if_transcribed(old_voice_note->transcription_info);
return new_id;
}
@ -278,14 +256,11 @@ void VoiceNotesManager::recognize_speech(FullMessageId full_message_id, Promise<
auto file_id = it->second;
auto voice_note = get_voice_note(file_id);
CHECK(voice_note != nullptr);
if (voice_note->is_transcribed) {
return promise.set_value(Unit());
if (voice_note->transcription_info == nullptr) {
voice_note->transcription_info = make_unique<TranscriptionInfo>();
}
auto &queries = speech_recognition_queries_[file_id];
queries.push_back(std::move(promise));
if (queries.size() == 1) {
if (voice_note->transcription_info->start_recognize_speech(std::move(promise))) {
td_->create_handler<TranscribeAudioQuery>()->send(file_id, full_message_id);
voice_note->last_transcription_error = Status::OK();
on_voice_note_transcription_updated(file_id);
}
}
@ -294,25 +269,13 @@ void VoiceNotesManager::on_voice_note_transcribed(FileId file_id, string &&text,
bool is_final) {
auto voice_note = get_voice_note(file_id);
CHECK(voice_note != nullptr);
CHECK(!voice_note->is_transcribed);
CHECK(voice_note->transcription_id == 0 || voice_note->transcription_id == transcription_id);
CHECK(transcription_id != 0);
bool is_changed = voice_note->is_transcribed != is_final || voice_note->text != text;
voice_note->transcription_id = transcription_id;
voice_note->is_transcribed = is_final;
voice_note->text = std::move(text);
voice_note->last_transcription_error = Status::OK();
CHECK(voice_note->transcription_info != nullptr);
if (is_final) {
auto it = speech_recognition_queries_.find(file_id);
CHECK(it != speech_recognition_queries_.end());
CHECK(!it->second.empty());
auto promises = std::move(it->second);
speech_recognition_queries_.erase(it);
auto promises = voice_note->transcription_info->on_final_transcription(std::move(text), transcription_id);
on_voice_note_transcription_completed(file_id);
set_promises(promises);
} else {
auto is_changed = voice_note->transcription_info->on_partial_transcription(std::move(text), transcription_id);
if (is_changed) {
on_voice_note_transcription_updated(file_id);
}
@ -330,19 +293,9 @@ void VoiceNotesManager::on_voice_note_transcribed(FileId file_id, string &&text,
void VoiceNotesManager::on_voice_note_transcription_failed(FileId file_id, Status &&error) {
auto voice_note = get_voice_note(file_id);
CHECK(voice_note != nullptr);
CHECK(!voice_note->is_transcribed);
CHECK(pending_voice_note_transcription_queries_.count(voice_note->transcription_id) == 0);
voice_note->transcription_id = 0;
voice_note->text.clear();
voice_note->last_transcription_error = error.clone();
auto it = speech_recognition_queries_.find(file_id);
CHECK(it != speech_recognition_queries_.end());
CHECK(!it->second.empty());
auto promises = std::move(it->second);
speech_recognition_queries_.erase(it);
CHECK(voice_note->transcription_info != nullptr);
CHECK(pending_voice_note_transcription_queries_.count(voice_note->transcription_info->get_transcription_id()) == 0);
auto promises = voice_note->transcription_info->on_failed_transcription(error.clone());
on_voice_note_transcription_updated(file_id);
fail_promises(promises, std::move(error));
}
@ -396,12 +349,12 @@ void VoiceNotesManager::rate_speech_recognition(FullMessageId full_message_id, b
auto file_id = it->second;
auto voice_note = get_voice_note(file_id);
CHECK(voice_note != nullptr);
if (!voice_note->is_transcribed) {
if (voice_note->transcription_info == nullptr || !voice_note->transcription_info->is_transcribed()) {
return promise.set_value(Unit());
}
CHECK(voice_note->transcription_id != 0);
td_->create_handler<RateTranscribedAudioQuery>(std::move(promise))
->send(full_message_id, voice_note->transcription_id, is_good);
auto transcription_id = voice_note->transcription_info->get_transcription_id();
CHECK(transcription_id != 0);
td_->create_handler<RateTranscribedAudioQuery>(std::move(promise))->send(full_message_id, transcription_id, is_good);
}
SecretInputMedia VoiceNotesManager::get_secret_input_media(FileId voice_note_file_id,

View File

@ -11,6 +11,7 @@
#include "td/telegram/SecretInputMedia.h"
#include "td/telegram/td_api.h"
#include "td/telegram/telegram_api.h"
#include "td/telegram/TranscriptionInfo.h"
#include "td/actor/actor.h"
#include "td/actor/MultiTimeout.h"
@ -79,11 +80,8 @@ class VoiceNotesManager final : public Actor {
public:
string mime_type;
int32 duration = 0;
bool is_transcribed = false;
string waveform;
int64 transcription_id = 0;
string text;
Status last_transcription_error;
unique_ptr<TranscriptionInfo> transcription_info;
FileId file_id;
};
@ -109,7 +107,6 @@ class VoiceNotesManager final : public Actor {
WaitFreeHashMap<FileId, unique_ptr<VoiceNote>, FileIdHash> voice_notes_;
FlatHashMap<FileId, vector<Promise<Unit>>, FileIdHash> speech_recognition_queries_;
FlatHashMap<int64, FileId> pending_voice_note_transcription_queries_;
MultiTimeout voice_note_transcription_timeout_{"VoiceNoteTranscriptionTimeout"};

View File

@ -9,6 +9,7 @@
#include "td/telegram/VoiceNotesManager.h"
#include "td/telegram/files/FileId.hpp"
#include "td/telegram/TranscriptionInfo.hpp"
#include "td/telegram/Version.h"
#include "td/utils/common.h"
@ -23,11 +24,12 @@ void VoiceNotesManager::store_voice_note(FileId file_id, StorerT &storer) const
bool has_mime_type = !voice_note->mime_type.empty();
bool has_duration = voice_note->duration != 0;
bool has_waveform = !voice_note->waveform.empty();
bool is_transcribed = voice_note->transcription_info != nullptr && voice_note->transcription_info->is_transcribed();
BEGIN_STORE_FLAGS();
STORE_FLAG(has_mime_type);
STORE_FLAG(has_duration);
STORE_FLAG(has_waveform);
STORE_FLAG(voice_note->is_transcribed);
STORE_FLAG(is_transcribed);
END_STORE_FLAGS();
if (has_mime_type) {
store(voice_note->mime_type, storer);
@ -38,9 +40,8 @@ void VoiceNotesManager::store_voice_note(FileId file_id, StorerT &storer) const
if (has_waveform) {
store(voice_note->waveform, storer);
}
if (voice_note->is_transcribed) {
store(voice_note->transcription_id, storer);
store(voice_note->text, storer);
if (is_transcribed) {
store(voice_note->transcription_info, storer);
}
store(file_id, storer);
}
@ -51,18 +52,19 @@ FileId VoiceNotesManager::parse_voice_note(ParserT &parser) {
bool has_mime_type;
bool has_duration;
bool has_waveform;
bool is_transcribed;
if (parser.version() >= static_cast<int32>(Version::AddVoiceNoteFlags)) {
BEGIN_PARSE_FLAGS();
PARSE_FLAG(has_mime_type);
PARSE_FLAG(has_duration);
PARSE_FLAG(has_waveform);
PARSE_FLAG(voice_note->is_transcribed);
PARSE_FLAG(is_transcribed);
END_PARSE_FLAGS();
} else {
has_mime_type = true;
has_duration = true;
has_waveform = true;
voice_note->is_transcribed = false;
is_transcribed = false;
}
if (has_mime_type) {
parse(voice_note->mime_type, parser);
@ -73,9 +75,8 @@ FileId VoiceNotesManager::parse_voice_note(ParserT &parser) {
if (has_waveform) {
parse(voice_note->waveform, parser);
}
if (voice_note->is_transcribed) {
parse(voice_note->transcription_id, parser);
parse(voice_note->text, parser);
if (is_transcribed) {
parse(voice_note->transcription_info, parser);
}
parse(voice_note->file_id, parser);
if (parser.get_error() != nullptr || !voice_note->file_id.is_valid()) {