diff --git a/benchmark/hashset_memory.cpp b/benchmark/hashset_memory.cpp index 67e871d67..ab453faaa 100644 --- a/benchmark/hashset_memory.cpp +++ b/benchmark/hashset_memory.cpp @@ -33,7 +33,7 @@ static bool use_memprof() { #endif } -static auto get_memory() { +static td::uint64 get_memory() { #if USE_MEMPROF if (use_memprof()) { return get_used_memory_size(); diff --git a/tdutils/td/utils/FlatHashMap.h b/tdutils/td/utils/FlatHashMap.h index 618097b59..bb6d6f391 100644 --- a/tdutils/td/utils/FlatHashMap.h +++ b/tdutils/td/utils/FlatHashMap.h @@ -73,81 +73,142 @@ class fixed_vector { size_t size_{0}; }; -template > -class FlatHashMapImpl { - public: - struct Node { - using first_type = KeyT; - using second_type = ValueT; - KeyT first{}; - union { - ValueT second; - }; - const auto &key() const { - return first; - } - auto &value() { - return second; - } +// TODO: move +template +bool is_key_empty(const KeyT &key) { + return key == KeyT(); +} - Node() { - } - Node(KeyT key, ValueT value) : first(std::move(key)) { - new (&second) ValueT(std::move(value)); - DCHECK(!empty()); - } - ~Node() { - if (!empty()) { - second.~ValueT(); - } - } - Node(Node &&other) noexcept { - *this = std::move(other); - } - Node &operator=(Node &&other) noexcept { - DCHECK(empty()); - DCHECK(!other.empty()); - first = std::move(other.first); - other.first = KeyT{}; - new (&second) ValueT(std::move(other.second)); - other.second.~ValueT(); - return *this; - } - - bool empty() const { - return is_key_empty(key()); - } - - void clear() { - DCHECK(!empty()); - first = KeyT(); - second.~ValueT(); - DCHECK(empty()); - } - - template - void emplace(KeyT key, ArgsT &&...args) { - DCHECK(empty()); - first = std::move(key); - new (&second) ValueT(std::forward(args)...); - DCHECK(!empty()); - } +template +struct MapNode { + using first_type = KeyT; + using second_type = ValueT; + using key_type = KeyT; + using public_type = MapNode; + using value_type = ValueT; + KeyT first{}; + union { + ValueT second; }; - using Self = FlatHashMapImpl; + const auto &key() const { + return first; + } + auto &value() { + return second; + } + auto &get_public() { + return *this; + } + + MapNode() { + } + MapNode(KeyT key, ValueT value) : first(std::move(key)) { + new (&second) ValueT(std::move(value)); + DCHECK(!empty()); + } + ~MapNode() { + if (!empty()) { + second.~ValueT(); + } + } + MapNode(MapNode &&other) noexcept { + *this = std::move(other); + } + MapNode &operator=(MapNode &&other) noexcept { + DCHECK(empty()); + DCHECK(!other.empty()); + first = std::move(other.first); + other.first = KeyT{}; + new (&second) ValueT(std::move(other.second)); + other.second.~ValueT(); + return *this; + } + + bool empty() const { + return is_key_empty(key()); + } + + void clear() { + DCHECK(!empty()); + first = KeyT(); + second.~ValueT(); + DCHECK(empty()); + } + + template + void emplace(KeyT key, ArgsT &&...args) { + DCHECK(empty()); + first = std::move(key); + new (&second) ValueT(std::forward(args)...); + DCHECK(!empty()); + } +}; + +template +struct SetNode { + using first_type = KeyT; + using key_type = KeyT; + using public_type = KeyT; + using value_type = KeyT; + KeyT first{}; + const auto &key() const { + return first; + } + const auto &value() const { + return first; + } + + auto &get_public() { + return first; + } + SetNode() = default; + explicit SetNode(KeyT key) : first(std::move(key)) { + } + SetNode(SetNode &&other) noexcept { + *this = std::move(other); + } + SetNode &operator=(SetNode &&other) noexcept { + DCHECK(empty()); + DCHECK(!other.empty()); + first = std::move(other.first); + other.first = KeyT{}; + return *this; + } + + bool empty() const { + return is_key_empty(key()); + } + + void clear() { + first = KeyT(); + CHECK(empty()); + } + + void emplace(KeyT key) { + first = std::move(key); + } +}; + +template +class FlatHashTable { + public: + using Self = FlatHashTable; + using Node = NodeT; using NodeIterator = typename fixed_vector::iterator; using ConstNodeIterator = typename fixed_vector::const_iterator; - using key_type = KeyT; - using value_type = Node; + using KeyT = typename Node::key_type; + using public_type = typename Node::public_type; + using value_type = typename Node::value_type; struct Iterator { using iterator_category = std::bidirectional_iterator_tag; using difference_type = std::ptrdiff_t; - using value_type = Node; - using pointer = Node *; - using reference = Node &; + using value_type = public_type; + using pointer = public_type *; + using reference = public_type &; - friend class FlatHashMapImpl; + friend class FlatHashTable; Iterator &operator++() { do { ++it_; @@ -160,21 +221,22 @@ class FlatHashMapImpl { } while (it_->empty()); return *this; } - Node &operator*() { - return *it_; + reference operator*() { + return it_->get_public(); } - Node *operator->() { + pointer operator->() { return &*it_; } bool operator==(const Iterator &other) const { - DCHECK(map_ == other.map_); + DCHECK(map_ == other.map_); return it_ == other.it_; } bool operator!=(const Iterator &other) const { - DCHECK(map_ == other.map_); + DCHECK(map_ == other.map_); return it_ != other.it_; } + Iterator() = default; Iterator(NodeIterator it, Self *map) : it_(std::move(it)), map_(map) { } @@ -186,11 +248,11 @@ class FlatHashMapImpl { struct ConstIterator { using iterator_category = std::bidirectional_iterator_tag; using difference_type = std::ptrdiff_t; - using value_type = Node; - using pointer = const Node *; - using reference = const Node &; + using value_type = public_type; + using pointer = const value_type *; + using reference = const value_type &; - friend class FlatHashMapImpl; + friend class FlatHashTable; ConstIterator &operator++() { ++it_; return *this; @@ -199,10 +261,10 @@ class FlatHashMapImpl { --it_; return *this; } - const Node &operator*() { + reference operator*() { return *it_; } - const Node *operator->() { + pointer operator->() { return &*it_; } bool operator==(const ConstIterator &other) const { @@ -212,25 +274,28 @@ class FlatHashMapImpl { return it_ != other.it_; } - explicit ConstIterator(Iterator it) : it_(std::move(it)) { + ConstIterator() = default; + ConstIterator(Iterator it) : it_(std::move(it)) { } private: Iterator it_; }; + using iterator = Iterator; + using const_iterator = ConstIterator; - FlatHashMapImpl() = default; - FlatHashMapImpl(const FlatHashMapImpl &other) : FlatHashMapImpl(other.begin(), other.end()) { + FlatHashTable() = default; + FlatHashTable(const FlatHashTable &other) : FlatHashTable(other.begin(), other.end()) { } - FlatHashMapImpl &operator=(const FlatHashMapImpl &other) { + FlatHashTable &operator=(const FlatHashTable &other) { assign(other.begin(), other.end()); return *this; } - FlatHashMapImpl(std::initializer_list nodes) { + FlatHashTable(std::initializer_list nodes) { reserve(nodes.size()); for (auto &node : nodes) { - CHECK(!is_key_empty(node.first)); + CHECK(!node.empty()); auto bucket = calc_bucket(node.first); while (true) { if (nodes_[bucket].key() == node.first) { @@ -247,29 +312,38 @@ class FlatHashMapImpl { } } - FlatHashMapImpl(FlatHashMapImpl &&other) noexcept : nodes_(std::move(other.nodes_)), used_nodes_(other.used_nodes_) { + FlatHashTable(FlatHashTable &&other) noexcept : nodes_(std::move(other.nodes_)), used_nodes_(other.used_nodes_) { other.used_nodes_ = 0; } - FlatHashMapImpl &operator=(FlatHashMapImpl &&other) noexcept { + FlatHashTable &operator=(FlatHashTable &&other) noexcept { nodes_ = std::move(other.nodes_); used_nodes_ = other.used_nodes_; other.used_nodes_ = 0; return *this; } - ~FlatHashMapImpl() = default; + void swap(FlatHashTable &other) noexcept { + using std::swap; + swap(nodes_, other.nodes_); + swap(used_nodes_, other.used_nodes_); + } + ~FlatHashTable() = default; template - FlatHashMapImpl(ItT begin, ItT end) { + FlatHashTable(ItT begin, ItT end) { assign(begin, end); } + size_t bucket_count() const { + return nodes_.size(); + } + Iterator find(const KeyT &key) { if (empty() || is_key_empty(key)) { return end(); } auto bucket = calc_bucket(key); while (true) { - if (nodes_[bucket].key() == key) { + if (EqT()(nodes_[bucket].key(), key)) { return Iterator{nodes_.begin() + bucket, this}; } if (nodes_[bucket].empty()) { @@ -326,7 +400,7 @@ class FlatHashMapImpl { CHECK(!is_key_empty(key)); auto bucket = calc_bucket(key); while (true) { - if (nodes_[bucket].key() == key) { + if (EqT()(nodes_[bucket].key(), key)) { return {Iterator{nodes_.begin() + bucket, this}, false}; } if (nodes_[bucket].empty()) { @@ -338,8 +412,19 @@ class FlatHashMapImpl { } } - ValueT &operator[](const KeyT &key) { - return emplace(key).first->second; + std::pair insert(KeyT key) { + return emplace(std::move(key)); + } + + template + void insert(ItT begin, ItT end) { + for (; begin != end; ++begin) { + emplace(*begin); + } + } + + value_type &operator[](const KeyT &key) { + return emplace(key).first->value(); } size_t erase(const KeyT &key) { @@ -363,7 +448,7 @@ class FlatHashMapImpl { void erase(Iterator it) { DCHECK(it != end()); - DCHECK(!is_key_empty(it->key())); + DCHECK(!it.it_->empty()); erase_node(it.it_); } @@ -392,9 +477,6 @@ class FlatHashMapImpl { } private: - static bool is_key_empty(const KeyT &key) { - return key == KeyT(); - } fixed_vector nodes_; size_t used_nodes_{}; @@ -496,8 +578,18 @@ class FlatHashMapImpl { } }; -template > -using FlatHashMap = FlatHashMapImpl; -//using FlatHashMap = std::unordered_map; +template , class EqT = std::equal_to> +using FlatHashMapImpl = FlatHashTable, HashT, EqT>; +template , class EqT = std::equal_to> +using FlatHashSetImpl = FlatHashTable, HashT, EqT>; + +template , class EqT = std::equal_to> +using FlatHashMap = FlatHashMapImpl; +//using FlatHashMap = std::unordered_map; + +template , class EqT = std::equal_to> +using FlatHashSet = FlatHashSetImpl; +//using FlatHashSet = std::unordered_set; + } // namespace td diff --git a/tdutils/td/utils/algorithm.h b/tdutils/td/utils/algorithm.h index 1ebfd6f56..0c3702d8c 100644 --- a/tdutils/td/utils/algorithm.h +++ b/tdutils/td/utils/algorithm.h @@ -206,11 +206,11 @@ void table_remove_if(TableT &table, FuncT &&func) { } } -template -class FlatHashMapImpl; +template +class FlatHashTable; -template -void table_remove_if(FlatHashMapImpl &table, FuncT &&func) { +template +void table_remove_if(FlatHashTable &table, FuncT &&func) { table.remove_if(func); } diff --git a/tdutils/test/HashSet.cpp b/tdutils/test/HashSet.cpp index d0a1ad2e1..773944ac5 100644 --- a/tdutils/test/HashSet.cpp +++ b/tdutils/test/HashSet.cpp @@ -24,6 +24,17 @@ static auto extract_kv(const T &reference) { } TEST(FlatHashMap, basic) { + { + td::FlatHashSet s; + s.insert(5); + for (auto x : s) { + } + int N = 100000; + for (int i = 0; i < 10000000; i++) { + s.insert((i + N/2)%N); + s.erase(i%N); + } + } { td::FlatHashMap map; map[1] = 2; diff --git a/tdutils/test/hashset_benchmark.cpp b/tdutils/test/hashset_benchmark.cpp index 07e4b1d80..714726eb4 100644 --- a/tdutils/test/hashset_benchmark.cpp +++ b/tdutils/test/hashset_benchmark.cpp @@ -248,7 +248,7 @@ static void BM_emplace_same(benchmark::State &state) { while (state.KeepRunningBatch(BATCH_SIZE)) { for (size_t i = 0; i < BATCH_SIZE; i++) { - benchmark::DoNotOptimize(table.emplace(key, 43784932)); + benchmark::DoNotOptimize(table.emplace(key + (i & 15) * 100 , 43784932)); } } } @@ -375,6 +375,28 @@ static void BM_remove_if_slow(benchmark::State &state) { constexpr size_t N = 5000; constexpr size_t BATCH_SIZE = 500000; + TableT table; + td::Random::Xorshift128plus rnd(123); + for (size_t i = 0; i < N; i++) { + table.emplace(rnd() + 1, i); + } + auto first_key = table.begin()->first; + { + size_t cnt = 0; + td::table_remove_if(table, [&cnt](auto &) { cnt += 2; return cnt <= N; }); + } + while (state.KeepRunningBatch(BATCH_SIZE)) { + for (size_t i = 0; i < BATCH_SIZE; i++) { + table.emplace(first_key, i); + table.erase(first_key); + } + } +} +template +static void BM_remove_if_slow_old(benchmark::State &state) { + constexpr size_t N = 100000; + constexpr size_t BATCH_SIZE = 500000; + TableT table; while (state.KeepRunningBatch(BATCH_SIZE)) { td::Random::Xorshift128plus rnd(123); @@ -452,14 +474,13 @@ static void benchmark_create(td::Slice name) { #define REGISTER_ERASE_ALL_BENCHMARK(HT) BENCHMARK_TEMPLATE(BM_erase_all_with_begin, HT); #define REGISTER_REMOVE_IF_SLOW_BENCHMARK(HT) BENCHMARK_TEMPLATE(BM_remove_if_slow, HT); -FOR_EACH_TABLE(REGISTER_CACHE3_BENCHMARK) -FOR_EACH_TABLE(REGISTER_CACHE_BENCHMARK) -FOR_EACH_TABLE(REGISTER_CACHE2_BENCHMARK) FOR_EACH_TABLE(REGISTER_REMOVE_IF_SLOW_BENCHMARK) -FOR_EACH_TABLE(REGISTER_ERASE_ALL_BENCHMARK) +FOR_EACH_TABLE(REGISTER_CACHE3_BENCHMARK) +FOR_EACH_TABLE(REGISTER_CACHE2_BENCHMARK) FOR_EACH_TABLE(REGISTER_CACHE_BENCHMARK) FOR_EACH_TABLE(REGISTER_REMOVE_IF_BENCHMARK) FOR_EACH_TABLE(REGISTER_EMPLACE_BENCHMARK) +FOR_EACH_TABLE(REGISTER_ERASE_ALL_BENCHMARK) FOR_EACH_TABLE(REGISTER_GET_BENCHMARK) FOR_EACH_TABLE(REGISTER_FIND_BENCHMARK)