diff --git a/benchmark/bench_crypto.cpp b/benchmark/bench_crypto.cpp index fb957ec8c..253127dc2 100644 --- a/benchmark/bench_crypto.cpp +++ b/benchmark/bench_crypto.cpp @@ -346,11 +346,11 @@ class Crc64Bench : public td::Benchmark { int main() { td::init_openssl_threads(); + td::bench(AesIgeShortBench()); td::bench(AesIgeEncryptBench()); td::bench(AesIgeDecryptBench()); td::bench(AesEcbBench()); td::bench(AesCtrBench()); - td::bench(AesIgeShortBench()); td::bench(Pbkdf2Bench()); td::bench(RandBench()); diff --git a/tdutils/td/utils/crypto.cpp b/tdutils/td/utils/crypto.cpp index 036f4ab2f..ca18a4fea 100644 --- a/tdutils/td/utils/crypto.cpp +++ b/tdutils/td/utils/crypto.cpp @@ -433,9 +433,6 @@ static void aes_ige_xcrypt(Slice aes_key, MutableSlice aes_iv, Slice from, Mutab } void aes_ige_encrypt(Slice aes_key, MutableSlice aes_iv, Slice from, MutableSlice to) { - if (from.size() <= 128) { - return aes_ige_xcrypt(aes_key, aes_iv, from, to, true); - } AesIgeState state; state.init(aes_key, aes_iv, true); state.encrypt(from, to); @@ -452,7 +449,6 @@ void aes_ige_decrypt(Slice aes_key, MutableSlice aes_iv, Slice from, MutableSlic class AesIgeState::Impl { public: - AesState cbc_state; AesState state; AesBlock encrypted_iv; AesBlock plaintext_iv; @@ -464,51 +460,38 @@ class AesIgeState::Impl { auto in = from.ubegin(); auto out = to.ubegin(); - static constexpr size_t BLOCK_COUNT = 32; - while (len >= BLOCK_COUNT) { + static constexpr size_t BLOCK_COUNT = 31; + while (len != 0) { AesBlock data[BLOCK_COUNT]; AesBlock data_xored[BLOCK_COUNT]; - std::memcpy(data, in, sizeof(data)); + + auto count = td::min(BLOCK_COUNT, len); + std::memcpy(data, in, AES_BLOCK_SIZE * count); data_xored[0] = data[0]; - data_xored[1] = plaintext_iv ^ data[1]; - for (size_t i = 2; i < BLOCK_COUNT; i++) { - data_xored[i] = data[i - 2] ^ data[i]; + if (count > 1) { + data_xored[1] = plaintext_iv ^ data[1]; + for (size_t i = 2; i < count; i++) { + data_xored[i] = data[i - 2] ^ data[i]; + } } - EVP_EncryptInit(cbc_state.impl_->ctx, nullptr, nullptr, encrypted_iv.raw()); + EVP_EncryptInit(state.impl_->ctx, nullptr, nullptr, encrypted_iv.raw()); int outlen = 0; - int inlen = static_cast(sizeof(data_xored)); - EVP_EncryptUpdate(cbc_state.impl_->ctx, data_xored[0].raw(), &outlen, data_xored[0].raw(), inlen); + int inlen = static_cast(AES_BLOCK_SIZE * count); + EVP_EncryptUpdate(state.impl_->ctx, data_xored[0].raw(), &outlen, data_xored[0].raw(), inlen); CHECK(outlen == inlen); data_xored[0] ^= plaintext_iv; - for (size_t i = 1; i < BLOCK_COUNT; i++) { + for (size_t i = 1; i < count; i++) { data_xored[i] ^= data[i - 1]; } - plaintext_iv = data[BLOCK_COUNT - 1]; - encrypted_iv = data_xored[BLOCK_COUNT - 1]; + plaintext_iv = data[count - 1]; + encrypted_iv = data_xored[count - 1]; - std::memcpy(out, data_xored, sizeof(data_xored)); - len -= BLOCK_COUNT; - in += AES_BLOCK_SIZE * BLOCK_COUNT; - out += AES_BLOCK_SIZE * BLOCK_COUNT; - } - - AesBlock plaintext; - - while (len) { - plaintext.load(in); - - encrypted_iv ^= plaintext; - state.encrypt(encrypted_iv.raw(), encrypted_iv.raw(), AES_BLOCK_SIZE); - encrypted_iv ^= plaintext_iv; - - encrypted_iv.store(out); - plaintext_iv = plaintext; - - --len; - in += AES_BLOCK_SIZE; - out += AES_BLOCK_SIZE; + std::memcpy(out, data_xored, AES_BLOCK_SIZE * count); + len -= count; + in += AES_BLOCK_SIZE * count; + out += AES_BLOCK_SIZE * count; } } @@ -547,9 +530,8 @@ void AesIgeState::init(Slice key, Slice iv, bool encrypt) { if (!impl_) { impl_ = make_unique(); } - impl_->state.init(key, encrypt); if (encrypt) { - auto &impl = impl_->cbc_state.impl_; + auto &impl = impl_->state.impl_; if (!impl) { impl = make_unique(); impl->ctx = EVP_CIPHER_CTX_new(); @@ -558,6 +540,8 @@ void AesIgeState::init(Slice key, Slice iv, bool encrypt) { int res = EVP_EncryptInit_ex(impl->ctx, EVP_aes_256_cbc(), nullptr, key.ubegin(), nullptr); LOG_IF(FATAL, res != 1); EVP_CIPHER_CTX_set_padding(impl->ctx, 0); + } else { + impl_->state.init(key, encrypt); } impl_->encrypted_iv.load(iv.ubegin()); impl_->plaintext_iv.load(iv.ubegin() + AES_BLOCK_SIZE);