diff --git a/tdutils/td/utils/FlatHashMapLinear.h b/tdutils/td/utils/FlatHashMapLinear.h index 8b905d06b..fde6f745a 100644 --- a/tdutils/td/utils/FlatHashMapLinear.h +++ b/tdutils/td/utils/FlatHashMapLinear.h @@ -11,7 +11,6 @@ #include "td/utils/Random.h" #include -#include #include #include #include @@ -161,70 +160,37 @@ struct SetNode { template class FlatHashTable { - struct FlatHashTableInner { - uint32 used_node_count_; - uint32 bucket_count_mask_; - uint32 bucket_count_; - uint32 begin_bucket_; - NodeT nodes_[1]; - }; - - static constexpr size_t OFFSET = 4 * sizeof(uint32); static constexpr uint32 INVALID_BUCKET = 0xFFFFFFFF; - static inline FlatHashTableInner *get_inner(NodeT *nodes) { - DCHECK(nodes != nullptr); - return reinterpret_cast(reinterpret_cast(nodes) - OFFSET); - } - - static NodeT *allocate_nodes(uint32 size) { + void allocate_nodes(uint32 size) { DCHECK(size >= 8); DCHECK((size & (size - 1)) == 0); - CHECK(size <= min(static_cast(1) << 29, static_cast((0x7FFFFFFF - OFFSET) / sizeof(NodeT)))); - auto inner = static_cast(std::malloc(OFFSET + sizeof(NodeT) * size)); - NodeT *nodes = &inner->nodes_[0]; - for (uint32 i = 0; i < size; i++) { - new (nodes + i) NodeT(); - } - // inner->used_node_count_ = 0; - inner->bucket_count_mask_ = size - 1; - inner->bucket_count_ = size; - inner->begin_bucket_ = INVALID_BUCKET; - return nodes; + CHECK(size <= min(static_cast(1) << 29, static_cast(0x7FFFFFFF / sizeof(NodeT)))); + nodes_ = new NodeT[size]; + // used_node_count_ = 0; + bucket_count_mask_ = size - 1; + bucket_count_ = size; + begin_bucket_ = INVALID_BUCKET; } static void clear_nodes(NodeT *nodes) { - auto inner = get_inner(nodes); - auto size = inner->bucket_count_; - for (uint32 i = 0; i < size; i++) { - nodes[i].~NodeT(); - } - std::free(inner); - } - - inline FlatHashTableInner *get_inner() { - return get_inner(nodes_); - } - - inline const FlatHashTableInner *get_inner() const { - DCHECK(nodes_ != nullptr); - return get_inner(const_cast(nodes_)); + delete[] nodes; } inline uint32 &used_node_count() { - return get_inner()->used_node_count_; + return used_node_count_; } inline uint32 get_used_node_count() const { - return get_inner()->used_node_count_; + return used_node_count_; } inline uint32 get_bucket_count_mask() const { - return get_inner()->bucket_count_mask_; + return bucket_count_mask_; } inline uint32 get_bucket_count() const { - return get_inner()->bucket_count_; + return bucket_count_; } public: @@ -351,21 +317,34 @@ class FlatHashTable { used_node_count() = used_nodes; } - FlatHashTable(FlatHashTable &&other) noexcept : nodes_(other.nodes_) { - other.nodes_ = nullptr; + FlatHashTable(FlatHashTable &&other) noexcept + : nodes_(other.nodes_) + , used_node_count_(other.used_node_count_) + , bucket_count_mask_(other.bucket_count_mask_) + , bucket_count_(other.bucket_count_) + , begin_bucket_(other.begin_bucket_) { + other.drop(); } void operator=(FlatHashTable &&other) noexcept { clear(); nodes_ = other.nodes_; - other.nodes_ = nullptr; + used_node_count_ = other.used_node_count_; + bucket_count_mask_ = other.bucket_count_mask_; + bucket_count_ = other.bucket_count_; + begin_bucket_ = other.begin_bucket_; + other.drop(); } void swap(FlatHashTable &other) noexcept { std::swap(nodes_, other.nodes_); + std::swap(used_node_count_, other.used_node_count_); + std::swap(bucket_count_mask_, other.bucket_count_mask_); + std::swap(bucket_count_, other.bucket_count_); + std::swap(begin_bucket_, other.begin_bucket_); } ~FlatHashTable() = default; uint32 bucket_count() const { - return unlikely(nodes_ == nullptr) ? 0 : get_bucket_count(); + return get_bucket_count(); } Iterator find(const KeyT &key) { @@ -390,18 +369,18 @@ class FlatHashTable { } size_t size() const { - return unlikely(nodes_ == nullptr) ? 0 : get_used_node_count(); + return get_used_node_count(); } bool empty() const { - return unlikely(nodes_ == nullptr) || get_used_node_count() == 0; + return get_used_node_count() == 0; } Iterator begin() { if (empty()) { return end(); } - auto &begin_bucket = get_inner()->begin_bucket_; + auto &begin_bucket = begin_bucket_; if (begin_bucket == INVALID_BUCKET) { begin_bucket = Random::fast_uint32() & get_bucket_count_mask(); while (nodes_[begin_bucket].empty()) { @@ -483,7 +462,7 @@ class FlatHashTable { void clear() { if (nodes_ != nullptr) { clear_nodes(nodes_); - nodes_ = nullptr; + drop(); } } @@ -530,6 +509,18 @@ class FlatHashTable { private: NodeT *nodes_ = nullptr; + uint32 used_node_count_ = 0; + uint32 bucket_count_mask_ = 0; + uint32 bucket_count_ = 0; + uint32 begin_bucket_ = 0; + + void drop() { + nodes_ = nullptr; + used_node_count_ = 0; + bucket_count_mask_ = 0; + bucket_count_ = 0; + begin_bucket_ = 0; + } void assign(const FlatHashTable &other) { if (other.size() == 0) { @@ -581,7 +572,7 @@ class FlatHashTable { void resize(uint32 new_size) { if (unlikely(nodes_ == nullptr)) { - nodes_ = allocate_nodes(new_size); + allocate_nodes(new_size); used_node_count() = 0; return; } @@ -589,7 +580,7 @@ class FlatHashTable { auto old_nodes = nodes_; uint32 old_size = get_used_node_count(); uint32 old_bucket_count = get_bucket_count(); - nodes_ = allocate_nodes(new_size); + allocate_nodes(new_size); used_node_count() = old_size; auto old_nodes_end = old_nodes + old_bucket_count; @@ -647,12 +638,13 @@ class FlatHashTable { } inline void invalidate_iterators() { - get_inner()->begin_bucket_ = INVALID_BUCKET; + begin_bucket_ = INVALID_BUCKET; } }; template , class EqT = std::equal_to> using FlatHashMapImpl = FlatHashTable, HashT, EqT>; + template , class EqT = std::equal_to> using FlatHashSetImpl = FlatHashTable, HashT, EqT>;