diff --git a/tdutils/td/utils/FlatHashTable.h b/tdutils/td/utils/FlatHashTable.h index ec90feb74..9d87bca98 100644 --- a/tdutils/td/utils/FlatHashTable.h +++ b/tdutils/td/utils/FlatHashTable.h @@ -136,6 +136,64 @@ class FlatHashTable { using iterator = Iterator; using const_iterator = ConstIterator; + struct NodePointer { + value_type &operator*() { + return it_->get_public(); + } + const value_type &operator*() const { + return it_->get_public(); + } + value_type *operator->() { + return &it_->get_public(); + } + const value_type *operator->() const { + return &it_->get_public(); + } + + NodeT *get() { + return it_; + } + + bool operator==(const EndSentinel &other) const { + return it_ == nullptr; + } + bool operator!=(const EndSentinel &other) const { + return it_ != nullptr; + } + + explicit NodePointer(NodeT *it) : it_(it) { + } + + private: + NodeT *it_ = nullptr; + }; + + struct ConstNodePointer { + const value_type &operator*() const { + return it_->get_public(); + } + const value_type *operator->() const { + return &it_->get_public(); + } + + bool operator==(const EndSentinel &other) const { + return it_ == nullptr; + } + bool operator!=(const EndSentinel &other) const { + return it_ != nullptr; + } + + const NodeT *get() const { + return it_; + } + + explicit ConstNodePointer(const NodeT *it) : it_(it) { + } + + private: + const NodeT *it_ = nullptr; + }; + FlatHashTable() = default; FlatHashTable(const FlatHashTable &other) { assign(other); @@ -205,12 +263,12 @@ class FlatHashTable { return bucket_count_; } - Iterator find(const KeyT &key) { - return create_iterator(find_impl(key)); + NodePointer find(const KeyT &key) { + return NodePointer(find_impl(key)); } - ConstIterator find(const KeyT &key) const { - return ConstIterator(const_cast(this)->find(key)); + ConstNodePointer find(const KeyT &key) const { + return ConstNodePointer(const_cast(this)->find_impl(key)); } size_t size() const { @@ -244,25 +302,25 @@ class FlatHashTable { } template - std::pair emplace(KeyT key, ArgsT &&...args) { + std::pair emplace(KeyT key, ArgsT &&...args) { try_grow(); CHECK(!is_hash_table_key_empty(key)); auto bucket = calc_bucket(key); while (true) { auto &node = nodes_[bucket]; if (EqT()(node.key(), key)) { - return {create_iterator(&node), false}; + return {NodePointer(&node), false}; } if (node.empty()) { node.emplace(std::move(key), std::forward(args)...); used_node_count_++; - return {create_iterator(&node), true}; + return {NodePointer(&node), true}; } next_bucket(bucket); } } - std::pair insert(KeyT key) { + std::pair insert(KeyT key) { return emplace(std::move(key)); } @@ -305,6 +363,12 @@ class FlatHashTable { try_shrink(); } + void erase(NodePointer it) { + DCHECK(it != end()); + erase_node(it.get()); + try_shrink(); + } + template void remove_if(F &&f) { if (empty()) { diff --git a/tdutils/td/utils/MapNode.h b/tdutils/td/utils/MapNode.h index 0c769e4dc..3304fcd56 100644 --- a/tdutils/td/utils/MapNode.h +++ b/tdutils/td/utils/MapNode.h @@ -34,6 +34,10 @@ struct MapNode { return *this; } + const MapNode &get_public() const { + return *this; + } + MapNode() { } MapNode(KeyT key, ValueT value) : first(std::move(key)) { diff --git a/tdutils/td/utils/SetNode.h b/tdutils/td/utils/SetNode.h index 05d22722c..3d1ca522a 100644 --- a/tdutils/td/utils/SetNode.h +++ b/tdutils/td/utils/SetNode.h @@ -14,7 +14,7 @@ namespace td { template struct SetNode { using public_key_type = KeyT; - using public_type = KeyT; + using public_type = const KeyT; using second_type = KeyT; // TODO: remove second_type? KeyT first{}; @@ -23,7 +23,7 @@ struct SetNode { return first; } - KeyT &get_public() { + const KeyT &get_public() { return first; }