tdlight/td/telegram/MessageThreadDb.cpp
2023-01-01 00:28:08 +03:00

344 lines
12 KiB
C++

//
// Copyright Aliaksei Levin (levlam@telegram.org), Arseny Smirnov (arseny30@gmail.com) 2014-2023
//
// Distributed under the Boost Software License, Version 1.0. (See accompanying
// file LICENSE_1_0.txt or copy at http://www.boost.org/LICENSE_1_0.txt)
//
#include "td/telegram/MessageThreadDb.h"
#include "td/telegram/Version.h"
#include "td/db/SqliteConnectionSafe.h"
#include "td/db/SqliteDb.h"
#include "td/db/SqliteStatement.h"
#include "td/actor/actor.h"
#include "td/actor/SchedulerLocalStorage.h"
#include "td/utils/common.h"
#include "td/utils/format.h"
#include "td/utils/logging.h"
#include "td/utils/ScopeGuard.h"
#include "td/utils/Time.h"
namespace td {
// NB: must happen inside a transaction
Status init_message_thread_db(SqliteDb &db, int32 version) {
LOG(INFO) << "Init message thread database " << tag("version", version);
// Check if database exists
TRY_RESULT(has_table, db.has_table("threads"));
if (!has_table) {
version = 0;
}
if (version > current_db_version()) {
TRY_STATUS(drop_message_thread_db(db, version));
version = 0;
}
if (version == 0) {
LOG(INFO) << "Create new message thread database";
TRY_STATUS(
db.exec("CREATE TABLE IF NOT EXISTS threads (dialog_id INT8, thread_id INT8, thread_order INT8, data BLOB, "
"PRIMARY KEY (dialog_id, thread_id))"));
TRY_STATUS(
db.exec("CREATE INDEX IF NOT EXISTS dialog_threads_by_thread_order ON threads (dialog_id, thread_order)"));
version = current_db_version();
}
return Status::OK();
}
// NB: must happen inside a transaction
Status drop_message_thread_db(SqliteDb &db, int version) {
if (version > current_db_version()) {
LOG(WARNING) << "Drop message_thread_db " << tag("version", version)
<< tag("current_db_version", current_db_version());
}
return db.exec("DROP TABLE IF EXISTS threads");
}
class MessageThreadDbImpl final : public MessageThreadDbSyncInterface {
public:
explicit MessageThreadDbImpl(SqliteDb db) : db_(std::move(db)) {
init().ensure();
}
Status init() {
TRY_RESULT_ASSIGN(add_thread_stmt_, db_.get_statement("INSERT OR REPLACE INTO threads VALUES(?1, ?2, ?3, ?4)"));
TRY_RESULT_ASSIGN(delete_thread_stmt_,
db_.get_statement("DELETE FROM threads WHERE dialog_id = ?1 AND thread_id = ?2"));
TRY_RESULT_ASSIGN(delete_all_dialog_threads_stmt_, db_.get_statement("DELETE FROM threads WHERE dialog_id = ?1"));
TRY_RESULT_ASSIGN(get_thread_stmt_,
db_.get_statement("SELECT data FROM threads WHERE dialog_id = ?1 AND thread_id = ?2"));
TRY_RESULT_ASSIGN(get_threads_stmt_,
db_.get_statement("SELECT data, dialog_id, thread_id, thread_order FROM threads WHERE dialog_id "
"= ?1 AND thread_order < ?2 ORDER BY thread_order DESC LIMIT ?3"));
// LOG(ERROR) << delete_thread_stmt_.explain().ok();
// LOG(ERROR) << delete_all_dialog_threads_stmt_.explain().ok();
// LOG(ERROR) << get_thread_stmt_.explain().ok();
// LOG(ERROR) << get_threads_stmt_.explain().ok();
// LOG(FATAL) << "EXPLAINED";
return Status::OK();
}
void add_message_thread(DialogId dialog_id, MessageId top_thread_message_id, int64 order, BufferSlice data) final {
SCOPE_EXIT {
add_thread_stmt_.reset();
};
add_thread_stmt_.bind_int64(1, dialog_id.get()).ensure();
add_thread_stmt_.bind_int64(2, top_thread_message_id.get()).ensure();
add_thread_stmt_.bind_int64(3, order).ensure();
add_thread_stmt_.bind_blob(4, data.as_slice()).ensure();
add_thread_stmt_.step().ensure();
}
void delete_message_thread(DialogId dialog_id, MessageId top_thread_message_id) final {
SCOPE_EXIT {
delete_thread_stmt_.reset();
};
delete_thread_stmt_.bind_int64(1, dialog_id.get()).ensure();
delete_thread_stmt_.bind_int64(2, top_thread_message_id.get()).ensure();
delete_thread_stmt_.step().ensure();
}
void delete_all_dialog_message_threads(DialogId dialog_id) final {
SCOPE_EXIT {
delete_all_dialog_threads_stmt_.reset();
};
delete_all_dialog_threads_stmt_.bind_int64(1, dialog_id.get()).ensure();
delete_all_dialog_threads_stmt_.step().ensure();
}
BufferSlice get_message_thread(DialogId dialog_id, MessageId top_thread_message_id) final {
SCOPE_EXIT {
get_thread_stmt_.reset();
};
get_thread_stmt_.bind_int64(1, dialog_id.get()).ensure();
get_thread_stmt_.bind_int64(2, top_thread_message_id.get()).ensure();
get_thread_stmt_.step().ensure();
if (!get_thread_stmt_.has_row()) {
return BufferSlice();
}
return BufferSlice(get_thread_stmt_.view_blob(0));
}
MessageThreadDbMessageThreads get_message_threads(DialogId dialog_id, int64 offset_order, int32 limit) final {
SCOPE_EXIT {
get_threads_stmt_.reset();
};
get_threads_stmt_.bind_int64(1, dialog_id.get()).ensure();
get_threads_stmt_.bind_int64(2, offset_order).ensure();
get_threads_stmt_.bind_int32(3, limit).ensure();
MessageThreadDbMessageThreads result;
result.next_order = offset_order;
get_threads_stmt_.step().ensure();
while (get_threads_stmt_.has_row()) {
BufferSlice data(get_threads_stmt_.view_blob(0));
result.next_order = get_threads_stmt_.view_int64(3);
LOG(INFO) << "Load thread of " << MessageId(get_threads_stmt_.view_int64(2)) << " in "
<< DialogId(get_threads_stmt_.view_int64(1)) << " with order " << result.next_order;
result.message_threads.emplace_back(std::move(data));
get_threads_stmt_.step().ensure();
}
return result;
}
Status begin_write_transaction() final {
return db_.begin_write_transaction();
}
Status commit_transaction() final {
return db_.commit_transaction();
}
private:
SqliteDb db_;
SqliteStatement add_thread_stmt_;
SqliteStatement delete_thread_stmt_;
SqliteStatement delete_all_dialog_threads_stmt_;
SqliteStatement get_thread_stmt_;
SqliteStatement get_threads_stmt_;
};
std::shared_ptr<MessageThreadDbSyncSafeInterface> create_message_thread_db_sync(
std::shared_ptr<SqliteConnectionSafe> sqlite_connection) {
class MessageThreadDbSyncSafe final : public MessageThreadDbSyncSafeInterface {
public:
explicit MessageThreadDbSyncSafe(std::shared_ptr<SqliteConnectionSafe> sqlite_connection)
: lsls_db_([safe_connection = std::move(sqlite_connection)] {
return make_unique<MessageThreadDbImpl>(safe_connection->get().clone());
}) {
}
MessageThreadDbSyncInterface &get() final {
return *lsls_db_.get();
}
private:
LazySchedulerLocalStorage<unique_ptr<MessageThreadDbSyncInterface>> lsls_db_;
};
return std::make_shared<MessageThreadDbSyncSafe>(std::move(sqlite_connection));
}
class MessageThreadDbAsync final : public MessageThreadDbAsyncInterface {
public:
MessageThreadDbAsync(std::shared_ptr<MessageThreadDbSyncSafeInterface> sync_db, int32 scheduler_id) {
impl_ = create_actor_on_scheduler<Impl>("MessageThreadDbActor", scheduler_id, std::move(sync_db));
}
void add_message_thread(DialogId dialog_id, MessageId top_thread_message_id, int64 order, BufferSlice data,
Promise<Unit> promise) final {
send_closure(impl_, &Impl::add_message_thread, dialog_id, top_thread_message_id, order, std::move(data),
std::move(promise));
}
void delete_message_thread(DialogId dialog_id, MessageId top_thread_message_id, Promise<Unit> promise) final {
send_closure(impl_, &Impl::delete_message_thread, dialog_id, top_thread_message_id, std::move(promise));
}
void delete_all_dialog_message_threads(DialogId dialog_id, Promise<Unit> promise) final {
send_closure(impl_, &Impl::delete_all_dialog_message_threads, dialog_id, std::move(promise));
}
void get_message_thread(DialogId dialog_id, MessageId top_thread_message_id, Promise<BufferSlice> promise) final {
send_closure_later(impl_, &Impl::get_message_thread, dialog_id, top_thread_message_id, std::move(promise));
}
void get_message_threads(DialogId dialog_id, int64 offset_order, int32 limit,
Promise<MessageThreadDbMessageThreads> promise) final {
send_closure_later(impl_, &Impl::get_message_threads, dialog_id, offset_order, limit, std::move(promise));
}
void close(Promise<Unit> promise) final {
send_closure_later(impl_, &Impl::close, std::move(promise));
}
void force_flush() final {
send_closure_later(impl_, &Impl::force_flush);
}
private:
class Impl final : public Actor {
public:
explicit Impl(std::shared_ptr<MessageThreadDbSyncSafeInterface> sync_db_safe)
: sync_db_safe_(std::move(sync_db_safe)) {
}
void add_message_thread(DialogId dialog_id, MessageId top_thread_message_id, int64 order, BufferSlice data,
Promise<Unit> promise) {
add_write_query([this, dialog_id, top_thread_message_id, order, data = std::move(data),
promise = std::move(promise)](Unit) mutable {
sync_db_->add_message_thread(dialog_id, top_thread_message_id, order, std::move(data));
on_write_result(std::move(promise));
});
}
void delete_message_thread(DialogId dialog_id, MessageId top_thread_message_id, Promise<Unit> promise) {
add_write_query([this, dialog_id, top_thread_message_id, promise = std::move(promise)](Unit) mutable {
sync_db_->delete_message_thread(dialog_id, top_thread_message_id);
on_write_result(std::move(promise));
});
}
void delete_all_dialog_message_threads(DialogId dialog_id, Promise<Unit> promise) {
add_write_query([this, dialog_id, promise = std::move(promise)](Unit) mutable {
sync_db_->delete_all_dialog_message_threads(dialog_id);
on_write_result(std::move(promise));
});
}
void on_write_result(Promise<Unit> &&promise) {
// We are inside a transaction and don't know how to handle errors
finished_writes_.push_back(std::move(promise));
}
void get_message_thread(DialogId dialog_id, MessageId top_thread_message_id, Promise<BufferSlice> promise) {
add_read_query();
promise.set_result(sync_db_->get_message_thread(dialog_id, top_thread_message_id));
}
void get_message_threads(DialogId dialog_id, int64 offset_order, int32 limit,
Promise<MessageThreadDbMessageThreads> promise) {
add_read_query();
promise.set_result(sync_db_->get_message_threads(dialog_id, offset_order, limit));
}
void close(Promise<> promise) {
do_flush();
sync_db_safe_.reset();
sync_db_ = nullptr;
promise.set_value(Unit());
stop();
}
void force_flush() {
do_flush();
LOG(INFO) << "MessageThreadDb flushed";
}
private:
std::shared_ptr<MessageThreadDbSyncSafeInterface> sync_db_safe_;
MessageThreadDbSyncInterface *sync_db_ = nullptr;
static constexpr size_t MAX_PENDING_QUERIES_COUNT{50};
static constexpr double MAX_PENDING_QUERIES_DELAY{0.01};
//NB: order is important, destructor of pending_writes_ will change finished_writes_
vector<Promise<Unit>> finished_writes_;
vector<Promise<Unit>> pending_writes_; // TODO use Action
double wakeup_at_ = 0;
template <class F>
void add_write_query(F &&f) {
pending_writes_.push_back(PromiseCreator::lambda(std::forward<F>(f)));
if (pending_writes_.size() > MAX_PENDING_QUERIES_COUNT) {
do_flush();
wakeup_at_ = 0;
} else if (wakeup_at_ == 0) {
wakeup_at_ = Time::now_cached() + MAX_PENDING_QUERIES_DELAY;
}
if (wakeup_at_ != 0) {
set_timeout_at(wakeup_at_);
}
}
void add_read_query() {
do_flush();
}
void do_flush() {
if (pending_writes_.empty()) {
return;
}
sync_db_->begin_write_transaction().ensure();
set_promises(pending_writes_);
sync_db_->commit_transaction().ensure();
set_promises(finished_writes_);
cancel_timeout();
}
void timeout_expired() final {
do_flush();
}
void start_up() final {
sync_db_ = &sync_db_safe_->get();
}
};
ActorOwn<Impl> impl_;
};
std::shared_ptr<MessageThreadDbAsyncInterface> create_message_thread_db_async(
std::shared_ptr<MessageThreadDbSyncSafeInterface> sync_db, int32 scheduler_id) {
return std::make_shared<MessageThreadDbAsync>(std::move(sync_db), scheduler_id);
}
} // namespace td