Minor AES improvements.

GitOrigin-RevId: 138384ad375735b9e889172cae919368c9976456
This commit is contained in:
levlam 2020-06-15 23:20:44 +03:00
parent 53b0a74f87
commit 6a7dfc4f01
4 changed files with 69 additions and 34 deletions

View File

@ -79,14 +79,14 @@ class AesEcbBench : public td::Benchmark {
} }
}; };
class AesIgeBench : public td::Benchmark { class AesIgeEncryptBench : public td::Benchmark {
public: public:
alignas(64) unsigned char data[DATA_SIZE]; alignas(64) unsigned char data[DATA_SIZE];
td::UInt256 key; td::UInt256 key;
td::UInt256 iv; td::UInt256 iv;
std::string get_description() const override { std::string get_description() const override {
return PSTRING() << "AES IGE OpenSSL [" << (DATA_SIZE >> 10) << "KB]"; return PSTRING() << "AES IGE OpenSSL encrypt [" << (DATA_SIZE >> 10) << "KB]";
} }
void start_up() override { void start_up() override {
@ -107,6 +107,34 @@ class AesIgeBench : public td::Benchmark {
} }
}; };
class AesIgeDecryptBench : public td::Benchmark {
public:
alignas(64) unsigned char data[DATA_SIZE];
td::UInt256 key;
td::UInt256 iv;
std::string get_description() const override {
return PSTRING() << "AES IGE OpenSSL decrypt [" << (DATA_SIZE >> 10) << "KB]";
}
void start_up() override {
for (int i = 0; i < DATA_SIZE; i++) {
data[i] = 123;
}
td::Random::secure_bytes(key.raw, sizeof(key));
td::Random::secure_bytes(iv.raw, sizeof(iv));
}
void run(int n) override {
td::MutableSlice data_slice(data, DATA_SIZE);
td::AesIgeState state;
state.init(as_slice(key), as_slice(iv), false);
for (int i = 0; i < n; i++) {
state.decrypt(data_slice, data_slice);
}
}
};
class AesCtrBench : public td::Benchmark { class AesCtrBench : public td::Benchmark {
public: public:
alignas(64) unsigned char data[DATA_SIZE]; alignas(64) unsigned char data[DATA_SIZE];
@ -163,17 +191,17 @@ class AesCbcBench : public td::Benchmark {
class AesIgeShortBench : public td::Benchmark { class AesIgeShortBench : public td::Benchmark {
public: public:
static constexpr int DATA_SIZE = 16; static constexpr int SHORT_DATA_SIZE = 16;
alignas(64) unsigned char data[DATA_SIZE]; alignas(64) unsigned char data[SHORT_DATA_SIZE];
td::UInt256 key; td::UInt256 key;
td::UInt256 iv; td::UInt256 iv;
std::string get_description() const override { std::string get_description() const override {
return PSTRING() << "AES IGE OpenSSL [" << (DATA_SIZE) << "B]"; return PSTRING() << "AES IGE OpenSSL [" << SHORT_DATA_SIZE << "B]";
} }
void start_up() override { void start_up() override {
for (int i = 0; i < DATA_SIZE; i++) { for (int i = 0; i < SHORT_DATA_SIZE; i++) {
data[i] = 123; data[i] = 123;
} }
td::Random::secure_bytes(as_slice(key)); td::Random::secure_bytes(as_slice(key));
@ -181,12 +209,15 @@ class AesIgeShortBench : public td::Benchmark {
} }
void run(int n) override { void run(int n) override {
td::MutableSlice data_slice(data, DATA_SIZE); td::MutableSlice data_slice(data, SHORT_DATA_SIZE);
td::AesIgeState ige; td::AesIgeState ige;
for (int i = 0; i < n; i++) { for (int i = 0; i < n; i++) {
ige.init(as_slice(key), as_slice(iv), true); if (true) {
ige.encrypt(data_slice, data_slice); ige.init(as_slice(key), as_slice(iv), true);
//td::aes_ige_encrypt(as_slice(key), as_slice(iv), data_slice, data_slice); ige.encrypt(data_slice, data_slice);
} else {
td::aes_ige_encrypt(as_slice(key), as_slice(iv), data_slice, data_slice);
}
} }
} }
}; };
@ -316,9 +347,10 @@ int main() {
td::init_openssl_threads(); td::init_openssl_threads();
td::bench(AesIgeShortBench()); td::bench(AesIgeShortBench());
td::bench(AesIgeEncryptBench());
td::bench(AesIgeDecryptBench());
td::bench(AesCtrBench()); td::bench(AesCtrBench());
td::bench(AesEcbBench()); td::bench(AesEcbBench());
td::bench(AesIgeBench());
td::bench(Pbkdf2Bench()); td::bench(Pbkdf2Bench());
td::bench(RandBench()); td::bench(RandBench());

View File

@ -8,7 +8,6 @@
#include "td/utils/as.h" #include "td/utils/as.h"
#include "td/utils/BigNum.h" #include "td/utils/BigNum.h"
#include "td/utils/bits.h"
#include "td/utils/common.h" #include "td/utils/common.h"
#include "td/utils/logging.h" #include "td/utils/logging.h"
#include "td/utils/misc.h" #include "td/utils/misc.h"
@ -49,7 +48,7 @@
namespace td { namespace td {
struct alignas(8) AesBlock { struct AesBlock {
uint64 hi; uint64 hi;
uint64 lo; uint64 lo;
@ -63,10 +62,9 @@ struct alignas(8) AesBlock {
res.lo = lo ^ b.lo; res.lo = lo ^ b.lo;
return res; return res;
} }
AesBlock &operator^=(const AesBlock &b) { void operator^=(const AesBlock &b) {
hi ^= b.hi; hi ^= b.hi;
lo ^= b.lo; lo ^= b.lo;
return *this;
} }
void load(const uint8 *from) { void load(const uint8 *from) {
@ -87,6 +85,8 @@ struct alignas(8) AesBlock {
return res; return res;
} }
}; };
static_assert(sizeof(AesBlock) == 16, "");
static_assert(sizeof(AesBlock) == AES_BLOCK_SIZE, "");
class XorBytes { class XorBytes {
public: public:
@ -107,7 +107,7 @@ class XorBytes {
size_t n; size_t n;
template <size_t N> template <size_t N>
struct alignas(16) Block { struct alignas(N) Block {
uint8 data[N]; uint8 data[N];
Block operator^(const Block &b) const & { Block operator^(const Block &b) const & {
Block res; Block res;
@ -134,8 +134,8 @@ class XorBytes {
}; };
struct AesCtrCounterPack { struct AesCtrCounterPack {
static constexpr size_t N = 32; static constexpr size_t BLOCK_COUNT = 32;
AesBlock blocks[N]; AesBlock blocks[BLOCK_COUNT];
uint8 *raw() { uint8 *raw() {
return reinterpret_cast<uint8 *>(this); return reinterpret_cast<uint8 *>(this);
} }
@ -144,7 +144,7 @@ struct AesCtrCounterPack {
} }
size_t size() const { size_t size() const {
return N * 16; return sizeof(blocks);
} }
Slice as_slice() const { Slice as_slice() const {
@ -156,13 +156,13 @@ struct AesCtrCounterPack {
void init(AesBlock block) { void init(AesBlock block) {
blocks[0] = block; blocks[0] = block;
for (size_t i = 1; i < N; i++) { for (size_t i = 1; i < BLOCK_COUNT; i++) {
blocks[i] = blocks[i - 1].inc(); blocks[i] = blocks[i - 1].inc();
} }
} }
void rotate() { void rotate() {
blocks[0] = blocks[N - 1].inc(); blocks[0] = blocks[BLOCK_COUNT - 1].inc();
for (size_t i = 1; i < N; i++) { for (size_t i = 1; i < BLOCK_COUNT; i++) {
blocks[i] = blocks[i - 1].inc(); blocks[i] = blocks[i - 1].inc();
} }
} }
@ -373,7 +373,6 @@ int pq_factorize(Slice pq_str, string *p_str, string *q_str) {
class AesState::Impl { class AesState::Impl {
public: public:
EVP_CIPHER_CTX *ctx{nullptr}; EVP_CIPHER_CTX *ctx{nullptr};
bool encrypt{false};
Impl() = default; Impl() = default;
Impl(const Impl &from) = delete; Impl(const Impl &from) = delete;
@ -406,12 +405,10 @@ void AesState::init(Slice key, bool encrypt) {
LOG_IF(FATAL, res != 1); LOG_IF(FATAL, res != 1);
} }
EVP_CIPHER_CTX_set_padding(impl_->ctx, 0); EVP_CIPHER_CTX_set_padding(impl_->ctx, 0);
impl_->encrypt = encrypt;
} }
void AesState::encrypt(const uint8 *src, uint8 *dst, int size) { void AesState::encrypt(const uint8 *src, uint8 *dst, int size) {
CHECK(impl_ != nullptr); CHECK(impl_ != nullptr);
CHECK(impl_->encrypt);
CHECK(impl_->ctx != nullptr); CHECK(impl_->ctx != nullptr);
CHECK(size % 16 == 0); CHECK(size % 16 == 0);
int len; int len;
@ -422,7 +419,6 @@ void AesState::encrypt(const uint8 *src, uint8 *dst, int size) {
void AesState::decrypt(const uint8 *src, uint8 *dst, int size) { void AesState::decrypt(const uint8 *src, uint8 *dst, int size) {
CHECK(impl_ != nullptr); CHECK(impl_ != nullptr);
CHECK(!impl_->encrypt);
CHECK(impl_->ctx != nullptr); CHECK(impl_->ctx != nullptr);
CHECK(size % 16 == 0); CHECK(size % 16 == 0);
int len; int len;
@ -469,6 +465,7 @@ class AesIgeState::Impl {
AesState state; AesState state;
AesBlock iv; AesBlock iv;
AesBlock iv2; AesBlock iv2;
void encrypt(Slice from, MutableSlice to) { void encrypt(Slice from, MutableSlice to) {
CHECK(from.size() % AES_BLOCK_SIZE == 0); CHECK(from.size() % AES_BLOCK_SIZE == 0);
CHECK(to.size() >= from.size()); CHECK(to.size() >= from.size());
@ -476,7 +473,8 @@ class AesIgeState::Impl {
auto in = from.ubegin(); auto in = from.ubegin();
auto out = to.ubegin(); auto out = to.ubegin();
AesBlock tmp, tmp2; AesBlock tmp;
AesBlock tmp2;
while (len) { while (len) {
tmp.load(in); tmp.load(in);
@ -492,6 +490,7 @@ class AesIgeState::Impl {
out += AES_BLOCK_SIZE; out += AES_BLOCK_SIZE;
} }
} }
void decrypt(Slice from, MutableSlice to) { void decrypt(Slice from, MutableSlice to) {
CHECK(from.size() % AES_BLOCK_SIZE == 0); CHECK(from.size() % AES_BLOCK_SIZE == 0);
CHECK(to.size() >= from.size()); CHECK(to.size() >= from.size());
@ -499,7 +498,8 @@ class AesIgeState::Impl {
auto in = from.ubegin(); auto in = from.ubegin();
auto out = to.ubegin(); auto out = to.ubegin();
AesBlock tmp, tmp2; AesBlock tmp;
AesBlock tmp2;
while (len) { while (len) {
tmp.load(in); tmp.load(in);
@ -597,10 +597,8 @@ class AesCtrState::Impl {
fill(); fill();
} }
size_t min_n = td::min(n, current.size()); size_t min_n = td::min(n, current.size());
auto curr = current.ubegin(); XorBytes::run(src, current.ubegin(), dst, min_n);
XorBytes::run(src, curr, dst, min_n);
src += min_n; src += min_n;
curr += min_n;
dst += min_n; dst += min_n;
n -= min_n; n -= min_n;
current.remove_prefix(min_n); current.remove_prefix(min_n);

View File

@ -29,7 +29,9 @@ class AesState {
~AesState(); ~AesState();
void init(Slice key, bool encrypt); void init(Slice key, bool encrypt);
void encrypt(const uint8 *src, uint8 *dst, int size); void encrypt(const uint8 *src, uint8 *dst, int size);
void decrypt(const uint8 *src, uint8 *dst, int size); void decrypt(const uint8 *src, uint8 *dst, int size);
private: private:
@ -50,8 +52,11 @@ struct AesIgeState {
AesIgeState(AesIgeState &&from); AesIgeState(AesIgeState &&from);
AesIgeState &operator=(AesIgeState &&from); AesIgeState &operator=(AesIgeState &&from);
~AesIgeState(); ~AesIgeState();
void init(Slice key, Slice iv, bool encrypt); void init(Slice key, Slice iv, bool encrypt);
void encrypt(Slice from, MutableSlice to); void encrypt(Slice from, MutableSlice to);
void decrypt(Slice from, MutableSlice to); void decrypt(Slice from, MutableSlice to);
private: private:

View File

@ -75,7 +75,7 @@ TEST(Crypto, AesCtrState) {
ASSERT_EQ(answers1[i], td::crc32(t)); ASSERT_EQ(answers1[i], td::crc32(t));
state.init(as_slice(key), as_slice(iv)); state.init(as_slice(key), as_slice(iv));
state.decrypt(t, t); state.decrypt(t, t);
ASSERT_STREQ(s, t); ASSERT_STREQ(td::base64_encode(s), td::base64_encode(t));
for (auto &c : iv.raw) { for (auto &c : iv.raw) {
c = 0xFF; c = 0xFF;
@ -87,7 +87,6 @@ TEST(Crypto, AesCtrState) {
i++; i++;
} }
} }
#endif
TEST(Crypto, AesIgeState) { TEST(Crypto, AesIgeState) {
td::vector<td::uint32> answers1{0u, 2045698207u, 2423540300u, 525522475u, 1545267325u}; td::vector<td::uint32> answers1{0u, 2045698207u, 2423540300u, 525522475u, 1545267325u};
@ -121,10 +120,11 @@ TEST(Crypto, AesIgeState) {
state.init(as_slice(key), as_slice(iv), false); state.init(as_slice(key), as_slice(iv), false);
state.decrypt(t, t); state.decrypt(t, t);
ASSERT_STREQ(s, t); ASSERT_STREQ(td::base64_encode(s), td::base64_encode(t));
i++; i++;
} }
} }
#endif
TEST(Crypto, Sha256State) { TEST(Crypto, Sha256State) {
for (auto length : {0, 1, 31, 32, 33, 9999, 10000, 10001, 999999, 1000001}) { for (auto length : {0, 1, 31, 32, 33, 9999, 10000, 10001, 999999, 1000001}) {