From 8d8896d7f0530539127710691acccd2500ff1fa7 Mon Sep 17 00:00:00 2001 From: Arseny Smirnov Date: Wed, 9 Feb 2022 20:59:08 +0100 Subject: [PATCH] FlatHashMap: remove_if; generic td::table_remove_if --- tdutils/td/utils/FlatHashMap.h | 81 +++++++++++++++++++++--------- tdutils/td/utils/algorithm.h | 11 ++++ tdutils/test/HashSet.cpp | 35 +++++++++++++ tdutils/test/hashset_benchmark.cpp | 42 +++++++++++++++- 4 files changed, 144 insertions(+), 25 deletions(-) diff --git a/tdutils/td/utils/FlatHashMap.h b/tdutils/td/utils/FlatHashMap.h index 6ba0bbfac..f5c905ea1 100644 --- a/tdutils/td/utils/FlatHashMap.h +++ b/tdutils/td/utils/FlatHashMap.h @@ -360,33 +360,31 @@ class FlatHashMapImpl { void erase(Iterator it) { DCHECK(it != end()); DCHECK(!is_key_empty(it->key())); - size_t empty_i = it.it_ - nodes_.begin(); - auto empty_bucket = empty_i; - DCHECK(0 <= empty_i && empty_i < nodes_.size()); - nodes_[empty_bucket].clear(); - used_nodes_--; + erase_node(it.it_); + } - for (size_t test_i = empty_i + 1;; test_i++) { - auto test_bucket = test_i; - if (test_bucket >= nodes_.size()) { - test_bucket -= nodes_.size(); - } - - if (nodes_[test_bucket].empty()) { - break; - } - - auto want_i = calc_bucket(nodes_[test_bucket].key()); - if (want_i < empty_i) { - want_i += nodes_.size(); - } - - if (want_i <= empty_i || want_i > test_i) { - nodes_[empty_bucket] = std::move(nodes_[test_bucket]); - empty_i = test_i; - empty_bucket = test_bucket; + template + void remove_if(F &&f) { + auto it = nodes_.begin(); + while (it != nodes_.end() && !it->empty()) { + ++it; + } + auto first_empty = it; + for (; it != nodes_.end();) { + if (!it->empty() && f(*it)) { + erase_node(it); + } else { + ++it; } } + for (it = nodes_.begin(); it != first_empty;) { + if (!it->empty() && f(*it)) { + erase_node(it); + } else { + ++it; + } + } + // TODO: resize hashtable is necessary } private: @@ -449,8 +447,43 @@ class FlatHashMapImpl { bucket = 0; } } + + void erase_node(NodeIterator it) { + size_t empty_i = it - nodes_.begin(); + auto empty_bucket = empty_i; + DCHECK(0 <= empty_i && empty_i < nodes_.size()); + nodes_[empty_bucket].clear(); + used_nodes_--; + + for (size_t test_i = empty_i + 1;; test_i++) { + auto test_bucket = test_i; + if (test_bucket >= nodes_.size()) { + test_bucket -= nodes_.size(); + } + + if (nodes_[test_bucket].empty()) { + break; + } + + auto want_i = calc_bucket(nodes_[test_bucket].key()); + if (want_i < empty_i) { + want_i += nodes_.size(); + } + + if (want_i <= empty_i || want_i > test_i) { + nodes_[empty_bucket] = std::move(nodes_[test_bucket]); + empty_i = test_i; + empty_bucket = test_bucket; + } + } + } }; +template +void table_remove_if(FlatHashMapImpl &table, FuncT &&func) { + table.remove_if(func); +} + template > using FlatHashMap = FlatHashMapImpl; //using FlatHashMap = std::unordered_map; diff --git a/tdutils/td/utils/algorithm.h b/tdutils/td/utils/algorithm.h index f8005b9dd..85623a727 100644 --- a/tdutils/td/utils/algorithm.h +++ b/tdutils/td/utils/algorithm.h @@ -195,4 +195,15 @@ detail::reversion_wrapper reversed(T &iterable) { return {iterable}; } +template +void table_remove_if(TableT &table, FuncT &&func) { + for (auto it = table.begin(); it != table.end();) { + if (func(*it)) { + it = table.erase(it); + } else { + ++it; + } + } +} + } // namespace td diff --git a/tdutils/test/HashSet.cpp b/tdutils/test/HashSet.cpp index cac8a1196..b4ae45756 100644 --- a/tdutils/test/HashSet.cpp +++ b/tdutils/test/HashSet.cpp @@ -9,6 +9,8 @@ #include "td/utils/tests.h" #include +#include "td/utils/Random.h" +#include "td/utils/algorithm.h" TEST(FlatHashMap, basic) { { @@ -47,3 +49,36 @@ TEST(FlatHashMap, basic) { ASSERT_EQ(1u, map.size()); } } + +template +auto extract_kv(const T &reference) { + auto expected = td::transform(reference, [](auto &it) {return std::make_pair(it.first, it.second);}); + std::sort(expected.begin(), expected.end()); + return expected; +} + +TEST(FlatHashMap, remove_if_basic) { + td::Random::Xorshift128plus rnd(123); + + for (int test_i = 0; test_i < 1000000; test_i++) { + std::unordered_map reference; + td::FlatHashMap table; + LOG_IF(ERROR, test_i % 1000 == 0) << test_i; + int N = rnd.fast(1, 1000); + for (int i = 0; i < N; i++) { + auto key = rnd(); + auto value = i; + reference[key] = value; + table[key] = value; + } + ASSERT_EQ(extract_kv(reference), extract_kv(table)); + + std::vector> kv; + td::table_remove_if(table, [&](auto &it) {kv.emplace_back(it.first, it.second); return it.second % 2 == 0; }); + std::sort(kv.begin(), kv.end()); + ASSERT_EQ(extract_kv(reference), kv); + + td::table_remove_if(reference, [](auto &it) {return it.second % 2 == 0;}); + ASSERT_EQ(extract_kv(reference), extract_kv(table)); + } +} diff --git a/tdutils/test/hashset_benchmark.cpp b/tdutils/test/hashset_benchmark.cpp index c67948048..f1a739ccd 100644 --- a/tdutils/test/hashset_benchmark.cpp +++ b/tdutils/test/hashset_benchmark.cpp @@ -4,6 +4,7 @@ // Distributed under the Boost Software License, Version 1.0. (See accompanying // file LICENSE_1_0.txt or copy at http://www.boost.org/LICENSE_1_0.txt) // +#include "td/utils/algorithm.h" #include "td/utils/common.h" #include "td/utils/FlatHashMap.h" #include "td/utils/format.h" @@ -251,6 +252,43 @@ static void BM_emplace_same(benchmark::State &state) { } } +namespace td { +template +void table_remove_if(absl::flat_hash_map &table, FunctT &&func) { + for (auto it = table.begin(); it != table.end();) { + if (func(*it)) { + auto copy = it; + ++it; + table.erase(copy); + } else { + ++it; + } + } +} +} + + +template +static void BM_remove_if(benchmark::State &state) { + constexpr size_t N = 100000; + constexpr size_t BATCH_SIZE = N; + + TableT table; + reserve(table, N); + while (state.KeepRunningBatch(BATCH_SIZE)) { + state.PauseTiming(); + td::Random::Xorshift128plus rnd(123); + for (size_t i = 0; i < N; i++) { + table.emplace(rnd(), i); + } + state.ResumeTiming(); + + td::table_remove_if(table, [](auto &it) { + return it.second % 2 == 0; + }); + } +} + template static void benchmark_create(td::Slice name) { td::Random::Xorshift128plus rnd(123); @@ -296,6 +334,7 @@ static void benchmark_create(td::Slice name) { //BENCHMARK(BM_Get>)->Range(1, 1 << 26); #define REGISTER_GET_BENCHMARK(HT) BENCHMARK(BM_Get>)->Range(1, 1 << 23); +#define REGISTER_REMOVE_IF_BENCHMARK(HT) BENCHMARK(BM_remove_if>); #define REGISTER_FIND_BENCHMARK(HT) \ BENCHMARK(BM_find_same>) \ @@ -308,12 +347,13 @@ static void benchmark_create(td::Slice name) { #define RUN_CREATE_BENCHMARK(HT) benchmark_create>(#HT); +FOR_EACH_TABLE(REGISTER_REMOVE_IF_BENCHMARK) FOR_EACH_TABLE(REGISTER_FIND_BENCHMARK) FOR_EACH_TABLE(REGISTER_EMPLACE_BENCHMARK) FOR_EACH_TABLE(REGISTER_GET_BENCHMARK) int main(int argc, char **argv) { - FOR_EACH_TABLE(RUN_CREATE_BENCHMARK); +// FOR_EACH_TABLE(RUN_CREATE_BENCHMARK); benchmark::Initialize(&argc, argv); benchmark::RunSpecifiedBenchmarks();