From 5d074a4b18d1c4bf7f1724176be59aa971f4709e Mon Sep 17 00:00:00 2001 From: Arseny Smirnov Date: Tue, 8 Feb 2022 19:11:14 +0100 Subject: [PATCH] FlatHashMap: some optimizations --- tdutils/td/utils/FlatHashMap.h | 175 ++++++++++++++++++++++++--------- 1 file changed, 129 insertions(+), 46 deletions(-) diff --git a/tdutils/td/utils/FlatHashMap.h b/tdutils/td/utils/FlatHashMap.h index fe331cf39..3f2ec8387 100644 --- a/tdutils/td/utils/FlatHashMap.h +++ b/tdutils/td/utils/FlatHashMap.h @@ -6,7 +6,9 @@ // #pragma once +#include "td/utils/bits.h" #include "td/utils/common.h" +#include "td/utils/logging.h" #include #include @@ -17,6 +19,57 @@ namespace td { +template +class fixed_vector { + public: + fixed_vector() = default; + explicit fixed_vector(size_t size) : ptr_(new T[size]), size_(size) { + } + fixed_vector(fixed_vector &&other) noexcept { + swap(other); + } + fixed_vector &operator=(fixed_vector &&other) noexcept { + swap(other); + return *this; + } + fixed_vector(const fixed_vector &) = delete; + fixed_vector &operator=(const fixed_vector &) = delete; + ~fixed_vector() { + delete[] ptr_; + } + T &operator[](size_t i) { + return ptr_[i]; + } + const T &operator[](size_t i) const { + return ptr_[i]; + } + T *begin() { + return ptr_; + } + const T *begin() const { + return ptr_; + } + T *end() { + return ptr_ + size_; + } + const T *end() const { + return ptr_ + size_; + } + size_t size() const { + return size_; + } + using iterator = T *; + using const_iterator = const T *; + void swap(fixed_vector &other) { + std::swap(ptr_, other.ptr_); + std::swap(size_, other.size_); + } + + private: + T *ptr_{}; + size_t size_{0}; +}; + template > class FlatHashMapImpl { struct Node { @@ -71,10 +124,14 @@ class FlatHashMapImpl { } }; using Self = FlatHashMapImpl; - using NodeIterator = typename std::vector::iterator; - using ConstNodeIterator = typename std::vector::const_iterator; + using NodeIterator = typename fixed_vector::iterator; + using ConstNodeIterator = typename fixed_vector::const_iterator; public: + // define key_type and value_type for benchmarks + using key_type = KeyT; + using value_type = std::pair; + struct Iterator { using iterator_category = std::bidirectional_iterator_tag; using difference_type = std::ptrdiff_t; @@ -177,19 +234,37 @@ class FlatHashMapImpl { assign(begin, end); } + template + auto with_node(const KeyT &key, SomeF &&some, NoneF &&none) { + size_t bucket = calc_bucket(key); + while (true) { + if (nodes_[bucket].key() == key) { + return some(nodes_.begin() + bucket); + } + if (nodes_[bucket].empty()) { + return none(nodes_.begin() + bucket); + } + bucket++; + if (bucket == nodes_.size()) { + bucket = 0; + } + } + } + Iterator find(const KeyT &key) { if (empty()) { return end(); } - auto it = find_bucket_for_insert(key); - if (it->empty()) { - return end(); - } - return Iterator(it, this); + return with_node( + key, + [&](auto it) { + return Iterator{it, this}; + }, + [&](auto it) { return end(); }); } ConstIterator find(const KeyT &key) const { - return ConstIterator(const_cast(this)->find(key)); + return ConstIterator(const_cast(this)->find(key)); } size_t size() const { @@ -221,33 +296,42 @@ class FlatHashMapImpl { return ConstIterator(const_cast(this)->end()); } + void reserve(size_t size) { + size_t want_size = normalize(size * 10 / 6 + 1); + // size_t want_size = size * 2; + if (want_size > nodes_.size()) { + resize(want_size); + } + } + template std::pair emplace(KeyT key, ArgsT &&...args) { - if (should_resize()) { - resize(used_nodes_ + 1); + if (unlikely(should_resize())) { + grow(); } - auto it = find_bucket_for_insert(key); - if (it->empty()) { - it->emplace(std::move(key), std::forward(args)...); - used_nodes_++; - return std::make_pair(Iterator(it, this), true); - } - return std::make_pair(Iterator(it, this), false); + return with_node( + key, [&](auto it) { return std::make_pair(Iterator(it, this), false); }, + [&](auto it) { + it->emplace(std::move(key), std::forward(args)...); + used_nodes_++; + return std::make_pair(Iterator(it, this), true); + }); } ValueT &operator[](const KeyT &key) { DCHECK(!is_key_empty(key)); if (should_resize()) { - resize(used_nodes_ + 1); + grow(); } - auto it = find_bucket_for_insert(key); - if (it->empty()) { - it->emplace(key); - used_nodes_++; - } - return it->second; + return *with_node( + key, [&](auto it) { return &it->second; }, + [&](auto it) { + it->emplace(key); + used_nodes_++; + return &it->second; + }); } size_t erase(const KeyT &key) { @@ -265,7 +349,7 @@ class FlatHashMapImpl { void clear() { used_nodes_ = 0; - nodes_.clear(); + nodes_ = {}; } void erase(Iterator it) { @@ -305,7 +389,7 @@ class FlatHashMapImpl { return key == KeyT(); } - vector nodes_; + fixed_vector nodes_; size_t used_nodes_{}; template @@ -317,43 +401,42 @@ class FlatHashMapImpl { } bool should_resize() const { - return (used_nodes_ + 1) * 10 > nodes_.size() * 6; + return should_resize(used_nodes_ + 1, nodes_.size()); + } + static bool should_resize(size_t used_count, size_t buckets_count) { + return used_count * 10 > buckets_count * 6; } size_t calc_bucket(const KeyT &key) const { return HashT()(key) * 2 % nodes_.size(); } - NodeIterator find_bucket_for_insert(const KeyT &key) { - size_t bucket = calc_bucket(key); - while (!(nodes_[bucket].key() == key) && !nodes_[bucket].empty()) { - bucket++; - if (bucket == nodes_.size()) { - bucket = 0; - } - } - return nodes_.begin() + bucket; + static size_t normalize(size_t size) { + return size_t(1) << (64 - count_leading_zeroes64(size)); } - - ConstNodeIterator find_bucket_for_insert(const KeyT &key) const { - return const_cast(find_bucket_for_insert(key)); + void grow() { + size_t want_size = normalize(td::max(nodes_.size() * 2 - 1, (used_nodes_ + 1) * 10 / 6 + 1)); + // size_t want_size = td::max(nodes_.size(), (used_nodes_ + 1)) * 2; + resize(want_size); } + void resize(size_t new_size) { + // LOG(ERROR) << new_size; + fixed_vector old_nodes(new_size); + std::swap(old_nodes, nodes_); - void resize(size_t size) { - auto old_nodes = std::move(nodes_); - nodes_.resize(td::max(old_nodes.size(), size) * 2 + 1); // TODO: some other logic for (auto &node : old_nodes) { if (node.empty()) { continue; } - auto new_node = find_bucket_for_insert(node.key()); - *new_node = std::move(node); + with_node( + node.key(), [](auto it) { UNREACHABLE(); }, [&](auto it) { *it = std::move(node); }); } } }; template > -//using FlatHashMap = FlatHashMapImpl; -using FlatHashMap = std::unordered_map; +using FlatHashMap = FlatHashMapImpl; +//using FlatHashMap = std::unordered_map; +//using FlatHashMap = absl::flat_hash_map; } // namespace td