diff --git a/tdnet/td/net/GetHostByNameActor.cpp b/tdnet/td/net/GetHostByNameActor.cpp index 9cd33a5b8..596bbfd8a 100644 --- a/tdnet/td/net/GetHostByNameActor.cpp +++ b/tdnet/td/net/GetHostByNameActor.cpp @@ -15,21 +15,28 @@ namespace td { namespace detail { class GoogleDnsResolver : public Actor { public: - GoogleDnsResolver(std::string host, int port, bool prefer_ipv6, td::Promise promise) { - const int timeout = 10; - const int ttl = 3; - wget_ = create_actor( - "Wget", create_result_handler(std::move(promise), port), - PSTRING() << "https://www.google.com/resolve?name=" << url_encode(host) << "&type=" << (prefer_ipv6 ? 28 : 1), - std::vector>({{"Host", "dns.google.com"}}), timeout, ttl, prefer_ipv6, - SslStream::VerifyPeer::Off); + GoogleDnsResolver(std::string host, GetHostByNameActor::Options options, td::Promise promise) + : host_(std::move(host)), options_(std::move(options)), promise_(std::move(promise)) { } private: + std::string host_; + GetHostByNameActor::Options options_; + Promise promise_; ActorOwn wget_; - Promise create_result_handler(Promise promise, int port) { - return PromiseCreator::lambda([promise = std::move(promise), port](Result r_http_query) mutable { + void start_up() override { + const int timeout = 10; + const int ttl = 3; + wget_ = create_actor("Wget", create_result_handler(std::move(promise_)), + PSTRING() << "https://www.google.com/resolve?name=" << url_encode(host_) + << "&type=" << (options_.prefer_ipv6 ? 28 : 1), + std::vector>({{"Host", "dns.google.com"}}), timeout, ttl, + options_.prefer_ipv6, SslStream::VerifyPeer::Off); + } + + Promise create_result_handler(Promise promise) { + return PromiseCreator::lambda([promise = std::move(promise)](Result r_http_query) mutable { promise.set_result([&]() -> Result { TRY_RESULT(http_query, std::move(r_http_query)); LOG(ERROR) << *http_query; @@ -47,52 +54,128 @@ class GoogleDnsResolver : public Actor { auto &answer_0 = answer.get_array()[0].get_object(); TRY_RESULT(ip_str, get_json_object_string_field(answer_0, "data", false)); IPAddress ip; - TRY_STATUS(ip.init_host_port(ip_str, port)); + TRY_STATUS(ip.init_host_port(ip_str, 0)); return ip; }()); }); } }; +class NativeDnsResolver : public Actor { + public: + NativeDnsResolver(std::string host, GetHostByNameActor::Options options, td::Promise promise) + : host_(std::move(host)), options_(std::move(options)), promise_(std::move(promise)) { + } + + private: + std::string host_; + GetHostByNameActor::Options options_; + Promise promise_; + + void start_up() override { + IPAddress ip; + auto begin_time = td::Time::now(); + auto status = ip.init_host_port(host_, 0, options_.prefer_ipv6); + auto end_time = td::Time::now(); + LOG(WARNING) << "Init host = " << host_ << " in " << end_time - begin_time << " seconds to " << ip; + if (status.is_error()) { + promise_.set_error(std::move(status)); + return; + } + promise_.set_value(std::move(ip)); + stop(); + } +}; +class DnsResolver : public Actor { + public: + DnsResolver(std::string host, GetHostByNameActor::Options options, td::Promise promise) + : host_(std::move(host)), options_(std::move(options)), promise_(std::move(promise)) { + } + + private: + std::string host_; + GetHostByNameActor::Options options_; + Promise promise_; + ActorOwn<> query_; + size_t pos_ = 0; + GetHostByNameActor::Options::Type types[2] = {GetHostByNameActor::Options::Google, + GetHostByNameActor::Options::Native}; + + void loop() override { + if (!query_.empty()) { + return; + } + if (pos_ == 2) { + promise_.set_error(Status::Error("Failed to resolve ip")); + return stop(); + } + options_.type = types[pos_]; + query_ = GetHostByNameActor::resolve(host_, options_, + PromiseCreator::lambda([actor_id = actor_id(this)](Result res) { + send_closure(actor_id, &DnsResolver::on_result, std::move(res)); + })); + } + + void on_result(Result res) { + query_.reset(); + if (res.is_ok() || pos_ == 2) { + promise_.set_result(std::move(res)); + return stop(); + } + loop(); + } +}; } // namespace detail -ActorOwn<> DnsOverHttps::resolve(std::string host, int port, bool prefer_ipv6, td::Promise promise) { - return ActorOwn<>(create_actor("GoogleDnsResolver", std::move(host), port, prefer_ipv6, - std::move(promise))); +ActorOwn<> GetHostByNameActor::resolve(std::string host, Options options, Promise promise) { + switch (options.type) { + case Options::Native: + return ActorOwn<>(create_actor_on_scheduler( + "NativeDnsResolver", options.scheduler_id, std::move(host), options, std::move(promise))); + case Options::Google: + return ActorOwn<>(create_actor_on_scheduler( + "GoogleDnsResolver", options.scheduler_id, std::move(host), options, std::move(promise))); + case Options::All: + return ActorOwn<>(create_actor_on_scheduler("DnsResolver", options.scheduler_id, + std::move(host), options, std::move(promise))); + } } GetHostByNameActor::GetHostByNameActor(int32 ok_timeout, int32 error_timeout) : ok_timeout_(ok_timeout), error_timeout_(error_timeout) { } -void GetHostByNameActor::run(std::string host, int port, bool prefer_ipv6, td::Promise promise) { - auto r_ip = load_ip(std::move(host), port, prefer_ipv6); - promise.set_result(std::move(r_ip)); +void GetHostByNameActor::on_result(std::string host, bool prefer_ipv6, Result res) { + auto &value = cache_[prefer_ipv6].emplace(host, Value{{}, 0}).first->second; + + auto promises = std::move(value.promises); + auto end_time = td::Time::now(); + if (res.is_ok()) { + value = Value{res.move_as_ok(), end_time + ok_timeout_}; + } else { + value = Value{res.move_as_error(), end_time + error_timeout_}; + } + for (auto &promise : promises) { + promise.second.set_result(value.get_ip_port(promise.first)); + } } -Result GetHostByNameActor::load_ip(string host, int port, bool prefer_ipv6) { +void GetHostByNameActor::run(string host, int port, bool prefer_ipv6, Promise promise) { auto &value = cache_[prefer_ipv6].emplace(host, Value{{}, 0}).first->second; auto begin_time = td::Time::now(); if (value.expire_at > begin_time) { - auto ip = value.ip.clone(); - if (ip.is_ok()) { - ip.ok_ref().set_port(port); - CHECK(ip.ok().get_port() == port); - } - return ip; + return promise.set_result(value.get_ip_port(port)); } - td::IPAddress ip; - auto status = ip.init_host_port(host, port, prefer_ipv6); - auto end_time = td::Time::now(); - LOG(WARNING) << "Init host = " << host << ", port = " << port << " in " << end_time - begin_time << " seconds to " - << ip; - - if (status.is_ok()) { - value = Value{ip, end_time + ok_timeout_}; - return ip; - } else { - value = Value{status.clone(), end_time + error_timeout_}; - return std::move(status); + value.promises.emplace_back(port, std::move(promise)); + if (value.query.empty()) { + Options options; + options.type = Options::Type::All; + options.prefer_ipv6 = prefer_ipv6; + value.query = + resolve(host, options, + PromiseCreator::lambda([actor_id = actor_id(this), host, prefer_ipv6](Result res) mutable { + send_closure(actor_id, &GetHostByNameActor::on_result, std::move(host), prefer_ipv6, std::move(res)); + })); } } diff --git a/tdnet/td/net/GetHostByNameActor.h b/tdnet/td/net/GetHostByNameActor.h index d96b077a3..5727f3783 100644 --- a/tdnet/td/net/GetHostByNameActor.h +++ b/tdnet/td/net/GetHostByNameActor.h @@ -15,25 +15,36 @@ #include namespace td { - -class DnsOverHttps { - public: - static TD_WARN_UNUSED_RESULT ActorOwn<> resolve(std::string host, int port, bool prefer_ipv6, - td::Promise promise); -}; - class GetHostByNameActor final : public td::Actor { public: explicit GetHostByNameActor(int32 ok_timeout = CACHE_TIME, int32 error_timeout = ERROR_CACHE_TIME); void run(std::string host, int port, bool prefer_ipv6, td::Promise promise); + struct Options { + enum Type { Native, Google, All } type{Native}; + bool prefer_ipv6{false}; + int scheduler_id{-1}; + }; + static TD_WARN_UNUSED_RESULT ActorOwn<> resolve(std::string host, Options options, Promise promise); + private: struct Value { Result ip; double expire_at; + ActorOwn<> query; + std::vector>> promises; + Value(Result ip, double expire_at) : ip(std::move(ip)), expire_at(expire_at) { } + + Result get_ip_port(int port) { + auto res = ip.clone(); + if (res.is_ok()) { + res.ok_ref().set_port(port); + } + return res; + } }; std::unordered_map cache_[2]; static constexpr int32 CACHE_TIME = 60 * 29; // 29 minutes @@ -42,7 +53,7 @@ class GetHostByNameActor final : public td::Actor { int32 ok_timeout_; int32 error_timeout_; - Result load_ip(string host, int port, bool prefer_ipv6) TD_WARN_UNUSED_RESULT; + void on_result(std::string host, bool prefer_ipv6, Result res); }; } // namespace td diff --git a/test/mtproto.cpp b/test/mtproto.cpp index 5ce05ee50..f81de34ff 100644 --- a/test/mtproto.cpp +++ b/test/mtproto.cpp @@ -43,7 +43,7 @@ TEST(Mtproto, DnsOverHttps) { { auto guard = sched.get_main_guard(); - auto run = [&](bool prefer_ipv6) { + auto run = [&](GetHostByNameActor::Options options) { auto promise = PromiseCreator::lambda([&, num = cnt](Result r_ip_address) { if (r_ip_address.is_ok()) { LOG(WARNING) << num << " " << r_ip_address.ok(); @@ -55,11 +55,15 @@ TEST(Mtproto, DnsOverHttps) { } }); cnt++; - DnsOverHttps::resolve("web.telegram.org", 443, prefer_ipv6, std::move(promise)).release(); + GetHostByNameActor::resolve("web.telegram.org", options, std::move(promise)).release(); }; - run(false); - run(true); + run(GetHostByNameActor::Options{GetHostByNameActor::Options::Native, true, -1}); + run(GetHostByNameActor::Options{GetHostByNameActor::Options::Google, true, -1}); + run(GetHostByNameActor::Options{GetHostByNameActor::Options::All, true, -1}); + run(GetHostByNameActor::Options{GetHostByNameActor::Options::Native, false, -1}); + run(GetHostByNameActor::Options{GetHostByNameActor::Options::Google, false, -1}); + run(GetHostByNameActor::Options{GetHostByNameActor::Options::All, false, -1}); } cnt--; sched.start();