tdutils: optimize aes ctr
GitOrigin-RevId: 09c6df45c0bf2683507a0f279769471efc859ecb
This commit is contained in:
parent
55ca575af5
commit
e913c3126b
@ -71,13 +71,42 @@ class AesBench : public td::Benchmark {
|
||||
state.init(td::as_slice(key), true);
|
||||
td::MutableSlice data_slice(data, DATA_SIZE);
|
||||
for (int i = 0; i <= n; i++) {
|
||||
for (size_t offset = 0; offset + 16 <= data_slice.size(); offset += 16) {
|
||||
state.encrypt(data_slice.ubegin() + offset, data_slice.ubegin() + offset);
|
||||
size_t step = 16;
|
||||
for (size_t offset = 0; offset + 16 <= data_slice.size(); offset += step) {
|
||||
state.encrypt(data_slice.ubegin() + offset, data_slice.ubegin() + offset, (int)step);
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
class AesCtrBench : public td::Benchmark {
|
||||
public:
|
||||
alignas(64) unsigned char data[DATA_SIZE];
|
||||
td::UInt256 key;
|
||||
td::UInt128 iv;
|
||||
|
||||
std::string get_description() const override {
|
||||
return PSTRING() << "AES CTR OpenSSL [" << (DATA_SIZE >> 10) << "KB]";
|
||||
}
|
||||
|
||||
void start_up() override {
|
||||
for (int i = 0; i < DATA_SIZE; i++) {
|
||||
data[i] = 123;
|
||||
}
|
||||
td::Random::secure_bytes(key.raw, sizeof(key));
|
||||
td::Random::secure_bytes(iv.raw, sizeof(iv));
|
||||
}
|
||||
|
||||
void run(int n) override {
|
||||
td::AesCtrState state;
|
||||
state.init(as_slice(key), as_slice(iv));
|
||||
td::MutableSlice data_slice(data, DATA_SIZE);
|
||||
for (int i = 0; i < n; i++) {
|
||||
state.encrypt(data_slice, data_slice);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
class AESBench : public td::Benchmark {
|
||||
public:
|
||||
alignas(64) unsigned char data[DATA_SIZE];
|
||||
@ -227,6 +256,7 @@ class Crc64Bench : public td::Benchmark {
|
||||
|
||||
int main() {
|
||||
td::init_openssl_threads();
|
||||
td::bench(AesCtrBench());
|
||||
td::bench(AesBench());
|
||||
td::bench(AESBench());
|
||||
|
||||
|
@ -273,30 +273,28 @@ void AesState::init(Slice key, bool encrypt) {
|
||||
|
||||
if (encrypt) {
|
||||
CHECK(1 == EVP_EncryptInit_ex(impl_->ctx, EVP_aes_256_ecb(), nullptr, key.ubegin(), nullptr));
|
||||
AES_set_encrypt_key(key.ubegin(), 256, &impl_->key);
|
||||
} else {
|
||||
CHECK(1 == EVP_DecryptInit_ex(impl_->ctx, EVP_aes_256_ecb(), nullptr, key.ubegin(), nullptr));
|
||||
AES_set_decrypt_key(key.ubegin(), 256, &impl_->key);
|
||||
}
|
||||
EVP_CIPHER_CTX_set_padding(impl_->ctx, 0);
|
||||
impl_->encrypt = encrypt;
|
||||
}
|
||||
|
||||
void AesState::encrypt(const uint8 *src, uint8 *dst) {
|
||||
void AesState::encrypt(const uint8 *src, uint8 *dst, int size) {
|
||||
CHECK(impl_->encrypt);
|
||||
CHECK(impl_->ctx);
|
||||
CHECK(size % 16 == 0);
|
||||
int len;
|
||||
CHECK(1 == EVP_EncryptUpdate(impl_->ctx, dst, &len, src, 16));
|
||||
CHECK(len == 16);
|
||||
//AES_encrypt(src, dst, &impl_->key);
|
||||
CHECK(1 == EVP_EncryptUpdate(impl_->ctx, dst, &len, src, size));
|
||||
CHECK(len == size);
|
||||
}
|
||||
void AesState::decrypt(const uint8 *src, uint8 *dst) {
|
||||
void AesState::decrypt(const uint8 *src, uint8 *dst, int size) {
|
||||
CHECK(!impl_->encrypt);
|
||||
CHECK(impl_->ctx);
|
||||
CHECK(size % 16 == 0);
|
||||
int len;
|
||||
CHECK(1 == EVP_DecryptUpdate(impl_->ctx, dst, &len, src, 16));
|
||||
LOG_CHECK(len == 16) << len;
|
||||
//AES_decrypt(src, dst, &impl_->key);
|
||||
CHECK(1 == EVP_DecryptUpdate(impl_->ctx, dst, &len, src, size));
|
||||
CHECK(len == size);
|
||||
}
|
||||
|
||||
static void aes_ige_xcrypt(Slice aes_key, MutableSlice aes_iv, Slice from, MutableSlice to, bool encrypt_flag) {
|
||||
@ -363,35 +361,63 @@ class AesCtrState::Impl {
|
||||
CHECK(key.size() == 32);
|
||||
CHECK(iv.size() == 16);
|
||||
static_assert(AES_BLOCK_SIZE == 16, "");
|
||||
if (AES_set_encrypt_key(key.ubegin(), 256, &aes_key) < 0) {
|
||||
LOG(FATAL) << "Failed to set encrypt key";
|
||||
}
|
||||
aes_state.init(key, true);
|
||||
counter.as_mutable_slice().copy_from(iv);
|
||||
current_pos = 0;
|
||||
fill();
|
||||
}
|
||||
|
||||
void encrypt(Slice from, MutableSlice to) {
|
||||
CHECK(to.size() >= from.size());
|
||||
for (size_t i = 0; i < from.size(); i++) {
|
||||
if (current_pos == 0) {
|
||||
AES_encrypt(counter.as_slice().ubegin(), encrypted_counter.as_mutable_slice().ubegin(), &aes_key);
|
||||
uint8 *ptr = counter.as_mutable_slice().ubegin();
|
||||
for (int j = 15; j >= 0; j--) {
|
||||
if (++ptr[j] != 0) {
|
||||
break;
|
||||
}
|
||||
auto *src = from.ubegin();
|
||||
auto *dst = to.ubegin();
|
||||
auto n = from.size();
|
||||
while (n != 0) {
|
||||
if (current.empty()) {
|
||||
if (N != 1) {
|
||||
counter.as_mutable_slice().copy_from(counter.as_slice().substr((N - 1) * AES_BLOCK_SIZE));
|
||||
}
|
||||
inc(counter.as_mutable_slice().ubegin());
|
||||
fill();
|
||||
}
|
||||
to[i] = static_cast<char>(from[i] ^ encrypted_counter[current_pos]);
|
||||
current_pos = (current_pos + 1) & 15;
|
||||
size_t min_n = td::min(n, current.size());
|
||||
auto curr = current.ubegin();
|
||||
for (size_t i = 0; i < min_n; i++) {
|
||||
dst[i] = src[i] ^ curr[i];
|
||||
}
|
||||
n -= min_n;
|
||||
src += min_n;
|
||||
dst += min_n;
|
||||
current.remove_prefix(min_n);
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
AES_KEY aes_key;
|
||||
SecureString counter{AES_BLOCK_SIZE};
|
||||
SecureString encrypted_counter{AES_BLOCK_SIZE};
|
||||
uint8 current_pos;
|
||||
AesState aes_state;
|
||||
|
||||
static constexpr size_t N = 32;
|
||||
SecureString counter{AES_BLOCK_SIZE * N};
|
||||
SecureString encrypted_counter{AES_BLOCK_SIZE * N};
|
||||
td::Slice current;
|
||||
|
||||
void inc(uint8 *ptr) {
|
||||
for (int j = 15; j >= 0; j--) {
|
||||
if (++ptr[j] != 0) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
void fill() {
|
||||
auto *src = counter.as_slice().ubegin();
|
||||
auto *dst = counter.as_mutable_slice().ubegin() + AES_BLOCK_SIZE;
|
||||
for (size_t i = 0; i + 1 < N; i++) {
|
||||
memcpy(dst, src, AES_BLOCK_SIZE);
|
||||
inc(dst);
|
||||
src += AES_BLOCK_SIZE;
|
||||
dst += AES_BLOCK_SIZE;
|
||||
}
|
||||
|
||||
aes_state.encrypt(counter.as_slice().ubegin(), encrypted_counter.as_mutable_slice().ubegin(), (int)counter.size());
|
||||
current = encrypted_counter.as_slice();
|
||||
}
|
||||
};
|
||||
|
||||
AesCtrState::AesCtrState() = default;
|
||||
|
@ -28,8 +28,8 @@ struct AesState {
|
||||
AesState &operator=(AesState &&from);
|
||||
~AesState();
|
||||
void init(Slice key, bool encrypt);
|
||||
void encrypt(const uint8 *src, uint8 *dst);
|
||||
void decrypt(const uint8 *src, uint8 *dst);
|
||||
void encrypt(const uint8 *src, uint8 *dst, int size);
|
||||
void decrypt(const uint8 *src, uint8 *dst, int size);
|
||||
|
||||
private:
|
||||
class Impl;
|
||||
|
@ -32,8 +32,8 @@ TEST(Crypto, Aes) {
|
||||
td::AesState decryptor;
|
||||
decryptor.init(as_slice(key), false);
|
||||
|
||||
encryptor.encrypt(td::as_slice(plaintext).ubegin(), td::as_slice(encrypted).ubegin());
|
||||
decryptor.decrypt(td::as_slice(encrypted).ubegin(), td::as_slice(decrypted).ubegin());
|
||||
encryptor.encrypt(td::as_slice(plaintext).ubegin(), td::as_slice(encrypted).ubegin(), 16);
|
||||
decryptor.decrypt(td::as_slice(encrypted).ubegin(), td::as_slice(decrypted).ubegin(), 16);
|
||||
|
||||
CHECK(decrypted == plaintext);
|
||||
CHECK(decrypted != encrypted);
|
||||
|
Loading…
Reference in New Issue
Block a user