572 lines
14 KiB
C++
572 lines
14 KiB
C++
//
|
|
// Copyright Aliaksei Levin (levlam@telegram.org), Arseny Smirnov (arseny30@gmail.com) 2014-2022
|
|
//
|
|
// Distributed under the Boost Software License, Version 1.0. (See accompanying
|
|
// file LICENSE_1_0.txt or copy at http://www.boost.org/LICENSE_1_0.txt)
|
|
//
|
|
#pragma once
|
|
|
|
#include "td/utils/bits.h"
|
|
#include "td/utils/common.h"
|
|
#include "td/utils/FlatHashMapLinear.h"
|
|
|
|
#include <cstddef>
|
|
#include <functional>
|
|
#include <initializer_list>
|
|
#include <iterator>
|
|
#include <limits>
|
|
#include <utility>
|
|
|
|
#if defined(__SSE2__) || (TD_MSVC && (defined(_M_X64) || (defined(_M_IX86) && _M_IX86_FP >= 2)))
|
|
#define TD_SSE2 1
|
|
#endif
|
|
|
|
#ifdef __aarch64__
|
|
#include <arm_neon.h>
|
|
#endif
|
|
|
|
#if TD_SSE2
|
|
#include <emmintrin.h>
|
|
#endif
|
|
|
|
namespace td {
|
|
template <int shift>
|
|
struct MaskIterator {
|
|
uint64 mask;
|
|
explicit operator bool() const {
|
|
return mask != 0;
|
|
}
|
|
int pos() const {
|
|
return count_trailing_zeroes64(mask) / shift;
|
|
}
|
|
void next() {
|
|
mask &= mask - 1;
|
|
}
|
|
|
|
// For foreach
|
|
bool operator!=(MaskIterator &other) const {
|
|
return mask != other.mask;
|
|
}
|
|
auto operator*() const {
|
|
return pos();
|
|
}
|
|
void operator++() {
|
|
next();
|
|
}
|
|
auto begin() {
|
|
return *this;
|
|
}
|
|
auto end() {
|
|
return MaskIterator{0u};
|
|
}
|
|
};
|
|
|
|
struct MaskPortable {
|
|
static MaskIterator<1> equal_mask(uint8 *bytes, uint8 needle) {
|
|
uint64 res = 0;
|
|
for (int i = 0; i < 16; i++) {
|
|
res |= (bytes[i] == needle) << i;
|
|
}
|
|
return {res & ((1u << 14) - 1)};
|
|
}
|
|
};
|
|
|
|
#ifdef __aarch64__
|
|
struct MaskNeonFolly {
|
|
static MaskIterator<4> equal_mask(uint8 *bytes, uint8 needle) {
|
|
uint8x16_t input_mask = vld1q_u8(bytes);
|
|
auto needle_mask = vdupq_n_u8(needle);
|
|
auto eq_mask = vceqq_u8(input_mask, needle_mask);
|
|
// get info from every byte into the bottom half of every uint16
|
|
// by shifting right 4, then round to get it into a 64-bit vector
|
|
uint8x8_t shifted_eq_mask = vshrn_n_u16(vreinterpretq_u16_u8(eq_mask), 4);
|
|
uint64 mask = vget_lane_u64(vreinterpret_u64_u8(shifted_eq_mask), 0);
|
|
return {mask & 0x11111111111111};
|
|
}
|
|
};
|
|
|
|
struct MaskNeon {
|
|
static MaskIterator<1> equal_mask(uint8 *bytes, uint8 needle) {
|
|
uint8x16_t input_mask = vld1q_u8(bytes);
|
|
auto needle_mask = vdupq_n_u8(needle);
|
|
auto eq_mask = vceqq_u8(input_mask, needle_mask);
|
|
uint16x8_t MASK = vdupq_n_u16(0x180);
|
|
uint16x8_t a_masked = vandq_u16(vreinterpretq_u16_u8(eq_mask), MASK);
|
|
const int16 __attribute__((aligned(16))) SHIFT_ARR[8] = {-7, -5, -3, -1, 1, 3, 5, 7};
|
|
int16x8_t SHIFT = vld1q_s16(SHIFT_ARR);
|
|
uint16x8_t a_shifted = vshlq_u16(a_masked, SHIFT);
|
|
return {vaddvq_u16(a_shifted) & ((1u << 14) - 1)};
|
|
}
|
|
};
|
|
#elif TD_SSE2
|
|
struct MaskSse2 {
|
|
static MaskIterator<1> equal_mask(uint8 *bytes, uint8 needle) {
|
|
auto input_mask = _mm_loadu_si128(reinterpret_cast<const __m128i *>(bytes));
|
|
auto needle_mask = _mm_set1_epi8(needle);
|
|
auto match_mask = _mm_cmpeq_epi8(needle_mask, input_mask);
|
|
return {static_cast<uint32>(_mm_movemask_epi8(match_mask)) & ((1u << 14) - 1)};
|
|
}
|
|
};
|
|
#endif
|
|
|
|
#ifdef __aarch64__
|
|
using MaskHelper = MaskNeonFolly;
|
|
#elif TD_SSE2
|
|
using MaskHelper = MaskSse2;
|
|
#else
|
|
using MaskHelper = MaskPortable;
|
|
#endif
|
|
|
|
template <class NodeT, class HashT, class EqT>
|
|
class FlatHashTableChunks {
|
|
public:
|
|
using Self = FlatHashTableChunks<NodeT, HashT, EqT>;
|
|
using Node = NodeT;
|
|
using NodeIterator = typename fixed_vector<Node>::iterator;
|
|
using ConstNodeIterator = typename fixed_vector<Node>::const_iterator;
|
|
|
|
using KeyT = typename Node::public_key_type;
|
|
using key_type = typename Node::public_key_type;
|
|
using value_type = typename Node::public_type;
|
|
|
|
struct Iterator {
|
|
using iterator_category = std::bidirectional_iterator_tag;
|
|
using difference_type = std::ptrdiff_t;
|
|
using value_type = FlatHashTableChunks::value_type;
|
|
using pointer = value_type *;
|
|
using reference = value_type &;
|
|
|
|
friend class FlatHashTableChunks;
|
|
Iterator &operator++() {
|
|
do {
|
|
++it_;
|
|
} while (it_ != map_->nodes_.end() && it_->empty());
|
|
return *this;
|
|
}
|
|
Iterator &operator--() {
|
|
do {
|
|
--it_;
|
|
} while (it_->empty());
|
|
return *this;
|
|
}
|
|
reference operator*() {
|
|
return it_->get_public();
|
|
}
|
|
pointer operator->() {
|
|
return &*it_;
|
|
}
|
|
bool operator==(const Iterator &other) const {
|
|
DCHECK(map_ == other.map_);
|
|
return it_ == other.it_;
|
|
}
|
|
bool operator!=(const Iterator &other) const {
|
|
DCHECK(map_ == other.map_);
|
|
return it_ != other.it_;
|
|
}
|
|
|
|
Iterator() = default;
|
|
Iterator(NodeIterator it, Self *map) : it_(std::move(it)), map_(map) {
|
|
}
|
|
|
|
private:
|
|
NodeIterator it_;
|
|
Self *map_;
|
|
};
|
|
|
|
struct ConstIterator {
|
|
using iterator_category = std::bidirectional_iterator_tag;
|
|
using difference_type = std::ptrdiff_t;
|
|
using value_type = FlatHashTableChunks::value_type;
|
|
using pointer = const value_type *;
|
|
using reference = const value_type &;
|
|
|
|
friend class FlatHashTableChunks;
|
|
ConstIterator &operator++() {
|
|
++it_;
|
|
return *this;
|
|
}
|
|
ConstIterator &operator--() {
|
|
--it_;
|
|
return *this;
|
|
}
|
|
reference operator*() {
|
|
return *it_;
|
|
}
|
|
pointer operator->() {
|
|
return &*it_;
|
|
}
|
|
bool operator==(const ConstIterator &other) const {
|
|
return it_ == other.it_;
|
|
}
|
|
bool operator!=(const ConstIterator &other) const {
|
|
return it_ != other.it_;
|
|
}
|
|
|
|
ConstIterator() = default;
|
|
ConstIterator(Iterator it) : it_(std::move(it)) {
|
|
}
|
|
|
|
private:
|
|
Iterator it_;
|
|
};
|
|
using iterator = Iterator;
|
|
using const_iterator = ConstIterator;
|
|
|
|
FlatHashTableChunks() = default;
|
|
FlatHashTableChunks(const FlatHashTableChunks &other) {
|
|
assign(other);
|
|
}
|
|
FlatHashTableChunks &operator=(const FlatHashTableChunks &other) {
|
|
clear();
|
|
assign(other);
|
|
return *this;
|
|
}
|
|
|
|
FlatHashTableChunks(std::initializer_list<Node> nodes) {
|
|
reserve(nodes.size());
|
|
for (auto &new_node : nodes) {
|
|
CHECK(!new_node.empty());
|
|
if (count(new_node.key()) > 0) {
|
|
continue;
|
|
}
|
|
Node node;
|
|
node.copy_from(new_node);
|
|
emplace_node(std::move(node));
|
|
}
|
|
}
|
|
|
|
FlatHashTableChunks(FlatHashTableChunks &&other) noexcept {
|
|
swap(other);
|
|
}
|
|
FlatHashTableChunks &operator=(FlatHashTableChunks &&other) noexcept {
|
|
swap(other);
|
|
return *this;
|
|
}
|
|
void swap(FlatHashTableChunks &other) noexcept {
|
|
nodes_.swap(other.nodes_);
|
|
chunks_.swap(other.chunks_);
|
|
std::swap(used_nodes_, other.used_nodes_);
|
|
}
|
|
~FlatHashTableChunks() = default;
|
|
|
|
size_t bucket_count() const {
|
|
return nodes_.size();
|
|
}
|
|
|
|
Iterator find(const KeyT &key) {
|
|
if (empty() || is_key_empty(key)) {
|
|
return end();
|
|
}
|
|
const auto hash = calc_hash(key);
|
|
auto chunk_it = get_chunk_it(hash.chunk_i);
|
|
while (true) {
|
|
auto chunk_i = chunk_it.pos();
|
|
auto chunk_begin = nodes_.begin() + chunk_i * Chunk::CHUNK_SIZE;
|
|
//__builtin_prefetch(chunk_begin);
|
|
auto &chunk = chunks_[chunk_i];
|
|
auto mask_it = MaskHelper::equal_mask(chunk.ctrl, hash.small_hash);
|
|
for (auto pos : mask_it) {
|
|
auto it = chunk_begin + pos;
|
|
if (likely(EqT()(it->key(), key))) {
|
|
return Iterator{it, this};
|
|
}
|
|
}
|
|
if (chunk.skipped_cnt == 0) {
|
|
break;
|
|
}
|
|
chunk_it.next();
|
|
}
|
|
return end();
|
|
}
|
|
|
|
ConstIterator find(const KeyT &key) const {
|
|
return ConstIterator(const_cast<Self *>(this)->find(key));
|
|
}
|
|
|
|
size_t size() const {
|
|
return used_nodes_;
|
|
}
|
|
|
|
bool empty() const {
|
|
return size() == 0;
|
|
}
|
|
|
|
Iterator begin() {
|
|
if (empty()) {
|
|
return end();
|
|
}
|
|
auto it = nodes_.begin();
|
|
while (it->empty()) {
|
|
++it;
|
|
}
|
|
return Iterator(it, this);
|
|
}
|
|
Iterator end() {
|
|
return Iterator(nodes_.end(), this);
|
|
}
|
|
|
|
ConstIterator begin() const {
|
|
return ConstIterator(const_cast<Self *>(this)->begin());
|
|
}
|
|
ConstIterator end() const {
|
|
return ConstIterator(const_cast<Self *>(this)->end());
|
|
}
|
|
|
|
void reserve(size_t size) {
|
|
//size_t want_size = normalize(size * 5 / 3 + 1);
|
|
size_t want_size = normalize(size * 14 / 12 + 1);
|
|
// size_t want_size = size * 2;
|
|
if (want_size > nodes_.size()) {
|
|
resize(want_size);
|
|
}
|
|
}
|
|
|
|
template <class... ArgsT>
|
|
std::pair<Iterator, bool> emplace(KeyT key, ArgsT &&...args) {
|
|
CHECK(!is_key_empty(key));
|
|
auto it = find(key);
|
|
if (it != end()) {
|
|
return {it, false};
|
|
}
|
|
try_grow();
|
|
|
|
auto hash = calc_hash(key);
|
|
auto chunk_it = get_chunk_it(hash.chunk_i);
|
|
while (true) {
|
|
auto chunk_i = chunk_it.pos();
|
|
auto &chunk = chunks_[chunk_i];
|
|
auto mask_it = MaskHelper::equal_mask(chunk.ctrl, 0);
|
|
if (mask_it) {
|
|
auto shift = mask_it.pos();
|
|
DCHECK(chunk.ctrl[shift] == 0);
|
|
auto node_it = nodes_.begin() + shift + chunk_i * Chunk::CHUNK_SIZE;
|
|
DCHECK(node_it->empty());
|
|
node_it->emplace(std::move(key), std::forward<ArgsT>(args)...);
|
|
DCHECK(!node_it->empty());
|
|
chunk.ctrl[shift] = hash.small_hash;
|
|
used_nodes_++;
|
|
return {{node_it, this}, true};
|
|
}
|
|
CHECK(chunk.skipped_cnt != std::numeric_limits<uint16>::max());
|
|
chunk.skipped_cnt++;
|
|
chunk_it.next();
|
|
}
|
|
}
|
|
|
|
std::pair<Iterator, bool> insert(KeyT key) {
|
|
return emplace(std::move(key));
|
|
}
|
|
|
|
template <class ItT>
|
|
void insert(ItT begin, ItT end) {
|
|
for (; begin != end; ++begin) {
|
|
emplace(*begin);
|
|
}
|
|
}
|
|
|
|
template <class T = typename Node::second_type>
|
|
T &operator[](const KeyT &key) {
|
|
return emplace(key).first->second;
|
|
}
|
|
|
|
size_t erase(const KeyT &key) {
|
|
auto it = find(key);
|
|
if (it == end()) {
|
|
return 0;
|
|
}
|
|
erase(it);
|
|
try_shrink();
|
|
return 1;
|
|
}
|
|
|
|
size_t count(const KeyT &key) const {
|
|
return find(key) != end();
|
|
}
|
|
|
|
void clear() {
|
|
used_nodes_ = 0;
|
|
nodes_ = {};
|
|
chunks_ = {};
|
|
}
|
|
|
|
void erase(Iterator it) {
|
|
DCHECK(it != end());
|
|
DCHECK(!it.it_->empty());
|
|
erase_node(it.it_);
|
|
}
|
|
|
|
template <class F>
|
|
void remove_if(F &&f) {
|
|
for (auto it = nodes_.begin(), end = nodes_.end(); it != end; ++it) {
|
|
if (!it->empty() && f(it->get_public())) {
|
|
erase_node(it);
|
|
}
|
|
}
|
|
try_shrink();
|
|
}
|
|
|
|
private:
|
|
struct Chunk {
|
|
static constexpr int CHUNK_SIZE = 14;
|
|
static constexpr int MASK = (1 << CHUNK_SIZE) - 1;
|
|
// 0x0 - empty
|
|
uint8 ctrl[CHUNK_SIZE] = {};
|
|
uint16 skipped_cnt{0};
|
|
};
|
|
fixed_vector<Node> nodes_;
|
|
fixed_vector<Chunk> chunks_;
|
|
size_t used_nodes_{};
|
|
|
|
void assign(const FlatHashTableChunks &other) {
|
|
reserve(other.size());
|
|
for (const auto &new_node : other) {
|
|
Node node;
|
|
node.copy_from(new_node);
|
|
emplace_node(std::move(node));
|
|
}
|
|
}
|
|
|
|
void try_grow() {
|
|
if (should_grow(used_nodes_ + 1, nodes_.size())) {
|
|
grow();
|
|
}
|
|
}
|
|
static bool should_grow(size_t used_count, size_t bucket_count) {
|
|
return used_count * 14 > bucket_count * 12;
|
|
}
|
|
void try_shrink() {
|
|
if (should_shrink(used_nodes_, nodes_.size())) {
|
|
shrink();
|
|
}
|
|
}
|
|
static bool should_shrink(size_t used_count, size_t bucket_count) {
|
|
return used_count * 10 < bucket_count;
|
|
}
|
|
|
|
static size_t normalize(size_t size) {
|
|
auto x = (size / Chunk::CHUNK_SIZE) | 1;
|
|
auto y = static_cast<size_t>(1) << (64 - count_leading_zeroes64(x));
|
|
return y * Chunk::CHUNK_SIZE;
|
|
}
|
|
|
|
void shrink() {
|
|
size_t want_size = normalize((used_nodes_ + 1) * 5 / 3 + 1);
|
|
resize(want_size);
|
|
}
|
|
|
|
void grow() {
|
|
size_t want_size = normalize(2 * nodes_.size() - !nodes_.empty());
|
|
resize(want_size);
|
|
}
|
|
|
|
struct HashInfo {
|
|
size_t chunk_i;
|
|
uint8 small_hash;
|
|
};
|
|
struct ChunkIt {
|
|
size_t chunk_i;
|
|
size_t chunk_mask;
|
|
size_t shift{};
|
|
size_t pos() const {
|
|
return chunk_i;
|
|
}
|
|
void next() {
|
|
DCHECK((chunk_mask & (chunk_mask + 1)) == 0);
|
|
shift++;
|
|
chunk_i += shift;
|
|
chunk_i &= chunk_mask;
|
|
}
|
|
};
|
|
|
|
ChunkIt get_chunk_it(size_t chunk_i) {
|
|
return {chunk_i, chunks_.size() - 1};
|
|
}
|
|
|
|
HashInfo calc_hash(const KeyT &key) {
|
|
auto h = HashT()(key);
|
|
// TODO: will be problematic with current hash.
|
|
return {(h >> 8) % chunks_.size(), static_cast<uint8>(0x80 | h)};
|
|
}
|
|
|
|
void resize(size_t new_size) {
|
|
CHECK(new_size >= Chunk::CHUNK_SIZE);
|
|
fixed_vector<Node> old_nodes(new_size);
|
|
fixed_vector<Chunk> chunks(new_size / Chunk::CHUNK_SIZE);
|
|
old_nodes.swap(nodes_);
|
|
chunks_ = std::move(chunks);
|
|
used_nodes_ = 0;
|
|
|
|
for (auto &node : old_nodes) {
|
|
if (node.empty()) {
|
|
continue;
|
|
}
|
|
emplace_node(std::move(node));
|
|
}
|
|
}
|
|
|
|
void emplace_node(Node &&node) {
|
|
DCHECK(!node.empty());
|
|
auto hash = calc_hash(node.key());
|
|
auto chunk_it = get_chunk_it(hash.chunk_i);
|
|
while (true) {
|
|
auto chunk_i = chunk_it.pos();
|
|
auto &chunk = chunks_[chunk_i];
|
|
auto mask_it = MaskHelper::equal_mask(chunk.ctrl, 0);
|
|
if (mask_it) {
|
|
auto shift = mask_it.pos();
|
|
auto node_it = nodes_.begin() + shift + chunk_i * Chunk::CHUNK_SIZE;
|
|
DCHECK(node_it->empty());
|
|
*node_it = std::move(node);
|
|
DCHECK(chunk.ctrl[shift] == 0);
|
|
chunk.ctrl[shift] = hash.small_hash;
|
|
DCHECK(chunk.ctrl[shift] != 0);
|
|
used_nodes_++;
|
|
break;
|
|
}
|
|
CHECK(chunk.skipped_cnt != std::numeric_limits<uint16>::max());
|
|
chunk.skipped_cnt++;
|
|
chunk_it.next();
|
|
}
|
|
}
|
|
|
|
void next_bucket(size_t &bucket) const {
|
|
bucket++;
|
|
if (unlikely(bucket == nodes_.size())) {
|
|
bucket = 0;
|
|
}
|
|
}
|
|
|
|
void erase_node(NodeIterator it) {
|
|
DCHECK(!it->empty());
|
|
size_t empty_i = it - nodes_.begin();
|
|
DCHECK(0 <= empty_i && empty_i < nodes_.size());
|
|
auto empty_chunk_i = empty_i / Chunk::CHUNK_SIZE;
|
|
auto hash = calc_hash(it->key());
|
|
auto chunk_it = get_chunk_it(hash.chunk_i);
|
|
while (true) {
|
|
auto chunk_i = chunk_it.pos();
|
|
auto &chunk = chunks_[chunk_i];
|
|
if (chunk_i == empty_chunk_i) {
|
|
chunk.ctrl[empty_i - empty_chunk_i * Chunk::CHUNK_SIZE] = 0;
|
|
break;
|
|
}
|
|
chunk.skipped_cnt--;
|
|
chunk_it.next();
|
|
}
|
|
it->clear();
|
|
used_nodes_--;
|
|
}
|
|
};
|
|
|
|
template <class KeyT, class ValueT, class HashT = std::hash<KeyT>, class EqT = std::equal_to<KeyT>>
|
|
using FlatHashMapChunks = FlatHashTableChunks<MapNode<KeyT, ValueT>, HashT, EqT>;
|
|
template <class KeyT, class HashT = std::hash<KeyT>, class EqT = std::equal_to<KeyT>>
|
|
using FlatHashSetChunks = FlatHashTableChunks<SetNode<KeyT>, HashT, EqT>;
|
|
|
|
template <class NodeT, class HashT, class EqT, class FuncT>
|
|
void table_remove_if(FlatHashTableChunks<NodeT, HashT, EqT> &table, FuncT &&func) {
|
|
table.remove_if(func);
|
|
}
|
|
|
|
} // namespace td
|