// // Copyright Aliaksei Levin (levlam@telegram.org), Arseny Smirnov (arseny30@gmail.com) 2014-2022 // // Distributed under the Boost Software License, Version 1.0. (See accompanying // file LICENSE_1_0.txt or copy at http://www.boost.org/LICENSE_1_0.txt) // #include "td/telegram/PollManager.h" #include "td/telegram/AccessRights.h" #include "td/telegram/AuthManager.h" #include "td/telegram/ChainId.h" #include "td/telegram/ContactsManager.h" #include "td/telegram/Dependencies.h" #include "td/telegram/DialogId.h" #include "td/telegram/Global.h" #include "td/telegram/logevent/LogEvent.h" #include "td/telegram/logevent/LogEventHelper.h" #include "td/telegram/MemoryManager.h" #include "td/telegram/MessagesManager.h" #include "td/telegram/misc.h" #include "td/telegram/PollId.hpp" #include "td/telegram/PollManager.hpp" #include "td/telegram/StateManager.h" #include "td/telegram/Td.h" #include "td/telegram/TdDb.h" #include "td/telegram/TdParameters.h" #include "td/telegram/telegram_api.hpp" #include "td/telegram/UpdatesManager.h" #include "td/db/binlog/BinlogEvent.h" #include "td/db/binlog/BinlogHelper.h" #include "td/db/SqliteKeyValue.h" #include "td/db/SqliteKeyValueAsync.h" #include "td/utils/algorithm.h" #include "td/utils/buffer.h" #include "td/utils/format.h" #include "td/utils/logging.h" #include "td/utils/misc.h" #include "td/utils/Random.h" #include "td/utils/Slice.h" #include "td/utils/SliceBuilder.h" #include "td/utils/Status.h" #include "td/utils/tl_helpers.h" #include #include namespace td { class GetPollResultsQuery final : public Td::ResultHandler { Promise> promise_; PollId poll_id_; DialogId dialog_id_; public: explicit GetPollResultsQuery(Promise> &&promise) : promise_(std::move(promise)) { } void send(PollId poll_id, FullMessageId full_message_id) { poll_id_ = poll_id; dialog_id_ = full_message_id.get_dialog_id(); auto input_peer = td_->messages_manager_->get_input_peer(dialog_id_, AccessRights::Read); if (input_peer == nullptr) { LOG(INFO) << "Can't reget poll, because have no read access to " << dialog_id_; return promise_.set_value(nullptr); } auto message_id = full_message_id.get_message_id().get_server_message_id().get(); send_query( G()->net_query_creator().create(telegram_api::messages_getPollResults(std::move(input_peer), message_id))); } void on_result(BufferSlice packet) final { auto result_ptr = fetch_result(packet); if (result_ptr.is_error()) { return on_error(result_ptr.move_as_error()); } promise_.set_value(result_ptr.move_as_ok()); } void on_error(Status status) final { if (!td_->messages_manager_->on_get_dialog_error(dialog_id_, status, "GetPollResultsQuery") && status.message() != "MESSAGE_ID_INVALID") { LOG(ERROR) << "Receive " << status << ", while trying to get results of " << poll_id_; } promise_.set_error(std::move(status)); } }; class GetPollVotersQuery final : public Td::ResultHandler { Promise> promise_; PollId poll_id_; DialogId dialog_id_; public: explicit GetPollVotersQuery(Promise> &&promise) : promise_(std::move(promise)) { } void send(PollId poll_id, FullMessageId full_message_id, BufferSlice &&option, const string &offset, int32 limit) { poll_id_ = poll_id; dialog_id_ = full_message_id.get_dialog_id(); auto input_peer = td_->messages_manager_->get_input_peer(dialog_id_, AccessRights::Read); if (input_peer == nullptr) { LOG(INFO) << "Can't get poll, because have no read access to " << dialog_id_; return promise_.set_error(Status::Error(400, "Chat is not accessible")); } CHECK(!option.empty()); int32 flags = telegram_api::messages_getPollVotes::OPTION_MASK; if (!offset.empty()) { flags |= telegram_api::messages_getPollVotes::OFFSET_MASK; } auto message_id = full_message_id.get_message_id().get_server_message_id().get(); send_query(G()->net_query_creator().create(telegram_api::messages_getPollVotes( flags, std::move(input_peer), message_id, std::move(option), offset, limit))); } void on_result(BufferSlice packet) final { auto result_ptr = fetch_result(packet); if (result_ptr.is_error()) { return on_error(result_ptr.move_as_error()); } promise_.set_value(result_ptr.move_as_ok()); } void on_error(Status status) final { if (!td_->messages_manager_->on_get_dialog_error(dialog_id_, status, "GetPollVotersQuery") && status.message() != "MESSAGE_ID_INVALID") { LOG(ERROR) << "Receive " << status << ", while trying to get voters of " << poll_id_; } promise_.set_error(std::move(status)); } }; class SendVoteQuery final : public Td::ResultHandler { Promise> promise_; DialogId dialog_id_; public: explicit SendVoteQuery(Promise> &&promise) : promise_(std::move(promise)) { } void send(FullMessageId full_message_id, vector &&options, PollId poll_id, uint64 generation, NetQueryRef *query_ref) { dialog_id_ = full_message_id.get_dialog_id(); auto input_peer = td_->messages_manager_->get_input_peer(dialog_id_, AccessRights::Read); if (input_peer == nullptr) { LOG(INFO) << "Can't set poll answer, because have no read access to " << dialog_id_; return on_error(Status::Error(400, "Can't access the chat")); } auto message_id = full_message_id.get_message_id().get_server_message_id().get(); auto query = G()->net_query_creator().create( telegram_api::messages_sendVote(std::move(input_peer), message_id, std::move(options)), {{poll_id}, {dialog_id_}}); *query_ref = query.get_weak(); send_query(std::move(query)); } void on_result(BufferSlice packet) final { auto result_ptr = fetch_result(packet); if (result_ptr.is_error()) { return on_error(result_ptr.move_as_error()); } auto result = result_ptr.move_as_ok(); LOG(INFO) << "Receive result for SendVoteQuery: " << to_string(result); promise_.set_value(std::move(result)); } void on_error(Status status) final { td_->messages_manager_->on_get_dialog_error(dialog_id_, status, "SendVoteQuery"); promise_.set_error(std::move(status)); } }; class StopPollQuery final : public Td::ResultHandler { Promise promise_; DialogId dialog_id_; public: explicit StopPollQuery(Promise &&promise) : promise_(std::move(promise)) { } void send(FullMessageId full_message_id, unique_ptr &&reply_markup, PollId poll_id) { dialog_id_ = full_message_id.get_dialog_id(); auto input_peer = td_->messages_manager_->get_input_peer(dialog_id_, AccessRights::Edit); if (input_peer == nullptr) { LOG(INFO) << "Can't close poll, because have no edit access to " << dialog_id_; return on_error(Status::Error(400, "Can't access the chat")); } int32 flags = telegram_api::messages_editMessage::MEDIA_MASK; auto input_reply_markup = get_input_reply_markup(td_->contacts_manager_.get(), reply_markup); if (input_reply_markup != nullptr) { flags |= telegram_api::messages_editMessage::REPLY_MARKUP_MASK; } auto message_id = full_message_id.get_message_id().get_server_message_id().get(); auto poll = telegram_api::make_object(); poll->flags_ |= telegram_api::poll::CLOSED_MASK; auto input_media = telegram_api::make_object(0, std::move(poll), vector(), string(), Auto()); send_query(G()->net_query_creator().create( telegram_api::messages_editMessage(flags, false /*ignored*/, std::move(input_peer), message_id, string(), std::move(input_media), std::move(input_reply_markup), vector>(), 0), {{poll_id}, {dialog_id_}})); } void on_result(BufferSlice packet) final { auto result_ptr = fetch_result(packet); if (result_ptr.is_error()) { return on_error(result_ptr.move_as_error()); } auto result = result_ptr.move_as_ok(); LOG(INFO) << "Receive result for StopPollQuery: " << to_string(result); td_->updates_manager_->on_get_updates(std::move(result), std::move(promise_)); } void on_error(Status status) final { if (!td_->auth_manager_->is_bot() && status.message() == "MESSAGE_NOT_MODIFIED") { return promise_.set_value(Unit()); } td_->messages_manager_->on_get_dialog_error(dialog_id_, status, "StopPollQuery"); promise_.set_error(std::move(status)); } }; PollManager::PollManager(Td *td, ActorShared<> parent) : td_(td), parent_(std::move(parent)) { update_poll_timeout_.set_callback(on_update_poll_timeout_callback); update_poll_timeout_.set_callback_data(static_cast(this)); close_poll_timeout_.set_callback(on_close_poll_timeout_callback); close_poll_timeout_.set_callback_data(static_cast(this)); unload_poll_timeout_.set_callback(on_unload_poll_timeout_callback); unload_poll_timeout_.set_callback_data(static_cast(this)); } void PollManager::start_up() { class StateCallback final : public StateManager::Callback { public: explicit StateCallback(ActorId parent) : parent_(std::move(parent)) { } bool on_online(bool is_online) final { if (is_online) { send_closure(parent_, &PollManager::on_online); } return parent_.is_alive(); } private: ActorId parent_; }; send_closure(G()->state_manager(), &StateManager::add_callback, make_unique(actor_id(this))); } void PollManager::tear_down() { parent_.reset(); } PollManager::~PollManager() { Scheduler::instance()->destroy_on_scheduler(G()->get_gc_scheduler_id(), polls_, server_poll_messages_, other_poll_messages_, poll_voters_, loaded_from_database_polls_); } void PollManager::on_update_poll_timeout_callback(void *poll_manager_ptr, int64 poll_id_int) { if (G()->close_flag()) { return; } auto poll_manager = static_cast(poll_manager_ptr); send_closure_later(poll_manager->actor_id(poll_manager), &PollManager::on_update_poll_timeout, PollId(poll_id_int)); } void PollManager::on_close_poll_timeout_callback(void *poll_manager_ptr, int64 poll_id_int) { if (G()->close_flag()) { return; } auto poll_manager = static_cast(poll_manager_ptr); send_closure_later(poll_manager->actor_id(poll_manager), &PollManager::on_close_poll_timeout, PollId(poll_id_int)); } void PollManager::on_unload_poll_timeout_callback(void *poll_manager_ptr, int64 poll_id_int) { if (G()->close_flag()) { return; } auto poll_manager = static_cast(poll_manager_ptr); send_closure_later(poll_manager->actor_id(poll_manager), &PollManager::on_unload_poll_timeout, PollId(poll_id_int)); } bool PollManager::is_local_poll_id(PollId poll_id) { return poll_id.get() < 0 && poll_id.get() > std::numeric_limits::min(); } const PollManager::Poll *PollManager::get_poll(PollId poll_id) const { auto p = polls_.find(poll_id); if (p == polls_.end()) { return nullptr; } else { return p->second.get(); } } const PollManager::Poll *PollManager::get_poll(PollId poll_id) { auto p = polls_.find(poll_id); if (p == polls_.end()) { return nullptr; } else { schedule_poll_unload(poll_id); return p->second.get(); } } PollManager::Poll *PollManager::get_poll_editable(PollId poll_id) { auto p = polls_.find(poll_id); if (p == polls_.end()) { return nullptr; } else { schedule_poll_unload(poll_id); return p->second.get(); } } bool PollManager::have_poll(PollId poll_id) const { return get_poll(poll_id) != nullptr; } void PollManager::notify_on_poll_update(PollId poll_id) { auto server_it = server_poll_messages_.find(poll_id); if (server_it != server_poll_messages_.end()) { for (const auto &full_message_id : server_it->second) { td_->messages_manager_->on_external_update_message_content(full_message_id); } } auto other_it = other_poll_messages_.find(poll_id); if (other_it != other_poll_messages_.end()) { for (const auto &full_message_id : other_it->second) { td_->messages_manager_->on_external_update_message_content(full_message_id); } } } string PollManager::get_poll_database_key(PollId poll_id) { return PSTRING() << "poll" << poll_id.get(); } void PollManager::save_poll(const Poll *poll, PollId poll_id) { CHECK(!is_local_poll_id(poll_id)); poll->was_saved = true; if (!G()->parameters().use_message_db) { return; } LOG(INFO) << "Save " << poll_id << " to database"; CHECK(poll != nullptr); G()->td_db()->get_sqlite_pmc()->set(get_poll_database_key(poll_id), log_event_store(*poll).as_slice().str(), Auto()); } void PollManager::on_load_poll_from_database(PollId poll_id, string value) { CHECK(poll_id.is_valid()); loaded_from_database_polls_.insert(poll_id); LOG(INFO) << "Successfully loaded " << poll_id << " of size " << value.size() << " from database"; // G()->td_db()->get_sqlite_pmc()->erase(get_poll_database_key(poll_id), Auto()); // return; CHECK(!have_poll(poll_id)); if (!value.empty()) { auto poll = make_unique(); auto status = log_event_parse(*poll, value); if (status.is_error()) { LOG(FATAL) << status << ": " << format::as_hex_dump<4>(Slice(value)); } for (auto &user_id : poll->recent_voter_user_ids) { td_->contacts_manager_->have_user_force(user_id); } if (!poll->is_closed && poll->close_date != 0) { if (poll->close_date <= G()->server_time()) { poll->is_closed = true; } else { CHECK(!is_local_poll_id(poll_id)); close_poll_timeout_.set_timeout_in(poll_id.get(), poll->close_date - G()->server_time() + 1e-3); } } polls_[poll_id] = std::move(poll); } } bool PollManager::have_poll_force(PollId poll_id) { return get_poll_force(poll_id) != nullptr; } PollManager::Poll *PollManager::get_poll_force(PollId poll_id) { auto poll = get_poll_editable(poll_id); if (poll != nullptr) { return poll; } if (!G()->parameters().use_message_db) { return nullptr; } if (!poll_id.is_valid() || loaded_from_database_polls_.count(poll_id)) { return nullptr; } LOG(INFO) << "Trying to load " << poll_id << " from database"; on_load_poll_from_database(poll_id, G()->td_db()->get_sqlite_sync_pmc()->get(get_poll_database_key(poll_id))); return get_poll_editable(poll_id); } td_api::object_ptr PollManager::get_poll_option_object(const PollOption &poll_option) { return td_api::make_object(poll_option.text, poll_option.voter_count, 0, poll_option.is_chosen, false); } vector PollManager::get_vote_percentage(const vector &voter_counts, int32 total_voter_count) { int32 sum = 0; for (auto voter_count : voter_counts) { CHECK(0 <= voter_count); CHECK(voter_count <= std::numeric_limits::max() - sum); sum += voter_count; } if (total_voter_count > sum) { if (sum != 0) { LOG(ERROR) << "Have total_voter_count = " << total_voter_count << ", but votes sum = " << sum << ": " << voter_counts; } total_voter_count = sum; } vector result(voter_counts.size(), 0); if (total_voter_count == 0) { return result; } if (total_voter_count != sum) { // just round to the nearest for (size_t i = 0; i < result.size(); i++) { result[i] = static_cast((static_cast(voter_counts[i]) * 200 + total_voter_count) / total_voter_count / 2); } return result; } // make sure that options with equal votes have equal percent and total sum is less than 100% int32 percent_sum = 0; vector gap(voter_counts.size(), 0); for (size_t i = 0; i < result.size(); i++) { auto multiplied_voter_count = static_cast(voter_counts[i]) * 100; result[i] = static_cast(multiplied_voter_count / total_voter_count); CHECK(0 <= result[i] && result[i] <= 100); gap[i] = static_cast(static_cast(result[i] + 1) * total_voter_count - multiplied_voter_count); CHECK(0 <= gap[i] && gap[i] <= total_voter_count); percent_sum += result[i]; } CHECK(0 <= percent_sum && percent_sum <= 100); if (percent_sum == 100) { return result; } // now we need to choose up to (100 - percent_sum) options with a minimum total gap, such that // any two options with the same voter_count are chosen or not chosen simultaneously struct Option { int32 pos = -1; int32 count = 0; }; FlatHashMap options; for (size_t i = 0; i < result.size(); i++) { auto &option = options[voter_counts[i] + 1]; if (option.pos == -1) { option.pos = narrow_cast(i); } option.count++; } vector