diff --git a/tdutils/td/utils/crypto.cpp b/tdutils/td/utils/crypto.cpp index 5b8a15633..5d86525d4 100644 --- a/tdutils/td/utils/crypto.cpp +++ b/tdutils/td/utils/crypto.cpp @@ -396,6 +396,10 @@ class Evp { init(Type::Cbc, false, EVP_aes_256_cbc(), key); } + void init_encrypt_ctr(Slice key) { + init(Type::Ctr, true, EVP_aes_256_ctr(), key); + } + void init_iv(Slice iv) { int res = EVP_CipherInit_ex(ctx_, nullptr, nullptr, nullptr, iv.ubegin(), -1); LOG_IF(FATAL, res != 1); @@ -403,7 +407,7 @@ class Evp { void encrypt(const uint8 *src, uint8 *dst, int size) { // CHECK(type_ != Type::Empty && is_encrypt_); - CHECK(size % AES_BLOCK_SIZE == 0); + // CHECK(size % AES_BLOCK_SIZE == 0); int len; int res = EVP_EncryptUpdate(ctx_, dst, &len, src, size); LOG_IF(FATAL, res != 1); @@ -421,7 +425,7 @@ class Evp { private: EVP_CIPHER_CTX *ctx_{nullptr}; - enum class Type : int8 { Empty, Ecb, Cbc }; + enum class Type : int8 { Empty, Ecb, Cbc, Ctr }; // Type type_{Type::Empty}; // bool is_encrypt_ = false; @@ -627,49 +631,8 @@ void AesCbcState::decrypt(Slice from, MutableSlice to) { ::td::aes_cbc_decrypt(raw_.key.as_slice(), raw_.iv.as_mutable_slice(), from, to); } -class AesCtrState::Impl { - public: - Impl(Slice key, Slice iv) { - CHECK(key.size() == 32); - CHECK(iv.size() == 16); - static_assert(AES_BLOCK_SIZE == 16, ""); - evp_.init_encrypt_ecb(key); - counter_.load(iv.ubegin()); - fill(); - } - - void encrypt(Slice from, MutableSlice to) { - auto *src = from.ubegin(); - auto *dst = to.ubegin(); - auto n = from.size(); - while (n != 0) { - size_t left = encrypted_counter_.raw() + AesCtrCounterPack::size() - current_; - if (left == 0) { - fill(); - left = AesCtrCounterPack::size(); - } - size_t min_n = td::min(n, left); - XorBytes::run(src, current_, dst, min_n); - src += min_n; - dst += min_n; - n -= min_n; - current_ += min_n; - } - } - - private: +struct AesCtrState::Impl { Evp evp_; - - uint8 *current_; - AesBlock counter_; - AesCtrCounterPack encrypted_counter_; - - void fill() { - encrypted_counter_.init(counter_); - counter_ = encrypted_counter_.blocks[AesCtrCounterPack::BLOCK_COUNT - 1].inc(); - current_ = encrypted_counter_.raw(); - evp_.encrypt(current_, current_, static_cast(AesCtrCounterPack::size())); - } }; AesCtrState::AesCtrState() = default; @@ -678,15 +641,20 @@ AesCtrState &AesCtrState::operator=(AesCtrState &&from) = default; AesCtrState::~AesCtrState() = default; void AesCtrState::init(Slice key, Slice iv) { - ctx_ = make_unique(key, iv); + CHECK(key.size() == 32); + CHECK(iv.size() == 16); + ctx_ = make_unique(); + ctx_->evp_.init_encrypt_ctr(key); + ctx_->evp_.init_iv(iv); } void AesCtrState::encrypt(Slice from, MutableSlice to) { - ctx_->encrypt(from, to); + CHECK(from.size() <= to.size()); + ctx_->evp_.encrypt(from.ubegin(), to.ubegin(), from.size()); } void AesCtrState::decrypt(Slice from, MutableSlice to) { - encrypt(from, to); // it is the same as decrypt + encrypt(from, to); } void sha1(Slice data, unsigned char output[20]) { diff --git a/tdutils/td/utils/crypto.h b/tdutils/td/utils/crypto.h index 0045f7e1a..7fa4d8046 100644 --- a/tdutils/td/utils/crypto.h +++ b/tdutils/td/utils/crypto.h @@ -84,7 +84,7 @@ class AesCtrState { void decrypt(Slice from, MutableSlice to); private: - class Impl; + struct Impl; unique_ptr ctx_; };