use void* instead of pimpl idiom in AesCtrState and AesIgeState

GitOrigin-RevId: cc1c175d078b179e7af730b8617165c6ea6193f3
This commit is contained in:
Arseny Smirnov 2020-06-17 19:05:43 +03:00
parent 50ce05a34f
commit 95af3e74bd
2 changed files with 288 additions and 286 deletions

View File

@ -251,39 +251,35 @@ int pq_factorize(Slice pq_str, string *p_str, string *q_str) {
return 0; return 0;
} }
struct AesBlock { uint8 *AesBlock::raw() {
uint64 hi;
uint64 lo;
uint8 *raw() {
return reinterpret_cast<uint8 *>(this); return reinterpret_cast<uint8 *>(this);
} }
const uint8 *raw() const { const uint8 *AesBlock::raw() const {
return reinterpret_cast<const uint8 *>(this); return reinterpret_cast<const uint8 *>(this);
} }
Slice as_slice() const { Slice AesBlock::as_slice() const {
return Slice(raw(), AES_BLOCK_SIZE); return Slice(raw(), AES_BLOCK_SIZE);
} }
AesBlock operator^(const AesBlock &b) const { AesBlock AesBlock::operator^(const AesBlock &b) const {
AesBlock res; AesBlock res;
res.hi = hi ^ b.hi; res.hi = hi ^ b.hi;
res.lo = lo ^ b.lo; res.lo = lo ^ b.lo;
return res; return res;
} }
void operator^=(const AesBlock &b) { void AesBlock::operator^=(const AesBlock &b) {
hi ^= b.hi; hi ^= b.hi;
lo ^= b.lo; lo ^= b.lo;
} }
void load(const uint8 *from) { void AesBlock::load(const uint8 *from) {
*this = as<AesBlock>(from); *this = as<AesBlock>(from);
} }
void store(uint8 *to) { void AesBlock::store(uint8 *to) {
as<AesBlock>(to) = *this; as<AesBlock>(to) = *this;
} }
AesBlock inc() const { AesBlock AesBlock::inc() const {
#if SIZE_MAX == UINT64_MAX #if SIZE_MAX == UINT64_MAX
AesBlock res; AesBlock res;
res.lo = host_to_big_endian64(big_endian_to_host64(lo) + 1); res.lo = host_to_big_endian64(big_endian_to_host64(lo) + 1);
@ -305,8 +301,7 @@ struct AesBlock {
} }
return res; return res;
#endif #endif
} }
};
static_assert(sizeof(AesBlock) == 16, ""); static_assert(sizeof(AesBlock) == 16, "");
static_assert(sizeof(AesBlock) == AES_BLOCK_SIZE, ""); static_assert(sizeof(AesBlock) == AES_BLOCK_SIZE, "");
@ -343,109 +338,101 @@ class XorBytes {
}; };
}; };
struct AesCtrCounterPack { uint8 *AesCtrCounterPack::raw() {
static constexpr size_t BLOCK_COUNT = 32;
AesBlock blocks[BLOCK_COUNT];
uint8 *raw() {
return reinterpret_cast<uint8 *>(this); return reinterpret_cast<uint8 *>(this);
} }
const uint8 *raw() const { const uint8 *AesCtrCounterPack::raw() const {
return reinterpret_cast<const uint8 *>(this); return reinterpret_cast<const uint8 *>(this);
} }
size_t size() const { size_t AesCtrCounterPack::size() const {
return sizeof(blocks); return sizeof(blocks);
} }
Slice as_slice() const { Slice AesCtrCounterPack::as_slice() const {
return Slice(raw(), size()); return Slice(raw(), size());
} }
MutableSlice as_mutable_slice() { MutableSlice AesCtrCounterPack::as_mutable_slice() {
return MutableSlice(raw(), size()); return MutableSlice(raw(), size());
} }
void init(AesBlock block) { void AesCtrCounterPack::init(AesBlock block) {
blocks[0] = block; blocks[0] = block;
for (size_t i = 1; i < BLOCK_COUNT; i++) { for (size_t i = 1; i < BLOCK_COUNT; i++) {
blocks[i] = blocks[i - 1].inc(); blocks[i] = blocks[i - 1].inc();
} }
} }
void rotate() {
void AesCtrCounterPack::rotate() {
blocks[0] = blocks[BLOCK_COUNT - 1].inc(); blocks[0] = blocks[BLOCK_COUNT - 1].inc();
for (size_t i = 1; i < BLOCK_COUNT; i++) { for (size_t i = 1; i < BLOCK_COUNT; i++) {
blocks[i] = blocks[i - 1].inc(); blocks[i] = blocks[i - 1].inc();
} }
}
static EVP_CIPHER_CTX *evp_load(std::unique_ptr<void, Evp::EvpDeleter> &ptr) {
if (!ptr) {
ptr.reset(EVP_CIPHER_CTX_new());
LOG_IF(FATAL, !ptr);
} }
return reinterpret_cast<EVP_CIPHER_CTX *>(ptr.get());
}; };
class Evp { static void evp_init(EVP_CIPHER_CTX *ctx, Evp::Type type, bool is_encrypt, const EVP_CIPHER *cipher, Slice key) {
public: // type_ = type;
Evp() { // is_encrypt_ = is_encrypt;
ctx_ = EVP_CIPHER_CTX_new(); int res = EVP_CipherInit_ex(ctx, cipher, nullptr, key.ubegin(), nullptr, is_encrypt ? 1 : 0);
LOG_IF(FATAL, ctx_ == nullptr);
}
Evp(const Evp &from) = delete;
Evp &operator=(const Evp &from) = delete;
Evp(Evp &&from) = delete;
Evp &operator=(Evp &&from) = delete;
~Evp() {
CHECK(ctx_ != nullptr);
EVP_CIPHER_CTX_free(ctx_);
}
void init_encrypt_ecb(Slice key) {
init(Type::Ecb, true, EVP_aes_256_ecb(), key);
}
void init_decrypt_ecb(Slice key) {
init(Type::Ecb, false, EVP_aes_256_ecb(), key);
}
void init_encrypt_cbc(Slice key) {
init(Type::Cbc, true, EVP_aes_256_cbc(), key);
}
void init_decrypt_cbc(Slice key) {
init(Type::Cbc, false, EVP_aes_256_cbc(), key);
}
void init_iv(Slice iv) {
int res = EVP_CipherInit_ex(ctx_, nullptr, nullptr, nullptr, iv.ubegin(), -1);
LOG_IF(FATAL, res != 1); LOG_IF(FATAL, res != 1);
} EVP_CIPHER_CTX_set_padding(ctx, 0);
}
void encrypt(const uint8 *src, uint8 *dst, int size) { void Evp::EvpDeleter::operator()(void *ptr) {
CHECK(ptr != nullptr);
EVP_CIPHER_CTX_free(reinterpret_cast<EVP_CIPHER_CTX *>(ptr));
}
Evp::~Evp() = default;
Evp &Evp::operator=(Evp &&from) = default;
Evp::Evp(Evp &&from) = default;
void Evp::init_encrypt_ecb(Slice key) {
evp_init(evp_load(ctx_), Type::Ecb, true, EVP_aes_256_ecb(), key);
}
void Evp::init_decrypt_ecb(Slice key) {
evp_init(evp_load(ctx_), Type::Ecb, false, EVP_aes_256_ecb(), key);
}
void Evp::init_encrypt_cbc(Slice key) {
evp_init(evp_load(ctx_), Type::Cbc, true, EVP_aes_256_cbc(), key);
}
void Evp::init_decrypt_cbc(Slice key) {
evp_init(evp_load(ctx_), Type::Cbc, false, EVP_aes_256_cbc(), key);
}
void Evp::init_iv(Slice iv) {
int res = EVP_CipherInit_ex(evp_load(ctx_), nullptr, nullptr, nullptr, iv.ubegin(), -1);
LOG_IF(FATAL, res != 1);
}
void Evp::encrypt(const uint8 *src, uint8 *dst, int size) {
// CHECK(type_ != Type::Empty && is_encrypt_); // CHECK(type_ != Type::Empty && is_encrypt_);
CHECK(size % AES_BLOCK_SIZE == 0); CHECK(size % AES_BLOCK_SIZE == 0);
int len; int len;
int res = EVP_EncryptUpdate(ctx_, dst, &len, src, size); int res = EVP_EncryptUpdate(evp_load(ctx_), dst, &len, src, size);
LOG_IF(FATAL, res != 1); LOG_IF(FATAL, res != 1);
CHECK(len == size); CHECK(len == size);
} }
void decrypt(const uint8 *src, uint8 *dst, int size) { void Evp::decrypt(const uint8 *src, uint8 *dst, int size) {
// CHECK(type_ != Type::Empty && !is_encrypt_); // CHECK(type_ != Type::Empty && !is_encrypt_);
CHECK(size % AES_BLOCK_SIZE == 0); CHECK(size % AES_BLOCK_SIZE == 0);
int len; int len;
int res = EVP_DecryptUpdate(ctx_, dst, &len, src, size); int res = EVP_DecryptUpdate(evp_load(ctx_), dst, &len, src, size);
LOG_IF(FATAL, res != 1); LOG_IF(FATAL, res != 1);
CHECK(len == size); CHECK(len == size);
} }
private:
EVP_CIPHER_CTX *ctx_{nullptr};
enum class Type : int8 { Empty, Ecb, Cbc };
// Type type_{Type::Empty};
// bool is_encrypt_ = false;
void init(Type type, bool is_encrypt, const EVP_CIPHER *cipher, Slice key) {
// type_ = type;
// is_encrypt_ = is_encrypt;
int res = EVP_CipherInit_ex(ctx_, cipher, nullptr, key.ubegin(), nullptr, is_encrypt ? 1 : 0);
LOG_IF(FATAL, res != 1);
EVP_CIPHER_CTX_set_padding(ctx_, 0);
}
};
struct AesState::Impl { struct AesState::Impl {
Evp evp; Evp evp;
@ -511,9 +498,7 @@ void aes_ige_decrypt(Slice aes_key, MutableSlice aes_iv, Slice from, MutableSlic
state.decrypt(from, to); state.decrypt(from, to);
} }
class AesIgeState::Impl { void AesIgeState::init(Slice key, Slice iv, bool encrypt) {
public:
void init(Slice key, Slice iv, bool encrypt) {
CHECK(key.size() == 32); CHECK(key.size() == 32);
CHECK(iv.size() == 32); CHECK(iv.size() == 32);
if (encrypt) { if (encrypt) {
@ -524,8 +509,8 @@ class AesIgeState::Impl {
encrypted_iv_.load(iv.ubegin()); encrypted_iv_.load(iv.ubegin());
plaintext_iv_.load(iv.ubegin() + AES_BLOCK_SIZE); plaintext_iv_.load(iv.ubegin() + AES_BLOCK_SIZE);
} }
void encrypt(Slice from, MutableSlice to) { void AesIgeState::encrypt(Slice from, MutableSlice to) {
CHECK(from.size() % AES_BLOCK_SIZE == 0); CHECK(from.size() % AES_BLOCK_SIZE == 0);
CHECK(to.size() >= from.size()); CHECK(to.size() >= from.size());
auto len = to.size() / AES_BLOCK_SIZE; auto len = to.size() / AES_BLOCK_SIZE;
@ -563,9 +548,9 @@ class AesIgeState::Impl {
in += AES_BLOCK_SIZE * count; in += AES_BLOCK_SIZE * count;
out += AES_BLOCK_SIZE * count; out += AES_BLOCK_SIZE * count;
} }
} }
void decrypt(Slice from, MutableSlice to) { void AesIgeState::decrypt(Slice from, MutableSlice to) {
CHECK(from.size() % AES_BLOCK_SIZE == 0); CHECK(from.size() % AES_BLOCK_SIZE == 0);
CHECK(to.size() >= from.size()); CHECK(to.size() >= from.size());
auto len = to.size() / AES_BLOCK_SIZE; auto len = to.size() / AES_BLOCK_SIZE;
@ -588,35 +573,13 @@ class AesIgeState::Impl {
in += AES_BLOCK_SIZE; in += AES_BLOCK_SIZE;
out += AES_BLOCK_SIZE; out += AES_BLOCK_SIZE;
} }
} }
private:
Evp evp_;
AesBlock encrypted_iv_;
AesBlock plaintext_iv_;
};
AesIgeState::AesIgeState() = default; AesIgeState::AesIgeState() = default;
AesIgeState::AesIgeState(AesIgeState &&from) = default; AesIgeState::AesIgeState(AesIgeState &&from) = default;
AesIgeState &AesIgeState::operator=(AesIgeState &&from) = default; AesIgeState &AesIgeState::operator=(AesIgeState &&from) = default;
AesIgeState::~AesIgeState() = default; AesIgeState::~AesIgeState() = default;
void AesIgeState::init(Slice key, Slice iv, bool encrypt) {
if (!impl_) {
impl_ = make_unique<Impl>();
}
impl_->init(key, iv, encrypt);
}
void AesIgeState::encrypt(Slice from, MutableSlice to) {
impl_->encrypt(from, to);
}
void AesIgeState::decrypt(Slice from, MutableSlice to) {
impl_->decrypt(from, to);
}
static void aes_cbc_xcrypt(Slice aes_key, MutableSlice aes_iv, Slice from, MutableSlice to, bool encrypt_flag) { static void aes_cbc_xcrypt(Slice aes_key, MutableSlice aes_iv, Slice from, MutableSlice to, bool encrypt_flag) {
CHECK(aes_key.size() == 32); CHECK(aes_key.size() == 32);
CHECK(aes_iv.size() == 16); CHECK(aes_iv.size() == 16);
@ -652,9 +615,7 @@ void AesCbcState::decrypt(Slice from, MutableSlice to) {
::td::aes_cbc_decrypt(key_.as_slice(), iv_.as_mutable_slice(), from, to); ::td::aes_cbc_decrypt(key_.as_slice(), iv_.as_mutable_slice(), from, to);
} }
class AesCtrState::Impl { void AesCtrState::init(Slice key, Slice iv) {
public:
Impl(Slice key, Slice iv) {
CHECK(key.size() == 32); CHECK(key.size() == 32);
CHECK(iv.size() == 16); CHECK(iv.size() == 16);
static_assert(AES_BLOCK_SIZE == 16, ""); static_assert(AES_BLOCK_SIZE == 16, "");
@ -663,9 +624,9 @@ class AesCtrState::Impl {
block.load(iv.ubegin()); block.load(iv.ubegin());
counter_.init(block); counter_.init(block);
fill(); fill();
} }
void encrypt(Slice from, MutableSlice to) { void AesCtrState::encrypt(Slice from, MutableSlice to) {
auto *src = from.ubegin(); auto *src = from.ubegin();
auto *dst = to.ubegin(); auto *dst = to.ubegin();
auto n = from.size(); auto n = from.size();
@ -681,37 +642,15 @@ class AesCtrState::Impl {
n -= min_n; n -= min_n;
current_.remove_prefix(min_n); current_.remove_prefix(min_n);
} }
}
private:
Evp evp_;
AesCtrCounterPack counter_;
AesCtrCounterPack encrypted_counter_;
Slice current_;
void fill() {
evp_.encrypt(counter_.raw(), encrypted_counter_.raw(), static_cast<int>(counter_.size()));
current_ = encrypted_counter_.as_slice();
}
};
AesCtrState::AesCtrState() = default;
AesCtrState::AesCtrState(AesCtrState &&from) = default;
AesCtrState &AesCtrState::operator=(AesCtrState &&from) = default;
AesCtrState::~AesCtrState() = default;
void AesCtrState::init(Slice key, Slice iv) {
ctx_ = make_unique<AesCtrState::Impl>(key, iv);
}
void AesCtrState::encrypt(Slice from, MutableSlice to) {
ctx_->encrypt(from, to);
} }
void AesCtrState::decrypt(Slice from, MutableSlice to) { void AesCtrState::decrypt(Slice from, MutableSlice to) {
encrypt(from, to); // it is the same as decrypt encrypt(from, to); // it is the same as decrypt
} }
void AesCtrState::fill() {
evp_.encrypt(counter_.raw(), encrypted_counter_.raw(), static_cast<int>(counter_.size()));
current_ = encrypted_counter_.as_slice();
}
void sha1(Slice data, unsigned char output[20]) { void sha1(Slice data, unsigned char output[20]) {
auto result = SHA1(data.ubegin(), data.size(), output); auto result = SHA1(data.ubegin(), data.size(), output);
@ -1077,7 +1016,6 @@ uint32 crc32c_extend(uint32 old_crc, Slice data) {
} }
namespace { namespace {
uint32 gf32_matrix_times(const uint32 *matrix, uint32 vector) { uint32 gf32_matrix_times(const uint32 *matrix, uint32 vector) {
uint32 sum = 0; uint32 sum = 0;
while (vector) { while (vector) {

View File

@ -20,6 +20,64 @@ uint64 pq_factorize(uint64 pq);
void init_crypto(); void init_crypto();
int pq_factorize(Slice pq_str, string *p_str, string *q_str); int pq_factorize(Slice pq_str, string *p_str, string *q_str);
class Evp {
public:
Evp() = default;
Evp(const Evp &from) = delete;
Evp &operator=(const Evp &from) = delete;
Evp(Evp &&from);
Evp &operator=(Evp &&from);
~Evp();
void init_encrypt_ecb(Slice key);
void init_decrypt_ecb(Slice key);
void init_encrypt_cbc(Slice key);
void init_decrypt_cbc(Slice key);
void init_iv(Slice iv);
void encrypt(const uint8 *src, uint8 *dst, int size);
void decrypt(const uint8 *src, uint8 *dst, int size);
struct EvpDeleter {
public:
void operator()(void *ptr);
};
enum class Type : int8 { Empty, Ecb, Cbc };
private:
std::unique_ptr<void, EvpDeleter> ctx_;
};
struct AesBlock {
uint64 hi;
uint64 lo;
uint8 *raw();
const uint8 *raw() const;
Slice as_slice() const;
AesBlock operator^(const AesBlock &b) const;
void operator^=(const AesBlock &b);
void load(const uint8 *from);
void store(uint8 *to);
AesBlock inc() const;
};
struct AesCtrCounterPack {
static constexpr size_t BLOCK_COUNT = 32;
AesBlock blocks[BLOCK_COUNT];
uint8 *raw();
const uint8 *raw() const;
size_t size() const;
Slice as_slice() const;
MutableSlice as_mutable_slice();
void init(AesBlock block);
void rotate();
};
class AesState { class AesState {
public: public:
@ -60,8 +118,9 @@ class AesIgeState {
void decrypt(Slice from, MutableSlice to); void decrypt(Slice from, MutableSlice to);
private: private:
class Impl; AesBlock encrypted_iv_;
unique_ptr<Impl> impl_; AesBlock plaintext_iv_;
Evp evp_;
}; };
void aes_cbc_encrypt(Slice aes_key, MutableSlice aes_iv, Slice from, MutableSlice to); void aes_cbc_encrypt(Slice aes_key, MutableSlice aes_iv, Slice from, MutableSlice to);
@ -69,12 +128,12 @@ void aes_cbc_decrypt(Slice aes_key, MutableSlice aes_iv, Slice from, MutableSlic
class AesCtrState { class AesCtrState {
public: public:
AesCtrState(); AesCtrState() = default;
AesCtrState(const AesCtrState &from) = delete; AesCtrState(const AesCtrState &from) = delete;
AesCtrState &operator=(const AesCtrState &from) = delete; AesCtrState &operator=(const AesCtrState &from) = delete;
AesCtrState(AesCtrState &&from); AesCtrState(AesCtrState &&from) = default;
AesCtrState &operator=(AesCtrState &&from); AesCtrState &operator=(AesCtrState &&from) = default;
~AesCtrState(); ~AesCtrState() = default;
void init(Slice key, Slice iv); void init(Slice key, Slice iv);
@ -83,8 +142,13 @@ class AesCtrState {
void decrypt(Slice from, MutableSlice to); void decrypt(Slice from, MutableSlice to);
private: private:
class Impl; Evp evp_;
unique_ptr<Impl> ctx_;
AesCtrCounterPack counter_;
AesCtrCounterPack encrypted_counter_;
Slice current_;
void fill();
}; };
class AesCbcState { class AesCbcState {