SequenceDispatcher: support multiple chains

This commit is contained in:
Arseny Smirnov 2022-01-28 17:06:54 +03:00
parent 355c2950ad
commit 452f60be0b
9 changed files with 79 additions and 63 deletions

View File

@ -103,6 +103,32 @@ class CancelVectorImpl {
vector<PacketStorer<CancelImpl>> storers_;
};
class InvokeAfter {
public:
explicit InvokeAfter(Span<uint64> ids): ids_(ids){
}
template <class StorerT>
void store(StorerT &storer) const {
if (ids_.empty()) {
return;
}
if (ids_.size() == 1) {
storer.store_int(static_cast<int32>(0xcb9f372d));
storer.store_long(static_cast<int64>(ids_[0]));
return;
}
// invokeAfterMsgs#3dc4b4f0 {X:Type} msg_ids:Vector<long> query:!X = X;
storer.store_int(static_cast<int32>(0x3dc4b4f0));
storer.store_int(static_cast<int32>(0x1cb5c415));
storer.store_int(narrow_cast<int>(ids_.size()));
for (auto id : ids_) {
storer.store_long(static_cast<int64>(id));
}
}
private:
Span<uint64> ids_;
};
class QueryImpl {
public:
QueryImpl(const MtprotoQuery &query, Slice header) : query_(query), header_(header) {
@ -112,23 +138,9 @@ class QueryImpl {
void do_store(StorerT &storer) const {
storer.store_binary(query_.message_id);
storer.store_binary(query_.seq_no);
Slice invoke_header;
// TODO(refactor):
// invokeAfterMsg#cb9f372d {X:Type} msg_id:long query:!X = X;
// This code makes me very sad.
// InvokeAfterMsg is not even in mtproto_api. It is in telegram_api.
#pragma pack(push, 4)
struct {
uint32 constructor_id;
uint64 invoke_after_id;
} invoke_data;
#pragma pack(pop)
if (query_.invoke_after_id != 0) {
invoke_data.constructor_id = 0xcb9f372d;
invoke_data.invoke_after_id = query_.invoke_after_id;
invoke_header = Slice(reinterpret_cast<const uint8 *>(&invoke_data), sizeof(invoke_data));
}
InvokeAfter invoke_after(query_.invoke_after_ids);
auto invoke_after_storer = create_default_storer(invoke_after);
Slice data = query_.packet.as_slice();
mtproto_api::gzip_packed packed(data);
@ -136,9 +148,8 @@ class QueryImpl {
auto gzip_storer = create_storer(packed);
const Storer &data_storer =
query_.gzip_flag ? static_cast<const Storer &>(gzip_storer) : static_cast<const Storer &>(plain_storer);
auto invoke_header_storer = create_storer(invoke_header);
auto header_storer = create_storer(header_);
auto suff_storer = create_storer(invoke_header_storer, data_storer);
auto suff_storer = create_storer(invoke_after_storer, data_storer);
auto all_storer = create_storer(header_storer, suff_storer);
storer.store_binary(static_cast<uint32>(all_storer.size()));

View File

@ -17,7 +17,7 @@ struct MtprotoQuery {
int32 seq_no;
BufferSlice packet;
bool gzip_flag;
uint64 invoke_after_id;
std::vector<uint64> invoke_after_ids;
bool use_quick_ack;
};

View File

@ -765,7 +765,7 @@ void SessionConnection::send_crypto(const Storer &storer, uint64 quick_ack_token
}
Result<uint64> SessionConnection::send_query(BufferSlice buffer, bool gzip_flag, int64 message_id,
uint64 invoke_after_id, bool use_quick_ack) {
vector<uint64> invoke_after_ids, bool use_quick_ack) {
CHECK(mode_ != Mode::HttpLongPoll); // "LongPoll connection is only for http_wait"
if (message_id == 0) {
message_id = auth_data_->next_message_id(Time::now_cached());
@ -774,9 +774,9 @@ Result<uint64> SessionConnection::send_query(BufferSlice buffer, bool gzip_flag,
if (to_send_.empty()) {
send_before(Time::now_cached() + QUERY_DELAY);
}
to_send_.push_back(MtprotoQuery{message_id, seq_no, std::move(buffer), gzip_flag, invoke_after_id, use_quick_ack});
to_send_.push_back(MtprotoQuery{message_id, seq_no, std::move(buffer), gzip_flag, std::move(invoke_after_ids), use_quick_ack});
VLOG(mtproto) << "Invoke query " << message_id << " of size " << to_send_.back().packet.size() << " with seq_no "
<< seq_no << " after " << invoke_after_id << (use_quick_ack ? " with quick ack" : "");
<< seq_no << " after " << invoke_after_ids << (use_quick_ack ? " with quick ack" : "");
return message_id;
}
@ -817,7 +817,7 @@ std::pair<uint64, BufferSlice> SessionConnection::encrypted_bind(int64 perm_key,
CHECK(size == real_size);
MtprotoQuery query{
auth_data_->next_message_id(Time::now_cached()), 0, object_packet.as_buffer_slice(), false, 0, false};
auth_data_->next_message_id(Time::now_cached()), 0, object_packet.as_buffer_slice(), false, {}, false};
PacketStorer<QueryImpl> query_storer(query, Slice());
PacketInfo info;

View File

@ -82,7 +82,7 @@ class SessionConnection final
// Interface
Result<uint64> TD_WARN_UNUSED_RESULT send_query(BufferSlice buffer, bool gzip_flag, int64 message_id = 0,
uint64 invoke_after_id = 0, bool use_quick_ack = false);
std::vector<uint64> invoke_after_id = {}, bool use_quick_ack = false);
std::pair<uint64, BufferSlice> encrypted_bind(int64 perm_key, int64 nonce, int32 expires_at);
void get_state_info(int64 message_id);

View File

@ -341,7 +341,7 @@ class GetPinnedDialogsActor final : public NetActorOnce {
auto query = G()->net_query_creator().create(telegram_api::messages_getPinnedDialogs(folder_id.get()));
auto result = query.get_weak();
send_closure(td_->messages_manager_->sequence_dispatcher_, &MultiSequenceDispatcher::send_with_callback,
std::move(query), actor_shared(this), sequence_id);
std::move(query), actor_shared(this), ChainIds{sequence_id});
return result;
}
@ -796,7 +796,7 @@ class GetDialogListActor final : public NetActorOnce {
telegram_api::messages_getDialogs(flags, false /*ignored*/, folder_id.get(), offset_date,
offset_message_id.get(), std::move(input_peer), limit, 0));
send_closure(td_->messages_manager_->sequence_dispatcher_, &MultiSequenceDispatcher::send_with_callback,
std::move(query), actor_shared(this), sequence_id);
std::move(query), actor_shared(this), ChainIds{sequence_id});
}
void on_result(BufferSlice packet) final {
@ -1853,7 +1853,7 @@ class ToggleDialogIsBlockedActor final : public NetActorOnce {
auto query = is_blocked ? G()->net_query_creator().create(telegram_api::contacts_block(std::move(input_peer)))
: G()->net_query_creator().create(telegram_api::contacts_unblock(std::move(input_peer)));
send_closure(td_->messages_manager_->sequence_dispatcher_, &MultiSequenceDispatcher::send_with_callback,
std::move(query), actor_shared(this), sequence_dispatcher_id);
std::move(query), actor_shared(this), ChainIds{sequence_dispatcher_id});
}
void on_result(BufferSlice packet) final {
@ -3053,7 +3053,7 @@ class SaveDefaultSendAsActor final : public NetActorOnce {
telegram_api::messages_saveDefaultSendAs(std::move(input_peer), std::move(send_as_input_peer)));
query->debug("send to MessagesManager::MultiSequenceDispatcher");
send_closure(td_->messages_manager_->sequence_dispatcher_, &MultiSequenceDispatcher::send_with_callback,
std::move(query), actor_shared(this), sequence_dispatcher_id);
std::move(query), actor_shared(this), ChainIds{sequence_dispatcher_id});
}
void on_result(BufferSlice packet) final {
@ -3165,7 +3165,7 @@ class SendMessageActor final : public NetActorOnce {
*send_query_ref = query.get_weak();
query->debug("send to MessagesManager::MultiSequenceDispatcher");
send_closure(td_->messages_manager_->sequence_dispatcher_, &MultiSequenceDispatcher::send_with_callback,
std::move(query), actor_shared(this), sequence_dispatcher_id);
std::move(query), actor_shared(this), ChainIds{sequence_dispatcher_id});
}
void on_result(BufferSlice packet) final {
@ -3350,7 +3350,7 @@ class SendMultiMediaActor final : public NetActorOnce {
// no quick ack, because file reference errors are very likely to happen
query->debug("send to MessagesManager::MultiSequenceDispatcher");
send_closure(td_->messages_manager_->sequence_dispatcher_, &MultiSequenceDispatcher::send_with_callback,
std::move(query), actor_shared(this), sequence_dispatcher_id);
std::move(query), actor_shared(this), ChainIds{sequence_dispatcher_id});
}
void on_result(BufferSlice packet) final {
@ -3476,7 +3476,7 @@ class SendMediaActor final : public NetActorOnce {
*send_query_ref = query.get_weak();
query->debug("send to MessagesManager::MultiSequenceDispatcher");
send_closure(td_->messages_manager_->sequence_dispatcher_, &MultiSequenceDispatcher::send_with_callback,
std::move(query), actor_shared(this), sequence_dispatcher_id);
std::move(query), actor_shared(this), ChainIds{sequence_dispatcher_id});
}
void on_result(BufferSlice packet) final {
@ -3641,7 +3641,7 @@ class SendScheduledMessageActor final : public NetActorOnce {
query->debug("send to MessagesManager::MultiSequenceDispatcher");
send_closure(td_->messages_manager_->sequence_dispatcher_, &MultiSequenceDispatcher::send_with_callback,
std::move(query), actor_shared(this), sequence_dispatcher_id);
std::move(query), actor_shared(this), ChainIds{sequence_dispatcher_id});
}
void on_result(BufferSlice packet) final {
@ -3723,7 +3723,7 @@ class EditMessageActor final : public NetActorOnce {
query->debug("send to MessagesManager::MultiSequenceDispatcher");
send_closure(td_->messages_manager_->sequence_dispatcher_, &MultiSequenceDispatcher::send_with_callback,
std::move(query), actor_shared(this), sequence_dispatcher_id);
std::move(query), actor_shared(this), ChainIds{sequence_dispatcher_id});
}
void on_result(BufferSlice packet) final {
@ -3854,7 +3854,7 @@ class ForwardMessagesActor final : public NetActorOnce {
PromiseCreator::Ignore());
}
send_closure(td_->messages_manager_->sequence_dispatcher_, &MultiSequenceDispatcher::send_with_callback,
std::move(query), actor_shared(this), sequence_dispatcher_id);
std::move(query), actor_shared(this), ChainIds{sequence_dispatcher_id});
}
void on_result(BufferSlice packet) final {

View File

@ -168,7 +168,11 @@ void SequenceDispatcher::loop() {
if (last_sent_i_ != std::numeric_limits<size_t>::max() && data_[last_sent_i_].state_ == State::Wait) {
invoke_after = data_[last_sent_i_].net_query_ref_;
}
data_[next_i_].query_->set_invoke_after(invoke_after);
if (!invoke_after.empty()) {
data_[next_i_].query_->set_invoke_after({invoke_after});
} else {
data_[next_i_].query_->set_invoke_after({});
}
data_[next_i_].query_->last_timeout_ = 0;
VLOG(net_query) << "Send " << data_[next_i_].query_;
@ -243,8 +247,11 @@ void SequenceDispatcher::close_silent() {
/*** MultiSequenceDispatcher ***/
void MultiSequenceDispatcherOld::send_with_callback(NetQueryPtr query, ActorShared<NetQueryCallback> callback,
uint64 sequence_id) {
CHECK(sequence_id != 0);
td::Span<uint64> chains) {
CHECK(all_of(chains, [](auto chain_id) {return chain_id != 0;}));
CHECK(!chains.empty());
auto sequence_id = chains[0];
auto it_ok = dispatchers_.emplace(sequence_id, Data{0, ActorOwn<SequenceDispatcher>()});
auto &data = it_ok.first->second;
if (it_ok.second) {
@ -273,14 +280,14 @@ void MultiSequenceDispatcherOld::ready_to_close() {
class MultiSequenceDispatcherNewImpl final : public MultiSequenceDispatcherNew {
public:
void send_with_callback(NetQueryPtr query, ActorShared<NetQueryCallback> callback, uint64 sequence_id) final {
LOG(ERROR) << "send " << query;
void send_with_callback(NetQueryPtr query, ActorShared<NetQueryCallback> callback, td::Span<uint64> chains) final {
CHECK(all_of(chains, [](auto chain_id) {return chain_id != 0;}));
Node node;
node.net_query = std::move(query);
node.net_query->debug("Waiting at SequenceDispatcher");
node.net_query_ref = node.net_query.get_weak();
node.callback = std::move(callback);
scheduler_.create_task({ChainId{sequence_id}}, std::move(node));
scheduler_.create_task(chains, std::move(node));
loop();
}
@ -335,11 +342,9 @@ class MultiSequenceDispatcherNewImpl final : public MultiSequenceDispatcherNew {
while (true) {
auto o_task = scheduler_.start_next_task();
if (!o_task) {
LOG(ERROR) << " no more tasks " << scheduler_;
break;
}
auto task = o_task.unwrap();
LOG(ERROR) << " next task = " << task.task_id;
auto &node = *scheduler_.get_task_extra(task.task_id);
CHECK(!node.net_query.empty());
@ -348,15 +353,10 @@ class MultiSequenceDispatcherNewImpl final : public MultiSequenceDispatcherNew {
for (auto parent_id : task.parents) {
auto &parent_node = *scheduler_.get_task_extra(parent_id);
parents.push_back(parent_node.net_query_ref);
CHECK(!parent_node.net_query_ref.empty());
}
if (parents.empty()) {
query->set_invoke_after({});
} else if (parents.size() == 1) {
query->set_invoke_after(parents[0]);
} else if (parents.size() > 1){
LOG(FATAL) << "TODO: support invokeAfterMsgs";
}
query->set_invoke_after(std::move(parents));
query->last_timeout_ = 0; // TODO: flood
VLOG(net_query) << "Send " << query;
query->debug("send to Td::send_with_callback");

View File

@ -75,7 +75,7 @@ class SequenceDispatcher final : public NetQueryCallback {
class MultiSequenceDispatcherOld final : public SequenceDispatcher::Parent {
public:
void send_with_callback(NetQueryPtr query, ActorShared<NetQueryCallback> callback, uint64 sequence_id);
void send_with_callback(NetQueryPtr query, ActorShared<NetQueryCallback> callback, Span<uint64> chains);
static ActorOwn<MultiSequenceDispatcherOld> create(td::Slice name) {
return create_actor<MultiSequenceDispatcherOld>(name);
}
@ -90,12 +90,14 @@ class MultiSequenceDispatcherOld final : public SequenceDispatcher::Parent {
void ready_to_close() final;
};
using ChainId = uint64;
using ChainIds = std::vector<ChainId>;
class MultiSequenceDispatcherNew : public NetQueryCallback {
public:
virtual void send_with_callback(NetQueryPtr query, ActorShared<NetQueryCallback> callback, uint64 sequence_id) = 0;
virtual void send_with_callback(NetQueryPtr query, ActorShared<NetQueryCallback> callback, Span<uint64> chains) = 0;
static ActorOwn<MultiSequenceDispatcherNew> create(Slice name);
};
using MultiSequenceDispatcher = MultiSequenceDispatcherOld;
using MultiSequenceDispatcher = MultiSequenceDispatcherNew;
} // namespace td

View File

@ -196,11 +196,11 @@ class NetQuery final : public TsListNode<NetQueryDebug> {
message_id_ = message_id;
}
NetQueryRef invoke_after() const {
Span<NetQueryRef> invoke_after() const {
return invoke_after_;
}
void set_invoke_after(NetQueryRef ref) {
invoke_after_ = ref;
void set_invoke_after(std::vector<NetQueryRef> refs) {
invoke_after_ = std::move(refs);
}
void set_session_rand(uint32 session_rand) {
session_rand_ = session_rand;
@ -289,7 +289,7 @@ class NetQuery final : public TsListNode<NetQueryDebug> {
BufferSlice answer_;
int32 tl_constructor_ = 0;
NetQueryRef invoke_after_;
std::vector<NetQueryRef> invoke_after_;
uint32 session_rand_ = 0;
bool may_be_lost_ = false;

View File

@ -992,14 +992,17 @@ void Session::connection_send_query(ConnectionInfo *info, NetQueryPtr &&net_quer
return return_query(std::move(net_query));
}
uint64 invoke_after_id = 0;
NetQueryRef invoke_after = net_query->invoke_after();
if (!invoke_after.empty()) {
invoke_after_id = invoke_after->message_id();
if (invoke_after->session_id() != auth_data_.get_session_id() || invoke_after_id == 0) {
Span<NetQueryRef> invoke_after = net_query->invoke_after();
std::vector<uint64> invoke_after_ids;
for (auto &ref : invoke_after) {
auto invoke_after_id = ref->message_id();
if (ref->session_id() != auth_data_.get_session_id() || invoke_after_id == 0) {
net_query->set_error_resend_invoke_after();
return return_query(std::move(net_query));
}
invoke_after_ids.push_back(invoke_after_id);
}
if (!invoke_after.empty()) {
if (!unknown_queries_.empty()) {
pending_invoke_after_queries_.push_back(std::move(net_query));
return;
@ -1010,7 +1013,7 @@ void Session::connection_send_query(ConnectionInfo *info, NetQueryPtr &&net_quer
if (!immediately_fail_query) {
auto r_message_id =
info->connection_->send_query(net_query->query().clone(), net_query->gzip_flag() == NetQuery::GzipFlag::On,
message_id, invoke_after_id, static_cast<bool>(net_query->quick_ack_promise_));
message_id, invoke_after_ids, static_cast<bool>(net_query->quick_ack_promise_));
net_query->on_net_write(net_query->query().size());
@ -1024,7 +1027,7 @@ void Session::connection_send_query(ConnectionInfo *info, NetQueryPtr &&net_quer
}
}
VLOG(net_query) << "Send query to connection " << net_query << " [msg_id:" << format::as_hex(message_id) << "]"
<< tag("invoke_after", format::as_hex(invoke_after_id));
<< tag("invoke_after", td::transform(invoke_after_ids, [](auto id){return format::as_hex(id);}));
net_query->set_message_id(message_id);
net_query->cancel_slot_.clear_event();
LOG_CHECK(sent_queries_.find(message_id) == sent_queries_.end()) << message_id;