FlatHashMap: optimizations

This commit is contained in:
Arseny Smirnov 2022-02-18 00:29:29 +01:00
parent b20a98036f
commit a356cc7e3d
2 changed files with 69 additions and 32 deletions

View File

@ -19,20 +19,51 @@
#include <new>
#include <utility>
template <int shift>
struct MaskIterator {
uint64_t mask;
explicit operator bool() const {
return mask != 0;
}
int pos() const {
return td::count_trailing_zeroes64(mask) / shift;
}
void next() {
mask &= mask - 1;
}
// For foreach
bool operator!=(MaskIterator &other) const {
return mask != other.mask;
}
auto operator*() const {
return pos();
}
void operator++() {
next();
}
auto begin() {
return *this;
}
auto end() {
return MaskIterator{0u};
}
};
struct MaskPortable {
static uint64_t equal_mask(uint8_t *bytes, uint8_t needle) {
static MaskIterator<1> equal_mask(uint8_t *bytes, uint8_t needle) {
uint64_t res = 0;
for (int i = 0; i < 16; i++) {
res |= (bytes[i] == needle) << i;
}
return res;
return {res & ((1u << 14) - 1)};
}
};
#ifdef __aarch64__
#include <arm_neon.h>
struct MaskNeonFolly {
static uint64_t equal_mask(uint8_t *bytes, uint8_t needle) {
static MaskIterator<4> equal_mask(uint8_t *bytes, uint8_t needle) {
uint8x16_t input_mask = vld1q_u8(bytes);
auto needle_mask = vdupq_n_u8(needle);
auto eq_mask = vceqq_u8(input_mask, needle_mask);
@ -40,12 +71,12 @@ struct MaskNeonFolly {
// by shifting right 4, then round to get it into a 64-bit vector
uint8x8_t shifted_eq_mask = vshrn_n_u16(vreinterpretq_u16_u8(eq_mask), 4);
uint64_t mask = vget_lane_u64(vreinterpret_u64_u8(shifted_eq_mask), 0);
return mask & 0x1111111111111111;
return {mask & 0x11111111111111};
}
};
struct MaskNeon {
static uint64_t equal_mask(uint8_t *bytes, uint8_t needle) {
static MaskIterator<1> equal_mask(uint8_t *bytes, uint8_t needle) {
uint8x16_t input_mask = vld1q_u8(bytes);
auto needle_mask = vdupq_n_u8(needle);
auto eq_mask = vceqq_u8(input_mask, needle_mask);
@ -54,13 +85,13 @@ struct MaskNeon {
const int16_t __attribute__((aligned(16))) SHIFT_ARR[8] = {-7, -5, -3, -1, 1, 3, 5, 7};
int16x8_t SHIFT = vld1q_s16(SHIFT_ARR);
uint16x8_t a_shifted = vshlq_u16(a_masked, SHIFT);
return vaddvq_u16(a_shifted);
return {vaddvq_u16(a_shifted) & ((1u << 14) - 1)};
}
};
#endif
#ifdef __aarch64__
using MaskHelper = MaskNeon;
using MaskHelper = MaskNeonFolly;
#else
using MaskHelper = MaskPortable;
#endif
@ -209,22 +240,24 @@ class FlatHashTableChunks {
if (empty() || is_key_empty(key)) {
return end();
}
auto hash = calc_hash(key);
const auto hash = calc_hash(key);
auto chunk_it = get_chunk_it(hash.chunk_i);
while (true) {
auto chunk_i = chunk_it.next();
auto chunk_i = chunk_it.pos();
auto chunk_begin = nodes_.begin() + chunk_i * Chunk::CHUNK_SIZE;
//__builtin_prefetch(chunk_begin);
auto &chunk = chunks_[chunk_i];
auto mask = MaskHelper::equal_mask(chunk.ctrl, hash.small_hash) & Chunk::MASK;
while (mask != 0) {
auto it = nodes_.begin() + td::count_trailing_zeroes64(mask) + chunk_i * Chunk::CHUNK_SIZE;
if (EqT()(it->first, key)) {
auto mask_it = MaskHelper::equal_mask(chunk.ctrl, hash.small_hash);
for (auto pos : mask_it) {
auto it = chunk_begin + pos;
if (likely(EqT()(it->first, key))) {
return Iterator{it, this};
}
mask &= mask - 1;
}
if (chunk.skipped_cnt == 0) {
break;
}
chunk_it.next();
}
return end();
}
@ -283,11 +316,11 @@ class FlatHashTableChunks {
auto hash = calc_hash(key);
auto chunk_it = get_chunk_it(hash.chunk_i);
while (true) {
auto chunk_i = chunk_it.next();
auto chunk_i = chunk_it.pos();
auto &chunk = chunks_[chunk_i];
auto mask = MaskHelper::equal_mask(chunk.ctrl, 0) & Chunk::MASK;
if (mask != 0) {
auto shift = td::count_trailing_zeroes64(mask);
auto mask_it = MaskHelper::equal_mask(chunk.ctrl, 0);
if (mask_it) {
auto shift = mask_it.pos();
DCHECK(chunk.ctrl[shift] == 0);
auto node_it = nodes_.begin() + shift + chunk_i * Chunk::CHUNK_SIZE;
DCHECK(node_it->empty());
@ -299,6 +332,7 @@ class FlatHashTableChunks {
}
CHECK(chunk.skipped_cnt != std::numeric_limits<uint16_t>::max());
chunk.skipped_cnt++;
chunk_it.next();
}
}
@ -413,20 +447,21 @@ class FlatHashTableChunks {
};
struct ChunkIt {
size_t chunk_i;
size_t chunk_n;
size_t chunk_mask;
size_t shift{};
size_t next() {
chunk_i += shift;
shift++;
if (chunk_i >= chunk_n) {
chunk_i -= chunk_n;
}
size_t pos() const {
return chunk_i;
}
void next() {
DCHECK((chunk_mask & (chunk_mask + 1)) == 0);
shift++;
chunk_i += shift;
chunk_i &= chunk_mask;
}
};
ChunkIt get_chunk_it(size_t chunk_i) {
return {chunk_i, chunks_.size()};
return {chunk_i, chunks_.size() - 1};
}
HashInfo calc_hash(const KeyT &key) {
@ -456,11 +491,11 @@ class FlatHashTableChunks {
auto hash = calc_hash(node.first);
auto chunk_it = get_chunk_it(hash.chunk_i);
while (true) {
auto chunk_i = chunk_it.next();
auto chunk_i = chunk_it.pos();
auto &chunk = chunks_[chunk_i];
auto mask = MaskHelper::equal_mask(chunk.ctrl, 0) & Chunk::MASK;
if (mask != 0) {
auto shift = td::count_trailing_zeroes64(mask);
auto mask_it = MaskHelper::equal_mask(chunk.ctrl, 0);
if (mask_it) {
auto shift = mask_it.pos();
auto node_it = nodes_.begin() + shift + chunk_i * Chunk::CHUNK_SIZE;
DCHECK(node_it->empty());
*node_it = std::move(node);
@ -472,6 +507,7 @@ class FlatHashTableChunks {
}
CHECK(chunk.skipped_cnt != std::numeric_limits<uint16_t>::max());
chunk.skipped_cnt++;
chunk_it.next();
}
}
@ -490,13 +526,14 @@ class FlatHashTableChunks {
auto hash = calc_hash(it->first);
auto chunk_it = get_chunk_it(hash.chunk_i);
while (true) {
auto chunk_i = chunk_it.next();
auto chunk_i = chunk_it.pos();
auto &chunk = chunks_[chunk_i];
if (chunk_i == empty_chunk_i) {
chunk.ctrl[empty_i - empty_chunk_i * Chunk::CHUNK_SIZE] = 0;
break;
}
chunk.skipped_cnt--;
chunk_it.next();
}
it->clear();
used_nodes_--;

View File

@ -328,7 +328,7 @@ TEST(FlatHashMap, stress_test) {
});
td::RandomSteps runner(std::move(steps));
for (size_t i = 0; i < 1000000000; i++) {
for (size_t i = 0; i < 10000000; i++) {
runner.step(rnd);
}
}