Check message_thread_id parameter before using it.

This commit is contained in:
levlam 2022-11-02 14:05:23 +03:00
parent 36e41d6b7f
commit 8aac13eb4d
2 changed files with 111 additions and 17 deletions

View File

@ -3194,6 +3194,55 @@ class Client::TdOnCheckMessageCallback final : public TdQueryCallback {
OnSuccess on_success_;
};
template <class OnSuccess>
class Client::TdOnCheckMessageThreadCallback final : public TdQueryCallback {
public:
TdOnCheckMessageThreadCallback(Client *client, int64 chat_id, int64 message_thread_id, int64 reply_to_message_id,
PromisedQueryPtr query, OnSuccess on_success)
: client_(client)
, chat_id_(chat_id)
, message_thread_id_(message_thread_id)
, reply_to_message_id_(reply_to_message_id)
, query_(std::move(query))
, on_success_(std::move(on_success)) {
}
void on_result(object_ptr<td_api::Object> result) final {
if (result->get_id() == td_api::error::ID) {
auto error = move_object_as<td_api::error>(result);
if (error->code_ == 429) {
LOG(WARNING) << "Failed to get message thread " << message_thread_id_ << " in " << chat_id_;
}
return fail_query_with_error(std::move(query_), std::move(error), "Message thread not found");
}
CHECK(result->get_id() == td_api::message::ID);
auto full_message_id = client_->add_message(move_object_as<td_api::message>(result));
CHECK(full_message_id.chat_id == chat_id_);
CHECK(full_message_id.message_id == message_thread_id_);
const MessageInfo *message_info = client_->get_message(chat_id_, message_thread_id_);
CHECK(message_info != nullptr);
if (message_info->message_thread_id != message_thread_id_) {
return fail_query_with_error(std::move(query_), 400, "MESSAGE_THREAD_INVALID", "Message thread not found");
}
if (!message_info->is_topic_message) {
return fail_query_with_error(std::move(query_), 400, "MESSAGE_THREAD_INVALID",
"Message thread is not a forum topic thread");
}
on_success_(chat_id_, message_thread_id_, reply_to_message_id_, std::move(query_));
}
private:
Client *client_;
int64 chat_id_;
int64 message_thread_id_;
int64 reply_to_message_id_;
PromisedQueryPtr query_;
OnSuccess on_success_;
};
template <class OnSuccess>
class Client::TdOnCheckRemoteFileIdCallback final : public TdQueryCallback {
public:
@ -4381,6 +4430,30 @@ void Client::check_message(Slice chat_id_str, int64 message_id, bool allow_empty
});
}
template <class OnSuccess>
void Client::check_message_thread(int64 chat_id, int64 message_thread_id, int64 reply_to_message_id,
PromisedQueryPtr query, OnSuccess on_success) {
if (message_thread_id <= 0) {
return on_success(chat_id, 0, reply_to_message_id, std::move(query));
}
if (reply_to_message_id != 0) {
const MessageInfo *message_info = get_message(chat_id, reply_to_message_id);
CHECK(message_info != nullptr);
if (message_info->message_thread_id != message_thread_id) {
return fail_query_with_error(std::move(query), 400, "MESSAGE_THREAD_INVALID",
"Replied message is not in the specified message thread");
}
}
if (reply_to_message_id == message_thread_id) {
return on_success(chat_id, message_thread_id, reply_to_message_id, std::move(query));
}
send_request(make_object<td_api::getMessage>(chat_id, message_thread_id),
td::make_unique<TdOnCheckMessageThreadCallback<OnSuccess>>(
this, chat_id, message_thread_id, reply_to_message_id, std::move(query), std::move(on_success)));
}
template <class OnSuccess>
void Client::resolve_sticker_set(const td::string &sticker_set_name, PromisedQueryPtr query, OnSuccess on_success) {
if (sticker_set_name.empty()) {
@ -7542,17 +7615,24 @@ td::Status Client::process_send_media_group_query(PromisedQueryPtr &query) {
input_message_contents = std::move(input_message_contents),
reply_markup = std::move(reply_markup)](int64 chat_id, int64 reply_to_message_id,
PromisedQueryPtr query) mutable {
auto on_message_thread_checked =
[this, disable_notification, protect_content, input_message_contents = std::move(input_message_contents),
reply_markup = std::move(reply_markup)](int64 chat_id, int64 message_thread_id,
int64 reply_to_message_id, PromisedQueryPtr query) mutable {
auto it = yet_unsent_message_count_.find(chat_id);
if (it != yet_unsent_message_count_.end() && it->second > MAX_CONCURRENTLY_SENT_CHAT_MESSAGES) {
return query->set_retry_after_error(60);
}
send_request(
make_object<td_api::sendMessageAlbum>(chat_id, message_thread_id, reply_to_message_id,
send_request(make_object<td_api::sendMessageAlbum>(
chat_id, message_thread_id, reply_to_message_id,
get_message_send_options(disable_notification, protect_content),
std::move(input_message_contents), false),
td::make_unique<TdOnSendMessageAlbumCallback>(this, std::move(query)));
};
check_message_thread(chat_id, message_thread_id, reply_to_message_id, std::move(query),
std::move(on_message_thread_checked));
};
check_message(chat_id, reply_to_message_id, reply_to_message_id <= 0 || allow_sending_without_reply,
AccessRights::Write, "replied message", std::move(query), std::move(on_success));
});
@ -9112,16 +9192,24 @@ void Client::do_send_message(object_ptr<td_api::InputMessageContent> input_messa
input_message_content = std::move(input_message_content),
reply_markup = std::move(reply_markup)](int64 chat_id, int64 reply_to_message_id,
PromisedQueryPtr query) mutable {
auto on_message_thread_checked =
[this, disable_notification, protect_content, input_message_content = std::move(input_message_content),
reply_markup = std::move(reply_markup)](int64 chat_id, int64 message_thread_id,
int64 reply_to_message_id, PromisedQueryPtr query) mutable {
auto it = yet_unsent_message_count_.find(chat_id);
if (it != yet_unsent_message_count_.end() && it->second > MAX_CONCURRENTLY_SENT_CHAT_MESSAGES) {
return query->set_retry_after_error(60);
}
send_request(make_object<td_api::sendMessage>(chat_id, message_thread_id, reply_to_message_id,
send_request(
make_object<td_api::sendMessage>(chat_id, message_thread_id, reply_to_message_id,
get_message_send_options(disable_notification, protect_content),
std::move(reply_markup), std::move(input_message_content)),
td::make_unique<TdOnSendMessageCallback>(this, std::move(query)));
};
check_message_thread(chat_id, message_thread_id, reply_to_message_id, std::move(query),
std::move(on_message_thread_checked));
};
check_message(chat_id, reply_to_message_id, reply_to_message_id <= 0 || allow_sending_without_reply,
AccessRights::Write, "replied message", std::move(query), std::move(on_success));
});

View File

@ -250,6 +250,8 @@ class Client final : public WebhookActor::Callback {
template <class OnSuccess>
class TdOnCheckMessageCallback;
template <class OnSuccess>
class TdOnCheckMessageThreadCallback;
template <class OnSuccess>
class TdOnCheckRemoteFileIdCallback;
template <class OnSuccess>
class TdOnGetChatMemberCallback;
@ -288,6 +290,10 @@ class Client final : public WebhookActor::Callback {
void check_message(Slice chat_id_str, int64 message_id, bool allow_empty, AccessRights access_rights,
Slice message_type, PromisedQueryPtr query, OnSuccess on_success);
template <class OnSuccess>
void check_message_thread(int64 chat_id, int64 message_thread_id, int64 reply_to_message_id, PromisedQueryPtr query,
OnSuccess on_success);
template <class OnSuccess>
void resolve_sticker_set(const td::string &sticker_set_name, PromisedQueryPtr query, OnSuccess on_success);