tdutils: optimize aes ctr

GitOrigin-RevId: 09c6df45c0bf2683507a0f279769471efc859ecb
This commit is contained in:
Arseny Smirnov 2020-06-12 19:40:17 +03:00
parent 55ca575af5
commit e913c3126b
4 changed files with 91 additions and 35 deletions

View File

@ -71,13 +71,42 @@ class AesBench : public td::Benchmark {
state.init(td::as_slice(key), true); state.init(td::as_slice(key), true);
td::MutableSlice data_slice(data, DATA_SIZE); td::MutableSlice data_slice(data, DATA_SIZE);
for (int i = 0; i <= n; i++) { for (int i = 0; i <= n; i++) {
for (size_t offset = 0; offset + 16 <= data_slice.size(); offset += 16) { size_t step = 16;
state.encrypt(data_slice.ubegin() + offset, data_slice.ubegin() + offset); for (size_t offset = 0; offset + 16 <= data_slice.size(); offset += step) {
state.encrypt(data_slice.ubegin() + offset, data_slice.ubegin() + offset, (int)step);
} }
} }
} }
}; };
class AesCtrBench : public td::Benchmark {
public:
alignas(64) unsigned char data[DATA_SIZE];
td::UInt256 key;
td::UInt128 iv;
std::string get_description() const override {
return PSTRING() << "AES CTR OpenSSL [" << (DATA_SIZE >> 10) << "KB]";
}
void start_up() override {
for (int i = 0; i < DATA_SIZE; i++) {
data[i] = 123;
}
td::Random::secure_bytes(key.raw, sizeof(key));
td::Random::secure_bytes(iv.raw, sizeof(iv));
}
void run(int n) override {
td::AesCtrState state;
state.init(as_slice(key), as_slice(iv));
td::MutableSlice data_slice(data, DATA_SIZE);
for (int i = 0; i < n; i++) {
state.encrypt(data_slice, data_slice);
}
}
};
class AESBench : public td::Benchmark { class AESBench : public td::Benchmark {
public: public:
alignas(64) unsigned char data[DATA_SIZE]; alignas(64) unsigned char data[DATA_SIZE];
@ -227,6 +256,7 @@ class Crc64Bench : public td::Benchmark {
int main() { int main() {
td::init_openssl_threads(); td::init_openssl_threads();
td::bench(AesCtrBench());
td::bench(AesBench()); td::bench(AesBench());
td::bench(AESBench()); td::bench(AESBench());

View File

@ -273,30 +273,28 @@ void AesState::init(Slice key, bool encrypt) {
if (encrypt) { if (encrypt) {
CHECK(1 == EVP_EncryptInit_ex(impl_->ctx, EVP_aes_256_ecb(), nullptr, key.ubegin(), nullptr)); CHECK(1 == EVP_EncryptInit_ex(impl_->ctx, EVP_aes_256_ecb(), nullptr, key.ubegin(), nullptr));
AES_set_encrypt_key(key.ubegin(), 256, &impl_->key);
} else { } else {
CHECK(1 == EVP_DecryptInit_ex(impl_->ctx, EVP_aes_256_ecb(), nullptr, key.ubegin(), nullptr)); CHECK(1 == EVP_DecryptInit_ex(impl_->ctx, EVP_aes_256_ecb(), nullptr, key.ubegin(), nullptr));
AES_set_decrypt_key(key.ubegin(), 256, &impl_->key);
} }
EVP_CIPHER_CTX_set_padding(impl_->ctx, 0); EVP_CIPHER_CTX_set_padding(impl_->ctx, 0);
impl_->encrypt = encrypt; impl_->encrypt = encrypt;
} }
void AesState::encrypt(const uint8 *src, uint8 *dst) { void AesState::encrypt(const uint8 *src, uint8 *dst, int size) {
CHECK(impl_->encrypt); CHECK(impl_->encrypt);
CHECK(impl_->ctx); CHECK(impl_->ctx);
CHECK(size % 16 == 0);
int len; int len;
CHECK(1 == EVP_EncryptUpdate(impl_->ctx, dst, &len, src, 16)); CHECK(1 == EVP_EncryptUpdate(impl_->ctx, dst, &len, src, size));
CHECK(len == 16); CHECK(len == size);
//AES_encrypt(src, dst, &impl_->key);
} }
void AesState::decrypt(const uint8 *src, uint8 *dst) { void AesState::decrypt(const uint8 *src, uint8 *dst, int size) {
CHECK(!impl_->encrypt); CHECK(!impl_->encrypt);
CHECK(impl_->ctx); CHECK(impl_->ctx);
CHECK(size % 16 == 0);
int len; int len;
CHECK(1 == EVP_DecryptUpdate(impl_->ctx, dst, &len, src, 16)); CHECK(1 == EVP_DecryptUpdate(impl_->ctx, dst, &len, src, size));
LOG_CHECK(len == 16) << len; CHECK(len == size);
//AES_decrypt(src, dst, &impl_->key);
} }
static void aes_ige_xcrypt(Slice aes_key, MutableSlice aes_iv, Slice from, MutableSlice to, bool encrypt_flag) { static void aes_ige_xcrypt(Slice aes_key, MutableSlice aes_iv, Slice from, MutableSlice to, bool encrypt_flag) {
@ -363,35 +361,63 @@ class AesCtrState::Impl {
CHECK(key.size() == 32); CHECK(key.size() == 32);
CHECK(iv.size() == 16); CHECK(iv.size() == 16);
static_assert(AES_BLOCK_SIZE == 16, ""); static_assert(AES_BLOCK_SIZE == 16, "");
if (AES_set_encrypt_key(key.ubegin(), 256, &aes_key) < 0) { aes_state.init(key, true);
LOG(FATAL) << "Failed to set encrypt key";
}
counter.as_mutable_slice().copy_from(iv); counter.as_mutable_slice().copy_from(iv);
current_pos = 0; fill();
} }
void encrypt(Slice from, MutableSlice to) { void encrypt(Slice from, MutableSlice to) {
CHECK(to.size() >= from.size()); auto *src = from.ubegin();
for (size_t i = 0; i < from.size(); i++) { auto *dst = to.ubegin();
if (current_pos == 0) { auto n = from.size();
AES_encrypt(counter.as_slice().ubegin(), encrypted_counter.as_mutable_slice().ubegin(), &aes_key); while (n != 0) {
uint8 *ptr = counter.as_mutable_slice().ubegin(); if (current.empty()) {
for (int j = 15; j >= 0; j--) { if (N != 1) {
if (++ptr[j] != 0) { counter.as_mutable_slice().copy_from(counter.as_slice().substr((N - 1) * AES_BLOCK_SIZE));
break;
}
} }
inc(counter.as_mutable_slice().ubegin());
fill();
} }
to[i] = static_cast<char>(from[i] ^ encrypted_counter[current_pos]); size_t min_n = td::min(n, current.size());
current_pos = (current_pos + 1) & 15; auto curr = current.ubegin();
for (size_t i = 0; i < min_n; i++) {
dst[i] = src[i] ^ curr[i];
}
n -= min_n;
src += min_n;
dst += min_n;
current.remove_prefix(min_n);
} }
} }
private: private:
AES_KEY aes_key; AesState aes_state;
SecureString counter{AES_BLOCK_SIZE};
SecureString encrypted_counter{AES_BLOCK_SIZE}; static constexpr size_t N = 32;
uint8 current_pos; SecureString counter{AES_BLOCK_SIZE * N};
SecureString encrypted_counter{AES_BLOCK_SIZE * N};
td::Slice current;
void inc(uint8 *ptr) {
for (int j = 15; j >= 0; j--) {
if (++ptr[j] != 0) {
break;
}
}
}
void fill() {
auto *src = counter.as_slice().ubegin();
auto *dst = counter.as_mutable_slice().ubegin() + AES_BLOCK_SIZE;
for (size_t i = 0; i + 1 < N; i++) {
memcpy(dst, src, AES_BLOCK_SIZE);
inc(dst);
src += AES_BLOCK_SIZE;
dst += AES_BLOCK_SIZE;
}
aes_state.encrypt(counter.as_slice().ubegin(), encrypted_counter.as_mutable_slice().ubegin(), (int)counter.size());
current = encrypted_counter.as_slice();
}
}; };
AesCtrState::AesCtrState() = default; AesCtrState::AesCtrState() = default;

View File

@ -28,8 +28,8 @@ struct AesState {
AesState &operator=(AesState &&from); AesState &operator=(AesState &&from);
~AesState(); ~AesState();
void init(Slice key, bool encrypt); void init(Slice key, bool encrypt);
void encrypt(const uint8 *src, uint8 *dst); void encrypt(const uint8 *src, uint8 *dst, int size);
void decrypt(const uint8 *src, uint8 *dst); void decrypt(const uint8 *src, uint8 *dst, int size);
private: private:
class Impl; class Impl;

View File

@ -32,8 +32,8 @@ TEST(Crypto, Aes) {
td::AesState decryptor; td::AesState decryptor;
decryptor.init(as_slice(key), false); decryptor.init(as_slice(key), false);
encryptor.encrypt(td::as_slice(plaintext).ubegin(), td::as_slice(encrypted).ubegin()); encryptor.encrypt(td::as_slice(plaintext).ubegin(), td::as_slice(encrypted).ubegin(), 16);
decryptor.decrypt(td::as_slice(encrypted).ubegin(), td::as_slice(decrypted).ubegin()); decryptor.decrypt(td::as_slice(encrypted).ubegin(), td::as_slice(decrypted).ubegin(), 16);
CHECK(decrypted == plaintext); CHECK(decrypted == plaintext);
CHECK(decrypted != encrypted); CHECK(decrypted != encrypted);