diff --git a/td/mtproto/Transport.cpp b/td/mtproto/Transport.cpp index 7865ee16..3d8027bb 100644 --- a/td/mtproto/Transport.cpp +++ b/td/mtproto/Transport.cpp @@ -128,18 +128,16 @@ Status Transport::read_crypto_impl(int X, MutableSlice message, const AuthKey &a auto *prefix = reinterpret_cast(header->data); *prefix_ptr = prefix; size_t data_size = prefix->message_data_length + sizeof(PrefixT); - bool is_length_ok = prefix->message_data_length % 4 == 0; + bool is_length_ok = true; UInt128 real_message_key; if (info->version == 1) { + is_length_ok &= !info->check_mod4 || prefix->message_data_length % 4 == 0; auto expected_size = calc_crypto_size(data_size); is_length_ok = (is_length_ok & (expected_size == message.size())) != 0; auto check_size = data_size * is_length_ok + tail_size * (1 - is_length_ok); std::tie(info->message_ack, real_message_key) = calc_message_ack_and_key(*header, check_size); } else { - size_t pad_size = tail_size - data_size; - is_length_ok = (is_length_ok & (tail_size - sizeof(PrefixT) >= prefix->message_data_length) & (12 <= pad_size) & - (pad_size <= 1024)) != 0; std::tie(info->message_ack, real_message_key) = calc_message_key2(auth_key, X, to_decrypt); } @@ -153,9 +151,29 @@ Status Transport::read_crypto_impl(int X, MutableSlice message, const AuthKey &a << format::as_hex_dump(header->message_key) << "] [expected=" << format::as_hex_dump(real_message_key) << "]"); } - if (!is_length_ok) { - return Status::Error(PSLICE() << "Invalid mtproto message: invalid length " << tag("total_size", message.size()) - << tag("message_data_length", prefix->message_data_length)); + + if (info->version == 2) { + if (info->check_mod4 && prefix->message_data_length % 4 != 0) { + return Status::Error(PSLICE() << "Invalid mtproto message: invalid length (not divisible by four)" + << tag("total_size", message.size()) + << tag("message_data_length", prefix->message_data_length)); + } + if (tail_size - sizeof(PrefixT) < prefix->message_data_length) { + return Status::Error(PSLICE() << "Invalid mtproto message: invalid length (message_data_length is too big)" + << tag("total_size", message.size()) + << tag("message_data_length", prefix->message_data_length)); + } + size_t pad_size = tail_size - data_size; + if (pad_size < 12 || pad_size > 1024) { + return Status::Error(PSLICE() << "Invalid mtproto message: invalid length (invalid padding length)" + << tag("padding_size", pad_size) << tag("total_size", message.size()) + << tag("message_data_length", prefix->message_data_length)); + } + } else { + if (!is_length_ok) { + return Status::Error(PSLICE() << "Invalid mtproto message: invalid length " << tag("total_size", message.size()) + << tag("message_data_length", prefix->message_data_length)); + } } *data = MutableSlice(header->data, data_size); diff --git a/td/mtproto/Transport.h b/td/mtproto/Transport.h index bee0fdda..8b9ec8c3 100644 --- a/td/mtproto/Transport.h +++ b/td/mtproto/Transport.h @@ -132,9 +132,10 @@ struct PacketInfo { uint64 message_id; int32 seq_no; - int32 version = 1; + int32 version{1}; bool no_crypto_flag; - bool is_creator = false; + bool is_creator{false}; + bool check_mod4{true}; }; class Transport { diff --git a/td/mtproto/crypto.cpp b/td/mtproto/crypto.cpp index 1495b3e9..16788584 100644 --- a/td/mtproto/crypto.cpp +++ b/td/mtproto/crypto.cpp @@ -334,7 +334,7 @@ std::pair DhHandshake::gen_key() { return std::pair(key_id, std::move(key)); } -int64 DhHandshake::calc_key_id(const string &auth_key) { +int64 DhHandshake::calc_key_id(Slice auth_key) { UInt<160> auth_key_sha1; sha1(auth_key, auth_key_sha1.raw); return as(auth_key_sha1.raw + 12); diff --git a/td/mtproto/crypto.h b/td/mtproto/crypto.h index 285a1e4a..fd0960a3 100644 --- a/td/mtproto/crypto.h +++ b/td/mtproto/crypto.h @@ -83,7 +83,7 @@ class DhHandshake { std::pair gen_key(); - static int64 calc_key_id(const string &auth_key); + static int64 calc_key_id(Slice auth_key); enum Flags { HasConfig = 1, HasGA = 2 }; diff --git a/td/telegram/DeviceTokenManager.cpp b/td/telegram/DeviceTokenManager.cpp index 26fbb52a..c9c7f971 100644 --- a/td/telegram/DeviceTokenManager.cpp +++ b/td/telegram/DeviceTokenManager.cpp @@ -254,9 +254,7 @@ void DeviceTokenManager::register_device(tl_object_ptr devi info.encryption_key.resize(ENCRYPTION_KEY_LENGTH); while (true) { Random::secure_bytes(info.encryption_key); - uint8 sha1_buf[20]; - sha1(info.encryption_key, sha1_buf); - info.encryption_key_id = as(sha1_buf + 12); + info.encryption_key_id = DhHandshake::calc_key_id(info.encryption_key); if (info.encryption_key_id <= -MIN_ENCRYPTION_KEY_ID || info.encryption_key_id >= MIN_ENCRYPTION_KEY_ID) { // ensure that encryption key ID never collide with anything break; diff --git a/td/telegram/NotificationManager.cpp b/td/telegram/NotificationManager.cpp index d7e22148..7297bc51 100644 --- a/td/telegram/NotificationManager.cpp +++ b/td/telegram/NotificationManager.cpp @@ -16,6 +16,9 @@ #include "td/telegram/Td.h" #include "td/telegram/TdDb.h" +#include "td/mtproto/Transport.h" +#include "td/mtproto/AuthKey.h" + #include "td/utils/as.h" #include "td/utils/base64.h" #include "td/utils/format.h" @@ -2134,6 +2137,53 @@ Result NotificationManager::get_push_receiver_id(string payload) { return Status::Error(200, "Unsupported push notification"); } +Result NotificationManager::decrypt_push(int64 encryption_key_id, string encryption_key, string push) { + auto r_json_value = json_decode(push); + if (r_json_value.is_error()) { + return Status::Error(400, "Failed to parse payload as JSON object"); + } + + auto json_value = r_json_value.move_as_ok(); + if (json_value.type() != JsonValue::Type::Object) { + return Status::Error(400, "Expected JSON object"); + } + + for (auto &field_value : json_value.get_object()) { + if (field_value.first == "p") { + auto encrypted_payload = std::move(field_value.second); + if (encrypted_payload.type() != JsonValue::Type::String) { + return Status::Error(400, "Expected encrypted payload as a String"); + } + Slice data = encrypted_payload.get_string(); + if (data.size() < 12) { + return Status::Error(400, "Encrypted payload is too small"); + } + auto r_decoded = base64url_decode(data); + if (r_decoded.is_error()) { + return Status::Error(400, "Failed to base64url-decode payload"); + } + return decrypt_push_payload(encryption_key_id, std::move(encryption_key), r_decoded.move_as_ok()); + } + } + return Status::Error(400, "No 'p'(payload) field found in push"); +} + +Result NotificationManager::decrypt_push_payload(int64 encryption_key_id, string encryption_key, + string payload) { + mtproto::AuthKey auth_key(encryption_key_id, std::move(encryption_key)); + mtproto::PacketInfo packet_info; + packet_info.version = 2; + packet_info.type = mtproto::PacketInfo::EndToEnd; + packet_info.is_creator = true; + packet_info.check_mod4 = false; + + TRY_RESULT(result, mtproto::Transport::read(payload, auth_key, &packet_info)); + if (result.type() != mtproto::Transport::ReadResult::Packet) { + return Status::Error(400, "Wrong packet type"); + } + return result.packet().str(); +} + void NotificationManager::before_get_difference() { if (is_disabled()) { return; diff --git a/td/telegram/NotificationManager.h b/td/telegram/NotificationManager.h index bf72e617..34f2ba45 100644 --- a/td/telegram/NotificationManager.h +++ b/td/telegram/NotificationManager.h @@ -91,8 +91,9 @@ class NotificationManager : public Actor { void process_push_notification(string payload, Promise &&promise); - static Result get_push_receiver_id(string payload); - + static Result get_push_receiver_id(string push); + static Result decrypt_push(int64 encryption_key_id, string encryption_key, string push); + static Result decrypt_push_payload(int64 encryption_key_id, string encryption_key, string payload); void before_get_difference(); void after_get_difference(); diff --git a/tdutils/td/utils/as.h b/tdutils/td/utils/as.h index 815e777d..dd0d57ec 100644 --- a/tdutils/td/utils/as.h +++ b/tdutils/td/utils/as.h @@ -38,6 +38,9 @@ class As { std::memcpy(&res, ptr_, sizeof(T)); return res; } + bool operator==(const As &other) const { + return this->operator T() == other.operator T(); + } private: void *ptr_; diff --git a/test/mtproto.cpp b/test/mtproto.cpp index 6ed3397c..7d1e7452 100644 --- a/test/mtproto.cpp +++ b/test/mtproto.cpp @@ -15,6 +15,7 @@ #include "td/mtproto/HandshakeConnection.h" #include "td/mtproto/PingConnection.h" #include "td/mtproto/RawConnection.h" +#include "td/mtproto/Transport.h" #include "td/net/GetHostByNameActor.h" #include "td/net/Socks5.h" @@ -23,7 +24,10 @@ #include "td/telegram/ConfigManager.h" #include "td/telegram/net/DcId.h" #include "td/telegram/net/PublicRsaKeyShared.h" +#include "td/telegram/NotificationManager.h" +#include "td/utils/as.h" +#include "td/utils/base64.h" #include "td/utils/logging.h" #include "td/utils/port/IPAddress.h" #include "td/utils/port/SocketFd.h" @@ -426,3 +430,27 @@ TEST(Mtproto, socks5) { } sched.finish(); } + +TEST(Mtproto, notifications) { + string push = + "eyJwIjoiSkRnQ3NMRWxEaWhyVWRRN1pYM3J1WVU4TlRBMFhMb0N6UWRNdzJ1cWlqMkdRbVR1WXVvYXhUeFJHaG1QQm8yVElYZFBzX2N3b2RIb3lY" + "b2drVjM1dVl0UzdWeElNX1FNMDRKMG1mV3ZZWm4zbEtaVlJ0aFVBNGhYUWlaN0pfWDMyZDBLQUlEOWgzRnZwRjNXUFRHQWRaVkdFYzg3bnFPZ3hD" + "NUNMRkM2SU9fZmVqcEpaV2RDRlhBWWpwc1k2aktrbVNRdFZ1MzE5ZW04UFVieXZudFpfdTNud2hjQ0czMk96TGp4S1kyS1lzU21JZm1GMzRmTmw1" + "QUxaa2JvY2s2cE5rZEdrak9qYmRLckJyU0ZtWU8tQ0FsRE10dEplZFFnY1U5bVJQdU80b1d2NG5sb1VXS19zSlNTaXdIWEZyb1pWTnZTeFJ0Z1dN" + "ZyJ9"; + string key = + "uBa5yu01a-nJJeqsR3yeqMs6fJLYXjecYzFcvS6jIwS3nefBIr95LWrTm-IbRBNDLrkISz1Sv0KYpDzhU8WFRk1D0V_" + "qyO7XsbDPyrYxRBpGxofJUINSjb1uCxoSdoh1_F0UXEA2fWWKKVxL0DKUQssZfbVj3AbRglsWpH-jDK1oc6eBydRiS3i4j-" + "H0yJkEMoKRgaF9NaYI4u26oIQ-Ez46kTVU-R7e3acdofOJKm7HIKan_5ZMg82Dvec2M6vc_" + "I54Vs28iBx8IbBO1y5z9WSScgW3JCvFFKP2MXIu7Jow5-cpUx6jXdzwRUb9RDApwAFKi45zpv8eb3uPCDAmIQ"; + string decrypted_payload = + "fwAAAHsibG9jX2tleSI6Ik1FU1NBR0VfVEVYVCIsImxvY19hcmdzIjpbIkFyc2VueSBTbWlybm92IiwiYWJjZGVmZyJdLCJjdXN0b20iOnsibXNn" + "X2lkIjoiNTkwMDQ3IiwiZnJvbV9pZCI6IjYyODE0In0sImJhZGdlIjoiNDA5In0"; + push = base64url_decode(push).move_as_ok(); + key = base64url_decode(key).move_as_ok(); + decrypted_payload = base64url_decode(decrypted_payload).move_as_ok(); + + auto key_id = DhHandshake::calc_key_id(key); + ASSERT_EQ(key_id, NotificationManager::get_push_receiver_id(push).ok()); + ASSERT_EQ(decrypted_payload, NotificationManager::decrypt_push(key_id, key, push).ok()); +}