tdutils: optimize aes ctr

GitOrigin-RevId: b24920ac38bb3b8e94ece87e7438a8b8b1b370c4
This commit is contained in:
Arseny Smirnov 2020-06-15 18:59:56 +03:00
parent 8845e18da9
commit 7e06d91739
2 changed files with 62 additions and 19 deletions

View File

@ -285,9 +285,9 @@ class Crc64Bench : public td::Benchmark {
int main() { int main() {
td::init_openssl_threads(); td::init_openssl_threads();
td::bench(AesCtrBench());
td::bench(AesEcbBench()); td::bench(AesEcbBench());
td::bench(AesIgeBench()); td::bench(AesIgeBench());
td::bench(AesCtrBench());
td::bench(Pbkdf2Bench()); td::bench(Pbkdf2Bench());
td::bench(RandBench()); td::bench(RandBench());

View File

@ -8,6 +8,7 @@
#include "td/utils/as.h" #include "td/utils/as.h"
#include "td/utils/BigNum.h" #include "td/utils/BigNum.h"
#include "td/utils/bits.h"
#include "td/utils/common.h" #include "td/utils/common.h"
#include "td/utils/logging.h" #include "td/utils/logging.h"
#include "td/utils/misc.h" #include "td/utils/misc.h"
@ -76,19 +77,65 @@ struct alignas(8) AesBlock {
} }
AesBlock inc() const { AesBlock inc() const {
AesBlock res = *this; AesBlock res;
auto ptr = res.raw(); res.lo = bswap64(bswap64(lo) + 1);
for (int j = 15; j >= 0; j--) { if (res.lo == 0) {
if (++ptr[j] != 0) { res.hi = bswap64(bswap64(hi) + 1);
break; } else {
} res.hi = hi;
} }
return res; return res;
} }
}; };
struct AesCtrBlock { class XorBytes {
static constexpr size_t N = 128; public:
static void run(const uint8 *a, const uint8 *b, uint8 *c, size_t n) {
XorBytes xorer;
xorer.a = a;
xorer.b = b;
xorer.c = c;
xorer.n = n;
xorer.step<16>();
xorer.step<1>();
}
private:
const uint8 *a;
const uint8 *b;
uint8 *c;
size_t n;
template <size_t N>
struct alignas(16) Block {
uint8 data[N];
Block operator^(const Block &b) const & {
Block res;
for (size_t i = 0; i < N; i++) {
res.data[i] = data[i] ^ b.data[i];
}
return res;
}
};
template <size_t N>
void step() {
auto cnt = n / N;
n -= cnt * N;
for (size_t i = 0; i < cnt; i++) {
Block<N> a_big = as<Block<N>>(a);
Block<N> b_big = as<Block<N>>(b);
as<Block<N>>(c) = a_big ^ b_big;
a += N;
b += N;
c += N;
}
}
};
struct AesCtrCounterPack {
static constexpr size_t N = 32;
AesBlock blocks[N]; AesBlock blocks[N];
uint8 *raw() { uint8 *raw() {
return reinterpret_cast<uint8 *>(this); return reinterpret_cast<uint8 *>(this);
@ -538,12 +585,11 @@ class AesCtrState::Impl {
} }
size_t min_n = td::min(n, current.size()); size_t min_n = td::min(n, current.size());
auto curr = current.ubegin(); auto curr = current.ubegin();
for (size_t i = 0; i < min_n; i++) { XorBytes::run(src, curr, dst, min_n);
dst[i] = src[i] ^ curr[i];
}
n -= min_n;
src += min_n; src += min_n;
curr += min_n;
dst += min_n; dst += min_n;
n -= min_n;
current.remove_prefix(min_n); current.remove_prefix(min_n);
} }
} }
@ -551,15 +597,12 @@ class AesCtrState::Impl {
private: private:
AesState aes_state; AesState aes_state;
static constexpr size_t BLOCK_COUNT = 32; AesCtrCounterPack counter;
AesCtrCounterPack encrypted_counter;
AesCtrBlock counter;
AesCtrBlock encrypted_counter;
Slice current; Slice current;
void fill() { void fill() {
aes_state.encrypt(counter.as_slice().ubegin(), encrypted_counter.as_mutable_slice().ubegin(), aes_state.encrypt(counter.raw(), encrypted_counter.raw(), static_cast<int>(counter.size()));
static_cast<int>(counter.size()));
current = encrypted_counter.as_slice(); current = encrypted_counter.as_slice();
} }
}; };