Move testProxy implementation to ConnectionCreator.

This commit is contained in:
levlam 2024-07-31 16:56:49 +03:00
parent c8b5ecce6c
commit a0487d4acd
3 changed files with 183 additions and 139 deletions

View File

@ -110,7 +110,6 @@
#include "td/telegram/net/NetStatsManager.h"
#include "td/telegram/net/NetType.h"
#include "td/telegram/net/Proxy.h"
#include "td/telegram/net/PublicRsaKeySharedMain.h"
#include "td/telegram/net/TempAuthKeyWatchdog.h"
#include "td/telegram/NotificationGroupId.h"
#include "td/telegram/NotificationId.h"
@ -178,13 +177,6 @@
#include "td/db/binlog/BinlogEvent.h"
#include "td/mtproto/DhCallback.h"
#include "td/mtproto/Handshake.h"
#include "td/mtproto/HandshakeActor.h"
#include "td/mtproto/RawConnection.h"
#include "td/mtproto/RSA.h"
#include "td/mtproto/TransportType.h"
#include "td/actor/actor.h"
#include "td/utils/algorithm.h"
@ -194,8 +186,6 @@
#include "td/utils/MimeType.h"
#include "td/utils/misc.h"
#include "td/utils/PathView.h"
#include "td/utils/port/IPAddress.h"
#include "td/utils/port/SocketFd.h"
#include "td/utils/port/uname.h"
#include "td/utils/Random.h"
#include "td/utils/Slice.h"
@ -493,114 +483,6 @@ class TestNetworkQuery final : public Td::ResultHandler {
}
};
class TestProxyRequest final : public RequestOnceActor {
Proxy proxy_;
int16 dc_id_;
double timeout_;
ActorOwn<> child_;
Promise<> promise_;
auto get_transport() {
return mtproto::TransportType{mtproto::TransportType::ObfuscatedTcp, dc_id_, proxy_.secret()};
}
void do_run(Promise<Unit> &&promise) final {
set_timeout_in(timeout_);
promise_ = std::move(promise);
IPAddress ip_address;
auto status = ip_address.init_host_port(proxy_.server(), proxy_.port());
if (status.is_error()) {
return promise_.set_error(Status::Error(400, status.public_message()));
}
auto r_socket_fd = SocketFd::open(ip_address);
if (r_socket_fd.is_error()) {
return promise_.set_error(Status::Error(400, r_socket_fd.error().public_message()));
}
auto dc_options = ConnectionCreator::get_default_dc_options(false);
IPAddress mtproto_ip_address;
for (auto &dc_option : dc_options.dc_options) {
if (dc_option.get_dc_id().get_raw_id() == dc_id_) {
mtproto_ip_address = dc_option.get_ip_address();
break;
}
}
auto connection_promise =
PromiseCreator::lambda([actor_id = actor_id(this)](Result<ConnectionCreator::ConnectionData> r_data) mutable {
send_closure(actor_id, &TestProxyRequest::on_connection_data, std::move(r_data));
});
child_ = ConnectionCreator::prepare_connection(ip_address, r_socket_fd.move_as_ok(), proxy_, mtproto_ip_address,
get_transport(), "Test", "TestPingDC2", nullptr, {}, false,
std::move(connection_promise));
}
void on_connection_data(Result<ConnectionCreator::ConnectionData> r_data) {
if (r_data.is_error()) {
return promise_.set_error(r_data.move_as_error());
}
class HandshakeContext final : public mtproto::AuthKeyHandshakeContext {
public:
mtproto::DhCallback *get_dh_callback() final {
return nullptr;
}
mtproto::PublicRsaKeyInterface *get_public_rsa_key_interface() final {
return public_rsa_key_.get();
}
private:
std::shared_ptr<mtproto::PublicRsaKeyInterface> public_rsa_key_ = PublicRsaKeySharedMain::create(false);
};
auto handshake = make_unique<mtproto::AuthKeyHandshake>(dc_id_, 3600);
auto data = r_data.move_as_ok();
auto raw_connection =
mtproto::RawConnection::create(data.ip_address, std::move(data.buffered_socket_fd), get_transport(), nullptr);
child_ = create_actor<mtproto::HandshakeActor>(
"HandshakeActor", std::move(handshake), std::move(raw_connection), make_unique<HandshakeContext>(), 10.0,
PromiseCreator::lambda([actor_id = actor_id(this)](Result<unique_ptr<mtproto::RawConnection>> raw_connection) {
send_closure(actor_id, &TestProxyRequest::on_handshake_connection, std::move(raw_connection));
}),
PromiseCreator::lambda(
[actor_id = actor_id(this)](Result<unique_ptr<mtproto::AuthKeyHandshake>> handshake) mutable {
send_closure(actor_id, &TestProxyRequest::on_handshake, std::move(handshake));
}));
}
void on_handshake_connection(Result<unique_ptr<mtproto::RawConnection>> r_raw_connection) {
if (r_raw_connection.is_error()) {
return promise_.set_error(Status::Error(400, r_raw_connection.move_as_error().public_message()));
}
}
void on_handshake(Result<unique_ptr<mtproto::AuthKeyHandshake>> r_handshake) {
if (!promise_) {
return;
}
if (r_handshake.is_error()) {
return promise_.set_error(Status::Error(400, r_handshake.move_as_error().public_message()));
}
auto handshake = r_handshake.move_as_ok();
if (!handshake->is_ready_for_finish()) {
promise_.set_error(Status::Error(400, "Handshake is not ready"));
}
promise_.set_value(Unit());
}
void timeout_expired() final {
send_error(Status::Error(400, "Timeout expired"));
stop();
}
public:
TestProxyRequest(ActorShared<Td> td, uint64 request_id, Proxy proxy, int32 dc_id, double timeout)
: RequestOnceActor(std::move(td), request_id)
, proxy_(std::move(proxy))
, dc_id_(static_cast<int16>(dc_id))
, timeout_(timeout) {
}
};
class GetMeRequest final : public RequestActor<> {
UserId user_id_;
@ -10174,7 +10056,9 @@ void Td::on_request(uint64 id, td_api::testProxy &request) {
if (r_proxy.is_error()) {
return send_closure(actor_id(this), &Td::send_error, id, r_proxy.move_as_error());
}
CREATE_REQUEST(TestProxyRequest, r_proxy.move_as_ok(), request.dc_id_, request.timeout_);
CREATE_OK_REQUEST_PROMISE();
send_closure(G()->connection_creator(), &ConnectionCreator::test_proxy, r_proxy.move_as_ok(), request.dc_id_,
request.timeout_, std::move(promise));
}
void Td::on_request(uint64 id, const td_api::testGetDifference &request) {

View File

@ -14,13 +14,16 @@
#include "td/telegram/net/MtprotoHeader.h"
#include "td/telegram/net/NetQueryDispatcher.h"
#include "td/telegram/net/NetType.h"
#include "td/telegram/net/PublicRsaKeySharedMain.h"
#include "td/telegram/StateManager.h"
#include "td/telegram/Td.h"
#include "td/telegram/TdDb.h"
#include "td/mtproto/DhCallback.h"
#include "td/mtproto/HandshakeActor.h"
#include "td/mtproto/Ping.h"
#include "td/mtproto/ProxySecret.h"
#include "td/mtproto/RawConnection.h"
#include "td/mtproto/RSA.h"
#include "td/mtproto/TlsInit.h"
#include "td/net/GetHostByNameActor.h"
@ -28,12 +31,13 @@
#include "td/net/Socks5.h"
#include "td/net/TransparentProxy.h"
#include "td/actor/SleepActor.h"
#include "td/utils/algorithm.h"
#include "td/utils/base64.h"
#include "td/utils/format.h"
#include "td/utils/logging.h"
#include "td/utils/misc.h"
#include "td/utils/port/IPAddress.h"
#include "td/utils/Random.h"
#include "td/utils/ScopeGuard.h"
#include "td/utils/SliceBuilder.h"
@ -386,6 +390,138 @@ void ConnectionCreator::ping_proxy_buffered_socket_fd(IPAddress ip_address, Buff
create_reference(token))};
}
void ConnectionCreator::test_proxy(Proxy &&proxy, int32 dc_id, double timeout, Promise<Unit> &&promise) {
auto start_time = Time::now();
IPAddress ip_address;
auto status = ip_address.init_host_port(proxy.server(), proxy.port());
if (status.is_error()) {
return promise.set_error(Status::Error(400, status.public_message()));
}
auto r_socket_fd = SocketFd::open(ip_address);
if (r_socket_fd.is_error()) {
return promise.set_error(Status::Error(400, r_socket_fd.error().public_message()));
}
auto dc_options = get_default_dc_options(false);
IPAddress mtproto_ip_address;
for (auto &dc_option : dc_options.dc_options) {
if (dc_option.get_dc_id().get_raw_id() == dc_id) {
mtproto_ip_address = dc_option.get_ip_address();
break;
}
}
if (!mtproto_ip_address.is_valid()) {
return promise.set_error(Status::Error(400, "Invalid datacenter identifier specified"));
}
auto request_id = ++test_proxy_request_id_;
auto request = make_unique<TestProxyRequest>();
request->proxy_ = std::move(proxy);
request->dc_id_ = static_cast<int16>(dc_id);
request->promise_ = std::move(promise);
auto connection_promise =
PromiseCreator::lambda([actor_id = actor_id(this), request_id](Result<ConnectionData> r_data) {
send_closure(actor_id, &ConnectionCreator::on_test_proxy_connection_data, request_id, std::move(r_data));
});
request->child_ = prepare_connection(ip_address, r_socket_fd.move_as_ok(), request->proxy_, mtproto_ip_address,
request->get_transport(), "Test", "TestPingDC2", nullptr, {}, false,
std::move(connection_promise));
test_proxy_requests_.emplace(request_id, std::move(request));
create_actor<SleepActor>("TestProxyTimeoutActor", timeout + start_time - Time::now(),
PromiseCreator::lambda([actor_id = actor_id(this), request_id](Result<Unit> result) {
send_closure(actor_id, &ConnectionCreator::on_test_proxy_timeout, request_id);
}))
.release();
}
void ConnectionCreator::on_test_proxy_connection_data(uint64 request_id, Result<ConnectionData> r_data) {
auto it = test_proxy_requests_.find(request_id);
if (it == test_proxy_requests_.end()) {
return;
}
auto *request = it->second.get();
if (r_data.is_error()) {
auto promise = std::move(request->promise_);
test_proxy_requests_.erase(it);
return promise.set_error(r_data.move_as_error());
}
class HandshakeContext final : public mtproto::AuthKeyHandshakeContext {
public:
mtproto::DhCallback *get_dh_callback() final {
return nullptr;
}
mtproto::PublicRsaKeyInterface *get_public_rsa_key_interface() final {
return public_rsa_key_.get();
}
private:
std::shared_ptr<mtproto::PublicRsaKeyInterface> public_rsa_key_ = PublicRsaKeySharedMain::create(false);
};
auto handshake = make_unique<mtproto::AuthKeyHandshake>(request->dc_id_, 3600);
auto data = r_data.move_as_ok();
auto raw_connection = mtproto::RawConnection::create(data.ip_address, std::move(data.buffered_socket_fd),
request->get_transport(), nullptr);
request->child_ = create_actor<mtproto::HandshakeActor>(
"HandshakeActor", std::move(handshake), std::move(raw_connection), make_unique<HandshakeContext>(), 10.0,
PromiseCreator::lambda(
[actor_id = actor_id(this), request_id](Result<unique_ptr<mtproto::RawConnection>> raw_connection) {
send_closure(actor_id, &ConnectionCreator::on_test_proxy_handshake_connection, request_id,
std::move(raw_connection));
}),
PromiseCreator::lambda(
[actor_id = actor_id(this), request_id](Result<unique_ptr<mtproto::AuthKeyHandshake>> handshake) {
send_closure(actor_id, &ConnectionCreator::on_test_proxy_handshake, request_id, std::move(handshake));
}));
}
void ConnectionCreator::on_test_proxy_handshake_connection(
uint64 request_id, Result<unique_ptr<mtproto::RawConnection>> r_raw_connection) {
if (r_raw_connection.is_error()) {
auto it = test_proxy_requests_.find(request_id);
if (it == test_proxy_requests_.end()) {
return;
}
auto promise = std::move(it->second->promise_);
test_proxy_requests_.erase(it);
return promise.set_error(Status::Error(400, r_raw_connection.move_as_error().public_message()));
}
}
void ConnectionCreator::on_test_proxy_handshake(uint64 request_id,
Result<unique_ptr<mtproto::AuthKeyHandshake>> r_handshake) {
auto it = test_proxy_requests_.find(request_id);
if (it == test_proxy_requests_.end()) {
return;
}
auto promise = std::move(it->second->promise_);
test_proxy_requests_.erase(it);
if (r_handshake.is_error()) {
return promise.set_error(Status::Error(400, r_handshake.move_as_error().public_message()));
}
auto handshake = r_handshake.move_as_ok();
if (!handshake->is_ready_for_finish()) {
return promise.set_error(Status::Error(400, "Handshake is not ready"));
}
promise.set_value(Unit());
}
void ConnectionCreator::on_test_proxy_timeout(uint64 request_id) {
auto it = test_proxy_requests_.find(request_id);
if (it == test_proxy_requests_.end()) {
return;
}
auto promise = std::move(it->second->promise_);
test_proxy_requests_.erase(it);
promise.set_error(Status::Error(400, "Timeout expired"));
}
void ConnectionCreator::set_active_proxy_id(int32 proxy_id, bool from_binlog) {
active_proxy_id_ = proxy_id;
if (proxy_id == 0) {

View File

@ -15,6 +15,7 @@
#include "td/mtproto/AuthData.h"
#include "td/mtproto/ConnectionManager.h"
#include "td/mtproto/Handshake.h"
#include "td/mtproto/RawConnection.h"
#include "td/mtproto/TransportType.h"
@ -79,22 +80,7 @@ class ConnectionCreator final : public NetQueryCallback {
void get_proxy_link(int32 proxy_id, Promise<string> promise);
void ping_proxy(int32 proxy_id, Promise<double> promise);
struct ConnectionData {
IPAddress ip_address;
BufferedFd<SocketFd> buffered_socket_fd;
mtproto::ConnectionManager::ConnectionToken connection_token;
unique_ptr<mtproto::RawConnection::StatsCallback> stats_callback;
};
static DcOptions get_default_dc_options(bool is_test);
static ActorOwn<> prepare_connection(IPAddress ip_address, SocketFd socket_fd, const Proxy &proxy,
const IPAddress &mtproto_ip_address,
const mtproto::TransportType &transport_type, Slice actor_name_prefix,
Slice debug_str,
unique_ptr<mtproto::RawConnection::StatsCallback> stats_callback,
ActorShared<> parent, bool use_connection_token,
Promise<ConnectionData> promise);
void test_proxy(Proxy &&proxy, int32 dc_id, double timeout, Promise<Unit> &&promise);
private:
ActorShared<> parent_;
@ -185,6 +171,26 @@ class ConnectionCreator final : public NetQueryCallback {
};
std::map<uint64, PingMainDcRequest> ping_main_dc_requests_;
struct ConnectionData {
IPAddress ip_address;
BufferedFd<SocketFd> buffered_socket_fd;
mtproto::ConnectionManager::ConnectionToken connection_token;
unique_ptr<mtproto::RawConnection::StatsCallback> stats_callback;
};
struct TestProxyRequest {
Proxy proxy_;
int16 dc_id_;
ActorOwn<> child_;
Promise<Unit> promise_;
mtproto::TransportType get_transport() const {
return mtproto::TransportType{mtproto::TransportType::ObfuscatedTcp, dc_id_, proxy_.secret()};
}
};
uint64 test_proxy_request_id_ = 0;
FlatHashMap<uint64, unique_ptr<TestProxyRequest>> test_proxy_requests_;
uint64 next_token() {
return ++current_token_;
}
@ -237,12 +243,21 @@ class ConnectionCreator final : public NetQueryCallback {
IPAddress mtproto_ip_address;
bool check_mode{false};
};
Result<SocketFd> find_connection(const Proxy &proxy, const IPAddress &proxy_ip_address, DcId dc_id,
bool allow_media_only, FindConnectionExtra &extra);
static DcOptions get_default_dc_options(bool is_test);
static Result<mtproto::TransportType> get_transport_type(const Proxy &proxy,
const DcOptionsSet::ConnectionInfo &info);
Result<SocketFd> find_connection(const Proxy &proxy, const IPAddress &proxy_ip_address, DcId dc_id,
bool allow_media_only, FindConnectionExtra &extra);
static ActorOwn<> prepare_connection(IPAddress ip_address, SocketFd socket_fd, const Proxy &proxy,
const IPAddress &mtproto_ip_address,
const mtproto::TransportType &transport_type, Slice actor_name_prefix,
Slice debug_str,
unique_ptr<mtproto::RawConnection::StatsCallback> stats_callback,
ActorShared<> parent, bool use_connection_token,
Promise<ConnectionData> promise);
ActorId<GetHostByNameActor> get_dns_resolver();
@ -252,6 +267,15 @@ class ConnectionCreator final : public NetQueryCallback {
mtproto::TransportType transport_type, string debug_str, Promise<double> promise);
void on_ping_main_dc_result(uint64 token, Result<double> result);
void on_test_proxy_connection_data(uint64 request_id, Result<ConnectionData> r_data);
void on_test_proxy_handshake_connection(uint64 request_id,
Result<unique_ptr<mtproto::RawConnection>> r_raw_connection);
void on_test_proxy_handshake(uint64 request_id, Result<unique_ptr<mtproto::AuthKeyHandshake>> r_handshake);
void on_test_proxy_timeout(uint64 request_id);
};
} // namespace td