// // Copyright Aliaksei Levin (levlam@telegram.org), Arseny Smirnov (arseny30@gmail.com) 2014-2024 // // Distributed under the Boost Software License, Version 1.0. (See accompanying // file LICENSE_1_0.txt or copy at http://www.boost.org/LICENSE_1_0.txt) // #pragma once #include "td/utils/common.h" namespace td { struct HeapNode { bool in_heap() const { return pos_ != -1; } bool is_top() const { return pos_ == 0; } void remove() { pos_ = -1; } int32 pos_ = -1; }; template class KHeap { public: bool empty() const { return array_.empty(); } size_t size() const { return array_.size(); } KeyT top_key() const { return array_[0].key_; } KeyT get_key(const HeapNode *node) const { auto pos = static_cast(node->pos_); CHECK(pos < array_.size()); return array_[pos].key_; } const HeapNode *top() const { return array_[0].node_; } HeapNode *pop() { CHECK(!empty()); HeapNode *result = array_[0].node_; result->remove(); erase(static_cast(0)); return result; } void insert(KeyT key, HeapNode *node) { CHECK(!node->in_heap()); array_.push_back({key, node}); fix_up(array_.size() - 1); } void fix(KeyT key, HeapNode *node) { auto pos = static_cast(node->pos_); CHECK(pos < array_.size()); KeyT old_key = array_[pos].key_; array_[pos].key_ = key; if (key < old_key) { fix_up(pos); } else { fix_down(pos); } } void erase(HeapNode *node) { auto pos = static_cast(node->pos_); node->remove(); CHECK(pos < array_.size()); erase(pos); } template void for_each(F &&f) const { for (auto &it : array_) { f(it.key_, it.node_); } } template void for_each(F &&f) { for (auto &it : array_) { f(it.key_, it.node_); } } void check() const { for (size_t i = 0; i < array_.size(); i++) { for (size_t j = i * K + 1; j < i * K + 1 + K && j < array_.size(); j++) { CHECK(array_[i].key_ <= array_[j].key_); } } } private: struct Item { KeyT key_; HeapNode *node_; }; vector array_; void fix_up(size_t pos) { auto item = array_[pos]; while (pos) { auto parent_pos = (pos - 1) / K; auto parent_item = array_[parent_pos]; if (parent_item.key_ < item.key_) { break; } parent_item.node_->pos_ = static_cast(pos); array_[pos] = parent_item; pos = parent_pos; } item.node_->pos_ = static_cast(pos); array_[pos] = item; } void fix_down(size_t pos) { auto item = array_[pos]; while (true) { auto left_pos = pos * K + 1; auto right_pos = min(left_pos + K, array_.size()); auto next_pos = pos; KeyT next_key = item.key_; for (auto i = left_pos; i < right_pos; i++) { KeyT i_key = array_[i].key_; if (i_key < next_key) { next_key = i_key; next_pos = i; } } if (next_pos == pos) { break; } array_[pos] = array_[next_pos]; array_[pos].node_->pos_ = static_cast(pos); pos = next_pos; } item.node_->pos_ = static_cast(pos); array_[pos] = item; } void erase(size_t pos) { array_[pos] = array_.back(); array_.pop_back(); if (pos < array_.size()) { fix_down(pos); fix_up(pos); } if (array_.capacity() > 50 && array_.size() < array_.capacity() / 4) { array_.shrink_to_fit(); } } }; } // namespace td