diff --git a/td/mtproto/RawConnection.cpp b/td/mtproto/RawConnection.cpp index c58a0e9f..7b5bd844 100644 --- a/td/mtproto/RawConnection.cpp +++ b/td/mtproto/RawConnection.cpp @@ -76,51 +76,69 @@ Status RawConnection::flush_read(const AuthKey &auth_key, Callback &callback) { } if (quick_ack != 0) { - auto it = quick_ack_to_token_.find(quick_ack); - if (it == quick_ack_to_token_.end()) { - LOG(WARNING) << Status::Error(PSLICE() << "Unknown " << tag("quick_ack", quick_ack)); - continue; - // TODO: return Status::Error(PSLICE() << "Unknown " << tag("quick_ack", quick_ack)); - } - auto token = it->second; - quick_ack_to_token_.erase(it); - callback.on_quick_ack(token); + on_quick_ack(quick_ack, callback); continue; } - MutableSlice data = packet.as_slice(); PacketInfo info; info.version = 2; - int32 error_code = 0; - TRY_STATUS(mtproto::Transport::read(data, auth_key, &info, &data, &error_code)); - - if (error_code) { - if (error_code == -429) { - if (stats_callback_) { - stats_callback_->on_mtproto_error(); + TRY_RESULT(read_result, mtproto::Transport::read(packet.as_slice(), auth_key, &info)); + switch (read_result.type()) { + case mtproto::Transport::ReadResult::Quickack: { + TRY_STATUS(on_quick_ack(read_result.quick_ack(), callback)); + break; + } + case mtproto::Transport::ReadResult::Error: { + TRY_STATUS(on_read_mtproto_error(read_result.error())); + break; + } + case mtproto::Transport::ReadResult::Packet: { + // If a packet was successfully decrypted, then it is ok to assume that the connection is alive + if (!auth_key.empty()) { + if (stats_callback_) { + stats_callback_->on_pong(); + } } - return Status::Error(500, PSLICE() << "Mtproto error: " << error_code); - } - if (error_code == -404) { - return Status::Error(-404, PSLICE() << "Mtproto error: " << error_code); - } - return Status::Error(PSLICE() << "Mtproto error: " << error_code); - } - // If a packet was successfully decrypted, then it is ok to assume that the connection is alive - if (!auth_key.empty()) { - if (stats_callback_) { - stats_callback_->on_pong(); + TRY_STATUS(callback.on_raw_packet(info, packet.from_slice(read_result.packet()))); + break; } + case mtproto::Transport::ReadResult::Nop: + break; } - - TRY_STATUS(callback.on_raw_packet(info, packet.from_slice(data))); } + TRY_STATUS(std::move(r)); return Status::OK(); } +Status RawConnection::on_read_mtproto_error(int32 error_code) { + if (error_code == -429) { + if (stats_callback_) { + stats_callback_->on_mtproto_error(); + } + return Status::Error(500, PSLICE() << "Mtproto error: " << error_code); + } + if (error_code == -404) { + return Status::Error(-404, PSLICE() << "Mtproto error: " << error_code); + } + return Status::Error(PSLICE() << "Mtproto error: " << error_code); +} + +Status RawConnection::on_quick_ack(uint32 quick_ack, Callback &callback) { + auto it = quick_ack_to_token_.find(quick_ack); + if (it == quick_ack_to_token_.end()) { + LOG(WARNING) << Status::Error(PSLICE() << "Unknown " << tag("quick_ack", quick_ack)); + return Status::OK(); + // TODO: return Status::Error(PSLICE() << "Unknown " << tag("quick_ack", quick_ack)); + } + auto token = it->second; + quick_ack_to_token_.erase(it); + callback.on_quick_ack(token); + return Status::OK(); +} + Status RawConnection::flush_write() { TRY_RESULT(size, socket_fd_.flush_write()); if (size > 0 && stats_callback_) { diff --git a/td/mtproto/RawConnection.h b/td/mtproto/RawConnection.h index 3c989c41..3660b1df 100644 --- a/td/mtproto/RawConnection.h +++ b/td/mtproto/RawConnection.h @@ -122,6 +122,9 @@ class RawConnection { Status flush_read(const AuthKey &auth_key, Callback &callback); Status flush_write(); + Status on_quick_ack(uint32 quick_ack, Callback &callback); + Status on_read_mtproto_error(int32 error_code); + Status do_flush(const AuthKey &auth_key, Callback &callback) TD_WARN_UNUSED_RESULT { if (has_error_) { return Status::Error("Connection has already failed"); diff --git a/td/mtproto/Transport.cpp b/td/mtproto/Transport.cpp index 738938c2..c68cb899 100644 --- a/td/mtproto/Transport.cpp +++ b/td/mtproto/Transport.cpp @@ -285,29 +285,41 @@ Result Transport::read_auth_key_id(Slice message) { return as(message.begin()); } -Status Transport::read(MutableSlice message, const AuthKey &auth_key, PacketInfo *info, MutableSlice *data, - int32 *error_code) { - if (message.size() < 8) { - if (message.size() >= 4) { - *error_code = as(message.begin()); - return Status::OK(); +Result Transport::read(MutableSlice message, const AuthKey &auth_key, PacketInfo *info) { + if (message.size() < 16) { + if (message.size() < 4) { + return Status::Error(PSLICE() << "Invalid mtproto message: smaller than 4 bytes [size=" << message.size() << "]"); } - return Status::Error(PSLICE() << "Invalid mtproto message: smaller than 8 bytes [size=" << message.size() << "]"); + + auto code = as(message.begin()); + if (code == 0) { + return ReadResult::make_nop(); + } else if (code == -1) { + if (message.size() >= 8) { + return ReadResult::make_quick_ack(as(message.begin() + 4)); + } + } else { + return ReadResult::make_error(code); + } + return Status::Error(PSLICE() << "Invalid small mtproto message"); } + info->auth_key_id = as(message.begin()); info->no_crypto_flag = info->auth_key_id == 0; + MutableSlice data; if (info->type == PacketInfo::EndToEnd) { - return read_e2e_crypto(message, auth_key, info, data); + TRY_STATUS(read_e2e_crypto(message, auth_key, info, &data)); } if (info->no_crypto_flag) { - return read_no_crypto(message, info, data); + TRY_STATUS(read_no_crypto(message, info, &data)); } else { if (auth_key.empty()) { return Status::Error("Failed to decrypt mtproto message: auth key is empty"); } - return read_crypto(message, auth_key, info, data); + TRY_STATUS(read_crypto(message, auth_key, info, &data)); } -} + return ReadResult::make_packet(data); +} // namespace mtproto size_t Transport::write(const Storer &storer, const AuthKey &auth_key, PacketInfo *info, MutableSlice dest) { if (info->type == PacketInfo::EndToEnd) { diff --git a/td/mtproto/Transport.h b/td/mtproto/Transport.h index 9184d86e..5591cfcd 100644 --- a/td/mtproto/Transport.h +++ b/td/mtproto/Transport.h @@ -137,6 +137,57 @@ struct PacketInfo { class Transport { public: + class ReadResult { + public: + enum Type { Packet, Nop, Error, Quickack }; + + static ReadResult make_nop() { + return {}; + } + static ReadResult make_error(int32 error_code) { + ReadResult res; + res.type_ = Error; + res.error_code_ = error_code; + return res; + } + static ReadResult make_packet(MutableSlice packet) { + CHECK(!packet.empty()); + ReadResult res; + res.type_ = Packet; + res.packet_ = packet; + return res; + } + static ReadResult make_quick_ack(uint32 quick_ack) { + ReadResult res; + res.type_ = Quickack; + res.quick_ack_ = quick_ack; + return res; + } + + Type type() const { + return type_; + } + + MutableSlice packet() const { + CHECK(type_ == Packet); + return packet_; + } + uint32 quick_ack() const { + CHECK(type_ == Quickack); + return quick_ack_; + } + int32 error() const { + CHECK(type_ == Error); + return error_code_; + } + + private: + Type type_ = Nop; + MutableSlice packet_; + int32 error_code_; + uint32 quick_ack_; + }; + static Result read_auth_key_id(Slice message); // Reads mtproto packet from [message] and saves into [data]. @@ -145,8 +196,7 @@ class Transport { // Returns size of mtproto packet. // If dest.size() >= size, the packet is also written into [dest]. // If auth_key is nonempty, encryption will be used. - static Status read(MutableSlice message, const AuthKey &auth_key, PacketInfo *info, MutableSlice *data, - int32 *error_code) TD_WARN_UNUSED_RESULT; + static Result read(MutableSlice message, const AuthKey &auth_key, PacketInfo *info) TD_WARN_UNUSED_RESULT; static size_t write(const Storer &storer, const AuthKey &auth_key, PacketInfo *info, MutableSlice dest = MutableSlice()); diff --git a/td/telegram/SecretChatActor.cpp b/td/telegram/SecretChatActor.cpp index d6374b43..bb5ae3b8 100644 --- a/td/telegram/SecretChatActor.cpp +++ b/td/telegram/SecretChatActor.cpp @@ -805,7 +805,7 @@ Result> SecretChatActor::decrypt(BufferSl BufferSlice encrypted_message_copy; int32 mtproto_version = -1; - int32 error_code = 0; + Result r_read_result; for (size_t i = 0; i < versions.size(); i++) { bool is_last = i + 1 == versions.size(); encrypted_message_copy = encrypted_message.copy(); @@ -817,19 +817,31 @@ Result> SecretChatActor::decrypt(BufferSl mtproto_version = versions[i]; info.version = mtproto_version; info.is_creator = auth_state_.x == 0; - auto status = mtproto::Transport::read(data, *auth_key, &info, &data, &error_code); - if (is_last) { - TRY_STATUS(std::move(status)); - } else if (status.is_error()) { - LOG(WARNING) << tag("mtproto", mtproto_version) << " decryption failed " << status; + r_read_result = mtproto::Transport::read(data, *auth_key, &info); + if (!is_last && r_read_result.is_error()) { + LOG(WARNING) << tag("mtproto", mtproto_version) << " decryption failed " << r_read_result.error(); continue; } break; } - - if (error_code) { - return Status::Error(PSLICE() << "Got mtproto error code: " << error_code); + TRY_RESULT(read_result, std::move(r_read_result)); + switch (read_result.type()) { + case mtproto::Transport::ReadResult::Quickack: { + return Status::Error("Got quickack instead of a message"); + } + case mtproto::Transport::ReadResult::Error: { + return Status::Error(PSLICE() << "Got mtproto error code instead of a message: " << read_result.error()); + } + case mtproto::Transport::ReadResult::Nop: { + return Status::Error("Got nop instead of a message"); + break; + } + case mtproto::Transport::ReadResult::Packet: { + data = read_result.packet(); + break; + } } + auto len = as(data.begin()); data = data.substr(4, len); if (!is_aligned_pointer<4>(data.data())) { diff --git a/tddb/td/db/binlog/BinlogEvent.cpp b/tddb/td/db/binlog/BinlogEvent.cpp index e4584e92..cf27bd12 100644 --- a/tddb/td/db/binlog/BinlogEvent.cpp +++ b/tddb/td/db/binlog/BinlogEvent.cpp @@ -14,7 +14,7 @@ int32 VERBOSITY_NAME(binlog) = VERBOSITY_NAME(DEBUG) + 8; Status BinlogEvent::init(BufferSlice &&raw_event, bool check_crc) { TlParser parser(raw_event.as_slice()); size_ = parser.fetch_int(); - CHECK(size_ == raw_event.size()); + CHECK(size_ == raw_event.size()) << size_ << " " << raw_event.size(); id_ = parser.fetch_long(); type_ = parser.fetch_int(); flags_ = parser.fetch_int();