//
// Copyright Aliaksei Levin (levlam@telegram.org), Arseny Smirnov (arseny30@gmail.com) 2014-2022
//
// Distributed under the Boost Software License, Version 1.0. (See accompanying
// file LICENSE_1_0.txt or copy at http://www.boost.org/LICENSE_1_0.txt)
//
#pragma once

#include "td/utils/algorithm.h"
#include "td/utils/common.h"
#include "td/utils/Container.h"
#include "td/utils/FlatHashSet.h"
#include "td/utils/List.h"
#include "td/utils/logging.h"
#include "td/utils/optional.h"
#include "td/utils/Span.h"
#include "td/utils/StringBuilder.h"
#include "td/utils/VectorQueue.h"

#include <functional>
#include <unordered_map>

namespace td {

struct ChainSchedulerBase {
  struct TaskWithParents {
    uint64 task_id{};
    vector<uint64> parents;
  };
};

template <class ExtraT = Unit>
class ChainScheduler final : public ChainSchedulerBase {
 public:
  using TaskId = uint64;
  using ChainId = uint64;

  TaskId create_task(Span<ChainId> chains, ExtraT extra = {});

  ExtraT *get_task_extra(TaskId task_id);

  optional<TaskWithParents> start_next_task();

  void pause_task(TaskId task_id);

  void finish_task(TaskId task_id);

  void reset_task(TaskId task_id);

  template <class F>
  void for_each(F &&f) {
    tasks_.for_each([&f](auto, Task &task) { f(task.extra); });
  }

  template <class F>
  void for_each_dependent(TaskId task_id, F &&f) {
    auto *task = tasks_.get(task_id);
    CHECK(task != nullptr);
    FlatHashSet<TaskId> visited;
    bool check_for_collisions = task->chains.size() > 1;
    for (TaskChainInfo &task_chain_info : task->chains) {
      ChainInfo &chain_info = *task_chain_info.chain_info;
      chain_info.chain.foreach_child(&task_chain_info.chain_node, [&](TaskId task_id, uint64) {
        if (check_for_collisions && !visited.insert(task_id).second) {
          return;
        }
        f(task_id);
      });
    }
  }

 private:
  struct ChainNode : ListNode {
    TaskId task_id{};
    uint64 generation{};
  };

  class Chain {
   public:
    void add_task(ChainNode *node) {
      head_.put_back(node);
    }

    optional<TaskId> get_first() {
      if (head_.empty()) {
        return {};
      }
      return static_cast<ChainNode &>(*head_.get_next()).task_id;
    }

    optional<TaskId> get_child(ChainNode *chain_node) {
      if (chain_node->get_next() == head_.end()) {
        return {};
      }
      return static_cast<ChainNode &>(*chain_node->get_next()).task_id;
    }
    optional<ChainNode *> get_parent(ChainNode *chain_node) {
      if (chain_node->get_prev() == head_.end()) {
        return {};
      }
      return static_cast<ChainNode *>(chain_node->get_prev());
    }

    void finish_task(ChainNode *node) {
      node->remove();
    }

    bool empty() const {
      return head_.empty();
    }

    void foreach(std::function<void(TaskId, uint64)> f) const {
      for (auto it = head_.begin(); it != head_.end(); it = it->get_next()) {
        auto &node = static_cast<const ChainNode &>(*it);
        f(node.task_id, node.generation);
      }
    }
    void foreach_child(ListNode *start_node, std::function<void(TaskId, uint64)> f) const {
      for (auto it = start_node; it != head_.end(); it = it->get_next()) {
        auto &node = static_cast<const ChainNode &>(*it);
        f(node.task_id, node.generation);
      }
    }

   private:
    ListNode head_;
  };
  struct ChainInfo {
    Chain chain;
    uint32 active_tasks{};
    uint64 generation{1};
  };
  struct TaskChainInfo {
    ChainNode chain_node;
    ChainId chain_id{};
    ChainInfo *chain_info{};
  };
  struct Task {
    enum class State { Pending, Active, Paused } state{State::Pending};
    vector<TaskChainInfo> chains;
    ExtraT extra;
  };
  std::unordered_map<ChainId, ChainInfo> chains_;
  std::unordered_map<ChainId, TaskId> limited_tasks_;
  Container<Task> tasks_;
  VectorQueue<TaskId> pending_tasks_;

  void try_start_task(TaskId task_id) {
    auto *task = tasks_.get(task_id);
    CHECK(task != nullptr);
    if (task->state != Task::State::Pending) {
      return;
    }
    for (TaskChainInfo &task_chain_info : task->chains) {
      auto o_parent = task_chain_info.chain_info->chain.get_parent(&task_chain_info.chain_node);

      if (o_parent) {
        if (o_parent.value()->generation != task_chain_info.chain_info->generation) {
          return;
        }
      }

      if (task_chain_info.chain_info->active_tasks >= 10) {
        limited_tasks_[task_chain_info.chain_id] = task_id;
        return;
      }
    }

    do_start_task(task_id, task);
  }

  void do_start_task(TaskId task_id, Task *task) {
    for (TaskChainInfo &task_chain_info : task->chains) {
      ChainInfo &chain_info = chains_[task_chain_info.chain_id];
      chain_info.active_tasks++;
      task_chain_info.chain_node.generation = chain_info.generation;
    }
    task->state = Task::State::Active;

    pending_tasks_.push(task_id);
    for_each_child(task, [&](TaskId task_id) { try_start_task(task_id); });
  }

  template <class F>
  void for_each_child(Task *task, F &&f) {
    for (TaskChainInfo &task_chain_info : task->chains) {
      ChainInfo &chain_info = *task_chain_info.chain_info;
      auto o_child = chain_info.chain.get_child(&task_chain_info.chain_node);
      if (o_child) {
        f(o_child.value());
      }
    }
  }

  void inactivate_task(TaskId task_id, bool failed) {
    LOG(DEBUG) << "Inactivate " << task_id << " " << (failed ? "failed" : "finished");
    auto *task = tasks_.get(task_id);
    CHECK(task != nullptr);
    bool was_active = task->state == Task::State::Active;
    task->state = Task::State::Pending;
    for (TaskChainInfo &task_chain_info : task->chains) {
      ChainInfo &chain_info = *task_chain_info.chain_info;
      if (was_active) {
        chain_info.active_tasks--;
      }
      if (was_active && failed) {
        chain_info.generation = td::max(chain_info.generation, task_chain_info.chain_node.generation + 1);
      }

      auto it = limited_tasks_.find(task_chain_info.chain_id);
      if (it != limited_tasks_.end()) {
        auto limited_task_id = it->second;
        limited_tasks_.erase(it);
        if (limited_task_id != task_id) {
          try_start_task_later(limited_task_id);
        }
      }

      auto o_first = chain_info.chain.get_first();
      if (o_first) {
        auto first_task_id = o_first.unwrap();
        if (first_task_id != task_id) {
          try_start_task_later(first_task_id);
        }
      }
    }
  }

  void finish_chain_task(TaskChainInfo &task_chain_info) {
    auto &chain = task_chain_info.chain_info->chain;
    chain.finish_task(&task_chain_info.chain_node);
    if (chain.empty()) {
      chains_.erase(task_chain_info.chain_id);
    }
  }

  vector<TaskId> to_start_;

  void try_start_task_later(TaskId task_id) {
    LOG(DEBUG) << "Start later " << task_id;
    to_start_.push_back(task_id);
  }

  void flush_try_start_task() {
    auto moved_to_start = std::move(to_start_);
    for (auto task_id : moved_to_start) {
      try_start_task(task_id);
    }
    CHECK(to_start_.empty());
  }

  template <class ExtraTT>
  friend StringBuilder &operator<<(StringBuilder &sb, ChainScheduler<ExtraTT> &scheduler);
};

template <class ExtraT>
typename ChainScheduler<ExtraT>::TaskId ChainScheduler<ExtraT>::create_task(Span<ChainScheduler::ChainId> chains,
                                                                            ExtraT extra) {
  auto task_id = tasks_.create();
  Task &task = *tasks_.get(task_id);
  task.extra = std::move(extra);
  task.chains = transform(chains, [&](auto chain_id) {
    TaskChainInfo task_chain_info;
    ChainInfo &chain_info = chains_[chain_id];
    task_chain_info.chain_id = chain_id;
    task_chain_info.chain_info = &chain_info;
    task_chain_info.chain_node.task_id = task_id;
    task_chain_info.chain_node.generation = 0;
    return task_chain_info;
  });

  for (TaskChainInfo &task_chain_info : task.chains) {
    ChainInfo &chain_info = *task_chain_info.chain_info;
    chain_info.chain.add_task(&task_chain_info.chain_node);
  }

  try_start_task(task_id);
  return task_id;
}

// TODO: return reference
template <class ExtraT>
ExtraT *ChainScheduler<ExtraT>::get_task_extra(ChainScheduler::TaskId task_id) {  // may return nullptr
  auto *task = tasks_.get(task_id);
  if (task == nullptr) {
    return nullptr;
  }
  return &task->extra;
}

template <class ExtraT>
optional<ChainSchedulerBase::TaskWithParents> ChainScheduler<ExtraT>::start_next_task() {
  if (pending_tasks_.empty()) {
    return {};
  }
  auto task_id = pending_tasks_.pop();
  TaskWithParents res;
  res.task_id = task_id;
  auto *task = tasks_.get(task_id);
  CHECK(task != nullptr);
  for (TaskChainInfo &task_chain_info : task->chains) {
    Chain &chain = task_chain_info.chain_info->chain;
    auto o_parent = chain.get_parent(&task_chain_info.chain_node);
    if (o_parent) {
      res.parents.push_back(o_parent.value()->task_id);
    }
  }
  return res;
}

template <class ExtraT>
void ChainScheduler<ExtraT>::finish_task(ChainScheduler::TaskId task_id) {
  auto *task = tasks_.get(task_id);
  CHECK(task != nullptr);
  CHECK(to_start_.empty());

  inactivate_task(task_id, false);

  for_each_child(task, [&](TaskId task_id) { try_start_task_later(task_id); });

  for (TaskChainInfo &task_chain_info : task->chains) {
    finish_chain_task(task_chain_info);
  }

  tasks_.erase(task_id);
  flush_try_start_task();
}

template <class ExtraT>
void ChainScheduler<ExtraT>::reset_task(ChainScheduler::TaskId task_id) {
  CHECK(to_start_.empty());
  auto *task = tasks_.get(task_id);
  CHECK(task != nullptr);
  inactivate_task(task_id, true);
  try_start_task_later(task_id);
  flush_try_start_task();
}

template <class ExtraT>
void ChainScheduler<ExtraT>::pause_task(TaskId task_id) {
  auto *task = tasks_.get(task_id);
  CHECK(task != nullptr);
  inactivate_task(task_id, true);
  task->state = Task::State::Paused;
  flush_try_start_task();
}

template <class ExtraT>
StringBuilder &operator<<(StringBuilder &sb, ChainScheduler<ExtraT> &scheduler) {
  // 1 print chains
  sb << '\n';
  for (auto &it : scheduler.chains_) {
    sb << "ChainId{" << it.first << "}";
    sb << " active_cnt = " << it.second.active_tasks;
    sb << " g = " << it.second.generation;
    sb << ':';
    it.second.chain.foreach(
        [&](auto task_id, auto generation) { sb << ' ' << *scheduler.get_task_extra(task_id) << ':' << generation; });
    sb << '\n';
  }
  scheduler.tasks_.for_each([&](auto id, auto &task) {
    sb << "Task: " << task.extra;
    sb << " state = " << static_cast<int>(task.state);
    for (auto &task_chain_info : task.chains) {
      sb << " g = " << task_chain_info.chain_node.generation;
      if (task_chain_info.chain_info->generation != task_chain_info.chain_node.generation) {
        sb << " chain_g = " << task_chain_info.chain_info->generation;
      }
    }
    sb << '\n';
  });
  return sb;
}

}  // namespace td