diff --git a/tdutils/td/utils/FlatHashMapChunks.h b/tdutils/td/utils/FlatHashMapChunks.h index 0ff67a476..dabeba9a0 100644 --- a/tdutils/td/utils/FlatHashMapChunks.h +++ b/tdutils/td/utils/FlatHashMapChunks.h @@ -19,8 +19,6 @@ #include #include -#include - struct MaskPortable { static uint64_t equal_mask(uint8_t *bytes, uint8_t needle) { uint64_t res = 0; @@ -31,6 +29,8 @@ struct MaskPortable { } }; +#ifdef __aarch64__ +#include struct MaskNeonFolly { static uint64_t equal_mask(uint8_t *bytes, uint8_t needle) { uint8x16_t input_mask = vld1q_u8(bytes); @@ -57,6 +57,13 @@ struct MaskNeon { return vaddvq_u16(a_shifted); } }; +#endif + +#ifdef __aarch64__ +using MaskHelper = MaskNeon; +#else +using MaskHelper = MaskPortable; +#endif namespace td { template @@ -207,7 +214,7 @@ class FlatHashTableChunks { while (true) { auto chunk_i = chunk_it.next(); auto &chunk = chunks_[chunk_i]; - auto mask = MaskNeon::equal_mask(chunk.ctrl, hash.small_hash) & Chunk::MASK; + auto mask = MaskHelper::equal_mask(chunk.ctrl, hash.small_hash) & Chunk::MASK; while (mask != 0) { auto it = nodes_.begin() + td::count_trailing_zeroes64(mask) + chunk_i * Chunk::CHUNK_SIZE; if (EqT()(it->first, key)) { @@ -278,7 +285,7 @@ class FlatHashTableChunks { while (true) { auto chunk_i = chunk_it.next(); auto &chunk = chunks_[chunk_i]; - auto mask = MaskPortable::equal_mask(chunk.ctrl, 0) & Chunk::MASK; + auto mask = MaskHelper::equal_mask(chunk.ctrl, 0) & Chunk::MASK; if (mask != 0) { auto shift = td::count_trailing_zeroes64(mask); DCHECK(chunk.ctrl[shift] == 0); @@ -451,7 +458,7 @@ class FlatHashTableChunks { while (true) { auto chunk_i = chunk_it.next(); auto &chunk = chunks_[chunk_i]; - auto mask = MaskPortable::equal_mask(chunk.ctrl, 0) & Chunk::MASK; + auto mask = MaskHelper::equal_mask(chunk.ctrl, 0) & Chunk::MASK; if (mask != 0) { auto shift = td::count_trailing_zeroes64(mask); auto node_it = nodes_.begin() + shift + chunk_i * Chunk::CHUNK_SIZE; diff --git a/tdutils/test/hashset_benchmark.cpp b/tdutils/test/hashset_benchmark.cpp index d937d7c45..d4240ce81 100644 --- a/tdutils/test/hashset_benchmark.cpp +++ b/tdutils/test/hashset_benchmark.cpp @@ -543,8 +543,10 @@ void BM_mask(benchmark::State &state) { } } BENCHMARK_TEMPLATE(BM_mask, MaskPortable); +#ifdef __aarch64__ BENCHMARK_TEMPLATE(BM_mask, MaskNeonFolly); BENCHMARK_TEMPLATE(BM_mask, MaskNeon); +#endif #define FOR_EACH_TABLE(F) \ F(td::FlatHashMapChunks) \