tdutils: optimize aes ctr

GitOrigin-RevId: b24920ac38bb3b8e94ece87e7438a8b8b1b370c4
This commit is contained in:
Arseny Smirnov 2020-06-15 18:59:56 +03:00
parent 8845e18da9
commit 7e06d91739
2 changed files with 62 additions and 19 deletions

View File

@ -285,9 +285,9 @@ class Crc64Bench : public td::Benchmark {
int main() {
td::init_openssl_threads();
td::bench(AesCtrBench());
td::bench(AesEcbBench());
td::bench(AesIgeBench());
td::bench(AesCtrBench());
td::bench(Pbkdf2Bench());
td::bench(RandBench());

View File

@ -8,6 +8,7 @@
#include "td/utils/as.h"
#include "td/utils/BigNum.h"
#include "td/utils/bits.h"
#include "td/utils/common.h"
#include "td/utils/logging.h"
#include "td/utils/misc.h"
@ -76,19 +77,65 @@ struct alignas(8) AesBlock {
}
AesBlock inc() const {
AesBlock res = *this;
auto ptr = res.raw();
for (int j = 15; j >= 0; j--) {
if (++ptr[j] != 0) {
break;
}
AesBlock res;
res.lo = bswap64(bswap64(lo) + 1);
if (res.lo == 0) {
res.hi = bswap64(bswap64(hi) + 1);
} else {
res.hi = hi;
}
return res;
}
};
struct AesCtrBlock {
static constexpr size_t N = 128;
class XorBytes {
public:
static void run(const uint8 *a, const uint8 *b, uint8 *c, size_t n) {
XorBytes xorer;
xorer.a = a;
xorer.b = b;
xorer.c = c;
xorer.n = n;
xorer.step<16>();
xorer.step<1>();
}
private:
const uint8 *a;
const uint8 *b;
uint8 *c;
size_t n;
template <size_t N>
struct alignas(16) Block {
uint8 data[N];
Block operator^(const Block &b) const & {
Block res;
for (size_t i = 0; i < N; i++) {
res.data[i] = data[i] ^ b.data[i];
}
return res;
}
};
template <size_t N>
void step() {
auto cnt = n / N;
n -= cnt * N;
for (size_t i = 0; i < cnt; i++) {
Block<N> a_big = as<Block<N>>(a);
Block<N> b_big = as<Block<N>>(b);
as<Block<N>>(c) = a_big ^ b_big;
a += N;
b += N;
c += N;
}
}
};
struct AesCtrCounterPack {
static constexpr size_t N = 32;
AesBlock blocks[N];
uint8 *raw() {
return reinterpret_cast<uint8 *>(this);
@ -538,12 +585,11 @@ class AesCtrState::Impl {
}
size_t min_n = td::min(n, current.size());
auto curr = current.ubegin();
for (size_t i = 0; i < min_n; i++) {
dst[i] = src[i] ^ curr[i];
}
n -= min_n;
XorBytes::run(src, curr, dst, min_n);
src += min_n;
curr += min_n;
dst += min_n;
n -= min_n;
current.remove_prefix(min_n);
}
}
@ -551,15 +597,12 @@ class AesCtrState::Impl {
private:
AesState aes_state;
static constexpr size_t BLOCK_COUNT = 32;
AesCtrBlock counter;
AesCtrBlock encrypted_counter;
AesCtrCounterPack counter;
AesCtrCounterPack encrypted_counter;
Slice current;
void fill() {
aes_state.encrypt(counter.as_slice().ubegin(), encrypted_counter.as_mutable_slice().ubegin(),
static_cast<int>(counter.size()));
aes_state.encrypt(counter.raw(), encrypted_counter.raw(), static_cast<int>(counter.size()));
current = encrypted_counter.as_slice();
}
};