FlatHashMap: optimizations
This commit is contained in:
parent
b20a98036f
commit
a356cc7e3d
@ -19,20 +19,51 @@
|
|||||||
#include <new>
|
#include <new>
|
||||||
#include <utility>
|
#include <utility>
|
||||||
|
|
||||||
|
template <int shift>
|
||||||
|
struct MaskIterator {
|
||||||
|
uint64_t mask;
|
||||||
|
explicit operator bool() const {
|
||||||
|
return mask != 0;
|
||||||
|
}
|
||||||
|
int pos() const {
|
||||||
|
return td::count_trailing_zeroes64(mask) / shift;
|
||||||
|
}
|
||||||
|
void next() {
|
||||||
|
mask &= mask - 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
// For foreach
|
||||||
|
bool operator!=(MaskIterator &other) const {
|
||||||
|
return mask != other.mask;
|
||||||
|
}
|
||||||
|
auto operator*() const {
|
||||||
|
return pos();
|
||||||
|
}
|
||||||
|
void operator++() {
|
||||||
|
next();
|
||||||
|
}
|
||||||
|
auto begin() {
|
||||||
|
return *this;
|
||||||
|
}
|
||||||
|
auto end() {
|
||||||
|
return MaskIterator{0u};
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
struct MaskPortable {
|
struct MaskPortable {
|
||||||
static uint64_t equal_mask(uint8_t *bytes, uint8_t needle) {
|
static MaskIterator<1> equal_mask(uint8_t *bytes, uint8_t needle) {
|
||||||
uint64_t res = 0;
|
uint64_t res = 0;
|
||||||
for (int i = 0; i < 16; i++) {
|
for (int i = 0; i < 16; i++) {
|
||||||
res |= (bytes[i] == needle) << i;
|
res |= (bytes[i] == needle) << i;
|
||||||
}
|
}
|
||||||
return res;
|
return {res & ((1u << 14) - 1)};
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
#ifdef __aarch64__
|
#ifdef __aarch64__
|
||||||
#include <arm_neon.h>
|
#include <arm_neon.h>
|
||||||
struct MaskNeonFolly {
|
struct MaskNeonFolly {
|
||||||
static uint64_t equal_mask(uint8_t *bytes, uint8_t needle) {
|
static MaskIterator<4> equal_mask(uint8_t *bytes, uint8_t needle) {
|
||||||
uint8x16_t input_mask = vld1q_u8(bytes);
|
uint8x16_t input_mask = vld1q_u8(bytes);
|
||||||
auto needle_mask = vdupq_n_u8(needle);
|
auto needle_mask = vdupq_n_u8(needle);
|
||||||
auto eq_mask = vceqq_u8(input_mask, needle_mask);
|
auto eq_mask = vceqq_u8(input_mask, needle_mask);
|
||||||
@ -40,12 +71,12 @@ struct MaskNeonFolly {
|
|||||||
// by shifting right 4, then round to get it into a 64-bit vector
|
// by shifting right 4, then round to get it into a 64-bit vector
|
||||||
uint8x8_t shifted_eq_mask = vshrn_n_u16(vreinterpretq_u16_u8(eq_mask), 4);
|
uint8x8_t shifted_eq_mask = vshrn_n_u16(vreinterpretq_u16_u8(eq_mask), 4);
|
||||||
uint64_t mask = vget_lane_u64(vreinterpret_u64_u8(shifted_eq_mask), 0);
|
uint64_t mask = vget_lane_u64(vreinterpret_u64_u8(shifted_eq_mask), 0);
|
||||||
return mask & 0x1111111111111111;
|
return {mask & 0x11111111111111};
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
struct MaskNeon {
|
struct MaskNeon {
|
||||||
static uint64_t equal_mask(uint8_t *bytes, uint8_t needle) {
|
static MaskIterator<1> equal_mask(uint8_t *bytes, uint8_t needle) {
|
||||||
uint8x16_t input_mask = vld1q_u8(bytes);
|
uint8x16_t input_mask = vld1q_u8(bytes);
|
||||||
auto needle_mask = vdupq_n_u8(needle);
|
auto needle_mask = vdupq_n_u8(needle);
|
||||||
auto eq_mask = vceqq_u8(input_mask, needle_mask);
|
auto eq_mask = vceqq_u8(input_mask, needle_mask);
|
||||||
@ -54,13 +85,13 @@ struct MaskNeon {
|
|||||||
const int16_t __attribute__((aligned(16))) SHIFT_ARR[8] = {-7, -5, -3, -1, 1, 3, 5, 7};
|
const int16_t __attribute__((aligned(16))) SHIFT_ARR[8] = {-7, -5, -3, -1, 1, 3, 5, 7};
|
||||||
int16x8_t SHIFT = vld1q_s16(SHIFT_ARR);
|
int16x8_t SHIFT = vld1q_s16(SHIFT_ARR);
|
||||||
uint16x8_t a_shifted = vshlq_u16(a_masked, SHIFT);
|
uint16x8_t a_shifted = vshlq_u16(a_masked, SHIFT);
|
||||||
return vaddvq_u16(a_shifted);
|
return {vaddvq_u16(a_shifted) & ((1u << 14) - 1)};
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
#ifdef __aarch64__
|
#ifdef __aarch64__
|
||||||
using MaskHelper = MaskNeon;
|
using MaskHelper = MaskNeonFolly;
|
||||||
#else
|
#else
|
||||||
using MaskHelper = MaskPortable;
|
using MaskHelper = MaskPortable;
|
||||||
#endif
|
#endif
|
||||||
@ -209,22 +240,24 @@ class FlatHashTableChunks {
|
|||||||
if (empty() || is_key_empty(key)) {
|
if (empty() || is_key_empty(key)) {
|
||||||
return end();
|
return end();
|
||||||
}
|
}
|
||||||
auto hash = calc_hash(key);
|
const auto hash = calc_hash(key);
|
||||||
auto chunk_it = get_chunk_it(hash.chunk_i);
|
auto chunk_it = get_chunk_it(hash.chunk_i);
|
||||||
while (true) {
|
while (true) {
|
||||||
auto chunk_i = chunk_it.next();
|
auto chunk_i = chunk_it.pos();
|
||||||
|
auto chunk_begin = nodes_.begin() + chunk_i * Chunk::CHUNK_SIZE;
|
||||||
|
//__builtin_prefetch(chunk_begin);
|
||||||
auto &chunk = chunks_[chunk_i];
|
auto &chunk = chunks_[chunk_i];
|
||||||
auto mask = MaskHelper::equal_mask(chunk.ctrl, hash.small_hash) & Chunk::MASK;
|
auto mask_it = MaskHelper::equal_mask(chunk.ctrl, hash.small_hash);
|
||||||
while (mask != 0) {
|
for (auto pos : mask_it) {
|
||||||
auto it = nodes_.begin() + td::count_trailing_zeroes64(mask) + chunk_i * Chunk::CHUNK_SIZE;
|
auto it = chunk_begin + pos;
|
||||||
if (EqT()(it->first, key)) {
|
if (likely(EqT()(it->first, key))) {
|
||||||
return Iterator{it, this};
|
return Iterator{it, this};
|
||||||
}
|
}
|
||||||
mask &= mask - 1;
|
|
||||||
}
|
}
|
||||||
if (chunk.skipped_cnt == 0) {
|
if (chunk.skipped_cnt == 0) {
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
|
chunk_it.next();
|
||||||
}
|
}
|
||||||
return end();
|
return end();
|
||||||
}
|
}
|
||||||
@ -283,11 +316,11 @@ class FlatHashTableChunks {
|
|||||||
auto hash = calc_hash(key);
|
auto hash = calc_hash(key);
|
||||||
auto chunk_it = get_chunk_it(hash.chunk_i);
|
auto chunk_it = get_chunk_it(hash.chunk_i);
|
||||||
while (true) {
|
while (true) {
|
||||||
auto chunk_i = chunk_it.next();
|
auto chunk_i = chunk_it.pos();
|
||||||
auto &chunk = chunks_[chunk_i];
|
auto &chunk = chunks_[chunk_i];
|
||||||
auto mask = MaskHelper::equal_mask(chunk.ctrl, 0) & Chunk::MASK;
|
auto mask_it = MaskHelper::equal_mask(chunk.ctrl, 0);
|
||||||
if (mask != 0) {
|
if (mask_it) {
|
||||||
auto shift = td::count_trailing_zeroes64(mask);
|
auto shift = mask_it.pos();
|
||||||
DCHECK(chunk.ctrl[shift] == 0);
|
DCHECK(chunk.ctrl[shift] == 0);
|
||||||
auto node_it = nodes_.begin() + shift + chunk_i * Chunk::CHUNK_SIZE;
|
auto node_it = nodes_.begin() + shift + chunk_i * Chunk::CHUNK_SIZE;
|
||||||
DCHECK(node_it->empty());
|
DCHECK(node_it->empty());
|
||||||
@ -299,6 +332,7 @@ class FlatHashTableChunks {
|
|||||||
}
|
}
|
||||||
CHECK(chunk.skipped_cnt != std::numeric_limits<uint16_t>::max());
|
CHECK(chunk.skipped_cnt != std::numeric_limits<uint16_t>::max());
|
||||||
chunk.skipped_cnt++;
|
chunk.skipped_cnt++;
|
||||||
|
chunk_it.next();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -413,20 +447,21 @@ class FlatHashTableChunks {
|
|||||||
};
|
};
|
||||||
struct ChunkIt {
|
struct ChunkIt {
|
||||||
size_t chunk_i;
|
size_t chunk_i;
|
||||||
size_t chunk_n;
|
size_t chunk_mask;
|
||||||
size_t shift{};
|
size_t shift{};
|
||||||
size_t next() {
|
size_t pos() const {
|
||||||
chunk_i += shift;
|
|
||||||
shift++;
|
|
||||||
if (chunk_i >= chunk_n) {
|
|
||||||
chunk_i -= chunk_n;
|
|
||||||
}
|
|
||||||
return chunk_i;
|
return chunk_i;
|
||||||
}
|
}
|
||||||
|
void next() {
|
||||||
|
DCHECK((chunk_mask & (chunk_mask + 1)) == 0);
|
||||||
|
shift++;
|
||||||
|
chunk_i += shift;
|
||||||
|
chunk_i &= chunk_mask;
|
||||||
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
ChunkIt get_chunk_it(size_t chunk_i) {
|
ChunkIt get_chunk_it(size_t chunk_i) {
|
||||||
return {chunk_i, chunks_.size()};
|
return {chunk_i, chunks_.size() - 1};
|
||||||
}
|
}
|
||||||
|
|
||||||
HashInfo calc_hash(const KeyT &key) {
|
HashInfo calc_hash(const KeyT &key) {
|
||||||
@ -456,11 +491,11 @@ class FlatHashTableChunks {
|
|||||||
auto hash = calc_hash(node.first);
|
auto hash = calc_hash(node.first);
|
||||||
auto chunk_it = get_chunk_it(hash.chunk_i);
|
auto chunk_it = get_chunk_it(hash.chunk_i);
|
||||||
while (true) {
|
while (true) {
|
||||||
auto chunk_i = chunk_it.next();
|
auto chunk_i = chunk_it.pos();
|
||||||
auto &chunk = chunks_[chunk_i];
|
auto &chunk = chunks_[chunk_i];
|
||||||
auto mask = MaskHelper::equal_mask(chunk.ctrl, 0) & Chunk::MASK;
|
auto mask_it = MaskHelper::equal_mask(chunk.ctrl, 0);
|
||||||
if (mask != 0) {
|
if (mask_it) {
|
||||||
auto shift = td::count_trailing_zeroes64(mask);
|
auto shift = mask_it.pos();
|
||||||
auto node_it = nodes_.begin() + shift + chunk_i * Chunk::CHUNK_SIZE;
|
auto node_it = nodes_.begin() + shift + chunk_i * Chunk::CHUNK_SIZE;
|
||||||
DCHECK(node_it->empty());
|
DCHECK(node_it->empty());
|
||||||
*node_it = std::move(node);
|
*node_it = std::move(node);
|
||||||
@ -472,6 +507,7 @@ class FlatHashTableChunks {
|
|||||||
}
|
}
|
||||||
CHECK(chunk.skipped_cnt != std::numeric_limits<uint16_t>::max());
|
CHECK(chunk.skipped_cnt != std::numeric_limits<uint16_t>::max());
|
||||||
chunk.skipped_cnt++;
|
chunk.skipped_cnt++;
|
||||||
|
chunk_it.next();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -490,13 +526,14 @@ class FlatHashTableChunks {
|
|||||||
auto hash = calc_hash(it->first);
|
auto hash = calc_hash(it->first);
|
||||||
auto chunk_it = get_chunk_it(hash.chunk_i);
|
auto chunk_it = get_chunk_it(hash.chunk_i);
|
||||||
while (true) {
|
while (true) {
|
||||||
auto chunk_i = chunk_it.next();
|
auto chunk_i = chunk_it.pos();
|
||||||
auto &chunk = chunks_[chunk_i];
|
auto &chunk = chunks_[chunk_i];
|
||||||
if (chunk_i == empty_chunk_i) {
|
if (chunk_i == empty_chunk_i) {
|
||||||
chunk.ctrl[empty_i - empty_chunk_i * Chunk::CHUNK_SIZE] = 0;
|
chunk.ctrl[empty_i - empty_chunk_i * Chunk::CHUNK_SIZE] = 0;
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
chunk.skipped_cnt--;
|
chunk.skipped_cnt--;
|
||||||
|
chunk_it.next();
|
||||||
}
|
}
|
||||||
it->clear();
|
it->clear();
|
||||||
used_nodes_--;
|
used_nodes_--;
|
||||||
|
@ -328,7 +328,7 @@ TEST(FlatHashMap, stress_test) {
|
|||||||
});
|
});
|
||||||
|
|
||||||
td::RandomSteps runner(std::move(steps));
|
td::RandomSteps runner(std::move(steps));
|
||||||
for (size_t i = 0; i < 1000000000; i++) {
|
for (size_t i = 0; i < 10000000; i++) {
|
||||||
runner.step(rnd);
|
runner.step(rnd);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user