diff --git a/td/mtproto/CryptoStorer.h b/td/mtproto/CryptoStorer.h index 06527585a..355e2505a 100644 --- a/td/mtproto/CryptoStorer.h +++ b/td/mtproto/CryptoStorer.h @@ -103,6 +103,32 @@ class CancelVectorImpl { vector> storers_; }; +class InvokeAfter { + public: + explicit InvokeAfter(Span ids): ids_(ids){ + } + template + void store(StorerT &storer) const { + if (ids_.empty()) { + return; + } + if (ids_.size() == 1) { + storer.store_int(static_cast(0xcb9f372d)); + storer.store_long(static_cast(ids_[0])); + return; + } + // invokeAfterMsgs#3dc4b4f0 {X:Type} msg_ids:Vector query:!X = X; + storer.store_int(static_cast(0x3dc4b4f0)); + storer.store_int(static_cast(0x1cb5c415)); + storer.store_int(narrow_cast(ids_.size())); + for (auto id : ids_) { + storer.store_long(static_cast(id)); + } + } + private: + Span ids_; +}; + class QueryImpl { public: QueryImpl(const MtprotoQuery &query, Slice header) : query_(query), header_(header) { @@ -112,23 +138,9 @@ class QueryImpl { void do_store(StorerT &storer) const { storer.store_binary(query_.message_id); storer.store_binary(query_.seq_no); - Slice invoke_header; -// TODO(refactor): -// invokeAfterMsg#cb9f372d {X:Type} msg_id:long query:!X = X; -// This code makes me very sad. -// InvokeAfterMsg is not even in mtproto_api. It is in telegram_api. -#pragma pack(push, 4) - struct { - uint32 constructor_id; - uint64 invoke_after_id; - } invoke_data; -#pragma pack(pop) - if (query_.invoke_after_id != 0) { - invoke_data.constructor_id = 0xcb9f372d; - invoke_data.invoke_after_id = query_.invoke_after_id; - invoke_header = Slice(reinterpret_cast(&invoke_data), sizeof(invoke_data)); - } + InvokeAfter invoke_after(query_.invoke_after_ids); + auto invoke_after_storer = create_default_storer(invoke_after); Slice data = query_.packet.as_slice(); mtproto_api::gzip_packed packed(data); @@ -136,9 +148,8 @@ class QueryImpl { auto gzip_storer = create_storer(packed); const Storer &data_storer = query_.gzip_flag ? static_cast(gzip_storer) : static_cast(plain_storer); - auto invoke_header_storer = create_storer(invoke_header); auto header_storer = create_storer(header_); - auto suff_storer = create_storer(invoke_header_storer, data_storer); + auto suff_storer = create_storer(invoke_after_storer, data_storer); auto all_storer = create_storer(header_storer, suff_storer); storer.store_binary(static_cast(all_storer.size())); diff --git a/td/mtproto/MtprotoQuery.h b/td/mtproto/MtprotoQuery.h index 980f60773..b04c3c61e 100644 --- a/td/mtproto/MtprotoQuery.h +++ b/td/mtproto/MtprotoQuery.h @@ -17,7 +17,7 @@ struct MtprotoQuery { int32 seq_no; BufferSlice packet; bool gzip_flag; - uint64 invoke_after_id; + std::vector invoke_after_ids; bool use_quick_ack; }; diff --git a/td/mtproto/SessionConnection.cpp b/td/mtproto/SessionConnection.cpp index 3a7124f08..1495e1c13 100644 --- a/td/mtproto/SessionConnection.cpp +++ b/td/mtproto/SessionConnection.cpp @@ -765,7 +765,7 @@ void SessionConnection::send_crypto(const Storer &storer, uint64 quick_ack_token } Result SessionConnection::send_query(BufferSlice buffer, bool gzip_flag, int64 message_id, - uint64 invoke_after_id, bool use_quick_ack) { + vector invoke_after_ids, bool use_quick_ack) { CHECK(mode_ != Mode::HttpLongPoll); // "LongPoll connection is only for http_wait" if (message_id == 0) { message_id = auth_data_->next_message_id(Time::now_cached()); @@ -774,9 +774,9 @@ Result SessionConnection::send_query(BufferSlice buffer, bool gzip_flag, if (to_send_.empty()) { send_before(Time::now_cached() + QUERY_DELAY); } - to_send_.push_back(MtprotoQuery{message_id, seq_no, std::move(buffer), gzip_flag, invoke_after_id, use_quick_ack}); + to_send_.push_back(MtprotoQuery{message_id, seq_no, std::move(buffer), gzip_flag, std::move(invoke_after_ids), use_quick_ack}); VLOG(mtproto) << "Invoke query " << message_id << " of size " << to_send_.back().packet.size() << " with seq_no " - << seq_no << " after " << invoke_after_id << (use_quick_ack ? " with quick ack" : ""); + << seq_no << " after " << invoke_after_ids << (use_quick_ack ? " with quick ack" : ""); return message_id; } @@ -817,7 +817,7 @@ std::pair SessionConnection::encrypted_bind(int64 perm_key, CHECK(size == real_size); MtprotoQuery query{ - auth_data_->next_message_id(Time::now_cached()), 0, object_packet.as_buffer_slice(), false, 0, false}; + auth_data_->next_message_id(Time::now_cached()), 0, object_packet.as_buffer_slice(), false, {}, false}; PacketStorer query_storer(query, Slice()); PacketInfo info; diff --git a/td/mtproto/SessionConnection.h b/td/mtproto/SessionConnection.h index 16c350321..6ca7524bf 100644 --- a/td/mtproto/SessionConnection.h +++ b/td/mtproto/SessionConnection.h @@ -82,7 +82,7 @@ class SessionConnection final // Interface Result TD_WARN_UNUSED_RESULT send_query(BufferSlice buffer, bool gzip_flag, int64 message_id = 0, - uint64 invoke_after_id = 0, bool use_quick_ack = false); + std::vector invoke_after_id = {}, bool use_quick_ack = false); std::pair encrypted_bind(int64 perm_key, int64 nonce, int32 expires_at); void get_state_info(int64 message_id); diff --git a/td/telegram/MessagesManager.cpp b/td/telegram/MessagesManager.cpp index 35b7d7afd..59c84709b 100644 --- a/td/telegram/MessagesManager.cpp +++ b/td/telegram/MessagesManager.cpp @@ -341,7 +341,7 @@ class GetPinnedDialogsActor final : public NetActorOnce { auto query = G()->net_query_creator().create(telegram_api::messages_getPinnedDialogs(folder_id.get())); auto result = query.get_weak(); send_closure(td_->messages_manager_->sequence_dispatcher_, &MultiSequenceDispatcher::send_with_callback, - std::move(query), actor_shared(this), sequence_id); + std::move(query), actor_shared(this), ChainIds{sequence_id}); return result; } @@ -796,7 +796,7 @@ class GetDialogListActor final : public NetActorOnce { telegram_api::messages_getDialogs(flags, false /*ignored*/, folder_id.get(), offset_date, offset_message_id.get(), std::move(input_peer), limit, 0)); send_closure(td_->messages_manager_->sequence_dispatcher_, &MultiSequenceDispatcher::send_with_callback, - std::move(query), actor_shared(this), sequence_id); + std::move(query), actor_shared(this), ChainIds{sequence_id}); } void on_result(BufferSlice packet) final { @@ -1853,7 +1853,7 @@ class ToggleDialogIsBlockedActor final : public NetActorOnce { auto query = is_blocked ? G()->net_query_creator().create(telegram_api::contacts_block(std::move(input_peer))) : G()->net_query_creator().create(telegram_api::contacts_unblock(std::move(input_peer))); send_closure(td_->messages_manager_->sequence_dispatcher_, &MultiSequenceDispatcher::send_with_callback, - std::move(query), actor_shared(this), sequence_dispatcher_id); + std::move(query), actor_shared(this), ChainIds{sequence_dispatcher_id}); } void on_result(BufferSlice packet) final { @@ -3053,7 +3053,7 @@ class SaveDefaultSendAsActor final : public NetActorOnce { telegram_api::messages_saveDefaultSendAs(std::move(input_peer), std::move(send_as_input_peer))); query->debug("send to MessagesManager::MultiSequenceDispatcher"); send_closure(td_->messages_manager_->sequence_dispatcher_, &MultiSequenceDispatcher::send_with_callback, - std::move(query), actor_shared(this), sequence_dispatcher_id); + std::move(query), actor_shared(this), ChainIds{sequence_dispatcher_id}); } void on_result(BufferSlice packet) final { @@ -3165,7 +3165,7 @@ class SendMessageActor final : public NetActorOnce { *send_query_ref = query.get_weak(); query->debug("send to MessagesManager::MultiSequenceDispatcher"); send_closure(td_->messages_manager_->sequence_dispatcher_, &MultiSequenceDispatcher::send_with_callback, - std::move(query), actor_shared(this), sequence_dispatcher_id); + std::move(query), actor_shared(this), ChainIds{sequence_dispatcher_id}); } void on_result(BufferSlice packet) final { @@ -3350,7 +3350,7 @@ class SendMultiMediaActor final : public NetActorOnce { // no quick ack, because file reference errors are very likely to happen query->debug("send to MessagesManager::MultiSequenceDispatcher"); send_closure(td_->messages_manager_->sequence_dispatcher_, &MultiSequenceDispatcher::send_with_callback, - std::move(query), actor_shared(this), sequence_dispatcher_id); + std::move(query), actor_shared(this), ChainIds{sequence_dispatcher_id}); } void on_result(BufferSlice packet) final { @@ -3476,7 +3476,7 @@ class SendMediaActor final : public NetActorOnce { *send_query_ref = query.get_weak(); query->debug("send to MessagesManager::MultiSequenceDispatcher"); send_closure(td_->messages_manager_->sequence_dispatcher_, &MultiSequenceDispatcher::send_with_callback, - std::move(query), actor_shared(this), sequence_dispatcher_id); + std::move(query), actor_shared(this), ChainIds{sequence_dispatcher_id}); } void on_result(BufferSlice packet) final { @@ -3641,7 +3641,7 @@ class SendScheduledMessageActor final : public NetActorOnce { query->debug("send to MessagesManager::MultiSequenceDispatcher"); send_closure(td_->messages_manager_->sequence_dispatcher_, &MultiSequenceDispatcher::send_with_callback, - std::move(query), actor_shared(this), sequence_dispatcher_id); + std::move(query), actor_shared(this), ChainIds{sequence_dispatcher_id}); } void on_result(BufferSlice packet) final { @@ -3723,7 +3723,7 @@ class EditMessageActor final : public NetActorOnce { query->debug("send to MessagesManager::MultiSequenceDispatcher"); send_closure(td_->messages_manager_->sequence_dispatcher_, &MultiSequenceDispatcher::send_with_callback, - std::move(query), actor_shared(this), sequence_dispatcher_id); + std::move(query), actor_shared(this), ChainIds{sequence_dispatcher_id}); } void on_result(BufferSlice packet) final { @@ -3854,7 +3854,7 @@ class ForwardMessagesActor final : public NetActorOnce { PromiseCreator::Ignore()); } send_closure(td_->messages_manager_->sequence_dispatcher_, &MultiSequenceDispatcher::send_with_callback, - std::move(query), actor_shared(this), sequence_dispatcher_id); + std::move(query), actor_shared(this), ChainIds{sequence_dispatcher_id}); } void on_result(BufferSlice packet) final { diff --git a/td/telegram/SequenceDispatcher.cpp b/td/telegram/SequenceDispatcher.cpp index 4a129cfd1..505033e5b 100644 --- a/td/telegram/SequenceDispatcher.cpp +++ b/td/telegram/SequenceDispatcher.cpp @@ -168,7 +168,11 @@ void SequenceDispatcher::loop() { if (last_sent_i_ != std::numeric_limits::max() && data_[last_sent_i_].state_ == State::Wait) { invoke_after = data_[last_sent_i_].net_query_ref_; } - data_[next_i_].query_->set_invoke_after(invoke_after); + if (!invoke_after.empty()) { + data_[next_i_].query_->set_invoke_after({invoke_after}); + } else { + data_[next_i_].query_->set_invoke_after({}); + } data_[next_i_].query_->last_timeout_ = 0; VLOG(net_query) << "Send " << data_[next_i_].query_; @@ -243,8 +247,11 @@ void SequenceDispatcher::close_silent() { /*** MultiSequenceDispatcher ***/ void MultiSequenceDispatcherOld::send_with_callback(NetQueryPtr query, ActorShared callback, - uint64 sequence_id) { - CHECK(sequence_id != 0); + td::Span chains) { + CHECK(all_of(chains, [](auto chain_id) {return chain_id != 0;})); + CHECK(!chains.empty()); + auto sequence_id = chains[0]; + auto it_ok = dispatchers_.emplace(sequence_id, Data{0, ActorOwn()}); auto &data = it_ok.first->second; if (it_ok.second) { @@ -273,14 +280,14 @@ void MultiSequenceDispatcherOld::ready_to_close() { class MultiSequenceDispatcherNewImpl final : public MultiSequenceDispatcherNew { public: - void send_with_callback(NetQueryPtr query, ActorShared callback, uint64 sequence_id) final { - LOG(ERROR) << "send " << query; + void send_with_callback(NetQueryPtr query, ActorShared callback, td::Span chains) final { + CHECK(all_of(chains, [](auto chain_id) {return chain_id != 0;})); Node node; node.net_query = std::move(query); node.net_query->debug("Waiting at SequenceDispatcher"); node.net_query_ref = node.net_query.get_weak(); node.callback = std::move(callback); - scheduler_.create_task({ChainId{sequence_id}}, std::move(node)); + scheduler_.create_task(chains, std::move(node)); loop(); } @@ -335,11 +342,9 @@ class MultiSequenceDispatcherNewImpl final : public MultiSequenceDispatcherNew { while (true) { auto o_task = scheduler_.start_next_task(); if (!o_task) { - LOG(ERROR) << " no more tasks " << scheduler_; break; } auto task = o_task.unwrap(); - LOG(ERROR) << " next task = " << task.task_id; auto &node = *scheduler_.get_task_extra(task.task_id); CHECK(!node.net_query.empty()); @@ -348,15 +353,10 @@ class MultiSequenceDispatcherNewImpl final : public MultiSequenceDispatcherNew { for (auto parent_id : task.parents) { auto &parent_node = *scheduler_.get_task_extra(parent_id); parents.push_back(parent_node.net_query_ref); + CHECK(!parent_node.net_query_ref.empty()); } - if (parents.empty()) { - query->set_invoke_after({}); - } else if (parents.size() == 1) { - query->set_invoke_after(parents[0]); - } else if (parents.size() > 1){ - LOG(FATAL) << "TODO: support invokeAfterMsgs"; - } + query->set_invoke_after(std::move(parents)); query->last_timeout_ = 0; // TODO: flood VLOG(net_query) << "Send " << query; query->debug("send to Td::send_with_callback"); diff --git a/td/telegram/SequenceDispatcher.h b/td/telegram/SequenceDispatcher.h index 4b8e9d996..5feedaa6b 100644 --- a/td/telegram/SequenceDispatcher.h +++ b/td/telegram/SequenceDispatcher.h @@ -75,7 +75,7 @@ class SequenceDispatcher final : public NetQueryCallback { class MultiSequenceDispatcherOld final : public SequenceDispatcher::Parent { public: - void send_with_callback(NetQueryPtr query, ActorShared callback, uint64 sequence_id); + void send_with_callback(NetQueryPtr query, ActorShared callback, Span chains); static ActorOwn create(td::Slice name) { return create_actor(name); } @@ -90,12 +90,14 @@ class MultiSequenceDispatcherOld final : public SequenceDispatcher::Parent { void ready_to_close() final; }; +using ChainId = uint64; +using ChainIds = std::vector; class MultiSequenceDispatcherNew : public NetQueryCallback { public: - virtual void send_with_callback(NetQueryPtr query, ActorShared callback, uint64 sequence_id) = 0; + virtual void send_with_callback(NetQueryPtr query, ActorShared callback, Span chains) = 0; static ActorOwn create(Slice name); }; -using MultiSequenceDispatcher = MultiSequenceDispatcherOld; +using MultiSequenceDispatcher = MultiSequenceDispatcherNew; } // namespace td diff --git a/td/telegram/net/NetQuery.h b/td/telegram/net/NetQuery.h index 90c66190a..d04469e3a 100644 --- a/td/telegram/net/NetQuery.h +++ b/td/telegram/net/NetQuery.h @@ -196,11 +196,11 @@ class NetQuery final : public TsListNode { message_id_ = message_id; } - NetQueryRef invoke_after() const { + Span invoke_after() const { return invoke_after_; } - void set_invoke_after(NetQueryRef ref) { - invoke_after_ = ref; + void set_invoke_after(std::vector refs) { + invoke_after_ = std::move(refs); } void set_session_rand(uint32 session_rand) { session_rand_ = session_rand; @@ -289,7 +289,7 @@ class NetQuery final : public TsListNode { BufferSlice answer_; int32 tl_constructor_ = 0; - NetQueryRef invoke_after_; + std::vector invoke_after_; uint32 session_rand_ = 0; bool may_be_lost_ = false; diff --git a/td/telegram/net/Session.cpp b/td/telegram/net/Session.cpp index 5daba7798..38c106aeb 100644 --- a/td/telegram/net/Session.cpp +++ b/td/telegram/net/Session.cpp @@ -992,14 +992,17 @@ void Session::connection_send_query(ConnectionInfo *info, NetQueryPtr &&net_quer return return_query(std::move(net_query)); } - uint64 invoke_after_id = 0; - NetQueryRef invoke_after = net_query->invoke_after(); - if (!invoke_after.empty()) { - invoke_after_id = invoke_after->message_id(); - if (invoke_after->session_id() != auth_data_.get_session_id() || invoke_after_id == 0) { + Span invoke_after = net_query->invoke_after(); + std::vector invoke_after_ids; + for (auto &ref : invoke_after) { + auto invoke_after_id = ref->message_id(); + if (ref->session_id() != auth_data_.get_session_id() || invoke_after_id == 0) { net_query->set_error_resend_invoke_after(); return return_query(std::move(net_query)); } + invoke_after_ids.push_back(invoke_after_id); + } + if (!invoke_after.empty()) { if (!unknown_queries_.empty()) { pending_invoke_after_queries_.push_back(std::move(net_query)); return; @@ -1010,7 +1013,7 @@ void Session::connection_send_query(ConnectionInfo *info, NetQueryPtr &&net_quer if (!immediately_fail_query) { auto r_message_id = info->connection_->send_query(net_query->query().clone(), net_query->gzip_flag() == NetQuery::GzipFlag::On, - message_id, invoke_after_id, static_cast(net_query->quick_ack_promise_)); + message_id, invoke_after_ids, static_cast(net_query->quick_ack_promise_)); net_query->on_net_write(net_query->query().size()); @@ -1024,7 +1027,7 @@ void Session::connection_send_query(ConnectionInfo *info, NetQueryPtr &&net_quer } } VLOG(net_query) << "Send query to connection " << net_query << " [msg_id:" << format::as_hex(message_id) << "]" - << tag("invoke_after", format::as_hex(invoke_after_id)); + << tag("invoke_after", td::transform(invoke_after_ids, [](auto id){return format::as_hex(id);})); net_query->set_message_id(message_id); net_query->cancel_slot_.clear_event(); LOG_CHECK(sent_queries_.find(message_id) == sent_queries_.end()) << message_id;