Destroy mtproto keys on logout

GitOrigin-RevId: 8ac2bafd2d1897dc0942a33d8406ea8c2e5dfaa7
This commit is contained in:
Arseny Smirnov 2017-12-29 23:34:39 +03:00
parent abbb1a42a7
commit 27770ce060
20 changed files with 249 additions and 67 deletions

View File

@ -66,6 +66,10 @@ msg_new_detailed_info#809db6df answer_msg_id:long bytes:int status:int = MsgDeta
rsa_public_key n:string e:string = RSAPublicKey; rsa_public_key n:string e:string = RSAPublicKey;
destroy_auth_key_ok#f660e1d4 = DestroyAuthKeyRes;
destroy_auth_key_none#0a9f2259 = DestroyAuthKeyRes;
destroy_auth_key_fail#ea109b13 = DestroyAuthKeyRes;
---functions--- ---functions---
req_pq_multi#be7e8ef1 nonce:int128 = ResPQ; req_pq_multi#be7e8ef1 nonce:int128 = ResPQ;
@ -82,6 +86,8 @@ destroy_session#e7512126 session_id:long = DestroySessionRes;
http_wait#9299359f max_delay:int wait_after:int max_wait:int = HttpWait; http_wait#9299359f max_delay:int wait_after:int max_wait:int = HttpWait;
destroy_auth_key#d1435160 = DestroyAuthKeyRes;
//test.useGzipPacked = GzipPacked; //test.useGzipPacked = GzipPacked;
//test.useServerDhInnerData = Server_DH_inner_data; //test.useServerDhInnerData = Server_DH_inner_data;
//test.useNewSessionCreated = NewSession; //test.useNewSessionCreated = NewSession;

Binary file not shown.

View File

@ -136,7 +136,7 @@ class AuthData {
void set_auth_flag(bool auth_flag) { void set_auth_flag(bool auth_flag) {
main_auth_key_.set_auth_flag(auth_flag); main_auth_key_.set_auth_flag(auth_flag);
if (!auth_flag) { if (!auth_flag) {
tmp_auth_key_.set_auth_flag(auth_flag); drop_tmp_auth_key();
} }
} }

View File

@ -33,11 +33,7 @@ class AuthKey {
return was_auth_flag_; return was_auth_flag_;
} }
void set_auth_flag(bool new_auth_flag) { void set_auth_flag(bool new_auth_flag) {
if (new_auth_flag == false) { was_auth_flag_ |= new_auth_flag;
clear();
} else {
was_auth_flag_ = true;
}
auth_flag_ = new_auth_flag; auth_flag_ = new_auth_flag;
} }

View File

@ -17,6 +17,12 @@
#include "td/utils/Time.h" #include "td/utils/Time.h"
namespace td { namespace td {
namespace mtproto_api {
class msg_container {
public:
static const int32 ID = 0x73f1f8dc;
};
} // namespace mtproto_api
namespace mtproto { namespace mtproto {
template <class Object, class ObjectStorer> template <class Object, class ObjectStorer>
@ -65,6 +71,7 @@ using GetFutureSaltsImpl = ObjectImpl<mtproto_api::get_future_salts, TLStorer<mt
using ResendImpl = ObjectImpl<mtproto_api::msg_resend_req, TLObjectStorer<mtproto_api::msg_resend_req>>; using ResendImpl = ObjectImpl<mtproto_api::msg_resend_req, TLObjectStorer<mtproto_api::msg_resend_req>>;
using CancelImpl = ObjectImpl<mtproto_api::rpc_drop_answer, TLStorer<mtproto_api::rpc_drop_answer>>; using CancelImpl = ObjectImpl<mtproto_api::rpc_drop_answer, TLStorer<mtproto_api::rpc_drop_answer>>;
using GetInfoImpl = ObjectImpl<mtproto_api::msgs_state_req, TLObjectStorer<mtproto_api::msgs_state_req>>; using GetInfoImpl = ObjectImpl<mtproto_api::msgs_state_req, TLObjectStorer<mtproto_api::msgs_state_req>>;
using DestroyAuthKeyImpl = ObjectImpl<mtproto_api::destroy_auth_key, TLStorer<mtproto_api::destroy_auth_key>>;
class CancelVectorImpl { class CancelVectorImpl {
public: public:
@ -182,8 +189,8 @@ class CryptoImpl {
public: public:
CryptoImpl(const vector<Query> &to_send, Slice header, vector<int64> &&to_ack, int64 ping_id, int ping_timeout, CryptoImpl(const vector<Query> &to_send, Slice header, vector<int64> &&to_ack, int64 ping_id, int ping_timeout,
int max_delay, int max_after, int max_wait, int future_salt_n, vector<int64> get_info, int max_delay, int max_after, int max_wait, int future_salt_n, vector<int64> get_info,
vector<int64> resend, vector<int64> cancel, AuthData *auth_data, uint64 *container_id, uint64 *get_info_id, vector<int64> resend, vector<int64> cancel, bool destroy_key, AuthData *auth_data, uint64 *container_id,
uint64 *resend_id, uint64 *ping_message_id, uint64 *parent_message_id) uint64 *get_info_id, uint64 *resend_id, uint64 *ping_message_id, uint64 *parent_message_id)
: query_storer_(to_send, header) : query_storer_(to_send, header)
, ack_empty_(to_ack.empty()) , ack_empty_(to_ack.empty())
, ack_storer_(!ack_empty_, mtproto_api::msgs_ack(std::move(to_ack)), auth_data) , ack_storer_(!ack_empty_, mtproto_api::msgs_ack(std::move(to_ack)), auth_data)
@ -197,16 +204,18 @@ class CryptoImpl {
, cancel_not_empty_(!cancel.empty()) , cancel_not_empty_(!cancel.empty())
, cancel_cnt_(static_cast<int32>(cancel.size())) , cancel_cnt_(static_cast<int32>(cancel.size()))
, cancel_storer_(cancel_not_empty_, std::move(cancel), auth_data, true) , cancel_storer_(cancel_not_empty_, std::move(cancel), auth_data, true)
, destroy_key_storer_(destroy_key, mtproto_api::destroy_auth_key(), auth_data, true)
, tmp_storer_(query_storer_, ack_storer_) , tmp_storer_(query_storer_, ack_storer_)
, tmp2_storer_(tmp_storer_, http_wait_storer_) , tmp2_storer_(tmp_storer_, http_wait_storer_)
, tmp3_storer_(tmp2_storer_, get_future_salts_storer_) , tmp3_storer_(tmp2_storer_, get_future_salts_storer_)
, tmp4_storer_(tmp3_storer_, get_info_storer_) , tmp4_storer_(tmp3_storer_, get_info_storer_)
, tmp5_storer_(tmp4_storer_, resend_storer_) , tmp5_storer_(tmp4_storer_, resend_storer_)
, tmp6_storer_(tmp5_storer_, cancel_storer_) , tmp6_storer_(tmp5_storer_, cancel_storer_)
, concat_storer_(tmp6_storer_, ping_storer_) , tmp7_storer_(tmp6_storer_, destroy_key_storer_)
, concat_storer_(tmp7_storer_, ping_storer_)
, cnt_(static_cast<int32>(to_send.size()) + ack_storer_.not_empty() + ping_storer_.not_empty() + , cnt_(static_cast<int32>(to_send.size()) + ack_storer_.not_empty() + ping_storer_.not_empty() +
http_wait_storer_.not_empty() + get_future_salts_storer_.not_empty() + get_info_storer_.not_empty() + http_wait_storer_.not_empty() + get_future_salts_storer_.not_empty() + get_info_storer_.not_empty() +
resend_storer_.not_empty() + cancel_cnt_) resend_storer_.not_empty() + cancel_cnt_ + destroy_key_storer_.not_empty())
, container_storer_(cnt_, concat_storer_) { , container_storer_(cnt_, concat_storer_) {
CHECK(cnt_ != 0); CHECK(cnt_ != 0);
if (get_info_storer_.not_empty() && get_info_id) { if (get_info_storer_.not_empty() && get_info_id) {
@ -252,6 +261,9 @@ class CryptoImpl {
} else if (cancel_storer_.not_empty()) { } else if (cancel_storer_.not_empty()) {
type_ = OnlyCancel; type_ = OnlyCancel;
*parent_message_id = cancel_storer_.get_message_id(); *parent_message_id = cancel_storer_.get_message_id();
} else if (destroy_key_storer_.not_empty()) {
type_ = OnlyDestroyKey;
*parent_message_id = destroy_key_storer_.get_message_id();
} else { } else {
UNREACHABLE(); UNREACHABLE();
} }
@ -284,6 +296,9 @@ class CryptoImpl {
case OnlyGetInfo: case OnlyGetInfo:
return storer.store_storer(get_info_storer_); return storer.store_storer(get_info_storer_);
case OnlyDestroyKey:
return storer.store_storer(destroy_key_storer_);
default: default:
storer.store_binary(message_id_); storer.store_binary(message_id_);
storer.store_binary(seq_no_); storer.store_binary(seq_no_);
@ -306,12 +321,14 @@ class CryptoImpl {
bool cancel_not_empty_; bool cancel_not_empty_;
int32 cancel_cnt_; int32 cancel_cnt_;
PacketStorer<CancelVectorImpl> cancel_storer_; PacketStorer<CancelVectorImpl> cancel_storer_;
PacketStorer<DestroyAuthKeyImpl> destroy_key_storer_;
ConcatStorer tmp_storer_; ConcatStorer tmp_storer_;
ConcatStorer tmp2_storer_; ConcatStorer tmp2_storer_;
ConcatStorer tmp3_storer_; ConcatStorer tmp3_storer_;
ConcatStorer tmp4_storer_; ConcatStorer tmp4_storer_;
ConcatStorer tmp5_storer_; ConcatStorer tmp5_storer_;
ConcatStorer tmp6_storer_; ConcatStorer tmp6_storer_;
ConcatStorer tmp7_storer_;
ConcatStorer concat_storer_; ConcatStorer concat_storer_;
int32 cnt_; int32 cnt_;
PacketStorer<ContainerImpl> container_storer_; PacketStorer<ContainerImpl> container_storer_;
@ -324,6 +341,7 @@ class CryptoImpl {
OnlyResend, OnlyResend,
OnlyCancel, OnlyCancel,
OnlyGetInfo, OnlyGetInfo,
OnlyDestroyKey,
Mixed Mixed
}; };
Type type_; Type type_;

View File

@ -270,14 +270,29 @@ Status SessionConnection::on_packet(const MsgInfo &info, const T &packet) {
LOG(ERROR) << "Unsupported: " << to_string(packet); LOG(ERROR) << "Unsupported: " << to_string(packet);
return Status::OK(); return Status::OK();
} }
Status SessionConnection::on_packet(const MsgInfo &info, const mtproto_api::destroy_auth_key_ok &destroy_auth_key) {
return on_destroy_auth_key(destroy_auth_key);
}
Status SessionConnection::on_packet(const MsgInfo &info, const mtproto_api::destroy_auth_key_none &destroy_auth_key) {
return on_destroy_auth_key(destroy_auth_key);
}
Status SessionConnection::on_packet(const MsgInfo &info, const mtproto_api::destroy_auth_key_fail &destroy_auth_key) {
return on_destroy_auth_key(destroy_auth_key);
}
Status SessionConnection::on_destroy_auth_key(const mtproto_api::DestroyAuthKeyRes &destroy_auth_key) {
CHECK(need_destroy_auth_key_);
LOG(INFO) << to_string(destroy_auth_key);
return callback_->on_destroy_auth_key();
}
Status SessionConnection::on_packet(const MsgInfo &info, const mtproto_api::rpc_error &rpc_error) { Status SessionConnection::on_packet(const MsgInfo &info, const mtproto_api::rpc_error &rpc_error) {
return on_packet(info, 0, rpc_error); return on_packet(info, 0, rpc_error);
} }
Status SessionConnection::on_packet(const MsgInfo &info, uint64 req_msg_id, const mtproto_api::rpc_error &rpc_error) { Status SessionConnection::on_packet(const MsgInfo &info, uint64 req_msg_id, const mtproto_api::rpc_error &rpc_error) {
VLOG(mtproto) << "ERROR [code:" << rpc_error.error_code_ << "] [msg:" << rpc_error.error_message_.str().c_str() VLOG(mtproto) << "ERROR [code:" << rpc_error.error_code_ << "] [msg:" << rpc_error.error_message_.str().c_str() << "]"
<< "]"; << " " << tag("req_msg_id", req_msg_id);
if (req_msg_id != 0) { if (req_msg_id != 0) {
callback_->on_message_result_error(req_msg_id, rpc_error.error_code_, as_buffer_slice(rpc_error.error_message_)); callback_->on_message_result_error(req_msg_id, rpc_error.error_code_, as_buffer_slice(rpc_error.error_message_));
} else { } else {
@ -524,6 +539,8 @@ Status SessionConnection::on_main_packet(const PacketInfo &info, Slice packet) {
void SessionConnection::on_message_failed(uint64 id, Status status) { void SessionConnection::on_message_failed(uint64 id, Status status) {
callback_->on_message_failed(id, std::move(status)); callback_->on_message_failed(id, std::move(status));
sent_destroy_auth_key_ = false;
if (id == last_ping_message_id_ || id == last_ping_container_id_) { if (id == last_ping_message_id_ || id == last_ping_container_id_) {
// restart ping immediately // restart ping immediately
last_ping_at_ = 0; last_ping_at_ = 0;
@ -613,6 +630,10 @@ bool SessionConnection::must_flush_packet() {
relax_timeout_at(&flush_packet_at_, get_future_salts_at); relax_timeout_at(&flush_packet_at_, get_future_salts_at);
} }
if (has_salt && need_destroy_auth_key_ && !sent_destroy_auth_key_) {
return true;
}
return false; return false;
} }
@ -741,6 +762,11 @@ void SessionConnection::cancel_answer(int64 message_id) {
to_cancel_answer_.push_back(message_id); to_cancel_answer_.push_back(message_id);
} }
void SessionConnection::destroy_key() {
LOG(INFO) << "need_destroy_key = true";
need_destroy_auth_key_ = true;
}
std::pair<uint64, BufferSlice> SessionConnection::encrypted_bind(int64 perm_key, int64 nonce, int32 expire_at) { std::pair<uint64, BufferSlice> SessionConnection::encrypted_bind(int64 perm_key, int64 nonce, int32 expire_at) {
int64 temp_key = auth_data_->get_tmp_auth_key().id(); int64 temp_key = auth_data_->get_tmp_auth_key().id();
@ -839,17 +865,21 @@ void SessionConnection::flush_packet() {
to_send_.erase(to_send_.begin(), to_send_.begin() + send_till); to_send_.erase(to_send_.begin(), to_send_.begin() + send_till);
} }
bool destroy_auth_key = need_destroy_auth_key_ && !sent_destroy_auth_key_;
if (queries.empty() && to_ack_.empty() && ping_id == 0 && max_delay < 0 && future_salt_n == 0 && if (queries.empty() && to_ack_.empty() && ping_id == 0 && max_delay < 0 && future_salt_n == 0 &&
to_resend_answer_.empty() && to_cancel_answer_.empty() && to_get_state_info_.empty()) { to_resend_answer_.empty() && to_cancel_answer_.empty() && to_get_state_info_.empty() && !destroy_auth_key) {
force_send_at_ = 0; force_send_at_ = 0;
return; return;
} }
sent_destroy_auth_key_ |= destroy_auth_key;
VLOG(mtproto) << "Sent packet: " << tag("query_count", queries.size()) << tag("ack_cnt", to_ack_.size()) VLOG(mtproto) << "Sent packet: " << tag("query_count", queries.size()) << tag("ack_cnt", to_ack_.size())
<< tag("ping", ping_id != 0) << tag("http_wait", max_delay >= 0) << tag("ping", ping_id != 0) << tag("http_wait", max_delay >= 0)
<< tag("future_salt", future_salt_n > 0) << tag("get_info", to_get_state_info_.size()) << tag("future_salt", future_salt_n > 0) << tag("get_info", to_get_state_info_.size())
<< tag("resend", to_resend_answer_.size()) << tag("cancel", to_cancel_answer_.size()) << tag("resend", to_resend_answer_.size()) << tag("cancel", to_cancel_answer_.size())
<< tag("auth_id", auth_data_->get_auth_key().id()); << tag("destroy_key", destroy_auth_key) << tag("auth_id", auth_data_->get_auth_key().id());
auto cut_tail = [](auto &v, size_t size, Slice name) { auto cut_tail = [](auto &v, size_t size, Slice name) {
if (size >= v.size()) { if (size >= v.size()) {
@ -878,8 +908,8 @@ void SessionConnection::flush_packet() {
uint64 parent_message_id = 0; uint64 parent_message_id = 0;
auto storer = PacketStorer<CryptoImpl>( auto storer = PacketStorer<CryptoImpl>(
queries, auth_data_->get_header(), std::move(to_ack), ping_id, ping_disconnect_delay() + 2, max_delay, queries, auth_data_->get_header(), std::move(to_ack), ping_id, ping_disconnect_delay() + 2, max_delay,
max_after, max_wait, future_salt_n, to_get_state_info, to_resend_answer, to_cancel_answer, auth_data_, max_after, max_wait, future_salt_n, to_get_state_info, to_resend_answer, to_cancel_answer, destroy_auth_key,
&container_id, &get_state_info_id, &resend_answer_id, &ping_message_id, &parent_message_id); auth_data_, &container_id, &get_state_info_id, &resend_answer_id, &ping_message_id, &parent_message_id);
auto quick_ack_token = use_quick_ack ? parent_message_id : 0; auto quick_ack_token = use_quick_ack ? parent_message_id : 0;
send_crypto(storer, quick_ack_token); send_crypto(storer, quick_ack_token);

View File

@ -25,11 +25,6 @@
namespace td { namespace td {
namespace mtproto_api { namespace mtproto_api {
class msg_container {
public:
static const int32 ID = 0x73f1f8dc;
};
class rpc_error; class rpc_error;
class new_session_created; class new_session_created;
class bad_msg_notification; class bad_msg_notification;
@ -42,6 +37,10 @@ class msgs_state_info;
class msgs_all_info; class msgs_all_info;
class msg_detailed_info; class msg_detailed_info;
class msg_new_detailed_info; class msg_new_detailed_info;
class DestroyAuthKeyRes;
class destroy_auth_key_ok;
class destroy_auth_key_fail;
class destroy_auth_key_none;
} // namespace mtproto_api } // namespace mtproto_api
namespace mtproto { namespace mtproto {
@ -78,6 +77,7 @@ class SessionConnection
void get_state_info(int64 message_id); void get_state_info(int64 message_id);
void resend_answer(int64 message_id); void resend_answer(int64 message_id);
void cancel_answer(int64 message_id); void cancel_answer(int64 message_id);
void destroy_key();
void set_online(bool online_flag); void set_online(bool online_flag);
@ -109,6 +109,8 @@ class SessionConnection
virtual void on_message_result_error(uint64 id, int code, BufferSlice descr) = 0; virtual void on_message_result_error(uint64 id, int code, BufferSlice descr) = 0;
virtual void on_message_failed(uint64 id, Status status) = 0; virtual void on_message_failed(uint64 id, Status status) = 0;
virtual void on_message_info(uint64 id, int32 state, uint64 answer_id, int32 answer_size) = 0; virtual void on_message_info(uint64 id, int32 state, uint64 answer_id, int32 answer_size) = 0;
virtual Status on_destroy_auth_key() = 0;
}; };
double flush(SessionConnection::Callback *callback); double flush(SessionConnection::Callback *callback);
@ -168,6 +170,9 @@ class SessionConnection
uint64 last_ping_message_id_ = 0; uint64 last_ping_message_id_ = 0;
uint64 last_ping_container_id_ = 0; uint64 last_ping_container_id_ = 0;
bool need_destroy_auth_key_{false};
bool sent_destroy_auth_key_{false};
double wakeup_at_ = 0; double wakeup_at_ = 0;
double flush_packet_at_ = 0; double flush_packet_at_ = 0;
@ -222,6 +227,12 @@ class SessionConnection
Status on_packet(const MsgInfo &info, const mtproto_api::msg_detailed_info &msg_detailed_info) TD_WARN_UNUSED_RESULT; Status on_packet(const MsgInfo &info, const mtproto_api::msg_detailed_info &msg_detailed_info) TD_WARN_UNUSED_RESULT;
Status on_packet(const MsgInfo &info, Status on_packet(const MsgInfo &info,
const mtproto_api::msg_new_detailed_info &msg_new_detailed_info) TD_WARN_UNUSED_RESULT; const mtproto_api::msg_new_detailed_info &msg_new_detailed_info) TD_WARN_UNUSED_RESULT;
Status on_packet(const MsgInfo &info, const mtproto_api::destroy_auth_key_ok &destroy_auth_key) TD_WARN_UNUSED_RESULT;
Status on_packet(const MsgInfo &info,
const mtproto_api::destroy_auth_key_none &destroy_auth_key) TD_WARN_UNUSED_RESULT;
Status on_packet(const MsgInfo &info,
const mtproto_api::destroy_auth_key_fail &destroy_auth_key) TD_WARN_UNUSED_RESULT;
Status on_destroy_auth_key(const mtproto_api::DestroyAuthKeyRes &destroy_auth_key);
Status on_slice_packet(const MsgInfo &info, Slice packet) TD_WARN_UNUSED_RESULT; Status on_slice_packet(const MsgInfo &info, Slice packet) TD_WARN_UNUSED_RESULT;
Status on_main_packet(const PacketInfo &info, Slice packet) TD_WARN_UNUSED_RESULT; Status on_main_packet(const PacketInfo &info, Slice packet) TD_WARN_UNUSED_RESULT;

View File

@ -431,6 +431,8 @@ AuthManager::AuthManager(int32 api_id, const string &api_hash, ActorShared<> par
} }
} else if (auth_str == "logout") { } else if (auth_str == "logout") {
update_state(State::LoggingOut); update_state(State::LoggingOut);
} else if (auth_str == "destroy") {
update_state(State::DestroyingKeys);
} else { } else {
if (!load_state()) { if (!load_state()) {
update_state(State::WaitPhoneNumber); update_state(State::WaitPhoneNumber);
@ -441,6 +443,8 @@ AuthManager::AuthManager(int32 api_id, const string &api_hash, ActorShared<> par
void AuthManager::start_up() { void AuthManager::start_up() {
if (state_ == State::LoggingOut) { if (state_ == State::LoggingOut) {
start_net_query(NetQueryType::LogOut, G()->net_query_creator().create(create_storer(telegram_api::auth_logOut()))); start_net_query(NetQueryType::LogOut, G()->net_query_creator().create(create_storer(telegram_api::auth_logOut())));
} else if (state_ == State::DestroyingKeys) {
destroy_auth_keys();
} }
} }
void AuthManager::tear_down() { void AuthManager::tear_down() {
@ -475,6 +479,7 @@ tl_object_ptr<td_api::AuthorizationState> AuthManager::get_authorization_state_o
return make_tl_object<td_api::authorizationStateWaitPassword>( return make_tl_object<td_api::authorizationStateWaitPassword>(
wait_password_state_.hint_, wait_password_state_.has_recovery_, wait_password_state_.email_address_pattern_); wait_password_state_.hint_, wait_password_state_.has_recovery_, wait_password_state_.email_address_pattern_);
case State::LoggingOut: case State::LoggingOut:
case State::DestroyingKeys:
return make_tl_object<td_api::authorizationStateLoggingOut>(); return make_tl_object<td_api::authorizationStateLoggingOut>();
case State::Closing: case State::Closing:
return make_tl_object<td_api::authorizationStateClosing>(); return make_tl_object<td_api::authorizationStateClosing>();
@ -655,7 +660,7 @@ void AuthManager::logout(uint64 query_id) {
if (state_ == State::Closing) { if (state_ == State::Closing) {
return on_query_error(query_id, Status::Error(8, "Already logged out")); return on_query_error(query_id, Status::Error(8, "Already logged out"));
} }
if (state_ == State::LoggingOut) { if (state_ == State::LoggingOut || state_ == State::DestroyingKeys) {
return on_query_error(query_id, Status::Error(8, "Already logging out")); return on_query_error(query_id, Status::Error(8, "Already logging out"));
} }
on_new_query(query_id); on_new_query(query_id);
@ -663,7 +668,6 @@ void AuthManager::logout(uint64 query_id) {
update_state(State::LoggingOut); update_state(State::LoggingOut);
// TODO: could skip full logout if still no authorization // TODO: could skip full logout if still no authorization
// TODO: send auth.cancelCode if state_ == State::WaitCode // TODO: send auth.cancelCode if state_ == State::WaitCode
send_closure_later(G()->td(), &Td::destroy);
on_query_ok(); on_query_ok();
} else { } else {
LOG(INFO) << "Logging out"; LOG(INFO) << "Logging out";
@ -844,11 +848,29 @@ void AuthManager::on_log_out_result(NetQueryPtr &result) {
} }
LOG_IF(ERROR, status.is_error()) << "auth.logOut failed: " << status; LOG_IF(ERROR, status.is_error()) << "auth.logOut failed: " << status;
// state_ will stay logout, so no queries will work. // state_ will stay logout, so no queries will work.
send_closure_later(G()->td(), &Td::destroy); destroy_auth_keys();
if (query_id_ != 0) { if (query_id_ != 0) {
on_query_ok(); on_query_ok();
} }
} }
void AuthManager::on_authorization_lost() {
destroy_auth_keys();
}
void AuthManager::destroy_auth_keys() {
if (state_ == State::Closing) {
return;
}
update_state(State::DestroyingKeys);
auto promise = PromiseCreator::lambda(
[](Unit) {
G()->net_query_dispatcher().destroy_auth_keys(PromiseCreator::lambda(
[](Unit) { send_closure_later(G()->td(), &Td::destroy); }, PromiseCreator::Ignore()));
},
PromiseCreator::Ignore());
G()->td_db()->get_binlog_pmc()->set("auth", "destroy");
G()->td_db()->get_binlog_pmc()->force_sync(std::move(promise));
}
void AuthManager::on_delete_account_result(NetQueryPtr &result) { void AuthManager::on_delete_account_result(NetQueryPtr &result) {
Status status; Status status;
@ -871,8 +893,7 @@ void AuthManager::on_delete_account_result(NetQueryPtr &result) {
on_query_error(std::move(status)); on_query_error(std::move(status));
} }
} else { } else {
update_state(State::LoggingOut); destroy_auth_keys();
send_closure_later(G()->td(), &Td::destroy);
if (query_id_ != 0) { if (query_id_ != 0) {
on_query_ok(); on_query_ok();
} }

View File

@ -166,6 +166,7 @@ class AuthManager : public NetActor {
void logout(uint64 query_id); void logout(uint64 query_id);
void delete_account(uint64 query_id, const string &reason); void delete_account(uint64 query_id, const string &reason);
void on_authorization_lost();
void on_closing(); void on_closing();
// can return nullptr if state isn't initialized yet // can return nullptr if state isn't initialized yet
@ -181,6 +182,7 @@ class AuthManager : public NetActor {
WaitPassword, WaitPassword,
Ok, Ok,
LoggingOut, LoggingOut,
DestroyingKeys,
Closing Closing
} state_ = State::None; } state_ = State::None;
enum class NetQueryType : int32 { enum class NetQueryType : int32 {
@ -291,6 +293,8 @@ class AuthManager : public NetActor {
void on_query_ok(); void on_query_ok();
void start_net_query(NetQueryType net_query_type, NetQueryPtr net_query); void start_net_query(NetQueryType net_query_type, NetQueryPtr net_query);
void destroy_auth_keys();
void on_send_code_result(NetQueryPtr &result); void on_send_code_result(NetQueryPtr &result);
void on_get_password_result(NetQueryPtr &result); void on_get_password_result(NetQueryPtr &result);
void on_request_password_recovery_result(NetQueryPtr &result); void on_request_password_recovery_result(NetQueryPtr &result);

View File

@ -311,9 +311,10 @@ ActorOwn<> get_full_config(DcId dc_id, IPAddress ip_address, Promise<FullConfig>
if (G()->is_test_dc()) { if (G()->is_test_dc()) {
int_dc_id += 10000; int_dc_id += 10000;
} }
session_ = create_actor<Session>("ConfigSession", std::move(session_callback), std::move(auth_data), int_dc_id, session_ =
false /*is_main*/, true /*use_pfs*/, false /*is_cdn*/, mtproto::AuthKey(), create_actor<Session>("ConfigSession", std::move(session_callback), std::move(auth_data), int_dc_id,
std::vector<mtproto::ServerSalt>()); false /*is_main*/, true /*use_pfs*/, false /*is_cdn*/, false /*need_destroy_auth_key*/,
mtproto::AuthKey(), std::vector<mtproto::ServerSalt>());
auto query = G()->net_query_creator().create(create_storer(telegram_api::help_getConfig()), DcId::empty(), auto query = G()->net_query_creator().create(create_storer(telegram_api::help_getConfig()), DcId::empty(),
NetQuery::Type::Common, NetQuery::AuthFlag::Off, NetQuery::Type::Common, NetQuery::AuthFlag::Off,
NetQuery::GzipFlag::On, 60 * 60 * 24); NetQuery::GzipFlag::On, 60 * 60 * 24);

View File

@ -192,11 +192,34 @@ void DcAuthManager::dc_loop(DcInfo &dc) {
} }
} }
void DcAuthManager::destroy(Promise<> promise) {
destroy_promise_ = std::move(promise);
loop();
}
void DcAuthManager::destroy_loop() {
if (!destroy_promise_) {
return;
}
bool is_ready{true};
for (auto &dc : dcs_) {
is_ready &= dc.auth_state == AuthState::Empty;
}
if (is_ready) {
LOG(INFO) << "Destroy auth keys loop is ready, all keys are destroyed";
destroy_promise_.set_value(Unit());
} else {
LOG(ERROR) << "NOT READY";
}
}
void DcAuthManager::loop() { void DcAuthManager::loop() {
if (close_flag_) { if (close_flag_) {
VLOG(dc) << "Skip loop because close_flag"; VLOG(dc) << "Skip loop because close_flag";
return; return;
} }
destroy_loop();
if (!main_dc_id_.is_exact()) { if (!main_dc_id_.is_exact()) {
VLOG(dc) << "Skip loop because main_dc_id is unknown"; VLOG(dc) << "Skip loop because main_dc_id is unknown";
return; return;
@ -205,6 +228,7 @@ void DcAuthManager::loop() {
if (!main_dc || main_dc->auth_state != AuthState::OK) { if (!main_dc || main_dc->auth_state != AuthState::OK) {
if (was_auth_) { if (was_auth_) {
G()->shared_config().set_option_boolean("auth", false); G()->shared_config().set_option_boolean("auth", false);
destroy_loop();
} }
VLOG(dc) << "Skip loop because auth state of main dc " << main_dc_id_.get_raw_id() << " is " VLOG(dc) << "Skip loop because auth state of main dc " << main_dc_id_.get_raw_id() << " is "
<< (main_dc != nullptr ? (PSTRING() << main_dc->auth_state) : "unknown"); << (main_dc != nullptr ? (PSTRING() << main_dc->auth_state) : "unknown");

View File

@ -10,7 +10,6 @@
#include "td/telegram/net/AuthDataShared.h" #include "td/telegram/net/AuthDataShared.h"
#include "td/telegram/net/DcId.h" #include "td/telegram/net/DcId.h"
#include "td/telegram/net/NetQuery.h" #include "td/telegram/net/NetQuery.h"
#include "td/actor/actor.h" #include "td/actor/actor.h"
#include "td/utils/buffer.h" #include "td/utils/buffer.h"
@ -26,6 +25,7 @@ class DcAuthManager : public NetQueryCallback {
void add_dc(std::shared_ptr<AuthDataShared> auth_data); void add_dc(std::shared_ptr<AuthDataShared> auth_data);
void update_main_dc(DcId new_main_dc_id); void update_main_dc(DcId new_main_dc_id);
void destroy(Promise<> promise);
private: private:
struct DcInfo { struct DcInfo {
@ -43,9 +43,10 @@ class DcAuthManager : public NetQueryCallback {
ActorShared<> parent_; ActorShared<> parent_;
std::vector<DcInfo> dcs_; std::vector<DcInfo> dcs_;
bool was_auth_ = false; bool was_auth_{false};
DcId main_dc_id_; DcId main_dc_id_;
bool close_flag_ = false; bool close_flag_{false};
Promise<> destroy_promise_;
DcInfo &get_dc(int32 dc_id); DcInfo &get_dc(int32 dc_id);
DcInfo *find_dc(int32 dc_id); DcInfo *find_dc(int32 dc_id);
@ -55,6 +56,7 @@ class DcAuthManager : public NetQueryCallback {
void on_result(NetQueryPtr result) override; void on_result(NetQueryPtr result) override;
void dc_loop(DcInfo &dc); void dc_loop(DcInfo &dc);
void destroy_loop();
void loop() override; void loop() override;
}; };

View File

@ -127,12 +127,14 @@ Status NetQueryDispatcher::wait_dc_init(DcId dc_id, bool force) {
if (should_init) { if (should_init) {
std::lock_guard<std::mutex> guard(main_dc_id_mutex_); std::lock_guard<std::mutex> guard(main_dc_id_mutex_);
if (stop_flag_.load(std::memory_order_relaxed)) { if (stop_flag_.load(std::memory_order_relaxed) || need_destroy_auth_key_) {
return Status::Error("Closing"); return Status::Error("Closing");
} }
// init dc // init dc
dc.id_ = dc_id;
decltype(common_public_rsa_key_) public_rsa_key; decltype(common_public_rsa_key_) public_rsa_key;
bool is_cdn = false; bool is_cdn = false;
bool need_destroy_key = false;
if (dc_id.is_internal()) { if (dc_id.is_internal()) {
public_rsa_key = common_public_rsa_key_; public_rsa_key = common_public_rsa_key_;
} else { } else {
@ -150,18 +152,18 @@ Status NetQueryDispatcher::wait_dc_init(DcId dc_id, bool force) {
int32 upload_session_count = raw_dc_id != 2 && raw_dc_id != 4 ? 8 : 4; int32 upload_session_count = raw_dc_id != 2 && raw_dc_id != 4 ? 8 : 4;
int32 download_session_count = 2; int32 download_session_count = 2;
int32 download_small_session_count = 2; int32 download_small_session_count = 2;
dc.main_session_ = dc.main_session_ = create_actor<SessionMultiProxy>(PSLICE() << "SessionMultiProxy:" << raw_dc_id << ":main",
create_actor<SessionMultiProxy>(PSLICE() << "SessionMultiProxy:" << raw_dc_id << ":main", session_count, session_count, auth_data, raw_dc_id == main_dc_id_, use_pfs,
auth_data, raw_dc_id == main_dc_id_, use_pfs, false, false, is_cdn); false, false, is_cdn, need_destroy_key);
dc.upload_session_ = create_actor_on_scheduler<SessionMultiProxy>( dc.upload_session_ = create_actor_on_scheduler<SessionMultiProxy>(
PSLICE() << "SessionMultiProxy:" << raw_dc_id << ":upload", slow_net_scheduler_id, upload_session_count, PSLICE() << "SessionMultiProxy:" << raw_dc_id << ":upload", slow_net_scheduler_id, upload_session_count,
auth_data, false, use_pfs, false, true, is_cdn); auth_data, false, use_pfs, false, true, is_cdn, need_destroy_key);
dc.download_session_ = create_actor_on_scheduler<SessionMultiProxy>( dc.download_session_ = create_actor_on_scheduler<SessionMultiProxy>(
PSLICE() << "SessionMultiProxy:" << raw_dc_id << ":download", slow_net_scheduler_id, download_session_count, PSLICE() << "SessionMultiProxy:" << raw_dc_id << ":download", slow_net_scheduler_id, download_session_count,
auth_data, false, use_pfs, true, true, is_cdn); auth_data, false, use_pfs, true, true, is_cdn, need_destroy_key);
dc.download_small_session_ = create_actor_on_scheduler<SessionMultiProxy>( dc.download_small_session_ = create_actor_on_scheduler<SessionMultiProxy>(
PSLICE() << "SessionMultiProxy:" << raw_dc_id << ":download_small", slow_net_scheduler_id, PSLICE() << "SessionMultiProxy:" << raw_dc_id << ":download_small", slow_net_scheduler_id,
download_small_session_count, auth_data, false, use_pfs, true, true, is_cdn); download_small_session_count, auth_data, false, use_pfs, true, true, is_cdn, need_destroy_key);
dc.is_inited_ = true; dc.is_inited_ = true;
if (dc_id.is_internal()) { if (dc_id.is_internal()) {
send_closure_later(dc_auth_manager_, &DcAuthManager::add_dc, std::move(auth_data)); send_closure_later(dc_auth_manager_, &DcAuthManager::add_dc, std::move(auth_data));
@ -212,6 +214,18 @@ void NetQueryDispatcher::update_session_count() {
} }
} }
} }
void NetQueryDispatcher::destroy_auth_keys(Promise<> promise) {
std::lock_guard<std::mutex> guard(main_dc_id_mutex_);
LOG(INFO) << "Destory auth keys";
need_destroy_auth_key_ = true;
for (size_t i = 1; i < MAX_DC_COUNT; i++) {
if (is_dc_inited(narrow_cast<int32>(i)) && dcs_[i - 1].id_.is_internal()) {
send_closure_later(dcs_[i - 1].main_session_, &SessionMultiProxy::update_destroy_auth_key,
need_destroy_auth_key_);
}
}
send_closure_later(dc_auth_manager_, &DcAuthManager::destroy, std::move(promise));
}
void NetQueryDispatcher::update_use_pfs() { void NetQueryDispatcher::update_use_pfs() {
std::lock_guard<std::mutex> guard(main_dc_id_mutex_); std::lock_guard<std::mutex> guard(main_dc_id_mutex_);

View File

@ -11,6 +11,7 @@
#include "td/telegram/net/NetQuery.h" #include "td/telegram/net/NetQuery.h"
#include "td/actor/actor.h" #include "td/actor/actor.h"
#include "td/actor/PromiseFuture.h"
#include "td/utils/common.h" #include "td/utils/common.h"
#include "td/utils/ScopeGuard.h" #include "td/utils/ScopeGuard.h"
@ -46,6 +47,7 @@ class NetQueryDispatcher {
void stop(); void stop();
void update_session_count(); void update_session_count();
void destroy_auth_keys(Promise<> promise);
void update_use_pfs(); void update_use_pfs();
void update_mtproto_header(); void update_mtproto_header();
@ -57,9 +59,11 @@ class NetQueryDispatcher {
private: private:
std::atomic<bool> stop_flag_{false}; std::atomic<bool> stop_flag_{false};
bool need_destroy_auth_key_{false};
ActorOwn<NetQueryDelayer> delayer_; ActorOwn<NetQueryDelayer> delayer_;
ActorOwn<DcAuthManager> dc_auth_manager_; ActorOwn<DcAuthManager> dc_auth_manager_;
struct Dc { struct Dc {
DcId id_;
std::atomic<bool> is_valid_{false}; std::atomic<bool> is_valid_{false};
std::atomic<bool> is_inited_{false}; // TODO: cache in scheduler local storage :D std::atomic<bool> is_inited_{false}; // TODO: cache in scheduler local storage :D

View File

@ -108,10 +108,15 @@ class GenAuthKeyActor : public Actor {
} // namespace detail } // namespace detail
Session::Session(unique_ptr<Callback> callback, std::shared_ptr<AuthDataShared> shared_auth_data, int32 dc_id, Session::Session(unique_ptr<Callback> callback, std::shared_ptr<AuthDataShared> shared_auth_data, int32 dc_id,
bool is_main, bool use_pfs, bool is_cdn, const mtproto::AuthKey &tmp_auth_key, bool is_main, bool use_pfs, bool is_cdn, bool need_destroy, const mtproto::AuthKey &tmp_auth_key,
std::vector<mtproto::ServerSalt> server_salts) std::vector<mtproto::ServerSalt> server_salts)
: dc_id_(dc_id), is_main_(is_main), is_cdn_(is_cdn) { : dc_id_(dc_id), is_main_(is_main), is_cdn_(is_cdn) {
VLOG(dc) << "Start connection"; VLOG(dc) << "Start connection";
need_destroy_ = need_destroy;
if (need_destroy) {
use_pfs = false;
CHECK(!is_cdn);
}
shared_auth_data_ = std::move(shared_auth_data); shared_auth_data_ = std::move(shared_auth_data);
auth_data_.set_use_pfs(use_pfs); auth_data_.set_use_pfs(use_pfs);
@ -141,6 +146,10 @@ Session::Session(unique_ptr<Callback> callback, std::shared_ptr<AuthDataShared>
last_activity_timestamp_ = Time::now(); last_activity_timestamp_ = Time::now();
} }
bool Session::can_destroy_auth_key() {
return need_destroy_;
}
void Session::start_up() { void Session::start_up() {
class StateCallback : public StateManager::Callback { class StateCallback : public StateManager::Callback {
public: public:
@ -415,6 +424,9 @@ void Session::on_closed(Status status) {
auth_data_.drop_main_auth_key(); auth_data_.drop_main_auth_key();
on_auth_key_updated(); on_auth_key_updated();
on_session_failed(std::move(status)); on_session_failed(std::move(status));
} else if (need_destroy_) {
auth_data_.drop_main_auth_key();
on_auth_key_updated();
} }
} }
@ -774,6 +786,11 @@ void Session::on_message_info(uint64 id, int32 state, uint64 answer_id, int32 an
current_info_->connection->resend_answer(answer_id); current_info_->connection->resend_answer(answer_id);
} }
} }
Status Session::on_destroy_auth_key() {
auth_data_.drop_main_auth_key();
on_auth_key_updated();
return Status::Error("Close because of on_destroy_auth_key");
}
bool Session::has_queries() const { bool Session::has_queries() const {
return !pending_invoke_after_queries_.empty() || !pending_queries_.empty() || !sent_queries_.empty(); return !pending_invoke_after_queries_.empty() || !pending_queries_.empty() || !sent_queries_.empty();
@ -993,7 +1010,8 @@ bool Session::need_send_bind_key() {
return auth_data_.use_pfs() && !auth_data_.get_bind_flag() && auth_data_.get_tmp_auth_key().id() != tmp_auth_key_id_; return auth_data_.use_pfs() && !auth_data_.get_bind_flag() && auth_data_.get_tmp_auth_key().id() != tmp_auth_key_id_;
} }
bool Session::need_send_query() { bool Session::need_send_query() {
return !close_flag_ && (!auth_data_.use_pfs() || auth_data_.get_bind_flag()) && !pending_queries_.empty(); return !close_flag_ && (!auth_data_.use_pfs() || auth_data_.get_bind_flag()) && !pending_queries_.empty() &&
!can_destroy_auth_key();
} }
bool Session::connection_send_bind_key(ConnectionInfo *info) { bool Session::connection_send_bind_key(ConnectionInfo *info) {
CHECK(info->state != ConnectionInfo::State::Empty); CHECK(info->state != ConnectionInfo::State::Empty);
@ -1116,6 +1134,9 @@ void Session::create_gen_auth_key_actor(HandshakeId handshake_id) {
} }
void Session::auth_loop() { void Session::auth_loop() {
if (can_destroy_auth_key()) {
return;
}
if (auth_data_.need_main_auth_key()) { if (auth_data_.need_main_auth_key()) {
create_gen_auth_key_actor(MainAuthKeyHandshake); create_gen_auth_key_actor(MainAuthKeyHandshake);
} }
@ -1133,7 +1154,8 @@ void Session::loop() {
if (cached_connection_timestamp_ < Time::now_cached() - 10) { if (cached_connection_timestamp_ < Time::now_cached() - 10) {
cached_connection_.reset(); cached_connection_.reset();
} }
if (!is_main_ && !has_queries() && last_activity_timestamp_ < Time::now_cached() - ACTIVITY_TIMEOUT) { if (!is_main_ && !has_queries() && !need_destroy_ &&
last_activity_timestamp_ < Time::now_cached() - ACTIVITY_TIMEOUT) {
on_session_failed(Status::OK()); on_session_failed(Status::OK());
} }
@ -1179,6 +1201,11 @@ void Session::loop() {
connection_send_bind_key(&main_connection_); connection_send_bind_key(&main_connection_);
need_flush = true; need_flush = true;
} }
if (can_destroy_auth_key()) {
if (main_connection_.connection) {
main_connection_.connection->destroy_key();
}
}
} }
if (need_flush) { if (need_flush) {
connection_flush(&main_connection_); connection_flush(&main_connection_);

View File

@ -62,7 +62,7 @@ class Session final
}; };
Session(unique_ptr<Callback> callback, std::shared_ptr<AuthDataShared> shared_auth_data, int32 dc_id, bool is_main, Session(unique_ptr<Callback> callback, std::shared_ptr<AuthDataShared> shared_auth_data, int32 dc_id, bool is_main,
bool use_pfs, bool is_cdn, const mtproto::AuthKey &tmp_auth_key, bool use_pfs, bool is_cdn, bool need_destroy, const mtproto::AuthKey &tmp_auth_key,
std::vector<mtproto::ServerSalt> server_salts); std::vector<mtproto::ServerSalt> server_salts);
void send(NetQueryPtr &&query); void send(NetQueryPtr &&query);
void on_network(bool network_flag, uint32 network_generation); void on_network(bool network_flag, uint32 network_generation);
@ -101,6 +101,7 @@ class Session final
enum class Mode : int8 { Tcp, Http } mode_ = Mode::Tcp; enum class Mode : int8 { Tcp, Http } mode_ = Mode::Tcp;
bool is_main_; bool is_main_;
bool is_cdn_; bool is_cdn_;
bool need_destroy_;
bool was_on_network_ = false; bool was_on_network_ = false;
bool network_flag_ = false; bool network_flag_ = false;
uint32 network_generation_ = 0; uint32 network_generation_ = 0;
@ -193,6 +194,8 @@ class Session final
void on_message_info(uint64 id, int32 state, uint64 answer_id, int32 answer_size) override; void on_message_info(uint64 id, int32 state, uint64 answer_id, int32 answer_size) override;
Status on_destroy_auth_key() override;
void flush_pending_invoke_after_queries(); void flush_pending_invoke_after_queries();
bool has_queries() const; bool has_queries() const;
@ -221,6 +224,7 @@ class Session final
void connection_send_query(ConnectionInfo *info, NetQueryPtr &&net_query, uint64 message_id = 0); void connection_send_query(ConnectionInfo *info, NetQueryPtr &&net_query, uint64 message_id = 0);
bool need_send_bind_key(); bool need_send_bind_key();
bool need_send_query(); bool need_send_query();
bool can_destroy_auth_key();
bool connection_send_bind_key(ConnectionInfo *info); bool connection_send_bind_key(ConnectionInfo *info);
void on_result(NetQueryPtr query) override; void on_result(NetQueryPtr query) override;

View File

@ -18,14 +18,16 @@ SessionMultiProxy::SessionMultiProxy() = default;
SessionMultiProxy::~SessionMultiProxy() = default; SessionMultiProxy::~SessionMultiProxy() = default;
SessionMultiProxy::SessionMultiProxy(int32 session_count, std::shared_ptr<AuthDataShared> shared_auth_data, SessionMultiProxy::SessionMultiProxy(int32 session_count, std::shared_ptr<AuthDataShared> shared_auth_data,
bool is_main, bool use_pfs, bool allow_media_only, bool is_media, bool is_cdn) bool is_main, bool use_pfs, bool allow_media_only, bool is_media, bool is_cdn,
bool need_destroy_auth_key)
: session_count_(session_count) : session_count_(session_count)
, auth_data_(std::move(shared_auth_data)) , auth_data_(std::move(shared_auth_data))
, is_main_(is_main) , is_main_(is_main)
, use_pfs_(use_pfs) , use_pfs_(use_pfs)
, allow_media_only_(allow_media_only) , allow_media_only_(allow_media_only)
, is_media_(is_media) , is_media_(is_media)
, is_cdn_(is_cdn) { , is_cdn_(is_cdn)
, need_destroy_auth_key_(need_destroy_auth_key) {
if (allow_media_only_) { if (allow_media_only_) {
CHECK(is_media_); CHECK(is_media_);
} }
@ -52,6 +54,13 @@ void SessionMultiProxy::update_main_flag(bool is_main) {
send_closure(session, &SessionProxy::update_main_flag, is_main); send_closure(session, &SessionProxy::update_main_flag, is_main);
} }
} }
void SessionMultiProxy::update_destroy_auth_key(bool need_destroy_auth_key) {
need_destroy_auth_key_ = need_destroy_auth_key;
for (auto &session : sessions_) {
send_closure(session, &SessionProxy::update_destroy, need_destroy_auth_key_);
}
}
void SessionMultiProxy::update_session_count(int32 session_count) { void SessionMultiProxy::update_session_count(int32 session_count) {
update_options(session_count, use_pfs_); update_options(session_count, use_pfs_);
} }
@ -110,7 +119,8 @@ void SessionMultiProxy::init() {
string name = PSTRING() << "Session" << get_name().substr(Slice("SessionMulti").size()) string name = PSTRING() << "Session" << get_name().substr(Slice("SessionMulti").size())
<< format::cond(session_count_ > 1, format::concat("#", i)); << format::cond(session_count_ > 1, format::concat("#", i));
sessions_.push_back(create_actor<SessionProxy>(name, auth_data_, is_main_, allow_media_only_, is_media_, sessions_.push_back(create_actor<SessionProxy>(name, auth_data_, is_main_, allow_media_only_, is_media_,
get_pfs_flag(), is_main_ && i != 0, is_cdn_)); get_pfs_flag(), is_main_ && i != 0, is_cdn_,
need_destroy_auth_key_));
} }
} }

View File

@ -24,7 +24,7 @@ class SessionMultiProxy : public Actor {
SessionMultiProxy &operator=(const SessionMultiProxy &other) = delete; SessionMultiProxy &operator=(const SessionMultiProxy &other) = delete;
~SessionMultiProxy() override; ~SessionMultiProxy() override;
SessionMultiProxy(int32 session_count, std::shared_ptr<AuthDataShared> shared_auth_data, bool is_main, bool use_pfs, SessionMultiProxy(int32 session_count, std::shared_ptr<AuthDataShared> shared_auth_data, bool is_main, bool use_pfs,
bool allow_media_only, bool is_media, bool is_cdn); bool allow_media_only, bool is_media, bool is_cdn, bool need_destroy_auth_key);
void send(NetQueryPtr query); void send(NetQueryPtr query);
void update_main_flag(bool is_main); void update_main_flag(bool is_main);
@ -34,6 +34,8 @@ class SessionMultiProxy : public Actor {
void update_options(int32 session_count, bool use_pfs); void update_options(int32 session_count, bool use_pfs);
void update_mtproto_header(); void update_mtproto_header();
void update_destroy_auth_key(bool need_destroy_auth_key);
private: private:
size_t pos_ = 0; size_t pos_ = 0;
int32 session_count_ = 0; int32 session_count_ = 0;
@ -43,6 +45,7 @@ class SessionMultiProxy : public Actor {
bool allow_media_only_ = false; bool allow_media_only_ = false;
bool is_media_ = false; bool is_media_ = false;
bool is_cdn_ = false; bool is_cdn_ = false;
bool need_destroy_auth_key_ = false;
std::vector<ActorOwn<SessionProxy>> sessions_; std::vector<ActorOwn<SessionProxy>> sessions_;
void start_up() override; void start_up() override;

View File

@ -63,14 +63,15 @@ class SessionCallback : public Session::Callback {
}; };
SessionProxy::SessionProxy(std::shared_ptr<AuthDataShared> shared_auth_data, bool is_main, bool allow_media_only, SessionProxy::SessionProxy(std::shared_ptr<AuthDataShared> shared_auth_data, bool is_main, bool allow_media_only,
bool is_media, bool use_pfs, bool need_wait_for_key, bool is_cdn) bool is_media, bool use_pfs, bool need_wait_for_key, bool is_cdn, bool need_destroy)
: auth_data_(std::move(shared_auth_data)) : auth_data_(std::move(shared_auth_data))
, is_main_(is_main) , is_main_(is_main)
, allow_media_only_(allow_media_only) , allow_media_only_(allow_media_only)
, is_media_(is_media) , is_media_(is_media)
, use_pfs_(use_pfs) , use_pfs_(use_pfs)
, need_wait_for_key_(need_wait_for_key) , need_wait_for_key_(need_wait_for_key)
, is_cdn_(is_cdn) { , is_cdn_(is_cdn)
, need_destroy_(need_destroy) {
} }
void SessionProxy::start_up() { void SessionProxy::start_up() {
@ -91,10 +92,8 @@ void SessionProxy::start_up() {
}; };
auth_state_ = auth_data_->get_auth_state().first; auth_state_ = auth_data_->get_auth_state().first;
auth_data_->add_auth_key_listener(make_unique<Listener>(actor_shared(this))); auth_data_->add_auth_key_listener(make_unique<Listener>(actor_shared(this)));
if (is_main_ && !need_wait_for_key_) {
open_session(); open_session();
} }
}
void SessionProxy::tear_down() { void SessionProxy::tear_down() {
for (auto &query : pending_queries_) { for (auto &query : pending_queries_) {
@ -110,9 +109,7 @@ void SessionProxy::send(NetQueryPtr query) {
pending_queries_.emplace_back(std::move(query)); pending_queries_.emplace_back(std::move(query));
return; return;
} }
if (session_.empty()) {
open_session(true); open_session(true);
}
query->debug(PSTRING() << get_name() << ": sent to session"); query->debug(PSTRING() << get_name() << ": sent to session");
send_closure(session_, &Session::send, std::move(query)); send_closure(session_, &Session::send, std::move(query));
} }
@ -127,6 +124,12 @@ void SessionProxy::update_main_flag(bool is_main) {
open_session(); open_session();
} }
void SessionProxy::update_destroy(bool need_destroy) {
need_destroy_ = need_destroy;
close_session();
open_session();
}
void SessionProxy::on_failed() { void SessionProxy::on_failed() {
if (session_generation_ != get_link_token()) { if (session_generation_ != get_link_token()) {
return; return;
@ -148,9 +151,19 @@ void SessionProxy::close_session() {
session_generation_++; session_generation_++;
} }
void SessionProxy::open_session(bool force) { void SessionProxy::open_session(bool force) {
if (!force && !is_main_) { if (!session_.empty()) {
return; return;
} }
if (auth_state_ == AuthState::Empty && need_destroy_) {
return;
}
if (auth_state_ != AuthState::OK && need_wait_for_key_) {
return;
}
if (!is_main_ && pending_queries_.empty() && !need_destroy_) {
return;
}
CHECK(session_.empty()); CHECK(session_.empty());
auto dc_id = auth_data_->dc_id(); auto dc_id = auth_data_->dc_id();
string name = PSTRING() << "Session" << get_name().substr(Slice("SessionProxy").size()); string name = PSTRING() << "Session" << get_name().substr(Slice("SessionProxy").size());
@ -166,20 +179,12 @@ void SessionProxy::open_session(bool force) {
session_ = create_actor<Session>( session_ = create_actor<Session>(
name, name,
make_unique<SessionCallback>(actor_shared(this, session_generation_), dc_id, allow_media_only_, is_media_, hash), make_unique<SessionCallback>(actor_shared(this, session_generation_), dc_id, allow_media_only_, is_media_, hash),
auth_data_, int_dc_id, is_main_, use_pfs_, is_cdn_, tmp_auth_key_, server_salts_); auth_data_, int_dc_id, is_main_, use_pfs_, is_cdn_, need_destroy_, tmp_auth_key_, server_salts_);
} }
void SessionProxy::update_auth_state() { void SessionProxy::update_auth_state() {
auth_state_ = auth_data_->get_auth_state().first; auth_state_ = auth_data_->get_auth_state().first;
if (pending_queries_.empty() && !need_wait_for_key_) {
return;
}
if (auth_state_ != AuthState::OK) {
return;
}
if (session_.empty()) {
open_session(true); open_session(true);
}
for (auto &query : pending_queries_) { for (auto &query : pending_queries_) {
query->debug(PSTRING() << get_name() << ": sent to session"); query->debug(PSTRING() << get_name() << ": sent to session");
send_closure(session_, &Session::send, std::move(query)); send_closure(session_, &Session::send, std::move(query));

View File

@ -22,11 +22,12 @@ class SessionProxy : public Actor {
friend class SessionCallback; friend class SessionCallback;
SessionProxy(std::shared_ptr<AuthDataShared> shared_auth_data, bool is_main, bool allow_media_only, bool is_media, SessionProxy(std::shared_ptr<AuthDataShared> shared_auth_data, bool is_main, bool allow_media_only, bool is_media,
bool use_pfs, bool need_wait_for_key, bool is_cdn); bool use_pfs, bool need_wait_for_key, bool is_cdn, bool need_destroy);
void send(NetQueryPtr query); void send(NetQueryPtr query);
void update_main_flag(bool is_main); void update_main_flag(bool is_main);
void update_mtproto_header(); void update_mtproto_header();
void update_destroy(bool need_destroy);
private: private:
std::shared_ptr<AuthDataShared> auth_data_; std::shared_ptr<AuthDataShared> auth_data_;
@ -39,6 +40,7 @@ class SessionProxy : public Actor {
std::vector<mtproto::ServerSalt> server_salts_; std::vector<mtproto::ServerSalt> server_salts_;
bool need_wait_for_key_; bool need_wait_for_key_;
bool is_cdn_; bool is_cdn_;
bool need_destroy_;
ActorOwn<Session> session_; ActorOwn<Session> session_;
std::vector<NetQueryPtr> pending_queries_; std::vector<NetQueryPtr> pending_queries_;
uint64 session_generation_ = 1; uint64 session_generation_ = 1;