Reduce size of an empty FlatHashTable.

This commit is contained in:
levlam 2022-02-23 22:13:40 +03:00
parent 3da16b4501
commit e1909b018e
2 changed files with 117 additions and 62 deletions

View File

@ -8,6 +8,7 @@
#include "td/utils/bits.h" #include "td/utils/bits.h"
#include "td/utils/common.h" #include "td/utils/common.h"
#include "td/utils/fixed_vector.h"
#include "td/utils/FlatHashMapLinear.h" #include "td/utils/FlatHashMapLinear.h"
#include <cstddef> #include <cstddef>

View File

@ -8,7 +8,6 @@
#include "td/utils/bits.h" #include "td/utils/bits.h"
#include "td/utils/common.h" #include "td/utils/common.h"
#include "td/utils/fixed_vector.h"
#include <cstddef> #include <cstddef>
#include <functional> #include <functional>
@ -160,9 +159,60 @@ struct SetNode {
template <class NodeT, class HashT, class EqT> template <class NodeT, class HashT, class EqT>
class FlatHashTable { class FlatHashTable {
public: struct FlatHashTableInner {
using NodeIterator = NodeT *; uint32 used_node_count_;
uint32 bucket_count_mask_;
NodeT nodes_[1];
};
static constexpr size_t OFFSET = 2 * sizeof(uint32);
static NodeT *allocate_nodes(size_t size) {
DCHECK(size >= 8);
DCHECK((size & (size - 1)) == 0);
CHECK(size <= (1 << 29));
auto inner = static_cast<FlatHashTableInner *>(std::malloc(OFFSET + sizeof(NodeT) * size));
NodeT *nodes = &inner->nodes_[0];
for (size_t i = 0; i < size; i++) {
new (nodes + i) NodeT();
}
// inner->used_node_count_ = 0;
inner->bucket_count_mask_ = static_cast<uint32>(size - 1);
return nodes;
}
static void clear_inner(FlatHashTableInner *inner) {
auto size = inner->bucket_count_mask_ + 1;
NodeT *nodes = &inner->nodes_[0];
for (size_t i = 0; i < size; i++) {
nodes[i].~NodeT();
}
std::free(inner);
}
inline FlatHashTableInner *get_inner() {
DCHECK(nodes_ != nullptr);
return reinterpret_cast<FlatHashTableInner *>(reinterpret_cast<char *>(nodes_) - OFFSET);
}
inline const FlatHashTableInner *get_inner() const {
DCHECK(nodes_ != nullptr);
return reinterpret_cast<const FlatHashTableInner *>(reinterpret_cast<const char *>(nodes_) - OFFSET);
}
inline uint32 &used_node_count() {
return get_inner()->used_node_count_;
}
inline uint32 get_used_node_count() const {
return get_inner()->used_node_count_;
}
inline uint32 get_bucket_count_mask() const {
return get_inner()->bucket_count_mask_;
}
public:
using KeyT = typename NodeT::public_key_type; using KeyT = typename NodeT::public_key_type;
using key_type = typename NodeT::public_key_type; using key_type = typename NodeT::public_key_type;
using value_type = typename NodeT::public_type; using value_type = typename NodeT::public_type;
@ -203,11 +253,11 @@ class FlatHashTable {
} }
Iterator() = default; Iterator() = default;
Iterator(NodeIterator it, FlatHashTable *map) : it_(std::move(it)), end_(map->nodes_.end()) { Iterator(NodeT *it, FlatHashTable *map) : it_(it), end_(map->nodes_ + map->bucket_count()) {
} }
private: private:
NodeIterator it_; NodeT *it_;
NodeT *end_; NodeT *end_;
}; };
@ -264,6 +314,7 @@ class FlatHashTable {
return; return;
} }
reserve(nodes.size()); reserve(nodes.size());
uint32 used_nodes = 0;
for (auto &new_node : nodes) { for (auto &new_node : nodes) {
CHECK(!new_node.empty()); CHECK(!new_node.empty());
auto bucket = calc_bucket(new_node.key()); auto bucket = calc_bucket(new_node.key());
@ -271,7 +322,7 @@ class FlatHashTable {
auto &node = nodes_[bucket]; auto &node = nodes_[bucket];
if (node.empty()) { if (node.empty()) {
node.copy_from(new_node); node.copy_from(new_node);
used_nodes_++; used_nodes++;
break; break;
} }
if (EqT()(node.key(), new_node.key())) { if (EqT()(node.key(), new_node.key())) {
@ -280,28 +331,28 @@ class FlatHashTable {
next_bucket(bucket); next_bucket(bucket);
} }
} }
used_node_count() = used_nodes;
} }
FlatHashTable(FlatHashTable &&other) noexcept : nodes_(std::move(other.nodes_)), used_nodes_(other.used_nodes_) { FlatHashTable(FlatHashTable &&other) noexcept : nodes_(other.nodes_) {
other.clear(); other.nodes_ = nullptr;
} }
void operator=(FlatHashTable &&other) noexcept { void operator=(FlatHashTable &&other) noexcept {
nodes_ = std::move(other.nodes_); clear();
used_nodes_ = other.used_nodes_; nodes_ = other.nodes_;
other.clear(); other.nodes_ = nullptr;
} }
void swap(FlatHashTable &other) noexcept { void swap(FlatHashTable &other) noexcept {
nodes_.swap(other.nodes_); std::swap(nodes_, other.nodes_);
std::swap(used_nodes_, other.used_nodes_);
} }
~FlatHashTable() = default; ~FlatHashTable() = default;
size_t bucket_count() const { size_t bucket_count() const {
return nodes_.size(); return unlikely(nodes_ == nullptr) ? 0 : get_bucket_count_mask() + 1;
} }
Iterator find(const KeyT &key) { Iterator find(const KeyT &key) {
if (empty() || is_key_empty(key)) { if (unlikely(nodes_ == nullptr) || is_key_empty(key)) {
return end(); return end();
} }
auto bucket = calc_bucket(key); auto bucket = calc_bucket(key);
@ -322,25 +373,25 @@ class FlatHashTable {
} }
size_t size() const { size_t size() const {
return used_nodes_; return unlikely(nodes_ == nullptr) ? 0 : get_used_node_count();
} }
bool empty() const { bool empty() const {
return size() == 0; return unlikely(nodes_ == nullptr) || get_used_node_count() == 0;
} }
Iterator begin() { Iterator begin() {
if (empty()) { if (empty()) {
return end(); return end();
} }
auto it = nodes_.begin(); auto it = nodes_;
while (it->empty()) { while (it->empty()) {
++it; ++it;
} }
return Iterator(it, this); return Iterator(it, this);
} }
Iterator end() { Iterator end() {
return Iterator(nodes_.end(), this); return Iterator(nodes_ + bucket_count(), this);
} }
ConstIterator begin() const { ConstIterator begin() const {
@ -351,8 +402,10 @@ class FlatHashTable {
} }
void reserve(size_t size) { void reserve(size_t size) {
if (size == 0) {
return;
}
size_t want_size = normalize(size * 5 / 3 + 1); size_t want_size = normalize(size * 5 / 3 + 1);
// size_t want_size = size * 2;
if (want_size > bucket_count()) { if (want_size > bucket_count()) {
resize(want_size); resize(want_size);
} }
@ -370,7 +423,7 @@ class FlatHashTable {
} }
if (node.empty()) { if (node.empty()) {
node.emplace(std::move(key), std::forward<ArgsT>(args)...); node.emplace(std::move(key), std::forward<ArgsT>(args)...);
used_nodes_++; used_node_count()++;
return {Iterator{&node, this}, true}; return {Iterator{&node, this}, true};
} }
next_bucket(bucket); next_bucket(bucket);
@ -407,8 +460,10 @@ class FlatHashTable {
} }
void clear() { void clear() {
used_nodes_ = 0; if (nodes_ != nullptr) {
nodes_ = {}; clear_inner(get_inner());
nodes_ = nullptr;
}
} }
void erase(Iterator it) { void erase(Iterator it) {
@ -420,19 +475,24 @@ class FlatHashTable {
template <class F> template <class F>
void remove_if(F &&f) { void remove_if(F &&f) {
if (empty()) {
return;
}
auto it = begin().it_; auto it = begin().it_;
while (it != nodes_.end() && !it->empty()) { auto end = nodes_ + bucket_count();
while (it != end && !it->empty()) {
++it; ++it;
} }
auto first_empty = it; auto first_empty = it;
for (; it != nodes_.end();) { while (it != end) {
if (!it->empty() && f(it->get_public())) { if (!it->empty() && f(it->get_public())) {
erase_node(it); erase_node(it);
} else { } else {
++it; ++it;
} }
} }
for (it = nodes_.begin(); it != first_empty;) { for (it = nodes_; it != first_empty;) {
if (!it->empty() && f(it->get_public())) { if (!it->empty() && f(it->get_public())) {
erase_node(it); erase_node(it);
} else { } else {
@ -443,11 +503,13 @@ class FlatHashTable {
} }
private: private:
fixed_vector<NodeT> nodes_; NodeT *nodes_ = nullptr;
size_t used_nodes_{};
void assign(const FlatHashTable &other) { void assign(const FlatHashTable &other) {
resize(other.size()); if (other.size() == 0) {
return;
}
resize(other.bucket_count());
for (const auto &new_node : other) { for (const auto &new_node : other) {
auto bucket = calc_bucket(new_node.key()); auto bucket = calc_bucket(new_node.key());
while (true) { while (true) {
@ -459,49 +521,45 @@ class FlatHashTable {
next_bucket(bucket); next_bucket(bucket);
} }
} }
used_nodes_ = other.size(); used_node_count() = other.get_used_node_count();
} }
void try_grow() { void try_grow() {
if (should_grow(used_nodes_ + 1, bucket_count())) { if (unlikely(nodes_ == nullptr)) {
grow(); resize(8);
} else if (unlikely(get_used_node_count() * 5 > get_bucket_count_mask() * 3)) {
resize(2 * get_bucket_count_mask() + 2);
} }
} }
static bool should_grow(size_t used_count, size_t bucket_count) {
return used_count * 5 > bucket_count * 3;
}
void try_shrink() { void try_shrink() {
if (should_shrink(used_nodes_, bucket_count())) { DCHECK(nodes_ != nullptr);
shrink(); if (unlikely(get_used_node_count() * 10 < get_bucket_count_mask() && get_bucket_count_mask() > 7)) {
resize(normalize((get_used_node_count() + 1) * 5 / 3 + 1));
} }
} }
static bool should_shrink(size_t used_count, size_t bucket_count) {
return used_count * 10 < bucket_count;
}
static size_t normalize(size_t size) { static size_t normalize(size_t size) {
return static_cast<size_t>(1) << (64 - count_leading_zeroes64(size | 7)); return td::max(static_cast<size_t>(1) << (64 - count_leading_zeroes64(size)), static_cast<size_t>(8));
}
void shrink() {
size_t want_size = normalize((used_nodes_ + 1) * 5 / 3 + 1);
resize(want_size);
}
void grow() {
size_t want_size = normalize(2 * bucket_count() - !nodes_.empty());
resize(want_size);
} }
uint32 calc_bucket(const KeyT &key) const { uint32 calc_bucket(const KeyT &key) const {
return randomize_hash(HashT()(key)) & static_cast<uint32>(bucket_count() - 1); return randomize_hash(HashT()(key)) & get_bucket_count_mask();
}
inline void next_bucket(uint32 &bucket) const {
bucket = (bucket + 1) & get_bucket_count_mask();
} }
void resize(size_t new_size) { void resize(size_t new_size) {
fixed_vector<NodeT> old_nodes(new_size); auto old_nodes = nodes_;
old_nodes.swap(nodes_); auto old_size = size();
size_t old_bucket_count = bucket_count();
nodes_ = allocate_nodes(new_size);
used_node_count() = static_cast<uint32>(old_size);
for (auto &old_node : old_nodes) { for (size_t i = 0; i < old_bucket_count; i++) {
auto &old_node = old_nodes[i];
if (old_node.empty()) { if (old_node.empty()) {
continue; continue;
} }
@ -513,16 +571,12 @@ class FlatHashTable {
} }
} }
void next_bucket(uint32 &bucket) const { void erase_node(NodeT *it) {
bucket = (bucket + 1) & static_cast<uint32>(bucket_count() - 1); size_t empty_i = it - nodes_;
}
void erase_node(NodeIterator it) {
size_t empty_i = it - nodes_.begin();
auto empty_bucket = empty_i; auto empty_bucket = empty_i;
DCHECK(0 <= empty_i && empty_i < bucket_count()); DCHECK(0 <= empty_i && empty_i < bucket_count());
nodes_[empty_bucket].clear(); nodes_[empty_bucket].clear();
used_nodes_--; used_node_count()--;
for (size_t test_i = empty_i + 1;; test_i++) { for (size_t test_i = empty_i + 1;; test_i++) {
auto test_bucket = test_i; auto test_bucket = test_i;