diff --git a/tdutils/td/utils/crypto.cpp b/tdutils/td/utils/crypto.cpp index 30f672e3f..f25350530 100644 --- a/tdutils/td/utils/crypto.cpp +++ b/tdutils/td/utils/crypto.cpp @@ -364,17 +364,17 @@ int pq_factorize(Slice pq_str, string *p_str, string *q_str) { return 0; } -class AesState::Impl { +class Evp { public: - Impl() { + Evp() { ctx_ = EVP_CIPHER_CTX_new(); LOG_IF(FATAL, !ctx_); } - Impl(const Impl &from) = delete; - Impl &operator=(const Impl &from) = delete; - Impl(Impl &&from) = delete; - Impl &operator=(Impl &&from) = delete; - ~Impl() { + Evp(const Evp &from) = delete; + Evp &operator=(const Evp &from) = delete; + Evp(Evp &&from) = delete; + Evp &operator=(Evp &&from) = delete; + ~Evp() { if (ctx_ != nullptr) { EVP_CIPHER_CTX_free(ctx_); } @@ -442,6 +442,10 @@ class AesState::Impl { enum Type { Empty, EncryptEcb, DecryptEcb, EncryptCbc, DecryptCbc } type_{Empty}; }; +struct AesState::Impl { + Evp evp; +}; + AesState::AesState() = default; AesState::~AesState() = default; @@ -451,20 +455,20 @@ void AesState::init(Slice key, bool encrypt) { impl_ = make_unique(); } if (encrypt) { - impl_->init_encrypt_ecb(key); + impl_->evp.init_encrypt_ecb(key); } else { - impl_->init_decrypt_ecb(key); + impl_->evp.init_decrypt_ecb(key); } } void AesState::encrypt(const uint8 *src, uint8 *dst, int size) { - CHECK(impl_ != nullptr); - impl_->encrypt(src, dst, size); + CHECK(impl_); + impl_->evp.encrypt(src, dst, size); } void AesState::decrypt(const uint8 *src, uint8 *dst, int size) { - CHECK(impl_ != nullptr); - impl_->decrypt(src, dst, size); + CHECK(impl_); + impl_->evp.decrypt(src, dst, size); } static void aes_ige_xcrypt(Slice aes_key, MutableSlice aes_iv, Slice from, MutableSlice to, bool encrypt_flag) { @@ -502,10 +506,18 @@ void aes_ige_decrypt(Slice aes_key, MutableSlice aes_iv, Slice from, MutableSlic class AesIgeState::Impl { public: - AesState state; - AesBlock encrypted_iv; - AesBlock plaintext_iv; + void init(Slice key, Slice iv, bool encrypt) { + CHECK(key.size() == 32); + CHECK(iv.size() == 32); + if (encrypt) { + evp_.init_encrypt_cbc(key); + } else { + evp_.init_decrypt_ecb(key); + } + encrypted_iv_.load(iv.ubegin()); + plaintext_iv_.load(iv.ubegin() + AES_BLOCK_SIZE); + } void encrypt(Slice from, MutableSlice to) { CHECK(from.size() % AES_BLOCK_SIZE == 0); CHECK(to.size() >= from.size()); @@ -522,23 +534,22 @@ class AesIgeState::Impl { std::memcpy(data, in, AES_BLOCK_SIZE * count); data_xored[0] = data[0]; if (count > 1) { - data_xored[1] = plaintext_iv ^ data[1]; + data_xored[1] = plaintext_iv_ ^ data[1]; for (size_t i = 2; i < count; i++) { data_xored[i] = data[i - 2] ^ data[i]; } } - state.impl_->init_encrypt_cbc_iv(encrypted_iv.as_mutable_slice()); - int outlen = 0; + evp_.init_encrypt_cbc_iv(encrypted_iv_.as_mutable_slice()); int inlen = static_cast(AES_BLOCK_SIZE * count); - state.impl_->encrypt(data_xored[0].raw(), data_xored[0].raw(), inlen); + evp_.encrypt(data_xored[0].raw(), data_xored[0].raw(), inlen); - data_xored[0] ^= plaintext_iv; + data_xored[0] ^= plaintext_iv_; for (size_t i = 1; i < count; i++) { data_xored[i] ^= data[i - 1]; } - plaintext_iv = data[count - 1]; - encrypted_iv = data_xored[count - 1]; + plaintext_iv_ = data[count - 1]; + encrypted_iv_ = data_xored[count - 1]; std::memcpy(out, data_xored, AES_BLOCK_SIZE * count); len -= count; @@ -559,42 +570,34 @@ class AesIgeState::Impl { while (len) { encrypted.load(in); - plaintext_iv ^= encrypted; - state.decrypt(plaintext_iv.raw(), plaintext_iv.raw(), AES_BLOCK_SIZE); - plaintext_iv ^= encrypted_iv; + plaintext_iv_ ^= encrypted; + evp_.decrypt(plaintext_iv_.raw(), plaintext_iv_.raw(), AES_BLOCK_SIZE); + plaintext_iv_ ^= encrypted_iv_; - plaintext_iv.store(out); - encrypted_iv = encrypted; + plaintext_iv_.store(out); + encrypted_iv_ = encrypted; --len; in += AES_BLOCK_SIZE; out += AES_BLOCK_SIZE; } } + + private: + Evp evp_; + AesBlock encrypted_iv_; + AesBlock plaintext_iv_; }; AesIgeState::AesIgeState() = default; AesIgeState::~AesIgeState() = default; void AesIgeState::init(Slice key, Slice iv, bool encrypt) { - CHECK(key.size() == 32); - CHECK(iv.size() == 32); if (!impl_) { impl_ = make_unique(); } - auto &impl = impl_->state.impl_; - if (!impl) { - impl = make_unique(); - } - if (encrypt) { - impl->init_encrypt_cbc(key); - } else { - impl->init_decrypt_ecb(key); - } - - impl_->encrypted_iv.load(iv.ubegin()); - impl_->plaintext_iv.load(iv.ubegin() + AES_BLOCK_SIZE); + impl_->init(key, iv, encrypt); } void AesIgeState::encrypt(Slice from, MutableSlice to) { @@ -646,10 +649,10 @@ class AesCtrState::Impl { CHECK(key.size() == 32); CHECK(iv.size() == 16); static_assert(AES_BLOCK_SIZE == 16, ""); - aes_state.init(key, true); + evp_.init_encrypt_ecb(key); AesBlock block; block.load(iv.ubegin()); - counter.init(block); + counter_.init(block); fill(); } @@ -658,29 +661,29 @@ class AesCtrState::Impl { auto *dst = to.ubegin(); auto n = from.size(); while (n != 0) { - if (current.empty()) { - counter.rotate(); + if (current_.empty()) { + counter_.rotate(); fill(); } - size_t min_n = td::min(n, current.size()); - XorBytes::run(src, current.ubegin(), dst, min_n); + size_t min_n = td::min(n, current_.size()); + XorBytes::run(src, current_.ubegin(), dst, min_n); src += min_n; dst += min_n; n -= min_n; - current.remove_prefix(min_n); + current_.remove_prefix(min_n); } } private: - AesState aes_state; + Evp evp_; - AesCtrCounterPack counter; - AesCtrCounterPack encrypted_counter; - Slice current; + AesCtrCounterPack counter_; + AesCtrCounterPack encrypted_counter_; + Slice current_; void fill() { - aes_state.encrypt(counter.raw(), encrypted_counter.raw(), static_cast(counter.size())); - current = encrypted_counter.as_slice(); + evp_.encrypt(counter_.raw(), encrypted_counter_.raw(), static_cast(counter_.size())); + current_ = encrypted_counter_.as_slice(); } }; diff --git a/tdutils/td/utils/crypto.h b/tdutils/td/utils/crypto.h index e34a25803..a94970e1e 100644 --- a/tdutils/td/utils/crypto.h +++ b/tdutils/td/utils/crypto.h @@ -35,7 +35,7 @@ class AesState { void decrypt(const uint8 *src, uint8 *dst, int size); private: - class Impl; + struct Impl; unique_ptr impl_; friend class AesIgeState;