Refactor td::AesState::Impl

GitOrigin-RevId: 5731ddc8d85c912cbfb141bd4e5eefea22d8ab21
This commit is contained in:
Arseny Smirnov 2020-06-16 17:57:19 +03:00
parent 86ca096840
commit 005611e924
2 changed files with 90 additions and 40 deletions

View File

@ -211,9 +211,9 @@ class AesIgeShortBench : public td::Benchmark {
void run(int n) override { void run(int n) override {
td::MutableSlice data_slice(data, SHORT_DATA_SIZE); td::MutableSlice data_slice(data, SHORT_DATA_SIZE);
td::AesIgeState ige;
for (int i = 0; i < n; i++) { for (int i = 0; i < n; i++) {
if (use_state) { if (use_state) {
td::AesIgeState ige;
ige.init(as_slice(key), as_slice(iv), false); ige.init(as_slice(key), as_slice(iv), false);
ige.decrypt(data_slice, data_slice); ige.decrypt(data_slice, data_slice);
} else { } else {

View File

@ -6,6 +6,7 @@
// //
#include "td/utils/crypto.h" #include "td/utils/crypto.h"
#include "td/utils/Slice-decl.h"
#include "td/utils/as.h" #include "td/utils/as.h"
#include "td/utils/BigNum.h" #include "td/utils/BigNum.h"
#include "td/utils/common.h" #include "td/utils/common.h"
@ -55,6 +56,9 @@ struct AesBlock {
uint8 *raw() { uint8 *raw() {
return reinterpret_cast<uint8 *>(this); return reinterpret_cast<uint8 *>(this);
} }
Slice as_mutable_slice() {
return td::MutableSlice(raw(), 16);
}
AesBlock operator^(const AesBlock &b) const { AesBlock operator^(const AesBlock &b) const {
AesBlock res; AesBlock res;
@ -362,18 +366,80 @@ int pq_factorize(Slice pq_str, string *p_str, string *q_str) {
class AesState::Impl { class AesState::Impl {
public: public:
EVP_CIPHER_CTX *ctx{nullptr}; Impl() {
ctx_ = EVP_CIPHER_CTX_new();
Impl() = default; LOG_IF(FATAL, !ctx_);
}
Impl(const Impl &from) = delete; Impl(const Impl &from) = delete;
Impl &operator=(const Impl &from) = delete; Impl &operator=(const Impl &from) = delete;
Impl(Impl &&from) = delete; Impl(Impl &&from) = delete;
Impl &operator=(Impl &&from) = delete; Impl &operator=(Impl &&from) = delete;
~Impl() { ~Impl() {
if (ctx != nullptr) { if (ctx_ != nullptr) {
EVP_CIPHER_CTX_free(ctx); EVP_CIPHER_CTX_free(ctx_);
} }
} }
void init_encrypt_ecb(Slice key) {
type_ = EncryptEcb;
int res = EVP_EncryptInit_ex(ctx_, EVP_aes_256_ecb(), nullptr, key.ubegin(), nullptr);
LOG_IF(FATAL, res != 1);
EVP_CIPHER_CTX_set_padding(ctx_, 0);
}
void init_decrypt_ecb(Slice key) {
type_ = DecryptEcb;
int res = EVP_DecryptInit_ex(ctx_, EVP_aes_256_ecb(), nullptr, key.ubegin(), nullptr);
LOG_IF(FATAL, res != 1);
EVP_CIPHER_CTX_set_padding(ctx_, 0);
}
void init_encrypt_cbc(Slice key) {
type_ = EncryptCbc;
int res = EVP_EncryptInit_ex(ctx_, EVP_aes_256_cbc(), nullptr, key.ubegin(), nullptr);
LOG_IF(FATAL, res != 1);
EVP_CIPHER_CTX_set_padding(ctx_, 0);
}
void init_decrypt_cbc(Slice key) {
type_ = DecryptCbc;
int res = EVP_DecryptInit_ex(ctx_, EVP_aes_256_cbc(), nullptr, key.ubegin(), nullptr);
LOG_IF(FATAL, res != 1);
EVP_CIPHER_CTX_set_padding(ctx_, 0);
}
void init_encrypt_cbc_iv(Slice iv) {
CHECK(type_ == EncryptCbc);
int res = EVP_EncryptInit(ctx_, nullptr, nullptr, iv.ubegin());
LOG_IF(FATAL, res != 1);
}
void init_decrypt_cbc_iv(Slice iv) {
CHECK(type_ == DecryptCbc);
int res = EVP_DecryptInit(ctx_, nullptr, nullptr, iv.ubegin());
LOG_IF(FATAL, res != 1);
}
void encrypt(const uint8 *src, uint8 *dst, int size) {
CHECK(type_ == EncryptCbc || type_ == EncryptEcb);
CHECK(size % 16 == 0);
int len;
int res = EVP_EncryptUpdate(ctx_, dst, &len, src, size);
LOG_IF(FATAL, res != 1);
CHECK(len == size);
}
void decrypt(const uint8 *src, uint8 *dst, int size) {
CHECK(type_ == DecryptCbc || type_ == DecryptEcb);
CHECK(size % 16 == 0);
int len;
int res = EVP_DecryptUpdate(ctx_, dst, &len, src, size);
LOG_IF(FATAL, res != 1);
CHECK(len == size);
}
private:
EVP_CIPHER_CTX *ctx_{nullptr};
enum Type { Empty, EncryptEcb, DecryptEcb, EncryptCbc, DecryptCbc } type_{Empty};
}; };
AesState::AesState() = default; AesState::AesState() = default;
@ -383,38 +449,22 @@ void AesState::init(Slice key, bool encrypt) {
CHECK(key.size() == 32); CHECK(key.size() == 32);
if (!impl_) { if (!impl_) {
impl_ = make_unique<Impl>(); impl_ = make_unique<Impl>();
impl_->ctx = EVP_CIPHER_CTX_new();
} }
CHECK(impl_->ctx);
if (encrypt) { if (encrypt) {
int res = EVP_EncryptInit_ex(impl_->ctx, EVP_aes_256_ecb(), nullptr, key.ubegin(), nullptr); impl_->init_encrypt_ecb(key);
LOG_IF(FATAL, res != 1);
} else { } else {
int res = EVP_DecryptInit_ex(impl_->ctx, EVP_aes_256_ecb(), nullptr, key.ubegin(), nullptr); impl_->init_decrypt_ecb(key);
LOG_IF(FATAL, res != 1);
} }
EVP_CIPHER_CTX_set_padding(impl_->ctx, 0);
} }
void AesState::encrypt(const uint8 *src, uint8 *dst, int size) { void AesState::encrypt(const uint8 *src, uint8 *dst, int size) {
CHECK(impl_ != nullptr); CHECK(impl_ != nullptr);
CHECK(impl_->ctx != nullptr); impl_->encrypt(src, dst, size);
CHECK(size % 16 == 0);
int len;
int res = EVP_EncryptUpdate(impl_->ctx, dst, &len, src, size);
LOG_IF(FATAL, res != 1);
CHECK(len == size);
} }
void AesState::decrypt(const uint8 *src, uint8 *dst, int size) { void AesState::decrypt(const uint8 *src, uint8 *dst, int size) {
CHECK(impl_ != nullptr); CHECK(impl_ != nullptr);
CHECK(impl_->ctx != nullptr); impl_->decrypt(src, dst, size);
CHECK(size % 16 == 0);
int len;
int res = EVP_DecryptUpdate(impl_->ctx, dst, &len, src, size);
LOG_IF(FATAL, res != 1);
CHECK(len == size);
} }
static void aes_ige_xcrypt(Slice aes_key, MutableSlice aes_iv, Slice from, MutableSlice to, bool encrypt_flag) { static void aes_ige_xcrypt(Slice aes_key, MutableSlice aes_iv, Slice from, MutableSlice to, bool encrypt_flag) {
@ -433,6 +483,9 @@ static void aes_ige_xcrypt(Slice aes_key, MutableSlice aes_iv, Slice from, Mutab
} }
void aes_ige_encrypt(Slice aes_key, MutableSlice aes_iv, Slice from, MutableSlice to) { void aes_ige_encrypt(Slice aes_key, MutableSlice aes_iv, Slice from, MutableSlice to) {
if (from.size() <= 128) {
return aes_ige_xcrypt(aes_key, aes_iv, from, to, true);
}
AesIgeState state; AesIgeState state;
state.init(aes_key, aes_iv, true); state.init(aes_key, aes_iv, true);
state.encrypt(from, to); state.encrypt(from, to);
@ -475,11 +528,10 @@ class AesIgeState::Impl {
} }
} }
EVP_EncryptInit(state.impl_->ctx, nullptr, nullptr, encrypted_iv.raw()); state.impl_->init_encrypt_cbc_iv(encrypted_iv.as_mutable_slice());
int outlen = 0; int outlen = 0;
int inlen = static_cast<int>(AES_BLOCK_SIZE * count); int inlen = static_cast<int>(AES_BLOCK_SIZE * count);
EVP_EncryptUpdate(state.impl_->ctx, data_xored[0].raw(), &outlen, data_xored[0].raw(), inlen); state.impl_->encrypt(data_xored[0].raw(), data_xored[0].raw(), inlen);
CHECK(outlen == inlen);
data_xored[0] ^= plaintext_iv; data_xored[0] ^= plaintext_iv;
for (size_t i = 1; i < count; i++) { for (size_t i = 1; i < count; i++) {
@ -530,19 +582,17 @@ void AesIgeState::init(Slice key, Slice iv, bool encrypt) {
if (!impl_) { if (!impl_) {
impl_ = make_unique<Impl>(); impl_ = make_unique<Impl>();
} }
if (encrypt) {
auto &impl = impl_->state.impl_;
if (!impl) {
impl = make_unique<AesState::Impl>();
impl->ctx = EVP_CIPHER_CTX_new();
}
int res = EVP_EncryptInit_ex(impl->ctx, EVP_aes_256_cbc(), nullptr, key.ubegin(), nullptr); auto &impl = impl_->state.impl_;
LOG_IF(FATAL, res != 1); if (!impl) {
EVP_CIPHER_CTX_set_padding(impl->ctx, 0); impl = make_unique<AesState::Impl>();
} else {
impl_->state.init(key, encrypt);
} }
if (encrypt) {
impl->init_encrypt_cbc(key);
} else {
impl->init_decrypt_ecb(key);
}
impl_->encrypted_iv.load(iv.ubegin()); impl_->encrypted_iv.load(iv.ubegin());
impl_->plaintext_iv.load(iv.ubegin() + AES_BLOCK_SIZE); impl_->plaintext_iv.load(iv.ubegin() + AES_BLOCK_SIZE);
} }