FlatHashSet

This commit is contained in:
Arseny Smirnov 2022-02-11 17:40:16 +01:00
parent 15f27455c5
commit ecceb51881
5 changed files with 231 additions and 107 deletions

View File

@ -33,7 +33,7 @@ static bool use_memprof() {
#endif
}
static auto get_memory() {
static td::uint64 get_memory() {
#if USE_MEMPROF
if (use_memprof()) {
return get_used_memory_size();

View File

@ -73,81 +73,142 @@ class fixed_vector {
size_t size_{0};
};
template <class KeyT, class ValueT, class HashT = std::hash<KeyT>>
class FlatHashMapImpl {
public:
struct Node {
using first_type = KeyT;
using second_type = ValueT;
KeyT first{};
union {
ValueT second;
};
const auto &key() const {
return first;
}
auto &value() {
return second;
}
// TODO: move
template <class KeyT>
bool is_key_empty(const KeyT &key) {
return key == KeyT();
}
Node() {
}
Node(KeyT key, ValueT value) : first(std::move(key)) {
new (&second) ValueT(std::move(value));
DCHECK(!empty());
}
~Node() {
if (!empty()) {
second.~ValueT();
}
}
Node(Node &&other) noexcept {
*this = std::move(other);
}
Node &operator=(Node &&other) noexcept {
DCHECK(empty());
DCHECK(!other.empty());
first = std::move(other.first);
other.first = KeyT{};
new (&second) ValueT(std::move(other.second));
other.second.~ValueT();
return *this;
}
bool empty() const {
return is_key_empty(key());
}
void clear() {
DCHECK(!empty());
first = KeyT();
second.~ValueT();
DCHECK(empty());
}
template <class... ArgsT>
void emplace(KeyT key, ArgsT &&...args) {
DCHECK(empty());
first = std::move(key);
new (&second) ValueT(std::forward<ArgsT>(args)...);
DCHECK(!empty());
}
template <class KeyT, class ValueT>
struct MapNode {
using first_type = KeyT;
using second_type = ValueT;
using key_type = KeyT;
using public_type = MapNode<KeyT, ValueT>;
using value_type = ValueT;
KeyT first{};
union {
ValueT second;
};
using Self = FlatHashMapImpl<KeyT, ValueT, HashT>;
const auto &key() const {
return first;
}
auto &value() {
return second;
}
auto &get_public() {
return *this;
}
MapNode() {
}
MapNode(KeyT key, ValueT value) : first(std::move(key)) {
new (&second) ValueT(std::move(value));
DCHECK(!empty());
}
~MapNode() {
if (!empty()) {
second.~ValueT();
}
}
MapNode(MapNode &&other) noexcept {
*this = std::move(other);
}
MapNode &operator=(MapNode &&other) noexcept {
DCHECK(empty());
DCHECK(!other.empty());
first = std::move(other.first);
other.first = KeyT{};
new (&second) ValueT(std::move(other.second));
other.second.~ValueT();
return *this;
}
bool empty() const {
return is_key_empty(key());
}
void clear() {
DCHECK(!empty());
first = KeyT();
second.~ValueT();
DCHECK(empty());
}
template <class... ArgsT>
void emplace(KeyT key, ArgsT &&...args) {
DCHECK(empty());
first = std::move(key);
new (&second) ValueT(std::forward<ArgsT>(args)...);
DCHECK(!empty());
}
};
template <class KeyT>
struct SetNode {
using first_type = KeyT;
using key_type = KeyT;
using public_type = KeyT;
using value_type = KeyT;
KeyT first{};
const auto &key() const {
return first;
}
const auto &value() const {
return first;
}
auto &get_public() {
return first;
}
SetNode() = default;
explicit SetNode(KeyT key) : first(std::move(key)) {
}
SetNode(SetNode &&other) noexcept {
*this = std::move(other);
}
SetNode &operator=(SetNode &&other) noexcept {
DCHECK(empty());
DCHECK(!other.empty());
first = std::move(other.first);
other.first = KeyT{};
return *this;
}
bool empty() const {
return is_key_empty(key());
}
void clear() {
first = KeyT();
CHECK(empty());
}
void emplace(KeyT key) {
first = std::move(key);
}
};
template <class NodeT, class HashT, class EqT>
class FlatHashTable {
public:
using Self = FlatHashTable<NodeT, HashT, EqT>;
using Node = NodeT;
using NodeIterator = typename fixed_vector<Node>::iterator;
using ConstNodeIterator = typename fixed_vector<Node>::const_iterator;
using key_type = KeyT;
using value_type = Node;
using KeyT = typename Node::key_type;
using public_type = typename Node::public_type;
using value_type = typename Node::value_type;
struct Iterator {
using iterator_category = std::bidirectional_iterator_tag;
using difference_type = std::ptrdiff_t;
using value_type = Node;
using pointer = Node *;
using reference = Node &;
using value_type = public_type;
using pointer = public_type *;
using reference = public_type &;
friend class FlatHashMapImpl;
friend class FlatHashTable;
Iterator &operator++() {
do {
++it_;
@ -160,21 +221,22 @@ class FlatHashMapImpl {
} while (it_->empty());
return *this;
}
Node &operator*() {
return *it_;
reference operator*() {
return it_->get_public();
}
Node *operator->() {
pointer operator->() {
return &*it_;
}
bool operator==(const Iterator &other) const {
DCHECK(map_ == other.map_);
DCHECK(map_ == other.map_);
return it_ == other.it_;
}
bool operator!=(const Iterator &other) const {
DCHECK(map_ == other.map_);
DCHECK(map_ == other.map_);
return it_ != other.it_;
}
Iterator() = default;
Iterator(NodeIterator it, Self *map) : it_(std::move(it)), map_(map) {
}
@ -186,11 +248,11 @@ class FlatHashMapImpl {
struct ConstIterator {
using iterator_category = std::bidirectional_iterator_tag;
using difference_type = std::ptrdiff_t;
using value_type = Node;
using pointer = const Node *;
using reference = const Node &;
using value_type = public_type;
using pointer = const value_type *;
using reference = const value_type &;
friend class FlatHashMapImpl;
friend class FlatHashTable;
ConstIterator &operator++() {
++it_;
return *this;
@ -199,10 +261,10 @@ class FlatHashMapImpl {
--it_;
return *this;
}
const Node &operator*() {
reference operator*() {
return *it_;
}
const Node *operator->() {
pointer operator->() {
return &*it_;
}
bool operator==(const ConstIterator &other) const {
@ -212,25 +274,28 @@ class FlatHashMapImpl {
return it_ != other.it_;
}
explicit ConstIterator(Iterator it) : it_(std::move(it)) {
ConstIterator() = default;
ConstIterator(Iterator it) : it_(std::move(it)) {
}
private:
Iterator it_;
};
using iterator = Iterator;
using const_iterator = ConstIterator;
FlatHashMapImpl() = default;
FlatHashMapImpl(const FlatHashMapImpl &other) : FlatHashMapImpl(other.begin(), other.end()) {
FlatHashTable() = default;
FlatHashTable(const FlatHashTable &other) : FlatHashTable(other.begin(), other.end()) {
}
FlatHashMapImpl &operator=(const FlatHashMapImpl &other) {
FlatHashTable &operator=(const FlatHashTable &other) {
assign(other.begin(), other.end());
return *this;
}
FlatHashMapImpl(std::initializer_list<Node> nodes) {
FlatHashTable(std::initializer_list<Node> nodes) {
reserve(nodes.size());
for (auto &node : nodes) {
CHECK(!is_key_empty(node.first));
CHECK(!node.empty());
auto bucket = calc_bucket(node.first);
while (true) {
if (nodes_[bucket].key() == node.first) {
@ -247,29 +312,38 @@ class FlatHashMapImpl {
}
}
FlatHashMapImpl(FlatHashMapImpl &&other) noexcept : nodes_(std::move(other.nodes_)), used_nodes_(other.used_nodes_) {
FlatHashTable(FlatHashTable &&other) noexcept : nodes_(std::move(other.nodes_)), used_nodes_(other.used_nodes_) {
other.used_nodes_ = 0;
}
FlatHashMapImpl &operator=(FlatHashMapImpl &&other) noexcept {
FlatHashTable &operator=(FlatHashTable &&other) noexcept {
nodes_ = std::move(other.nodes_);
used_nodes_ = other.used_nodes_;
other.used_nodes_ = 0;
return *this;
}
~FlatHashMapImpl() = default;
void swap(FlatHashTable &other) noexcept {
using std::swap;
swap(nodes_, other.nodes_);
swap(used_nodes_, other.used_nodes_);
}
~FlatHashTable() = default;
template <class ItT>
FlatHashMapImpl(ItT begin, ItT end) {
FlatHashTable(ItT begin, ItT end) {
assign(begin, end);
}
size_t bucket_count() const {
return nodes_.size();
}
Iterator find(const KeyT &key) {
if (empty() || is_key_empty(key)) {
return end();
}
auto bucket = calc_bucket(key);
while (true) {
if (nodes_[bucket].key() == key) {
if (EqT()(nodes_[bucket].key(), key)) {
return Iterator{nodes_.begin() + bucket, this};
}
if (nodes_[bucket].empty()) {
@ -326,7 +400,7 @@ class FlatHashMapImpl {
CHECK(!is_key_empty(key));
auto bucket = calc_bucket(key);
while (true) {
if (nodes_[bucket].key() == key) {
if (EqT()(nodes_[bucket].key(), key)) {
return {Iterator{nodes_.begin() + bucket, this}, false};
}
if (nodes_[bucket].empty()) {
@ -338,8 +412,19 @@ class FlatHashMapImpl {
}
}
ValueT &operator[](const KeyT &key) {
return emplace(key).first->second;
std::pair<Iterator, bool> insert(KeyT key) {
return emplace(std::move(key));
}
template <class ItT>
void insert(ItT begin, ItT end) {
for (; begin != end; ++begin) {
emplace(*begin);
}
}
value_type &operator[](const KeyT &key) {
return emplace(key).first->value();
}
size_t erase(const KeyT &key) {
@ -363,7 +448,7 @@ class FlatHashMapImpl {
void erase(Iterator it) {
DCHECK(it != end());
DCHECK(!is_key_empty(it->key()));
DCHECK(!it.it_->empty());
erase_node(it.it_);
}
@ -392,9 +477,6 @@ class FlatHashMapImpl {
}
private:
static bool is_key_empty(const KeyT &key) {
return key == KeyT();
}
fixed_vector<Node> nodes_;
size_t used_nodes_{};
@ -496,8 +578,18 @@ class FlatHashMapImpl {
}
};
template <class KeyT, class ValueT, class HashT = std::hash<KeyT>>
using FlatHashMap = FlatHashMapImpl<KeyT, ValueT, HashT>;
//using FlatHashMap = std::unordered_map<KeyT, ValueT, HashT>;
template <class KeyT, class ValueT, class HashT = std::hash<KeyT>, class EqT = std::equal_to<KeyT>>
using FlatHashMapImpl = FlatHashTable<MapNode<KeyT, ValueT>, HashT, EqT>;
template <class KeyT, class HashT = std::hash<KeyT>, class EqT = std::equal_to<KeyT>>
using FlatHashSetImpl = FlatHashTable<SetNode<KeyT>, HashT, EqT>;
template <class KeyT, class ValueT, class HashT = std::hash<KeyT>, class EqT = std::equal_to<KeyT>>
using FlatHashMap = FlatHashMapImpl<KeyT, ValueT, HashT, EqT>;
//using FlatHashMap = std::unordered_map<KeyT, ValueT, HashT, EqT>;
template <class KeyT, class HashT = std::hash<KeyT>, class EqT = std::equal_to<KeyT>>
using FlatHashSet = FlatHashSetImpl<KeyT, HashT, EqT>;
//using FlatHashSet = std::unordered_set<KeyT, HashT, EqT>;
} // namespace td

View File

@ -206,11 +206,11 @@ void table_remove_if(TableT &table, FuncT &&func) {
}
}
template <class KeyT, class ValueT, class HashT>
class FlatHashMapImpl;
template <class NodeT, class HashT, class EqT>
class FlatHashTable;
template <class KeyT, class ValueT, class HashT, class FuncT>
void table_remove_if(FlatHashMapImpl<KeyT, ValueT, HashT> &table, FuncT &&func) {
template <class NodeT, class HashT, class EqT, class FuncT>
void table_remove_if(FlatHashTable<NodeT, HashT, EqT> &table, FuncT &&func) {
table.remove_if(func);
}

View File

@ -24,6 +24,17 @@ static auto extract_kv(const T &reference) {
}
TEST(FlatHashMap, basic) {
{
td::FlatHashSet<int> s;
s.insert(5);
for (auto x : s) {
}
int N = 100000;
for (int i = 0; i < 10000000; i++) {
s.insert((i + N/2)%N);
s.erase(i%N);
}
}
{
td::FlatHashMap<int, int> map;
map[1] = 2;

View File

@ -248,7 +248,7 @@ static void BM_emplace_same(benchmark::State &state) {
while (state.KeepRunningBatch(BATCH_SIZE)) {
for (size_t i = 0; i < BATCH_SIZE; i++) {
benchmark::DoNotOptimize(table.emplace(key, 43784932));
benchmark::DoNotOptimize(table.emplace(key + (i & 15) * 100 , 43784932));
}
}
}
@ -375,6 +375,28 @@ static void BM_remove_if_slow(benchmark::State &state) {
constexpr size_t N = 5000;
constexpr size_t BATCH_SIZE = 500000;
TableT table;
td::Random::Xorshift128plus rnd(123);
for (size_t i = 0; i < N; i++) {
table.emplace(rnd() + 1, i);
}
auto first_key = table.begin()->first;
{
size_t cnt = 0;
td::table_remove_if(table, [&cnt](auto &) { cnt += 2; return cnt <= N; });
}
while (state.KeepRunningBatch(BATCH_SIZE)) {
for (size_t i = 0; i < BATCH_SIZE; i++) {
table.emplace(first_key, i);
table.erase(first_key);
}
}
}
template <typename TableT>
static void BM_remove_if_slow_old(benchmark::State &state) {
constexpr size_t N = 100000;
constexpr size_t BATCH_SIZE = 500000;
TableT table;
while (state.KeepRunningBatch(BATCH_SIZE)) {
td::Random::Xorshift128plus rnd(123);
@ -452,14 +474,13 @@ static void benchmark_create(td::Slice name) {
#define REGISTER_ERASE_ALL_BENCHMARK(HT) BENCHMARK_TEMPLATE(BM_erase_all_with_begin, HT<td::uint64, td::uint64>);
#define REGISTER_REMOVE_IF_SLOW_BENCHMARK(HT) BENCHMARK_TEMPLATE(BM_remove_if_slow, HT<td::uint64, td::uint64>);
FOR_EACH_TABLE(REGISTER_CACHE3_BENCHMARK)
FOR_EACH_TABLE(REGISTER_CACHE_BENCHMARK)
FOR_EACH_TABLE(REGISTER_CACHE2_BENCHMARK)
FOR_EACH_TABLE(REGISTER_REMOVE_IF_SLOW_BENCHMARK)
FOR_EACH_TABLE(REGISTER_ERASE_ALL_BENCHMARK)
FOR_EACH_TABLE(REGISTER_CACHE3_BENCHMARK)
FOR_EACH_TABLE(REGISTER_CACHE2_BENCHMARK)
FOR_EACH_TABLE(REGISTER_CACHE_BENCHMARK)
FOR_EACH_TABLE(REGISTER_REMOVE_IF_BENCHMARK)
FOR_EACH_TABLE(REGISTER_EMPLACE_BENCHMARK)
FOR_EACH_TABLE(REGISTER_ERASE_ALL_BENCHMARK)
FOR_EACH_TABLE(REGISTER_GET_BENCHMARK)
FOR_EACH_TABLE(REGISTER_FIND_BENCHMARK)