1
0
This repository has been archived on 2020-05-25. You can view files and clone it, but cannot push or open issues or pull requests.
tdlib-fork/td/telegram/net/AuthDataShared.cpp
Arseny Smirnov 600bbcf3ca Protect G()->td_db() usage in AuthDataShared with guard
GitOrigin-RevId: d9e8a79c9db9b061de81f98fcd66f085d4efcaaa
2018-03-13 16:40:02 +03:00

120 lines
3.7 KiB
C++

//
// Copyright Aliaksei Levin (levlam@telegram.org), Arseny Smirnov (arseny30@gmail.com) 2014-2018
//
// Distributed under the Boost Software License, Version 1.0. (See accompanying
// file LICENSE_1_0.txt or copy at http://www.boost.org/LICENSE_1_0.txt)
//
#include "td/telegram/net/AuthDataShared.h"
#include "td/telegram/Global.h"
#include "td/utils/format.h"
#include "td/utils/logging.h"
#include "td/utils/port/RwMutex.h"
#include "td/utils/tl_helpers.h"
#include <algorithm>
namespace td {
class AuthDataSharedImpl : public AuthDataShared {
public:
AuthDataSharedImpl(DcId dc_id, std::shared_ptr<PublicRsaKeyShared> public_rsa_key, std::shared_ptr<Guard> guard)
: dc_id_(dc_id), public_rsa_key_(std::move(public_rsa_key)), guard_(std::move(guard)) {
log_auth_key(get_auth_key());
}
DcId dc_id() const override {
return dc_id_;
}
const std::shared_ptr<PublicRsaKeyShared> &public_rsa_key() override {
return public_rsa_key_;
}
mtproto::AuthKey get_auth_key() override {
string dc_key = G()->td_db()->get_binlog_pmc()->get(auth_key_key());
mtproto::AuthKey res;
if (!dc_key.empty()) {
unserialize(res, dc_key).ensure();
}
return res;
}
using AuthDataShared::get_auth_state;
std::pair<AuthState, bool> get_auth_state() override {
// TODO (perf):
auto auth_key = get_auth_key();
AuthState state = get_auth_state(auth_key);
return std::make_pair(state, auth_key.was_auth_flag());
}
void set_auth_key(const mtproto::AuthKey &auth_key) override {
G()->td_db()->get_binlog_pmc()->set(auth_key_key(), serialize(auth_key));
log_auth_key(auth_key);
notify();
}
// TODO: extract it from G()
void update_server_time_difference(double diff) override {
G()->update_server_time_difference(diff);
}
double get_server_time_difference() override {
return G()->get_server_time_difference();
}
void add_auth_key_listener(unique_ptr<Listener> listener) override {
if (listener->notify()) {
auto lock = rw_mutex_.lock_write();
auth_key_listeners_.push_back(std::move(listener));
}
}
void set_future_salts(const std::vector<mtproto::ServerSalt> &future_salts) override {
G()->td_db()->get_binlog_pmc()->set(future_salts_key(), serialize(future_salts));
}
std::vector<mtproto::ServerSalt> get_future_salts() override {
string future_salts = G()->td_db()->get_binlog_pmc()->get(future_salts_key());
std::vector<mtproto::ServerSalt> res;
if (!future_salts.empty()) {
unserialize(res, future_salts).ensure();
}
return res;
}
private:
DcId dc_id_;
std::vector<unique_ptr<Listener>> auth_key_listeners_;
std::shared_ptr<PublicRsaKeyShared> public_rsa_key_;
std::shared_ptr<Guard> guard_;
RwMutex rw_mutex_;
string auth_key_key() {
return PSTRING() << "auth" << dc_id_.get_raw_id();
}
string future_salts_key() {
return PSTRING() << "salt" << dc_id_.get_raw_id();
}
void notify() {
auto lock = rw_mutex_.lock_read();
auto it = std::remove_if(auth_key_listeners_.begin(), auth_key_listeners_.end(),
[&](auto &listener) { return !listener->notify(); });
auth_key_listeners_.erase(it, auth_key_listeners_.end());
}
void log_auth_key(const mtproto::AuthKey &auth_key) {
LOG(WARNING) << dc_id_ << " " << tag("auth_key_id", auth_key.id()) << tag("state", get_auth_state(auth_key));
}
};
std::shared_ptr<AuthDataShared> AuthDataShared::create(DcId dc_id, std::shared_ptr<PublicRsaKeyShared> public_rsa_key,
std::shared_ptr<Guard> guard) {
return std::make_shared<AuthDataSharedImpl>(dc_id, std::move(public_rsa_key), std::move(guard));
}
} // namespace td