Validate and drop invalid main authorization key

GitOrigin-RevId: 5f5a0baf4fc55b629b6e0534c475f6236cc72506
This commit is contained in:
Arseny Smirnov 2020-01-07 15:42:04 +03:00
parent ad3a1a35c5
commit 890855a4f0
7 changed files with 157 additions and 17 deletions

View File

@ -60,6 +60,9 @@ class AuthData {
void set_main_auth_key(AuthKey auth_key) { void set_main_auth_key(AuthKey auth_key) {
main_auth_key_ = std::move(auth_key); main_auth_key_ = std::move(auth_key);
} }
void break_main_auth_key() {
main_auth_key_.break_key();
}
const AuthKey &get_main_auth_key() const { const AuthKey &get_main_auth_key() const {
// CHECK(has_main_auth_key()); // CHECK(has_main_auth_key());
return main_auth_key_; return main_auth_key_;

View File

@ -17,6 +17,10 @@ class AuthKey {
AuthKey() = default; AuthKey() = default;
AuthKey(uint64 auth_key_id, string &&auth_key) : auth_key_id_(auth_key_id), auth_key_(auth_key) { AuthKey(uint64 auth_key_id, string &&auth_key) : auth_key_id_(auth_key_id), auth_key_(auth_key) {
} }
void break_key() {
auth_key_id_++;
auth_key_[0]++;
}
bool empty() const { bool empty() const {
return auth_key_.empty(); return auth_key_.empty();
@ -47,21 +51,33 @@ class AuthKey {
double expires_at() const { double expires_at() const {
return expires_at_; return expires_at_;
} }
double created_at() const {
return created_at_;
}
void set_expires_at(double expires_at) { void set_expires_at(double expires_at) {
expires_at_ = expires_at; expires_at_ = expires_at;
// expires_at_ = Time::now() + 60 * 60 + 10 * 60; // expires_at_ = Time::now() + 60 * 60 + 10 * 60;
} }
void set_created_at(double created_at) {
created_at_ = created_at;
}
void clear() { void clear() {
auth_key_.clear(); auth_key_.clear();
} }
enum : int32 { AUTH_FLAG = 1, WAS_AUTH_FLAG = 2 }; enum : int32 { AUTH_FLAG = 1, WAS_AUTH_FLAG = 2, HAS_CREATED_AT = 4 };
template <class StorerT> template <class StorerT>
void store(StorerT &storer) const { void store(StorerT &storer) const {
storer.store_binary(auth_key_id_); storer.store_binary(auth_key_id_);
storer.store_binary(static_cast<int32>((auth_flag_ ? AUTH_FLAG : 0) | (was_auth_flag_ ? WAS_AUTH_FLAG : 0))); bool has_created_at = created_at_ != 0;
storer.store_binary(static_cast<int32>((auth_flag_ ? AUTH_FLAG : 0) | (was_auth_flag_ ? WAS_AUTH_FLAG : 0) |
(has_created_at ? HAS_CREATED_AT : 0)));
storer.store_string(auth_key_); storer.store_string(auth_key_);
if (has_created_at) {
storer.store_binary(created_at_);
}
} }
template <class ParserT> template <class ParserT>
@ -71,17 +87,21 @@ class AuthKey {
auth_flag_ = (flags & AUTH_FLAG) != 0; auth_flag_ = (flags & AUTH_FLAG) != 0;
was_auth_flag_ = (flags & WAS_AUTH_FLAG) != 0 || auth_flag_; was_auth_flag_ = (flags & WAS_AUTH_FLAG) != 0 || auth_flag_;
auth_key_ = parser.template fetch_string<string>(); auth_key_ = parser.template fetch_string<string>();
if ((flags & HAS_CREATED_AT) != 0) {
created_at_ = parser.fetch_double();
}
// just in case // just in case
need_header_ = true; need_header_ = true;
} }
private: private:
uint64 auth_key_id_ = 0; uint64 auth_key_id_{0};
string auth_key_; string auth_key_;
bool auth_flag_ = false; bool auth_flag_{false};
bool was_auth_flag_ = false; bool was_auth_flag_{false};
bool need_header_ = true; bool need_header_{true};
double expires_at_ = 0; double expires_at_{0};
double created_at_{0};
}; };
} // namespace mtproto } // namespace mtproto

View File

@ -213,6 +213,7 @@ Status AuthKeyHandshake::on_server_dh_params(Slice message, Callback *connection
if (mode_ == Mode::Temp) { if (mode_ == Mode::Temp) {
auth_key.set_expires_at(expires_at_); auth_key.set_expires_at(expires_at_);
} }
auth_key.set_created_at(dh_inner_data.server_time_);
server_salt = as<int64>(new_nonce.raw) ^ as<int64>(server_nonce.raw); server_salt = as<int64>(new_nonce.raw) ^ as<int64>(server_nonce.raw);

View File

@ -123,6 +123,12 @@ class Global : public ActorContext {
return *shared_config_; return *shared_config_;
} }
bool is_server_time_reliable() const {
return server_time_difference_was_updated_;
}
double from_server_time(double date) const {
return date - get_server_time_difference();
}
double to_server_time(double now) const { double to_server_time(double now) const {
return now + get_server_time_difference(); return now + get_server_time_difference();
} }

View File

@ -106,7 +106,9 @@ class AuthDataSharedImpl : public AuthDataShared {
} }
void log_auth_key(const mtproto::AuthKey &auth_key) { void log_auth_key(const mtproto::AuthKey &auth_key) {
LOG(WARNING) << dc_id_ << " " << tag("auth_key_id", auth_key.id()) << tag("state", get_auth_key_state(auth_key)); LOG(WARNING) << dc_id_ << " " << tag("auth_key_id", auth_key.id()) << tag("state", get_auth_key_state(auth_key))
<< tag("created_at", auth_key.created_at());
;
} }
}; };

View File

@ -129,6 +129,7 @@ Session::Session(unique_ptr<Callback> callback, std::shared_ptr<AuthDataShared>
shared_auth_data_ = std::move(shared_auth_data); shared_auth_data_ = std::move(shared_auth_data);
auth_data_.set_use_pfs(use_pfs); auth_data_.set_use_pfs(use_pfs);
auth_data_.set_main_auth_key(shared_auth_data_->get_auth_key()); auth_data_.set_main_auth_key(shared_auth_data_->get_auth_key());
//auth_data_.break_main_auth_key();
auth_data_.set_server_time_difference(shared_auth_data_->get_server_time_difference()); auth_data_.set_server_time_difference(shared_auth_data_->get_server_time_difference());
auth_data_.set_future_salts(shared_auth_data_->get_future_salts(), Time::now()); auth_data_.set_future_salts(shared_auth_data_->get_future_salts(), Time::now());
if (use_pfs && !tmp_auth_key.empty()) { if (use_pfs && !tmp_auth_key.empty()) {
@ -140,6 +141,7 @@ Session::Session(unique_ptr<Callback> callback, std::shared_ptr<AuthDataShared>
Random::secure_bytes(reinterpret_cast<uint8 *>(&session_id), sizeof(session_id)); Random::secure_bytes(reinterpret_cast<uint8 *>(&session_id), sizeof(session_id));
} while (session_id == 0); } while (session_id == 0);
auth_data_.set_session_id(session_id); auth_data_.set_session_id(session_id);
use_pfs_ = use_pfs;
LOG(WARNING) << "Generate new session_id " << session_id << " for " << (use_pfs ? "temp " : "") LOG(WARNING) << "Generate new session_id " << session_id << " for " << (use_pfs ? "temp " : "")
<< (is_cdn ? "CDN " : "") << "auth key " << auth_data_.get_auth_key().id() << " for " << (is_cdn ? "CDN " : "") << "auth key " << auth_data_.get_auth_key().id() << " for "
<< (is_main_ ? "main " : "") << "DC" << dc_id; << (is_main_ ? "main " : "") << "DC" << dc_id;
@ -238,19 +240,36 @@ void Session::send(NetQueryPtr &&query) {
loop(); loop();
} }
void Session::on_result(NetQueryPtr query) { void Session::on_bind_result(NetQueryPtr query) {
CHECK(UniqueId::extract_type(query->id()) == UniqueId::BindKey);
if (last_bind_id_ != query->id()) {
query->clear();
return;
}
LOG(INFO) << "ANSWER TO BindKey" << query; LOG(INFO) << "ANSWER TO BindKey" << query;
Status status; Status status;
tmp_auth_key_id_ = 0; tmp_auth_key_id_ = 0;
last_bind_id_ = 0; last_bind_id_ = 0;
if (query->is_error()) { if (query->is_error()) {
status = std::move(query->error()); status = std::move(query->error());
if (status.code() == 400 && status.message() == "ENCRYPTED_MESSAGE_INVALID") {
bool has_immunity =
!G()->is_server_time_reliable() || G()->server_time() - auth_data_.get_main_auth_key().created_at() < 60;
LOG(ERROR) << G()->is_server_time_reliable() << " "
<< G()->server_time() - auth_data_.get_auth_key().created_at();
if (!use_pfs_) {
if (has_immunity) {
LOG(WARNING) << "Do not drop main key, because it was created too recently";
} else {
LOG(WARNING) << "Drop main key because check with temporary key failed";
auth_data_.drop_main_auth_key();
on_auth_key_updated();
}
} else {
if (has_immunity) {
LOG(WARNING) << "Do not check validate main key, because it was created too recently";
} else {
need_check_main_key_ = true;
auth_data_.set_use_pfs(false);
LOG(WARNING) << "Got ENCRYPTED_MESSAGE_INVALID error, validate main key";
}
}
}
} else { } else {
auto r_flag = fetch_result<telegram_api::auth_bindTempAuthKey>(query->ok()); auto r_flag = fetch_result<telegram_api::auth_bindTempAuthKey>(query->ok());
if (r_flag.is_error()) { if (r_flag.is_error()) {
@ -268,11 +287,53 @@ void Session::on_result(NetQueryPtr query) {
on_tmp_auth_key_updated(); on_tmp_auth_key_updated();
} else { } else {
LOG(ERROR) << "BindKey failed: " << status; LOG(ERROR) << "BindKey failed: " << status;
connection_close(&main_connection_);
connection_close(&long_poll_connection_);
} }
query->clear(); query->clear();
yield(); yield();
} }
void Session::on_check_key_result(NetQueryPtr query) {
LOG(INFO) << "ANSWER TO GetNearestDc" << query;
Status status;
auth_key_id_ = 0;
last_check_id_ = 0;
if (query->is_error()) {
status = std::move(query->error());
} else {
auto r_flag = fetch_result<telegram_api::help_getNearestDc>(query->ok());
if (r_flag.is_error()) {
status = r_flag.move_as_error();
}
}
if (status.is_ok()) {
LOG(INFO) << "Check main key ok";
need_check_main_key_ = false;
auth_data_.set_use_pfs(true);
} else {
LOG(ERROR) << "Check main key failed: " << status;
connection_close(&main_connection_);
connection_close(&long_poll_connection_);
}
query->clear();
yield();
}
void Session::on_result(NetQueryPtr query) {
CHECK(UniqueId::extract_type(query->id()) == UniqueId::BindKey);
if (last_bind_id_ == query->id()) {
return on_bind_result(std::move(query));
}
if (last_check_id_ == query->id()) {
return on_check_key_result(std::move(query));
}
query->clear();
return;
}
void Session::return_query(NetQueryPtr &&query) { void Session::return_query(NetQueryPtr &&query) {
last_activity_timestamp_ = Time::now(); last_activity_timestamp_ = Time::now();
@ -429,6 +490,16 @@ void Session::on_closed(Status status) {
} else if (need_destroy_) { } else if (need_destroy_) {
auth_data_.drop_main_auth_key(); auth_data_.drop_main_auth_key();
on_auth_key_updated(); on_auth_key_updated();
} else {
if (!use_pfs_) {
// Logout if has error and or 1 minute is passed from start, or 1 minute has passed
// since auth_key creation
auth_data_.set_use_pfs(true);
} else if (need_check_main_key_) {
LOG(WARNING) << "Invalidate main key";
auth_data_.drop_main_auth_key();
on_auth_key_updated();
}
} }
} }
@ -1019,14 +1090,38 @@ void Session::connection_close(ConnectionInfo *info) {
info->connection->force_close(static_cast<mtproto::SessionConnection::Callback *>(this)); info->connection->force_close(static_cast<mtproto::SessionConnection::Callback *>(this));
CHECK(info->state == ConnectionInfo::State::Empty); CHECK(info->state == ConnectionInfo::State::Empty);
} }
bool Session::need_send_check_main_key() const {
return need_check_main_key_ && auth_data_.get_main_auth_key().id() != auth_key_id_;
}
bool Session::connection_send_check_main_key(ConnectionInfo *info) {
if (!need_check_main_key_) {
return false;
}
uint64 key_id = auth_data_.get_main_auth_key().id();
if (key_id == auth_key_id_) {
return false;
}
CHECK(info->state != ConnectionInfo::State::Empty);
LOG(INFO) << "Check main key";
auth_key_id_ = key_id;
last_check_id_ = UniqueId::next(UniqueId::BindKey);
NetQueryPtr query = G()->net_query_creator().create(last_check_id_, create_storer(telegram_api::help_getNearestDc()));
query->dispatch_ttl = 0;
query->set_callback(actor_shared(this));
connection_send_query(info, std::move(query));
return true;
}
bool Session::need_send_bind_key() const { bool Session::need_send_bind_key() const {
return auth_data_.use_pfs() && !auth_data_.get_bind_flag() && auth_data_.get_tmp_auth_key().id() != tmp_auth_key_id_; return auth_data_.use_pfs() && !auth_data_.get_bind_flag() && auth_data_.get_tmp_auth_key().id() != tmp_auth_key_id_;
} }
bool Session::need_send_query() const { bool Session::need_send_query() const {
return !close_flag_ && (!auth_data_.use_pfs() || auth_data_.get_bind_flag()) && !pending_queries_.empty() && return !close_flag_ && !need_check_main_key_ && (!auth_data_.use_pfs() || auth_data_.get_bind_flag()) &&
!can_destroy_auth_key(); !pending_queries_.empty() && !can_destroy_auth_key();
} }
bool Session::connection_send_bind_key(ConnectionInfo *info) { bool Session::connection_send_bind_key(ConnectionInfo *info) {
CHECK(info->state != ConnectionInfo::State::Empty); CHECK(info->state != ConnectionInfo::State::Empty);
uint64 key_id = auth_data_.get_tmp_auth_key().id(); uint64 key_id = auth_data_.get_tmp_auth_key().id();
@ -1215,6 +1310,10 @@ void Session::loop() {
connection_send_bind_key(&main_connection_); connection_send_bind_key(&main_connection_);
need_flush = true; need_flush = true;
} }
if (need_send_check_main_key()) {
connection_send_check_main_key(&main_connection_);
need_flush = true;
}
} }
if (need_flush) { if (need_flush) {
connection_flush(&main_connection_); connection_flush(&main_connection_);

View File

@ -113,7 +113,9 @@ class Session final
bool online_flag_ = false; bool online_flag_ = false;
bool connection_online_flag_ = false; bool connection_online_flag_ = false;
uint64 tmp_auth_key_id_ = 0; uint64 tmp_auth_key_id_ = 0;
uint64 auth_key_id_ = 0;
uint64 last_bind_id_ = 0; uint64 last_bind_id_ = 0;
uint64 last_check_id_ = 0;
double last_activity_timestamp_ = 0; double last_activity_timestamp_ = 0;
size_t dropped_size_ = 0; size_t dropped_size_ = 0;
@ -148,6 +150,8 @@ class Session final
std::shared_ptr<Callback> callback_; std::shared_ptr<Callback> callback_;
mtproto::AuthData auth_data_; mtproto::AuthData auth_data_;
bool use_pfs_{false};
bool need_check_main_key_{false};
TempAuthKeyWatchdog::RegisteredAuthKey registered_temp_auth_key_; TempAuthKeyWatchdog::RegisteredAuthKey registered_temp_auth_key_;
std::shared_ptr<AuthDataShared> shared_auth_data_; std::shared_ptr<AuthDataShared> shared_auth_data_;
bool close_flag_ = false; bool close_flag_ = false;
@ -229,9 +233,14 @@ class Session final
bool need_send_query() const; bool need_send_query() const;
bool can_destroy_auth_key() const; bool can_destroy_auth_key() const;
bool connection_send_bind_key(ConnectionInfo *info); bool connection_send_bind_key(ConnectionInfo *info);
bool need_send_check_main_key() const;
bool connection_send_check_main_key(ConnectionInfo *info);
void on_result(NetQueryPtr query) override; void on_result(NetQueryPtr query) override;
void on_bind_result(NetQueryPtr query);
void on_check_key_result(NetQueryPtr query);
void start_up() override; void start_up() override;
void loop() override; void loop() override;
void hangup() override; void hangup() override;