AuthManager: persistent State::WaitPassword

GitOrigin-RevId: dd863c2484b16da1700e280b13adc01facb7c8bb
This commit is contained in:
Arseny Smirnov 2018-03-15 13:06:26 +03:00
parent 6f33dec5bb
commit 710f71701a
3 changed files with 112 additions and 25 deletions

View File

@ -369,7 +369,8 @@ tl_object_ptr<td_api::AuthorizationState> AuthManager::get_authorization_state_o
case State::WaitPhoneNumber: case State::WaitPhoneNumber:
return make_tl_object<td_api::authorizationStateWaitPhoneNumber>(); return make_tl_object<td_api::authorizationStateWaitPhoneNumber>();
case State::WaitPassword: case State::WaitPassword:
return make_tl_object<td_api::authorizationStateWaitPassword>(hint_, has_recovery_, email_address_pattern_); return make_tl_object<td_api::authorizationStateWaitPassword>(
wait_password_state_.hint_, wait_password_state_.has_recovery_, wait_password_state_.email_address_pattern_);
case State::LoggingOut: case State::LoggingOut:
return make_tl_object<td_api::authorizationStateLoggingOut>(); return make_tl_object<td_api::authorizationStateLoggingOut>();
case State::Closing: case State::Closing:
@ -506,7 +507,7 @@ void AuthManager::check_password(uint64 query_id, string password) {
return on_query_error(query_id, Status::Error(8, "checkAuthenticationPassword unexpected")); return on_query_error(query_id, Status::Error(8, "checkAuthenticationPassword unexpected"));
} }
BufferSlice buf(32); BufferSlice buf(32);
password = current_salt_ + password + current_salt_; password = wait_password_state_.current_salt_ + password + wait_password_state_.current_salt_;
sha256(password, buf.as_slice()); sha256(password, buf.as_slice());
on_new_query(query_id); on_new_query(query_id);
@ -638,16 +639,17 @@ void AuthManager::on_get_password_result(NetQueryPtr &result) {
return on_query_error(r_password.move_as_error()); return on_query_error(r_password.move_as_error());
} }
auto password = r_password.move_as_ok(); auto password = r_password.move_as_ok();
wait_password_state_ = WaitPasswordState();
if (password->get_id() == telegram_api::account_noPassword::ID) { if (password->get_id() == telegram_api::account_noPassword::ID) {
auto no_password = move_tl_object_as<telegram_api::account_noPassword>(password); auto no_password = move_tl_object_as<telegram_api::account_noPassword>(password);
new_salt_ = no_password->new_salt_.as_slice().str(); wait_password_state_.new_salt_ = no_password->new_salt_.as_slice().str();
} else { } else {
CHECK(password->get_id() == telegram_api::account_password::ID); CHECK(password->get_id() == telegram_api::account_password::ID);
auto password_info = move_tl_object_as<telegram_api::account_password>(password); auto password_info = move_tl_object_as<telegram_api::account_password>(password);
current_salt_ = password_info->current_salt_.as_slice().str(); wait_password_state_.current_salt_ = password_info->current_salt_.as_slice().str();
new_salt_ = password_info->new_salt_.as_slice().str(); wait_password_state_.new_salt_ = password_info->new_salt_.as_slice().str();
hint_ = password_info->hint_; wait_password_state_.hint_ = password_info->hint_;
has_recovery_ = password_info->has_recovery_; wait_password_state_.has_recovery_ = password_info->has_recovery_;
} }
update_state(State::WaitPassword); update_state(State::WaitPassword);
on_query_ok(); on_query_ok();
@ -660,7 +662,7 @@ void AuthManager::on_request_password_recovery_result(NetQueryPtr &result) {
} }
auto email_address_pattern = r_email_address_pattern.move_as_ok(); auto email_address_pattern = r_email_address_pattern.move_as_ok();
CHECK(email_address_pattern->get_id() == telegram_api::auth_passwordRecovery::ID); CHECK(email_address_pattern->get_id() == telegram_api::auth_passwordRecovery::ID);
email_address_pattern_ = email_address_pattern->email_pattern_; wait_password_state_.email_address_pattern_ = email_address_pattern->email_pattern_;
update_state(State::WaitPassword, true); update_state(State::WaitPassword, true);
on_query_ok(); on_query_ok();
} }
@ -847,7 +849,6 @@ bool AuthManager::load_state() {
LOG(INFO) << "Ignore auth_state: " << status; LOG(INFO) << "Ignore auth_state: " << status;
return false; return false;
} }
CHECK(db_state.state_ == State::WaitCode);
if (db_state.api_id_ != api_id_ || db_state.api_hash_ != api_hash_) { if (db_state.api_id_ != api_id_ || db_state.api_hash_ != api_hash_) {
LOG(INFO) << "Ignore auth_state: api_id or api_hash changed"; LOG(INFO) << "Ignore auth_state: api_id or api_hash changed";
return false; return false;
@ -860,22 +861,35 @@ bool AuthManager::load_state() {
LOG(INFO) << "Ignore auth_state: expired " << db_state.state_timestamp_.in(); LOG(INFO) << "Ignore auth_state: expired " << db_state.state_timestamp_.in();
return false; return false;
} }
LOG(INFO) << "Load auth_state from db";
send_code_helper_ = db_state.send_code_helper_; LOG(INFO) << "Load auth_state from db: " << tag("state", static_cast<int32>(db_state.state_));
update_state(State::WaitCode, false, false); if (db_state.state_ == State::WaitCode) {
send_code_helper_ = std::move(db_state.send_code_helper_);
} else if (db_state.state_ == State::WaitPassword) {
wait_password_state_ = std::move(db_state.wait_password_state_);
} else {
UNREACHABLE();
}
update_state(db_state.state_, false, false);
return true; return true;
} }
void AuthManager::save_state() { void AuthManager::save_state() {
if (state_ != State::WaitCode) { if (state_ != State::WaitCode && state_ != State::WaitPassword) {
if (state_ != State::Closing) { if (state_ != State::Closing) {
G()->td_db()->get_binlog_pmc()->erase("auth_state"); G()->td_db()->get_binlog_pmc()->erase("auth_state");
} }
return; return;
} }
DbState db_state{state_, api_id_, api_hash_, send_code_helper_, Timestamp::now()}; DbState db_state;
if (state_ == State::WaitCode) {
db_state = DbState::wait_code(api_id_, api_hash_, send_code_helper_);
} else if (state_ == State::WaitPassword) {
db_state = DbState::wait_password(api_id_, api_hash_, wait_password_state_);
} else {
UNREACHABLE();
}
G()->td_db()->get_binlog_pmc()->set("auth_state", log_event_store(db_state).as_slice().str()); G()->td_db()->get_binlog_pmc()->set("auth_state", log_event_store(db_state).as_slice().str());
} }

View File

@ -141,7 +141,15 @@ class AuthManager : public NetActor {
private: private:
static constexpr size_t MAX_NAME_LENGTH = 255; // server side limit static constexpr size_t MAX_NAME_LENGTH = 255; // server side limit
enum class State : int32 { None, WaitPhoneNumber, WaitCode, WaitPassword, Ok, LoggingOut, Closing } state_ = State::None; enum class State : int32 {
None,
WaitPhoneNumber,
WaitCode,
WaitPassword,
Ok,
LoggingOut,
Closing
} state_ = State::None;
enum class NetQueryType { enum class NetQueryType {
None, None,
SignIn, SignIn,
@ -157,13 +165,51 @@ class AuthManager : public NetActor {
DeleteAccount DeleteAccount
}; };
struct WaitPasswordState {
string current_salt_;
string new_salt_;
string hint_;
bool has_recovery_;
string email_address_pattern_;
template <class T>
void store(T &storer) const;
template <class T>
void parse(T &parser);
};
struct DbState { struct DbState {
State state_; State state_;
int32 api_id_; int32 api_id_;
string api_hash_; string api_hash_;
SendCodeHelper send_code_helper_;
Timestamp state_timestamp_; Timestamp state_timestamp_;
// WaitCode
SendCodeHelper send_code_helper_;
//WaitPassword
WaitPasswordState wait_password_state_;
static DbState wait_code(int32 api_id, string api_hash, SendCodeHelper send_code_helper) {
DbState state;
state.state_ = State::WaitCode;
state.api_id_ = api_id;
state.api_hash_ = api_hash;
state.send_code_helper_ = std::move(send_code_helper);
state.state_timestamp_ = Timestamp::now();
return state;
}
static DbState wait_password(int32 api_id, string api_hash, WaitPasswordState wait_password_state) {
DbState state;
state.state_ = State::WaitPassword;
state.api_id_ = api_id;
state.api_hash_ = api_hash;
state.wait_password_state_ = std::move(wait_password_state);
state.state_timestamp_ = Timestamp::now();
return state;
}
template <class T> template <class T>
void store(T &storer) const; void store(T &storer) const;
template <class T> template <class T>
@ -187,12 +233,7 @@ class AuthManager : public NetActor {
string bot_token_; string bot_token_;
uint64 query_id_ = 0; uint64 query_id_ = 0;
// State::WaitPassword WaitPasswordState wait_password_state_;
string current_salt_;
string new_salt_;
string hint_;
bool has_recovery_;
string email_address_pattern_;
bool was_check_bot_token_ = false; bool was_check_bot_token_ = false;
bool is_bot_ = false; bool is_bot_ = false;

View File

@ -45,16 +45,41 @@ void SendCodeHelper::parse(T &parser) {
parse(next_code_info_, parser); parse(next_code_info_, parser);
parse(next_code_timestamp_, parser); parse(next_code_timestamp_, parser);
} }
template <class T>
void AuthManager::WaitPasswordState::store(T &storer) const {
using td::store;
store(current_salt_, storer);
store(new_salt_, storer);
store(hint_, storer);
store(has_recovery_, storer);
store(email_address_pattern_, storer);
}
template <class T>
void AuthManager::WaitPasswordState::parse(T &parser) {
using td::parse;
parse(current_salt_, parser);
parse(new_salt_, parser);
parse(hint_, parser);
parse(has_recovery_, parser);
parse(email_address_pattern_, parser);
}
template <class T> template <class T>
void AuthManager::DbState::store(T &storer) const { void AuthManager::DbState::store(T &storer) const {
using td::store; using td::store;
CHECK(state_ == State::WaitCode);
store(state_, storer); store(state_, storer);
store(api_id_, storer); store(api_id_, storer);
store(api_hash_, storer); store(api_hash_, storer);
store(send_code_helper_, storer);
store(state_timestamp_, storer); store(state_timestamp_, storer);
if (state_ == State::WaitCode) {
store(send_code_helper_, storer);
} else if (state_ == State::WaitPassword) {
store(wait_password_state_, storer);
} else {
UNREACHABLE();
}
} }
template <class T> template <class T>
void AuthManager::DbState::parse(T &parser) { void AuthManager::DbState::parse(T &parser) {
@ -62,7 +87,14 @@ void AuthManager::DbState::parse(T &parser) {
parse(state_, parser); parse(state_, parser);
parse(api_id_, parser); parse(api_id_, parser);
parse(api_hash_, parser); parse(api_hash_, parser);
parse(send_code_helper_, parser);
parse(state_timestamp_, parser); parse(state_timestamp_, parser);
if (state_ == State::WaitCode) {
parse(send_code_helper_, parser);
} else if (state_ == State::WaitPassword) {
parse(wait_password_state_, parser);
} else {
parser.set_error(PSTRING() << "Unexpected " << tag("state", static_cast<int32>(state_)));
}
} }
} // namespace td } // namespace td