Consistently use uint32 in FlatHashTable.

This commit is contained in:
levlam 2022-02-23 22:51:43 +03:00
parent fe06a1d4fc
commit 35cf57eed8

View File

@ -167,24 +167,24 @@ class FlatHashTable {
static constexpr size_t OFFSET = 2 * sizeof(uint32); static constexpr size_t OFFSET = 2 * sizeof(uint32);
static NodeT *allocate_nodes(size_t size) { static NodeT *allocate_nodes(uint32 size) {
DCHECK(size >= 8); DCHECK(size >= 8);
DCHECK((size & (size - 1)) == 0); DCHECK((size & (size - 1)) == 0);
CHECK(size <= (1 << 29)); CHECK(size <= (1 << 29));
auto inner = static_cast<FlatHashTableInner *>(std::malloc(OFFSET + sizeof(NodeT) * size)); auto inner = static_cast<FlatHashTableInner *>(std::malloc(OFFSET + sizeof(NodeT) * size));
NodeT *nodes = &inner->nodes_[0]; NodeT *nodes = &inner->nodes_[0];
for (size_t i = 0; i < size; i++) { for (uint32 i = 0; i < size; i++) {
new (nodes + i) NodeT(); new (nodes + i) NodeT();
} }
// inner->used_node_count_ = 0; // inner->used_node_count_ = 0;
inner->bucket_count_mask_ = static_cast<uint32>(size - 1); inner->bucket_count_mask_ = size - 1;
return nodes; return nodes;
} }
static void clear_inner(FlatHashTableInner *inner) { static void clear_inner(FlatHashTableInner *inner) {
auto size = inner->bucket_count_mask_ + 1; auto size = inner->bucket_count_mask_ + 1;
NodeT *nodes = &inner->nodes_[0]; NodeT *nodes = &inner->nodes_[0];
for (size_t i = 0; i < size; i++) { for (uint32 i = 0; i < size; i++) {
nodes[i].~NodeT(); nodes[i].~NodeT();
} }
std::free(inner); std::free(inner);
@ -347,7 +347,7 @@ class FlatHashTable {
} }
~FlatHashTable() = default; ~FlatHashTable() = default;
size_t bucket_count() const { uint32 bucket_count() const {
return unlikely(nodes_ == nullptr) ? 0 : get_bucket_count_mask() + 1; return unlikely(nodes_ == nullptr) ? 0 : get_bucket_count_mask() + 1;
} }
@ -405,7 +405,8 @@ class FlatHashTable {
if (size == 0) { if (size == 0) {
return; return;
} }
size_t want_size = normalize(size * 5 / 3 + 1); CHECK(size <= (1u << 29));
uint32 want_size = normalize(static_cast<uint32>(size) * 5 / 3 + 1);
if (want_size > bucket_count()) { if (want_size > bucket_count()) {
resize(want_size); resize(want_size);
} }
@ -539,8 +540,8 @@ class FlatHashTable {
} }
} }
static size_t normalize(size_t size) { static uint32 normalize(uint32 size) {
return td::max(static_cast<size_t>(1) << (64 - count_leading_zeroes64(size)), static_cast<size_t>(8)); return td::max(static_cast<uint32>(1) << (32 - count_leading_zeroes32(size)), static_cast<uint32>(8));
} }
uint32 calc_bucket(const KeyT &key) const { uint32 calc_bucket(const KeyT &key) const {
@ -551,14 +552,20 @@ class FlatHashTable {
bucket = (bucket + 1) & get_bucket_count_mask(); bucket = (bucket + 1) & get_bucket_count_mask();
} }
void resize(size_t new_size) { void resize(uint32 new_size) {
auto old_nodes = nodes_; if (unlikely(nodes_ == nullptr)) {
auto old_size = size(); nodes_ = allocate_nodes(new_size);
size_t old_bucket_count = bucket_count(); return;
nodes_ = allocate_nodes(new_size); }
used_node_count() = static_cast<uint32>(old_size);
for (size_t i = 0; i < old_bucket_count; i++) { auto old_nodes = nodes_;
uint32 old_size = get_used_node_count();
uint32 old_bucket_count = get_bucket_count_mask() + 1;
;
nodes_ = allocate_nodes(new_size);
used_node_count() = old_size;
for (uint32 i = 0; i < old_bucket_count; i++) {
auto &old_node = old_nodes[i]; auto &old_node = old_nodes[i];
if (old_node.empty()) { if (old_node.empty()) {
continue; continue;
@ -572,13 +579,13 @@ class FlatHashTable {
} }
void erase_node(NodeT *it) { void erase_node(NodeT *it) {
size_t empty_i = it - nodes_; DCHECK(nodes_ <= it && it - nodes_ < bucket_count());
uint32 empty_i = static_cast<uint32>(it - nodes_);
auto empty_bucket = empty_i; auto empty_bucket = empty_i;
DCHECK(0 <= empty_i && empty_i < bucket_count());
nodes_[empty_bucket].clear(); nodes_[empty_bucket].clear();
used_node_count()--; used_node_count()--;
for (size_t test_i = empty_i + 1;; test_i++) { for (uint32 test_i = empty_i + 1;; test_i++) {
auto test_bucket = test_i; auto test_bucket = test_i;
if (test_bucket >= bucket_count()) { if (test_bucket >= bucket_count()) {
test_bucket -= bucket_count(); test_bucket -= bucket_count();