diff --git a/benchmark/bench_crypto.cpp b/benchmark/bench_crypto.cpp index 148b50ee6..38ef179fe 100644 --- a/benchmark/bench_crypto.cpp +++ b/benchmark/bench_crypto.cpp @@ -285,9 +285,9 @@ class Crc64Bench : public td::Benchmark { int main() { td::init_openssl_threads(); + td::bench(AesCtrBench()); td::bench(AesEcbBench()); td::bench(AesIgeBench()); - td::bench(AesCtrBench()); td::bench(Pbkdf2Bench()); td::bench(RandBench()); diff --git a/tdutils/td/utils/crypto.cpp b/tdutils/td/utils/crypto.cpp index 16bebefb3..d9edb37bd 100644 --- a/tdutils/td/utils/crypto.cpp +++ b/tdutils/td/utils/crypto.cpp @@ -8,6 +8,7 @@ #include "td/utils/as.h" #include "td/utils/BigNum.h" +#include "td/utils/bits.h" #include "td/utils/common.h" #include "td/utils/logging.h" #include "td/utils/misc.h" @@ -76,19 +77,65 @@ struct alignas(8) AesBlock { } AesBlock inc() const { - AesBlock res = *this; - auto ptr = res.raw(); - for (int j = 15; j >= 0; j--) { - if (++ptr[j] != 0) { - break; - } + AesBlock res; + res.lo = bswap64(bswap64(lo) + 1); + if (res.lo == 0) { + res.hi = bswap64(bswap64(hi) + 1); + } else { + res.hi = hi; } + return res; } }; -struct AesCtrBlock { - static constexpr size_t N = 128; +class XorBytes { + public: + static void run(const uint8 *a, const uint8 *b, uint8 *c, size_t n) { + XorBytes xorer; + xorer.a = a; + xorer.b = b; + xorer.c = c; + xorer.n = n; + xorer.step<16>(); + xorer.step<1>(); + } + + private: + const uint8 *a; + const uint8 *b; + uint8 *c; + size_t n; + + template + struct alignas(16) Block { + uint8 data[N]; + Block operator^(const Block &b) const & { + Block res; + for (size_t i = 0; i < N; i++) { + res.data[i] = data[i] ^ b.data[i]; + } + return res; + } + }; + + template + void step() { + auto cnt = n / N; + n -= cnt * N; + for (size_t i = 0; i < cnt; i++) { + Block a_big = as>(a); + Block b_big = as>(b); + as>(c) = a_big ^ b_big; + a += N; + b += N; + c += N; + } + } +}; + +struct AesCtrCounterPack { + static constexpr size_t N = 32; AesBlock blocks[N]; uint8 *raw() { return reinterpret_cast(this); @@ -538,12 +585,11 @@ class AesCtrState::Impl { } size_t min_n = td::min(n, current.size()); auto curr = current.ubegin(); - for (size_t i = 0; i < min_n; i++) { - dst[i] = src[i] ^ curr[i]; - } - n -= min_n; + XorBytes::run(src, curr, dst, min_n); src += min_n; + curr += min_n; dst += min_n; + n -= min_n; current.remove_prefix(min_n); } } @@ -551,15 +597,12 @@ class AesCtrState::Impl { private: AesState aes_state; - static constexpr size_t BLOCK_COUNT = 32; - - AesCtrBlock counter; - AesCtrBlock encrypted_counter; + AesCtrCounterPack counter; + AesCtrCounterPack encrypted_counter; Slice current; void fill() { - aes_state.encrypt(counter.as_slice().ubegin(), encrypted_counter.as_mutable_slice().ubegin(), - static_cast(counter.size())); + aes_state.encrypt(counter.raw(), encrypted_counter.raw(), static_cast(counter.size())); current = encrypted_counter.as_slice(); } };