diff --git a/tdutils/td/utils/FlatHashMapLinear.h b/tdutils/td/utils/FlatHashMapLinear.h index 05e51679e..2bf1a6ab4 100644 --- a/tdutils/td/utils/FlatHashMapLinear.h +++ b/tdutils/td/utils/FlatHashMapLinear.h @@ -167,24 +167,24 @@ class FlatHashTable { static constexpr size_t OFFSET = 2 * sizeof(uint32); - static NodeT *allocate_nodes(size_t size) { + static NodeT *allocate_nodes(uint32 size) { DCHECK(size >= 8); DCHECK((size & (size - 1)) == 0); CHECK(size <= (1 << 29)); auto inner = static_cast(std::malloc(OFFSET + sizeof(NodeT) * size)); NodeT *nodes = &inner->nodes_[0]; - for (size_t i = 0; i < size; i++) { + for (uint32 i = 0; i < size; i++) { new (nodes + i) NodeT(); } // inner->used_node_count_ = 0; - inner->bucket_count_mask_ = static_cast(size - 1); + inner->bucket_count_mask_ = size - 1; return nodes; } static void clear_inner(FlatHashTableInner *inner) { auto size = inner->bucket_count_mask_ + 1; NodeT *nodes = &inner->nodes_[0]; - for (size_t i = 0; i < size; i++) { + for (uint32 i = 0; i < size; i++) { nodes[i].~NodeT(); } std::free(inner); @@ -347,7 +347,7 @@ class FlatHashTable { } ~FlatHashTable() = default; - size_t bucket_count() const { + uint32 bucket_count() const { return unlikely(nodes_ == nullptr) ? 0 : get_bucket_count_mask() + 1; } @@ -405,7 +405,8 @@ class FlatHashTable { if (size == 0) { return; } - size_t want_size = normalize(size * 5 / 3 + 1); + CHECK(size <= (1u << 29)); + uint32 want_size = normalize(static_cast(size) * 5 / 3 + 1); if (want_size > bucket_count()) { resize(want_size); } @@ -539,8 +540,8 @@ class FlatHashTable { } } - static size_t normalize(size_t size) { - return td::max(static_cast(1) << (64 - count_leading_zeroes64(size)), static_cast(8)); + static uint32 normalize(uint32 size) { + return td::max(static_cast(1) << (32 - count_leading_zeroes32(size)), static_cast(8)); } uint32 calc_bucket(const KeyT &key) const { @@ -551,14 +552,20 @@ class FlatHashTable { bucket = (bucket + 1) & get_bucket_count_mask(); } - void resize(size_t new_size) { - auto old_nodes = nodes_; - auto old_size = size(); - size_t old_bucket_count = bucket_count(); - nodes_ = allocate_nodes(new_size); - used_node_count() = static_cast(old_size); + void resize(uint32 new_size) { + if (unlikely(nodes_ == nullptr)) { + nodes_ = allocate_nodes(new_size); + return; + } - for (size_t i = 0; i < old_bucket_count; i++) { + auto old_nodes = nodes_; + uint32 old_size = get_used_node_count(); + uint32 old_bucket_count = get_bucket_count_mask() + 1; + ; + nodes_ = allocate_nodes(new_size); + used_node_count() = old_size; + + for (uint32 i = 0; i < old_bucket_count; i++) { auto &old_node = old_nodes[i]; if (old_node.empty()) { continue; @@ -572,13 +579,13 @@ class FlatHashTable { } void erase_node(NodeT *it) { - size_t empty_i = it - nodes_; + DCHECK(nodes_ <= it && it - nodes_ < bucket_count()); + uint32 empty_i = static_cast(it - nodes_); auto empty_bucket = empty_i; - DCHECK(0 <= empty_i && empty_i < bucket_count()); nodes_[empty_bucket].clear(); used_node_count()--; - for (size_t test_i = empty_i + 1;; test_i++) { + for (uint32 test_i = empty_i + 1;; test_i++) { auto test_bucket = test_i; if (test_bucket >= bucket_count()) { test_bucket -= bucket_count();