diff --git a/tdutils/td/utils/WaitFreeHashMap.h b/tdutils/td/utils/WaitFreeHashMap.h index 04ff05bba..5392b4493 100644 --- a/tdutils/td/utils/WaitFreeHashMap.h +++ b/tdutils/td/utils/WaitFreeHashMap.h @@ -16,72 +16,86 @@ namespace td { template , class EqT = std::equal_to> class WaitFreeHashMap { - using Storage = FlatHashMap; - static constexpr size_t MAX_STORAGE_COUNT = 1 << 11; + static constexpr size_t MAX_STORAGE_COUNT = 1 << 8; static_assert((MAX_STORAGE_COUNT & (MAX_STORAGE_COUNT - 1)) == 0, ""); - static constexpr size_t MAX_STORAGE_SIZE = 1 << 16; - static_assert((MAX_STORAGE_SIZE & (MAX_STORAGE_SIZE - 1)) == 0, ""); + static constexpr size_t DEFAULT_STORAGE_SIZE = 1 << 14; - Storage default_map_; + FlatHashMap default_map_; struct WaitFreeStorage { - Storage maps_[MAX_STORAGE_COUNT]; + WaitFreeHashMap maps_[MAX_STORAGE_COUNT]; }; unique_ptr wait_free_storage_; + uint32 hash_mult_ = 1; + uint32 max_storage_size_ = DEFAULT_STORAGE_SIZE; - Storage &get_wait_free_storage(const KeyT &key) { - return wait_free_storage_->maps_[randomize_hash(HashT()(key)) & (MAX_STORAGE_COUNT - 1)]; + uint32 get_wait_free_index(const KeyT &key) const { + return randomize_hash(HashT()(key) * hash_mult_) & (MAX_STORAGE_COUNT - 1); } - Storage &get_storage(const KeyT &key) { - if (wait_free_storage_ == nullptr) { - return default_map_; - } - - return get_wait_free_storage(key); + WaitFreeHashMap &get_wait_free_storage(const KeyT &key) { + return wait_free_storage_->maps_[get_wait_free_index(key)]; } - const Storage &get_storage(const KeyT &key) const { - return const_cast(this)->get_storage(key); + const WaitFreeHashMap &get_wait_free_storage(const KeyT &key) const { + return wait_free_storage_->maps_[get_wait_free_index(key)]; } void split_storage() { CHECK(wait_free_storage_ == nullptr); wait_free_storage_ = make_unique(); + auto next_hash_mult = hash_mult_ * 1000000007; + for (uint32 i = 0; i < MAX_STORAGE_COUNT; i++) { + auto &map = wait_free_storage_->maps_[i]; + map.hash_mult_ = next_hash_mult; + map.max_storage_size_ = DEFAULT_STORAGE_SIZE + i * next_hash_mult % DEFAULT_STORAGE_SIZE; + } for (auto &it : default_map_) { - get_wait_free_storage(it.first).emplace(it.first, std::move(it.second)); + get_wait_free_storage(it.first).set(it.first, std::move(it.second)); } default_map_.clear(); } public: void set(const KeyT &key, ValueT value) { - auto &storage = get_storage(key); - storage[key] = std::move(value); - if (default_map_.size() == MAX_STORAGE_SIZE) { + if (wait_free_storage_ != nullptr) { + return get_wait_free_storage(key).set(key, std::move(value)); + } + + default_map_[key] = std::move(value); + if (default_map_.size() == max_storage_size_) { split_storage(); } } ValueT get(const KeyT &key) const { - const auto &storage = get_storage(key); - auto it = storage.find(key); - if (it == storage.end()) { + if (wait_free_storage_ != nullptr) { + return get_wait_free_storage(key).get(key); + } + + auto it = default_map_.find(key); + if (it == default_map_.end()) { return {}; } return it->second; } size_t count(const KeyT &key) const { - const auto &storage = get_storage(key); - return storage.count(key); + if (wait_free_storage_ != nullptr) { + return get_wait_free_storage(key).count(key); + } + + return default_map_.count(key); } // specialization for WaitFreeHashMap<..., unique_ptr> template typename T::element_type *get_pointer(const KeyT &key) { - auto &storage = get_storage(key); - auto it = storage.find(key); - if (it == storage.end()) { + if (wait_free_storage_ != nullptr) { + return get_wait_free_storage(key).get_pointer(key); + } + + auto it = default_map_.find(key); + if (it == default_map_.end()) { return nullptr; } return it->second.get(); @@ -89,9 +103,12 @@ class WaitFreeHashMap { template const typename T::element_type *get_pointer(const KeyT &key) const { - auto &storage = get_storage(key); - auto it = storage.find(key); - if (it == storage.end()) { + if (wait_free_storage_ != nullptr) { + return get_wait_free_storage(key).get_pointer(key); + } + + auto it = default_map_.find(key); + if (it == default_map_.end()) { return nullptr; } return it->second.get(); @@ -100,7 +117,7 @@ class WaitFreeHashMap { ValueT &operator[](const KeyT &key) { if (wait_free_storage_ == nullptr) { ValueT &result = default_map_[key]; - if (default_map_.size() != MAX_STORAGE_SIZE) { + if (default_map_.size() != max_storage_size_) { return result; } @@ -111,10 +128,14 @@ class WaitFreeHashMap { } size_t erase(const KeyT &key) { - return get_storage(key).erase(key); + if (wait_free_storage_ != nullptr) { + return get_wait_free_storage(key).erase(key); + } + + return default_map_.erase(key); } - void foreach(std::function callback) { + void foreach(const std::function &callback) { if (wait_free_storage_ == nullptr) { for (auto &it : default_map_) { callback(it.first, it.second); @@ -122,14 +143,12 @@ class WaitFreeHashMap { return; } - for (size_t i = 0; i < MAX_STORAGE_COUNT; i++) { - for (auto &it : wait_free_storage_->maps_[i]) { - callback(it.first, it.second); - } + for (auto &it : wait_free_storage_->maps_) { + it.foreach(callback); } } - void foreach(std::function callback) const { + void foreach(const std::function &callback) const { if (wait_free_storage_ == nullptr) { for (auto &it : default_map_) { callback(it.first, it.second); @@ -137,10 +156,8 @@ class WaitFreeHashMap { return; } - for (size_t i = 0; i < MAX_STORAGE_COUNT; i++) { - for (auto &it : wait_free_storage_->maps_[i]) { - callback(it.first, it.second); - } + for (auto &it : wait_free_storage_->maps_) { + it.foreach(callback); } } diff --git a/tdutils/test/WaitFreeHashMap.cpp b/tdutils/test/WaitFreeHashMap.cpp index 2b53e14a2..83a1345d9 100644 --- a/tdutils/test/WaitFreeHashMap.cpp +++ b/tdutils/test/WaitFreeHashMap.cpp @@ -24,8 +24,10 @@ TEST(WaitFreeHashMap, stress_test) { return rnd() % 100000 + 1; }; - auto check = [&] { - ASSERT_EQ(reference.size(), map.size()); + auto check = [&](bool check_size = false) { + if (check_size) { + ASSERT_EQ(reference.size(), map.size()); + } ASSERT_EQ(reference.empty(), map.empty()); if (reference.size() < 100) { @@ -58,7 +60,7 @@ TEST(WaitFreeHashMap, stress_test) { add_step(200, [&] { auto key = gen_key(); ASSERT_EQ(reference[key], map[key]); - check(); + check(true); }); add_step(2000, [&] {