// // Copyright Aliaksei Levin (levlam@telegram.org), Arseny Smirnov (arseny30@gmail.com) 2014-2024 // // Distributed under the Boost Software License, Version 1.0. (See accompanying // file LICENSE_1_0.txt or copy at http://www.boost.org/LICENSE_1_0.txt) // #pragma once #include "td/utils/common.h" #include "td/utils/HazardPointers.h" #include "td/utils/logging.h" #include "td/utils/port/thread_local.h" #include <atomic> #include <condition_variable> #include <mutex> namespace td { // AtomicHashArray<KeyT, ValueT> // Building block for other concurrent hash maps // // Support one operation: // template <class F> // bool with_value(KeyT key, bool should_create, F &&func); // // Finds slot for key, and call func(value) // Creates slot if should_create is true. // Returns true if func was called. // // Concurrent calls with the same key may result in concurrent calls to func(value) // It is responsibility of the caller to handle such races. // // Key should already be random // It is responsibility of the caller to provide unique random key. // One may use injective hash function, or handle collisions in some other way. template <class KeyT, class ValueT> class AtomicHashArray { public: explicit AtomicHashArray(size_t n) : nodes_(n) { } struct Node { std::atomic<KeyT> key{KeyT{}}; ValueT value{}; }; size_t size() const { return nodes_.size(); } Node &node_at(size_t i) { return nodes_[i]; } static KeyT empty_key() { return KeyT{}; } template <class F> bool with_value(KeyT key, bool should_create, F &&f) { DCHECK(key != empty_key()); auto pos = static_cast<size_t>(key) % nodes_.size(); auto n = td::min(td::max(static_cast<size_t>(300), nodes_.size() / 16 + 2), nodes_.size()); for (size_t i = 0; i < n; i++) { pos++; if (pos >= nodes_.size()) { pos = 0; } auto &node = nodes_[pos]; while (true) { auto node_key = node.key.load(std::memory_order_acquire); if (node_key == empty_key()) { if (!should_create) { return false; } KeyT expected_key = empty_key(); if (node.key.compare_exchange_strong(expected_key, key, std::memory_order_relaxed, std::memory_order_relaxed)) { f(node.value); return true; } } else if (node_key == key) { f(node.value); return true; } else { break; } } } return false; } private: std::vector<Node> nodes_; }; // Simple concurrent hash map with multiple limitations template <class KeyT, class ValueT> class ConcurrentHashMap { using HashMap = AtomicHashArray<KeyT, std::atomic<ValueT>>; static HazardPointers<HashMap> hp_; public: explicit ConcurrentHashMap(size_t n = 32) { n = 1; hash_map_.store(make_unique<HashMap>(n).release()); } ConcurrentHashMap(const ConcurrentHashMap &) = delete; ConcurrentHashMap &operator=(const ConcurrentHashMap &) = delete; ConcurrentHashMap(ConcurrentHashMap &&) = delete; ConcurrentHashMap &operator=(ConcurrentHashMap &&) = delete; ~ConcurrentHashMap() { unique_ptr<HashMap>(hash_map_.load()).reset(); } static std::string get_name() { return "ConcurrentHashMap"; } static KeyT empty_key() { return KeyT{}; } static ValueT empty_value() { return ValueT{}; } static ValueT migrate_value() { return (ValueT)(1); // c-style conversion because reinterpret_cast<int>(1) is CE in MSVC } ValueT insert(KeyT key, ValueT value) { CHECK(key != empty_key()); CHECK(value != migrate_value()); typename HazardPointers<HashMap>::Holder holder(hp_, get_thread_id(), 0); while (true) { auto hash_map = holder.protect(hash_map_); if (!hash_map) { do_migrate(nullptr); continue; } bool ok = false; ValueT inserted_value; hash_map->with_value(key, true, [&](auto &node_value) { ValueT expected_value = this->empty_value(); if (node_value.compare_exchange_strong(expected_value, value, std::memory_order_release, std::memory_order_acquire)) { ok = true; inserted_value = value; } else { if (expected_value == this->migrate_value()) { ok = false; } else { ok = true; inserted_value = expected_value; } } }); if (ok) { return inserted_value; } do_migrate(hash_map); } } ValueT find(KeyT key, ValueT value) { typename HazardPointers<HashMap>::Holder holder(hp_, get_thread_id(), 0); while (true) { auto hash_map = holder.protect(hash_map_); if (!hash_map) { do_migrate(nullptr); continue; } bool has_value = hash_map->with_value( key, false, [&](auto &node_value) { value = node_value.load(std::memory_order_acquire); }); if (!has_value || value != migrate_value()) { return value; } do_migrate(hash_map); } } template <class F> void for_each(F &&f) { auto hash_map = hash_map_.load(); CHECK(hash_map); auto size = hash_map->size(); for (size_t i = 0; i < size; i++) { auto &node = hash_map->node_at(i); auto key = node.key.load(std::memory_order_relaxed); auto value = node.value.load(std::memory_order_relaxed); if (key != empty_key()) { CHECK(value != migrate_value()); if (value != empty_value()) { f(key, value); } } } } private: // use no padding intentionally std::atomic<HashMap *> hash_map_{nullptr}; std::mutex migrate_mutex_; std::condition_variable migrate_cv_; int migrate_cnt_{0}; int migrate_generation_{0}; HashMap *migrate_from_hash_map_{nullptr}; HashMap *migrate_to_hash_map_{nullptr}; struct Task { size_t begin; size_t end; bool empty() const { return begin >= end; } size_t size() const { if (empty()) { return 0; } return end - begin; } }; struct TaskCreator { size_t chunk_size; size_t size; std::atomic<size_t> pos{0}; Task create() { auto i = pos++; auto begin = i * chunk_size; auto end = begin + chunk_size; if (end > size) { end = size; } return {begin, end}; } }; TaskCreator task_creator; void do_migrate(HashMap *ptr) { //LOG(ERROR) << "In do_migrate: " << ptr; std::unique_lock<std::mutex> lock(migrate_mutex_); if (hash_map_.load() != ptr) { return; } init_migrate(); CHECK(!ptr || migrate_from_hash_map_ == ptr); migrate_cnt_++; auto migrate_generation = migrate_generation_; lock.unlock(); run_migrate(); lock.lock(); migrate_cnt_--; if (migrate_cnt_ == 0) { finish_migrate(); } migrate_cv_.wait(lock, [&] { return migrate_generation_ != migrate_generation; }); } void finish_migrate() { //LOG(ERROR) << "In finish_migrate"; hash_map_.store(migrate_to_hash_map_); hp_.retire(get_thread_id(), migrate_from_hash_map_); migrate_from_hash_map_ = nullptr; migrate_to_hash_map_ = nullptr; migrate_generation_++; migrate_cv_.notify_all(); } void init_migrate() { if (migrate_from_hash_map_ != nullptr) { return; } //LOG(ERROR) << "In init_migrate"; CHECK(migrate_cnt_ == 0); migrate_generation_++; migrate_from_hash_map_ = hash_map_.exchange(nullptr); auto new_size = migrate_from_hash_map_->size() * 2; migrate_to_hash_map_ = make_unique<HashMap>(new_size).release(); task_creator.chunk_size = 100; task_creator.size = migrate_from_hash_map_->size(); task_creator.pos = 0; } void run_migrate() { //LOG(ERROR) << "In run_migrate"; size_t cnt = 0; while (true) { auto task = task_creator.create(); cnt += task.size(); if (task.empty()) { break; } run_task(task); } //LOG(ERROR) << "In run_migrate " << cnt; } void run_task(Task task) { for (auto i = task.begin; i < task.end; i++) { auto &node = migrate_from_hash_map_->node_at(i); auto old_value = node.value.exchange(migrate_value(), std::memory_order_acq_rel); if (old_value == 0) { continue; } auto node_key = node.key.load(std::memory_order_relaxed); auto ok = migrate_to_hash_map_->with_value( node_key, true, [&](auto &node_value) { node_value.store(old_value, std::memory_order_relaxed); }); LOG_CHECK(ok) << "Migration overflow"; } } }; template <class KeyT, class ValueT> HazardPointers<typename ConcurrentHashMap<KeyT, ValueT>::HashMap> ConcurrentHashMap<KeyT, ValueT>::hp_(64); } // namespace td