//
// Copyright Aliaksei Levin (levlam@telegram.org), Arseny Smirnov (arseny30@gmail.com) 2014-2023
//
// Distributed under the Boost Software License, Version 1.0. (See accompanying
// file LICENSE_1_0.txt or copy at http://www.boost.org/LICENSE_1_0.txt)
//
#pragma once

#include "td/utils/common.h"
#include "td/utils/HashTableUtils.h"

#include <cstddef>
#include <initializer_list>
#include <iterator>
#include <utility>

namespace td {

namespace detail {
uint32 normalize_flat_hash_table_size(uint32 size);
uint32 get_random_flat_hash_table_bucket(uint32 bucket_count_mask);
}  // namespace detail

template <class NodeT, class HashT, class EqT>
class FlatHashTable {
  static constexpr uint32 INVALID_BUCKET = 0xFFFFFFFF;

  void allocate_nodes(uint32 size) {
    DCHECK(size >= 8);
    DCHECK((size & (size - 1)) == 0);
    CHECK(size <= min(static_cast<uint32>(1) << 29, static_cast<uint32>(0x7FFFFFFF / sizeof(NodeT))));
    nodes_ = new NodeT[size];
    // used_node_count_ = 0;
    bucket_count_mask_ = size - 1;
    bucket_count_ = size;
    begin_bucket_ = INVALID_BUCKET;
  }

  static void clear_nodes(NodeT *nodes) {
    delete[] nodes;
  }

 public:
  using KeyT = typename NodeT::public_key_type;
  using key_type = typename NodeT::public_key_type;
  using value_type = typename NodeT::public_type;

  // TODO use EndSentinel for end() after switching to C++17
  // struct EndSentinel {};

  struct Iterator {
    using iterator_category = std::forward_iterator_tag;
    using difference_type = std::ptrdiff_t;
    using value_type = typename NodeT::public_type;
    using pointer = value_type *;
    using reference = value_type &;

    Iterator &operator++() {
      DCHECK(it_ != nullptr);
      do {
        if (unlikely(++it_ == end_)) {
          it_ = begin_;
        }
        if (unlikely(it_ == start_)) {
          it_ = nullptr;
          break;
        }
      } while (it_->empty());
      return *this;
    }
    reference operator*() {
      return it_->get_public();
    }
    const value_type &operator*() const {
      return it_->get_public();
    }
    pointer operator->() {
      return &it_->get_public();
    }
    const value_type *operator->() const {
      return &it_->get_public();
    }

    NodeT *get() {
      return it_;
    }

    bool operator==(const Iterator &other) const {
      DCHECK(other.it_ == nullptr);
      return it_ == nullptr;
    }
    bool operator!=(const Iterator &other) const {
      DCHECK(other.it_ == nullptr);
      return it_ != nullptr;
    }

    Iterator() = default;
    Iterator(NodeT *it, NodeT *begin, NodeT *end) : it_(it), begin_(begin), start_(it), end_(end) {
    }

   private:
    NodeT *it_ = nullptr;
    NodeT *begin_ = nullptr;
    NodeT *start_ = nullptr;
    NodeT *end_ = nullptr;
  };

  struct ConstIterator {
    using iterator_category = std::forward_iterator_tag;
    using difference_type = std::ptrdiff_t;
    using value_type = typename NodeT::public_type;
    using pointer = const value_type *;
    using reference = const value_type &;

    ConstIterator &operator++() {
      ++it_;
      return *this;
    }
    reference operator*() const {
      return *it_;
    }
    pointer operator->() const {
      return &*it_;
    }
    bool operator==(const ConstIterator &other) const {
      return it_ == other.it_;
    }
    bool operator!=(const ConstIterator &other) const {
      return it_ != other.it_;
    }

    ConstIterator() = default;
    ConstIterator(Iterator it) : it_(std::move(it)) {
    }

   private:
    Iterator it_;
  };
  using iterator = Iterator;
  using const_iterator = ConstIterator;

  struct NodePointer {
    value_type &operator*() {
      return it_->get_public();
    }
    const value_type &operator*() const {
      return it_->get_public();
    }
    value_type *operator->() {
      return &it_->get_public();
    }
    const value_type *operator->() const {
      return &it_->get_public();
    }

    NodeT *get() {
      return it_;
    }

    bool operator==(const Iterator &) const {
      return it_ == nullptr;
    }
    bool operator!=(const Iterator &) const {
      return it_ != nullptr;
    }

    explicit NodePointer(NodeT *it) : it_(it) {
    }

   private:
    NodeT *it_ = nullptr;
  };

  struct ConstNodePointer {
    const value_type &operator*() const {
      return it_->get_public();
    }
    const value_type *operator->() const {
      return &it_->get_public();
    }

    bool operator==(const ConstIterator &) const {
      return it_ == nullptr;
    }
    bool operator!=(const ConstIterator &) const {
      return it_ != nullptr;
    }

    const NodeT *get() const {
      return it_;
    }

    explicit ConstNodePointer(const NodeT *it) : it_(it) {
    }

   private:
    const NodeT *it_ = nullptr;
  };

  FlatHashTable() = default;
  FlatHashTable(const FlatHashTable &) = delete;
  FlatHashTable &operator=(const FlatHashTable &) = delete;

  FlatHashTable(std::initializer_list<NodeT> nodes) {
    if (nodes.size() == 0) {
      return;
    }
    reserve(nodes.size());
    uint32 used_nodes = 0;
    for (auto &new_node : nodes) {
      CHECK(!new_node.empty());
      auto bucket = calc_bucket(new_node.key());
      while (true) {
        auto &node = nodes_[bucket];
        if (node.empty()) {
          node.copy_from(new_node);
          used_nodes++;
          break;
        }
        if (EqT()(node.key(), new_node.key())) {
          break;
        }
        next_bucket(bucket);
      }
    }
    used_node_count_ = used_nodes;
  }

  template <class T>
  FlatHashTable(std::initializer_list<T> keys) {
    for (auto &key : keys) {
      emplace(KeyT(key));
    }
  }

  FlatHashTable(FlatHashTable &&other) noexcept
      : nodes_(other.nodes_)
      , used_node_count_(other.used_node_count_)
      , bucket_count_mask_(other.bucket_count_mask_)
      , bucket_count_(other.bucket_count_)
      , begin_bucket_(other.begin_bucket_) {
    other.drop();
  }
  void operator=(FlatHashTable &&other) noexcept {
    clear();
    nodes_ = other.nodes_;
    used_node_count_ = other.used_node_count_;
    bucket_count_mask_ = other.bucket_count_mask_;
    bucket_count_ = other.bucket_count_;
    begin_bucket_ = other.begin_bucket_;
    other.drop();
  }
  ~FlatHashTable() {
    clear_nodes(nodes_);
  }

  void swap(FlatHashTable &other) noexcept {
    std::swap(nodes_, other.nodes_);
    std::swap(used_node_count_, other.used_node_count_);
    std::swap(bucket_count_mask_, other.bucket_count_mask_);
    std::swap(bucket_count_, other.bucket_count_);
    std::swap(begin_bucket_, other.begin_bucket_);
  }

  uint32 bucket_count() const {
    return bucket_count_;
  }

  NodePointer find(const KeyT &key) {
    return NodePointer(find_impl(key));
  }

  ConstNodePointer find(const KeyT &key) const {
    return ConstNodePointer(const_cast<FlatHashTable *>(this)->find_impl(key));
  }

  size_t size() const {
    return used_node_count_;
  }

  bool empty() const {
    return used_node_count_ == 0;
  }

  Iterator begin() {
    return create_iterator(begin_impl());
  }
  Iterator end() {
    return Iterator();
  }
  ConstIterator begin() const {
    return ConstIterator(const_cast<FlatHashTable *>(this)->begin());
  }
  ConstIterator end() const {
    return ConstIterator();
  }

  void reserve(size_t size) {
    if (size == 0) {
      return;
    }
    CHECK(size <= (1u << 29));
    uint32 want_size = detail::normalize_flat_hash_table_size(static_cast<uint32>(size) * 5 / 3 + 1);
    if (want_size > bucket_count()) {
      resize(want_size);
    }
  }

  template <class... ArgsT>
  std::pair<NodePointer, bool> emplace(KeyT key, ArgsT &&...args) {
    CHECK(!is_hash_table_key_empty(key));
    if (unlikely(bucket_count_mask_ == 0)) {
      CHECK(used_node_count_ == 0);
      resize(8);
    }
    auto bucket = calc_bucket(key);
    while (true) {
      auto &node = nodes_[bucket];
      if (node.empty()) {
        if (unlikely(used_node_count_ * 5 >= bucket_count_mask_ * 3)) {
          resize(2 * bucket_count_);
          CHECK(used_node_count_ * 5 < bucket_count_mask_ * 3);
          return emplace(std::move(key), std::forward<ArgsT>(args)...);
        }
        invalidate_iterators();

        node.emplace(std::move(key), std::forward<ArgsT>(args)...);
        used_node_count_++;
        return {NodePointer(&node), true};
      }
      if (EqT()(node.key(), key)) {
        return {NodePointer(&node), false};
      }
      next_bucket(bucket);
    }
  }

  std::pair<NodePointer, bool> insert(KeyT key) {
    return emplace(std::move(key));
  }

  template <class ItT>
  void insert(ItT begin, ItT end) {
    for (; begin != end; ++begin) {
      emplace(*begin);
    }
  }

  template <class T = typename NodeT::second_type>
  T &operator[](const KeyT &key) {
    return emplace(key).first->second;
  }

  size_t erase(const KeyT &key) {
    auto *node = find_impl(key);
    if (node == nullptr) {
      return 0;
    }
    erase_node(node);
    try_shrink();
    return 1;
  }

  size_t count(const KeyT &key) const {
    return const_cast<FlatHashTable *>(this)->find_impl(key) != nullptr;
  }

  void clear() {
    if (nodes_ != nullptr) {
      clear_nodes(nodes_);
      drop();
    }
  }

  void erase(Iterator it) {
    DCHECK(it != end());
    erase_node(it.get());
    try_shrink();
  }

  void erase(NodePointer it) {
    DCHECK(it != end());
    erase_node(it.get());
    try_shrink();
  }

  template <class F>
  void remove_if(F &&f) {
    if (empty()) {
      return;
    }

    auto it = begin_impl();
    auto end = nodes_ + bucket_count();
    while (it != end && !it->empty()) {
      ++it;
    }
    if (it == end) {
      do {
        --it;
      } while (!it->empty());
    }
    auto first_empty = it;
    while (it != end) {
      if (!it->empty() && f(it->get_public())) {
        erase_node(it);
      } else {
        ++it;
      }
    }
    for (it = nodes_; it != first_empty;) {
      if (!it->empty() && f(it->get_public())) {
        erase_node(it);
      } else {
        ++it;
      }
    }
    try_shrink();
  }

 private:
  NodeT *nodes_ = nullptr;
  uint32 used_node_count_ = 0;
  uint32 bucket_count_mask_ = 0;
  uint32 bucket_count_ = 0;
  uint32 begin_bucket_ = 0;

  void drop() {
    nodes_ = nullptr;
    used_node_count_ = 0;
    bucket_count_mask_ = 0;
    bucket_count_ = 0;
    begin_bucket_ = 0;
  }

  NodeT *begin_impl() {
    if (empty()) {
      return nullptr;
    }
    if (begin_bucket_ == INVALID_BUCKET) {
      begin_bucket_ = detail::get_random_flat_hash_table_bucket(bucket_count_mask_);
      while (nodes_[begin_bucket_].empty()) {
        next_bucket(begin_bucket_);
      }
    }
    return nodes_ + begin_bucket_;
  }

  NodeT *find_impl(const KeyT &key) {
    if (unlikely(nodes_ == nullptr) || is_hash_table_key_empty(key)) {
      return nullptr;
    }
    auto bucket = calc_bucket(key);
    while (true) {
      auto &node = nodes_[bucket];
      if (node.empty()) {
        return nullptr;
      }
      if (EqT()(node.key(), key)) {
        return &node;
      }
      next_bucket(bucket);
    }
  }

  void try_shrink() {
    DCHECK(nodes_ != nullptr);
    if (unlikely(used_node_count_ * 10 < bucket_count_mask_ && bucket_count_mask_ > 7)) {
      resize(detail::normalize_flat_hash_table_size((used_node_count_ + 1) * 5 / 3 + 1));
    }
    invalidate_iterators();
  }

  uint32 calc_bucket(const KeyT &key) const {
    return HashT()(key) & bucket_count_mask_;
  }

  inline void next_bucket(uint32 &bucket) const {
    bucket = (bucket + 1) & bucket_count_mask_;
  }

  void resize(uint32 new_size) {
    if (unlikely(nodes_ == nullptr)) {
      allocate_nodes(new_size);
      used_node_count_ = 0;
      return;
    }

    auto old_nodes = nodes_;
    uint32 old_size = used_node_count_;
    uint32 old_bucket_count = bucket_count_;
    allocate_nodes(new_size);
    used_node_count_ = old_size;

    auto old_nodes_end = old_nodes + old_bucket_count;
    for (NodeT *old_node = old_nodes; old_node != old_nodes_end; ++old_node) {
      if (old_node->empty()) {
        continue;
      }
      auto bucket = calc_bucket(old_node->key());
      while (!nodes_[bucket].empty()) {
        next_bucket(bucket);
      }
      nodes_[bucket] = std::move(*old_node);
    }
    clear_nodes(old_nodes);
  }

  void erase_node(NodeT *it) {
    DCHECK(nodes_ <= it && static_cast<size_t>(it - nodes_) < bucket_count());
    it->clear();
    used_node_count_--;

    const auto bucket_count = bucket_count_;
    const auto *end = nodes_ + bucket_count;
    for (auto *test_node = it + 1; test_node != end; test_node++) {
      if (likely(test_node->empty())) {
        return;
      }

      auto want_node = nodes_ + calc_bucket(test_node->key());
      if (want_node <= it || want_node > test_node) {
        *it = std::move(*test_node);
        it = test_node;
      }
    }

    auto empty_i = static_cast<uint32>(it - nodes_);
    auto empty_bucket = empty_i;
    for (uint32 test_i = bucket_count;; test_i++) {
      auto test_bucket = test_i - bucket_count_;
      if (nodes_[test_bucket].empty()) {
        return;
      }

      auto want_i = calc_bucket(nodes_[test_bucket].key());
      if (want_i < empty_i) {
        want_i += bucket_count;
      }

      if (want_i <= empty_i || want_i > test_i) {
        nodes_[empty_bucket] = std::move(nodes_[test_bucket]);
        empty_i = test_i;
        empty_bucket = test_bucket;
      }
    }
  }

  Iterator create_iterator(NodeT *node) {
    return Iterator(node, nodes_, nodes_ + bucket_count());
  }

  void invalidate_iterators() {
    begin_bucket_ = INVALID_BUCKET;
  }
};

}  // namespace td