diff --git a/td/mtproto/Handshake.cpp b/td/mtproto/Handshake.cpp index ae4ca5c1..37fc5d93 100644 --- a/td/mtproto/Handshake.cpp +++ b/td/mtproto/Handshake.cpp @@ -56,7 +56,7 @@ Result AuthKeyHandshake::fill_data_with_hash(uint8 *data_with_hash, cons } Status AuthKeyHandshake::on_res_pq(Slice message, Callback *connection, PublicRsaKeyInterface *public_rsa_key) { - TRY_RESULT(res_pq, fetch_result(message)); + TRY_RESULT(res_pq, fetch_result(message, false)); if (res_pq->nonce_ != nonce) { return Status::Error("Nonce mismatch"); } @@ -116,7 +116,7 @@ Status AuthKeyHandshake::on_res_pq(Slice message, Callback *connection, PublicRs } Status AuthKeyHandshake::on_server_dh_params(Slice message, Callback *connection, DhCallback *dh_callback) { - TRY_RESULT(server_dh_params, fetch_result(message)); + TRY_RESULT(server_dh_params, fetch_result(message, false)); switch (server_dh_params->get_id()) { case mtproto_api::server_DH_params_ok::ID: break; @@ -215,7 +215,7 @@ Status AuthKeyHandshake::on_server_dh_params(Slice message, Callback *connection } Status AuthKeyHandshake::on_dh_gen_response(Slice message, Callback *connection) { - TRY_RESULT(answer, fetch_result(message)); + TRY_RESULT(answer, fetch_result(message, false)); switch (answer->get_id()) { case mtproto_api::dh_gen_ok::ID: state_ = Finish; diff --git a/td/mtproto/NoCryptoStorer.h b/td/mtproto/NoCryptoStorer.h index 2dc92abf..7836d734 100644 --- a/td/mtproto/NoCryptoStorer.h +++ b/td/mtproto/NoCryptoStorer.h @@ -7,22 +7,33 @@ #pragma once #include "td/mtproto/PacketStorer.h" +#include "td/utils/Random.h" + namespace td { namespace mtproto { class NoCryptoImpl { public: - NoCryptoImpl(uint64 message_id, const Storer &data) : message_id(message_id), data(data) { + NoCryptoImpl(uint64 message_id, const Storer &data, bool need_pad = true) : message_id(message_id), data(data) { + if (need_pad) { + auto data_size = data.size(); + auto pad_size = (data_size + 15) / 16 * 16 - data_size; + pad_size += 16 * (static_cast(Random::secure_int32()) % 16); + pad_.resize(pad_size); + Random::secure_bytes(pad_); + } } template void do_store(T &storer) const { storer.store_binary(message_id); - storer.store_binary(static_cast(data.size())); + storer.store_binary(static_cast(data.size() + pad_.size())); storer.store_storer(data); + storer.store_slice(pad_); } private: uint64 message_id; const Storer &data; + std::string pad_; }; } // namespace mtproto } // namespace td diff --git a/td/mtproto/utils.h b/td/mtproto/utils.h index 9bbb706d..2e6b442b 100644 --- a/td/mtproto/utils.h +++ b/td/mtproto/utils.h @@ -31,11 +31,13 @@ struct Query { } // namespace mtproto template -Result fetch_result(Slice message) { +Result fetch_result(Slice message, bool check_end = true) { TlParser parser(message); auto result = T::fetch_result(parser); - parser.fetch_end(); + if (check_end) { + parser.fetch_end(); + } const char *error = parser.get_error(); if (error != nullptr) { LOG(ERROR) << "Can't parse: " << format::as_hex_dump<4>(message); @@ -46,11 +48,13 @@ Result fetch_result(Slice message) { } template -Result fetch_result(const BufferSlice &message) { +Result fetch_result(const BufferSlice &message, bool check_end = true) { TlBufferParser parser(&message); auto result = T::fetch_result(parser); - parser.fetch_end(); + if (check_end) { + parser.fetch_end(); + } const char *error = parser.get_error(); if (error != nullptr) { LOG(ERROR) << "Can't parse: " << format::as_hex_dump<4>(message.as_slice());