diff --git a/td/telegram/CountryInfoManager.cpp b/td/telegram/CountryInfoManager.cpp index bbb500d58..a7842c556 100644 --- a/td/telegram/CountryInfoManager.cpp +++ b/td/telegram/CountryInfoManager.cpp @@ -451,7 +451,7 @@ const CountryInfoManager::CountryList *CountryInfoManager::get_country_list(cons it = countries_.find(language_code); CHECK(it != countries_.end()) auto *country = it->second.get(); - load_country_list(language_code, country->hash, Auto()); + load_country_list(language_code, country->hash, {}); return country; } return nullptr; @@ -460,7 +460,7 @@ const CountryInfoManager::CountryList *CountryInfoManager::get_country_list(cons auto *country = it->second.get(); CHECK(country != nullptr); if (country->next_reload_time < Time::now()) { - load_country_list(language_code, country->hash, Auto()); + load_country_list(language_code, country->hash, {}); } return country; diff --git a/tdactor/td/actor/PromiseFuture.h b/tdactor/td/actor/PromiseFuture.h index 29b21b42b..08b8ca485 100644 --- a/tdactor/td/actor/PromiseFuture.h +++ b/tdactor/td/actor/PromiseFuture.h @@ -20,7 +20,6 @@ #include namespace td { - template class PromiseInterface { public: @@ -43,6 +42,15 @@ class PromiseInterface { set_error(result.move_as_error()); } } + void operator()(T &&value) { + set_value(std::move(value)); + } + void operator()(Status &&error) { + set_error(std::move(error)); + } + void operator()(Result &&result) { + set_result(std::move(result)); + } virtual bool is_cancellable() const { return false; } @@ -56,10 +64,182 @@ class PromiseInterface { } }; +namespace detail { + +template +struct GetArg : public GetArg {}; + +template +class GetArg { + public: + using type = Arg; +}; +template +class GetArg { + public: + using type = Arg; +}; + +template +using get_arg_t = std::decay_t::type>; + +template +struct DropResult { + using type = T; +}; + +template +struct DropResult> { + using type = T; +}; + +template +using drop_result_t = typename DropResult::type; + +struct Ignore { + void operator()(Status &&error) { + error.ignore(); + } +}; +template +class LambdaPromise : public PromiseInterface { + enum class OnFail { None, Ok, Fail }; + + public: + void set_value(ValueT &&value) override { + CHECK(has_lambda_.get()); + do_ok(ok_, std::move(value)); + on_fail_ = OnFail::None; + } + void set_error(Status &&error) override { + CHECK(has_lambda_.get()); + do_error(std::move(error)); + } + LambdaPromise(const LambdaPromise &other) = delete; + LambdaPromise &operator=(const LambdaPromise &other) = delete; + LambdaPromise(LambdaPromise &&other) = default; + LambdaPromise &operator=(LambdaPromise &&other) = default; + ~LambdaPromise() override { + if (has_lambda_.get()) { + do_error(Status::Error("Lost promise")); + } + } + + template + LambdaPromise(FromOkT &&ok, FromFailT &&fail, bool use_ok_as_fail) + : ok_(std::forward(ok)) + , fail_(std::forward(fail)) + , on_fail_(use_ok_as_fail ? OnFail::Ok : OnFail::Fail), has_lambda_(true) { + } + template + LambdaPromise(FromOkT &&ok) : LambdaPromise(std::move(ok), Ignore(), true) { + } + + private: + FunctionOkT ok_; + FunctionFailT fail_; + OnFail on_fail_ = OnFail::None; + MovableValue has_lambda_{false}; + + void do_error(Status &&error) { + switch (on_fail_) { + case OnFail::None: + break; + case OnFail::Ok: + do_error(ok_, std::move(error)); + break; + case OnFail::Fail: + do_error(fail_, std::move(error)); + break; + } + on_fail_ = OnFail::None; + } + + template + std::enable_if_t>::value, void> do_error(F &&f, Status &&status) { + f(Result(std::move(status))); + } + template + std::enable_if_t>::value, void> do_error(F &&f, Y &&status) { + f(Auto()); + } + template + std::enable_if_t>::value, void> do_ok(F &&f, ValueT &&result) { + f(Result(std::move(result))); + } + template + std::enable_if_t>::value, void> do_ok(F &&f, ValueT &&result) { + f(std::move(result)); + } +}; +} + template class SafePromise; template +class Promise; + +constexpr std::false_type is_promise_interface(...) { + return {}; +} +template +constexpr std::true_type is_promise_interface(const PromiseInterface &promise) { + return {}; +} +template +constexpr std::true_type is_promise_interface(const Promise &promise) { + return {}; +} + +template +constexpr bool is_promise_interface() { + return decltype(is_promise_interface(std::declval()))::value; +} + +constexpr std::false_type is_promise_interface_ptr(...) { + return {}; +} +template +constexpr std::true_type is_promise_interface_ptr(const unique_ptr &promise) { + return {}; +} + +template +constexpr bool is_promise_interface_ptr() { + return decltype(is_promise_interface_ptr(std::declval()))::value; +} +template ::value, bool> has_t = false> +auto lambda_promise(F &&f) { + return detail::LambdaPromise>>, std::decay_t>(std::forward(f)); +} +template ::value, bool> has_t = true> +auto lambda_promise(F &&f) { + return detail::LambdaPromise>(std::forward(f)); +} + +template (), bool> from_promise_inerface = true> +auto &&promise_interface(F &&f) { + return std::forward(f); +} + +template (), bool> from_promise_inerface = false> +auto promise_interface(F &&f) { + return lambda_promise(std::forward(f)); +} + +template (), bool> from_promise_inerface = true> +auto promise_interface_ptr(F &&f) { + return std::forward(f); +} +template (), bool> from_promise_inerface = false> +auto promise_interface_ptr(F &&f) { + return td::make_unique(std::forward(f)))>>( + promise_interface(std::forward(f))); +} + + +template class Promise { public: void set_value(T &&value) { @@ -83,6 +263,14 @@ class Promise { promise_->set_result(std::move(result)); promise_.reset(); } + template + void operator()(S &&result) { + if (!promise_) { + return; + } + promise_->operator()(std::forward(result)); + promise_.reset(); + } void reset() { promise_.reset(); } @@ -117,8 +305,13 @@ class Promise { Promise() = default; explicit Promise(unique_ptr> promise) : promise_(std::move(promise)) { } + Promise(Auto) { + } Promise(SafePromise &&other); Promise &operator=(SafePromise &&other); + template + Promise(F &&f) : promise_(promise_interface_ptr(std::forward(f))) { + } explicit operator bool() { return static_cast(promise_); @@ -209,36 +402,6 @@ class EventPromise : public PromiseInterface { } }; -template -struct GetArg : public GetArg {}; - -template -class GetArg { - public: - using type = Arg; -}; -template -class GetArg { - public: - using type = Arg; -}; - -template -using get_arg_t = std::decay_t::type>; - -template -struct DropResult { - using type = T; -}; - -template -struct DropResult> { - using type = T; -}; - -template -using drop_result_t = typename DropResult::type; - template class CancellablePromise : public PromiseT { public: @@ -257,62 +420,6 @@ class CancellablePromise : public PromiseT { CancellationToken cancellation_token_; }; -template -class LambdaPromise : public PromiseInterface { - enum class OnFail { None, Ok, Fail }; - - public: - void set_value(ValueT &&value) override { - ok_(std::move(value)); - on_fail_ = OnFail::None; - } - void set_error(Status &&error) override { - do_error(std::move(error)); - } - LambdaPromise(const LambdaPromise &other) = delete; - LambdaPromise &operator=(const LambdaPromise &other) = delete; - LambdaPromise(LambdaPromise &&other) = delete; - LambdaPromise &operator=(LambdaPromise &&other) = delete; - ~LambdaPromise() override { - do_error(Status::Error("Lost promise")); - } - - template - LambdaPromise(FromOkT &&ok, FromFailT &&fail, bool use_ok_as_fail) - : ok_(std::forward(ok)) - , fail_(std::forward(fail)) - , on_fail_(use_ok_as_fail ? OnFail::Ok : OnFail::Fail) { - } - - private: - FunctionOkT ok_; - FunctionFailT fail_; - OnFail on_fail_ = OnFail::None; - - template > - std::enable_if_t::value> do_error_impl(FuncT &func, Status &&status) { - func(std::move(status)); - } - - template > - std::enable_if_t::value> do_error_impl(FuncT &func, Status &&status) { - func(Auto()); - } - - void do_error(Status &&error) { - switch (on_fail_) { - case OnFail::None: - break; - case OnFail::Ok: - do_error_impl(ok_, std::move(error)); - break; - case OnFail::Fail: - fail_(std::move(error)); - break; - } - on_fail_ = OnFail::None; - } -}; template class JoinPromise : public PromiseInterface { @@ -331,6 +438,30 @@ class JoinPromise : public PromiseInterface { }; } // namespace detail +class SendClosure { + public: + template + void operator()(ArgsT &&... args) const { + send_closure(std::forward(args)...); + } +}; + +//template +//template +//auto Promise::send_closure(ArgsT &&... args) { +// return [promise = std::move(*this), t = std::make_tuple(std::forward(args)...)](auto &&r_res) mutable { +// TRY_RESULT_PROMISE(promise, res, std::move(r_res)); +// td2::call_tuple(SendClosure(), std::tuple_cat(std::move(t), std::make_tuple(std::move(res), std::move(promise)))); +// }; +//} + +template +auto promise_send_closure(ArgsT &&... args) { + return [t = std::make_tuple(std::forward(args)...)](auto &&res) mutable { + call_tuple(SendClosure(), std::tuple_cat(std::move(t), std::make_tuple(std::move(res)))); + }; +} + /*** FutureActor and PromiseActor ***/ template class FutureActor; @@ -559,16 +690,12 @@ FutureActor send_promise(ActorId actor_id, ResultT (ActorBT::*func)( class PromiseCreator { public: - struct Ignore { - void operator()(Status &&error) { - error.ignore(); - } - }; + using Ignore = detail::Ignore; template >> static Promise lambda(OkT &&ok) { return Promise( - td::make_unique, Ignore>>(std::forward(ok), Ignore(), true)); + td::make_unique>>(std::forward(ok))); } template > @@ -580,8 +707,8 @@ class PromiseCreator { template >> static auto cancellable_lambda(CancellationToken cancellation_token, OkT &&ok) { return Promise( - td::make_unique, Ignore>>>( - std::move(cancellation_token), std::forward(ok), Ignore(), true)); + td::make_unique>>>( + std::move(cancellation_token), std::forward(ok))); } static Promise<> event(EventFull &&ok) { diff --git a/test/CMakeLists.txt b/test/CMakeLists.txt index 6fab03211..33b3fa426 100644 --- a/test/CMakeLists.txt +++ b/test/CMakeLists.txt @@ -33,16 +33,19 @@ set(TESTS_MAIN if (NOT CMAKE_CROSSCOMPILING OR EMSCRIPTEN) #Tests add_executable(test-tdutils EXCLUDE_FROM_ALL ${TESTS_MAIN} ${TDUTILS_TEST_SOURCE}) + add_executable(test-online EXCLUDE_FROM_ALL online.cpp) add_executable(run_all_tests ${TESTS_MAIN} ${TD_TEST_SOURCE}) if (CLANG AND NOT CYGWIN AND NOT EMSCRIPTEN AND NOT (CMAKE_HOST_SYSTEM_NAME MATCHES "OpenBSD")) target_compile_options(test-tdutils PUBLIC -fsanitize=undefined -fno-sanitize=vptr) target_compile_options(run_all_tests PUBLIC -fsanitize=undefined -fno-sanitize=vptr) + target_compile_options(test-online PUBLIC -fsanitize=undefined -fno-sanitize=vptr) set(CMAKE_EXE_LINKER_FLAGS "${CMAKE_EXE_LINKER_FLAGS} -fsanitize=undefined -fno-sanitize=vptr") endif() target_include_directories(run_all_tests PUBLIC $) target_include_directories(test-tdutils PUBLIC $) target_link_libraries(test-tdutils PRIVATE tdutils) target_link_libraries(run_all_tests PRIVATE tdcore tdclient) + target_link_libraries(test-online PRIVATE tdcore tdclient tdutils tdactor) if (CLANG) # add_executable(fuzz_url fuzz_url.cpp) diff --git a/test/online.cpp b/test/online.cpp new file mode 100644 index 000000000..308673e23 --- /dev/null +++ b/test/online.cpp @@ -0,0 +1,622 @@ + +#include +#include +#include +#include +#include +#include +#include "td/telegram/TdCallback.h" +#include "td/utils/port/signals.h" +#include "td/telegram/Log.h" +#include "td/utils/crypto.h" +#include "td/utils/misc.h" +#include "td/utils/Random.h" +#include "td/actor/actor.h" +#include "td/actor/PromiseFuture.h" +#include "td/actor/MultiPromise.h" +#include "td/telegram/td_api_json.h" + +namespace td { +template +static void check_td_error(T &result) { + LOG_CHECK(result->get_id() != td::td_api::error::ID) << to_string(result); +} + +class TestClient : public Actor { + public: + explicit TestClient(td::string name) : name_(std::move(name)) { + } + struct Update { + td::uint64 id; + td::tl_object_ptr object; + Update(td::uint64 id, td::tl_object_ptr object) : id(id), object(std::move(object)) { + } + }; + class Listener { + public: + Listener() = default; + Listener(const Listener &) = delete; + Listener &operator=(const Listener &) = delete; + Listener(Listener &&) = delete; + Listener &operator=(Listener &&) = delete; + virtual ~Listener() = default; + virtual void start_listen(TestClient *client) { + } + virtual void stop_listen() { + } + virtual void on_update(std::shared_ptr update) = 0; + }; + struct RemoveListener { + void operator()(Listener *listener) { + send_closure(self, &TestClient::remove_listener, listener); + } + ActorId self; + }; + using ListenerToken = std::unique_ptr; + void close(td::Promise<> close_promise) { + close_promise_ = std::move(close_promise); + td_client_.reset(); + } + + td::unique_ptr make_td_callback() { + class TdCallbackImpl : public td::TdCallback { + public: + explicit TdCallbackImpl(td::ActorId client) : client_(client) { + } + void on_result(td::uint64 id, td::tl_object_ptr result) override { + send_closure(client_, &TestClient::on_result, id, std::move(result)); + } + void on_error(td::uint64 id, td::tl_object_ptr error) override { + send_closure(client_, &TestClient::on_error, id, std::move(error)); + } + TdCallbackImpl(const TdCallbackImpl &) = delete; + TdCallbackImpl &operator=(const TdCallbackImpl &) = delete; + TdCallbackImpl(TdCallbackImpl &&) = delete; + TdCallbackImpl &operator=(TdCallbackImpl &&) = delete; + ~TdCallbackImpl() override { + send_closure(client_, &TestClient::on_closed); + } + + private: + td::ActorId client_; + }; + return td::make_unique(actor_id(this)); + } + + void add_listener(td::unique_ptr listener) { + auto *ptr = listener.get(); + listeners_.push_back(std::move(listener)); + ptr->start_listen(this); + } + void remove_listener(Listener *listener) { + pending_remove_.push_back(listener); + } + void do_pending_remove_listeners() { + for (auto listener : pending_remove_) { + do_remove_listener(listener); + } + pending_remove_.clear(); + } + void do_remove_listener(Listener *listener) { + for (size_t i = 0; i < listeners_.size(); i++) { + if (listeners_[i].get() == listener) { + listener->stop_listen(); + listeners_.erase(listeners_.begin() + i); + break; + } + } + } + + void on_result(td::uint64 id, td::tl_object_ptr result) { + on_update(std::make_shared(id, std::move(result))); + } + void on_error(td::uint64 id, td::tl_object_ptr error) { + on_update(std::make_shared(id, std::move(error))); + } + void on_update(std::shared_ptr update) { + for (auto &listener : listeners_) { + listener->on_update(update); + } + do_pending_remove_listeners(); + } + + void on_closed() { + stop(); + } + + void start_up() override { + auto old_context = set_context(std::make_shared()); + set_tag(name_); + LOG(INFO) << "START UP!"; + + td_client_ = td::create_actor("Td-proxy", make_td_callback()); + } + + td::ActorOwn td_client_; + + td::string name_; + + private: + td::vector> listeners_; + td::vector pending_remove_; + + td::Promise<> close_promise_; +}; + +class Task : public TestClient::Listener { + public: + void on_update(std::shared_ptr update) override { + auto it = sent_queries_.find(update->id); + if (it != sent_queries_.end()) { + it->second(std::move(update->object)); + sent_queries_.erase(it); + } + process_update(update); + } + void start_listen(TestClient *client) override { + client_ = client; + start_up(); + } + virtual void process_update(std::shared_ptr update) { + } + + template + void send_query(td::tl_object_ptr function, CallbackT callback) { + auto id = current_query_id_++; + + using ResultT = typename FunctionT::ReturnType; + sent_queries_[id] = + [callback = Promise(std::move(callback))](Result> r_obj) mutable { + TRY_RESULT_PROMISE(callback, obj, std::move(r_obj)); + if (obj->get_id() == td::td_api::error::ID) { + auto err = move_tl_object_as(std::move(obj)); + callback.set_error(Status::Error(err->code_, err->message_)); + return; + } + callback.set_value(move_tl_object_as(std::move(obj))); + }; + send_closure(client_->td_client_, &td::ClientActor::request, id, std::move(function)); + } + + protected: + std::map>> sent_queries_; + TestClient *client_ = nullptr; + td::uint64 current_query_id_ = 1; + + virtual void start_up() { + } + void stop() { + client_->remove_listener(this); + client_ = nullptr; + } + bool is_alive() const { + return client_ != nullptr; + } +}; + +class InitTask : public Task { + public: + struct Options { + string name; + int32 api_id; + string api_hash; + }; + InitTask(Options options, td::Promise<> promise) : options_(std::move(options)), promise_(std::move(promise)) { + } + + private: + Options options_; + td::Promise<> promise_; + bool start_flag_{false}; + + void start_up() override { + send_query(td::make_tl_object(), + [this](auto res) { this->process_authorization_state(res.move_as_ok()); }); + } + void process_authorization_state(td::tl_object_ptr authorization_state) { + start_flag_ = true; + td::tl_object_ptr function; + switch (authorization_state->get_id()) { + case td::td_api::authorizationStateWaitEncryptionKey::ID: + send(td::make_tl_object()); + break; + case td::td_api::authorizationStateReady::ID: + promise_.set_value({}); + stop(); + break; + case td::td_api::authorizationStateWaitTdlibParameters::ID: { + auto parameters = td::td_api::make_object(); + parameters->use_test_dc_ = true; + parameters->database_directory_ = options_.name + TD_DIR_SLASH; + parameters->use_message_database_ = true; + parameters->use_secret_chats_ = true; + parameters->api_id_ = options_.api_id; + parameters->api_hash_ = options_.api_hash; + parameters->system_language_code_ = "en"; + parameters->device_model_ = "Desktop"; + parameters->application_version_ = "tdclient-test"; + parameters->ignore_file_names_ = false; + parameters->enable_storage_optimizer_ = true; + send(td::td_api::make_object(std::move(parameters))); + break; + } + default: + LOG(ERROR) << "???"; + promise_.set_error( + Status::Error(PSLICE() << "Unexpected authorization state " << to_string(authorization_state))); + stop(); + break; + } + } + template + void send(T &&query) { + send_query(std::move(query), [this](auto res) { + if (is_alive()) { + res.ensure(); + } + }); + } + void process_update(std::shared_ptr update) override { + if (!start_flag_) { + return; + } + if (!update->object) { + return; + } + if (update->object->get_id() == td::td_api::updateAuthorizationState::ID) { + auto update_authorization_state = td::move_tl_object_as(update->object); + process_authorization_state(std::move(update_authorization_state->authorization_state_)); + } + } +}; + +class GetMe : public Task { + public: + struct Result { + int32 user_id; + int64 chat_id; + }; + GetMe(Promise promise) : promise_(std::move(promise)) { + } + void start_up() override { + send_query(td::make_tl_object(), [this](auto res) { with_user_id(res.move_as_ok()->id_); }); + } + + private: + Promise promise_; + Result result_; + + void with_user_id(int32 user_id) { + result_.user_id = user_id; + send_query(td::make_tl_object(user_id, false), [this](auto res) { with_chat_id(res.move_as_ok()->id_); }); + } + + void with_chat_id(int64 chat_id) { + result_.chat_id = chat_id; + promise_.set_value(std::move(result_)); + stop(); + } +}; + +class UploadFile : public Task { + public: + struct Result { + std::string content; + std::string remote_id; + }; + UploadFile(std::string dir, std::string content, int64 chat_id, Promise promise) : dir_(std::move(dir)), content_(std::move(content)), chat_id_(std::move(chat_id)), promise_(std::move(promise)) { + } + void start_up() override { + auto hash = hex_encode(sha256(content_)).substr(0, 10); + content_path_ = dir_ + TD_DIR_SLASH + hash + ".data"; + id_path_ = dir_ + TD_DIR_SLASH + hash + ".id"; + + auto r_id = read_file(id_path_); + if (r_id.is_ok() && r_id.ok().size() > 10) { + auto id = r_id.move_as_ok(); + LOG(ERROR) << "Got file from cache"; + Result res; + res.content = std::move(content_); + res.remote_id = id.as_slice().str(); + promise_.set_value(std::move(res)); + stop(); + return; + } + + write_file(content_path_, content_).ensure(); + + send_query(td::make_tl_object( + chat_id_, 0, 0, nullptr, nullptr, + td::make_tl_object( + td::make_tl_object(content_path_), nullptr, true, + td::make_tl_object("tag", td::Auto()))), + [this](auto res) { with_message(res.move_as_ok()); }); + } + + private: + std::string dir_; + std::string content_path_; + std::string id_path_; + std::string content_; + int64 chat_id_; + Promise promise_; + int64 file_id_{0}; + + void with_message(td::tl_object_ptr message) { + CHECK(message->content_->get_id() == td::td_api::messageDocument::ID); + auto messageDocument = td::move_tl_object_as(message->content_); + on_file(*messageDocument->document_->document_, true); + } + + void on_file(const td_api::file &file, bool force = false) { + if (force) { + file_id_ = file.id_; + } + if (file.id_ != file_id_) { + return; + } + if (file.remote_->is_uploading_completed_) { + Result res; + res.content = std::move(content_); + res.remote_id = file.remote_->id_; + + unlink(content_path_).ignore(); + atomic_write_file(id_path_, res.remote_id).ignore(); + + promise_.set_value(std::move(res)); + stop(); + } + } + + void process_update(std::shared_ptr update) override { + if (!update->object) { + return; + } + if (update->object->get_id() == td::td_api::updateFile::ID) { + auto updateFile = td::move_tl_object_as(update->object); + on_file(*updateFile->file_); + } + } +}; + +class TestDownloadFile : public Task { + public: + TestDownloadFile(std::string remote_id, std::string content, Promise promise) : remote_id_(std::move(remote_id)), content_(std::move(content)), promise_(std::move(promise)) { + } + void start_up() override { + send_query(td::make_tl_object( + remote_id_, nullptr + ), [this](auto res) { start_file(*res.ok()); }); + } + + private: + std::string remote_id_; + std::string content_; + Promise promise_; + struct Range { + size_t begin; + size_t end; + }; + int32 file_id_{0}; + std::vector ranges_; + + + void start_file(const td_api::file &file) { + LOG(ERROR) << "Start"; + file_id_ = file.id_; +// CHECK(!file.local_->is_downloading_active_); +// CHECK(!file.local_->is_downloading_completed_); +// CHECK(file.local_->download_offset_ == 0); + if (!file.local_->path_.empty()) { + unlink(file.local_->path_).ignore(); + } + + size_t size = file.size_; + Random::Xorshift128plus rnd(123); + + size_t begin = 0; + + while (begin + 128u < size) { + auto chunk_size = rnd.fast(128, 3096); + auto end = begin + chunk_size; + if (end > size) { + end = size; + } + + ranges_.push_back({begin, end}); + begin = end; + } + + random_shuffle(as_mutable_span(ranges_), rnd); + start_chunk(); + } + + void got_chunk(const td_api::file &file) { + LOG(ERROR) << "Got chunk"; + auto range = ranges_.back(); + std::string got_chunk(range.end - range.begin, '\0'); + FileFd::open(file.local_->path_, FileFd::Flags::Read).move_as_ok().pread(got_chunk, range.begin).ensure(); + CHECK(got_chunk == as_slice(content_).substr(range.begin, range.end - range.begin)); + ranges_.pop_back(); + if (ranges_.empty()) { + promise_.set_value(Unit{}); + return stop(); + } + start_chunk(); + } + + void start_chunk() { + + send_query(td::make_tl_object( + file_id_, 1, int(ranges_.back().begin), int(ranges_.back().end - ranges_.back().begin), true + ), [this](auto res) { got_chunk(*res.ok()); }); + + } +}; + +std::string gen_readable_file(size_t block_size, size_t block_count) { + std::string content; + for (size_t block_id = 0; block_id < block_count; block_id++) { + std::string block; + for (size_t line = 0; block.size() < block_size; line++) { + block += PSTRING() << "\nblock=" << block_id << ", line=" << line; + } + block.resize(block_size); + content += block; + } + return content; +} + +class TestTd : public Actor { + public: + struct Options { + std::string alice_dir = "alice"; + std::string bob_dir = "bob"; + int32 api_id{0}; + string api_hash; + }; + + TestTd(Options options) : options_(std::move(options)) { + } + + private: + Options options_; + ActorOwn alice_; + GetMe::Result alice_id_; + std::string alice_cache_dir_; + ActorOwn bob_; + + void start_up() override { + alice_ = create_actor("Alice", "Alice"); + bob_ = create_actor("Bob", "Bob"); + + MultiPromiseActorSafe mp("init"); + mp.add_promise(promise_send_closure(actor_id(this), &TestTd::check_init)); + + InitTask::Options options; + options.api_id = options_.api_id; + options.api_hash = options_.api_hash; + + options.name = options_.alice_dir; + td::send_closure(alice_, &TestClient::add_listener, td::make_unique(options, mp.get_promise())); + options.name = options_.bob_dir; + td::send_closure(bob_, &TestClient::add_listener, td::make_unique(options, mp.get_promise())); + } + + void check_init(Result res) { + LOG_IF(FATAL, res.is_error()) << res.error(); + alice_cache_dir_ = options_.alice_dir + TD_DIR_SLASH + "cache"; + mkdir(alice_cache_dir_).ignore(); + + td::send_closure(alice_, &TestClient::add_listener, + td::make_unique(promise_send_closure(actor_id(this), &TestTd::with_alice_id))); + + //close(); + } + + void with_alice_id(Result alice_id) { + alice_id_ = alice_id.move_as_ok(); + LOG(ERROR) << "Alice user_id=" << alice_id_.user_id << ", chat_id=" << alice_id_.chat_id; + auto content = gen_readable_file(65536, 20); + send_closure(alice_, &TestClient::add_listener, + td::make_unique(alice_cache_dir_, std::move(content), alice_id_.chat_id, promise_send_closure(actor_id(this), &TestTd::with_file))); + } + void with_file(Result r_result) { + auto result = r_result.move_as_ok(); + send_closure(alice_, &TestClient::add_listener, + td::make_unique(result.remote_id, std::move(result.content), promise_send_closure(actor_id(this), &TestTd::after_test_download_file))); + } + void after_test_download_file(Result) { + close(); + } + + + void close() { + MultiPromiseActorSafe mp("close"); + mp.add_promise(promise_send_closure(actor_id(this), &TestTd::check_close)); + td::send_closure(alice_, &TestClient::close, mp.get_promise()); + td::send_closure(bob_, &TestClient::close, mp.get_promise()); + } + + void check_close(Result res) { + Scheduler::instance()->finish(); + stop(); + } +}; + +static void fail_signal(int sig) { + signal_safe_write_signal_number(sig); + while (true) { + // spin forever to allow debugger to attach + } +} + +static void on_fatal_error(const char *error) { + std::cerr << "Fatal error: " << error << std::endl; +} +int main(int argc, char **argv) { + ignore_signal(SignalType::HangUp).ensure(); + ignore_signal(SignalType::Pipe).ensure(); + set_signal_handler(SignalType::Error, fail_signal).ensure(); + set_signal_handler(SignalType::Abort, fail_signal).ensure(); + Log::set_fatal_error_callback(on_fatal_error); + init_openssl_threads(); + + TestTd::Options test_options; + + test_options.api_id = [](auto x) -> int32 { + if (x) { + return to_integer(Slice(x)); + } + return 0; + }(std::getenv("TD_API_ID")); + test_options.api_hash = [](auto x) -> std::string { + if (x) { + return x; + } + return std::string(); + }(std::getenv("TD_API_HASH")); + + int new_verbosity_level = VERBOSITY_NAME(INFO); + + OptionParser options; + options.set_description("TDLib experimental tester"); + options.add_option('v', "verbosity", "Set verbosity level", [&](Slice level) { + int new_verbosity = 1; + while (begins_with(level, "v")) { + new_verbosity++; + level.remove_prefix(1); + } + if (!level.empty()) { + new_verbosity += to_integer(level) - (new_verbosity == 1); + } + new_verbosity_level = VERBOSITY_NAME(FATAL) + new_verbosity; + }); + options.add_check([&] { + if (test_options.api_id == 0 || test_options.api_hash.empty()) { + return Status::Error("You must provide valid api-id and api-hash obtained at https://my.telegram.org"); + } + return Status::OK(); + }); + auto r_non_options = options.run(argc, argv, 0); + if (r_non_options.is_error()) { + LOG(PLAIN) << argv[0] << ": " << r_non_options.error().message(); + LOG(PLAIN) << options; + return 1; + } + SET_VERBOSITY_LEVEL(new_verbosity_level); + + td::ConcurrentScheduler sched; + sched.init(4); + sched.create_actor_unsafe(0, "TestTd", std::move(test_options)).release(); + sched.start(); + while (sched.run_main(10)) { + } + sched.finish(); + return 0; +} +} // namespace td + +int main(int argc, char **argv) { + return td::main(argc, argv); +}