From 8845e18da9de009b749af717ccacafe60d0235a3 Mon Sep 17 00:00:00 2001 From: Arseny Smirnov Date: Mon, 15 Jun 2020 16:58:58 +0300 Subject: [PATCH] tdutils: simplify aes ctr GitOrigin-RevId: 557cc787f77e2f0af494e7dd46fa99e495a16925 --- tdutils/td/utils/crypto.cpp | 77 +++++++++++++++++++++++++------------ 1 file changed, 53 insertions(+), 24 deletions(-) diff --git a/tdutils/td/utils/crypto.cpp b/tdutils/td/utils/crypto.cpp index 43e85007b..16bebefb3 100644 --- a/tdutils/td/utils/crypto.cpp +++ b/tdutils/td/utils/crypto.cpp @@ -74,6 +74,52 @@ struct alignas(8) AesBlock { void store(uint8 *to) { as(to) = *this; } + + AesBlock inc() const { + AesBlock res = *this; + auto ptr = res.raw(); + for (int j = 15; j >= 0; j--) { + if (++ptr[j] != 0) { + break; + } + } + return res; + } +}; + +struct AesCtrBlock { + static constexpr size_t N = 128; + AesBlock blocks[N]; + uint8 *raw() { + return reinterpret_cast(this); + } + const uint8 *raw() const { + return reinterpret_cast(this); + } + + size_t size() const { + return N * 16; + } + + Slice as_slice() const { + return Slice(raw(), size()); + } + MutableSlice as_mutable_slice() { + return MutableSlice(raw(), size()); + } + + void init(AesBlock block) { + blocks[0] = block; + for (size_t i = 1; i < N; i++) { + blocks[i] = blocks[i - 1].inc(); + } + } + void rotate() { + blocks[0] = blocks[N - 1].inc(); + for (size_t i = 1; i < N; i++) { + blocks[i] = blocks[i - 1].inc(); + } + } }; static uint64 gcd(uint64 a, uint64 b) { @@ -475,7 +521,9 @@ class AesCtrState::Impl { CHECK(iv.size() == 16); static_assert(AES_BLOCK_SIZE == 16, ""); aes_state.init(key, true); - counter.as_mutable_slice().copy_from(iv); + AesBlock block; + block.load(iv.ubegin()); + counter.init(block); fill(); } @@ -485,10 +533,7 @@ class AesCtrState::Impl { auto n = from.size(); while (n != 0) { if (current.empty()) { - if (BLOCK_COUNT != 1) { - counter.as_mutable_slice().copy_from(counter.as_slice().substr((BLOCK_COUNT - 1) * AES_BLOCK_SIZE)); - } - inc(counter.as_mutable_slice().ubegin()); + counter.rotate(); fill(); } size_t min_n = td::min(n, current.size()); @@ -507,28 +552,12 @@ class AesCtrState::Impl { AesState aes_state; static constexpr size_t BLOCK_COUNT = 32; - SecureString counter{AES_BLOCK_SIZE * BLOCK_COUNT}; - SecureString encrypted_counter{AES_BLOCK_SIZE * BLOCK_COUNT}; + + AesCtrBlock counter; + AesCtrBlock encrypted_counter; Slice current; - void inc(uint8 *ptr) { - for (int j = 15; j >= 0; j--) { - if (++ptr[j] != 0) { - break; - } - } - } - void fill() { - auto *src = counter.as_slice().ubegin(); - auto *dst = counter.as_mutable_slice().ubegin() + AES_BLOCK_SIZE; - for (size_t i = 0; i + 1 < BLOCK_COUNT; i++) { - std::memcpy(dst, src, AES_BLOCK_SIZE); - inc(dst); - src += AES_BLOCK_SIZE; - dst += AES_BLOCK_SIZE; - } - aes_state.encrypt(counter.as_slice().ubegin(), encrypted_counter.as_mutable_slice().ubegin(), static_cast(counter.size())); current = encrypted_counter.as_slice();