FlatHashMap: optimizations

This commit is contained in:
Arseny Smirnov 2022-02-18 00:29:29 +01:00
parent b20a98036f
commit a356cc7e3d
2 changed files with 69 additions and 32 deletions

View File

@ -19,20 +19,51 @@
#include <new> #include <new>
#include <utility> #include <utility>
template <int shift>
struct MaskIterator {
uint64_t mask;
explicit operator bool() const {
return mask != 0;
}
int pos() const {
return td::count_trailing_zeroes64(mask) / shift;
}
void next() {
mask &= mask - 1;
}
// For foreach
bool operator!=(MaskIterator &other) const {
return mask != other.mask;
}
auto operator*() const {
return pos();
}
void operator++() {
next();
}
auto begin() {
return *this;
}
auto end() {
return MaskIterator{0u};
}
};
struct MaskPortable { struct MaskPortable {
static uint64_t equal_mask(uint8_t *bytes, uint8_t needle) { static MaskIterator<1> equal_mask(uint8_t *bytes, uint8_t needle) {
uint64_t res = 0; uint64_t res = 0;
for (int i = 0; i < 16; i++) { for (int i = 0; i < 16; i++) {
res |= (bytes[i] == needle) << i; res |= (bytes[i] == needle) << i;
} }
return res; return {res & ((1u << 14) - 1)};
} }
}; };
#ifdef __aarch64__ #ifdef __aarch64__
#include <arm_neon.h> #include <arm_neon.h>
struct MaskNeonFolly { struct MaskNeonFolly {
static uint64_t equal_mask(uint8_t *bytes, uint8_t needle) { static MaskIterator<4> equal_mask(uint8_t *bytes, uint8_t needle) {
uint8x16_t input_mask = vld1q_u8(bytes); uint8x16_t input_mask = vld1q_u8(bytes);
auto needle_mask = vdupq_n_u8(needle); auto needle_mask = vdupq_n_u8(needle);
auto eq_mask = vceqq_u8(input_mask, needle_mask); auto eq_mask = vceqq_u8(input_mask, needle_mask);
@ -40,12 +71,12 @@ struct MaskNeonFolly {
// by shifting right 4, then round to get it into a 64-bit vector // by shifting right 4, then round to get it into a 64-bit vector
uint8x8_t shifted_eq_mask = vshrn_n_u16(vreinterpretq_u16_u8(eq_mask), 4); uint8x8_t shifted_eq_mask = vshrn_n_u16(vreinterpretq_u16_u8(eq_mask), 4);
uint64_t mask = vget_lane_u64(vreinterpret_u64_u8(shifted_eq_mask), 0); uint64_t mask = vget_lane_u64(vreinterpret_u64_u8(shifted_eq_mask), 0);
return mask & 0x1111111111111111; return {mask & 0x11111111111111};
} }
}; };
struct MaskNeon { struct MaskNeon {
static uint64_t equal_mask(uint8_t *bytes, uint8_t needle) { static MaskIterator<1> equal_mask(uint8_t *bytes, uint8_t needle) {
uint8x16_t input_mask = vld1q_u8(bytes); uint8x16_t input_mask = vld1q_u8(bytes);
auto needle_mask = vdupq_n_u8(needle); auto needle_mask = vdupq_n_u8(needle);
auto eq_mask = vceqq_u8(input_mask, needle_mask); auto eq_mask = vceqq_u8(input_mask, needle_mask);
@ -54,13 +85,13 @@ struct MaskNeon {
const int16_t __attribute__((aligned(16))) SHIFT_ARR[8] = {-7, -5, -3, -1, 1, 3, 5, 7}; const int16_t __attribute__((aligned(16))) SHIFT_ARR[8] = {-7, -5, -3, -1, 1, 3, 5, 7};
int16x8_t SHIFT = vld1q_s16(SHIFT_ARR); int16x8_t SHIFT = vld1q_s16(SHIFT_ARR);
uint16x8_t a_shifted = vshlq_u16(a_masked, SHIFT); uint16x8_t a_shifted = vshlq_u16(a_masked, SHIFT);
return vaddvq_u16(a_shifted); return {vaddvq_u16(a_shifted) & ((1u << 14) - 1)};
} }
}; };
#endif #endif
#ifdef __aarch64__ #ifdef __aarch64__
using MaskHelper = MaskNeon; using MaskHelper = MaskNeonFolly;
#else #else
using MaskHelper = MaskPortable; using MaskHelper = MaskPortable;
#endif #endif
@ -209,22 +240,24 @@ class FlatHashTableChunks {
if (empty() || is_key_empty(key)) { if (empty() || is_key_empty(key)) {
return end(); return end();
} }
auto hash = calc_hash(key); const auto hash = calc_hash(key);
auto chunk_it = get_chunk_it(hash.chunk_i); auto chunk_it = get_chunk_it(hash.chunk_i);
while (true) { while (true) {
auto chunk_i = chunk_it.next(); auto chunk_i = chunk_it.pos();
auto chunk_begin = nodes_.begin() + chunk_i * Chunk::CHUNK_SIZE;
//__builtin_prefetch(chunk_begin);
auto &chunk = chunks_[chunk_i]; auto &chunk = chunks_[chunk_i];
auto mask = MaskHelper::equal_mask(chunk.ctrl, hash.small_hash) & Chunk::MASK; auto mask_it = MaskHelper::equal_mask(chunk.ctrl, hash.small_hash);
while (mask != 0) { for (auto pos : mask_it) {
auto it = nodes_.begin() + td::count_trailing_zeroes64(mask) + chunk_i * Chunk::CHUNK_SIZE; auto it = chunk_begin + pos;
if (EqT()(it->first, key)) { if (likely(EqT()(it->first, key))) {
return Iterator{it, this}; return Iterator{it, this};
} }
mask &= mask - 1;
} }
if (chunk.skipped_cnt == 0) { if (chunk.skipped_cnt == 0) {
break; break;
} }
chunk_it.next();
} }
return end(); return end();
} }
@ -283,11 +316,11 @@ class FlatHashTableChunks {
auto hash = calc_hash(key); auto hash = calc_hash(key);
auto chunk_it = get_chunk_it(hash.chunk_i); auto chunk_it = get_chunk_it(hash.chunk_i);
while (true) { while (true) {
auto chunk_i = chunk_it.next(); auto chunk_i = chunk_it.pos();
auto &chunk = chunks_[chunk_i]; auto &chunk = chunks_[chunk_i];
auto mask = MaskHelper::equal_mask(chunk.ctrl, 0) & Chunk::MASK; auto mask_it = MaskHelper::equal_mask(chunk.ctrl, 0);
if (mask != 0) { if (mask_it) {
auto shift = td::count_trailing_zeroes64(mask); auto shift = mask_it.pos();
DCHECK(chunk.ctrl[shift] == 0); DCHECK(chunk.ctrl[shift] == 0);
auto node_it = nodes_.begin() + shift + chunk_i * Chunk::CHUNK_SIZE; auto node_it = nodes_.begin() + shift + chunk_i * Chunk::CHUNK_SIZE;
DCHECK(node_it->empty()); DCHECK(node_it->empty());
@ -299,6 +332,7 @@ class FlatHashTableChunks {
} }
CHECK(chunk.skipped_cnt != std::numeric_limits<uint16_t>::max()); CHECK(chunk.skipped_cnt != std::numeric_limits<uint16_t>::max());
chunk.skipped_cnt++; chunk.skipped_cnt++;
chunk_it.next();
} }
} }
@ -413,20 +447,21 @@ class FlatHashTableChunks {
}; };
struct ChunkIt { struct ChunkIt {
size_t chunk_i; size_t chunk_i;
size_t chunk_n; size_t chunk_mask;
size_t shift{}; size_t shift{};
size_t next() { size_t pos() const {
chunk_i += shift;
shift++;
if (chunk_i >= chunk_n) {
chunk_i -= chunk_n;
}
return chunk_i; return chunk_i;
} }
void next() {
DCHECK((chunk_mask & (chunk_mask + 1)) == 0);
shift++;
chunk_i += shift;
chunk_i &= chunk_mask;
}
}; };
ChunkIt get_chunk_it(size_t chunk_i) { ChunkIt get_chunk_it(size_t chunk_i) {
return {chunk_i, chunks_.size()}; return {chunk_i, chunks_.size() - 1};
} }
HashInfo calc_hash(const KeyT &key) { HashInfo calc_hash(const KeyT &key) {
@ -456,11 +491,11 @@ class FlatHashTableChunks {
auto hash = calc_hash(node.first); auto hash = calc_hash(node.first);
auto chunk_it = get_chunk_it(hash.chunk_i); auto chunk_it = get_chunk_it(hash.chunk_i);
while (true) { while (true) {
auto chunk_i = chunk_it.next(); auto chunk_i = chunk_it.pos();
auto &chunk = chunks_[chunk_i]; auto &chunk = chunks_[chunk_i];
auto mask = MaskHelper::equal_mask(chunk.ctrl, 0) & Chunk::MASK; auto mask_it = MaskHelper::equal_mask(chunk.ctrl, 0);
if (mask != 0) { if (mask_it) {
auto shift = td::count_trailing_zeroes64(mask); auto shift = mask_it.pos();
auto node_it = nodes_.begin() + shift + chunk_i * Chunk::CHUNK_SIZE; auto node_it = nodes_.begin() + shift + chunk_i * Chunk::CHUNK_SIZE;
DCHECK(node_it->empty()); DCHECK(node_it->empty());
*node_it = std::move(node); *node_it = std::move(node);
@ -472,6 +507,7 @@ class FlatHashTableChunks {
} }
CHECK(chunk.skipped_cnt != std::numeric_limits<uint16_t>::max()); CHECK(chunk.skipped_cnt != std::numeric_limits<uint16_t>::max());
chunk.skipped_cnt++; chunk.skipped_cnt++;
chunk_it.next();
} }
} }
@ -490,13 +526,14 @@ class FlatHashTableChunks {
auto hash = calc_hash(it->first); auto hash = calc_hash(it->first);
auto chunk_it = get_chunk_it(hash.chunk_i); auto chunk_it = get_chunk_it(hash.chunk_i);
while (true) { while (true) {
auto chunk_i = chunk_it.next(); auto chunk_i = chunk_it.pos();
auto &chunk = chunks_[chunk_i]; auto &chunk = chunks_[chunk_i];
if (chunk_i == empty_chunk_i) { if (chunk_i == empty_chunk_i) {
chunk.ctrl[empty_i - empty_chunk_i * Chunk::CHUNK_SIZE] = 0; chunk.ctrl[empty_i - empty_chunk_i * Chunk::CHUNK_SIZE] = 0;
break; break;
} }
chunk.skipped_cnt--; chunk.skipped_cnt--;
chunk_it.next();
} }
it->clear(); it->clear();
used_nodes_--; used_nodes_--;

View File

@ -328,7 +328,7 @@ TEST(FlatHashMap, stress_test) {
}); });
td::RandomSteps runner(std::move(steps)); td::RandomSteps runner(std::move(steps));
for (size_t i = 0; i < 1000000000; i++) { for (size_t i = 0; i < 10000000; i++) {
runner.step(rnd); runner.step(rnd);
} }
} }