Make all AuthKeyHandshake fields private.

GitOrigin-RevId: 73bc8e67b5c86a4f30cabde16f84395b47d7f79e
This commit is contained in:
levlam 2020-01-08 20:30:07 +03:00
parent 89d7374272
commit 1b1bd481e3
4 changed files with 48 additions and 29 deletions

View File

@ -182,7 +182,7 @@ Status AuthKeyHandshake::on_server_dh_params(Slice message, Callback *connection
return Status::Error("Server nonce mismatch"); return Status::Error("Server nonce mismatch");
} }
server_time_diff = dh_inner_data.server_time_ - Time::now(); server_time_diff_ = dh_inner_data.server_time_ - Time::now();
DhHandshake handshake; DhHandshake handshake;
handshake.set_config(dh_inner_data.g_, dh_inner_data.dh_prime_); handshake.set_config(dh_inner_data.g_, dh_inner_data.dh_prime_);
@ -209,13 +209,13 @@ Status AuthKeyHandshake::on_server_dh_params(Slice message, Callback *connection
mtproto_api::set_client_DH_params set_client_dh_params(nonce, server_nonce, encrypted_data); mtproto_api::set_client_DH_params set_client_dh_params(nonce, server_nonce, encrypted_data);
send(connection, create_storer(set_client_dh_params)); send(connection, create_storer(set_client_dh_params));
auth_key = AuthKey(auth_key_params.first, std::move(auth_key_params.second)); auth_key_ = AuthKey(auth_key_params.first, std::move(auth_key_params.second));
if (mode_ == Mode::Temp) { if (mode_ == Mode::Temp) {
auth_key.set_expires_at(expires_at_); auth_key_.set_expires_at(expires_at_);
} }
auth_key.set_created_at(dh_inner_data.server_time_); auth_key_.set_created_at(dh_inner_data.server_time_);
server_salt = as<int64>(new_nonce.raw) ^ as<int64>(server_nonce.raw); server_salt_ = as<int64>(new_nonce.raw) ^ as<int64>(server_nonce.raw);
state_ = DHGenResponse; state_ = DHGenResponse;
return Status::OK(); return Status::OK();
@ -289,7 +289,7 @@ Status AuthKeyHandshake::on_start(Callback *connection) {
return Status::OK(); return Status::OK();
} }
Status AuthKeyHandshake::on_message(Slice message, Callback *connection, Context *context) { Status AuthKeyHandshake::on_message(Slice message, Callback *connection, AuthKeyHandshakeContext *context) {
Status status = [&] { Status status = [&] {
switch (state_) { switch (state_) {
case ResPQ: case ResPQ:

View File

@ -27,6 +27,8 @@ class AuthKeyHandshakeContext {
}; };
class AuthKeyHandshake { class AuthKeyHandshake {
enum class Mode { Unknown, Main, Temp };
public: public:
class Callback { class Callback {
public: public:
@ -36,20 +38,6 @@ class AuthKeyHandshake {
virtual ~Callback() = default; virtual ~Callback() = default;
virtual void send_no_crypto(const Storer &storer) = 0; virtual void send_no_crypto(const Storer &storer) = 0;
}; };
using Context = AuthKeyHandshakeContext;
enum class Mode { Unknown, Main, Temp };
AuthKey auth_key;
double server_time_diff = 0;
uint64 server_salt = 0;
bool is_ready_for_start() const;
Status start_main(Callback *connection) TD_WARN_UNUSED_RESULT;
Status start_tmp(Callback *connection, int32 expires_in) TD_WARN_UNUSED_RESULT;
bool is_ready_for_message(const UInt128 &message_nonce) const;
bool is_ready_for_finish() const;
void on_finish();
AuthKeyHandshake(int32 dc_id, int32 expires_in) { AuthKeyHandshake(int32 dc_id, int32 expires_in) {
dc_id_ = dc_id; dc_id_ = dc_id;
@ -60,22 +48,49 @@ class AuthKeyHandshake {
expires_in_ = expires_in; expires_in_ = expires_in;
} }
} }
bool is_ready_for_start() const;
Status start_main(Callback *connection) TD_WARN_UNUSED_RESULT;
Status start_tmp(Callback *connection, int32 expires_in) TD_WARN_UNUSED_RESULT;
bool is_ready_for_message(const UInt128 &message_nonce) const;
bool is_ready_for_finish() const;
void on_finish();
void init_main() { void init_main() {
clear(); clear();
mode_ = Mode::Main; mode_ = Mode::Main;
} }
void init_temp(int32 expires_in) { void init_temp(int32 expires_in) {
clear(); clear();
mode_ = Mode::Temp; mode_ = Mode::Temp;
expires_in_ = expires_in; expires_in_ = expires_in;
} }
void resume(Callback *connection); void resume(Callback *connection);
Status on_message(Slice message, Callback *connection, Context *context) TD_WARN_UNUSED_RESULT;
Status on_message(Slice message, Callback *connection, AuthKeyHandshakeContext *context) TD_WARN_UNUSED_RESULT;
bool is_ready() const { bool is_ready() const {
return is_ready_for_finish(); return is_ready_for_finish();
} }
void clear(); void clear();
AuthKey release_auth_key() {
return std::move(auth_key_);
}
double get_server_time_diff() const {
return server_time_diff_;
}
uint64 get_server_salt() const {
return server_salt_;
}
private: private:
using State = enum { Start, ResPQ, ServerDHParams, DHGenResponse, Finish }; using State = enum { Start, ResPQ, ServerDHParams, DHGenResponse, Finish };
State state_ = Start; State state_ = Start;
@ -84,6 +99,10 @@ class AuthKeyHandshake {
int32 expires_in_ = 0; int32 expires_in_ = 0;
double expires_at_ = 0; double expires_at_ = 0;
AuthKey auth_key_;
double server_time_diff_ = 0;
uint64 server_salt_ = 0;
UInt128 nonce; UInt128 nonce;
UInt128 server_nonce; UInt128 server_nonce;
UInt256 new_nonce; UInt256 new_nonce;

View File

@ -1166,13 +1166,13 @@ void Session::on_handshake_ready(Result<unique_ptr<mtproto::AuthKeyHandshake>> r
info.handshake_ = std::move(handshake); info.handshake_ = std::move(handshake);
} else { } else {
if (is_main) { if (is_main) {
auth_data_.set_main_auth_key(std::move(handshake->auth_key)); auth_data_.set_main_auth_key(handshake->release_auth_key());
on_auth_key_updated(); on_auth_key_updated();
} else { } else {
auth_data_.set_tmp_auth_key(handshake->release_auth_key());
if (is_main_) { if (is_main_) {
registered_temp_auth_key_ = TempAuthKeyWatchdog::register_auth_key_id(handshake->auth_key.id()); registered_temp_auth_key_ = TempAuthKeyWatchdog::register_auth_key_id(auth_data_.get_tmp_auth_key().id());
} }
auth_data_.set_tmp_auth_key(std::move(handshake->auth_key));
on_tmp_auth_key_updated(); on_tmp_auth_key_updated();
} }
LOG(WARNING) << "Update auth key in session_id " << auth_data_.get_session_id() << " to " LOG(WARNING) << "Update auth key in session_id " << auth_data_.get_session_id() << " to "
@ -1182,10 +1182,10 @@ void Session::on_handshake_ready(Result<unique_ptr<mtproto::AuthKeyHandshake>> r
// Salt of temporary key is different salt. Do not rewrite it // Salt of temporary key is different salt. Do not rewrite it
if (auth_data_.use_pfs() ^ is_main) { if (auth_data_.use_pfs() ^ is_main) {
auth_data_.set_server_salt(handshake->server_salt, Time::now_cached()); auth_data_.set_server_salt(handshake->get_server_salt(), Time::now_cached());
on_server_salt_updated(); on_server_salt_updated();
} }
if (auth_data_.update_server_time_difference(handshake->server_time_diff)) { if (auth_data_.update_server_time_difference(handshake->get_server_time_diff())) {
on_server_time_difference_updated(); on_server_time_difference_updated();
} }
LOG(INFO) << "Got " << (is_main ? "main" : "tmp") << " auth key"; LOG(INFO) << "Got " << (is_main ? "main" : "tmp") << " auth key";

View File

@ -564,9 +564,9 @@ class FastPingTestActor : public Actor {
unique_ptr<mtproto::AuthData> auth_data; unique_ptr<mtproto::AuthData> auth_data;
if (iteration_ % 2 == 0) { if (iteration_ % 2 == 0) {
auth_data = make_unique<mtproto::AuthData>(); auth_data = make_unique<mtproto::AuthData>();
auth_data->set_tmp_auth_key(handshake_->auth_key); auth_data->set_tmp_auth_key(handshake_->release_auth_key());
auth_data->set_server_time_difference(handshake_->server_time_diff); auth_data->set_server_time_difference(handshake_->get_server_time_diff());
auth_data->set_server_salt(handshake_->server_salt, Time::now()); auth_data->set_server_salt(handshake_->get_server_salt(), Time::now());
auth_data->set_future_salts({mtproto::ServerSalt{0u, 1e20, 1e30}}, Time::now()); auth_data->set_future_salts({mtproto::ServerSalt{0u, 1e20, 1e30}}, Time::now());
auth_data->set_use_pfs(true); auth_data->set_use_pfs(true);
uint64 session_id = 0; uint64 session_id = 0;