diff --git a/td/telegram/Td.cpp b/td/telegram/Td.cpp index 7d20c8d2f..7ba2337b9 100644 --- a/td/telegram/Td.cpp +++ b/td/telegram/Td.cpp @@ -110,7 +110,6 @@ #include "td/telegram/net/NetStatsManager.h" #include "td/telegram/net/NetType.h" #include "td/telegram/net/Proxy.h" -#include "td/telegram/net/PublicRsaKeySharedMain.h" #include "td/telegram/net/TempAuthKeyWatchdog.h" #include "td/telegram/NotificationGroupId.h" #include "td/telegram/NotificationId.h" @@ -178,13 +177,6 @@ #include "td/db/binlog/BinlogEvent.h" -#include "td/mtproto/DhCallback.h" -#include "td/mtproto/Handshake.h" -#include "td/mtproto/HandshakeActor.h" -#include "td/mtproto/RawConnection.h" -#include "td/mtproto/RSA.h" -#include "td/mtproto/TransportType.h" - #include "td/actor/actor.h" #include "td/utils/algorithm.h" @@ -194,8 +186,6 @@ #include "td/utils/MimeType.h" #include "td/utils/misc.h" #include "td/utils/PathView.h" -#include "td/utils/port/IPAddress.h" -#include "td/utils/port/SocketFd.h" #include "td/utils/port/uname.h" #include "td/utils/Random.h" #include "td/utils/Slice.h" @@ -493,114 +483,6 @@ class TestNetworkQuery final : public Td::ResultHandler { } }; -class TestProxyRequest final : public RequestOnceActor { - Proxy proxy_; - int16 dc_id_; - double timeout_; - ActorOwn<> child_; - Promise<> promise_; - - auto get_transport() { - return mtproto::TransportType{mtproto::TransportType::ObfuscatedTcp, dc_id_, proxy_.secret()}; - } - - void do_run(Promise &&promise) final { - set_timeout_in(timeout_); - - promise_ = std::move(promise); - IPAddress ip_address; - auto status = ip_address.init_host_port(proxy_.server(), proxy_.port()); - if (status.is_error()) { - return promise_.set_error(Status::Error(400, status.public_message())); - } - auto r_socket_fd = SocketFd::open(ip_address); - if (r_socket_fd.is_error()) { - return promise_.set_error(Status::Error(400, r_socket_fd.error().public_message())); - } - - auto dc_options = ConnectionCreator::get_default_dc_options(false); - IPAddress mtproto_ip_address; - for (auto &dc_option : dc_options.dc_options) { - if (dc_option.get_dc_id().get_raw_id() == dc_id_) { - mtproto_ip_address = dc_option.get_ip_address(); - break; - } - } - - auto connection_promise = - PromiseCreator::lambda([actor_id = actor_id(this)](Result r_data) mutable { - send_closure(actor_id, &TestProxyRequest::on_connection_data, std::move(r_data)); - }); - - child_ = ConnectionCreator::prepare_connection(ip_address, r_socket_fd.move_as_ok(), proxy_, mtproto_ip_address, - get_transport(), "Test", "TestPingDC2", nullptr, {}, false, - std::move(connection_promise)); - } - - void on_connection_data(Result r_data) { - if (r_data.is_error()) { - return promise_.set_error(r_data.move_as_error()); - } - class HandshakeContext final : public mtproto::AuthKeyHandshakeContext { - public: - mtproto::DhCallback *get_dh_callback() final { - return nullptr; - } - mtproto::PublicRsaKeyInterface *get_public_rsa_key_interface() final { - return public_rsa_key_.get(); - } - - private: - std::shared_ptr public_rsa_key_ = PublicRsaKeySharedMain::create(false); - }; - auto handshake = make_unique(dc_id_, 3600); - auto data = r_data.move_as_ok(); - auto raw_connection = - mtproto::RawConnection::create(data.ip_address, std::move(data.buffered_socket_fd), get_transport(), nullptr); - child_ = create_actor( - "HandshakeActor", std::move(handshake), std::move(raw_connection), make_unique(), 10.0, - PromiseCreator::lambda([actor_id = actor_id(this)](Result> raw_connection) { - send_closure(actor_id, &TestProxyRequest::on_handshake_connection, std::move(raw_connection)); - }), - PromiseCreator::lambda( - [actor_id = actor_id(this)](Result> handshake) mutable { - send_closure(actor_id, &TestProxyRequest::on_handshake, std::move(handshake)); - })); - } - void on_handshake_connection(Result> r_raw_connection) { - if (r_raw_connection.is_error()) { - return promise_.set_error(Status::Error(400, r_raw_connection.move_as_error().public_message())); - } - } - void on_handshake(Result> r_handshake) { - if (!promise_) { - return; - } - if (r_handshake.is_error()) { - return promise_.set_error(Status::Error(400, r_handshake.move_as_error().public_message())); - } - - auto handshake = r_handshake.move_as_ok(); - if (!handshake->is_ready_for_finish()) { - promise_.set_error(Status::Error(400, "Handshake is not ready")); - } - promise_.set_value(Unit()); - } - - void timeout_expired() final { - send_error(Status::Error(400, "Timeout expired")); - stop(); - } - - public: - TestProxyRequest(ActorShared td, uint64 request_id, Proxy proxy, int32 dc_id, double timeout) - : RequestOnceActor(std::move(td), request_id) - , proxy_(std::move(proxy)) - , dc_id_(static_cast(dc_id)) - , timeout_(timeout) { - } -}; - class GetMeRequest final : public RequestActor<> { UserId user_id_; @@ -10174,7 +10056,9 @@ void Td::on_request(uint64 id, td_api::testProxy &request) { if (r_proxy.is_error()) { return send_closure(actor_id(this), &Td::send_error, id, r_proxy.move_as_error()); } - CREATE_REQUEST(TestProxyRequest, r_proxy.move_as_ok(), request.dc_id_, request.timeout_); + CREATE_OK_REQUEST_PROMISE(); + send_closure(G()->connection_creator(), &ConnectionCreator::test_proxy, r_proxy.move_as_ok(), request.dc_id_, + request.timeout_, std::move(promise)); } void Td::on_request(uint64 id, const td_api::testGetDifference &request) { diff --git a/td/telegram/net/ConnectionCreator.cpp b/td/telegram/net/ConnectionCreator.cpp index 5d0473cba..68018bc66 100644 --- a/td/telegram/net/ConnectionCreator.cpp +++ b/td/telegram/net/ConnectionCreator.cpp @@ -14,13 +14,16 @@ #include "td/telegram/net/MtprotoHeader.h" #include "td/telegram/net/NetQueryDispatcher.h" #include "td/telegram/net/NetType.h" +#include "td/telegram/net/PublicRsaKeySharedMain.h" #include "td/telegram/StateManager.h" #include "td/telegram/Td.h" #include "td/telegram/TdDb.h" +#include "td/mtproto/DhCallback.h" +#include "td/mtproto/HandshakeActor.h" #include "td/mtproto/Ping.h" #include "td/mtproto/ProxySecret.h" -#include "td/mtproto/RawConnection.h" +#include "td/mtproto/RSA.h" #include "td/mtproto/TlsInit.h" #include "td/net/GetHostByNameActor.h" @@ -28,12 +31,13 @@ #include "td/net/Socks5.h" #include "td/net/TransparentProxy.h" +#include "td/actor/SleepActor.h" + #include "td/utils/algorithm.h" #include "td/utils/base64.h" #include "td/utils/format.h" #include "td/utils/logging.h" #include "td/utils/misc.h" -#include "td/utils/port/IPAddress.h" #include "td/utils/Random.h" #include "td/utils/ScopeGuard.h" #include "td/utils/SliceBuilder.h" @@ -386,6 +390,138 @@ void ConnectionCreator::ping_proxy_buffered_socket_fd(IPAddress ip_address, Buff create_reference(token))}; } +void ConnectionCreator::test_proxy(Proxy &&proxy, int32 dc_id, double timeout, Promise &&promise) { + auto start_time = Time::now(); + + IPAddress ip_address; + auto status = ip_address.init_host_port(proxy.server(), proxy.port()); + if (status.is_error()) { + return promise.set_error(Status::Error(400, status.public_message())); + } + auto r_socket_fd = SocketFd::open(ip_address); + if (r_socket_fd.is_error()) { + return promise.set_error(Status::Error(400, r_socket_fd.error().public_message())); + } + + auto dc_options = get_default_dc_options(false); + IPAddress mtproto_ip_address; + for (auto &dc_option : dc_options.dc_options) { + if (dc_option.get_dc_id().get_raw_id() == dc_id) { + mtproto_ip_address = dc_option.get_ip_address(); + break; + } + } + if (!mtproto_ip_address.is_valid()) { + return promise.set_error(Status::Error(400, "Invalid datacenter identifier specified")); + } + + auto request_id = ++test_proxy_request_id_; + auto request = make_unique(); + request->proxy_ = std::move(proxy); + request->dc_id_ = static_cast(dc_id); + request->promise_ = std::move(promise); + + auto connection_promise = + PromiseCreator::lambda([actor_id = actor_id(this), request_id](Result r_data) { + send_closure(actor_id, &ConnectionCreator::on_test_proxy_connection_data, request_id, std::move(r_data)); + }); + request->child_ = prepare_connection(ip_address, r_socket_fd.move_as_ok(), request->proxy_, mtproto_ip_address, + request->get_transport(), "Test", "TestPingDC2", nullptr, {}, false, + std::move(connection_promise)); + + test_proxy_requests_.emplace(request_id, std::move(request)); + + create_actor("TestProxyTimeoutActor", timeout + start_time - Time::now(), + PromiseCreator::lambda([actor_id = actor_id(this), request_id](Result result) { + send_closure(actor_id, &ConnectionCreator::on_test_proxy_timeout, request_id); + })) + .release(); +} + +void ConnectionCreator::on_test_proxy_connection_data(uint64 request_id, Result r_data) { + auto it = test_proxy_requests_.find(request_id); + if (it == test_proxy_requests_.end()) { + return; + } + auto *request = it->second.get(); + if (r_data.is_error()) { + auto promise = std::move(request->promise_); + test_proxy_requests_.erase(it); + return promise.set_error(r_data.move_as_error()); + } + + class HandshakeContext final : public mtproto::AuthKeyHandshakeContext { + public: + mtproto::DhCallback *get_dh_callback() final { + return nullptr; + } + mtproto::PublicRsaKeyInterface *get_public_rsa_key_interface() final { + return public_rsa_key_.get(); + } + + private: + std::shared_ptr public_rsa_key_ = PublicRsaKeySharedMain::create(false); + }; + auto handshake = make_unique(request->dc_id_, 3600); + auto data = r_data.move_as_ok(); + auto raw_connection = mtproto::RawConnection::create(data.ip_address, std::move(data.buffered_socket_fd), + request->get_transport(), nullptr); + request->child_ = create_actor( + "HandshakeActor", std::move(handshake), std::move(raw_connection), make_unique(), 10.0, + PromiseCreator::lambda( + [actor_id = actor_id(this), request_id](Result> raw_connection) { + send_closure(actor_id, &ConnectionCreator::on_test_proxy_handshake_connection, request_id, + std::move(raw_connection)); + }), + PromiseCreator::lambda( + [actor_id = actor_id(this), request_id](Result> handshake) { + send_closure(actor_id, &ConnectionCreator::on_test_proxy_handshake, request_id, std::move(handshake)); + })); +} + +void ConnectionCreator::on_test_proxy_handshake_connection( + uint64 request_id, Result> r_raw_connection) { + if (r_raw_connection.is_error()) { + auto it = test_proxy_requests_.find(request_id); + if (it == test_proxy_requests_.end()) { + return; + } + auto promise = std::move(it->second->promise_); + test_proxy_requests_.erase(it); + return promise.set_error(Status::Error(400, r_raw_connection.move_as_error().public_message())); + } +} + +void ConnectionCreator::on_test_proxy_handshake(uint64 request_id, + Result> r_handshake) { + auto it = test_proxy_requests_.find(request_id); + if (it == test_proxy_requests_.end()) { + return; + } + auto promise = std::move(it->second->promise_); + test_proxy_requests_.erase(it); + + if (r_handshake.is_error()) { + return promise.set_error(Status::Error(400, r_handshake.move_as_error().public_message())); + } + auto handshake = r_handshake.move_as_ok(); + if (!handshake->is_ready_for_finish()) { + return promise.set_error(Status::Error(400, "Handshake is not ready")); + } + promise.set_value(Unit()); +} + +void ConnectionCreator::on_test_proxy_timeout(uint64 request_id) { + auto it = test_proxy_requests_.find(request_id); + if (it == test_proxy_requests_.end()) { + return; + } + auto promise = std::move(it->second->promise_); + test_proxy_requests_.erase(it); + + promise.set_error(Status::Error(400, "Timeout expired")); +} + void ConnectionCreator::set_active_proxy_id(int32 proxy_id, bool from_binlog) { active_proxy_id_ = proxy_id; if (proxy_id == 0) { diff --git a/td/telegram/net/ConnectionCreator.h b/td/telegram/net/ConnectionCreator.h index d43bd92fd..37b7447fa 100644 --- a/td/telegram/net/ConnectionCreator.h +++ b/td/telegram/net/ConnectionCreator.h @@ -15,6 +15,7 @@ #include "td/mtproto/AuthData.h" #include "td/mtproto/ConnectionManager.h" +#include "td/mtproto/Handshake.h" #include "td/mtproto/RawConnection.h" #include "td/mtproto/TransportType.h" @@ -79,22 +80,7 @@ class ConnectionCreator final : public NetQueryCallback { void get_proxy_link(int32 proxy_id, Promise promise); void ping_proxy(int32 proxy_id, Promise promise); - struct ConnectionData { - IPAddress ip_address; - BufferedFd buffered_socket_fd; - mtproto::ConnectionManager::ConnectionToken connection_token; - unique_ptr stats_callback; - }; - - static DcOptions get_default_dc_options(bool is_test); - - static ActorOwn<> prepare_connection(IPAddress ip_address, SocketFd socket_fd, const Proxy &proxy, - const IPAddress &mtproto_ip_address, - const mtproto::TransportType &transport_type, Slice actor_name_prefix, - Slice debug_str, - unique_ptr stats_callback, - ActorShared<> parent, bool use_connection_token, - Promise promise); + void test_proxy(Proxy &&proxy, int32 dc_id, double timeout, Promise &&promise); private: ActorShared<> parent_; @@ -185,6 +171,26 @@ class ConnectionCreator final : public NetQueryCallback { }; std::map ping_main_dc_requests_; + struct ConnectionData { + IPAddress ip_address; + BufferedFd buffered_socket_fd; + mtproto::ConnectionManager::ConnectionToken connection_token; + unique_ptr stats_callback; + }; + + struct TestProxyRequest { + Proxy proxy_; + int16 dc_id_; + ActorOwn<> child_; + Promise promise_; + + mtproto::TransportType get_transport() const { + return mtproto::TransportType{mtproto::TransportType::ObfuscatedTcp, dc_id_, proxy_.secret()}; + } + }; + uint64 test_proxy_request_id_ = 0; + FlatHashMap> test_proxy_requests_; + uint64 next_token() { return ++current_token_; } @@ -237,12 +243,21 @@ class ConnectionCreator final : public NetQueryCallback { IPAddress mtproto_ip_address; bool check_mode{false}; }; + Result find_connection(const Proxy &proxy, const IPAddress &proxy_ip_address, DcId dc_id, + bool allow_media_only, FindConnectionExtra &extra); + + static DcOptions get_default_dc_options(bool is_test); static Result get_transport_type(const Proxy &proxy, const DcOptionsSet::ConnectionInfo &info); - Result find_connection(const Proxy &proxy, const IPAddress &proxy_ip_address, DcId dc_id, - bool allow_media_only, FindConnectionExtra &extra); + static ActorOwn<> prepare_connection(IPAddress ip_address, SocketFd socket_fd, const Proxy &proxy, + const IPAddress &mtproto_ip_address, + const mtproto::TransportType &transport_type, Slice actor_name_prefix, + Slice debug_str, + unique_ptr stats_callback, + ActorShared<> parent, bool use_connection_token, + Promise promise); ActorId get_dns_resolver(); @@ -252,6 +267,15 @@ class ConnectionCreator final : public NetQueryCallback { mtproto::TransportType transport_type, string debug_str, Promise promise); void on_ping_main_dc_result(uint64 token, Result result); + + void on_test_proxy_connection_data(uint64 request_id, Result r_data); + + void on_test_proxy_handshake_connection(uint64 request_id, + Result> r_raw_connection); + + void on_test_proxy_handshake(uint64 request_id, Result> r_handshake); + + void on_test_proxy_timeout(uint64 request_id); }; } // namespace td