FlatHashMap: optimizations
This commit is contained in:
parent
b20a98036f
commit
a356cc7e3d
@ -19,20 +19,51 @@
|
||||
#include <new>
|
||||
#include <utility>
|
||||
|
||||
template <int shift>
|
||||
struct MaskIterator {
|
||||
uint64_t mask;
|
||||
explicit operator bool() const {
|
||||
return mask != 0;
|
||||
}
|
||||
int pos() const {
|
||||
return td::count_trailing_zeroes64(mask) / shift;
|
||||
}
|
||||
void next() {
|
||||
mask &= mask - 1;
|
||||
}
|
||||
|
||||
// For foreach
|
||||
bool operator!=(MaskIterator &other) const {
|
||||
return mask != other.mask;
|
||||
}
|
||||
auto operator*() const {
|
||||
return pos();
|
||||
}
|
||||
void operator++() {
|
||||
next();
|
||||
}
|
||||
auto begin() {
|
||||
return *this;
|
||||
}
|
||||
auto end() {
|
||||
return MaskIterator{0u};
|
||||
}
|
||||
};
|
||||
|
||||
struct MaskPortable {
|
||||
static uint64_t equal_mask(uint8_t *bytes, uint8_t needle) {
|
||||
static MaskIterator<1> equal_mask(uint8_t *bytes, uint8_t needle) {
|
||||
uint64_t res = 0;
|
||||
for (int i = 0; i < 16; i++) {
|
||||
res |= (bytes[i] == needle) << i;
|
||||
}
|
||||
return res;
|
||||
return {res & ((1u << 14) - 1)};
|
||||
}
|
||||
};
|
||||
|
||||
#ifdef __aarch64__
|
||||
#include <arm_neon.h>
|
||||
struct MaskNeonFolly {
|
||||
static uint64_t equal_mask(uint8_t *bytes, uint8_t needle) {
|
||||
static MaskIterator<4> equal_mask(uint8_t *bytes, uint8_t needle) {
|
||||
uint8x16_t input_mask = vld1q_u8(bytes);
|
||||
auto needle_mask = vdupq_n_u8(needle);
|
||||
auto eq_mask = vceqq_u8(input_mask, needle_mask);
|
||||
@ -40,12 +71,12 @@ struct MaskNeonFolly {
|
||||
// by shifting right 4, then round to get it into a 64-bit vector
|
||||
uint8x8_t shifted_eq_mask = vshrn_n_u16(vreinterpretq_u16_u8(eq_mask), 4);
|
||||
uint64_t mask = vget_lane_u64(vreinterpret_u64_u8(shifted_eq_mask), 0);
|
||||
return mask & 0x1111111111111111;
|
||||
return {mask & 0x11111111111111};
|
||||
}
|
||||
};
|
||||
|
||||
struct MaskNeon {
|
||||
static uint64_t equal_mask(uint8_t *bytes, uint8_t needle) {
|
||||
static MaskIterator<1> equal_mask(uint8_t *bytes, uint8_t needle) {
|
||||
uint8x16_t input_mask = vld1q_u8(bytes);
|
||||
auto needle_mask = vdupq_n_u8(needle);
|
||||
auto eq_mask = vceqq_u8(input_mask, needle_mask);
|
||||
@ -54,13 +85,13 @@ struct MaskNeon {
|
||||
const int16_t __attribute__((aligned(16))) SHIFT_ARR[8] = {-7, -5, -3, -1, 1, 3, 5, 7};
|
||||
int16x8_t SHIFT = vld1q_s16(SHIFT_ARR);
|
||||
uint16x8_t a_shifted = vshlq_u16(a_masked, SHIFT);
|
||||
return vaddvq_u16(a_shifted);
|
||||
return {vaddvq_u16(a_shifted) & ((1u << 14) - 1)};
|
||||
}
|
||||
};
|
||||
#endif
|
||||
|
||||
#ifdef __aarch64__
|
||||
using MaskHelper = MaskNeon;
|
||||
using MaskHelper = MaskNeonFolly;
|
||||
#else
|
||||
using MaskHelper = MaskPortable;
|
||||
#endif
|
||||
@ -209,22 +240,24 @@ class FlatHashTableChunks {
|
||||
if (empty() || is_key_empty(key)) {
|
||||
return end();
|
||||
}
|
||||
auto hash = calc_hash(key);
|
||||
const auto hash = calc_hash(key);
|
||||
auto chunk_it = get_chunk_it(hash.chunk_i);
|
||||
while (true) {
|
||||
auto chunk_i = chunk_it.next();
|
||||
auto chunk_i = chunk_it.pos();
|
||||
auto chunk_begin = nodes_.begin() + chunk_i * Chunk::CHUNK_SIZE;
|
||||
//__builtin_prefetch(chunk_begin);
|
||||
auto &chunk = chunks_[chunk_i];
|
||||
auto mask = MaskHelper::equal_mask(chunk.ctrl, hash.small_hash) & Chunk::MASK;
|
||||
while (mask != 0) {
|
||||
auto it = nodes_.begin() + td::count_trailing_zeroes64(mask) + chunk_i * Chunk::CHUNK_SIZE;
|
||||
if (EqT()(it->first, key)) {
|
||||
auto mask_it = MaskHelper::equal_mask(chunk.ctrl, hash.small_hash);
|
||||
for (auto pos : mask_it) {
|
||||
auto it = chunk_begin + pos;
|
||||
if (likely(EqT()(it->first, key))) {
|
||||
return Iterator{it, this};
|
||||
}
|
||||
mask &= mask - 1;
|
||||
}
|
||||
if (chunk.skipped_cnt == 0) {
|
||||
break;
|
||||
}
|
||||
chunk_it.next();
|
||||
}
|
||||
return end();
|
||||
}
|
||||
@ -283,11 +316,11 @@ class FlatHashTableChunks {
|
||||
auto hash = calc_hash(key);
|
||||
auto chunk_it = get_chunk_it(hash.chunk_i);
|
||||
while (true) {
|
||||
auto chunk_i = chunk_it.next();
|
||||
auto chunk_i = chunk_it.pos();
|
||||
auto &chunk = chunks_[chunk_i];
|
||||
auto mask = MaskHelper::equal_mask(chunk.ctrl, 0) & Chunk::MASK;
|
||||
if (mask != 0) {
|
||||
auto shift = td::count_trailing_zeroes64(mask);
|
||||
auto mask_it = MaskHelper::equal_mask(chunk.ctrl, 0);
|
||||
if (mask_it) {
|
||||
auto shift = mask_it.pos();
|
||||
DCHECK(chunk.ctrl[shift] == 0);
|
||||
auto node_it = nodes_.begin() + shift + chunk_i * Chunk::CHUNK_SIZE;
|
||||
DCHECK(node_it->empty());
|
||||
@ -299,6 +332,7 @@ class FlatHashTableChunks {
|
||||
}
|
||||
CHECK(chunk.skipped_cnt != std::numeric_limits<uint16_t>::max());
|
||||
chunk.skipped_cnt++;
|
||||
chunk_it.next();
|
||||
}
|
||||
}
|
||||
|
||||
@ -413,20 +447,21 @@ class FlatHashTableChunks {
|
||||
};
|
||||
struct ChunkIt {
|
||||
size_t chunk_i;
|
||||
size_t chunk_n;
|
||||
size_t chunk_mask;
|
||||
size_t shift{};
|
||||
size_t next() {
|
||||
chunk_i += shift;
|
||||
shift++;
|
||||
if (chunk_i >= chunk_n) {
|
||||
chunk_i -= chunk_n;
|
||||
}
|
||||
size_t pos() const {
|
||||
return chunk_i;
|
||||
}
|
||||
void next() {
|
||||
DCHECK((chunk_mask & (chunk_mask + 1)) == 0);
|
||||
shift++;
|
||||
chunk_i += shift;
|
||||
chunk_i &= chunk_mask;
|
||||
}
|
||||
};
|
||||
|
||||
ChunkIt get_chunk_it(size_t chunk_i) {
|
||||
return {chunk_i, chunks_.size()};
|
||||
return {chunk_i, chunks_.size() - 1};
|
||||
}
|
||||
|
||||
HashInfo calc_hash(const KeyT &key) {
|
||||
@ -456,11 +491,11 @@ class FlatHashTableChunks {
|
||||
auto hash = calc_hash(node.first);
|
||||
auto chunk_it = get_chunk_it(hash.chunk_i);
|
||||
while (true) {
|
||||
auto chunk_i = chunk_it.next();
|
||||
auto chunk_i = chunk_it.pos();
|
||||
auto &chunk = chunks_[chunk_i];
|
||||
auto mask = MaskHelper::equal_mask(chunk.ctrl, 0) & Chunk::MASK;
|
||||
if (mask != 0) {
|
||||
auto shift = td::count_trailing_zeroes64(mask);
|
||||
auto mask_it = MaskHelper::equal_mask(chunk.ctrl, 0);
|
||||
if (mask_it) {
|
||||
auto shift = mask_it.pos();
|
||||
auto node_it = nodes_.begin() + shift + chunk_i * Chunk::CHUNK_SIZE;
|
||||
DCHECK(node_it->empty());
|
||||
*node_it = std::move(node);
|
||||
@ -472,6 +507,7 @@ class FlatHashTableChunks {
|
||||
}
|
||||
CHECK(chunk.skipped_cnt != std::numeric_limits<uint16_t>::max());
|
||||
chunk.skipped_cnt++;
|
||||
chunk_it.next();
|
||||
}
|
||||
}
|
||||
|
||||
@ -490,13 +526,14 @@ class FlatHashTableChunks {
|
||||
auto hash = calc_hash(it->first);
|
||||
auto chunk_it = get_chunk_it(hash.chunk_i);
|
||||
while (true) {
|
||||
auto chunk_i = chunk_it.next();
|
||||
auto chunk_i = chunk_it.pos();
|
||||
auto &chunk = chunks_[chunk_i];
|
||||
if (chunk_i == empty_chunk_i) {
|
||||
chunk.ctrl[empty_i - empty_chunk_i * Chunk::CHUNK_SIZE] = 0;
|
||||
break;
|
||||
}
|
||||
chunk.skipped_cnt--;
|
||||
chunk_it.next();
|
||||
}
|
||||
it->clear();
|
||||
used_nodes_--;
|
||||
|
@ -328,7 +328,7 @@ TEST(FlatHashMap, stress_test) {
|
||||
});
|
||||
|
||||
td::RandomSteps runner(std::move(steps));
|
||||
for (size_t i = 0; i < 1000000000; i++) {
|
||||
for (size_t i = 0; i < 10000000; i++) {
|
||||
runner.step(rnd);
|
||||
}
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user