Return NodePointer instead of iterator in find/emplace.

This commit is contained in:
levlam 2022-03-12 15:27:14 +03:00
parent 62f463b421
commit daef14ade1
3 changed files with 78 additions and 10 deletions

View File

@ -136,6 +136,64 @@ class FlatHashTable {
using iterator = Iterator;
using const_iterator = ConstIterator;
struct NodePointer {
value_type &operator*() {
return it_->get_public();
}
const value_type &operator*() const {
return it_->get_public();
}
value_type *operator->() {
return &it_->get_public();
}
const value_type *operator->() const {
return &it_->get_public();
}
NodeT *get() {
return it_;
}
bool operator==(const EndSentinel &other) const {
return it_ == nullptr;
}
bool operator!=(const EndSentinel &other) const {
return it_ != nullptr;
}
explicit NodePointer(NodeT *it) : it_(it) {
}
private:
NodeT *it_ = nullptr;
};
struct ConstNodePointer {
const value_type &operator*() const {
return it_->get_public();
}
const value_type *operator->() const {
return &it_->get_public();
}
bool operator==(const EndSentinel &other) const {
return it_ == nullptr;
}
bool operator!=(const EndSentinel &other) const {
return it_ != nullptr;
}
const NodeT *get() const {
return it_;
}
explicit ConstNodePointer(const NodeT *it) : it_(it) {
}
private:
const NodeT *it_ = nullptr;
};
FlatHashTable() = default;
FlatHashTable(const FlatHashTable &other) {
assign(other);
@ -205,12 +263,12 @@ class FlatHashTable {
return bucket_count_;
}
Iterator find(const KeyT &key) {
return create_iterator(find_impl(key));
NodePointer find(const KeyT &key) {
return NodePointer(find_impl(key));
}
ConstIterator find(const KeyT &key) const {
return ConstIterator(const_cast<FlatHashTable *>(this)->find(key));
ConstNodePointer find(const KeyT &key) const {
return ConstNodePointer(const_cast<FlatHashTable *>(this)->find_impl(key));
}
size_t size() const {
@ -244,25 +302,25 @@ class FlatHashTable {
}
template <class... ArgsT>
std::pair<Iterator, bool> emplace(KeyT key, ArgsT &&...args) {
std::pair<NodePointer, bool> emplace(KeyT key, ArgsT &&...args) {
try_grow();
CHECK(!is_hash_table_key_empty(key));
auto bucket = calc_bucket(key);
while (true) {
auto &node = nodes_[bucket];
if (EqT()(node.key(), key)) {
return {create_iterator(&node), false};
return {NodePointer(&node), false};
}
if (node.empty()) {
node.emplace(std::move(key), std::forward<ArgsT>(args)...);
used_node_count_++;
return {create_iterator(&node), true};
return {NodePointer(&node), true};
}
next_bucket(bucket);
}
}
std::pair<Iterator, bool> insert(KeyT key) {
std::pair<NodePointer, bool> insert(KeyT key) {
return emplace(std::move(key));
}
@ -305,6 +363,12 @@ class FlatHashTable {
try_shrink();
}
void erase(NodePointer it) {
DCHECK(it != end());
erase_node(it.get());
try_shrink();
}
template <class F>
void remove_if(F &&f) {
if (empty()) {

View File

@ -34,6 +34,10 @@ struct MapNode {
return *this;
}
const MapNode &get_public() const {
return *this;
}
MapNode() {
}
MapNode(KeyT key, ValueT value) : first(std::move(key)) {

View File

@ -14,7 +14,7 @@ namespace td {
template <class KeyT>
struct SetNode {
using public_key_type = KeyT;
using public_type = KeyT;
using public_type = const KeyT;
using second_type = KeyT; // TODO: remove second_type?
KeyT first{};
@ -23,7 +23,7 @@ struct SetNode {
return first;
}
KeyT &get_public() {
const KeyT &get_public() {
return first;
}