Optimize SessionConnection::on_slice_packet using TlDowncastHelper.

This commit is contained in:
levlam 2021-09-12 19:46:12 +03:00
parent 47d3806c62
commit 45ebe775c5
2 changed files with 63 additions and 64 deletions

View File

@ -25,6 +25,7 @@
#include "td/utils/SliceBuilder.h"
#include "td/utils/Time.h"
#include "td/utils/tl_parsers.h"
#include "td/utils/TlDowncastHelper.h"
#include "td/mtproto/mtproto_api.h"
#include "td/mtproto/mtproto_api.hpp"
@ -171,22 +172,6 @@ namespace mtproto {
*
*/
class OnPacket {
const MsgInfo &info_;
SessionConnection *connection_;
Status *status_;
public:
OnPacket(const MsgInfo &info, SessionConnection *connection, Status *status)
: info_(info), connection_(connection), status_(status) {
}
template <class T>
void operator()(const T &func) const {
*status_ = connection_->on_packet(info_, func);
}
};
unique_ptr<RawConnection> SessionConnection::move_as_raw_connection() {
return std::move(raw_connection_);
}
@ -230,7 +215,6 @@ Status SessionConnection::on_packet_container(const MsgInfo &info, Slice packet)
};
TlParser parser(packet);
parser.fetch_int();
int32 size = parser.fetch_int();
if (parser.get_error()) {
return Status::Error(PSLICE() << "Failed to parse mtproto_api::rpc_container: " << parser.get_error());
@ -243,7 +227,6 @@ Status SessionConnection::on_packet_container(const MsgInfo &info, Slice packet)
Status SessionConnection::on_packet_rpc_result(const MsgInfo &info, Slice packet) {
TlParser parser(packet);
parser.fetch_int();
uint64 req_msg_id = parser.fetch_long();
if (parser.get_error()) {
return Status::Error(PSLICE() << "Failed to parse mtproto_api::rpc_result: " << parser.get_error());
@ -275,7 +258,7 @@ Status SessionConnection::on_packet_rpc_result(const MsgInfo &info, Slice packet
return callback_->on_message_result_ok(req_msg_id, std::move(object), info.size);
}
default:
packet.remove_prefix(4 + sizeof(req_msg_id));
packet.remove_prefix(sizeof(req_msg_id));
return callback_->on_message_result_ok(req_msg_id, as_buffer_slice(packet), info.size);
}
}
@ -285,12 +268,15 @@ Status SessionConnection::on_packet(const MsgInfo &info, const T &packet) {
LOG(ERROR) << "Unsupported: " << to_string(packet);
return Status::OK();
}
Status SessionConnection::on_packet(const MsgInfo &info, const mtproto_api::destroy_auth_key_ok &destroy_auth_key) {
return on_destroy_auth_key(destroy_auth_key);
}
Status SessionConnection::on_packet(const MsgInfo &info, const mtproto_api::destroy_auth_key_none &destroy_auth_key) {
return on_destroy_auth_key(destroy_auth_key);
}
Status SessionConnection::on_packet(const MsgInfo &info, const mtproto_api::destroy_auth_key_fail &destroy_auth_key) {
return on_destroy_auth_key(destroy_auth_key);
}
@ -479,52 +465,66 @@ Status SessionConnection::on_slice_packet(const MsgInfo &info, Slice packet) {
if (info.seq_no & 1) {
send_ack(info.message_id);
}
TlParser parser(packet);
tl_object_ptr<mtproto_api::Object> object = mtproto_api::Object::fetch(parser);
parser.fetch_end();
if (parser.get_error()) {
// msg_container is not real tl object
if (packet.size() >= 4 && as<int32>(packet.begin()) == mtproto_api::msg_container::ID) {
return on_packet_container(info, packet);
}
if (packet.size() >= 4 && as<int32>(packet.begin()) == mtproto_api::rpc_result::ID) {
return on_packet_rpc_result(info, packet);
}
// It is an update... I hope.
auto status = auth_data_->check_update(info.message_id);
auto recheck_status = auth_data_->recheck_update(info.message_id);
if (recheck_status.is_error() && recheck_status.code() == 2) {
LOG(WARNING) << "Receive very old update from " << get_name() << " created in " << (Time::now() - created_at_)
<< " in container " << container_id_ << " from session " << auth_data_->get_session_id()
<< " with message_id " << info.message_id << ", main_message_id = " << main_message_id_
<< ", seq_no = " << info.seq_no << " and original size " << info.size << ": " << status << ' '
<< recheck_status;
}
if (status.is_error()) {
if (status.code() == 2) {
LOG(WARNING) << "Receive too old update from " << get_name() << " created in " << (Time::now() - created_at_)
<< " in container " << container_id_ << " from session " << auth_data_->get_session_id()
<< " with message_id " << info.message_id << ", main_message_id = " << main_message_id_
<< ", seq_no = " << info.seq_no << " and original size " << info.size << ": " << status;
callback_->on_session_failed(Status::Error("Receive too old update"));
return status;
}
VLOG(mtproto) << "Skip update " << info.message_id << " of size " << info.size << " with seq_no " << info.seq_no
<< " from " << get_name() << " created in " << (Time::now() - created_at_) << ": " << status;
return Status::OK();
} else {
VLOG(mtproto) << "Got update from " << get_name() << " created in " << (Time::now() - created_at_)
<< " in container " << container_id_ << " from session " << auth_data_->get_session_id()
<< " with message_id " << info.message_id << ", main_message_id = " << main_message_id_
<< ", seq_no = " << info.seq_no << " and original size " << info.size;
return callback_->on_update(as_buffer_slice(packet));
}
if (packet.size() < 4) {
callback_->on_session_failed(Status::Error("Receive too small packet"));
return Status::Error(PSLICE() << "Receive packet of size " << packet.size());
}
int32 constructor_id = as<int32>(packet.begin());
if (constructor_id == mtproto_api::msg_container::ID) {
return on_packet_container(info, packet.substr(4));
}
if (constructor_id == mtproto_api::rpc_result::ID) {
return on_packet_rpc_result(info, packet.substr(4));
}
TlDowncastHelper<mtproto_api::Object> helper(constructor_id);
Status status;
downcast_call(*object, OnPacket(info, this, &status));
return status;
bool is_mtproto_api = downcast_call(static_cast<mtproto_api::Object &>(helper), [&](auto &dummy) {
// a constructor from mtproto_api
using Type = std::decay_t<decltype(dummy)>;
TlParser parser(packet.substr(4));
auto object = Type::fetch(parser);
parser.fetch_end();
if (parser.get_error()) {
status = parser.get_status();
} else {
status = this->on_packet(info, static_cast<const Type &>(*object));
}
});
if (is_mtproto_api) {
return status;
}
// It is an update... I hope.
status = auth_data_->check_update(info.message_id);
auto recheck_status = auth_data_->recheck_update(info.message_id);
if (recheck_status.is_error() && recheck_status.code() == 2) {
LOG(WARNING) << "Receive very old update from " << get_name() << " created in " << (Time::now() - created_at_)
<< " in container " << container_id_ << " from session " << auth_data_->get_session_id()
<< " with message_id " << info.message_id << ", main_message_id = " << main_message_id_
<< ", seq_no = " << info.seq_no << " and original size " << info.size << ": " << status << ' '
<< recheck_status;
}
if (status.is_error()) {
if (status.code() == 2) {
LOG(WARNING) << "Receive too old update from " << get_name() << " created in " << (Time::now() - created_at_)
<< " in container " << container_id_ << " from session " << auth_data_->get_session_id()
<< " with message_id " << info.message_id << ", main_message_id = " << main_message_id_
<< ", seq_no = " << info.seq_no << " and original size " << info.size << ": " << status;
callback_->on_session_failed(Status::Error("Receive too old update"));
return status;
}
VLOG(mtproto) << "Skip update " << info.message_id << " of size " << info.size << " with seq_no " << info.seq_no
<< " from " << get_name() << " created in " << (Time::now() - created_at_) << ": " << status;
return Status::OK();
} else {
VLOG(mtproto) << "Got update from " << get_name() << " created in " << (Time::now() - created_at_)
<< " in container " << container_id_ << " from session " << auth_data_->get_session_id()
<< " with message_id " << info.message_id << ", main_message_id = " << main_message_id_
<< ", seq_no = " << info.seq_no << " and original size " << info.size;
return callback_->on_update(as_buffer_slice(packet));
}
}
Status SessionConnection::parse_packet(TlParser &parser) {
@ -579,6 +579,7 @@ void SessionConnection::on_message_failed(uint64 id, Status status) {
on_message_failed_inner(id);
}
}
void SessionConnection::on_message_failed_inner(uint64 id) {
auto it = service_queries_.find(id);
if (it == service_queries_.end()) {

View File

@ -203,8 +203,6 @@ class SessionConnection final
SessionConnection::Callback *callback_ = nullptr;
BufferSlice *current_buffer_slice_;
friend class OnPacket;
BufferSlice as_buffer_slice(Slice packet);
auto set_buffer_slice(BufferSlice *buffer_slice) TD_WARN_UNUSED_RESULT {
auto old_buffer_slice = current_buffer_slice_;