Return the same begin() if hashtable wasn't changed.

This commit is contained in:
levlam 2022-02-25 00:24:27 +03:00
parent a657cf6458
commit 01b884858d

View File

@ -163,10 +163,13 @@ class FlatHashTable {
struct FlatHashTableInner { struct FlatHashTableInner {
uint32 used_node_count_; uint32 used_node_count_;
uint32 bucket_count_mask_; uint32 bucket_count_mask_;
uint32 begin_bucket_;
uint32 padding_;
NodeT nodes_[1]; NodeT nodes_[1];
}; };
static constexpr size_t OFFSET = 2 * sizeof(uint32); static constexpr size_t OFFSET = 4 * sizeof(uint32);
static constexpr uint32 INVALID_BUCKET = 0xFFFFFFFF;
static NodeT *allocate_nodes(uint32 size) { static NodeT *allocate_nodes(uint32 size) {
DCHECK(size >= 8); DCHECK(size >= 8);
@ -179,6 +182,7 @@ class FlatHashTable {
} }
// inner->used_node_count_ = 0; // inner->used_node_count_ = 0;
inner->bucket_count_mask_ = size - 1; inner->bucket_count_mask_ = size - 1;
inner->begin_bucket_ = INVALID_BUCKET;
return nodes; return nodes;
} }
@ -387,11 +391,14 @@ class FlatHashTable {
if (empty()) { if (empty()) {
return end(); return end();
} }
auto bucket = Random::fast_uint32() & get_bucket_count_mask(); auto &begin_bucket = get_inner()->begin_bucket_;
while (nodes_[bucket].empty()) { if (begin_bucket == INVALID_BUCKET) {
next_bucket(bucket); begin_bucket = Random::fast_uint32() & get_bucket_count_mask();
while (nodes_[begin_bucket].empty()) {
next_bucket(begin_bucket);
}
} }
return Iterator(nodes_ + bucket, this); return Iterator(nodes_ + begin_bucket, this);
} }
Iterator end() { Iterator end() {
return Iterator(nullptr, this); return Iterator(nullptr, this);
@ -539,6 +546,7 @@ class FlatHashTable {
} else if (unlikely(get_used_node_count() * 5 > get_bucket_count_mask() * 3)) { } else if (unlikely(get_used_node_count() * 5 > get_bucket_count_mask() * 3)) {
resize(2 * get_bucket_count_mask() + 2); resize(2 * get_bucket_count_mask() + 2);
} }
invalidate_iterators();
} }
void try_shrink() { void try_shrink() {
@ -546,6 +554,7 @@ class FlatHashTable {
if (unlikely(get_used_node_count() * 10 < get_bucket_count_mask() && get_bucket_count_mask() > 7)) { if (unlikely(get_used_node_count() * 10 < get_bucket_count_mask() && get_bucket_count_mask() > 7)) {
resize(normalize((get_used_node_count() + 1) * 5 / 3 + 1)); resize(normalize((get_used_node_count() + 1) * 5 / 3 + 1));
} }
invalidate_iterators();
} }
static uint32 normalize(uint32 size) { static uint32 normalize(uint32 size) {
@ -615,6 +624,10 @@ class FlatHashTable {
} }
} }
} }
inline void invalidate_iterators() {
get_inner()->begin_bucket_ = INVALID_BUCKET;
}
}; };
template <class KeyT, class ValueT, class HashT = std::hash<KeyT>, class EqT = std::equal_to<KeyT>> template <class KeyT, class ValueT, class HashT = std::hash<KeyT>, class EqT = std::equal_to<KeyT>>