Make WaitFreeHashMap recursive.

This commit is contained in:
levlam 2022-11-18 12:53:26 +03:00
parent a1f19371b0
commit e7b7217256
2 changed files with 65 additions and 46 deletions

View File

@ -16,72 +16,86 @@ namespace td {
template <class KeyT, class ValueT, class HashT = std::hash<KeyT>, class EqT = std::equal_to<KeyT>> template <class KeyT, class ValueT, class HashT = std::hash<KeyT>, class EqT = std::equal_to<KeyT>>
class WaitFreeHashMap { class WaitFreeHashMap {
using Storage = FlatHashMap<KeyT, ValueT, HashT, EqT>; static constexpr size_t MAX_STORAGE_COUNT = 1 << 8;
static constexpr size_t MAX_STORAGE_COUNT = 1 << 11;
static_assert((MAX_STORAGE_COUNT & (MAX_STORAGE_COUNT - 1)) == 0, ""); static_assert((MAX_STORAGE_COUNT & (MAX_STORAGE_COUNT - 1)) == 0, "");
static constexpr size_t MAX_STORAGE_SIZE = 1 << 16; static constexpr size_t DEFAULT_STORAGE_SIZE = 1 << 14;
static_assert((MAX_STORAGE_SIZE & (MAX_STORAGE_SIZE - 1)) == 0, "");
Storage default_map_; FlatHashMap<KeyT, ValueT, HashT, EqT> default_map_;
struct WaitFreeStorage { struct WaitFreeStorage {
Storage maps_[MAX_STORAGE_COUNT]; WaitFreeHashMap maps_[MAX_STORAGE_COUNT];
}; };
unique_ptr<WaitFreeStorage> wait_free_storage_; unique_ptr<WaitFreeStorage> wait_free_storage_;
uint32 hash_mult_ = 1;
uint32 max_storage_size_ = DEFAULT_STORAGE_SIZE;
Storage &get_wait_free_storage(const KeyT &key) { uint32 get_wait_free_index(const KeyT &key) const {
return wait_free_storage_->maps_[randomize_hash(HashT()(key)) & (MAX_STORAGE_COUNT - 1)]; return randomize_hash(HashT()(key) * hash_mult_) & (MAX_STORAGE_COUNT - 1);
} }
Storage &get_storage(const KeyT &key) { WaitFreeHashMap &get_wait_free_storage(const KeyT &key) {
if (wait_free_storage_ == nullptr) { return wait_free_storage_->maps_[get_wait_free_index(key)];
return default_map_;
} }
return get_wait_free_storage(key); const WaitFreeHashMap &get_wait_free_storage(const KeyT &key) const {
} return wait_free_storage_->maps_[get_wait_free_index(key)];
const Storage &get_storage(const KeyT &key) const {
return const_cast<WaitFreeHashMap *>(this)->get_storage(key);
} }
void split_storage() { void split_storage() {
CHECK(wait_free_storage_ == nullptr); CHECK(wait_free_storage_ == nullptr);
wait_free_storage_ = make_unique<WaitFreeStorage>(); wait_free_storage_ = make_unique<WaitFreeStorage>();
auto next_hash_mult = hash_mult_ * 1000000007;
for (uint32 i = 0; i < MAX_STORAGE_COUNT; i++) {
auto &map = wait_free_storage_->maps_[i];
map.hash_mult_ = next_hash_mult;
map.max_storage_size_ = DEFAULT_STORAGE_SIZE + i * next_hash_mult % DEFAULT_STORAGE_SIZE;
}
for (auto &it : default_map_) { for (auto &it : default_map_) {
get_wait_free_storage(it.first).emplace(it.first, std::move(it.second)); get_wait_free_storage(it.first).set(it.first, std::move(it.second));
} }
default_map_.clear(); default_map_.clear();
} }
public: public:
void set(const KeyT &key, ValueT value) { void set(const KeyT &key, ValueT value) {
auto &storage = get_storage(key); if (wait_free_storage_ != nullptr) {
storage[key] = std::move(value); return get_wait_free_storage(key).set(key, std::move(value));
if (default_map_.size() == MAX_STORAGE_SIZE) { }
default_map_[key] = std::move(value);
if (default_map_.size() == max_storage_size_) {
split_storage(); split_storage();
} }
} }
ValueT get(const KeyT &key) const { ValueT get(const KeyT &key) const {
const auto &storage = get_storage(key); if (wait_free_storage_ != nullptr) {
auto it = storage.find(key); return get_wait_free_storage(key).get(key);
if (it == storage.end()) { }
auto it = default_map_.find(key);
if (it == default_map_.end()) {
return {}; return {};
} }
return it->second; return it->second;
} }
size_t count(const KeyT &key) const { size_t count(const KeyT &key) const {
const auto &storage = get_storage(key); if (wait_free_storage_ != nullptr) {
return storage.count(key); return get_wait_free_storage(key).count(key);
}
return default_map_.count(key);
} }
// specialization for WaitFreeHashMap<..., unique_ptr<T>> // specialization for WaitFreeHashMap<..., unique_ptr<T>>
template <class T = ValueT> template <class T = ValueT>
typename T::element_type *get_pointer(const KeyT &key) { typename T::element_type *get_pointer(const KeyT &key) {
auto &storage = get_storage(key); if (wait_free_storage_ != nullptr) {
auto it = storage.find(key); return get_wait_free_storage(key).get_pointer(key);
if (it == storage.end()) { }
auto it = default_map_.find(key);
if (it == default_map_.end()) {
return nullptr; return nullptr;
} }
return it->second.get(); return it->second.get();
@ -89,9 +103,12 @@ class WaitFreeHashMap {
template <class T = ValueT> template <class T = ValueT>
const typename T::element_type *get_pointer(const KeyT &key) const { const typename T::element_type *get_pointer(const KeyT &key) const {
auto &storage = get_storage(key); if (wait_free_storage_ != nullptr) {
auto it = storage.find(key); return get_wait_free_storage(key).get_pointer(key);
if (it == storage.end()) { }
auto it = default_map_.find(key);
if (it == default_map_.end()) {
return nullptr; return nullptr;
} }
return it->second.get(); return it->second.get();
@ -100,7 +117,7 @@ class WaitFreeHashMap {
ValueT &operator[](const KeyT &key) { ValueT &operator[](const KeyT &key) {
if (wait_free_storage_ == nullptr) { if (wait_free_storage_ == nullptr) {
ValueT &result = default_map_[key]; ValueT &result = default_map_[key];
if (default_map_.size() != MAX_STORAGE_SIZE) { if (default_map_.size() != max_storage_size_) {
return result; return result;
} }
@ -111,10 +128,14 @@ class WaitFreeHashMap {
} }
size_t erase(const KeyT &key) { size_t erase(const KeyT &key) {
return get_storage(key).erase(key); if (wait_free_storage_ != nullptr) {
return get_wait_free_storage(key).erase(key);
} }
void foreach(std::function<void(const KeyT &key, ValueT &value)> callback) { return default_map_.erase(key);
}
void foreach(const std::function<void(const KeyT &key, ValueT &value)> &callback) {
if (wait_free_storage_ == nullptr) { if (wait_free_storage_ == nullptr) {
for (auto &it : default_map_) { for (auto &it : default_map_) {
callback(it.first, it.second); callback(it.first, it.second);
@ -122,14 +143,12 @@ class WaitFreeHashMap {
return; return;
} }
for (size_t i = 0; i < MAX_STORAGE_COUNT; i++) { for (auto &it : wait_free_storage_->maps_) {
for (auto &it : wait_free_storage_->maps_[i]) { it.foreach(callback);
callback(it.first, it.second);
}
} }
} }
void foreach(std::function<void(const KeyT &key, const ValueT &value)> callback) const { void foreach(const std::function<void(const KeyT &key, const ValueT &value)> &callback) const {
if (wait_free_storage_ == nullptr) { if (wait_free_storage_ == nullptr) {
for (auto &it : default_map_) { for (auto &it : default_map_) {
callback(it.first, it.second); callback(it.first, it.second);
@ -137,10 +156,8 @@ class WaitFreeHashMap {
return; return;
} }
for (size_t i = 0; i < MAX_STORAGE_COUNT; i++) { for (auto &it : wait_free_storage_->maps_) {
for (auto &it : wait_free_storage_->maps_[i]) { it.foreach(callback);
callback(it.first, it.second);
}
} }
} }

View File

@ -24,8 +24,10 @@ TEST(WaitFreeHashMap, stress_test) {
return rnd() % 100000 + 1; return rnd() % 100000 + 1;
}; };
auto check = [&] { auto check = [&](bool check_size = false) {
if (check_size) {
ASSERT_EQ(reference.size(), map.size()); ASSERT_EQ(reference.size(), map.size());
}
ASSERT_EQ(reference.empty(), map.empty()); ASSERT_EQ(reference.empty(), map.empty());
if (reference.size() < 100) { if (reference.size() < 100) {
@ -58,7 +60,7 @@ TEST(WaitFreeHashMap, stress_test) {
add_step(200, [&] { add_step(200, [&] {
auto key = gen_key(); auto key = gen_key();
ASSERT_EQ(reference[key], map[key]); ASSERT_EQ(reference[key], map[key]);
check(); check(true);
}); });
add_step(2000, [&] { add_step(2000, [&] {