Reduce size of an empty FlatHashTable.

This commit is contained in:
levlam 2022-02-23 22:13:40 +03:00
parent 3da16b4501
commit e1909b018e
2 changed files with 117 additions and 62 deletions

View File

@ -8,6 +8,7 @@
#include "td/utils/bits.h"
#include "td/utils/common.h"
#include "td/utils/fixed_vector.h"
#include "td/utils/FlatHashMapLinear.h"
#include <cstddef>

View File

@ -8,7 +8,6 @@
#include "td/utils/bits.h"
#include "td/utils/common.h"
#include "td/utils/fixed_vector.h"
#include <cstddef>
#include <functional>
@ -160,9 +159,60 @@ struct SetNode {
template <class NodeT, class HashT, class EqT>
class FlatHashTable {
public:
using NodeIterator = NodeT *;
struct FlatHashTableInner {
uint32 used_node_count_;
uint32 bucket_count_mask_;
NodeT nodes_[1];
};
static constexpr size_t OFFSET = 2 * sizeof(uint32);
static NodeT *allocate_nodes(size_t size) {
DCHECK(size >= 8);
DCHECK((size & (size - 1)) == 0);
CHECK(size <= (1 << 29));
auto inner = static_cast<FlatHashTableInner *>(std::malloc(OFFSET + sizeof(NodeT) * size));
NodeT *nodes = &inner->nodes_[0];
for (size_t i = 0; i < size; i++) {
new (nodes + i) NodeT();
}
// inner->used_node_count_ = 0;
inner->bucket_count_mask_ = static_cast<uint32>(size - 1);
return nodes;
}
static void clear_inner(FlatHashTableInner *inner) {
auto size = inner->bucket_count_mask_ + 1;
NodeT *nodes = &inner->nodes_[0];
for (size_t i = 0; i < size; i++) {
nodes[i].~NodeT();
}
std::free(inner);
}
inline FlatHashTableInner *get_inner() {
DCHECK(nodes_ != nullptr);
return reinterpret_cast<FlatHashTableInner *>(reinterpret_cast<char *>(nodes_) - OFFSET);
}
inline const FlatHashTableInner *get_inner() const {
DCHECK(nodes_ != nullptr);
return reinterpret_cast<const FlatHashTableInner *>(reinterpret_cast<const char *>(nodes_) - OFFSET);
}
inline uint32 &used_node_count() {
return get_inner()->used_node_count_;
}
inline uint32 get_used_node_count() const {
return get_inner()->used_node_count_;
}
inline uint32 get_bucket_count_mask() const {
return get_inner()->bucket_count_mask_;
}
public:
using KeyT = typename NodeT::public_key_type;
using key_type = typename NodeT::public_key_type;
using value_type = typename NodeT::public_type;
@ -203,11 +253,11 @@ class FlatHashTable {
}
Iterator() = default;
Iterator(NodeIterator it, FlatHashTable *map) : it_(std::move(it)), end_(map->nodes_.end()) {
Iterator(NodeT *it, FlatHashTable *map) : it_(it), end_(map->nodes_ + map->bucket_count()) {
}
private:
NodeIterator it_;
NodeT *it_;
NodeT *end_;
};
@ -264,6 +314,7 @@ class FlatHashTable {
return;
}
reserve(nodes.size());
uint32 used_nodes = 0;
for (auto &new_node : nodes) {
CHECK(!new_node.empty());
auto bucket = calc_bucket(new_node.key());
@ -271,7 +322,7 @@ class FlatHashTable {
auto &node = nodes_[bucket];
if (node.empty()) {
node.copy_from(new_node);
used_nodes_++;
used_nodes++;
break;
}
if (EqT()(node.key(), new_node.key())) {
@ -280,28 +331,28 @@ class FlatHashTable {
next_bucket(bucket);
}
}
used_node_count() = used_nodes;
}
FlatHashTable(FlatHashTable &&other) noexcept : nodes_(std::move(other.nodes_)), used_nodes_(other.used_nodes_) {
other.clear();
FlatHashTable(FlatHashTable &&other) noexcept : nodes_(other.nodes_) {
other.nodes_ = nullptr;
}
void operator=(FlatHashTable &&other) noexcept {
nodes_ = std::move(other.nodes_);
used_nodes_ = other.used_nodes_;
other.clear();
clear();
nodes_ = other.nodes_;
other.nodes_ = nullptr;
}
void swap(FlatHashTable &other) noexcept {
nodes_.swap(other.nodes_);
std::swap(used_nodes_, other.used_nodes_);
std::swap(nodes_, other.nodes_);
}
~FlatHashTable() = default;
size_t bucket_count() const {
return nodes_.size();
return unlikely(nodes_ == nullptr) ? 0 : get_bucket_count_mask() + 1;
}
Iterator find(const KeyT &key) {
if (empty() || is_key_empty(key)) {
if (unlikely(nodes_ == nullptr) || is_key_empty(key)) {
return end();
}
auto bucket = calc_bucket(key);
@ -322,25 +373,25 @@ class FlatHashTable {
}
size_t size() const {
return used_nodes_;
return unlikely(nodes_ == nullptr) ? 0 : get_used_node_count();
}
bool empty() const {
return size() == 0;
return unlikely(nodes_ == nullptr) || get_used_node_count() == 0;
}
Iterator begin() {
if (empty()) {
return end();
}
auto it = nodes_.begin();
auto it = nodes_;
while (it->empty()) {
++it;
}
return Iterator(it, this);
}
Iterator end() {
return Iterator(nodes_.end(), this);
return Iterator(nodes_ + bucket_count(), this);
}
ConstIterator begin() const {
@ -351,8 +402,10 @@ class FlatHashTable {
}
void reserve(size_t size) {
if (size == 0) {
return;
}
size_t want_size = normalize(size * 5 / 3 + 1);
// size_t want_size = size * 2;
if (want_size > bucket_count()) {
resize(want_size);
}
@ -370,7 +423,7 @@ class FlatHashTable {
}
if (node.empty()) {
node.emplace(std::move(key), std::forward<ArgsT>(args)...);
used_nodes_++;
used_node_count()++;
return {Iterator{&node, this}, true};
}
next_bucket(bucket);
@ -407,8 +460,10 @@ class FlatHashTable {
}
void clear() {
used_nodes_ = 0;
nodes_ = {};
if (nodes_ != nullptr) {
clear_inner(get_inner());
nodes_ = nullptr;
}
}
void erase(Iterator it) {
@ -420,19 +475,24 @@ class FlatHashTable {
template <class F>
void remove_if(F &&f) {
if (empty()) {
return;
}
auto it = begin().it_;
while (it != nodes_.end() && !it->empty()) {
auto end = nodes_ + bucket_count();
while (it != end && !it->empty()) {
++it;
}
auto first_empty = it;
for (; it != nodes_.end();) {
while (it != end) {
if (!it->empty() && f(it->get_public())) {
erase_node(it);
} else {
++it;
}
}
for (it = nodes_.begin(); it != first_empty;) {
for (it = nodes_; it != first_empty;) {
if (!it->empty() && f(it->get_public())) {
erase_node(it);
} else {
@ -443,11 +503,13 @@ class FlatHashTable {
}
private:
fixed_vector<NodeT> nodes_;
size_t used_nodes_{};
NodeT *nodes_ = nullptr;
void assign(const FlatHashTable &other) {
resize(other.size());
if (other.size() == 0) {
return;
}
resize(other.bucket_count());
for (const auto &new_node : other) {
auto bucket = calc_bucket(new_node.key());
while (true) {
@ -459,49 +521,45 @@ class FlatHashTable {
next_bucket(bucket);
}
}
used_nodes_ = other.size();
used_node_count() = other.get_used_node_count();
}
void try_grow() {
if (should_grow(used_nodes_ + 1, bucket_count())) {
grow();
if (unlikely(nodes_ == nullptr)) {
resize(8);
} else if (unlikely(get_used_node_count() * 5 > get_bucket_count_mask() * 3)) {
resize(2 * get_bucket_count_mask() + 2);
}
}
static bool should_grow(size_t used_count, size_t bucket_count) {
return used_count * 5 > bucket_count * 3;
}
void try_shrink() {
if (should_shrink(used_nodes_, bucket_count())) {
shrink();
DCHECK(nodes_ != nullptr);
if (unlikely(get_used_node_count() * 10 < get_bucket_count_mask() && get_bucket_count_mask() > 7)) {
resize(normalize((get_used_node_count() + 1) * 5 / 3 + 1));
}
}
static bool should_shrink(size_t used_count, size_t bucket_count) {
return used_count * 10 < bucket_count;
}
static size_t normalize(size_t size) {
return static_cast<size_t>(1) << (64 - count_leading_zeroes64(size | 7));
}
void shrink() {
size_t want_size = normalize((used_nodes_ + 1) * 5 / 3 + 1);
resize(want_size);
}
void grow() {
size_t want_size = normalize(2 * bucket_count() - !nodes_.empty());
resize(want_size);
return td::max(static_cast<size_t>(1) << (64 - count_leading_zeroes64(size)), static_cast<size_t>(8));
}
uint32 calc_bucket(const KeyT &key) const {
return randomize_hash(HashT()(key)) & static_cast<uint32>(bucket_count() - 1);
return randomize_hash(HashT()(key)) & get_bucket_count_mask();
}
inline void next_bucket(uint32 &bucket) const {
bucket = (bucket + 1) & get_bucket_count_mask();
}
void resize(size_t new_size) {
fixed_vector<NodeT> old_nodes(new_size);
old_nodes.swap(nodes_);
auto old_nodes = nodes_;
auto old_size = size();
size_t old_bucket_count = bucket_count();
nodes_ = allocate_nodes(new_size);
used_node_count() = static_cast<uint32>(old_size);
for (auto &old_node : old_nodes) {
for (size_t i = 0; i < old_bucket_count; i++) {
auto &old_node = old_nodes[i];
if (old_node.empty()) {
continue;
}
@ -513,16 +571,12 @@ class FlatHashTable {
}
}
void next_bucket(uint32 &bucket) const {
bucket = (bucket + 1) & static_cast<uint32>(bucket_count() - 1);
}
void erase_node(NodeIterator it) {
size_t empty_i = it - nodes_.begin();
void erase_node(NodeT *it) {
size_t empty_i = it - nodes_;
auto empty_bucket = empty_i;
DCHECK(0 <= empty_i && empty_i < bucket_count());
nodes_[empty_bucket].clear();
used_nodes_--;
used_node_count()--;
for (size_t test_i = empty_i + 1;; test_i++) {
auto test_bucket = test_i;