diff --git a/td/mtproto/HttpTransport.h b/td/mtproto/HttpTransport.h index 8e3ed41a7..c1874a051 100644 --- a/td/mtproto/HttpTransport.h +++ b/td/mtproto/HttpTransport.h @@ -34,7 +34,7 @@ class Transport : public IStreamTransport { size_t max_prepend_size() const override; TransportType get_type() const override { - return TransportType::Http; + return {TransportType::Http, 0, ""}; } private: diff --git a/td/mtproto/IStreamTransport.cpp b/td/mtproto/IStreamTransport.cpp index 8440108c7..1184d1668 100644 --- a/td/mtproto/IStreamTransport.cpp +++ b/td/mtproto/IStreamTransport.cpp @@ -14,9 +14,9 @@ namespace td { namespace mtproto { std::unique_ptr create_transport(TransportType type) { - switch (type) { + switch (type.type) { case TransportType::ObfuscatedTcp: - return std::make_unique(); + return std::make_unique(type.dc_id, std::move(type.secret)); case TransportType::Tcp: return std::make_unique(); case TransportType::Http: diff --git a/td/mtproto/IStreamTransport.h b/td/mtproto/IStreamTransport.h index 6796d457f..615019553 100644 --- a/td/mtproto/IStreamTransport.h +++ b/td/mtproto/IStreamTransport.h @@ -12,7 +12,11 @@ namespace td { namespace mtproto { -enum class TransportType { Tcp, ObfuscatedTcp, Http }; +struct TransportType { + enum { Tcp, ObfuscatedTcp, Http } type; + int16 dc_id; + std::string secret; +}; class IStreamTransport { public: IStreamTransport() = default; diff --git a/td/mtproto/TcpTransport.cpp b/td/mtproto/TcpTransport.cpp index e7613acab..d50b21668 100644 --- a/td/mtproto/TcpTransport.cpp +++ b/td/mtproto/TcpTransport.cpp @@ -154,14 +154,29 @@ void ObfuscatedTransport::init(ChainBufferReader *input, ChainBufferWriter *outp // TODO: It is actually IntermediateTransport::init_output_stream, so it will work only with // TransportImpl==IntermediateTransport as(header_slice.begin() + 56) = 0xeeeeeeee; + if (dc_id_ != 0) { + as(header_slice.begin() + 60) = dc_id_; + } string rheader = header; std::reverse(rheader.begin(), rheader.end()); - aes_ctr_byte_flow_.init(as(rheader.data() + 8), as(rheader.data() + 8 + 32)); + auto key = as(rheader.data() + 8); + auto fix_key = [&](UInt256 &key) { + if (secret_.size() == 16) { + Sha256State state; + sha256_init(&state); + sha256_update(as_slice(key), &state); + sha256_update(secret_, &state); + sha256_final(&state, as_slice(key)); + } + }; + fix_key(key); + aes_ctr_byte_flow_.init(key, as(rheader.data() + 8 + 32)); aes_ctr_byte_flow_.set_input(input_); aes_ctr_byte_flow_ >> byte_flow_sink_; output_key_ = as(header.data() + 8); + fix_key(output_key_); output_state_.init(output_key_, as(header.data() + 8 + 32)); output_->append(header_slice.substr(0, 56)); output_state_.encrypt(header_slice, header_slice); diff --git a/td/mtproto/TcpTransport.h b/td/mtproto/TcpTransport.h index d53048478..2641422a9 100644 --- a/td/mtproto/TcpTransport.h +++ b/td/mtproto/TcpTransport.h @@ -95,7 +95,7 @@ class OldTransport : public IStreamTransport { return 4; } TransportType get_type() const override { - return TransportType::Tcp; + return TransportType{TransportType::Tcp, 0, ""}; } private: @@ -106,7 +106,8 @@ class OldTransport : public IStreamTransport { class ObfuscatedTransport : public IStreamTransport { public: - ObfuscatedTransport() = default; + explicit ObfuscatedTransport(int16 dc_id, std::string secret) : dc_id_(dc_id), secret_(std::move(secret)) { + } Result read_next(BufferSlice *message, uint32 *quick_ack) override TD_WARN_UNUSED_RESULT { aes_ctr_byte_flow_.wakeup(); return impl_.read_from_stream(byte_flow_sink_.get_output(), message, quick_ack); @@ -138,11 +139,13 @@ class ObfuscatedTransport : public IStreamTransport { } TransportType get_type() const override { - return TransportType::ObfuscatedTcp; + return TransportType{TransportType::ObfuscatedTcp, dc_id_, secret_}; } private: TransportImpl impl_; + int16 dc_id_; + std::string secret_; AesCtrByteFlow aes_ctr_byte_flow_; ByteFlowSink byte_flow_sink_; ChainBufferReader *input_; diff --git a/td/telegram/net/ConnectionCreator.cpp b/td/telegram/net/ConnectionCreator.cpp index f17272e65..7b420af6c 100644 --- a/td/telegram/net/ConnectionCreator.cpp +++ b/td/telegram/net/ConnectionCreator.cpp @@ -290,8 +290,8 @@ void ConnectionCreator::request_raw_connection_by_ip(IPAddress ip_address, if (r_socket_fd.is_error()) { return promise.set_error(r_socket_fd.move_as_error()); } - auto raw_connection = std::make_unique(r_socket_fd.move_as_ok(), - mtproto::TransportType::ObfuscatedTcp, nullptr); + auto raw_connection = std::make_unique( + r_socket_fd.move_as_ok(), mtproto::TransportType{mtproto::TransportType::ObfuscatedTcp, 0, ""}, nullptr); raw_connection->extra_ = network_generation_; promise.set_value(std::move(raw_connection)); } @@ -370,7 +370,7 @@ void ConnectionCreator::client_loop(ClientInfo &client) { // Create new RawConnection DcOptionsSet::Stat *stat{nullptr}; - bool use_http{false}; + mtproto::TransportType transport_type; string debug_str; IPAddress mtproto_ip; @@ -379,7 +379,15 @@ void ConnectionCreator::client_loop(ClientInfo &client) { auto r_socket_fd = [&, dc_id = client.dc_id, allow_media_only = client.allow_media_only]() -> Result { TRY_RESULT(info, dc_options_set_.find_connection(dc_id, allow_media_only, use_socks5)); stat = info.stat; - use_http = info.use_http; + if (info.use_http) { + transport_type = {mtproto::TransportType::Http, 0, ""}; + } else { + int16 raw_dc_id = narrow_cast(dc_id.get_raw_id()); + if (info.option->is_media_only()) { + raw_dc_id = -raw_dc_id; + } + transport_type = {mtproto::TransportType::ObfuscatedTcp, raw_dc_id, info.option->get_secret().str()}; + } check_mode |= info.should_check; if (use_socks5) { @@ -420,10 +428,10 @@ void ConnectionCreator::client_loop(ClientInfo &client) { } auto promise = PromiseCreator::lambda( - [actor_id = actor_id(this), check_mode, use_http, hash = client.hash, debug_str, + [actor_id = actor_id(this), check_mode, transport_type, hash = client.hash, debug_str, network_generation = network_generation_](Result r_connection_data) mutable { send_closure(std::move(actor_id), &ConnectionCreator::client_create_raw_connection, - std::move(r_connection_data), check_mode, use_http, hash, debug_str, network_generation); + std::move(r_connection_data), check_mode, transport_type, hash, debug_str, network_generation); }); auto stats_callback = std::make_unique( @@ -476,8 +484,8 @@ void ConnectionCreator::client_loop(ClientInfo &client) { } void ConnectionCreator::client_create_raw_connection(Result r_connection_data, bool check_mode, - bool use_http, size_t hash, string debug_str, - uint32 network_generation) { + mtproto::TransportType transport_type, size_t hash, + string debug_str, uint32 network_generation) { auto promise = PromiseCreator::lambda([actor_id = actor_id(this), hash, check_mode, debug_str](Result> result) mutable { VLOG(connections) << "Ready " << debug_str << " " << tag("checked", check_mode) << tag("ok", result.is_ok()); @@ -490,9 +498,7 @@ void ConnectionCreator::client_create_raw_connection(Result r_co auto connection_data = r_connection_data.move_as_ok(); auto raw_connection = std::make_unique( - std::move(connection_data.socket_fd), - use_http ? mtproto::TransportType::Http : mtproto::TransportType::ObfuscatedTcp, - std::move(connection_data.stats_callback)); + std::move(connection_data.socket_fd), std::move(transport_type), std::move(connection_data.stats_callback)); raw_connection->set_connection_token(std::move(connection_data.connection_token)); raw_connection->extra_ = network_generation; diff --git a/td/telegram/net/ConnectionCreator.h b/td/telegram/net/ConnectionCreator.h index 7265cdaba..7d9096ce4 100644 --- a/td/telegram/net/ConnectionCreator.h +++ b/td/telegram/net/ConnectionCreator.h @@ -12,6 +12,8 @@ #include "td/telegram/net/DcOptionsSet.h" #include "td/telegram/StateManager.h" +#include "td/mtproto/IStreamTransport.h" + #include "td/actor/actor.h" #include "td/actor/PromiseFuture.h" #include "td/actor/SignalSlot.h" @@ -241,8 +243,9 @@ class ConnectionCreator : public Actor { StateManager::ConnectionToken connection_token; std::unique_ptr stats_callback; }; - void client_create_raw_connection(Result r_connection_data, bool check_mode, bool use_http, - size_t hash, string debug_str, uint32 network_generation); + void client_create_raw_connection(Result r_connection_data, bool check_mode, + mtproto::TransportType transport_type, size_t hash, string debug_str, + uint32 network_generation); void client_add_connection(size_t hash, Result> r_raw_connection, bool check_flag); void client_set_timeout_at(ClientInfo &client, double wakeup_at); diff --git a/td/telegram/net/DcOptionsSet.cpp b/td/telegram/net/DcOptionsSet.cpp index bdba462d8..4aefd3b7f 100644 --- a/td/telegram/net/DcOptionsSet.cpp +++ b/td/telegram/net/DcOptionsSet.cpp @@ -60,10 +60,6 @@ Result DcOptionsSet::find_connection(DcId dc_id, b LOG(DEBUG) << "Skip media only option"; continue; } - if (!option.get_secret().empty()) { - LOG(DEBUG) << "Skip options with secret"; - continue; // TODO secret support - } ConnectionInfo info; info.option = &option; diff --git a/td/telegram/net/Session.cpp b/td/telegram/net/Session.cpp index b48595502..e45bd4a0c 100644 --- a/td/telegram/net/Session.cpp +++ b/td/telegram/net/Session.cpp @@ -894,7 +894,8 @@ void Session::connection_open_finish(ConnectionInfo *info, return; } - Mode expected_mode = raw_connection->get_transport_type() == mtproto::TransportType::Http ? Mode::Http : Mode::Tcp; + Mode expected_mode = + raw_connection->get_transport_type().type == mtproto::TransportType::Http ? Mode::Http : Mode::Tcp; if (mode_ != expected_mode) { LOG(INFO) << "Change mode " << mode_ << "--->" << expected_mode; mode_ = expected_mode; diff --git a/test/mtproto.cpp b/test/mtproto.cpp index de81cd6c4..aaec870f6 100644 --- a/test/mtproto.cpp +++ b/test/mtproto.cpp @@ -107,7 +107,7 @@ class TestPingActor : public Actor { void start_up() override { ping_connection_ = std::make_unique(std::make_unique( - SocketFd::open(ip_address_).move_as_ok(), mtproto::TransportType::Tcp, nullptr)); + SocketFd::open(ip_address_).move_as_ok(), mtproto::TransportType{mtproto::TransportType::Tcp, 0, ""}, nullptr)); ping_connection_->get_pollable().set_observer(this); subscribe(ping_connection_->get_pollable()); @@ -209,8 +209,9 @@ class HandshakeTestActor : public Actor { } void loop() override { if (!wait_for_raw_connection_ && !raw_connection_) { - raw_connection_ = std::make_unique(SocketFd::open(get_default_ip_address()).move_as_ok(), - mtproto::TransportType::Tcp, nullptr); + raw_connection_ = + std::make_unique(SocketFd::open(get_default_ip_address()).move_as_ok(), + mtproto::TransportType{mtproto::TransportType::Tcp, 0, ""}, nullptr); } if (!wait_for_handshake_ && !handshake_) { handshake_ = std::make_unique(0);