diff --git a/benchmark/bench_crypto.cpp b/benchmark/bench_crypto.cpp index 6536fa613..369506884 100644 --- a/benchmark/bench_crypto.cpp +++ b/benchmark/bench_crypto.cpp @@ -71,13 +71,42 @@ class AesBench : public td::Benchmark { state.init(td::as_slice(key), true); td::MutableSlice data_slice(data, DATA_SIZE); for (int i = 0; i <= n; i++) { - for (size_t offset = 0; offset + 16 <= data_slice.size(); offset += 16) { - state.encrypt(data_slice.ubegin() + offset, data_slice.ubegin() + offset); + size_t step = 16; + for (size_t offset = 0; offset + 16 <= data_slice.size(); offset += step) { + state.encrypt(data_slice.ubegin() + offset, data_slice.ubegin() + offset, (int)step); } } } }; +class AesCtrBench : public td::Benchmark { + public: + alignas(64) unsigned char data[DATA_SIZE]; + td::UInt256 key; + td::UInt128 iv; + + std::string get_description() const override { + return PSTRING() << "AES CTR OpenSSL [" << (DATA_SIZE >> 10) << "KB]"; + } + + void start_up() override { + for (int i = 0; i < DATA_SIZE; i++) { + data[i] = 123; + } + td::Random::secure_bytes(key.raw, sizeof(key)); + td::Random::secure_bytes(iv.raw, sizeof(iv)); + } + + void run(int n) override { + td::AesCtrState state; + state.init(as_slice(key), as_slice(iv)); + td::MutableSlice data_slice(data, DATA_SIZE); + for (int i = 0; i < n; i++) { + state.encrypt(data_slice, data_slice); + } + } +}; + class AESBench : public td::Benchmark { public: alignas(64) unsigned char data[DATA_SIZE]; @@ -227,6 +256,7 @@ class Crc64Bench : public td::Benchmark { int main() { td::init_openssl_threads(); + td::bench(AesCtrBench()); td::bench(AesBench()); td::bench(AESBench()); diff --git a/tdutils/td/utils/crypto.cpp b/tdutils/td/utils/crypto.cpp index 9a37cb4c7..465703b87 100644 --- a/tdutils/td/utils/crypto.cpp +++ b/tdutils/td/utils/crypto.cpp @@ -273,30 +273,28 @@ void AesState::init(Slice key, bool encrypt) { if (encrypt) { CHECK(1 == EVP_EncryptInit_ex(impl_->ctx, EVP_aes_256_ecb(), nullptr, key.ubegin(), nullptr)); - AES_set_encrypt_key(key.ubegin(), 256, &impl_->key); } else { CHECK(1 == EVP_DecryptInit_ex(impl_->ctx, EVP_aes_256_ecb(), nullptr, key.ubegin(), nullptr)); - AES_set_decrypt_key(key.ubegin(), 256, &impl_->key); } EVP_CIPHER_CTX_set_padding(impl_->ctx, 0); impl_->encrypt = encrypt; } -void AesState::encrypt(const uint8 *src, uint8 *dst) { +void AesState::encrypt(const uint8 *src, uint8 *dst, int size) { CHECK(impl_->encrypt); CHECK(impl_->ctx); + CHECK(size % 16 == 0); int len; - CHECK(1 == EVP_EncryptUpdate(impl_->ctx, dst, &len, src, 16)); - CHECK(len == 16); - //AES_encrypt(src, dst, &impl_->key); + CHECK(1 == EVP_EncryptUpdate(impl_->ctx, dst, &len, src, size)); + CHECK(len == size); } -void AesState::decrypt(const uint8 *src, uint8 *dst) { +void AesState::decrypt(const uint8 *src, uint8 *dst, int size) { CHECK(!impl_->encrypt); CHECK(impl_->ctx); + CHECK(size % 16 == 0); int len; - CHECK(1 == EVP_DecryptUpdate(impl_->ctx, dst, &len, src, 16)); - LOG_CHECK(len == 16) << len; - //AES_decrypt(src, dst, &impl_->key); + CHECK(1 == EVP_DecryptUpdate(impl_->ctx, dst, &len, src, size)); + CHECK(len == size); } static void aes_ige_xcrypt(Slice aes_key, MutableSlice aes_iv, Slice from, MutableSlice to, bool encrypt_flag) { @@ -363,35 +361,63 @@ class AesCtrState::Impl { CHECK(key.size() == 32); CHECK(iv.size() == 16); static_assert(AES_BLOCK_SIZE == 16, ""); - if (AES_set_encrypt_key(key.ubegin(), 256, &aes_key) < 0) { - LOG(FATAL) << "Failed to set encrypt key"; - } + aes_state.init(key, true); counter.as_mutable_slice().copy_from(iv); - current_pos = 0; + fill(); } void encrypt(Slice from, MutableSlice to) { - CHECK(to.size() >= from.size()); - for (size_t i = 0; i < from.size(); i++) { - if (current_pos == 0) { - AES_encrypt(counter.as_slice().ubegin(), encrypted_counter.as_mutable_slice().ubegin(), &aes_key); - uint8 *ptr = counter.as_mutable_slice().ubegin(); - for (int j = 15; j >= 0; j--) { - if (++ptr[j] != 0) { - break; - } + auto *src = from.ubegin(); + auto *dst = to.ubegin(); + auto n = from.size(); + while (n != 0) { + if (current.empty()) { + if (N != 1) { + counter.as_mutable_slice().copy_from(counter.as_slice().substr((N - 1) * AES_BLOCK_SIZE)); } + inc(counter.as_mutable_slice().ubegin()); + fill(); } - to[i] = static_cast(from[i] ^ encrypted_counter[current_pos]); - current_pos = (current_pos + 1) & 15; + size_t min_n = td::min(n, current.size()); + auto curr = current.ubegin(); + for (size_t i = 0; i < min_n; i++) { + dst[i] = src[i] ^ curr[i]; + } + n -= min_n; + src += min_n; + dst += min_n; + current.remove_prefix(min_n); } } private: - AES_KEY aes_key; - SecureString counter{AES_BLOCK_SIZE}; - SecureString encrypted_counter{AES_BLOCK_SIZE}; - uint8 current_pos; + AesState aes_state; + + static constexpr size_t N = 32; + SecureString counter{AES_BLOCK_SIZE * N}; + SecureString encrypted_counter{AES_BLOCK_SIZE * N}; + td::Slice current; + + void inc(uint8 *ptr) { + for (int j = 15; j >= 0; j--) { + if (++ptr[j] != 0) { + break; + } + } + } + void fill() { + auto *src = counter.as_slice().ubegin(); + auto *dst = counter.as_mutable_slice().ubegin() + AES_BLOCK_SIZE; + for (size_t i = 0; i + 1 < N; i++) { + memcpy(dst, src, AES_BLOCK_SIZE); + inc(dst); + src += AES_BLOCK_SIZE; + dst += AES_BLOCK_SIZE; + } + + aes_state.encrypt(counter.as_slice().ubegin(), encrypted_counter.as_mutable_slice().ubegin(), (int)counter.size()); + current = encrypted_counter.as_slice(); + } }; AesCtrState::AesCtrState() = default; diff --git a/tdutils/td/utils/crypto.h b/tdutils/td/utils/crypto.h index 776b1c85a..fafe6d205 100644 --- a/tdutils/td/utils/crypto.h +++ b/tdutils/td/utils/crypto.h @@ -28,8 +28,8 @@ struct AesState { AesState &operator=(AesState &&from); ~AesState(); void init(Slice key, bool encrypt); - void encrypt(const uint8 *src, uint8 *dst); - void decrypt(const uint8 *src, uint8 *dst); + void encrypt(const uint8 *src, uint8 *dst, int size); + void decrypt(const uint8 *src, uint8 *dst, int size); private: class Impl; diff --git a/tdutils/test/crypto.cpp b/tdutils/test/crypto.cpp index f0b602970..ca15638a2 100644 --- a/tdutils/test/crypto.cpp +++ b/tdutils/test/crypto.cpp @@ -32,8 +32,8 @@ TEST(Crypto, Aes) { td::AesState decryptor; decryptor.init(as_slice(key), false); - encryptor.encrypt(td::as_slice(plaintext).ubegin(), td::as_slice(encrypted).ubegin()); - decryptor.decrypt(td::as_slice(encrypted).ubegin(), td::as_slice(decrypted).ubegin()); + encryptor.encrypt(td::as_slice(plaintext).ubegin(), td::as_slice(encrypted).ubegin(), 16); + decryptor.decrypt(td::as_slice(encrypted).ubegin(), td::as_slice(decrypted).ubegin(), 16); CHECK(decrypted == plaintext); CHECK(decrypted != encrypted);