diff --git a/td/mtproto/SessionConnection.cpp b/td/mtproto/SessionConnection.cpp index ad044ac05..b9a40ea40 100644 --- a/td/mtproto/SessionConnection.cpp +++ b/td/mtproto/SessionConnection.cpp @@ -25,6 +25,7 @@ #include "td/utils/SliceBuilder.h" #include "td/utils/Time.h" #include "td/utils/tl_parsers.h" +#include "td/utils/TlDowncastHelper.h" #include "td/mtproto/mtproto_api.h" #include "td/mtproto/mtproto_api.hpp" @@ -171,22 +172,6 @@ namespace mtproto { * */ -class OnPacket { - const MsgInfo &info_; - SessionConnection *connection_; - Status *status_; - - public: - OnPacket(const MsgInfo &info, SessionConnection *connection, Status *status) - : info_(info), connection_(connection), status_(status) { - } - - template - void operator()(const T &func) const { - *status_ = connection_->on_packet(info_, func); - } -}; - unique_ptr SessionConnection::move_as_raw_connection() { return std::move(raw_connection_); } @@ -230,7 +215,6 @@ Status SessionConnection::on_packet_container(const MsgInfo &info, Slice packet) }; TlParser parser(packet); - parser.fetch_int(); int32 size = parser.fetch_int(); if (parser.get_error()) { return Status::Error(PSLICE() << "Failed to parse mtproto_api::rpc_container: " << parser.get_error()); @@ -243,7 +227,6 @@ Status SessionConnection::on_packet_container(const MsgInfo &info, Slice packet) Status SessionConnection::on_packet_rpc_result(const MsgInfo &info, Slice packet) { TlParser parser(packet); - parser.fetch_int(); uint64 req_msg_id = parser.fetch_long(); if (parser.get_error()) { return Status::Error(PSLICE() << "Failed to parse mtproto_api::rpc_result: " << parser.get_error()); @@ -275,7 +258,7 @@ Status SessionConnection::on_packet_rpc_result(const MsgInfo &info, Slice packet return callback_->on_message_result_ok(req_msg_id, std::move(object), info.size); } default: - packet.remove_prefix(4 + sizeof(req_msg_id)); + packet.remove_prefix(sizeof(req_msg_id)); return callback_->on_message_result_ok(req_msg_id, as_buffer_slice(packet), info.size); } } @@ -285,12 +268,15 @@ Status SessionConnection::on_packet(const MsgInfo &info, const T &packet) { LOG(ERROR) << "Unsupported: " << to_string(packet); return Status::OK(); } + Status SessionConnection::on_packet(const MsgInfo &info, const mtproto_api::destroy_auth_key_ok &destroy_auth_key) { return on_destroy_auth_key(destroy_auth_key); } + Status SessionConnection::on_packet(const MsgInfo &info, const mtproto_api::destroy_auth_key_none &destroy_auth_key) { return on_destroy_auth_key(destroy_auth_key); } + Status SessionConnection::on_packet(const MsgInfo &info, const mtproto_api::destroy_auth_key_fail &destroy_auth_key) { return on_destroy_auth_key(destroy_auth_key); } @@ -479,52 +465,66 @@ Status SessionConnection::on_slice_packet(const MsgInfo &info, Slice packet) { if (info.seq_no & 1) { send_ack(info.message_id); } - TlParser parser(packet); - tl_object_ptr object = mtproto_api::Object::fetch(parser); - parser.fetch_end(); - if (parser.get_error()) { - // msg_container is not real tl object - if (packet.size() >= 4 && as(packet.begin()) == mtproto_api::msg_container::ID) { - return on_packet_container(info, packet); - } - if (packet.size() >= 4 && as(packet.begin()) == mtproto_api::rpc_result::ID) { - return on_packet_rpc_result(info, packet); - } - - // It is an update... I hope. - auto status = auth_data_->check_update(info.message_id); - auto recheck_status = auth_data_->recheck_update(info.message_id); - if (recheck_status.is_error() && recheck_status.code() == 2) { - LOG(WARNING) << "Receive very old update from " << get_name() << " created in " << (Time::now() - created_at_) - << " in container " << container_id_ << " from session " << auth_data_->get_session_id() - << " with message_id " << info.message_id << ", main_message_id = " << main_message_id_ - << ", seq_no = " << info.seq_no << " and original size " << info.size << ": " << status << ' ' - << recheck_status; - } - if (status.is_error()) { - if (status.code() == 2) { - LOG(WARNING) << "Receive too old update from " << get_name() << " created in " << (Time::now() - created_at_) - << " in container " << container_id_ << " from session " << auth_data_->get_session_id() - << " with message_id " << info.message_id << ", main_message_id = " << main_message_id_ - << ", seq_no = " << info.seq_no << " and original size " << info.size << ": " << status; - callback_->on_session_failed(Status::Error("Receive too old update")); - return status; - } - VLOG(mtproto) << "Skip update " << info.message_id << " of size " << info.size << " with seq_no " << info.seq_no - << " from " << get_name() << " created in " << (Time::now() - created_at_) << ": " << status; - return Status::OK(); - } else { - VLOG(mtproto) << "Got update from " << get_name() << " created in " << (Time::now() - created_at_) - << " in container " << container_id_ << " from session " << auth_data_->get_session_id() - << " with message_id " << info.message_id << ", main_message_id = " << main_message_id_ - << ", seq_no = " << info.seq_no << " and original size " << info.size; - return callback_->on_update(as_buffer_slice(packet)); - } + if (packet.size() < 4) { + callback_->on_session_failed(Status::Error("Receive too small packet")); + return Status::Error(PSLICE() << "Receive packet of size " << packet.size()); } + int32 constructor_id = as(packet.begin()); + if (constructor_id == mtproto_api::msg_container::ID) { + return on_packet_container(info, packet.substr(4)); + } + if (constructor_id == mtproto_api::rpc_result::ID) { + return on_packet_rpc_result(info, packet.substr(4)); + } + + TlDowncastHelper helper(constructor_id); Status status; - downcast_call(*object, OnPacket(info, this, &status)); - return status; + bool is_mtproto_api = downcast_call(static_cast(helper), [&](auto &dummy) { + // a constructor from mtproto_api + using Type = std::decay_t; + TlParser parser(packet.substr(4)); + auto object = Type::fetch(parser); + parser.fetch_end(); + if (parser.get_error()) { + status = parser.get_status(); + } else { + status = this->on_packet(info, static_cast(*object)); + } + }); + if (is_mtproto_api) { + return status; + } + + // It is an update... I hope. + status = auth_data_->check_update(info.message_id); + auto recheck_status = auth_data_->recheck_update(info.message_id); + if (recheck_status.is_error() && recheck_status.code() == 2) { + LOG(WARNING) << "Receive very old update from " << get_name() << " created in " << (Time::now() - created_at_) + << " in container " << container_id_ << " from session " << auth_data_->get_session_id() + << " with message_id " << info.message_id << ", main_message_id = " << main_message_id_ + << ", seq_no = " << info.seq_no << " and original size " << info.size << ": " << status << ' ' + << recheck_status; + } + if (status.is_error()) { + if (status.code() == 2) { + LOG(WARNING) << "Receive too old update from " << get_name() << " created in " << (Time::now() - created_at_) + << " in container " << container_id_ << " from session " << auth_data_->get_session_id() + << " with message_id " << info.message_id << ", main_message_id = " << main_message_id_ + << ", seq_no = " << info.seq_no << " and original size " << info.size << ": " << status; + callback_->on_session_failed(Status::Error("Receive too old update")); + return status; + } + VLOG(mtproto) << "Skip update " << info.message_id << " of size " << info.size << " with seq_no " << info.seq_no + << " from " << get_name() << " created in " << (Time::now() - created_at_) << ": " << status; + return Status::OK(); + } else { + VLOG(mtproto) << "Got update from " << get_name() << " created in " << (Time::now() - created_at_) + << " in container " << container_id_ << " from session " << auth_data_->get_session_id() + << " with message_id " << info.message_id << ", main_message_id = " << main_message_id_ + << ", seq_no = " << info.seq_no << " and original size " << info.size; + return callback_->on_update(as_buffer_slice(packet)); + } } Status SessionConnection::parse_packet(TlParser &parser) { @@ -579,6 +579,7 @@ void SessionConnection::on_message_failed(uint64 id, Status status) { on_message_failed_inner(id); } } + void SessionConnection::on_message_failed_inner(uint64 id) { auto it = service_queries_.find(id); if (it == service_queries_.end()) { diff --git a/td/mtproto/SessionConnection.h b/td/mtproto/SessionConnection.h index 9ef287ca9..3da23a4fc 100644 --- a/td/mtproto/SessionConnection.h +++ b/td/mtproto/SessionConnection.h @@ -203,8 +203,6 @@ class SessionConnection final SessionConnection::Callback *callback_ = nullptr; BufferSlice *current_buffer_slice_; - friend class OnPacket; - BufferSlice as_buffer_slice(Slice packet); auto set_buffer_slice(BufferSlice *buffer_slice) TD_WARN_UNUSED_RESULT { auto old_buffer_slice = current_buffer_slice_;