tdutils: simplify aes ctr

GitOrigin-RevId: 557cc787f77e2f0af494e7dd46fa99e495a16925
This commit is contained in:
Arseny Smirnov 2020-06-15 16:58:58 +03:00
parent 0c0f6a7b7b
commit 8845e18da9

View File

@ -74,6 +74,52 @@ struct alignas(8) AesBlock {
void store(uint8 *to) {
as<AesBlock>(to) = *this;
}
AesBlock inc() const {
AesBlock res = *this;
auto ptr = res.raw();
for (int j = 15; j >= 0; j--) {
if (++ptr[j] != 0) {
break;
}
}
return res;
}
};
struct AesCtrBlock {
static constexpr size_t N = 128;
AesBlock blocks[N];
uint8 *raw() {
return reinterpret_cast<uint8 *>(this);
}
const uint8 *raw() const {
return reinterpret_cast<const uint8 *>(this);
}
size_t size() const {
return N * 16;
}
Slice as_slice() const {
return Slice(raw(), size());
}
MutableSlice as_mutable_slice() {
return MutableSlice(raw(), size());
}
void init(AesBlock block) {
blocks[0] = block;
for (size_t i = 1; i < N; i++) {
blocks[i] = blocks[i - 1].inc();
}
}
void rotate() {
blocks[0] = blocks[N - 1].inc();
for (size_t i = 1; i < N; i++) {
blocks[i] = blocks[i - 1].inc();
}
}
};
static uint64 gcd(uint64 a, uint64 b) {
@ -475,7 +521,9 @@ class AesCtrState::Impl {
CHECK(iv.size() == 16);
static_assert(AES_BLOCK_SIZE == 16, "");
aes_state.init(key, true);
counter.as_mutable_slice().copy_from(iv);
AesBlock block;
block.load(iv.ubegin());
counter.init(block);
fill();
}
@ -485,10 +533,7 @@ class AesCtrState::Impl {
auto n = from.size();
while (n != 0) {
if (current.empty()) {
if (BLOCK_COUNT != 1) {
counter.as_mutable_slice().copy_from(counter.as_slice().substr((BLOCK_COUNT - 1) * AES_BLOCK_SIZE));
}
inc(counter.as_mutable_slice().ubegin());
counter.rotate();
fill();
}
size_t min_n = td::min(n, current.size());
@ -507,28 +552,12 @@ class AesCtrState::Impl {
AesState aes_state;
static constexpr size_t BLOCK_COUNT = 32;
SecureString counter{AES_BLOCK_SIZE * BLOCK_COUNT};
SecureString encrypted_counter{AES_BLOCK_SIZE * BLOCK_COUNT};
AesCtrBlock counter;
AesCtrBlock encrypted_counter;
Slice current;
void inc(uint8 *ptr) {
for (int j = 15; j >= 0; j--) {
if (++ptr[j] != 0) {
break;
}
}
}
void fill() {
auto *src = counter.as_slice().ubegin();
auto *dst = counter.as_mutable_slice().ubegin() + AES_BLOCK_SIZE;
for (size_t i = 0; i + 1 < BLOCK_COUNT; i++) {
std::memcpy(dst, src, AES_BLOCK_SIZE);
inc(dst);
src += AES_BLOCK_SIZE;
dst += AES_BLOCK_SIZE;
}
aes_state.encrypt(counter.as_slice().ubegin(), encrypted_counter.as_mutable_slice().ubegin(),
static_cast<int>(counter.size()));
current = encrypted_counter.as_slice();