a6ab35e616
GitOrigin-RevId: 56fb21194e995cbb37b3ebb5caed689c10320b8d
970 lines
33 KiB
C++
970 lines
33 KiB
C++
//
|
|
// Copyright Aliaksei Levin (levlam@telegram.org), Arseny Smirnov (arseny30@gmail.com) 2014-2018
|
|
//
|
|
// Distributed under the Boost Software License, Version 1.0. (See accompanying
|
|
// file LICENSE_1_0.txt or copy at http://www.boost.org/LICENSE_1_0.txt)
|
|
//
|
|
#include "td/mtproto/SessionConnection.h"
|
|
|
|
#include "td/mtproto/AuthData.h"
|
|
#include "td/mtproto/AuthKey.h"
|
|
#include "td/mtproto/CryptoStorer.h"
|
|
#include "td/mtproto/Handshake.h"
|
|
#include "td/mtproto/NoCryptoStorer.h"
|
|
|
|
#include "td/mtproto/HttpTransport.h"
|
|
#include "td/mtproto/TcpTransport.h"
|
|
#include "td/mtproto/Transport.h"
|
|
|
|
#include "td/utils/format.h"
|
|
#include "td/utils/Gzip.h"
|
|
#include "td/utils/logging.h"
|
|
#include "td/utils/misc.h"
|
|
#include "td/utils/Random.h"
|
|
#include "td/utils/ScopeGuard.h"
|
|
#include "td/utils/Time.h"
|
|
#include "td/utils/tl_parsers.h"
|
|
|
|
#include "td/mtproto/mtproto_api.h"
|
|
#include "td/mtproto/mtproto_api.hpp"
|
|
|
|
#include <algorithm>
|
|
#include <iterator>
|
|
#include <type_traits>
|
|
|
|
namespace td {
|
|
namespace mtproto_api {
|
|
|
|
const int32 msg_container::ID;
|
|
|
|
class rpc_result {
|
|
public:
|
|
static const int32 ID = -212046591;
|
|
};
|
|
|
|
} // namespace mtproto_api
|
|
|
|
namespace mtproto {
|
|
/**
|
|
* TODO-list.
|
|
*
|
|
* 1. Should I check input salt?
|
|
*
|
|
* 1. Cancellation of rpc request
|
|
* input:
|
|
* - rpc_drop_answer#58e4a740 req_msg_id:long = RpcDropAnswer;
|
|
* output:
|
|
* - rpc_answer_unknown#5e2ad36e = RpcDropAnswer;
|
|
* no ack
|
|
* - rpc_answer_dropped_running#cd78e586 = RpcDropAnswer;
|
|
* same answer will be returned to original query
|
|
* ack
|
|
* - rpc_answer_dropped#a43ad8b7 msg_id:long seq_no:int bytes:int = RpcDropAnswer;
|
|
* ack
|
|
* Alternative is destroy session
|
|
*
|
|
* 5. Destroy session
|
|
* (?) when session will be destroyed otherwise
|
|
* Must be call in different session
|
|
* input:
|
|
* - destroy_session#e7512126 session_id:long = DestroySessionRes;
|
|
* output:
|
|
* - destroy_session_ok#e22045fc session_id:long = DestroySessionRes;
|
|
* - destroy_session_none#62d350c9 session_id:long = DestroySessionRes;
|
|
*
|
|
|
|
* DONE:
|
|
* 3. Ping pong
|
|
* input:
|
|
* - ping#7abe77ec ping_id:long = Pong;
|
|
* - pong#347773c5 msg_id:long ping_id:long = Pong;
|
|
*
|
|
* 4. Ping + deferred connection closure
|
|
* input:
|
|
* - ping_delay_disconnect#f3427b8c ping_id:long disconnect_delay:int = Pong;
|
|
*
|
|
* 6. New session creation
|
|
* Notification about new session.
|
|
* It is reasonable to store unique_id with current session, in order to process duplicated notifications once.
|
|
*
|
|
* Causes all older than first_msg_id to be re-sent.
|
|
* Also there is a gap in updates, so getDifference MUST be sent
|
|
* output:
|
|
* - new_session_created#9ec20908 first_msg_id:long unique_id:long server_salt:long = NewSession
|
|
*
|
|
*
|
|
* 7. Containers
|
|
* I should pack output messages as containers
|
|
* - msg_container#73f1f8dc messages:vector message = MessageContainer;
|
|
* message msg_id:long seqno:int bytes:int body:Object = Message;
|
|
*
|
|
* 8. Packed Object
|
|
* I should pack big output messages with gzip
|
|
* - gzip_packed#3072cfa1 packed_data:string = Object;
|
|
*
|
|
* 9. Ack
|
|
* I should actually send acks
|
|
* (?) Does updates need ack
|
|
* - msgs_ack#62d6b459 msg_ids:Vector long = MsgsAck;
|
|
*
|
|
* 10. Errors
|
|
* output:
|
|
* - bad_msg_notification#a7eff811 bad_msg_id:long bad_msg_seqno:int error_code:int = BadMsgNotification;
|
|
* - bad_server_salt#edab447b bad_msg_id:long bad_msg_seqno:int error_code:int new_server_salt:long =
|
|
* BadMsgNotification;
|
|
*
|
|
* error codes:
|
|
* 16: msg_id is too low. -- lite resend. It will be automatially packed in a container. I hope.
|
|
* 17: msg_id is too high. -- fail connection.
|
|
* 18: msg_id % 4 != 0. -- Error and fail connection.
|
|
* 19: container msg_id is the same as msg_id of a previously received message. MUST NEVER HAPPENS. Error and fail
|
|
* connection.
|
|
*
|
|
* 20: message is to old -- full resend. (or fail query, if we are afraid of double send)
|
|
*
|
|
* 32: seq_no is too low. (msg_id1 < msg_id2 <==> seq_no1 < seq_no2). Error and fail connection
|
|
* 33: seq_no is too hight. Error and fail connection.
|
|
* 34: (?) an even msg_seqno expected (irrelevant message), but odd received. (Fail and call a developer...)
|
|
* 35: (?) odd msg_seqno expected (relevant message), but even received (Fail and call a developer)
|
|
*
|
|
* 48: incorrect server salt (in bad_server_salt message)
|
|
*
|
|
* 64: (?) invalid container
|
|
*
|
|
* 2. Get future salts
|
|
* input:
|
|
* - get_future_salts#b921bd04 num:int = FutureSalts;
|
|
* output:
|
|
* - future_salts#ae500895 req_msg_id:long now:int salts:vector future_salt = FutureSalts;
|
|
* future_salt#0949d9dc valid_since:int valid_until:int salt:long = FutureSalt;
|
|
*
|
|
* 1. ping-pong
|
|
* 3. Delayed ack.
|
|
* 4. Delayed packet.
|
|
* HTTP transport (support of several connections)
|
|
* get future salts
|
|
*
|
|
* 11. Explicit request of messages states
|
|
* input:
|
|
* - msgs_state_req#da69fb52 msg_ids:Vector long = MsgsStateReq;
|
|
* 12. States of messages in answer
|
|
* output:
|
|
* - msgs_state_info#04deb57d req_msg_id:long info:string = MsgsStateInfo;
|
|
* info contain one byte for each queired msg_id
|
|
* bytes:
|
|
* 1: unknown message_id
|
|
* 2: message not received (msg_id in stored range)
|
|
* 3: message not receiver (msg_id is too high)
|
|
* 4: message received. No extra ack will be sent
|
|
* +8: message already acknowledged
|
|
* +16: message not requiring acknolegement
|
|
* +32: RPC query contained in message being processed or processing already complete
|
|
* +64: content-related response to message already generated
|
|
* +128: other party knows for a fact that message is already received
|
|
*
|
|
* 13. Voluntary Communication of Status of Messages
|
|
* output:
|
|
* - msgs_all_info#8cc0d131 msg_ids:Vector long info:string = MsgsAllInfo
|
|
*
|
|
*/
|
|
|
|
class OnPacket {
|
|
const MsgInfo &info_;
|
|
SessionConnection *connection_;
|
|
Status *status_;
|
|
|
|
public:
|
|
OnPacket(const MsgInfo &info, SessionConnection *connection, Status *status)
|
|
: info_(info), connection_(connection), status_(status) {
|
|
}
|
|
|
|
template <class T>
|
|
void operator()(const T &func) const {
|
|
*status_ = connection_->on_packet(info_, func);
|
|
}
|
|
};
|
|
|
|
/*** SessionConnection ***/
|
|
BufferSlice SessionConnection::as_buffer_slice(Slice packet) {
|
|
return current_buffer_slice->from_slice(packet);
|
|
}
|
|
|
|
Status SessionConnection::parse_message(TlParser &parser, MsgInfo *info, Slice *packet, bool crypto_flag) {
|
|
// msg_id:long seqno:int bytes:int
|
|
parser.check_len(sizeof(int64) + (crypto_flag ? sizeof(int32) : 0) + sizeof(int32));
|
|
if (parser.get_error() != nullptr) {
|
|
return Status::Error(PSLICE() << "Failed to parse mtproto_api::message: " << parser.get_error());
|
|
}
|
|
info->message_id = parser.fetch_long_unsafe();
|
|
if (crypto_flag) {
|
|
info->seq_no = parser.fetch_int_unsafe();
|
|
}
|
|
uint32 bytes = parser.fetch_int_unsafe();
|
|
|
|
if (bytes % sizeof(int32) != 0) {
|
|
return Status::Error(PSLICE() << "Failed to parse mtproto_api::message: size of message [" << bytes
|
|
<< "] is not divisible by 4");
|
|
}
|
|
|
|
*packet = parser.fetch_string_raw<Slice>(bytes);
|
|
if (parser.get_error() != nullptr) {
|
|
return Status::Error(PSLICE() << "Failed to parse mtproto_api::message: " << parser.get_error());
|
|
}
|
|
|
|
info->size = bytes;
|
|
|
|
return Status::OK();
|
|
}
|
|
|
|
Status SessionConnection::on_packet_container(const MsgInfo &info, Slice packet) {
|
|
auto old_container_id = container_id_;
|
|
container_id_ = info.message_id;
|
|
SCOPE_EXIT {
|
|
container_id_ = old_container_id;
|
|
};
|
|
|
|
TlParser parser(packet);
|
|
parser.fetch_int();
|
|
int32 size = parser.fetch_int();
|
|
if (parser.get_error()) {
|
|
return Status::Error(PSLICE() << "Failed to parse mtproto_api::rpc_container: " << parser.get_error());
|
|
}
|
|
for (int i = 0; i < size; i++) {
|
|
TRY_STATUS(parse_packet(parser));
|
|
}
|
|
return Status::OK();
|
|
}
|
|
|
|
Status SessionConnection::on_packet_rpc_result(const MsgInfo &info, Slice packet) {
|
|
TlParser parser(packet);
|
|
parser.fetch_int();
|
|
uint64 req_msg_id = parser.fetch_long();
|
|
if (parser.get_error()) {
|
|
return Status::Error(PSLICE() << "Failed to parse mtproto_api::rpc_result: " << parser.get_error());
|
|
}
|
|
|
|
auto object_begin_pos = packet.size() - parser.get_left_len();
|
|
int32 id = parser.fetch_int();
|
|
if (id == mtproto_api::rpc_error::ID) {
|
|
mtproto_api::rpc_error rpc_error(parser);
|
|
if (parser.get_error()) {
|
|
return Status::Error(PSLICE() << "Failed to parse mtproto_api::rpc_error: " << parser.get_error());
|
|
}
|
|
return on_packet(info, req_msg_id, rpc_error);
|
|
} else if (id == mtproto_api::gzip_packed::ID) {
|
|
mtproto_api::gzip_packed gzip(parser);
|
|
if (parser.get_error()) {
|
|
return Status::Error(PSLICE() << "Failed to parse mtproto_api::gzip_packed: " << parser.get_error());
|
|
}
|
|
// yep, gzip in rpc_result
|
|
BufferSlice object = gzdecode(gzip.packed_data_);
|
|
// send header no more optimization
|
|
return callback_->on_message_result_ok(req_msg_id, std::move(object), info.size);
|
|
}
|
|
|
|
return callback_->on_message_result_ok(req_msg_id, as_buffer_slice(packet.substr(object_begin_pos)), info.size);
|
|
}
|
|
|
|
template <class T>
|
|
Status SessionConnection::on_packet(const MsgInfo &info, const T &packet) {
|
|
LOG(ERROR) << "Unsupported: " << to_string(packet);
|
|
return Status::OK();
|
|
}
|
|
|
|
Status SessionConnection::on_packet(const MsgInfo &info, const mtproto_api::rpc_error &rpc_error) {
|
|
return on_packet(info, 0, rpc_error);
|
|
}
|
|
|
|
Status SessionConnection::on_packet(const MsgInfo &info, uint64 req_msg_id, const mtproto_api::rpc_error &rpc_error) {
|
|
VLOG(mtproto) << "ERROR [code:" << rpc_error.error_code_ << "] [msg:" << rpc_error.error_message_.str().c_str()
|
|
<< "]";
|
|
if (req_msg_id != 0) {
|
|
callback_->on_message_result_error(req_msg_id, rpc_error.error_code_, as_buffer_slice(rpc_error.error_message_));
|
|
} else {
|
|
LOG(WARNING) << "rpc_error as update: [" << rpc_error.error_code_ << "][" << rpc_error.error_message_ << "]";
|
|
}
|
|
return Status::OK();
|
|
}
|
|
|
|
Status SessionConnection::on_packet(const MsgInfo &info, const mtproto_api::new_session_created &new_session_created) {
|
|
VLOG(mtproto) << "NEW_SESSION_CREATED: [first_msg_id:" << format::as_hex(new_session_created.first_msg_id_)
|
|
<< "] [unique_id:" << format::as_hex(new_session_created.unique_id_)
|
|
<< "] [server_salt:" << format::as_hex(new_session_created.server_salt_) << "]";
|
|
callback_->on_session_created(new_session_created.unique_id_, new_session_created.first_msg_id_);
|
|
return Status::OK();
|
|
}
|
|
|
|
Status SessionConnection::on_packet(const MsgInfo &info,
|
|
const mtproto_api::bad_msg_notification &bad_msg_notification) {
|
|
MsgInfo bad_info{info.session_id, bad_msg_notification.bad_msg_id_, bad_msg_notification.bad_msg_seqno_, 0};
|
|
enum Code {
|
|
MsgIdTooLow = 16,
|
|
MsgIdTooHigh = 17,
|
|
MsgIdMod4 = 18,
|
|
MsgIdCollision = 19,
|
|
|
|
MsgIdTooOld = 20,
|
|
|
|
SeqNoTooLow = 32,
|
|
SeqNoTooHigh = 33,
|
|
SeqNoNotEven = 34,
|
|
SeqNoNotOdd = 35,
|
|
|
|
InvalidContainer = 64
|
|
};
|
|
Slice common = " BUG! CALL FOR A DEVELOPER! Session will be closed";
|
|
switch (bad_msg_notification.error_code_) {
|
|
case MsgIdTooLow: {
|
|
LOG(WARNING) << bad_info << ": MessageId is too low. Message will be re-sent";
|
|
// time will be updated automagically
|
|
on_message_failed(bad_info.message_id, Status::Error("MessageId is too low"));
|
|
break;
|
|
}
|
|
case MsgIdTooHigh: {
|
|
LOG(ERROR) << bad_info << ": MessageId is too high. Session will be closed";
|
|
// All this queries will be re-sent by parent
|
|
to_send_.clear();
|
|
callback_->on_session_failed(Status::Error("MessageId is too high"));
|
|
return Status::Error("MessageId is too high");
|
|
}
|
|
case MsgIdMod4: {
|
|
LOG(ERROR) << bad_info << ": MessageId is not divisible by 4." << common;
|
|
return Status::Error("MessageId is not divisible by 4");
|
|
}
|
|
case MsgIdCollision: {
|
|
LOG(ERROR) << bad_info << ": Container and older message MessageId collision." << common;
|
|
return Status::Error("Container and older message MessageId collision");
|
|
}
|
|
|
|
case MsgIdTooOld: {
|
|
LOG(WARNING) << bad_info << ": MessageId is too old. Message will be re-sent";
|
|
on_message_failed(bad_info.message_id, Status::Error("MessageId is too old"));
|
|
break;
|
|
}
|
|
|
|
case SeqNoTooLow: {
|
|
LOG(ERROR) << bad_info << ": SeqNo is too low." << common;
|
|
return Status::Error("SeqNo is too low");
|
|
}
|
|
case SeqNoTooHigh: {
|
|
LOG(ERROR) << bad_info << ": SeqNo is too high." << common;
|
|
return Status::Error("SeqNo is too high");
|
|
}
|
|
case SeqNoNotEven: {
|
|
LOG(ERROR) << bad_info << ": SeqNo is not even for an irrelevant message." << common;
|
|
return Status::Error("SeqNo is not even for an irrelevant message");
|
|
}
|
|
case SeqNoNotOdd: {
|
|
LOG(ERROR) << bad_info << ": SeqNo is not odd for an irrelevant message." << common;
|
|
return Status::Error("SeqNo is not odd for an irrelevant message");
|
|
}
|
|
|
|
case InvalidContainer: {
|
|
LOG(ERROR) << bad_info << ": Invalid Contailer." << common;
|
|
return Status::Error("Invalid Contailer");
|
|
}
|
|
|
|
default: {
|
|
LOG(ERROR) << bad_info << ": Unknown error [code:" << bad_msg_notification.error_code_ << "]." << common;
|
|
return Status::Error("Unknown error code");
|
|
}
|
|
}
|
|
return Status::OK();
|
|
}
|
|
|
|
Status SessionConnection::on_packet(const MsgInfo &info, const mtproto_api::bad_server_salt &bad_server_salt) {
|
|
MsgInfo bad_info{info.session_id, bad_server_salt.bad_msg_id_, bad_server_salt.bad_msg_seqno_, 0};
|
|
VLOG(mtproto) << "BAD_SERVER_SALT: " << bad_info;
|
|
auth_data_->set_server_salt(bad_server_salt.new_server_salt_, Time::now_cached());
|
|
callback_->on_server_salt_updated();
|
|
|
|
on_message_failed(bad_info.message_id, Status::Error("Bad server salt"));
|
|
return Status::OK();
|
|
}
|
|
|
|
Status SessionConnection::on_packet(const MsgInfo &info, const mtproto_api::msgs_ack &msgs_ack) {
|
|
for (auto id : msgs_ack.msg_ids_) {
|
|
callback_->on_message_ack(id);
|
|
}
|
|
return Status::OK();
|
|
}
|
|
|
|
Status SessionConnection::on_packet(const MsgInfo &info, const mtproto_api::gzip_packed &gzip_packed) {
|
|
BufferSlice res = gzdecode(gzip_packed.packed_data_);
|
|
auto guard = set_buffer_slice(&res);
|
|
return on_slice_packet(info, res.as_slice());
|
|
}
|
|
|
|
Status SessionConnection::on_packet(const MsgInfo &info, const mtproto_api::pong &pong) {
|
|
VLOG(mtproto) << "PONG";
|
|
last_pong_at_ = Time::now_cached();
|
|
return callback_->on_pong();
|
|
}
|
|
Status SessionConnection::on_packet(const MsgInfo &info, const mtproto_api::future_salts &salts) {
|
|
VLOG(mtproto) << "FUTURE_SALTS";
|
|
std::vector<ServerSalt> new_salts;
|
|
for (auto &it : salts.salts_) {
|
|
new_salts.push_back(
|
|
ServerSalt{it->salt_, static_cast<double>(it->valid_since_), static_cast<double>(it->valid_until_)});
|
|
}
|
|
auth_data_->set_future_salts(new_salts, Time::now_cached());
|
|
callback_->on_server_salt_updated();
|
|
|
|
return Status::OK();
|
|
}
|
|
|
|
Status SessionConnection::on_msgs_state_info(const std::vector<int64> &ids, Slice info) {
|
|
if (ids.size() != info.size()) {
|
|
return Status::Error(PSLICE() << tag("ids.size()", ids.size()) << "!=" << tag("info.size()", info.size()));
|
|
}
|
|
size_t i = 0;
|
|
for (auto id : ids) {
|
|
callback_->on_message_info(id, info[i], 0, 0);
|
|
i++;
|
|
}
|
|
return Status::OK();
|
|
}
|
|
|
|
Status SessionConnection::on_packet(const MsgInfo &info, const mtproto_api::msgs_state_info &msgs_state_info) {
|
|
auto it = service_queries_.find(msgs_state_info.req_msg_id_);
|
|
if (it == service_queries_.end()) {
|
|
return Status::Error("Unknown msgs_state_info");
|
|
}
|
|
SCOPE_EXIT {
|
|
service_queries_.erase(it);
|
|
};
|
|
if (it->second.type != ServiceQuery::GetStateInfo) {
|
|
return Status::Error("Got msg_state_info in response not to GetStateInfo");
|
|
}
|
|
return on_msgs_state_info(it->second.message_ids, msgs_state_info.info_);
|
|
}
|
|
|
|
Status SessionConnection::on_packet(const MsgInfo &info, const mtproto_api::msgs_all_info &msgs_all_info) {
|
|
return on_msgs_state_info(msgs_all_info.msg_ids_, msgs_all_info.info_);
|
|
}
|
|
|
|
Status SessionConnection::on_packet(const MsgInfo &info, const mtproto_api::msg_detailed_info &msg_detailed_info) {
|
|
callback_->on_message_info(msg_detailed_info.msg_id_, msg_detailed_info.status_, msg_detailed_info.answer_msg_id_,
|
|
msg_detailed_info.bytes_);
|
|
return Status::OK();
|
|
}
|
|
|
|
Status SessionConnection::on_packet(const MsgInfo &info,
|
|
const mtproto_api::msg_new_detailed_info &msg_new_detailed_info) {
|
|
callback_->on_message_info(0, 0, msg_new_detailed_info.answer_msg_id_, msg_new_detailed_info.bytes_);
|
|
return Status::OK();
|
|
}
|
|
|
|
Status SessionConnection::on_slice_packet(const MsgInfo &info, Slice packet) {
|
|
if (info.seq_no & 1) {
|
|
send_ack(info.message_id);
|
|
}
|
|
TlParser parser(packet);
|
|
tl_object_ptr<mtproto_api::Object> object = mtproto_api::Object::fetch(parser);
|
|
if (parser.get_error()) {
|
|
// msg_container is not real tl object
|
|
if (packet.size() >= 4 && as<int32>(packet.begin()) == mtproto_api::msg_container::ID) {
|
|
return on_packet_container(info, packet);
|
|
}
|
|
if (packet.size() >= 4 && as<int32>(packet.begin()) == mtproto_api::rpc_result::ID) {
|
|
return on_packet_rpc_result(info, packet);
|
|
}
|
|
|
|
// It is an update... I hope.
|
|
auto status = auth_data_->check_update(info.message_id);
|
|
if (status.is_error()) {
|
|
VLOG(mtproto) << "Skip update " << info.message_id << " from " << get_name() << " created in "
|
|
<< (Time::now() - created_at_) << ": " << status;
|
|
return Status::OK();
|
|
} else {
|
|
VLOG(mtproto) << "Got update from " << get_name() << " created in " << (Time::now() - created_at_)
|
|
<< " in container " << container_id_ << " from session " << auth_data_->session_id_
|
|
<< " with message_id " << info.message_id << ", main_message_id = " << main_message_id_
|
|
<< ", seq_no = " << info.seq_no << " and original size " << info.size;
|
|
return callback_->on_message_result_ok(0, as_buffer_slice(packet), info.size);
|
|
}
|
|
}
|
|
|
|
Status status;
|
|
downcast_call(*object, OnPacket(info, this, &status));
|
|
return status;
|
|
}
|
|
|
|
Status SessionConnection::parse_packet(TlParser &parser) {
|
|
MsgInfo info;
|
|
Slice packet;
|
|
TRY_STATUS(parse_message(parser, &info, &packet, true));
|
|
return on_slice_packet(info, packet);
|
|
}
|
|
|
|
Status SessionConnection::on_main_packet(const PacketInfo &info, Slice packet) {
|
|
// Update pong here too. Real pong can be delayed by lots of big packets
|
|
last_pong_at_ = Time::now_cached();
|
|
|
|
if (!connected_flag_) {
|
|
connected_flag_ = true;
|
|
callback_->on_connected();
|
|
}
|
|
|
|
VLOG(raw_mtproto) << "Got packet: [session_id:" << format::as_hex(info.session_id) << "] "
|
|
<< format::as_hex_dump<4>(packet);
|
|
if (info.no_crypto_flag) {
|
|
return Status::Error("Unencrypted packet");
|
|
}
|
|
|
|
TlParser parser(packet);
|
|
TRY_STATUS(parse_packet(parser));
|
|
parser.fetch_end();
|
|
if (parser.get_error()) {
|
|
return Status::Error(PSLICE() << "Failed to parse packet: " << parser.get_error());
|
|
}
|
|
return Status::OK();
|
|
}
|
|
|
|
void SessionConnection::on_message_failed(uint64 id, Status status) {
|
|
callback_->on_message_failed(id, std::move(status));
|
|
|
|
if (id == last_ping_message_id_ || id == last_ping_container_id_) {
|
|
// restart ping immediately
|
|
last_ping_at_ = 0;
|
|
last_ping_message_id_ = 0;
|
|
last_ping_container_id_ = 0;
|
|
}
|
|
|
|
auto cit = container_to_service_msg_.find(id);
|
|
if (cit != container_to_service_msg_.end()) {
|
|
for (auto nid : cit->second) {
|
|
on_message_failed_inner(nid);
|
|
}
|
|
} else {
|
|
on_message_failed_inner(id);
|
|
}
|
|
}
|
|
void SessionConnection::on_message_failed_inner(uint64 id) {
|
|
auto it = service_queries_.find(id);
|
|
if (it == service_queries_.end()) {
|
|
return;
|
|
}
|
|
switch (it->second.type) {
|
|
case ServiceQuery::ResendAnswer: {
|
|
for (auto message_id : it->second.message_ids) {
|
|
resend_answer(message_id);
|
|
}
|
|
break;
|
|
}
|
|
case ServiceQuery::GetStateInfo: {
|
|
for (auto message_id : it->second.message_ids) {
|
|
get_state_info(message_id);
|
|
}
|
|
break;
|
|
}
|
|
}
|
|
service_queries_.erase(id);
|
|
}
|
|
|
|
bool SessionConnection::must_flush_packet() {
|
|
flush_packet_at_ = 0;
|
|
|
|
// we need key to send just something
|
|
if (!auth_data_->has_auth_key(Time::now_cached())) {
|
|
return false;
|
|
}
|
|
|
|
// transport is ready
|
|
if (!raw_connection_->can_send()) {
|
|
return false;
|
|
}
|
|
|
|
bool has_salt = auth_data_->has_salt(Time::now_cached());
|
|
// do not send anything in long poll connection before we have salt
|
|
if (mode_ == Mode::HttpLongPoll && !has_salt) {
|
|
return false;
|
|
}
|
|
|
|
// http_wait
|
|
if (mode_ == Mode::HttpLongPoll) {
|
|
return true;
|
|
}
|
|
// queries and acks (+ resend & get_info)
|
|
if (has_salt && force_send_at_ != 0) {
|
|
if (Time::now_cached() > force_send_at_) {
|
|
return true;
|
|
} else {
|
|
relax_timeout_at(&flush_packet_at_, force_send_at_);
|
|
}
|
|
}
|
|
|
|
// ping
|
|
if (has_salt) {
|
|
if (must_ping()) {
|
|
return true;
|
|
}
|
|
relax_timeout_at(&flush_packet_at_, last_ping_at_ + ping_must_delay());
|
|
}
|
|
// get_future_salt
|
|
if (!has_salt) {
|
|
auto get_future_salts_at = last_get_future_salt_at_ + 60;
|
|
if (get_future_salts_at < Time::now_cached()) {
|
|
return true;
|
|
}
|
|
relax_timeout_at(&flush_packet_at_, get_future_salts_at);
|
|
}
|
|
|
|
return false;
|
|
}
|
|
|
|
Status SessionConnection::before_write() {
|
|
while (must_flush_packet()) {
|
|
flush_packet();
|
|
}
|
|
return Status::OK();
|
|
}
|
|
|
|
Status SessionConnection::on_raw_packet(const td::mtproto::PacketInfo &info, BufferSlice packet) {
|
|
auto old_main_message_id = main_message_id_;
|
|
main_message_id_ = info.message_id;
|
|
SCOPE_EXIT {
|
|
main_message_id_ = old_main_message_id;
|
|
};
|
|
|
|
if (info.no_crypto_flag) {
|
|
return Status::Error("Unexpected unencrypted packet");
|
|
}
|
|
|
|
bool time_difference_was_updated = false;
|
|
auto status =
|
|
auth_data_->check_packet(info.session_id, info.message_id, Time::now_cached(), time_difference_was_updated);
|
|
if (time_difference_was_updated) {
|
|
callback_->on_server_time_difference_updated();
|
|
}
|
|
if (status.is_error()) {
|
|
if (status.code() == 1) {
|
|
LOG(WARNING) << "Packet ignored " << status;
|
|
send_ack(info.message_id);
|
|
return Status::OK();
|
|
} else {
|
|
return status;
|
|
}
|
|
}
|
|
|
|
auto guard = set_buffer_slice(&packet);
|
|
TRY_STATUS(on_main_packet(info, packet.as_slice()));
|
|
return Status::OK();
|
|
}
|
|
|
|
Status SessionConnection::on_quick_ack(uint64 quick_ack_token) {
|
|
callback_->on_message_ack(quick_ack_token);
|
|
return Status::OK();
|
|
}
|
|
SessionConnection::SessionConnection(Mode mode, std::unique_ptr<RawConnection> raw_connection, AuthData *auth_data,
|
|
DhCallback *dh_callback)
|
|
: raw_connection_(std::move(raw_connection)), auth_data_(auth_data), dh_callback_(dh_callback) {
|
|
state_ = Init;
|
|
mode_ = mode;
|
|
created_at_ = Time::now();
|
|
}
|
|
|
|
Fd &SessionConnection::get_pollable() {
|
|
return raw_connection_->get_pollable();
|
|
}
|
|
|
|
Status SessionConnection::init() {
|
|
CHECK(state_ == Init);
|
|
last_pong_at_ = Time::now_cached();
|
|
state_ = Run;
|
|
return Status::OK();
|
|
}
|
|
|
|
void SessionConnection::set_online(bool online_flag) {
|
|
online_flag_ = online_flag;
|
|
if (online_flag_) {
|
|
last_pong_at_ = Time::now_cached() - ping_disconnect_delay() + rtt(); // disconnect if no ping in 1 second
|
|
} else {
|
|
last_pong_at_ = Time::now_cached();
|
|
}
|
|
last_ping_at_ = 0;
|
|
last_ping_message_id_ = 0;
|
|
last_ping_container_id_ = 0;
|
|
}
|
|
|
|
void SessionConnection::do_close(Status status) {
|
|
state_ = Closed;
|
|
callback_->on_before_close();
|
|
raw_connection_->close();
|
|
// NB: this could be destroyed after on_closed
|
|
callback_->on_closed(std::move(status));
|
|
}
|
|
|
|
void SessionConnection::send_crypto(const Storer &storer, uint64 quick_ack_token) {
|
|
CHECK(state_ != Closed);
|
|
raw_connection_->send_crypto(storer, auth_data_->get_session_id(), auth_data_->get_server_salt(Time::now_cached()),
|
|
auth_data_->get_auth_key(), quick_ack_token);
|
|
}
|
|
|
|
Result<uint64> SessionConnection::send_query(BufferSlice buffer, bool gzip_flag, int64 message_id,
|
|
uint64 invoke_after_id, bool use_quick_ack) {
|
|
CHECK(mode_ != Mode::HttpLongPoll) << "LongPoll connection is only for http_wait";
|
|
if (message_id == 0) {
|
|
message_id = auth_data_->next_message_id(Time::now_cached());
|
|
}
|
|
auto seq_no = auth_data_->next_seq_no(true);
|
|
if (to_send_.empty()) {
|
|
send_before(Time::now_cached() + QUERY_DELAY);
|
|
}
|
|
to_send_.push_back(Query{message_id, seq_no, std::move(buffer), gzip_flag, invoke_after_id, use_quick_ack});
|
|
|
|
return message_id;
|
|
}
|
|
|
|
void SessionConnection::get_state_info(int64 message_id) {
|
|
if (to_get_state_info_.empty()) {
|
|
send_before(Time::now_cached());
|
|
}
|
|
to_get_state_info_.push_back(message_id);
|
|
}
|
|
|
|
void SessionConnection::resend_answer(int64 message_id) {
|
|
if (to_resend_answer_.empty()) {
|
|
send_before(Time::now_cached() + RESEND_ANSWER_DELAY);
|
|
}
|
|
to_resend_answer_.push_back(message_id);
|
|
}
|
|
void SessionConnection::cancel_answer(int64 message_id) {
|
|
if (to_cancel_answer_.empty()) {
|
|
send_before(Time::now_cached() + RESEND_ANSWER_DELAY);
|
|
}
|
|
to_cancel_answer_.push_back(message_id);
|
|
}
|
|
|
|
std::pair<uint64, BufferSlice> SessionConnection::encrypted_bind(int64 perm_key, int64 nonce, int32 expire_at) {
|
|
int64 temp_key = auth_data_->get_tmp_auth_key().id();
|
|
|
|
mtproto_api::bind_auth_key_inner object(nonce, temp_key, perm_key, auth_data_->session_id_, expire_at);
|
|
auto object_storer = create_storer(object);
|
|
auto object_packet = BufferWriter{object_storer.size(), 0, 0};
|
|
object_storer.store(object_packet.as_slice().ubegin());
|
|
|
|
Query query{auth_data_->next_message_id(Time::now_cached()), 0, object_packet.as_buffer_slice(), false, 0, false};
|
|
PacketStorer<QueryImpl> query_storer(query, Slice());
|
|
|
|
mtproto::PacketInfo info;
|
|
info.version = 1;
|
|
info.no_crypto_flag = false;
|
|
info.salt = Random::secure_int64();
|
|
info.session_id = Random::secure_int64();
|
|
|
|
auto packet = BufferWriter{mtproto::Transport::write(query_storer, auth_data_->get_main_auth_key(), &info), 0, 0};
|
|
mtproto::Transport::write(query_storer, auth_data_->get_main_auth_key(), &info, packet.as_slice());
|
|
return std::make_pair(query.message_id, packet.as_buffer_slice());
|
|
}
|
|
|
|
void SessionConnection::send_ack(uint64 message_id) {
|
|
VLOG(mtproto) << "Send ack: [msg_id:" << format::as_hex(message_id) << "]";
|
|
if (to_ack_.empty()) {
|
|
send_before(Time::now_cached() + ACK_DELAY);
|
|
}
|
|
auto ack = static_cast<int64>(message_id);
|
|
// an easiest way to eliminate duplicated acks for gzipped packets
|
|
if (to_ack_.empty() || to_ack_.back() != ack) {
|
|
to_ack_.push_back(ack);
|
|
}
|
|
}
|
|
|
|
// don't send ping in poll mode.
|
|
bool SessionConnection::may_ping() const {
|
|
return last_ping_at_ == 0 || (mode_ != Mode::HttpLongPoll && last_ping_at_ + ping_may_delay() < Time::now_cached());
|
|
}
|
|
|
|
bool SessionConnection::must_ping() const {
|
|
return last_ping_at_ == 0 || (mode_ != Mode::HttpLongPoll && last_ping_at_ + ping_must_delay() < Time::now_cached());
|
|
}
|
|
|
|
void SessionConnection::flush_packet() {
|
|
bool has_salt = auth_data_->has_salt(Time::now_cached());
|
|
// ping
|
|
uint64 container_id = 0;
|
|
int64 ping_id = 0;
|
|
if (has_salt && may_ping()) {
|
|
ping_id = ++cur_ping_id_;
|
|
last_ping_at_ = Time::now_cached();
|
|
}
|
|
|
|
// http_wait
|
|
int max_delay = -1;
|
|
int max_after = -1;
|
|
int max_wait = -1;
|
|
if (mode_ == Mode::HttpLongPoll) {
|
|
max_delay = HTTP_MAX_DELAY;
|
|
max_after = HTTP_MAX_AFTER;
|
|
max_wait = min(http_max_wait(),
|
|
static_cast<int>(1000 * max(0.1, ping_disconnect_delay() + last_pong_at_ - Time::now_cached() - 1)));
|
|
} else if (mode_ == Mode::Http) {
|
|
max_delay = HTTP_MAX_DELAY;
|
|
max_after = HTTP_MAX_AFTER;
|
|
max_wait = 0;
|
|
}
|
|
|
|
// future salts
|
|
int future_salt_n = 0;
|
|
if (mode_ != Mode::HttpLongPoll) {
|
|
if (auth_data_->need_future_salts(Time::now_cached()) && last_get_future_salt_at_ + 60 < Time::now_cached()) {
|
|
last_get_future_salt_at_ = Time::now_cached();
|
|
future_salt_n = 64;
|
|
}
|
|
}
|
|
|
|
size_t send_till = 0, send_size = 0;
|
|
// send at most 1020 queries, of total size 2^15
|
|
// don't send anything if have no salt
|
|
if (has_salt) {
|
|
while (send_till < to_send_.size() && send_till < 1020 && send_size < (1 << 15)) {
|
|
send_size += to_send_[send_till].packet.size();
|
|
send_till++;
|
|
}
|
|
}
|
|
std::vector<Query> queries;
|
|
if (send_till == to_send_.size()) {
|
|
queries = std::move(to_send_);
|
|
} else if (send_till != 0) {
|
|
queries.reserve(send_till);
|
|
std::move(to_send_.begin(), to_send_.begin() + send_till, std::back_inserter(queries));
|
|
to_send_.erase(to_send_.begin(), to_send_.begin() + send_till);
|
|
}
|
|
|
|
if (queries.empty() && to_ack_.empty() && ping_id == 0 && max_delay < 0 && future_salt_n == 0 &&
|
|
to_resend_answer_.empty() && to_cancel_answer_.empty() && to_get_state_info_.empty()) {
|
|
force_send_at_ = 0;
|
|
return;
|
|
}
|
|
|
|
VLOG(mtproto) << "Sent packet: " << tag("query_count", queries.size()) << tag("ack_cnt", to_ack_.size())
|
|
<< tag("ping", ping_id != 0) << tag("http_wait", max_delay >= 0)
|
|
<< tag("future_salt", future_salt_n > 0) << tag("get_info", to_get_state_info_.size())
|
|
<< tag("resend", to_resend_answer_.size()) << tag("cancel", to_cancel_answer_.size())
|
|
<< tag("auth_id", auth_data_->get_auth_key().id());
|
|
|
|
auto cut_tail = [](auto &v, size_t size, Slice name) {
|
|
if (size >= v.size()) {
|
|
return std::move(v);
|
|
}
|
|
LOG(WARNING) << "Too much ids in container: " << v.size() << " " << name;
|
|
std::decay_t<decltype(v)> res(std::make_move_iterator(v.end() - size), std::make_move_iterator(v.end()));
|
|
v.resize(v.size() - size);
|
|
return res;
|
|
};
|
|
|
|
// no more than 8192 ids per container..
|
|
auto to_resend_answer = cut_tail(to_resend_answer_, 8192, "resend_answer");
|
|
uint64 resend_answer_id = 0;
|
|
CHECK(queries.size() <= 1020);
|
|
auto to_cancel_answer = cut_tail(to_cancel_answer_, 1020 - queries.size(), "cancel_answer");
|
|
auto to_get_state_info = cut_tail(to_get_state_info_, 8192, "get_state_info");
|
|
uint64 get_state_info_id = 0;
|
|
auto to_ack = cut_tail(to_ack_, 8192, "ack");
|
|
uint64 ping_message_id = 0;
|
|
|
|
bool use_quick_ack =
|
|
std::any_of(queries.begin(), queries.end(), [](const auto &query) { return query.use_quick_ack; });
|
|
|
|
{
|
|
uint64 parent_message_id = 0;
|
|
auto storer = PacketStorer<CryptoImpl>(
|
|
queries, auth_data_->header(), std::move(to_ack), ping_id, ping_disconnect_delay() + 2, max_delay, max_after,
|
|
max_wait, future_salt_n, to_get_state_info, to_resend_answer, to_cancel_answer, auth_data_, &container_id,
|
|
&get_state_info_id, &resend_answer_id, &ping_message_id, &parent_message_id);
|
|
|
|
auto quick_ack_token = use_quick_ack ? parent_message_id : 0;
|
|
send_crypto(storer, quick_ack_token);
|
|
}
|
|
|
|
if (resend_answer_id) {
|
|
service_queries_.insert({resend_answer_id, ServiceQuery{ServiceQuery::ResendAnswer, std::move(to_resend_answer)}});
|
|
}
|
|
if (get_state_info_id) {
|
|
service_queries_.insert(
|
|
{get_state_info_id, ServiceQuery{ServiceQuery::GetStateInfo, std::move(to_get_state_info)}});
|
|
}
|
|
if (ping_id != 0) {
|
|
last_ping_container_id_ = container_id;
|
|
last_ping_message_id_ = ping_message_id;
|
|
}
|
|
|
|
if (container_id != 0) {
|
|
vector<uint64> ids = transform(queries, [](const Query &x) { return static_cast<uint64>(x.message_id); });
|
|
|
|
// some acks may be lost here. Nobody will resend them if something goes wrong with query.
|
|
// It is mostly problem for server. We will just drop this answers in next connection
|
|
//
|
|
// get future salt too.
|
|
// So I will re-ask salt if have no answer in 60 second.
|
|
callback_->on_container_sent(container_id, std::move(ids));
|
|
|
|
if (resend_answer_id) {
|
|
container_to_service_msg_[container_id].push_back(resend_answer_id);
|
|
}
|
|
if (get_state_info_id) {
|
|
container_to_service_msg_[container_id].push_back(get_state_info_id);
|
|
}
|
|
}
|
|
|
|
to_ack_.clear();
|
|
if (to_send_.empty()) {
|
|
force_send_at_ = 0;
|
|
}
|
|
}
|
|
|
|
void SessionConnection::send_before(double tm) {
|
|
if (force_send_at_ == 0 || force_send_at_ > tm) {
|
|
force_send_at_ = tm;
|
|
}
|
|
}
|
|
|
|
Status SessionConnection::do_flush() {
|
|
CHECK(state_ != Closed);
|
|
if (state_ == Init) {
|
|
TRY_STATUS(init());
|
|
}
|
|
if (!auth_data_->has_auth_key(Time::now_cached())) {
|
|
return Status::Error("No auth key");
|
|
}
|
|
|
|
TRY_STATUS(raw_connection_->flush(auth_data_->get_auth_key(), *this));
|
|
|
|
// check last pong
|
|
if (last_pong_at_ != 0 && last_pong_at_ + ping_disconnect_delay() < Time::now_cached()) {
|
|
raw_connection_->stats_callback()->on_error();
|
|
return Status::Error("No pong :(");
|
|
}
|
|
|
|
return Status::OK();
|
|
}
|
|
|
|
double SessionConnection::flush(SessionConnection::Callback *callback) {
|
|
callback_ = callback;
|
|
wakeup_at_ = 0;
|
|
auto status = do_flush();
|
|
// check error
|
|
if (status.is_error()) {
|
|
do_close(std::move(status));
|
|
return 0;
|
|
}
|
|
|
|
// wakeup_at
|
|
// two independent timeouts
|
|
// 1. close connection after PING_DISCONNECT_DELAY after last_pong.
|
|
// 2. the one returned by must_flush_packet
|
|
relax_timeout_at(&wakeup_at_, last_pong_at_ + ping_disconnect_delay() + 0.002);
|
|
// CHECK(wakeup_at > Time::now_cached());
|
|
|
|
relax_timeout_at(&wakeup_at_, flush_packet_at_);
|
|
return wakeup_at_;
|
|
}
|
|
|
|
void SessionConnection::force_close(SessionConnection::Callback *callback) {
|
|
CHECK(state_ != Closed);
|
|
callback_ = callback;
|
|
do_close(Status::OK());
|
|
}
|
|
} // namespace mtproto
|
|
} // namespace td
|