DownloadManager: WIP

This commit is contained in:
Arseny Smirnov 2022-02-25 23:18:07 +01:00
parent d4a4f5fb5d
commit 29f8f79b16
6 changed files with 487 additions and 4 deletions

View File

@ -314,6 +314,7 @@ set(TDLIB_SOURCE
td/telegram/Document.cpp
td/telegram/DocumentsManager.cpp
td/telegram/DownloadManager.cpp
td/telegram/DownloadsDb.cpp
td/telegram/DraftMessage.cpp
td/telegram/FileReferenceManager.cpp
td/telegram/files/FileBitmask.cpp

View File

@ -8,6 +8,10 @@
#include "td/utils/FlatHashMap.h"
#include "td/telegram/DownloadsDb.h"
#include "td/telegram/Global.h"
#include "td/telegram/TdDb.h"
namespace td {
class DownloadManagerImpl final : public DownloadManager {
@ -114,8 +118,10 @@ class DownloadManagerImpl final : public DownloadManager {
active_files_[file_id] = file_info;
callback_->start_file(file_info.internal_file_id, file_info.priority);
// TODO: add file to db
return Status::OK();
G()->td_db()->get_downloads_db_async()->add_download(
DownloadsDbDownload{callback_->get_unique_file_id(file_id),
callback_->get_file_source_serialized(file_source_id), search_by, 0, priority},
[](Result<Unit>) {});
}
void search(string query, bool only_active, bool only_completed, string offset, int32 limit,
@ -123,7 +129,10 @@ class DownloadManagerImpl final : public DownloadManager {
if (!callback_) {
return promise.set_error(Status::Error("TODO: code and message`"));
}
// TODO: do query
TRY_RESULT_PROMISE(promise, offset_int64, to_integer_safe<int64>(offset));
// TODO: only active, only completed
G()->td_db()->get_downloads_db_async()->get_downloads_fts(DownloadsDbFtsQuery{query, offset_int64, limit},
[](Result<Unit>) {});
return promise.set_value({});
}
@ -192,6 +201,10 @@ class DownloadManagerImpl final : public DownloadManager {
}
// TODO: ???
// TODO: load active files from db
auto downloads = G()->td_db()->get_downloads_db_sync()->get_active_downloads().move_as_ok();
for (auto &download : downloads.downloads) {
// ...
}
}
void tear_down() final {
callback_.reset();

334
td/telegram/DownloadsDb.cpp Normal file
View File

@ -0,0 +1,334 @@
//
// Copyright Aliaksei Levin (levlam@telegram.org), Arseny Smirnov (arseny30@gmail.com) 2014-2022
//
// Distributed under the Boost Software License, Version 1.0. (See accompanying
// file LICENSE_1_0.txt or copy at http://www.boost.org/LICENSE_1_0.txt)
//
#include "td/telegram/DownloadsDb.h"
#include "td/telegram/logevent/LogEvent.h"
#include "td/telegram/Version.h"
#include "td/db/SqliteConnectionSafe.h"
#include "td/db/SqliteDb.h"
#include "td/db/SqliteStatement.h"
#include "td/actor/actor.h"
#include "td/actor/PromiseFuture.h"
#include "td/actor/SchedulerLocalStorage.h"
#include "td/utils/format.h"
#include "td/utils/logging.h"
#include "td/utils/ScopeGuard.h"
#include "td/utils/Slice.h"
#include "td/utils/SliceBuilder.h"
#include "td/utils/StackAllocator.h"
#include "td/utils/StringBuilder.h"
#include "td/utils/Time.h"
#include "td/utils/tl_helpers.h"
#include "td/utils/unicode.h"
#include "td/utils/utf8.h"
#include <algorithm>
#include <array>
#include <iterator>
#include <limits>
#include <tuple>
#include <utility>
namespace td {
static constexpr int32 MESSAGES_DB_INDEX_COUNT = 30;
static constexpr int32 MESSAGES_DB_INDEX_COUNT_OLD = 9;
// NB: must happen inside a transaction
Status init_downloads_db(SqliteDb &db, int32 version) {
LOG(INFO) << "Init downloads database " << tag("version", version);
// Check if database exists
TRY_RESULT(has_table, db.has_table("downloads"));
if (!has_table) {
version = 0;
}
auto add_fts = [&db] {
TRY_STATUS(
db.exec("CREATE VIRTUAL TABLE IF NOT EXISTS downloads_fts USING fts5(search_text, content='downloads', "
"content_rowid='download_id', tokenize = \"unicode61 remove_diacritics 0 tokenchars '\a'\")"));
TRY_STATUS(
db.exec("CREATE TRIGGER IF NOT EXISTS trigger_downloads_fts_delete BEFORE DELETE ON downloads"
" BEGIN INSERT INTO downloads_fts(downloads_fts, rowid, search_text) VALUES(\'delete\', "
"OLD.download_id, OLD.search_text); END"));
TRY_STATUS(
db.exec("CREATE TRIGGER IF NOT EXISTS trigger_downloads_fts_insert AFTER INSERT ON downloads"
" BEGIN INSERT INTO downloads_fts(rowid, search_text) VALUES(NEW.download_id, NEW.search_text); END"));
// TODO: update?
return Status::OK();
};
if (version == 0) {
TRY_STATUS(
db.exec("CREATE TABLE IF NOT EXISTS downloads(download_id INT8 PRIMARY KEY, unique_file_id "
"BLOB UNIQUE, file_source BLOB, search_text STRING, date INT4, priority INT4)"));
// TODO: add indexes
// TRY_STATUS(
// db.exec("CREATE INDEX IF NOT EXISTS message_by_random_id ON messages (dialog_id, random_id) "
// "WHERE random_id IS NOT NULL"));
TRY_STATUS(add_fts());
version = current_db_version();
}
return Status::OK();
}
// NB: must happen inside a transaction
Status drop_downloads_db(SqliteDb &db, int32 version) {
LOG(WARNING) << "Drop downloads database " << tag("version", version)
<< tag("current_db_version", current_db_version());
return db.exec("DROP TABLE IF EXISTS downloads");
}
class DownloadsDbImpl final : public DownloadsDbSyncInterface {
public:
explicit DownloadsDbImpl(SqliteDb db) : db_(std::move(db)) {
init().ensure();
}
Status init() {
TRY_RESULT_ASSIGN(add_download_stmt_,
db_.get_statement("INSERT OR REPLACE INTO downloads VALUES(NULL, ?1, ?2, ?3, ?4, ?5)"));
TRY_RESULT_ASSIGN(
get_downloads_fts_stmt_,
db_.get_statement("SELECT download_id, unique_file_id, file_source, priority FROM downloads WHERE download_id "
"IN (SELECT rowid FROM downloads_fts WHERE downloads_fts MATCH ?1 AND rowid < ?2 "
"ORDER BY rowid DESC LIMIT ?3) ORDER BY download_id DESC"));
// LOG(ERROR) << get_message_stmt_.explain().ok();
// LOG(ERROR) << get_messages_from_notification_id_stmt.explain().ok();
// LOG(ERROR) << get_message_by_random_id_stmt_.explain().ok();
// LOG(ERROR) << get_message_by_unique_message_id_stmt_.explain().ok();
// LOG(ERROR) << get_expiring_messages_stmt_.explain().ok();
// LOG(ERROR) << get_expiring_messages_helper_stmt_.explain().ok();
// LOG(FATAL) << "EXPLAINED";
return Status::OK();
}
Result<DownloadsDbFtsResult> get_downloads_fts(DownloadsDbFtsQuery query) final {
SCOPE_EXIT {
get_downloads_fts_stmt_.reset();
};
auto &stmt = get_downloads_fts_stmt_;
stmt.bind_string(1, query.query).ensure();
stmt.bind_int64(2, query.offset).ensure();
stmt.bind_int32(3, query.limit).ensure();
DownloadsDbFtsResult result;
auto status = stmt.step();
if (status.is_error()) {
LOG(ERROR) << status;
return std::move(result);
}
while (stmt.has_row()) {
int64 download_id{stmt.view_int64(0)};
string unique_file_id{stmt.view_string(1).str()};
string file_source{stmt.view_string(2).str()};
int32 priority{stmt.view_int32(3)};
result.next_download_id = download_id;
result.downloads.push_back(DownloadsDbDownloadShort{std::move(unique_file_id), std::move(file_source), priority});
stmt.step().ensure();
}
return std::move(result);
}
Status begin_write_transaction() final {
return db_.begin_write_transaction();
}
Status commit_transaction() final {
return db_.commit_transaction();
}
Status add_download(DownloadsDbDownload download) override {
SCOPE_EXIT {
add_download_stmt_.reset();
};
auto &stmt = add_download_stmt_;
TRY_RESULT_ASSIGN(add_download_stmt_,
db_.get_statement("INSERT OR REPLACE INTO downloads VALUES(NULL, ?1, ?2, ?3, ?4, ?5)"));
stmt.bind_blob(1, download.unique_file_id).ensure();
stmt.bind_blob(2, download.file_source).ensure();
stmt.bind_string(3, download.search_text).ensure();
stmt.bind_int32(4, download.date).ensure();
stmt.bind_int32(5, download.priority).ensure();
stmt.step().ensure();
return Status();
}
Result<GetActiveDownloadsResult> get_active_downloads() override {
DownloadsDbFtsQuery query;
query.limit = 2000;
query.offset = uint64(1) << 60;
// TODO: optimize query
// TODO: only active
TRY_RESULT(ans, get_downloads_fts(query));
return GetActiveDownloadsResult{std::move(ans.downloads)};
}
private:
SqliteDb db_;
SqliteStatement add_download_stmt_;
SqliteStatement get_downloads_fts_stmt_;
};
std::shared_ptr<DownloadsDbSyncSafeInterface> create_downloads_db_sync(
std::shared_ptr<SqliteConnectionSafe> sqlite_connection) {
class DownloadsDbSyncSafe final : public DownloadsDbSyncSafeInterface {
public:
explicit DownloadsDbSyncSafe(std::shared_ptr<SqliteConnectionSafe> sqlite_connection)
: lsls_db_([safe_connection = std::move(sqlite_connection)] {
return make_unique<DownloadsDbImpl>(safe_connection->get().clone());
}) {
}
DownloadsDbSyncInterface &get() final {
return *lsls_db_.get();
}
private:
LazySchedulerLocalStorage<unique_ptr<DownloadsDbSyncInterface>> lsls_db_;
};
return std::make_shared<DownloadsDbSyncSafe>(std::move(sqlite_connection));
}
class DownloadsDbAsync final : public DownloadsDbAsyncInterface {
public:
DownloadsDbAsync(std::shared_ptr<DownloadsDbSyncSafeInterface> sync_db, int32 scheduler_id) {
impl_ = create_actor_on_scheduler<Impl>("DownloadsDbActor", scheduler_id, std::move(sync_db));
}
void add_download(DownloadsDbDownload query, Promise<> promise) final {
send_closure(impl_, &Impl::add_download, std::move(query), std::move(promise));
}
void get_active_downloads(Promise<GetActiveDownloadsResult> promise) final {
send_closure(impl_, &Impl::get_active_downloads, std::move(promise));
}
void get_downloads_fts(DownloadsDbFtsQuery query, Promise<DownloadsDbFtsResult> promise) final {
send_closure(impl_, &Impl::get_downloads_fts, std::move(query), std::move(promise));
}
void close(Promise<> promise) final {
send_closure_later(impl_, &Impl::close, std::move(promise));
}
void force_flush() final {
send_closure_later(impl_, &Impl::force_flush);
}
private:
class Impl final : public Actor {
public:
explicit Impl(std::shared_ptr<DownloadsDbSyncSafeInterface> sync_db_safe) : sync_db_safe_(std::move(sync_db_safe)) {
}
void add_download(DownloadsDbDownload query, Promise<> promise) {
add_write_query([this, query = std::move(query), promise = std::move(promise)](Unit) mutable {
on_write_result(std::move(promise), sync_db_->add_download(std::move(query)));
});
}
void get_downloads_fts(DownloadsDbFtsQuery query, Promise<DownloadsDbFtsResult> promise) {
add_read_query();
promise.set_result(sync_db_->get_downloads_fts(std::move(query)));
}
void get_active_downloads(Promise<> promise) {
add_read_query();
promise.set_result(sync_db_->get_active_downloads());
}
void close(Promise<> promise) {
do_flush();
sync_db_safe_.reset();
sync_db_ = nullptr;
promise.set_value(Unit());
stop();
}
void force_flush() {
LOG(INFO) << "DownloadsDb flushed";
do_flush();
}
private:
std::shared_ptr<DownloadsDbSyncSafeInterface> sync_db_safe_;
DownloadsDbSyncInterface *sync_db_ = nullptr;
static constexpr size_t MAX_PENDING_QUERIES_COUNT{50};
static constexpr double MAX_PENDING_QUERIES_DELAY{0.01};
//NB: order is important, destructor of pending_writes_ will change pending_write_results_
vector<std::pair<Promise<>, Status>> pending_write_results_;
vector<Promise<>> pending_writes_;
double wakeup_at_ = 0;
void on_write_result(Promise<> promise, Status status) {
// We are inside a transaction and don't know how to handle the error
status.ensure();
pending_write_results_.emplace_back(std::move(promise), std::move(status));
}
template <class F>
void add_write_query(F &&f) {
pending_writes_.push_back(PromiseCreator::lambda(std::forward<F>(f), PromiseCreator::Ignore()));
if (pending_writes_.size() > MAX_PENDING_QUERIES_COUNT) {
do_flush();
wakeup_at_ = 0;
} else if (wakeup_at_ == 0) {
wakeup_at_ = Time::now_cached() + MAX_PENDING_QUERIES_DELAY;
}
if (wakeup_at_ != 0) {
set_timeout_at(wakeup_at_);
}
}
void add_read_query() {
do_flush();
}
void do_flush() {
if (pending_writes_.empty()) {
return;
}
sync_db_->begin_write_transaction().ensure();
for (auto &query : pending_writes_) {
query.set_value(Unit());
}
sync_db_->commit_transaction().ensure();
pending_writes_.clear();
for (auto &p : pending_write_results_) {
p.first.set_result(std::move(p.second));
}
pending_write_results_.clear();
cancel_timeout();
}
void timeout_expired() final {
do_flush();
}
void start_up() final {
sync_db_ = &sync_db_safe_->get();
}
};
ActorOwn<Impl> impl_;
};
std::shared_ptr<DownloadsDbAsyncInterface> create_downloads_db_async(
std::shared_ptr<DownloadsDbSyncSafeInterface> sync_db, int32 scheduler_id) {
return std::make_shared<DownloadsDbAsync>(std::move(sync_db), scheduler_id);
}
} // namespace td

107
td/telegram/DownloadsDb.h Normal file
View File

@ -0,0 +1,107 @@
//
// Copyright Aliaksei Levin (levlam@telegram.org), Arseny Smirnov (arseny30@gmail.com) 2014-2022
//
// Distributed under the Boost Software License, Version 1.0. (See accompanying
// file LICENSE_1_0.txt or copy at http://www.boost.org/LICENSE_1_0.txt)
//
#pragma once
#include "td/telegram/DialogId.h"
#include "td/telegram/FullMessageId.h"
#include "td/telegram/MessageId.h"
#include "td/telegram/MessageSearchFilter.h"
#include "td/telegram/NotificationId.h"
#include "td/telegram/ServerMessageId.h"
#include "td/actor/PromiseFuture.h"
#include "td/utils/buffer.h"
#include "td/utils/common.h"
#include "td/utils/Status.h"
#include <memory>
#include <utility>
namespace td {
class SqliteConnectionSafe;
class SqliteDb;
struct DownloadsDbFtsQuery {
string query;
int64 offset{0};
int32 limit{0};
};
struct DownloadsDbDownloadShort {
string unique_file_id;
string file_source;
int32 priority;
};
struct DownloadsDbDownload {
string unique_file_id;
string file_source;
string search_text;
int32 date;
int32 priority;
};
struct GetActiveDownloadsResult {
std::vector<DownloadsDbDownloadShort> downloads;
};
struct DownloadsDbFtsResult {
vector<DownloadsDbDownloadShort> downloads;
int64 next_download_id{};
};
class DownloadsDbSyncInterface {
public:
DownloadsDbSyncInterface() = default;
DownloadsDbSyncInterface(const DownloadsDbSyncInterface &) = delete;
DownloadsDbSyncInterface &operator=(const DownloadsDbSyncInterface &) = delete;
virtual ~DownloadsDbSyncInterface() = default;
virtual Status add_download(DownloadsDbDownload) = 0;
virtual Result<GetActiveDownloadsResult> get_active_downloads() = 0;
virtual Result<DownloadsDbFtsResult> get_downloads_fts(DownloadsDbFtsQuery query) = 0;
virtual Status begin_write_transaction() = 0;
virtual Status commit_transaction() = 0;
};
class DownloadsDbSyncSafeInterface {
public:
DownloadsDbSyncSafeInterface() = default;
DownloadsDbSyncSafeInterface(const DownloadsDbSyncSafeInterface &) = delete;
DownloadsDbSyncSafeInterface &operator=(const DownloadsDbSyncSafeInterface &) = delete;
virtual ~DownloadsDbSyncSafeInterface() = default;
virtual DownloadsDbSyncInterface &get() = 0;
};
class DownloadsDbAsyncInterface {
public:
DownloadsDbAsyncInterface() = default;
DownloadsDbAsyncInterface(const DownloadsDbAsyncInterface &) = delete;
DownloadsDbAsyncInterface &operator=(const DownloadsDbAsyncInterface &) = delete;
virtual ~DownloadsDbAsyncInterface() = default;
virtual void add_download(DownloadsDbDownload, Promise<>) = 0;
virtual void get_active_downloads(Promise<GetActiveDownloadsResult>) = 0;
virtual void get_downloads_fts(DownloadsDbFtsQuery query, Promise<DownloadsDbFtsResult>) = 0;
virtual void close(Promise<> promise) = 0;
virtual void force_flush() = 0;
};
Status init_downloads_db(SqliteDb &db, int version) TD_WARN_UNUSED_RESULT;
Status drop_downloads_db(SqliteDb &db, int version) TD_WARN_UNUSED_RESULT;
std::shared_ptr<DownloadsDbSyncSafeInterface> create_downloads_db_sync(
std::shared_ptr<SqliteConnectionSafe> sqlite_connection);
std::shared_ptr<DownloadsDbAsyncInterface> create_downloads_db_async(
std::shared_ptr<DownloadsDbSyncSafeInterface> sync_db, int32 scheduler_id);
} // namespace td

View File

@ -7,6 +7,7 @@
#include "td/telegram/TdDb.h"
#include "td/telegram/DialogDb.h"
#include "td/telegram/DownloadsDb.h"
#include "td/telegram/files/FileDb.h"
#include "td/telegram/Global.h"
#include "td/telegram/logevent/LogEvent.h"
@ -202,6 +203,12 @@ DialogDbSyncInterface *TdDb::get_dialog_db_sync() {
DialogDbAsyncInterface *TdDb::get_dialog_db_async() {
return dialog_db_async_.get();
}
DownloadsDbSyncInterface *TdDb::get_downloads_db_sync() {
return &downloads_db_sync_safe_->get();
}
DownloadsDbAsyncInterface *TdDb::get_downloads_db_async() {
return downloads_db_async_.get();
}
CSlice TdDb::binlog_path() const {
return binlog_->get_path();
@ -265,6 +272,11 @@ void TdDb::do_close(Promise<> on_finished, bool destroy_flag) {
dialog_db_async_->close(mpas.get_promise());
}
downloads_db_sync_safe_.reset();
if (downloads_db_async_) {
downloads_db_async_->close(mpas.get_promise());
}
// binlog_pmc is dependent on binlog_ and anyway it doesn't support close_and_destroy
CHECK(binlog_pmc_.unique());
binlog_pmc_.reset();
@ -294,6 +306,7 @@ Status TdDb::init_sqlite(int32 scheduler_id, const TdParameters &parameters, con
bool use_file_db = parameters.use_file_db;
bool use_dialog_db = parameters.use_message_db;
bool use_message_db = parameters.use_message_db;
bool use_downloads_db = parameters.use_file_db;
if (!use_sqlite) {
unlink(sql_database_path).ignore();
return Status::OK();
@ -340,6 +353,12 @@ Status TdDb::init_sqlite(int32 scheduler_id, const TdParameters &parameters, con
TRY_STATUS(drop_file_db(db, user_version));
}
if (use_downloads_db) {
TRY_STATUS(init_downloads_db(db, user_version));
} else {
TRY_STATUS(drop_downloads_db(db, user_version));
}
// Update 'PRAGMA user_version'
auto db_version = current_db_version();
if (db_version != user_version) {
@ -375,7 +394,7 @@ Status TdDb::init_sqlite(int32 scheduler_id, const TdParameters &parameters, con
dialog_db_async_ = create_dialog_db_async(dialog_db_sync_safe_, scheduler_id);
}
if (use_message_db) {
if (use_downloads_db) {
messages_db_sync_safe_ = create_messages_db_sync(sql_connection_);
messages_db_async_ = create_messages_db_async(messages_db_sync_safe_, scheduler_id);
}

View File

@ -34,6 +34,9 @@ class FileDbInterface;
class MessagesDbSyncInterface;
class MessagesDbSyncSafeInterface;
class MessagesDbAsyncInterface;
class DownloadsDbSyncInterface;
class DownloadsDbSyncSafeInterface;
class DownloadsDbAsyncInterface;
class SqliteConnectionSafe;
class SqliteKeyValueSafe;
class SqliteKeyValueAsyncInterface;
@ -95,6 +98,9 @@ class TdDb {
DialogDbSyncInterface *get_dialog_db_sync();
DialogDbAsyncInterface *get_dialog_db_async();
DownloadsDbSyncInterface *get_downloads_db_sync();
DownloadsDbAsyncInterface *get_downloads_db_async();
void change_key(DbKey key, Promise<> promise);
void with_db_path(const std::function<void(CSlice)> &callback);
@ -113,6 +119,9 @@ class TdDb {
std::shared_ptr<MessagesDbSyncSafeInterface> messages_db_sync_safe_;
std::shared_ptr<MessagesDbAsyncInterface> messages_db_async_;
std::shared_ptr<DownloadsDbSyncSafeInterface> downloads_db_sync_safe_;
std::shared_ptr<DownloadsDbAsyncInterface> downloads_db_async_;
std::shared_ptr<DialogDbSyncSafeInterface> dialog_db_sync_safe_;
std::shared_ptr<DialogDbAsyncInterface> dialog_db_async_;