tdlight/tdutils/td/utils/ConcurrentHashTable.h
2024-01-01 03:07:21 +03:00

323 lines
8.7 KiB
C++

//
// Copyright Aliaksei Levin (levlam@telegram.org), Arseny Smirnov (arseny30@gmail.com) 2014-2024
//
// Distributed under the Boost Software License, Version 1.0. (See accompanying
// file LICENSE_1_0.txt or copy at http://www.boost.org/LICENSE_1_0.txt)
//
#pragma once
#include "td/utils/common.h"
#include "td/utils/HazardPointers.h"
#include "td/utils/logging.h"
#include "td/utils/port/thread_local.h"
#include <atomic>
#include <condition_variable>
#include <mutex>
namespace td {
// AtomicHashArray<KeyT, ValueT>
// Building block for other concurrent hash maps
//
// Support one operation:
// template <class F>
// bool with_value(KeyT key, bool should_create, F &&func);
//
// Finds slot for key, and call func(value)
// Creates slot if should_create is true.
// Returns true if func was called.
//
// Concurrent calls with the same key may result in concurrent calls to func(value)
// It is responsibility of the caller to handle such races.
//
// Key should already be random
// It is responsibility of the caller to provide unique random key.
// One may use injective hash function, or handle collisions in some other way.
template <class KeyT, class ValueT>
class AtomicHashArray {
public:
explicit AtomicHashArray(size_t n) : nodes_(n) {
}
struct Node {
std::atomic<KeyT> key{KeyT{}};
ValueT value{};
};
size_t size() const {
return nodes_.size();
}
Node &node_at(size_t i) {
return nodes_[i];
}
static KeyT empty_key() {
return KeyT{};
}
template <class F>
bool with_value(KeyT key, bool should_create, F &&f) {
DCHECK(key != empty_key());
auto pos = static_cast<size_t>(key) % nodes_.size();
auto n = td::min(td::max(static_cast<size_t>(300), nodes_.size() / 16 + 2), nodes_.size());
for (size_t i = 0; i < n; i++) {
pos++;
if (pos >= nodes_.size()) {
pos = 0;
}
auto &node = nodes_[pos];
while (true) {
auto node_key = node.key.load(std::memory_order_acquire);
if (node_key == empty_key()) {
if (!should_create) {
return false;
}
KeyT expected_key = empty_key();
if (node.key.compare_exchange_strong(expected_key, key, std::memory_order_relaxed,
std::memory_order_relaxed)) {
f(node.value);
return true;
}
} else if (node_key == key) {
f(node.value);
return true;
} else {
break;
}
}
}
return false;
}
private:
std::vector<Node> nodes_;
};
// Simple concurrent hash map with multiple limitations
template <class KeyT, class ValueT>
class ConcurrentHashMap {
using HashMap = AtomicHashArray<KeyT, std::atomic<ValueT>>;
static HazardPointers<HashMap> hp_;
public:
explicit ConcurrentHashMap(size_t n = 32) {
n = 1;
hash_map_.store(make_unique<HashMap>(n).release());
}
ConcurrentHashMap(const ConcurrentHashMap &) = delete;
ConcurrentHashMap &operator=(const ConcurrentHashMap &) = delete;
ConcurrentHashMap(ConcurrentHashMap &&) = delete;
ConcurrentHashMap &operator=(ConcurrentHashMap &&) = delete;
~ConcurrentHashMap() {
unique_ptr<HashMap>(hash_map_.load()).reset();
}
static std::string get_name() {
return "ConcurrrentHashMap";
}
static KeyT empty_key() {
return KeyT{};
}
static ValueT empty_value() {
return ValueT{};
}
static ValueT migrate_value() {
return (ValueT)(1); // c-style conversion because reinterpret_cast<int>(1) is CE in MSVC
}
ValueT insert(KeyT key, ValueT value) {
CHECK(key != empty_key());
CHECK(value != migrate_value());
typename HazardPointers<HashMap>::Holder holder(hp_, get_thread_id(), 0);
while (true) {
auto hash_map = holder.protect(hash_map_);
if (!hash_map) {
do_migrate(nullptr);
continue;
}
bool ok = false;
ValueT inserted_value;
hash_map->with_value(key, true, [&](auto &node_value) {
ValueT expected_value = this->empty_value();
if (node_value.compare_exchange_strong(expected_value, value, std::memory_order_release,
std::memory_order_acquire)) {
ok = true;
inserted_value = value;
} else {
if (expected_value == this->migrate_value()) {
ok = false;
} else {
ok = true;
inserted_value = expected_value;
}
}
});
if (ok) {
return inserted_value;
}
do_migrate(hash_map);
}
}
ValueT find(KeyT key, ValueT value) {
typename HazardPointers<HashMap>::Holder holder(hp_, get_thread_id(), 0);
while (true) {
auto hash_map = holder.protect(hash_map_);
if (!hash_map) {
do_migrate(nullptr);
continue;
}
bool has_value = hash_map->with_value(
key, false, [&](auto &node_value) { value = node_value.load(std::memory_order_acquire); });
if (!has_value || value != migrate_value()) {
return value;
}
do_migrate(hash_map);
}
}
template <class F>
void for_each(F &&f) {
auto hash_map = hash_map_.load();
CHECK(hash_map);
auto size = hash_map->size();
for (size_t i = 0; i < size; i++) {
auto &node = hash_map->node_at(i);
auto key = node.key.load(std::memory_order_relaxed);
auto value = node.value.load(std::memory_order_relaxed);
if (key != empty_key()) {
CHECK(value != migrate_value());
if (value != empty_value()) {
f(key, value);
}
}
}
}
private:
// use no padding intentionally
std::atomic<HashMap *> hash_map_{nullptr};
std::mutex migrate_mutex_;
std::condition_variable migrate_cv_;
int migrate_cnt_{0};
int migrate_generation_{0};
HashMap *migrate_from_hash_map_{nullptr};
HashMap *migrate_to_hash_map_{nullptr};
struct Task {
size_t begin;
size_t end;
bool empty() const {
return begin >= end;
}
size_t size() const {
if (empty()) {
return 0;
}
return end - begin;
}
};
struct TaskCreator {
size_t chunk_size;
size_t size;
std::atomic<size_t> pos{0};
Task create() {
auto i = pos++;
auto begin = i * chunk_size;
auto end = begin + chunk_size;
if (end > size) {
end = size;
}
return {begin, end};
}
};
TaskCreator task_creator;
void do_migrate(HashMap *ptr) {
//LOG(ERROR) << "In do_migrate: " << ptr;
std::unique_lock<std::mutex> lock(migrate_mutex_);
if (hash_map_.load() != ptr) {
return;
}
init_migrate();
CHECK(!ptr || migrate_from_hash_map_ == ptr);
migrate_cnt_++;
auto migrate_generation = migrate_generation_;
lock.unlock();
run_migrate();
lock.lock();
migrate_cnt_--;
if (migrate_cnt_ == 0) {
finish_migrate();
}
migrate_cv_.wait(lock, [&] { return migrate_generation_ != migrate_generation; });
}
void finish_migrate() {
//LOG(ERROR) << "In finish_migrate";
hash_map_.store(migrate_to_hash_map_);
hp_.retire(get_thread_id(), migrate_from_hash_map_);
migrate_from_hash_map_ = nullptr;
migrate_to_hash_map_ = nullptr;
migrate_generation_++;
migrate_cv_.notify_all();
}
void init_migrate() {
if (migrate_from_hash_map_ != nullptr) {
return;
}
//LOG(ERROR) << "In init_migrate";
CHECK(migrate_cnt_ == 0);
migrate_generation_++;
migrate_from_hash_map_ = hash_map_.exchange(nullptr);
auto new_size = migrate_from_hash_map_->size() * 2;
migrate_to_hash_map_ = make_unique<HashMap>(new_size).release();
task_creator.chunk_size = 100;
task_creator.size = migrate_from_hash_map_->size();
task_creator.pos = 0;
}
void run_migrate() {
//LOG(ERROR) << "In run_migrate";
size_t cnt = 0;
while (true) {
auto task = task_creator.create();
cnt += task.size();
if (task.empty()) {
break;
}
run_task(task);
}
//LOG(ERROR) << "In run_migrate " << cnt;
}
void run_task(Task task) {
for (auto i = task.begin; i < task.end; i++) {
auto &node = migrate_from_hash_map_->node_at(i);
auto old_value = node.value.exchange(migrate_value(), std::memory_order_acq_rel);
if (old_value == 0) {
continue;
}
auto node_key = node.key.load(std::memory_order_relaxed);
//LOG(ERROR) << node_key << " " << node_key;
auto ok = migrate_to_hash_map_->with_value(
node_key, true, [&](auto &node_value) { node_value.store(old_value, std::memory_order_relaxed); });
LOG_CHECK(ok) << "Migration overflow";
}
}
};
template <class KeyT, class ValueT>
HazardPointers<typename ConcurrentHashMap<KeyT, ValueT>::HashMap> ConcurrentHashMap<KeyT, ValueT>::hp_(64);
} // namespace td