//
// Copyright Aliaksei Levin (levlam@telegram.org), Arseny Smirnov (arseny30@gmail.com) 2014-2021
//
// Distributed under the Boost Software License, Version 1.0. (See accompanying
// file LICENSE_1_0.txt or copy at http://www.boost.org/LICENSE_1_0.txt)
//
#pragma once

#include "td/utils/common.h"
#include "td/utils/logging.h"
#include "td/utils/port/thread.h"

#include <algorithm>
#include <atomic>
#include <condition_variable>
#include <mutex>

namespace td {

class MpmcEagerWaiter {
 public:
  struct Slot {
   private:
    friend class MpmcEagerWaiter;
    int yields;
    uint32 worker_id;
  };
  void init_slot(Slot &slot, uint32 worker_id) {
    slot.yields = 0;
    slot.worker_id = worker_id;
  }
  void wait(Slot &slot) {
    if (slot.yields < RoundsTillSleepy) {
      td::this_thread::yield();
      slot.yields++;
      return;
    } else if (slot.yields == RoundsTillSleepy) {
      auto state = state_.load(std::memory_order_relaxed);
      if (!State::has_worker(state)) {
        auto new_state = State::with_worker(state, slot.worker_id);
        if (state_.compare_exchange_strong(state, new_state, std::memory_order_acq_rel)) {
          td::this_thread::yield();
          slot.yields++;
          return;
        }
        if (state == State::awake()) {
          slot.yields = 0;
          return;
        }
      }
      td::this_thread::yield();
      slot.yields = 0;
      return;
    } else if (slot.yields < RoundsTillAsleep) {
      auto state = state_.load(std::memory_order_acquire);
      if (State::still_sleepy(state, slot.worker_id)) {
        td::this_thread::yield();
        slot.yields++;
        return;
      }
      slot.yields = 0;
      return;
    } else {
      auto state = state_.load(std::memory_order_acquire);
      if (State::still_sleepy(state, slot.worker_id)) {
        std::unique_lock<std::mutex> lock(mutex_);
        if (state_.compare_exchange_strong(state, State::asleep(), std::memory_order_acq_rel)) {
          condition_variable_.wait(lock);
        }
      }
      slot.yields = 0;
      return;
    }
  }

  void stop_wait(Slot &slot) {
    if (slot.yields > RoundsTillSleepy) {
      notify_cold();
    }
    slot.yields = 0;
    return;
  }

  void close() {
  }

  void notify() {
    std::atomic_thread_fence(std::memory_order_seq_cst);
    if (state_.load(std::memory_order_acquire) == State::awake()) {
      return;
    }
    notify_cold();
  }

 private:
  struct State {
    static constexpr uint32 awake() {
      return 0;
    }
    static constexpr uint32 asleep() {
      return 1;
    }
    static bool is_asleep(uint32 state) {
      return (state & 1) != 0;
    }
    static bool has_worker(uint32 state) {
      return (state >> 1) != 0;
    }
    static int32 with_worker(uint32 state, uint32 worker) {
      return state | ((worker + 1) << 1);
    }
    static bool still_sleepy(uint32 state, uint32 worker) {
      return (state >> 1) == (worker + 1);
    }
  };
  enum { RoundsTillSleepy = 32, RoundsTillAsleep = 64 };
  // enum { RoundsTillSleepy = 1, RoundsTillAsleep = 2 };
  std::atomic<uint32> state_{State::awake()};
  std::mutex mutex_;
  std::condition_variable condition_variable_;

  void notify_cold() {
    auto old_state = state_.exchange(State::awake(), std::memory_order_release);
    if (State::is_asleep(old_state)) {
      std::lock_guard<std::mutex> guard(mutex_);
      condition_variable_.notify_all();
    }
  }
};

class MpmcSleepyWaiter {
 public:
  struct Slot {
   private:
    friend class MpmcSleepyWaiter;

    enum State { Search, Work, Sleep } state_{Work};

    void park() {
      std::unique_lock<std::mutex> guard(mutex_);
      condition_variable_.wait(guard, [&] { return unpark_flag_; });
      unpark_flag_ = false;
    }

    bool cancel_park() {
      auto res = unpark_flag_;
      unpark_flag_ = false;
      return res;
    }

    void unpark() {
      //TODO: try to unlock guard before notify_all
      std::unique_lock<std::mutex> guard(mutex_);
      unpark_flag_ = true;
      condition_variable_.notify_all();
    }

    std::mutex mutex_;
    std::condition_variable condition_variable_;
    bool unpark_flag_{false};  // TODO: move out of lock
    int yield_cnt{0};
    int32 worker_id{0};

   public:
    char padding[TD_CONCURRENCY_PAD];
  };

  // There are a lot of workers
  // Each has a slot
  //
  // States of a worker:
  //   - searching for work | Search
  //   - processing work    | Work
  //   - sleeping           | Sleep
  //
  // When somebody adds a work it calls notify
  //
  // notify
  //   if there are workers in search phase do nothing.
  //   if all workers are awake do nothing
  //   otherwise wake some random worker
  //
  // Initially all workers are in Search mode.
  //
  // When worker found nothing it may try to call wait.
  // This may put it in a Sleep for some time.
  // After wait returns worker will be in Search state again.
  //
  // Suppose worker found a work and ready to process it.
  // Then it may call stop_wait. This will cause transition from
  // Search to Work state.
  //
  // Main invariant:
  // After notify is called there should be at least on worker in Search or Work state.
  // If possible - in Search state
  //

  void init_slot(Slot &slot, int32 worker_id) {
    slot.state_ = Slot::State::Work;
    slot.unpark_flag_ = false;
    slot.worker_id = worker_id;
    VLOG(waiter) << "Init slot " << worker_id;
  }

  static constexpr int VERBOSITY_NAME(waiter) = VERBOSITY_NAME(DEBUG) + 10;
  void wait(Slot &slot) {
    if (slot.state_ == Slot::State::Work) {
      VLOG(waiter) << "Work -> Search";
      state_++;
      slot.state_ = Slot::State::Search;
      slot.yield_cnt = 0;
      return;
    }
    if (slot.state_ == Slot::Search) {
      if (slot.yield_cnt++ < 10 && false) {
        td::this_thread::yield();
        return;
      }

      slot.state_ = Slot::State::Sleep;
      std::unique_lock<std::mutex> guard(sleepers_mutex_);
      auto state_view = StateView(state_.fetch_add((1 << PARKING_SHIFT) - 1));
      CHECK(state_view.searching_count != 0);
      bool should_search = state_view.searching_count == 1;
      if (closed_) {
        return;
      }
      sleepers_.push_back(&slot);
      LOG_CHECK(slot.unpark_flag_ == false) << slot.worker_id;
      VLOG(waiter) << "Add to sleepers " << slot.worker_id;
      //guard.unlock();
      if (should_search) {
        VLOG(waiter) << "Search -> Search once, then Sleep ";
        return;
      }
      VLOG(waiter) << "Search -> Sleep " << state_view.searching_count << " " << state_view.parked_count;
    }

    CHECK(slot.state_ == Slot::State::Sleep);
    VLOG(waiter) << "Park " << slot.worker_id;
    slot.park();
    VLOG(waiter) << "Resume " << slot.worker_id;
    slot.state_ = Slot::State::Search;
    slot.yield_cnt = 0;
  }

  void stop_wait(Slot &slot) {
    if (slot.state_ == Slot::State::Work) {
      return;
    }
    if (slot.state_ == Slot::State::Sleep) {
      VLOG(waiter) << "Search once, then Sleep -> Work/Search " << slot.worker_id;
      slot.state_ = Slot::State::Work;
      std::unique_lock<std::mutex> guard(sleepers_mutex_);
      auto it = std::find(sleepers_.begin(), sleepers_.end(), &slot);
      if (it != sleepers_.end()) {
        sleepers_.erase(it);
        VLOG(waiter) << "Remove from sleepers " << slot.worker_id;
        state_.fetch_sub((1 << PARKING_SHIFT) - 1);
        guard.unlock();
      } else {
        guard.unlock();
        VLOG(waiter) << "Not in sleepers" << slot.worker_id;
        CHECK(slot.cancel_park());
      }
    }
    VLOG(waiter) << "Search once, then Sleep -> Work " << slot.worker_id;
    slot.state_ = Slot::State::Search;
    auto state_view = StateView(state_.fetch_sub(1));
    CHECK(state_view.searching_count != 0);
    CHECK(state_view.searching_count < 1000);
    bool should_notify = state_view.searching_count == 1;
    if (should_notify) {
      VLOG(waiter) << "Notify others";
      notify();
    }
    VLOG(waiter) << "Search -> Work ";
    slot.state_ = Slot::State::Work;
  }

  void notify() {
    auto view = StateView(state_.load());
    //LOG(ERROR) << view.parked_count;
    if (view.searching_count > 0 || view.parked_count == 0) {
      VLOG(waiter) << "Ingore notify: " << view.searching_count << " " << view.parked_count;
      return;
    }

    VLOG(waiter) << "Notify: " << view.searching_count << " " << view.parked_count;
    std::unique_lock<std::mutex> guard(sleepers_mutex_);

    view = StateView(state_.load());
    if (view.searching_count > 0) {
      VLOG(waiter) << "Skip notify: got searching";
      return;
    }

    CHECK(view.parked_count == static_cast<int>(sleepers_.size()));
    if (sleepers_.empty()) {
      VLOG(waiter) << "Skip notify: no sleepers";
      return;
    }

    auto sleeper = sleepers_.back();
    sleepers_.pop_back();
    state_.fetch_sub((1 << PARKING_SHIFT) - 1);
    VLOG(waiter) << "Unpark " << sleeper->worker_id;
    sleeper->unpark();
  }

  void close() {
    StateView state(state_.load());
    LOG_CHECK(state.parked_count == 0) << state.parked_count;
    LOG_CHECK(state.searching_count == 0) << state.searching_count;
  }

 private:
  static constexpr int32 PARKING_SHIFT = 16;
  struct StateView {
    int32 parked_count;
    int32 searching_count;
    explicit StateView(int32 x) {
      parked_count = x >> PARKING_SHIFT;
      searching_count = x & ((1 << PARKING_SHIFT) - 1);
    }
  };
  std::atomic<int32> state_{0};

  std::mutex sleepers_mutex_;
  vector<Slot *> sleepers_;

  bool closed_ = false;
};

using MpmcWaiter = MpmcSleepyWaiter;

}  // namespace td