diff --git a/tdutils/td/utils/FlatHashMapLinear.h b/tdutils/td/utils/FlatHashMapLinear.h index 79f08d7d9..4d07751c6 100644 --- a/tdutils/td/utils/FlatHashMapLinear.h +++ b/tdutils/td/utils/FlatHashMapLinear.h @@ -163,10 +163,13 @@ class FlatHashTable { struct FlatHashTableInner { uint32 used_node_count_; uint32 bucket_count_mask_; + uint32 begin_bucket_; + uint32 padding_; NodeT nodes_[1]; }; - static constexpr size_t OFFSET = 2 * sizeof(uint32); + static constexpr size_t OFFSET = 4 * sizeof(uint32); + static constexpr uint32 INVALID_BUCKET = 0xFFFFFFFF; static NodeT *allocate_nodes(uint32 size) { DCHECK(size >= 8); @@ -179,6 +182,7 @@ class FlatHashTable { } // inner->used_node_count_ = 0; inner->bucket_count_mask_ = size - 1; + inner->begin_bucket_ = INVALID_BUCKET; return nodes; } @@ -387,11 +391,14 @@ class FlatHashTable { if (empty()) { return end(); } - auto bucket = Random::fast_uint32() & get_bucket_count_mask(); - while (nodes_[bucket].empty()) { - next_bucket(bucket); + auto &begin_bucket = get_inner()->begin_bucket_; + if (begin_bucket == INVALID_BUCKET) { + begin_bucket = Random::fast_uint32() & get_bucket_count_mask(); + while (nodes_[begin_bucket].empty()) { + next_bucket(begin_bucket); + } } - return Iterator(nodes_ + bucket, this); + return Iterator(nodes_ + begin_bucket, this); } Iterator end() { return Iterator(nullptr, this); @@ -539,6 +546,7 @@ class FlatHashTable { } else if (unlikely(get_used_node_count() * 5 > get_bucket_count_mask() * 3)) { resize(2 * get_bucket_count_mask() + 2); } + invalidate_iterators(); } void try_shrink() { @@ -546,6 +554,7 @@ class FlatHashTable { if (unlikely(get_used_node_count() * 10 < get_bucket_count_mask() && get_bucket_count_mask() > 7)) { resize(normalize((get_used_node_count() + 1) * 5 / 3 + 1)); } + invalidate_iterators(); } static uint32 normalize(uint32 size) { @@ -615,6 +624,10 @@ class FlatHashTable { } } } + + inline void invalidate_iterators() { + get_inner()->begin_bucket_ = INVALID_BUCKET; + } }; template , class EqT = std::equal_to>