diff --git a/benchmark/bench_crypto.cpp b/benchmark/bench_crypto.cpp index e861914c5..dc4472d64 100644 --- a/benchmark/bench_crypto.cpp +++ b/benchmark/bench_crypto.cpp @@ -127,8 +127,10 @@ class AesIgeBench : public td::Benchmark { void run(int n) override { td::MutableSlice data_slice(data, DATA_SIZE); + td::AesIgeState state; + state.init(as_slice(key), as_slice(iv), true); for (int i = 0; i < n; i++) { - td::aes_ige_encrypt(as_slice(key), as_slice(iv), data_slice, data_slice); + state.encrypt(data_slice, data_slice); } } }; @@ -257,8 +259,8 @@ class Crc64Bench : public td::Benchmark { int main() { td::init_openssl_threads(); td::bench(AesEcbBench()); - td::bench(AesCtrBench()); td::bench(AesIgeBench()); + td::bench(AesCtrBench()); td::bench(Pbkdf2Bench()); td::bench(RandBench()); diff --git a/tdutils/td/utils/crypto.cpp b/tdutils/td/utils/crypto.cpp index e49e2eb98..e0067e81b 100644 --- a/tdutils/td/utils/crypto.cpp +++ b/tdutils/td/utils/crypto.cpp @@ -32,6 +32,16 @@ #include #endif +#if TD_HAVE_OPENSSL +#define N_WORDS (AES_BLOCK_SIZE / sizeof(unsigned long)) +typedef struct { + unsigned long data[N_WORDS]; +} aes_block_t; + +#define load_block(d, s) memcpy((d).data, (s), AES_BLOCK_SIZE) +#define store_block(d, s) memcpy((d), (s).data, AES_BLOCK_SIZE) +#endif + #if TD_HAVE_ZLIB #include #endif @@ -332,6 +342,87 @@ void aes_ige_decrypt(Slice aes_key, MutableSlice aes_iv, Slice from, MutableSlic aes_ige_xcrypt(aes_key, aes_iv, from, to, false); } +class AesIgeState::Impl { + public: + AesState state; + aes_block_t iv; + aes_block_t iv2; + void encrypt(Slice from, MutableSlice to) { + CHECK(from.size() % AES_BLOCK_SIZE == 0); + CHECK(to.size() >= from.size()); + auto len = to.size() / AES_BLOCK_SIZE; + auto in = from.ubegin(); + auto out = to.begin(); + + aes_block_t tmp, tmp2; + + while (len) { + load_block(tmp, in); + for (size_t n = 0; n < N_WORDS; ++n) { + tmp2.data[n] = tmp.data[n] ^ iv.data[n]; + } + + state.encrypt((unsigned char *)tmp2.data, (unsigned char *)tmp2.data, AES_BLOCK_SIZE); + for (size_t n = 0; n < N_WORDS; ++n) { + tmp2.data[n] ^= iv2.data[n]; + } + store_block(out, tmp2); + iv = tmp2; + iv2 = tmp; + --len; + in += AES_BLOCK_SIZE; + out += AES_BLOCK_SIZE; + } + } + void decrypt(Slice from, MutableSlice to) { + CHECK(from.size() % AES_BLOCK_SIZE == 0); + CHECK(to.size() >= from.size()); + auto len = to.size() / AES_BLOCK_SIZE; + auto in = from.ubegin(); + auto out = to.begin(); + + aes_block_t tmp, tmp2; + + while (len) { + load_block(tmp, in); + tmp2 = tmp; + for (size_t n = 0; n < N_WORDS; ++n) { + tmp.data[n] ^= iv2.data[n]; + } + state.decrypt((unsigned char *)tmp.data, (unsigned char *)tmp.data, AES_BLOCK_SIZE); + for (size_t n = 0; n < N_WORDS; ++n) { + tmp.data[n] ^= iv.data[n]; + } + store_block(out, tmp); + iv = tmp2; + iv2 = tmp; + --len; + in += AES_BLOCK_SIZE; + out += AES_BLOCK_SIZE; + } + } +}; + +AesIgeState::AesIgeState() = default; +AesIgeState::~AesIgeState() = default; + +void AesIgeState::init(Slice key, Slice iv, bool encrypt) { + CHECK(key.size() == 32); + CHECK(iv.size() == 32); + impl_ = make_unique(); + impl_->state.init(key, encrypt); + load_block(impl_->iv, iv.ubegin()); + load_block(impl_->iv2, iv.ubegin() + AES_BLOCK_SIZE); +} + +void AesIgeState::encrypt(Slice from, MutableSlice to) { + impl_->encrypt(from, to); +} + +void AesIgeState::decrypt(Slice from, MutableSlice to) { + impl_->decrypt(from, to); +} + static void aes_cbc_xcrypt(Slice aes_key, MutableSlice aes_iv, Slice from, MutableSlice to, bool encrypt_flag) { CHECK(aes_key.size() == 32); CHECK(aes_iv.size() == 16); diff --git a/tdutils/td/utils/crypto.h b/tdutils/td/utils/crypto.h index c97ad4c22..de0c5f766 100644 --- a/tdutils/td/utils/crypto.h +++ b/tdutils/td/utils/crypto.h @@ -42,6 +42,23 @@ int pq_factorize(Slice pq_str, string *p_str, string *q_str); void aes_ige_encrypt(Slice aes_key, MutableSlice aes_iv, Slice from, MutableSlice to); void aes_ige_decrypt(Slice aes_key, MutableSlice aes_iv, Slice from, MutableSlice to); +struct AesIgeState { + public: + AesIgeState(); + AesIgeState(const AesIgeState &from) = delete; + AesIgeState &operator=(const AesIgeState &from) = delete; + AesIgeState(AesIgeState &&from); + AesIgeState &operator=(AesIgeState &&from); + ~AesIgeState(); + void init(Slice key, Slice iv, bool encrypt); + void encrypt(Slice from, MutableSlice to); + void decrypt(Slice from, MutableSlice to); + + private: + class Impl; + unique_ptr impl_; +}; + void aes_cbc_encrypt(Slice aes_key, MutableSlice aes_iv, Slice from, MutableSlice to); void aes_cbc_decrypt(Slice aes_key, MutableSlice aes_iv, Slice from, MutableSlice to); diff --git a/tdutils/test/crypto.cpp b/tdutils/test/crypto.cpp index 306206d63..18fbf442f 100644 --- a/tdutils/test/crypto.cpp +++ b/tdutils/test/crypto.cpp @@ -89,6 +89,43 @@ TEST(Crypto, AesCtrState) { } #endif +TEST(Crypto, AesIgeState) { + td::vector answers1{0u, 2045698207u, 2423540300u, 525522475u, 1545267325u}; + + std::size_t i = 0; + for (auto length : {0, 16, 32, 256, 1024}) { + td::uint32 seed = length; + td::string s(length, '\0'); + for (auto &c : s) { + seed = seed * 123457567u + 987651241u; + c = static_cast((seed >> 23) & 255); + } + + td::UInt256 key; + for (auto &c : key.raw) { + seed = seed * 123457567u + 987651241u; + c = (seed >> 23) & 255; + } + td::UInt256 iv; + for (auto &c : iv.raw) { + seed = seed * 123457567u + 987651241u; + c = (seed >> 23) & 255; + } + + td::AesIgeState state; + state.init(as_slice(key), as_slice(iv), true); + td::string t(length, '\0'); + state.encrypt(s, t); + + ASSERT_EQ(answers1[i], td::crc32(t)); + + state.init(as_slice(key), as_slice(iv), false); + state.decrypt(t, t); + ASSERT_STREQ(s, t); + i++; + } +} + TEST(Crypto, Sha256State) { for (auto length : {0, 1, 31, 32, 33, 9999, 10000, 10001, 999999, 1000001}) { auto s = td::rand_string(std::numeric_limits::min(), std::numeric_limits::max(), length);