From 6a7dfc4f01c57eec3d7e1f83f3ec41ad51b0c241 Mon Sep 17 00:00:00 2001 From: levlam Date: Mon, 15 Jun 2020 23:20:44 +0300 Subject: [PATCH] Minor AES improvements. GitOrigin-RevId: 138384ad375735b9e889172cae919368c9976456 --- benchmark/bench_crypto.cpp | 54 +++++++++++++++++++++++++++++-------- tdutils/td/utils/crypto.cpp | 38 +++++++++++++------------- tdutils/td/utils/crypto.h | 5 ++++ tdutils/test/crypto.cpp | 6 ++--- 4 files changed, 69 insertions(+), 34 deletions(-) diff --git a/benchmark/bench_crypto.cpp b/benchmark/bench_crypto.cpp index 781372d0a..4d4edd0fd 100644 --- a/benchmark/bench_crypto.cpp +++ b/benchmark/bench_crypto.cpp @@ -79,14 +79,14 @@ class AesEcbBench : public td::Benchmark { } }; -class AesIgeBench : public td::Benchmark { +class AesIgeEncryptBench : public td::Benchmark { public: alignas(64) unsigned char data[DATA_SIZE]; td::UInt256 key; td::UInt256 iv; std::string get_description() const override { - return PSTRING() << "AES IGE OpenSSL [" << (DATA_SIZE >> 10) << "KB]"; + return PSTRING() << "AES IGE OpenSSL encrypt [" << (DATA_SIZE >> 10) << "KB]"; } void start_up() override { @@ -107,6 +107,34 @@ class AesIgeBench : public td::Benchmark { } }; +class AesIgeDecryptBench : public td::Benchmark { + public: + alignas(64) unsigned char data[DATA_SIZE]; + td::UInt256 key; + td::UInt256 iv; + + std::string get_description() const override { + return PSTRING() << "AES IGE OpenSSL decrypt [" << (DATA_SIZE >> 10) << "KB]"; + } + + void start_up() override { + for (int i = 0; i < DATA_SIZE; i++) { + data[i] = 123; + } + td::Random::secure_bytes(key.raw, sizeof(key)); + td::Random::secure_bytes(iv.raw, sizeof(iv)); + } + + void run(int n) override { + td::MutableSlice data_slice(data, DATA_SIZE); + td::AesIgeState state; + state.init(as_slice(key), as_slice(iv), false); + for (int i = 0; i < n; i++) { + state.decrypt(data_slice, data_slice); + } + } +}; + class AesCtrBench : public td::Benchmark { public: alignas(64) unsigned char data[DATA_SIZE]; @@ -163,17 +191,17 @@ class AesCbcBench : public td::Benchmark { class AesIgeShortBench : public td::Benchmark { public: - static constexpr int DATA_SIZE = 16; - alignas(64) unsigned char data[DATA_SIZE]; + static constexpr int SHORT_DATA_SIZE = 16; + alignas(64) unsigned char data[SHORT_DATA_SIZE]; td::UInt256 key; td::UInt256 iv; std::string get_description() const override { - return PSTRING() << "AES IGE OpenSSL [" << (DATA_SIZE) << "B]"; + return PSTRING() << "AES IGE OpenSSL [" << SHORT_DATA_SIZE << "B]"; } void start_up() override { - for (int i = 0; i < DATA_SIZE; i++) { + for (int i = 0; i < SHORT_DATA_SIZE; i++) { data[i] = 123; } td::Random::secure_bytes(as_slice(key)); @@ -181,12 +209,15 @@ class AesIgeShortBench : public td::Benchmark { } void run(int n) override { - td::MutableSlice data_slice(data, DATA_SIZE); + td::MutableSlice data_slice(data, SHORT_DATA_SIZE); td::AesIgeState ige; for (int i = 0; i < n; i++) { - ige.init(as_slice(key), as_slice(iv), true); - ige.encrypt(data_slice, data_slice); - //td::aes_ige_encrypt(as_slice(key), as_slice(iv), data_slice, data_slice); + if (true) { + ige.init(as_slice(key), as_slice(iv), true); + ige.encrypt(data_slice, data_slice); + } else { + td::aes_ige_encrypt(as_slice(key), as_slice(iv), data_slice, data_slice); + } } } }; @@ -316,9 +347,10 @@ int main() { td::init_openssl_threads(); td::bench(AesIgeShortBench()); + td::bench(AesIgeEncryptBench()); + td::bench(AesIgeDecryptBench()); td::bench(AesCtrBench()); td::bench(AesEcbBench()); - td::bench(AesIgeBench()); td::bench(Pbkdf2Bench()); td::bench(RandBench()); diff --git a/tdutils/td/utils/crypto.cpp b/tdutils/td/utils/crypto.cpp index 7dd345cf4..fdcbba806 100644 --- a/tdutils/td/utils/crypto.cpp +++ b/tdutils/td/utils/crypto.cpp @@ -8,7 +8,6 @@ #include "td/utils/as.h" #include "td/utils/BigNum.h" -#include "td/utils/bits.h" #include "td/utils/common.h" #include "td/utils/logging.h" #include "td/utils/misc.h" @@ -49,7 +48,7 @@ namespace td { -struct alignas(8) AesBlock { +struct AesBlock { uint64 hi; uint64 lo; @@ -63,10 +62,9 @@ struct alignas(8) AesBlock { res.lo = lo ^ b.lo; return res; } - AesBlock &operator^=(const AesBlock &b) { + void operator^=(const AesBlock &b) { hi ^= b.hi; lo ^= b.lo; - return *this; } void load(const uint8 *from) { @@ -87,6 +85,8 @@ struct alignas(8) AesBlock { return res; } }; +static_assert(sizeof(AesBlock) == 16, ""); +static_assert(sizeof(AesBlock) == AES_BLOCK_SIZE, ""); class XorBytes { public: @@ -107,7 +107,7 @@ class XorBytes { size_t n; template - struct alignas(16) Block { + struct alignas(N) Block { uint8 data[N]; Block operator^(const Block &b) const & { Block res; @@ -134,8 +134,8 @@ class XorBytes { }; struct AesCtrCounterPack { - static constexpr size_t N = 32; - AesBlock blocks[N]; + static constexpr size_t BLOCK_COUNT = 32; + AesBlock blocks[BLOCK_COUNT]; uint8 *raw() { return reinterpret_cast(this); } @@ -144,7 +144,7 @@ struct AesCtrCounterPack { } size_t size() const { - return N * 16; + return sizeof(blocks); } Slice as_slice() const { @@ -156,13 +156,13 @@ struct AesCtrCounterPack { void init(AesBlock block) { blocks[0] = block; - for (size_t i = 1; i < N; i++) { + for (size_t i = 1; i < BLOCK_COUNT; i++) { blocks[i] = blocks[i - 1].inc(); } } void rotate() { - blocks[0] = blocks[N - 1].inc(); - for (size_t i = 1; i < N; i++) { + blocks[0] = blocks[BLOCK_COUNT - 1].inc(); + for (size_t i = 1; i < BLOCK_COUNT; i++) { blocks[i] = blocks[i - 1].inc(); } } @@ -373,7 +373,6 @@ int pq_factorize(Slice pq_str, string *p_str, string *q_str) { class AesState::Impl { public: EVP_CIPHER_CTX *ctx{nullptr}; - bool encrypt{false}; Impl() = default; Impl(const Impl &from) = delete; @@ -406,12 +405,10 @@ void AesState::init(Slice key, bool encrypt) { LOG_IF(FATAL, res != 1); } EVP_CIPHER_CTX_set_padding(impl_->ctx, 0); - impl_->encrypt = encrypt; } void AesState::encrypt(const uint8 *src, uint8 *dst, int size) { CHECK(impl_ != nullptr); - CHECK(impl_->encrypt); CHECK(impl_->ctx != nullptr); CHECK(size % 16 == 0); int len; @@ -422,7 +419,6 @@ void AesState::encrypt(const uint8 *src, uint8 *dst, int size) { void AesState::decrypt(const uint8 *src, uint8 *dst, int size) { CHECK(impl_ != nullptr); - CHECK(!impl_->encrypt); CHECK(impl_->ctx != nullptr); CHECK(size % 16 == 0); int len; @@ -469,6 +465,7 @@ class AesIgeState::Impl { AesState state; AesBlock iv; AesBlock iv2; + void encrypt(Slice from, MutableSlice to) { CHECK(from.size() % AES_BLOCK_SIZE == 0); CHECK(to.size() >= from.size()); @@ -476,7 +473,8 @@ class AesIgeState::Impl { auto in = from.ubegin(); auto out = to.ubegin(); - AesBlock tmp, tmp2; + AesBlock tmp; + AesBlock tmp2; while (len) { tmp.load(in); @@ -492,6 +490,7 @@ class AesIgeState::Impl { out += AES_BLOCK_SIZE; } } + void decrypt(Slice from, MutableSlice to) { CHECK(from.size() % AES_BLOCK_SIZE == 0); CHECK(to.size() >= from.size()); @@ -499,7 +498,8 @@ class AesIgeState::Impl { auto in = from.ubegin(); auto out = to.ubegin(); - AesBlock tmp, tmp2; + AesBlock tmp; + AesBlock tmp2; while (len) { tmp.load(in); @@ -597,10 +597,8 @@ class AesCtrState::Impl { fill(); } size_t min_n = td::min(n, current.size()); - auto curr = current.ubegin(); - XorBytes::run(src, curr, dst, min_n); + XorBytes::run(src, current.ubegin(), dst, min_n); src += min_n; - curr += min_n; dst += min_n; n -= min_n; current.remove_prefix(min_n); diff --git a/tdutils/td/utils/crypto.h b/tdutils/td/utils/crypto.h index de0c5f766..b646652eb 100644 --- a/tdutils/td/utils/crypto.h +++ b/tdutils/td/utils/crypto.h @@ -29,7 +29,9 @@ class AesState { ~AesState(); void init(Slice key, bool encrypt); + void encrypt(const uint8 *src, uint8 *dst, int size); + void decrypt(const uint8 *src, uint8 *dst, int size); private: @@ -50,8 +52,11 @@ struct AesIgeState { AesIgeState(AesIgeState &&from); AesIgeState &operator=(AesIgeState &&from); ~AesIgeState(); + void init(Slice key, Slice iv, bool encrypt); + void encrypt(Slice from, MutableSlice to); + void decrypt(Slice from, MutableSlice to); private: diff --git a/tdutils/test/crypto.cpp b/tdutils/test/crypto.cpp index 18fbf442f..5092c73b7 100644 --- a/tdutils/test/crypto.cpp +++ b/tdutils/test/crypto.cpp @@ -75,7 +75,7 @@ TEST(Crypto, AesCtrState) { ASSERT_EQ(answers1[i], td::crc32(t)); state.init(as_slice(key), as_slice(iv)); state.decrypt(t, t); - ASSERT_STREQ(s, t); + ASSERT_STREQ(td::base64_encode(s), td::base64_encode(t)); for (auto &c : iv.raw) { c = 0xFF; @@ -87,7 +87,6 @@ TEST(Crypto, AesCtrState) { i++; } } -#endif TEST(Crypto, AesIgeState) { td::vector answers1{0u, 2045698207u, 2423540300u, 525522475u, 1545267325u}; @@ -121,10 +120,11 @@ TEST(Crypto, AesIgeState) { state.init(as_slice(key), as_slice(iv), false); state.decrypt(t, t); - ASSERT_STREQ(s, t); + ASSERT_STREQ(td::base64_encode(s), td::base64_encode(t)); i++; } } +#endif TEST(Crypto, Sha256State) { for (auto length : {0, 1, 31, 32, 33, 9999, 10000, 10001, 999999, 1000001}) {