tdutils: optimize aes ctr

GitOrigin-RevId: 09c6df45c0bf2683507a0f279769471efc859ecb
This commit is contained in:
Arseny Smirnov 2020-06-12 19:40:17 +03:00
parent 55ca575af5
commit e913c3126b
4 changed files with 91 additions and 35 deletions

View File

@ -71,13 +71,42 @@ class AesBench : public td::Benchmark {
state.init(td::as_slice(key), true);
td::MutableSlice data_slice(data, DATA_SIZE);
for (int i = 0; i <= n; i++) {
for (size_t offset = 0; offset + 16 <= data_slice.size(); offset += 16) {
state.encrypt(data_slice.ubegin() + offset, data_slice.ubegin() + offset);
size_t step = 16;
for (size_t offset = 0; offset + 16 <= data_slice.size(); offset += step) {
state.encrypt(data_slice.ubegin() + offset, data_slice.ubegin() + offset, (int)step);
}
}
}
};
class AesCtrBench : public td::Benchmark {
public:
alignas(64) unsigned char data[DATA_SIZE];
td::UInt256 key;
td::UInt128 iv;
std::string get_description() const override {
return PSTRING() << "AES CTR OpenSSL [" << (DATA_SIZE >> 10) << "KB]";
}
void start_up() override {
for (int i = 0; i < DATA_SIZE; i++) {
data[i] = 123;
}
td::Random::secure_bytes(key.raw, sizeof(key));
td::Random::secure_bytes(iv.raw, sizeof(iv));
}
void run(int n) override {
td::AesCtrState state;
state.init(as_slice(key), as_slice(iv));
td::MutableSlice data_slice(data, DATA_SIZE);
for (int i = 0; i < n; i++) {
state.encrypt(data_slice, data_slice);
}
}
};
class AESBench : public td::Benchmark {
public:
alignas(64) unsigned char data[DATA_SIZE];
@ -227,6 +256,7 @@ class Crc64Bench : public td::Benchmark {
int main() {
td::init_openssl_threads();
td::bench(AesCtrBench());
td::bench(AesBench());
td::bench(AESBench());

View File

@ -273,30 +273,28 @@ void AesState::init(Slice key, bool encrypt) {
if (encrypt) {
CHECK(1 == EVP_EncryptInit_ex(impl_->ctx, EVP_aes_256_ecb(), nullptr, key.ubegin(), nullptr));
AES_set_encrypt_key(key.ubegin(), 256, &impl_->key);
} else {
CHECK(1 == EVP_DecryptInit_ex(impl_->ctx, EVP_aes_256_ecb(), nullptr, key.ubegin(), nullptr));
AES_set_decrypt_key(key.ubegin(), 256, &impl_->key);
}
EVP_CIPHER_CTX_set_padding(impl_->ctx, 0);
impl_->encrypt = encrypt;
}
void AesState::encrypt(const uint8 *src, uint8 *dst) {
void AesState::encrypt(const uint8 *src, uint8 *dst, int size) {
CHECK(impl_->encrypt);
CHECK(impl_->ctx);
CHECK(size % 16 == 0);
int len;
CHECK(1 == EVP_EncryptUpdate(impl_->ctx, dst, &len, src, 16));
CHECK(len == 16);
//AES_encrypt(src, dst, &impl_->key);
CHECK(1 == EVP_EncryptUpdate(impl_->ctx, dst, &len, src, size));
CHECK(len == size);
}
void AesState::decrypt(const uint8 *src, uint8 *dst) {
void AesState::decrypt(const uint8 *src, uint8 *dst, int size) {
CHECK(!impl_->encrypt);
CHECK(impl_->ctx);
CHECK(size % 16 == 0);
int len;
CHECK(1 == EVP_DecryptUpdate(impl_->ctx, dst, &len, src, 16));
LOG_CHECK(len == 16) << len;
//AES_decrypt(src, dst, &impl_->key);
CHECK(1 == EVP_DecryptUpdate(impl_->ctx, dst, &len, src, size));
CHECK(len == size);
}
static void aes_ige_xcrypt(Slice aes_key, MutableSlice aes_iv, Slice from, MutableSlice to, bool encrypt_flag) {
@ -363,35 +361,63 @@ class AesCtrState::Impl {
CHECK(key.size() == 32);
CHECK(iv.size() == 16);
static_assert(AES_BLOCK_SIZE == 16, "");
if (AES_set_encrypt_key(key.ubegin(), 256, &aes_key) < 0) {
LOG(FATAL) << "Failed to set encrypt key";
}
aes_state.init(key, true);
counter.as_mutable_slice().copy_from(iv);
current_pos = 0;
fill();
}
void encrypt(Slice from, MutableSlice to) {
CHECK(to.size() >= from.size());
for (size_t i = 0; i < from.size(); i++) {
if (current_pos == 0) {
AES_encrypt(counter.as_slice().ubegin(), encrypted_counter.as_mutable_slice().ubegin(), &aes_key);
uint8 *ptr = counter.as_mutable_slice().ubegin();
for (int j = 15; j >= 0; j--) {
if (++ptr[j] != 0) {
break;
}
auto *src = from.ubegin();
auto *dst = to.ubegin();
auto n = from.size();
while (n != 0) {
if (current.empty()) {
if (N != 1) {
counter.as_mutable_slice().copy_from(counter.as_slice().substr((N - 1) * AES_BLOCK_SIZE));
}
inc(counter.as_mutable_slice().ubegin());
fill();
}
to[i] = static_cast<char>(from[i] ^ encrypted_counter[current_pos]);
current_pos = (current_pos + 1) & 15;
size_t min_n = td::min(n, current.size());
auto curr = current.ubegin();
for (size_t i = 0; i < min_n; i++) {
dst[i] = src[i] ^ curr[i];
}
n -= min_n;
src += min_n;
dst += min_n;
current.remove_prefix(min_n);
}
}
private:
AES_KEY aes_key;
SecureString counter{AES_BLOCK_SIZE};
SecureString encrypted_counter{AES_BLOCK_SIZE};
uint8 current_pos;
AesState aes_state;
static constexpr size_t N = 32;
SecureString counter{AES_BLOCK_SIZE * N};
SecureString encrypted_counter{AES_BLOCK_SIZE * N};
td::Slice current;
void inc(uint8 *ptr) {
for (int j = 15; j >= 0; j--) {
if (++ptr[j] != 0) {
break;
}
}
}
void fill() {
auto *src = counter.as_slice().ubegin();
auto *dst = counter.as_mutable_slice().ubegin() + AES_BLOCK_SIZE;
for (size_t i = 0; i + 1 < N; i++) {
memcpy(dst, src, AES_BLOCK_SIZE);
inc(dst);
src += AES_BLOCK_SIZE;
dst += AES_BLOCK_SIZE;
}
aes_state.encrypt(counter.as_slice().ubegin(), encrypted_counter.as_mutable_slice().ubegin(), (int)counter.size());
current = encrypted_counter.as_slice();
}
};
AesCtrState::AesCtrState() = default;

View File

@ -28,8 +28,8 @@ struct AesState {
AesState &operator=(AesState &&from);
~AesState();
void init(Slice key, bool encrypt);
void encrypt(const uint8 *src, uint8 *dst);
void decrypt(const uint8 *src, uint8 *dst);
void encrypt(const uint8 *src, uint8 *dst, int size);
void decrypt(const uint8 *src, uint8 *dst, int size);
private:
class Impl;

View File

@ -32,8 +32,8 @@ TEST(Crypto, Aes) {
td::AesState decryptor;
decryptor.init(as_slice(key), false);
encryptor.encrypt(td::as_slice(plaintext).ubegin(), td::as_slice(encrypted).ubegin());
decryptor.decrypt(td::as_slice(encrypted).ubegin(), td::as_slice(decrypted).ubegin());
encryptor.encrypt(td::as_slice(plaintext).ubegin(), td::as_slice(encrypted).ubegin(), 16);
decryptor.decrypt(td::as_slice(encrypted).ubegin(), td::as_slice(decrypted).ubegin(), 16);
CHECK(decrypted == plaintext);
CHECK(decrypted != encrypted);