FlatHashSet

This commit is contained in:
Arseny Smirnov 2022-02-11 17:40:16 +01:00
parent 15f27455c5
commit ecceb51881
5 changed files with 231 additions and 107 deletions

View File

@ -33,7 +33,7 @@ static bool use_memprof() {
#endif #endif
} }
static auto get_memory() { static td::uint64 get_memory() {
#if USE_MEMPROF #if USE_MEMPROF
if (use_memprof()) { if (use_memprof()) {
return get_used_memory_size(); return get_used_memory_size();

View File

@ -73,12 +73,19 @@ class fixed_vector {
size_t size_{0}; size_t size_{0};
}; };
template <class KeyT, class ValueT, class HashT = std::hash<KeyT>> // TODO: move
class FlatHashMapImpl { template <class KeyT>
public: bool is_key_empty(const KeyT &key) {
struct Node { return key == KeyT();
}
template <class KeyT, class ValueT>
struct MapNode {
using first_type = KeyT; using first_type = KeyT;
using second_type = ValueT; using second_type = ValueT;
using key_type = KeyT;
using public_type = MapNode<KeyT, ValueT>;
using value_type = ValueT;
KeyT first{}; KeyT first{};
union { union {
ValueT second; ValueT second;
@ -89,22 +96,25 @@ class FlatHashMapImpl {
auto &value() { auto &value() {
return second; return second;
} }
auto &get_public() {
Node() { return *this;
} }
Node(KeyT key, ValueT value) : first(std::move(key)) {
MapNode() {
}
MapNode(KeyT key, ValueT value) : first(std::move(key)) {
new (&second) ValueT(std::move(value)); new (&second) ValueT(std::move(value));
DCHECK(!empty()); DCHECK(!empty());
} }
~Node() { ~MapNode() {
if (!empty()) { if (!empty()) {
second.~ValueT(); second.~ValueT();
} }
} }
Node(Node &&other) noexcept { MapNode(MapNode &&other) noexcept {
*this = std::move(other); *this = std::move(other);
} }
Node &operator=(Node &&other) noexcept { MapNode &operator=(MapNode &&other) noexcept {
DCHECK(empty()); DCHECK(empty());
DCHECK(!other.empty()); DCHECK(!other.empty());
first = std::move(other.first); first = std::move(other.first);
@ -132,22 +142,73 @@ class FlatHashMapImpl {
new (&second) ValueT(std::forward<ArgsT>(args)...); new (&second) ValueT(std::forward<ArgsT>(args)...);
DCHECK(!empty()); DCHECK(!empty());
} }
}; };
using Self = FlatHashMapImpl<KeyT, ValueT, HashT>;
template <class KeyT>
struct SetNode {
using first_type = KeyT;
using key_type = KeyT;
using public_type = KeyT;
using value_type = KeyT;
KeyT first{};
const auto &key() const {
return first;
}
const auto &value() const {
return first;
}
auto &get_public() {
return first;
}
SetNode() = default;
explicit SetNode(KeyT key) : first(std::move(key)) {
}
SetNode(SetNode &&other) noexcept {
*this = std::move(other);
}
SetNode &operator=(SetNode &&other) noexcept {
DCHECK(empty());
DCHECK(!other.empty());
first = std::move(other.first);
other.first = KeyT{};
return *this;
}
bool empty() const {
return is_key_empty(key());
}
void clear() {
first = KeyT();
CHECK(empty());
}
void emplace(KeyT key) {
first = std::move(key);
}
};
template <class NodeT, class HashT, class EqT>
class FlatHashTable {
public:
using Self = FlatHashTable<NodeT, HashT, EqT>;
using Node = NodeT;
using NodeIterator = typename fixed_vector<Node>::iterator; using NodeIterator = typename fixed_vector<Node>::iterator;
using ConstNodeIterator = typename fixed_vector<Node>::const_iterator; using ConstNodeIterator = typename fixed_vector<Node>::const_iterator;
using key_type = KeyT; using KeyT = typename Node::key_type;
using value_type = Node; using public_type = typename Node::public_type;
using value_type = typename Node::value_type;
struct Iterator { struct Iterator {
using iterator_category = std::bidirectional_iterator_tag; using iterator_category = std::bidirectional_iterator_tag;
using difference_type = std::ptrdiff_t; using difference_type = std::ptrdiff_t;
using value_type = Node; using value_type = public_type;
using pointer = Node *; using pointer = public_type *;
using reference = Node &; using reference = public_type &;
friend class FlatHashMapImpl; friend class FlatHashTable;
Iterator &operator++() { Iterator &operator++() {
do { do {
++it_; ++it_;
@ -160,10 +221,10 @@ class FlatHashMapImpl {
} while (it_->empty()); } while (it_->empty());
return *this; return *this;
} }
Node &operator*() { reference operator*() {
return *it_; return it_->get_public();
} }
Node *operator->() { pointer operator->() {
return &*it_; return &*it_;
} }
bool operator==(const Iterator &other) const { bool operator==(const Iterator &other) const {
@ -175,6 +236,7 @@ class FlatHashMapImpl {
return it_ != other.it_; return it_ != other.it_;
} }
Iterator() = default;
Iterator(NodeIterator it, Self *map) : it_(std::move(it)), map_(map) { Iterator(NodeIterator it, Self *map) : it_(std::move(it)), map_(map) {
} }
@ -186,11 +248,11 @@ class FlatHashMapImpl {
struct ConstIterator { struct ConstIterator {
using iterator_category = std::bidirectional_iterator_tag; using iterator_category = std::bidirectional_iterator_tag;
using difference_type = std::ptrdiff_t; using difference_type = std::ptrdiff_t;
using value_type = Node; using value_type = public_type;
using pointer = const Node *; using pointer = const value_type *;
using reference = const Node &; using reference = const value_type &;
friend class FlatHashMapImpl; friend class FlatHashTable;
ConstIterator &operator++() { ConstIterator &operator++() {
++it_; ++it_;
return *this; return *this;
@ -199,10 +261,10 @@ class FlatHashMapImpl {
--it_; --it_;
return *this; return *this;
} }
const Node &operator*() { reference operator*() {
return *it_; return *it_;
} }
const Node *operator->() { pointer operator->() {
return &*it_; return &*it_;
} }
bool operator==(const ConstIterator &other) const { bool operator==(const ConstIterator &other) const {
@ -212,25 +274,28 @@ class FlatHashMapImpl {
return it_ != other.it_; return it_ != other.it_;
} }
explicit ConstIterator(Iterator it) : it_(std::move(it)) { ConstIterator() = default;
ConstIterator(Iterator it) : it_(std::move(it)) {
} }
private: private:
Iterator it_; Iterator it_;
}; };
using iterator = Iterator;
using const_iterator = ConstIterator;
FlatHashMapImpl() = default; FlatHashTable() = default;
FlatHashMapImpl(const FlatHashMapImpl &other) : FlatHashMapImpl(other.begin(), other.end()) { FlatHashTable(const FlatHashTable &other) : FlatHashTable(other.begin(), other.end()) {
} }
FlatHashMapImpl &operator=(const FlatHashMapImpl &other) { FlatHashTable &operator=(const FlatHashTable &other) {
assign(other.begin(), other.end()); assign(other.begin(), other.end());
return *this; return *this;
} }
FlatHashMapImpl(std::initializer_list<Node> nodes) { FlatHashTable(std::initializer_list<Node> nodes) {
reserve(nodes.size()); reserve(nodes.size());
for (auto &node : nodes) { for (auto &node : nodes) {
CHECK(!is_key_empty(node.first)); CHECK(!node.empty());
auto bucket = calc_bucket(node.first); auto bucket = calc_bucket(node.first);
while (true) { while (true) {
if (nodes_[bucket].key() == node.first) { if (nodes_[bucket].key() == node.first) {
@ -247,29 +312,38 @@ class FlatHashMapImpl {
} }
} }
FlatHashMapImpl(FlatHashMapImpl &&other) noexcept : nodes_(std::move(other.nodes_)), used_nodes_(other.used_nodes_) { FlatHashTable(FlatHashTable &&other) noexcept : nodes_(std::move(other.nodes_)), used_nodes_(other.used_nodes_) {
other.used_nodes_ = 0; other.used_nodes_ = 0;
} }
FlatHashMapImpl &operator=(FlatHashMapImpl &&other) noexcept { FlatHashTable &operator=(FlatHashTable &&other) noexcept {
nodes_ = std::move(other.nodes_); nodes_ = std::move(other.nodes_);
used_nodes_ = other.used_nodes_; used_nodes_ = other.used_nodes_;
other.used_nodes_ = 0; other.used_nodes_ = 0;
return *this; return *this;
} }
~FlatHashMapImpl() = default; void swap(FlatHashTable &other) noexcept {
using std::swap;
swap(nodes_, other.nodes_);
swap(used_nodes_, other.used_nodes_);
}
~FlatHashTable() = default;
template <class ItT> template <class ItT>
FlatHashMapImpl(ItT begin, ItT end) { FlatHashTable(ItT begin, ItT end) {
assign(begin, end); assign(begin, end);
} }
size_t bucket_count() const {
return nodes_.size();
}
Iterator find(const KeyT &key) { Iterator find(const KeyT &key) {
if (empty() || is_key_empty(key)) { if (empty() || is_key_empty(key)) {
return end(); return end();
} }
auto bucket = calc_bucket(key); auto bucket = calc_bucket(key);
while (true) { while (true) {
if (nodes_[bucket].key() == key) { if (EqT()(nodes_[bucket].key(), key)) {
return Iterator{nodes_.begin() + bucket, this}; return Iterator{nodes_.begin() + bucket, this};
} }
if (nodes_[bucket].empty()) { if (nodes_[bucket].empty()) {
@ -326,7 +400,7 @@ class FlatHashMapImpl {
CHECK(!is_key_empty(key)); CHECK(!is_key_empty(key));
auto bucket = calc_bucket(key); auto bucket = calc_bucket(key);
while (true) { while (true) {
if (nodes_[bucket].key() == key) { if (EqT()(nodes_[bucket].key(), key)) {
return {Iterator{nodes_.begin() + bucket, this}, false}; return {Iterator{nodes_.begin() + bucket, this}, false};
} }
if (nodes_[bucket].empty()) { if (nodes_[bucket].empty()) {
@ -338,8 +412,19 @@ class FlatHashMapImpl {
} }
} }
ValueT &operator[](const KeyT &key) { std::pair<Iterator, bool> insert(KeyT key) {
return emplace(key).first->second; return emplace(std::move(key));
}
template <class ItT>
void insert(ItT begin, ItT end) {
for (; begin != end; ++begin) {
emplace(*begin);
}
}
value_type &operator[](const KeyT &key) {
return emplace(key).first->value();
} }
size_t erase(const KeyT &key) { size_t erase(const KeyT &key) {
@ -363,7 +448,7 @@ class FlatHashMapImpl {
void erase(Iterator it) { void erase(Iterator it) {
DCHECK(it != end()); DCHECK(it != end());
DCHECK(!is_key_empty(it->key())); DCHECK(!it.it_->empty());
erase_node(it.it_); erase_node(it.it_);
} }
@ -392,9 +477,6 @@ class FlatHashMapImpl {
} }
private: private:
static bool is_key_empty(const KeyT &key) {
return key == KeyT();
}
fixed_vector<Node> nodes_; fixed_vector<Node> nodes_;
size_t used_nodes_{}; size_t used_nodes_{};
@ -496,8 +578,18 @@ class FlatHashMapImpl {
} }
}; };
template <class KeyT, class ValueT, class HashT = std::hash<KeyT>> template <class KeyT, class ValueT, class HashT = std::hash<KeyT>, class EqT = std::equal_to<KeyT>>
using FlatHashMap = FlatHashMapImpl<KeyT, ValueT, HashT>; using FlatHashMapImpl = FlatHashTable<MapNode<KeyT, ValueT>, HashT, EqT>;
//using FlatHashMap = std::unordered_map<KeyT, ValueT, HashT>; template <class KeyT, class HashT = std::hash<KeyT>, class EqT = std::equal_to<KeyT>>
using FlatHashSetImpl = FlatHashTable<SetNode<KeyT>, HashT, EqT>;
template <class KeyT, class ValueT, class HashT = std::hash<KeyT>, class EqT = std::equal_to<KeyT>>
using FlatHashMap = FlatHashMapImpl<KeyT, ValueT, HashT, EqT>;
//using FlatHashMap = std::unordered_map<KeyT, ValueT, HashT, EqT>;
template <class KeyT, class HashT = std::hash<KeyT>, class EqT = std::equal_to<KeyT>>
using FlatHashSet = FlatHashSetImpl<KeyT, HashT, EqT>;
//using FlatHashSet = std::unordered_set<KeyT, HashT, EqT>;
} // namespace td } // namespace td

View File

@ -206,11 +206,11 @@ void table_remove_if(TableT &table, FuncT &&func) {
} }
} }
template <class KeyT, class ValueT, class HashT> template <class NodeT, class HashT, class EqT>
class FlatHashMapImpl; class FlatHashTable;
template <class KeyT, class ValueT, class HashT, class FuncT> template <class NodeT, class HashT, class EqT, class FuncT>
void table_remove_if(FlatHashMapImpl<KeyT, ValueT, HashT> &table, FuncT &&func) { void table_remove_if(FlatHashTable<NodeT, HashT, EqT> &table, FuncT &&func) {
table.remove_if(func); table.remove_if(func);
} }

View File

@ -24,6 +24,17 @@ static auto extract_kv(const T &reference) {
} }
TEST(FlatHashMap, basic) { TEST(FlatHashMap, basic) {
{
td::FlatHashSet<int> s;
s.insert(5);
for (auto x : s) {
}
int N = 100000;
for (int i = 0; i < 10000000; i++) {
s.insert((i + N/2)%N);
s.erase(i%N);
}
}
{ {
td::FlatHashMap<int, int> map; td::FlatHashMap<int, int> map;
map[1] = 2; map[1] = 2;

View File

@ -248,7 +248,7 @@ static void BM_emplace_same(benchmark::State &state) {
while (state.KeepRunningBatch(BATCH_SIZE)) { while (state.KeepRunningBatch(BATCH_SIZE)) {
for (size_t i = 0; i < BATCH_SIZE; i++) { for (size_t i = 0; i < BATCH_SIZE; i++) {
benchmark::DoNotOptimize(table.emplace(key, 43784932)); benchmark::DoNotOptimize(table.emplace(key + (i & 15) * 100 , 43784932));
} }
} }
} }
@ -375,6 +375,28 @@ static void BM_remove_if_slow(benchmark::State &state) {
constexpr size_t N = 5000; constexpr size_t N = 5000;
constexpr size_t BATCH_SIZE = 500000; constexpr size_t BATCH_SIZE = 500000;
TableT table;
td::Random::Xorshift128plus rnd(123);
for (size_t i = 0; i < N; i++) {
table.emplace(rnd() + 1, i);
}
auto first_key = table.begin()->first;
{
size_t cnt = 0;
td::table_remove_if(table, [&cnt](auto &) { cnt += 2; return cnt <= N; });
}
while (state.KeepRunningBatch(BATCH_SIZE)) {
for (size_t i = 0; i < BATCH_SIZE; i++) {
table.emplace(first_key, i);
table.erase(first_key);
}
}
}
template <typename TableT>
static void BM_remove_if_slow_old(benchmark::State &state) {
constexpr size_t N = 100000;
constexpr size_t BATCH_SIZE = 500000;
TableT table; TableT table;
while (state.KeepRunningBatch(BATCH_SIZE)) { while (state.KeepRunningBatch(BATCH_SIZE)) {
td::Random::Xorshift128plus rnd(123); td::Random::Xorshift128plus rnd(123);
@ -452,14 +474,13 @@ static void benchmark_create(td::Slice name) {
#define REGISTER_ERASE_ALL_BENCHMARK(HT) BENCHMARK_TEMPLATE(BM_erase_all_with_begin, HT<td::uint64, td::uint64>); #define REGISTER_ERASE_ALL_BENCHMARK(HT) BENCHMARK_TEMPLATE(BM_erase_all_with_begin, HT<td::uint64, td::uint64>);
#define REGISTER_REMOVE_IF_SLOW_BENCHMARK(HT) BENCHMARK_TEMPLATE(BM_remove_if_slow, HT<td::uint64, td::uint64>); #define REGISTER_REMOVE_IF_SLOW_BENCHMARK(HT) BENCHMARK_TEMPLATE(BM_remove_if_slow, HT<td::uint64, td::uint64>);
FOR_EACH_TABLE(REGISTER_CACHE3_BENCHMARK)
FOR_EACH_TABLE(REGISTER_CACHE_BENCHMARK)
FOR_EACH_TABLE(REGISTER_CACHE2_BENCHMARK)
FOR_EACH_TABLE(REGISTER_REMOVE_IF_SLOW_BENCHMARK) FOR_EACH_TABLE(REGISTER_REMOVE_IF_SLOW_BENCHMARK)
FOR_EACH_TABLE(REGISTER_ERASE_ALL_BENCHMARK) FOR_EACH_TABLE(REGISTER_CACHE3_BENCHMARK)
FOR_EACH_TABLE(REGISTER_CACHE2_BENCHMARK)
FOR_EACH_TABLE(REGISTER_CACHE_BENCHMARK) FOR_EACH_TABLE(REGISTER_CACHE_BENCHMARK)
FOR_EACH_TABLE(REGISTER_REMOVE_IF_BENCHMARK) FOR_EACH_TABLE(REGISTER_REMOVE_IF_BENCHMARK)
FOR_EACH_TABLE(REGISTER_EMPLACE_BENCHMARK) FOR_EACH_TABLE(REGISTER_EMPLACE_BENCHMARK)
FOR_EACH_TABLE(REGISTER_ERASE_ALL_BENCHMARK)
FOR_EACH_TABLE(REGISTER_GET_BENCHMARK) FOR_EACH_TABLE(REGISTER_GET_BENCHMARK)
FOR_EACH_TABLE(REGISTER_FIND_BENCHMARK) FOR_EACH_TABLE(REGISTER_FIND_BENCHMARK)