FlatHashMap: fixes for portability

This commit is contained in:
Arseny Smirnov 2022-02-17 20:22:46 +01:00
parent 34a69e3133
commit 5ff92065bf
2 changed files with 14 additions and 5 deletions

View File

@ -19,8 +19,6 @@
#include <new> #include <new>
#include <utility> #include <utility>
#include <arm_neon.h>
struct MaskPortable { struct MaskPortable {
static uint64_t equal_mask(uint8_t *bytes, uint8_t needle) { static uint64_t equal_mask(uint8_t *bytes, uint8_t needle) {
uint64_t res = 0; uint64_t res = 0;
@ -31,6 +29,8 @@ struct MaskPortable {
} }
}; };
#ifdef __aarch64__
#include <arm_neon.h>
struct MaskNeonFolly { struct MaskNeonFolly {
static uint64_t equal_mask(uint8_t *bytes, uint8_t needle) { static uint64_t equal_mask(uint8_t *bytes, uint8_t needle) {
uint8x16_t input_mask = vld1q_u8(bytes); uint8x16_t input_mask = vld1q_u8(bytes);
@ -57,6 +57,13 @@ struct MaskNeon {
return vaddvq_u16(a_shifted); return vaddvq_u16(a_shifted);
} }
}; };
#endif
#ifdef __aarch64__
using MaskHelper = MaskNeon;
#else
using MaskHelper = MaskPortable;
#endif
namespace td { namespace td {
template <class NodeT, class HashT, class EqT> template <class NodeT, class HashT, class EqT>
@ -207,7 +214,7 @@ class FlatHashTableChunks {
while (true) { while (true) {
auto chunk_i = chunk_it.next(); auto chunk_i = chunk_it.next();
auto &chunk = chunks_[chunk_i]; auto &chunk = chunks_[chunk_i];
auto mask = MaskNeon::equal_mask(chunk.ctrl, hash.small_hash) & Chunk::MASK; auto mask = MaskHelper::equal_mask(chunk.ctrl, hash.small_hash) & Chunk::MASK;
while (mask != 0) { while (mask != 0) {
auto it = nodes_.begin() + td::count_trailing_zeroes64(mask) + chunk_i * Chunk::CHUNK_SIZE; auto it = nodes_.begin() + td::count_trailing_zeroes64(mask) + chunk_i * Chunk::CHUNK_SIZE;
if (EqT()(it->first, key)) { if (EqT()(it->first, key)) {
@ -278,7 +285,7 @@ class FlatHashTableChunks {
while (true) { while (true) {
auto chunk_i = chunk_it.next(); auto chunk_i = chunk_it.next();
auto &chunk = chunks_[chunk_i]; auto &chunk = chunks_[chunk_i];
auto mask = MaskPortable::equal_mask(chunk.ctrl, 0) & Chunk::MASK; auto mask = MaskHelper::equal_mask(chunk.ctrl, 0) & Chunk::MASK;
if (mask != 0) { if (mask != 0) {
auto shift = td::count_trailing_zeroes64(mask); auto shift = td::count_trailing_zeroes64(mask);
DCHECK(chunk.ctrl[shift] == 0); DCHECK(chunk.ctrl[shift] == 0);
@ -451,7 +458,7 @@ class FlatHashTableChunks {
while (true) { while (true) {
auto chunk_i = chunk_it.next(); auto chunk_i = chunk_it.next();
auto &chunk = chunks_[chunk_i]; auto &chunk = chunks_[chunk_i];
auto mask = MaskPortable::equal_mask(chunk.ctrl, 0) & Chunk::MASK; auto mask = MaskHelper::equal_mask(chunk.ctrl, 0) & Chunk::MASK;
if (mask != 0) { if (mask != 0) {
auto shift = td::count_trailing_zeroes64(mask); auto shift = td::count_trailing_zeroes64(mask);
auto node_it = nodes_.begin() + shift + chunk_i * Chunk::CHUNK_SIZE; auto node_it = nodes_.begin() + shift + chunk_i * Chunk::CHUNK_SIZE;

View File

@ -543,8 +543,10 @@ void BM_mask(benchmark::State &state) {
} }
} }
BENCHMARK_TEMPLATE(BM_mask, MaskPortable); BENCHMARK_TEMPLATE(BM_mask, MaskPortable);
#ifdef __aarch64__
BENCHMARK_TEMPLATE(BM_mask, MaskNeonFolly); BENCHMARK_TEMPLATE(BM_mask, MaskNeonFolly);
BENCHMARK_TEMPLATE(BM_mask, MaskNeon); BENCHMARK_TEMPLATE(BM_mask, MaskNeon);
#endif
#define FOR_EACH_TABLE(F) \ #define FOR_EACH_TABLE(F) \
F(td::FlatHashMapChunks) \ F(td::FlatHashMapChunks) \