Check that storer unsafe doesn't overflows.

GitOrigin-RevId: ffbdcbbba7d26688b59cda00318d02fc06e402dc
This commit is contained in:
levlam 2018-07-06 23:33:11 +03:00
parent 4c404f3a68
commit 90f0f006b4
10 changed files with 33 additions and 14 deletions

View File

@ -50,7 +50,8 @@ Result<size_t> AuthKeyHandshake::fill_data_with_hash(uint8 *data_with_hash, cons
return Status::Error("Too big data"); return Status::Error("Too big data");
} }
as<int32>(data_ptr) = data.get_id(); as<int32>(data_ptr) = data.get_id();
tl_store_unsafe(data, data_ptr + 4); auto real_size = tl_store_unsafe(data, data_ptr + 4);
CHECK(real_size == data_size);
sha1(Slice(data_ptr, data_size + 4), data_with_hash); sha1(Slice(data_ptr, data_size + 4), data_with_hash);
return data_size + 20 + 4; return data_size + 20 + 4;
} }
@ -193,7 +194,8 @@ Status AuthKeyHandshake::on_server_dh_params(Slice message, Callback *connection
string encrypted_data_str(encrypted_data_size_with_pad, 0); string encrypted_data_str(encrypted_data_size_with_pad, 0);
MutableSlice encrypted_data = encrypted_data_str; MutableSlice encrypted_data = encrypted_data_str;
as<int32>(encrypted_data.begin() + 20) = data.get_id(); as<int32>(encrypted_data.begin() + 20) = data.get_id();
tl_store_unsafe(data, encrypted_data.ubegin() + 20 + 4); auto real_size = tl_store_unsafe(data, encrypted_data.ubegin() + 20 + 4);
CHECK(real_size + 4 == data_size);
sha1(Slice(encrypted_data.ubegin() + 20, data_size), encrypted_data.ubegin()); sha1(Slice(encrypted_data.ubegin() + 20, data_size), encrypted_data.ubegin());
Random::secure_bytes(encrypted_data.ubegin() + encrypted_data_size, Random::secure_bytes(encrypted_data.ubegin() + encrypted_data_size,
encrypted_data_size_with_pad - encrypted_data_size); encrypted_data_size_with_pad - encrypted_data_size);
@ -230,8 +232,10 @@ Status AuthKeyHandshake::on_dh_gen_response(Slice message, Callback *connection)
return Status::OK(); return Status::OK();
} }
void AuthKeyHandshake::send(Callback *connection, const Storer &storer) { void AuthKeyHandshake::send(Callback *connection, const Storer &storer) {
auto writer = BufferWriter{storer.size(), 0, 0}; auto size = storer.size();
storer.store(writer.as_slice().ubegin()); auto writer = BufferWriter{size, 0, 0};
auto real_size = storer.store(writer.as_slice().ubegin());
CHECK(real_size == size);
last_query_ = writer.as_buffer_slice(); last_query_ = writer.as_buffer_slice();
return do_send(connection, create_storer(last_query_.as_slice())); return do_send(connection, create_storer(last_query_.as_slice()));
} }

View File

@ -743,8 +743,10 @@ std::pair<uint64, BufferSlice> SessionConnection::encrypted_bind(int64 perm_key,
mtproto_api::bind_auth_key_inner object(nonce, temp_key, perm_key, auth_data_->session_id_, expire_at); mtproto_api::bind_auth_key_inner object(nonce, temp_key, perm_key, auth_data_->session_id_, expire_at);
auto object_storer = create_storer(object); auto object_storer = create_storer(object);
auto object_packet = BufferWriter{object_storer.size(), 0, 0}; auto size = object_storer.size();
object_storer.store(object_packet.as_slice().ubegin()); auto object_packet = BufferWriter{size, 0, 0};
auto real_size = object_storer.store(object_packet.as_slice().ubegin());
CHECK(size == real_size);
Query query{auth_data_->next_message_id(Time::now_cached()), 0, object_packet.as_buffer_slice(), false, 0, false}; Query query{auth_data_->next_message_id(Time::now_cached()), 0, object_packet.as_buffer_slice(), false, 0, false};
PacketStorer<QueryImpl> query_storer(query, Slice()); PacketStorer<QueryImpl> query_storer(query, Slice());

View File

@ -194,14 +194,16 @@ size_t Transport::write_no_crypto(const Storer &storer, PacketInfo *info, Mutabl
} }
auto &header = as<NoCryptoHeader>(dest.begin()); auto &header = as<NoCryptoHeader>(dest.begin());
header.auth_key_id = 0; header.auth_key_id = 0;
storer.store(header.data); auto real_size = storer.store(header.data);
CHECK(real_size == storer.size());
return size; return size;
} }
template <class HeaderT> template <class HeaderT>
void Transport::write_crypto_impl(int X, const Storer &storer, const AuthKey &auth_key, PacketInfo *info, void Transport::write_crypto_impl(int X, const Storer &storer, const AuthKey &auth_key, PacketInfo *info,
HeaderT *header, size_t data_size) { HeaderT *header, size_t data_size) {
storer.store(header->data); auto real_data_size = storer.store(header->data);
CHECK(real_data_size == data_size);
VLOG(raw_mtproto) << "SEND" << format::as_hex_dump<4>(Slice(header->data, data_size)); VLOG(raw_mtproto) << "SEND" << format::as_hex_dump<4>(Slice(header->data, data_size));
// LOG(ERROR) << "SEND" << format::as_hex_dump<4>(Slice(header->data, data_size)) << info->version; // LOG(ERROR) << "SEND" << format::as_hex_dump<4>(Slice(header->data, data_size)) << info->version;

View File

@ -1300,6 +1300,7 @@ string as_key(const T &object) {
TlStorerUnsafe storer(key.ubegin()); TlStorerUnsafe storer(key.ubegin());
storer.store_int(T::KEY_MAGIC); storer.store_int(T::KEY_MAGIC);
object.as_key().store(storer); object.as_key().store(storer);
CHECK(storer.get_buf() == key.uend());
return key.str(); return key.str();
} }

View File

@ -14,7 +14,8 @@ NetQueryCreator::Ptr NetQueryCreator::create(uint64 id, const Storer &storer, Dc
NetQuery::AuthFlag auth_flag, NetQuery::GzipFlag gzip_flag, NetQuery::AuthFlag auth_flag, NetQuery::GzipFlag gzip_flag,
double total_timeout_limit) { double total_timeout_limit) {
BufferSlice slice(storer.size()); BufferSlice slice(storer.size());
storer.store(slice.as_slice().ubegin()); auto real_size = storer.store(slice.as_slice().ubegin());
CHECK(real_size == slice.size());
// TODO: magic constant // TODO: magic constant
if (slice.size() < (1 << 8)) { if (slice.size() < (1 << 8)) {

View File

@ -461,7 +461,8 @@ void Session::on_session_created(uint64 unique_id, uint64 first_id) {
telegram_api::updatesTooLong too_long_; telegram_api::updatesTooLong too_long_;
auto storer = create_storer(too_long_); auto storer = create_storer(too_long_);
BufferSlice packet(storer.size()); BufferSlice packet(storer.size());
storer.store(packet.as_slice().ubegin()); auto real_size = storer.store(packet.as_slice().ubegin());
CHECK(real_size == packet.size());
return_query(G()->net_query_creator().create_result(0, std::move(packet))); return_query(G()->net_query_creator().create_result(0, std::move(packet)));
} }

View File

@ -19,7 +19,7 @@ class Storer {
Storer &operator=(Storer &&) = default; Storer &operator=(Storer &&) = default;
virtual ~Storer() = default; virtual ~Storer() = default;
virtual size_t size() const = 0; virtual size_t size() const = 0;
virtual size_t store(uint8 *ptr) const = 0; virtual size_t store(uint8 *ptr) const TD_WARN_UNUSED_RESULT = 0;
}; };
} // namespace td } // namespace td

View File

@ -189,11 +189,13 @@ string serialize(const T &object) {
MutableSlice data = ptr.as_slice(); MutableSlice data = ptr.as_slice();
TlStorerUnsafe storer(data.ubegin()); TlStorerUnsafe storer(data.ubegin());
store(object, storer); store(object, storer);
CHECK(storer.get_buf() == data.uend());
key.assign(data.begin(), data.size()); key.assign(data.begin(), data.size());
} else { } else {
MutableSlice data = key; MutableSlice data = key;
TlStorerUnsafe storer(data.ubegin()); TlStorerUnsafe storer(data.ubegin());
store(object, storer); store(object, storer);
CHECK(storer.get_buf() == data.uend());
} }
return key; return key;
} }

View File

@ -272,6 +272,9 @@ size_t tl_calc_length(const T &data) {
return storer_calc_length.get_length(); return storer_calc_length.get_length();
} }
template <class T>
size_t tl_store_unsafe(const T &data, unsigned char *dst) TD_WARN_UNUSED_RESULT;
template <class T> template <class T>
size_t tl_store_unsafe(const T &data, unsigned char *dst) { size_t tl_store_unsafe(const T &data, unsigned char *dst) {
TlStorerUnsafe storer_unsafe(dst); TlStorerUnsafe storer_unsafe(dst);

View File

@ -867,7 +867,8 @@ class Master : public Actor {
config.version_ = 12; config.version_ = 12;
auto storer = TLObjectStorer<my_api::messages_dhConfig>(config); auto storer = TLObjectStorer<my_api::messages_dhConfig>(config);
BufferSlice answer(storer.size()); BufferSlice answer(storer.size());
storer.store(answer.as_slice().ubegin()); auto real_size = storer.store(answer.as_slice().ubegin());
CHECK(real_size == answer.size());
net_query->set_ok(std::move(answer)); net_query->set_ok(std::move(answer));
send_closure(std::move(callback), &NetQueryCallback::on_result, std::move(net_query)); send_closure(std::move(callback), &NetQueryCallback::on_result, std::move(net_query));
} }
@ -891,7 +892,8 @@ class Master : public Actor {
my_api::encryptedChat encrypted_chat(123, 321, 0, 1, 2, BufferSlice(), request_encryption.key_fingerprint_); my_api::encryptedChat encrypted_chat(123, 321, 0, 1, 2, BufferSlice(), request_encryption.key_fingerprint_);
auto storer = TLObjectStorer<my_api::encryptedChat>(encrypted_chat); auto storer = TLObjectStorer<my_api::encryptedChat>(encrypted_chat);
BufferSlice answer(storer.size()); BufferSlice answer(storer.size());
storer.store(answer.as_slice().ubegin()); auto real_size = storer.store(answer.as_slice().ubegin());
CHECK(real_size == answer.size());
net_query->set_ok(std::move(answer)); net_query->set_ok(std::move(answer));
send_closure(std::move(callback), &NetQueryCallback::on_result, std::move(net_query)); send_closure(std::move(callback), &NetQueryCallback::on_result, std::move(net_query));
send_closure(alice_, &SecretChatProxy::start_test); send_closure(alice_, &SecretChatProxy::start_test);
@ -935,7 +937,8 @@ class Master : public Actor {
sent_message.date_ = 0; sent_message.date_ = 0;
auto storer = TLObjectStorer<my_api::messages_sentEncryptedMessage>(sent_message); auto storer = TLObjectStorer<my_api::messages_sentEncryptedMessage>(sent_message);
BufferSlice answer(storer.size()); BufferSlice answer(storer.size());
storer.store(answer.as_slice().ubegin()); auto real_size = storer.store(answer.as_slice().ubegin());
CHECK(real_size == answer.size());
net_query->set_ok(std::move(answer)); net_query->set_ok(std::move(answer));
send_closure(std::move(callback), &NetQueryCallback::on_result, std::move(net_query)); send_closure(std::move(callback), &NetQueryCallback::on_result, std::move(net_query));