Mtproto secret support

GitOrigin-RevId: 3efeb3f309b76074d6581d68e9a9b20df79e82ae
This commit is contained in:
Arseny Smirnov 2018-04-30 20:01:18 +03:00
parent c5fecc1359
commit efc5cbb8ad
10 changed files with 58 additions and 29 deletions

View File

@ -34,7 +34,7 @@ class Transport : public IStreamTransport {
size_t max_prepend_size() const override; size_t max_prepend_size() const override;
TransportType get_type() const override { TransportType get_type() const override {
return TransportType::Http; return {TransportType::Http, 0, ""};
} }
private: private:

View File

@ -14,9 +14,9 @@
namespace td { namespace td {
namespace mtproto { namespace mtproto {
std::unique_ptr<IStreamTransport> create_transport(TransportType type) { std::unique_ptr<IStreamTransport> create_transport(TransportType type) {
switch (type) { switch (type.type) {
case TransportType::ObfuscatedTcp: case TransportType::ObfuscatedTcp:
return std::make_unique<tcp::ObfuscatedTransport>(); return std::make_unique<tcp::ObfuscatedTransport>(type.dc_id, std::move(type.secret));
case TransportType::Tcp: case TransportType::Tcp:
return std::make_unique<tcp::OldTransport>(); return std::make_unique<tcp::OldTransport>();
case TransportType::Http: case TransportType::Http:

View File

@ -12,7 +12,11 @@
namespace td { namespace td {
namespace mtproto { namespace mtproto {
enum class TransportType { Tcp, ObfuscatedTcp, Http }; struct TransportType {
enum { Tcp, ObfuscatedTcp, Http } type;
int16 dc_id;
std::string secret;
};
class IStreamTransport { class IStreamTransport {
public: public:
IStreamTransport() = default; IStreamTransport() = default;

View File

@ -154,14 +154,29 @@ void ObfuscatedTransport::init(ChainBufferReader *input, ChainBufferWriter *outp
// TODO: It is actually IntermediateTransport::init_output_stream, so it will work only with // TODO: It is actually IntermediateTransport::init_output_stream, so it will work only with
// TransportImpl==IntermediateTransport // TransportImpl==IntermediateTransport
as<uint32>(header_slice.begin() + 56) = 0xeeeeeeee; as<uint32>(header_slice.begin() + 56) = 0xeeeeeeee;
if (dc_id_ != 0) {
as<int16>(header_slice.begin() + 60) = dc_id_;
}
string rheader = header; string rheader = header;
std::reverse(rheader.begin(), rheader.end()); std::reverse(rheader.begin(), rheader.end());
aes_ctr_byte_flow_.init(as<UInt256>(rheader.data() + 8), as<UInt128>(rheader.data() + 8 + 32)); auto key = as<UInt256>(rheader.data() + 8);
auto fix_key = [&](UInt256 &key) {
if (secret_.size() == 16) {
Sha256State state;
sha256_init(&state);
sha256_update(as_slice(key), &state);
sha256_update(secret_, &state);
sha256_final(&state, as_slice(key));
}
};
fix_key(key);
aes_ctr_byte_flow_.init(key, as<UInt128>(rheader.data() + 8 + 32));
aes_ctr_byte_flow_.set_input(input_); aes_ctr_byte_flow_.set_input(input_);
aes_ctr_byte_flow_ >> byte_flow_sink_; aes_ctr_byte_flow_ >> byte_flow_sink_;
output_key_ = as<UInt256>(header.data() + 8); output_key_ = as<UInt256>(header.data() + 8);
fix_key(output_key_);
output_state_.init(output_key_, as<UInt128>(header.data() + 8 + 32)); output_state_.init(output_key_, as<UInt128>(header.data() + 8 + 32));
output_->append(header_slice.substr(0, 56)); output_->append(header_slice.substr(0, 56));
output_state_.encrypt(header_slice, header_slice); output_state_.encrypt(header_slice, header_slice);

View File

@ -95,7 +95,7 @@ class OldTransport : public IStreamTransport {
return 4; return 4;
} }
TransportType get_type() const override { TransportType get_type() const override {
return TransportType::Tcp; return TransportType{TransportType::Tcp, 0, ""};
} }
private: private:
@ -106,7 +106,8 @@ class OldTransport : public IStreamTransport {
class ObfuscatedTransport : public IStreamTransport { class ObfuscatedTransport : public IStreamTransport {
public: public:
ObfuscatedTransport() = default; explicit ObfuscatedTransport(int16 dc_id, std::string secret) : dc_id_(dc_id), secret_(std::move(secret)) {
}
Result<size_t> read_next(BufferSlice *message, uint32 *quick_ack) override TD_WARN_UNUSED_RESULT { Result<size_t> read_next(BufferSlice *message, uint32 *quick_ack) override TD_WARN_UNUSED_RESULT {
aes_ctr_byte_flow_.wakeup(); aes_ctr_byte_flow_.wakeup();
return impl_.read_from_stream(byte_flow_sink_.get_output(), message, quick_ack); return impl_.read_from_stream(byte_flow_sink_.get_output(), message, quick_ack);
@ -138,11 +139,13 @@ class ObfuscatedTransport : public IStreamTransport {
} }
TransportType get_type() const override { TransportType get_type() const override {
return TransportType::ObfuscatedTcp; return TransportType{TransportType::ObfuscatedTcp, dc_id_, secret_};
} }
private: private:
TransportImpl impl_; TransportImpl impl_;
int16 dc_id_;
std::string secret_;
AesCtrByteFlow aes_ctr_byte_flow_; AesCtrByteFlow aes_ctr_byte_flow_;
ByteFlowSink byte_flow_sink_; ByteFlowSink byte_flow_sink_;
ChainBufferReader *input_; ChainBufferReader *input_;

View File

@ -290,8 +290,8 @@ void ConnectionCreator::request_raw_connection_by_ip(IPAddress ip_address,
if (r_socket_fd.is_error()) { if (r_socket_fd.is_error()) {
return promise.set_error(r_socket_fd.move_as_error()); return promise.set_error(r_socket_fd.move_as_error());
} }
auto raw_connection = std::make_unique<mtproto::RawConnection>(r_socket_fd.move_as_ok(), auto raw_connection = std::make_unique<mtproto::RawConnection>(
mtproto::TransportType::ObfuscatedTcp, nullptr); r_socket_fd.move_as_ok(), mtproto::TransportType{mtproto::TransportType::ObfuscatedTcp, 0, ""}, nullptr);
raw_connection->extra_ = network_generation_; raw_connection->extra_ = network_generation_;
promise.set_value(std::move(raw_connection)); promise.set_value(std::move(raw_connection));
} }
@ -370,7 +370,7 @@ void ConnectionCreator::client_loop(ClientInfo &client) {
// Create new RawConnection // Create new RawConnection
DcOptionsSet::Stat *stat{nullptr}; DcOptionsSet::Stat *stat{nullptr};
bool use_http{false}; mtproto::TransportType transport_type;
string debug_str; string debug_str;
IPAddress mtproto_ip; IPAddress mtproto_ip;
@ -379,7 +379,15 @@ void ConnectionCreator::client_loop(ClientInfo &client) {
auto r_socket_fd = [&, dc_id = client.dc_id, allow_media_only = client.allow_media_only]() -> Result<SocketFd> { auto r_socket_fd = [&, dc_id = client.dc_id, allow_media_only = client.allow_media_only]() -> Result<SocketFd> {
TRY_RESULT(info, dc_options_set_.find_connection(dc_id, allow_media_only, use_socks5)); TRY_RESULT(info, dc_options_set_.find_connection(dc_id, allow_media_only, use_socks5));
stat = info.stat; stat = info.stat;
use_http = info.use_http; if (info.use_http) {
transport_type = {mtproto::TransportType::Http, 0, ""};
} else {
int16 raw_dc_id = narrow_cast<int16>(dc_id.get_raw_id());
if (info.option->is_media_only()) {
raw_dc_id = -raw_dc_id;
}
transport_type = {mtproto::TransportType::ObfuscatedTcp, raw_dc_id, info.option->get_secret().str()};
}
check_mode |= info.should_check; check_mode |= info.should_check;
if (use_socks5) { if (use_socks5) {
@ -420,10 +428,10 @@ void ConnectionCreator::client_loop(ClientInfo &client) {
} }
auto promise = PromiseCreator::lambda( auto promise = PromiseCreator::lambda(
[actor_id = actor_id(this), check_mode, use_http, hash = client.hash, debug_str, [actor_id = actor_id(this), check_mode, transport_type, hash = client.hash, debug_str,
network_generation = network_generation_](Result<ConnectionData> r_connection_data) mutable { network_generation = network_generation_](Result<ConnectionData> r_connection_data) mutable {
send_closure(std::move(actor_id), &ConnectionCreator::client_create_raw_connection, send_closure(std::move(actor_id), &ConnectionCreator::client_create_raw_connection,
std::move(r_connection_data), check_mode, use_http, hash, debug_str, network_generation); std::move(r_connection_data), check_mode, transport_type, hash, debug_str, network_generation);
}); });
auto stats_callback = std::make_unique<detail::StatsCallback>( auto stats_callback = std::make_unique<detail::StatsCallback>(
@ -476,8 +484,8 @@ void ConnectionCreator::client_loop(ClientInfo &client) {
} }
void ConnectionCreator::client_create_raw_connection(Result<ConnectionData> r_connection_data, bool check_mode, void ConnectionCreator::client_create_raw_connection(Result<ConnectionData> r_connection_data, bool check_mode,
bool use_http, size_t hash, string debug_str, mtproto::TransportType transport_type, size_t hash,
uint32 network_generation) { string debug_str, uint32 network_generation) {
auto promise = PromiseCreator::lambda([actor_id = actor_id(this), hash, check_mode, auto promise = PromiseCreator::lambda([actor_id = actor_id(this), hash, check_mode,
debug_str](Result<std::unique_ptr<mtproto::RawConnection>> result) mutable { debug_str](Result<std::unique_ptr<mtproto::RawConnection>> result) mutable {
VLOG(connections) << "Ready " << debug_str << " " << tag("checked", check_mode) << tag("ok", result.is_ok()); VLOG(connections) << "Ready " << debug_str << " " << tag("checked", check_mode) << tag("ok", result.is_ok());
@ -490,9 +498,7 @@ void ConnectionCreator::client_create_raw_connection(Result<ConnectionData> r_co
auto connection_data = r_connection_data.move_as_ok(); auto connection_data = r_connection_data.move_as_ok();
auto raw_connection = std::make_unique<mtproto::RawConnection>( auto raw_connection = std::make_unique<mtproto::RawConnection>(
std::move(connection_data.socket_fd), std::move(connection_data.socket_fd), std::move(transport_type), std::move(connection_data.stats_callback));
use_http ? mtproto::TransportType::Http : mtproto::TransportType::ObfuscatedTcp,
std::move(connection_data.stats_callback));
raw_connection->set_connection_token(std::move(connection_data.connection_token)); raw_connection->set_connection_token(std::move(connection_data.connection_token));
raw_connection->extra_ = network_generation; raw_connection->extra_ = network_generation;

View File

@ -12,6 +12,8 @@
#include "td/telegram/net/DcOptionsSet.h" #include "td/telegram/net/DcOptionsSet.h"
#include "td/telegram/StateManager.h" #include "td/telegram/StateManager.h"
#include "td/mtproto/IStreamTransport.h"
#include "td/actor/actor.h" #include "td/actor/actor.h"
#include "td/actor/PromiseFuture.h" #include "td/actor/PromiseFuture.h"
#include "td/actor/SignalSlot.h" #include "td/actor/SignalSlot.h"
@ -241,8 +243,9 @@ class ConnectionCreator : public Actor {
StateManager::ConnectionToken connection_token; StateManager::ConnectionToken connection_token;
std::unique_ptr<detail::StatsCallback> stats_callback; std::unique_ptr<detail::StatsCallback> stats_callback;
}; };
void client_create_raw_connection(Result<ConnectionData> r_connection_data, bool check_mode, bool use_http, void client_create_raw_connection(Result<ConnectionData> r_connection_data, bool check_mode,
size_t hash, string debug_str, uint32 network_generation); mtproto::TransportType transport_type, size_t hash, string debug_str,
uint32 network_generation);
void client_add_connection(size_t hash, Result<std::unique_ptr<mtproto::RawConnection>> r_raw_connection, void client_add_connection(size_t hash, Result<std::unique_ptr<mtproto::RawConnection>> r_raw_connection,
bool check_flag); bool check_flag);
void client_set_timeout_at(ClientInfo &client, double wakeup_at); void client_set_timeout_at(ClientInfo &client, double wakeup_at);

View File

@ -60,10 +60,6 @@ Result<DcOptionsSet::ConnectionInfo> DcOptionsSet::find_connection(DcId dc_id, b
LOG(DEBUG) << "Skip media only option"; LOG(DEBUG) << "Skip media only option";
continue; continue;
} }
if (!option.get_secret().empty()) {
LOG(DEBUG) << "Skip options with secret";
continue; // TODO secret support
}
ConnectionInfo info; ConnectionInfo info;
info.option = &option; info.option = &option;

View File

@ -894,7 +894,8 @@ void Session::connection_open_finish(ConnectionInfo *info,
return; return;
} }
Mode expected_mode = raw_connection->get_transport_type() == mtproto::TransportType::Http ? Mode::Http : Mode::Tcp; Mode expected_mode =
raw_connection->get_transport_type().type == mtproto::TransportType::Http ? Mode::Http : Mode::Tcp;
if (mode_ != expected_mode) { if (mode_ != expected_mode) {
LOG(INFO) << "Change mode " << mode_ << "--->" << expected_mode; LOG(INFO) << "Change mode " << mode_ << "--->" << expected_mode;
mode_ = expected_mode; mode_ = expected_mode;

View File

@ -107,7 +107,7 @@ class TestPingActor : public Actor {
void start_up() override { void start_up() override {
ping_connection_ = std::make_unique<mtproto::PingConnection>(std::make_unique<mtproto::RawConnection>( ping_connection_ = std::make_unique<mtproto::PingConnection>(std::make_unique<mtproto::RawConnection>(
SocketFd::open(ip_address_).move_as_ok(), mtproto::TransportType::Tcp, nullptr)); SocketFd::open(ip_address_).move_as_ok(), mtproto::TransportType{mtproto::TransportType::Tcp, 0, ""}, nullptr));
ping_connection_->get_pollable().set_observer(this); ping_connection_->get_pollable().set_observer(this);
subscribe(ping_connection_->get_pollable()); subscribe(ping_connection_->get_pollable());
@ -209,8 +209,9 @@ class HandshakeTestActor : public Actor {
} }
void loop() override { void loop() override {
if (!wait_for_raw_connection_ && !raw_connection_) { if (!wait_for_raw_connection_ && !raw_connection_) {
raw_connection_ = std::make_unique<mtproto::RawConnection>(SocketFd::open(get_default_ip_address()).move_as_ok(), raw_connection_ =
mtproto::TransportType::Tcp, nullptr); std::make_unique<mtproto::RawConnection>(SocketFd::open(get_default_ip_address()).move_as_ok(),
mtproto::TransportType{mtproto::TransportType::Tcp, 0, ""}, nullptr);
} }
if (!wait_for_handshake_ && !handshake_) { if (!wait_for_handshake_ && !handshake_) {
handshake_ = std::make_unique<AuthKeyHandshake>(0); handshake_ = std::make_unique<AuthKeyHandshake>(0);