Add class TranscriptionInfo.
This commit is contained in:
parent
02ed5e4f1e
commit
5fe3a7ca94
@ -454,6 +454,7 @@ set(TDLIB_SOURCE
|
||||
td/telegram/ThemeManager.cpp
|
||||
td/telegram/TopDialogCategory.cpp
|
||||
td/telegram/TopDialogManager.cpp
|
||||
td/telegram/TranscriptionInfo.cpp
|
||||
td/telegram/UpdatesManager.cpp
|
||||
td/telegram/Usernames.cpp
|
||||
td/telegram/Venue.cpp
|
||||
@ -718,6 +719,7 @@ set(TDLIB_SOURCE
|
||||
td/telegram/ThemeManager.h
|
||||
td/telegram/TopDialogCategory.h
|
||||
td/telegram/TopDialogManager.h
|
||||
td/telegram/TranscriptionInfo.h
|
||||
td/telegram/UniqueId.h
|
||||
td/telegram/UpdatesManager.h
|
||||
td/telegram/UserId.h
|
||||
@ -769,6 +771,7 @@ set(TDLIB_SOURCE
|
||||
td/telegram/SendCodeHelper.hpp
|
||||
td/telegram/StickerSetId.hpp
|
||||
td/telegram/StickersManager.hpp
|
||||
td/telegram/TranscriptionInfo.hpp
|
||||
td/telegram/VideoNotesManager.hpp
|
||||
td/telegram/VideosManager.hpp
|
||||
td/telegram/VoiceNotesManager.hpp
|
||||
|
108
td/telegram/TranscriptionInfo.cpp
Normal file
108
td/telegram/TranscriptionInfo.cpp
Normal file
@ -0,0 +1,108 @@
|
||||
//
|
||||
// Copyright Aliaksei Levin (levlam@telegram.org), Arseny Smirnov (arseny30@gmail.com) 2014-2022
|
||||
//
|
||||
// Distributed under the Boost Software License, Version 1.0. (See accompanying
|
||||
// file LICENSE_1_0.txt or copy at http://www.boost.org/LICENSE_1_0.txt)
|
||||
//
|
||||
#include "td/telegram/TranscriptionInfo.h"
|
||||
|
||||
namespace td {
|
||||
|
||||
bool TranscriptionInfo::start_recognize_speech(Promise<Unit> &&promise) {
|
||||
if (is_transcribed_) {
|
||||
promise.set_value(Unit());
|
||||
return false;
|
||||
}
|
||||
speech_recognition_queries_.push_back(std::move(promise));
|
||||
if (speech_recognition_queries_.size() == 1) {
|
||||
last_transcription_error_ = Status::OK();
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
vector<Promise<Unit>> TranscriptionInfo::on_final_transcription(string &&text, int64 transcription_id) {
|
||||
CHECK(!is_transcribed_);
|
||||
CHECK(transcription_id_ == 0 || transcription_id_ == transcription_id);
|
||||
CHECK(transcription_id != 0);
|
||||
transcription_id_ = transcription_id;
|
||||
is_transcribed_ = true;
|
||||
text_ = std::move(text);
|
||||
last_transcription_error_ = Status::OK();
|
||||
|
||||
CHECK(!speech_recognition_queries_.empty());
|
||||
auto promises = std::move(speech_recognition_queries_);
|
||||
speech_recognition_queries_.clear();
|
||||
|
||||
return std::move(promises);
|
||||
}
|
||||
|
||||
bool TranscriptionInfo::on_partial_transcription(string &&text, int64 transcription_id) {
|
||||
CHECK(!is_transcribed_);
|
||||
CHECK(transcription_id_ == 0 || transcription_id_ == transcription_id);
|
||||
CHECK(transcription_id != 0);
|
||||
bool is_changed = text_ != text;
|
||||
transcription_id_ = transcription_id;
|
||||
text_ = std::move(text);
|
||||
last_transcription_error_ = Status::OK();
|
||||
|
||||
return is_changed;
|
||||
}
|
||||
|
||||
vector<Promise<Unit>> TranscriptionInfo::on_failed_transcription(Status &&error) {
|
||||
CHECK(!is_transcribed_);
|
||||
transcription_id_ = 0;
|
||||
text_.clear();
|
||||
last_transcription_error_ = std::move(error);
|
||||
|
||||
CHECK(!speech_recognition_queries_.empty());
|
||||
auto promises = std::move(speech_recognition_queries_);
|
||||
speech_recognition_queries_.clear();
|
||||
return promises;
|
||||
}
|
||||
|
||||
unique_ptr<TranscriptionInfo> TranscriptionInfo::copy_if_transcribed(const unique_ptr<TranscriptionInfo> &info) {
|
||||
if (info == nullptr || !info->is_transcribed_) {
|
||||
return nullptr;
|
||||
}
|
||||
auto result = make_unique<TranscriptionInfo>();
|
||||
result->is_transcribed_ = true;
|
||||
result->transcription_id_ = info->transcription_id_;
|
||||
result->text_ = info->text_;
|
||||
return result;
|
||||
}
|
||||
|
||||
bool TranscriptionInfo::update_from(unique_ptr<TranscriptionInfo> &old_info, unique_ptr<TranscriptionInfo> &&new_info) {
|
||||
if (new_info == nullptr || !new_info->is_transcribed_) {
|
||||
return false;
|
||||
}
|
||||
CHECK(new_info->transcription_id_ != 0);
|
||||
CHECK(new_info->last_transcription_error_.is_ok());
|
||||
CHECK(new_info->speech_recognition_queries_.empty());
|
||||
if (old_info == nullptr) {
|
||||
old_info = std::move(new_info);
|
||||
return true;
|
||||
}
|
||||
if (old_info->transcription_id_ != 0 || !old_info->speech_recognition_queries_.empty()) {
|
||||
return false;
|
||||
}
|
||||
CHECK(!old_info->is_transcribed_);
|
||||
old_info = std::move(new_info);
|
||||
return true;
|
||||
}
|
||||
|
||||
td_api::object_ptr<td_api::SpeechRecognitionResult> TranscriptionInfo::get_speech_recognition_result_object() const {
|
||||
if (is_transcribed_) {
|
||||
return td_api::make_object<td_api::speechRecognitionResultText>(text_);
|
||||
}
|
||||
if (!speech_recognition_queries_.empty()) {
|
||||
return td_api::make_object<td_api::speechRecognitionResultPending>(text_);
|
||||
}
|
||||
if (last_transcription_error_.is_error()) {
|
||||
return td_api::make_object<td_api::speechRecognitionResultError>(td_api::make_object<td_api::error>(
|
||||
last_transcription_error_.code(), last_transcription_error_.message().str()));
|
||||
}
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
} // namespace td
|
56
td/telegram/TranscriptionInfo.h
Normal file
56
td/telegram/TranscriptionInfo.h
Normal file
@ -0,0 +1,56 @@
|
||||
//
|
||||
// Copyright Aliaksei Levin (levlam@telegram.org), Arseny Smirnov (arseny30@gmail.com) 2014-2022
|
||||
//
|
||||
// Distributed under the Boost Software License, Version 1.0. (See accompanying
|
||||
// file LICENSE_1_0.txt or copy at http://www.boost.org/LICENSE_1_0.txt)
|
||||
//
|
||||
#pragma once
|
||||
|
||||
#include "td/telegram/td_api.h"
|
||||
|
||||
#include "td/utils/common.h"
|
||||
#include "td/utils/Promise.h"
|
||||
#include "td/utils/Status.h"
|
||||
|
||||
namespace td {
|
||||
|
||||
class TranscriptionInfo {
|
||||
bool is_transcribed_ = false;
|
||||
int64 transcription_id_ = 0;
|
||||
string text_;
|
||||
|
||||
// temporary state
|
||||
Status last_transcription_error_;
|
||||
vector<Promise<Unit>> speech_recognition_queries_;
|
||||
|
||||
public:
|
||||
bool is_transcribed() const {
|
||||
return is_transcribed_;
|
||||
}
|
||||
|
||||
int64 get_transcription_id() const {
|
||||
return transcription_id_;
|
||||
}
|
||||
|
||||
bool start_recognize_speech(Promise<Unit> &&promise);
|
||||
|
||||
vector<Promise<Unit>> on_final_transcription(string &&text, int64 transcription_id);
|
||||
|
||||
bool on_partial_transcription(string &&text, int64 transcription_id);
|
||||
|
||||
vector<Promise<Unit>> on_failed_transcription(Status &&error);
|
||||
|
||||
static unique_ptr<TranscriptionInfo> copy_if_transcribed(const unique_ptr<TranscriptionInfo> &info);
|
||||
|
||||
static bool update_from(unique_ptr<TranscriptionInfo> &old_info, unique_ptr<TranscriptionInfo> &&new_info);
|
||||
|
||||
td_api::object_ptr<td_api::SpeechRecognitionResult> get_speech_recognition_result_object() const;
|
||||
|
||||
template <class StorerT>
|
||||
void store(StorerT &storer) const;
|
||||
|
||||
template <class ParserT>
|
||||
void parse(ParserT &parser);
|
||||
};
|
||||
|
||||
} // namespace td
|
31
td/telegram/TranscriptionInfo.hpp
Normal file
31
td/telegram/TranscriptionInfo.hpp
Normal file
@ -0,0 +1,31 @@
|
||||
//
|
||||
// Copyright Aliaksei Levin (levlam@telegram.org), Arseny Smirnov (arseny30@gmail.com) 2014-2022
|
||||
//
|
||||
// Distributed under the Boost Software License, Version 1.0. (See accompanying
|
||||
// file LICENSE_1_0.txt or copy at http://www.boost.org/LICENSE_1_0.txt)
|
||||
//
|
||||
#pragma once
|
||||
|
||||
#include "td/telegram/TranscriptionInfo.h"
|
||||
|
||||
#include "td/utils/common.h"
|
||||
#include "td/utils/tl_helpers.h"
|
||||
|
||||
namespace td {
|
||||
|
||||
template <class StorerT>
|
||||
void TranscriptionInfo::store(StorerT &storer) const {
|
||||
CHECK(is_transcribed());
|
||||
td::store(transcription_id_, storer);
|
||||
td::store(text_, storer);
|
||||
}
|
||||
|
||||
template <class ParserT>
|
||||
void TranscriptionInfo::parse(ParserT &parser) {
|
||||
is_transcribed_ = true;
|
||||
td::parse(transcription_id_, parser);
|
||||
td::parse(text_, parser);
|
||||
CHECK(transcription_id_ != 0);
|
||||
}
|
||||
|
||||
} // namespace td
|
@ -136,22 +136,9 @@ tl_object_ptr<td_api::voiceNote> VoiceNotesManager::get_voice_note_object(FileId
|
||||
|
||||
auto voice_note = get_voice_note(file_id);
|
||||
CHECK(voice_note != nullptr);
|
||||
|
||||
auto speech_recognition_result = [this, voice_note]() -> td_api::object_ptr<td_api::SpeechRecognitionResult> {
|
||||
if (voice_note->is_transcribed) {
|
||||
return td_api::make_object<td_api::speechRecognitionResultText>(voice_note->text);
|
||||
}
|
||||
if (speech_recognition_queries_.count(voice_note->file_id) != 0) {
|
||||
return td_api::make_object<td_api::speechRecognitionResultPending>(voice_note->text);
|
||||
}
|
||||
if (voice_note->last_transcription_error.is_error()) {
|
||||
return td_api::make_object<td_api::speechRecognitionResultError>(
|
||||
td_api::make_object<td_api::error>(voice_note->last_transcription_error.error().code(),
|
||||
voice_note->last_transcription_error.error().message().str()));
|
||||
}
|
||||
return nullptr;
|
||||
}();
|
||||
|
||||
auto speech_recognition_result = voice_note->transcription_info == nullptr
|
||||
? nullptr
|
||||
: voice_note->transcription_info->get_speech_recognition_result_object();
|
||||
return make_tl_object<td_api::voiceNote>(voice_note->duration, voice_note->waveform, voice_note->mime_type,
|
||||
std::move(speech_recognition_result),
|
||||
td_->file_manager_->get_file_object(file_id));
|
||||
@ -175,13 +162,7 @@ FileId VoiceNotesManager::on_get_voice_note(unique_ptr<VoiceNote> new_voice_note
|
||||
v->duration = new_voice_note->duration;
|
||||
v->waveform = std::move(new_voice_note->waveform);
|
||||
}
|
||||
if (new_voice_note->is_transcribed && v->transcription_id == 0) {
|
||||
CHECK(!v->is_transcribed);
|
||||
CHECK(new_voice_note->transcription_id != 0);
|
||||
v->is_transcribed = true;
|
||||
v->transcription_id = new_voice_note->transcription_id;
|
||||
v->text = std::move(new_voice_note->text);
|
||||
v->last_transcription_error = Status::OK();
|
||||
if (TranscriptionInfo::update_from(v->transcription_info, std::move(new_voice_note->transcription_info))) {
|
||||
on_voice_note_transcription_completed(file_id);
|
||||
}
|
||||
}
|
||||
@ -207,10 +188,7 @@ FileId VoiceNotesManager::dup_voice_note(FileId new_id, FileId old_id) {
|
||||
new_voice_note->mime_type = old_voice_note->mime_type;
|
||||
new_voice_note->duration = old_voice_note->duration;
|
||||
new_voice_note->waveform = old_voice_note->waveform;
|
||||
if (old_voice_note->is_transcribed) {
|
||||
new_voice_note->is_transcribed = old_voice_note->is_transcribed;
|
||||
new_voice_note->text = old_voice_note->text;
|
||||
}
|
||||
new_voice_note->transcription_info = TranscriptionInfo::copy_if_transcribed(old_voice_note->transcription_info);
|
||||
return new_id;
|
||||
}
|
||||
|
||||
@ -278,14 +256,11 @@ void VoiceNotesManager::recognize_speech(FullMessageId full_message_id, Promise<
|
||||
auto file_id = it->second;
|
||||
auto voice_note = get_voice_note(file_id);
|
||||
CHECK(voice_note != nullptr);
|
||||
if (voice_note->is_transcribed) {
|
||||
return promise.set_value(Unit());
|
||||
if (voice_note->transcription_info == nullptr) {
|
||||
voice_note->transcription_info = make_unique<TranscriptionInfo>();
|
||||
}
|
||||
auto &queries = speech_recognition_queries_[file_id];
|
||||
queries.push_back(std::move(promise));
|
||||
if (queries.size() == 1) {
|
||||
if (voice_note->transcription_info->start_recognize_speech(std::move(promise))) {
|
||||
td_->create_handler<TranscribeAudioQuery>()->send(file_id, full_message_id);
|
||||
voice_note->last_transcription_error = Status::OK();
|
||||
on_voice_note_transcription_updated(file_id);
|
||||
}
|
||||
}
|
||||
@ -294,25 +269,13 @@ void VoiceNotesManager::on_voice_note_transcribed(FileId file_id, string &&text,
|
||||
bool is_final) {
|
||||
auto voice_note = get_voice_note(file_id);
|
||||
CHECK(voice_note != nullptr);
|
||||
CHECK(!voice_note->is_transcribed);
|
||||
CHECK(voice_note->transcription_id == 0 || voice_note->transcription_id == transcription_id);
|
||||
CHECK(transcription_id != 0);
|
||||
bool is_changed = voice_note->is_transcribed != is_final || voice_note->text != text;
|
||||
voice_note->transcription_id = transcription_id;
|
||||
voice_note->is_transcribed = is_final;
|
||||
voice_note->text = std::move(text);
|
||||
voice_note->last_transcription_error = Status::OK();
|
||||
|
||||
CHECK(voice_note->transcription_info != nullptr);
|
||||
if (is_final) {
|
||||
auto it = speech_recognition_queries_.find(file_id);
|
||||
CHECK(it != speech_recognition_queries_.end());
|
||||
CHECK(!it->second.empty());
|
||||
auto promises = std::move(it->second);
|
||||
speech_recognition_queries_.erase(it);
|
||||
|
||||
auto promises = voice_note->transcription_info->on_final_transcription(std::move(text), transcription_id);
|
||||
on_voice_note_transcription_completed(file_id);
|
||||
set_promises(promises);
|
||||
} else {
|
||||
auto is_changed = voice_note->transcription_info->on_partial_transcription(std::move(text), transcription_id);
|
||||
if (is_changed) {
|
||||
on_voice_note_transcription_updated(file_id);
|
||||
}
|
||||
@ -330,19 +293,9 @@ void VoiceNotesManager::on_voice_note_transcribed(FileId file_id, string &&text,
|
||||
void VoiceNotesManager::on_voice_note_transcription_failed(FileId file_id, Status &&error) {
|
||||
auto voice_note = get_voice_note(file_id);
|
||||
CHECK(voice_note != nullptr);
|
||||
CHECK(!voice_note->is_transcribed);
|
||||
CHECK(pending_voice_note_transcription_queries_.count(voice_note->transcription_id) == 0);
|
||||
|
||||
voice_note->transcription_id = 0;
|
||||
voice_note->text.clear();
|
||||
voice_note->last_transcription_error = error.clone();
|
||||
|
||||
auto it = speech_recognition_queries_.find(file_id);
|
||||
CHECK(it != speech_recognition_queries_.end());
|
||||
CHECK(!it->second.empty());
|
||||
auto promises = std::move(it->second);
|
||||
speech_recognition_queries_.erase(it);
|
||||
|
||||
CHECK(voice_note->transcription_info != nullptr);
|
||||
CHECK(pending_voice_note_transcription_queries_.count(voice_note->transcription_info->get_transcription_id()) == 0);
|
||||
auto promises = voice_note->transcription_info->on_failed_transcription(error.clone());
|
||||
on_voice_note_transcription_updated(file_id);
|
||||
fail_promises(promises, std::move(error));
|
||||
}
|
||||
@ -396,12 +349,12 @@ void VoiceNotesManager::rate_speech_recognition(FullMessageId full_message_id, b
|
||||
auto file_id = it->second;
|
||||
auto voice_note = get_voice_note(file_id);
|
||||
CHECK(voice_note != nullptr);
|
||||
if (!voice_note->is_transcribed) {
|
||||
if (voice_note->transcription_info == nullptr || !voice_note->transcription_info->is_transcribed()) {
|
||||
return promise.set_value(Unit());
|
||||
}
|
||||
CHECK(voice_note->transcription_id != 0);
|
||||
td_->create_handler<RateTranscribedAudioQuery>(std::move(promise))
|
||||
->send(full_message_id, voice_note->transcription_id, is_good);
|
||||
auto transcription_id = voice_note->transcription_info->get_transcription_id();
|
||||
CHECK(transcription_id != 0);
|
||||
td_->create_handler<RateTranscribedAudioQuery>(std::move(promise))->send(full_message_id, transcription_id, is_good);
|
||||
}
|
||||
|
||||
SecretInputMedia VoiceNotesManager::get_secret_input_media(FileId voice_note_file_id,
|
||||
|
@ -11,6 +11,7 @@
|
||||
#include "td/telegram/SecretInputMedia.h"
|
||||
#include "td/telegram/td_api.h"
|
||||
#include "td/telegram/telegram_api.h"
|
||||
#include "td/telegram/TranscriptionInfo.h"
|
||||
|
||||
#include "td/actor/actor.h"
|
||||
#include "td/actor/MultiTimeout.h"
|
||||
@ -79,11 +80,8 @@ class VoiceNotesManager final : public Actor {
|
||||
public:
|
||||
string mime_type;
|
||||
int32 duration = 0;
|
||||
bool is_transcribed = false;
|
||||
string waveform;
|
||||
int64 transcription_id = 0;
|
||||
string text;
|
||||
Status last_transcription_error;
|
||||
unique_ptr<TranscriptionInfo> transcription_info;
|
||||
|
||||
FileId file_id;
|
||||
};
|
||||
@ -109,7 +107,6 @@ class VoiceNotesManager final : public Actor {
|
||||
|
||||
WaitFreeHashMap<FileId, unique_ptr<VoiceNote>, FileIdHash> voice_notes_;
|
||||
|
||||
FlatHashMap<FileId, vector<Promise<Unit>>, FileIdHash> speech_recognition_queries_;
|
||||
FlatHashMap<int64, FileId> pending_voice_note_transcription_queries_;
|
||||
MultiTimeout voice_note_transcription_timeout_{"VoiceNoteTranscriptionTimeout"};
|
||||
|
||||
|
@ -9,6 +9,7 @@
|
||||
#include "td/telegram/VoiceNotesManager.h"
|
||||
|
||||
#include "td/telegram/files/FileId.hpp"
|
||||
#include "td/telegram/TranscriptionInfo.hpp"
|
||||
#include "td/telegram/Version.h"
|
||||
|
||||
#include "td/utils/common.h"
|
||||
@ -23,11 +24,12 @@ void VoiceNotesManager::store_voice_note(FileId file_id, StorerT &storer) const
|
||||
bool has_mime_type = !voice_note->mime_type.empty();
|
||||
bool has_duration = voice_note->duration != 0;
|
||||
bool has_waveform = !voice_note->waveform.empty();
|
||||
bool is_transcribed = voice_note->transcription_info != nullptr && voice_note->transcription_info->is_transcribed();
|
||||
BEGIN_STORE_FLAGS();
|
||||
STORE_FLAG(has_mime_type);
|
||||
STORE_FLAG(has_duration);
|
||||
STORE_FLAG(has_waveform);
|
||||
STORE_FLAG(voice_note->is_transcribed);
|
||||
STORE_FLAG(is_transcribed);
|
||||
END_STORE_FLAGS();
|
||||
if (has_mime_type) {
|
||||
store(voice_note->mime_type, storer);
|
||||
@ -38,9 +40,8 @@ void VoiceNotesManager::store_voice_note(FileId file_id, StorerT &storer) const
|
||||
if (has_waveform) {
|
||||
store(voice_note->waveform, storer);
|
||||
}
|
||||
if (voice_note->is_transcribed) {
|
||||
store(voice_note->transcription_id, storer);
|
||||
store(voice_note->text, storer);
|
||||
if (is_transcribed) {
|
||||
store(voice_note->transcription_info, storer);
|
||||
}
|
||||
store(file_id, storer);
|
||||
}
|
||||
@ -51,18 +52,19 @@ FileId VoiceNotesManager::parse_voice_note(ParserT &parser) {
|
||||
bool has_mime_type;
|
||||
bool has_duration;
|
||||
bool has_waveform;
|
||||
bool is_transcribed;
|
||||
if (parser.version() >= static_cast<int32>(Version::AddVoiceNoteFlags)) {
|
||||
BEGIN_PARSE_FLAGS();
|
||||
PARSE_FLAG(has_mime_type);
|
||||
PARSE_FLAG(has_duration);
|
||||
PARSE_FLAG(has_waveform);
|
||||
PARSE_FLAG(voice_note->is_transcribed);
|
||||
PARSE_FLAG(is_transcribed);
|
||||
END_PARSE_FLAGS();
|
||||
} else {
|
||||
has_mime_type = true;
|
||||
has_duration = true;
|
||||
has_waveform = true;
|
||||
voice_note->is_transcribed = false;
|
||||
is_transcribed = false;
|
||||
}
|
||||
if (has_mime_type) {
|
||||
parse(voice_note->mime_type, parser);
|
||||
@ -73,9 +75,8 @@ FileId VoiceNotesManager::parse_voice_note(ParserT &parser) {
|
||||
if (has_waveform) {
|
||||
parse(voice_note->waveform, parser);
|
||||
}
|
||||
if (voice_note->is_transcribed) {
|
||||
parse(voice_note->transcription_id, parser);
|
||||
parse(voice_note->text, parser);
|
||||
if (is_transcribed) {
|
||||
parse(voice_note->transcription_info, parser);
|
||||
}
|
||||
parse(voice_note->file_id, parser);
|
||||
if (parser.get_error() != nullptr || !voice_note->file_id.is_valid()) {
|
||||
|
Loading…
Reference in New Issue
Block a user