//
// Copyright Aliaksei Levin (levlam@telegram.org), Arseny Smirnov (arseny30@gmail.com) 2014-2023
//
// Distributed under the Boost Software License, Version 1.0. (See accompanying
// file LICENSE_1_0.txt or copy at http://www.boost.org/LICENSE_1_0.txt)
//
#pragma once

#include "td/mtproto/AuthKey.h"
#include "td/mtproto/MessageId.h"

#include "td/utils/common.h"
#include "td/utils/Slice.h"
#include "td/utils/Status.h"

#include <array>

namespace td {
namespace mtproto {

struct ServerSalt {
  int64 salt;
  double valid_since;
  double valid_until;
};

template <class StorerT>
void store(const ServerSalt &salt, StorerT &storer) {
  storer.template store_binary<int64>(salt.salt);
  storer.template store_binary<double>(salt.valid_since);
  storer.template store_binary<double>(salt.valid_until);
}

template <class ParserT>
void parse(ServerSalt &salt, ParserT &parser) {
  salt.salt = parser.fetch_long();
  salt.valid_since = parser.fetch_double();
  salt.valid_until = parser.fetch_double();
}

Status check_message_id_duplicates(MessageId *saved_message_ids, size_t max_size, size_t &end_pos,
                                   MessageId message_id);

template <size_t max_size>
class MessageIdDuplicateChecker {
 public:
  Status check(MessageId message_id) {
    return check_message_id_duplicates(&saved_message_ids_[0], max_size, end_pos_, message_id);
  }

 private:
  std::array<MessageId, 2 * max_size> saved_message_ids_;
  size_t end_pos_ = 0;
};

class AuthData {
 public:
  AuthData();
  AuthData(const AuthData &) = default;
  AuthData &operator=(const AuthData &) = delete;
  AuthData(AuthData &&) = delete;
  AuthData &operator=(AuthData &&) = delete;
  ~AuthData() = default;

  bool is_ready(double now);

  void set_main_auth_key(AuthKey auth_key) {
    main_auth_key_ = std::move(auth_key);
  }
  void break_main_auth_key() {
    main_auth_key_.break_key();
  }
  const AuthKey &get_main_auth_key() const {
    // CHECK(has_main_auth_key());
    return main_auth_key_;
  }
  bool has_main_auth_key() const {
    return !main_auth_key_.empty();
  }
  bool need_main_auth_key() const {
    return !has_main_auth_key();
  }

  void set_tmp_auth_key(AuthKey auth_key) {
    CHECK(!auth_key.empty());
    tmp_auth_key_ = std::move(auth_key);
  }
  const AuthKey &get_tmp_auth_key() const {
    return tmp_auth_key_;
  }
  bool was_tmp_auth_key() const {
    return use_pfs() && !tmp_auth_key_.empty();
  }
  bool need_tmp_auth_key(double now, double refresh_margin) const {
    if (!use_pfs()) {
      return false;
    }
    if (tmp_auth_key_.empty()) {
      return true;
    }
    if (now > tmp_auth_key_.expires_at() - refresh_margin) {
      return true;
    }
    return false;
  }
  void drop_main_auth_key() {
    main_auth_key_ = AuthKey();
  }
  void drop_tmp_auth_key() {
    tmp_auth_key_ = AuthKey();
  }
  bool has_tmp_auth_key(double now) const {
    if (!use_pfs()) {
      return false;
    }
    if (tmp_auth_key_.empty()) {
      return false;
    }
    if (now > tmp_auth_key_.expires_at()) {
      return false;
    }
    return true;
  }

  const AuthKey &get_auth_key() const {
    if (use_pfs()) {
      return get_tmp_auth_key();
    }
    return get_main_auth_key();
  }
  bool has_auth_key(double now) const {
    if (use_pfs()) {
      return has_tmp_auth_key(now);
    }
    return has_main_auth_key();
  }

  bool get_auth_flag() const {
    return main_auth_key_.auth_flag();
  }
  void set_auth_flag(bool auth_flag) {
    main_auth_key_.set_auth_flag(auth_flag);
    if (!auth_flag) {
      drop_tmp_auth_key();
    }
  }

  bool get_bind_flag() const {
    return !use_pfs() || tmp_auth_key_.auth_flag();
  }
  void on_bind() {
    CHECK(use_pfs());
    tmp_auth_key_.set_auth_flag(true);
  }

  Slice get_header() const {
    if (use_pfs()) {
      return tmp_auth_key_.need_header() ? Slice(header_) : Slice();
    } else {
      return main_auth_key_.need_header() ? Slice(header_) : Slice();
    }
  }

  void set_header(std::string header) {
    header_ = std::move(header);
  }

  void on_api_response() {
    if (use_pfs()) {
      tmp_auth_key_.remove_header();
    } else {
      main_auth_key_.remove_header();
    }
  }

  void on_connection_not_inited() {
    if (use_pfs()) {
      tmp_auth_key_.restore_header();
    } else {
      main_auth_key_.restore_header();
    }
  }

  void set_session_id(uint64 session_id) {
    session_id_ = session_id;
  }
  uint64 get_session_id() const {
    CHECK(session_id_ != 0);
    return session_id_;
  }

  double get_server_time(double now) const {
    return server_time_difference_ + now;
  }

  double get_server_time_difference() const {
    return server_time_difference_;
  }

  // diff == msg_id / 2^32 - now == old_server_now - now <= server_now - now
  // server_time_difference >= max{diff}
  bool update_server_time_difference(double diff);

  void reset_server_time_difference(double diff);

  uint64 get_server_salt(double now) {
    update_salt(now);
    return server_salt_.salt;
  }

  void set_server_salt(uint64 salt, double now) {
    server_salt_.salt = salt;
    double server_time = get_server_time(now);
    server_salt_.valid_since = server_time;
    server_salt_.valid_until = server_time + 60 * 10;
    future_salts_.clear();
  }

  bool is_server_salt_valid(double now) const {
    return server_salt_.valid_until > get_server_time(now) + 60;
  }

  bool has_salt(double now) {
    update_salt(now);
    return is_server_salt_valid(now);
  }

  bool need_future_salts(double now) {
    update_salt(now);
    return future_salts_.empty() || !is_server_salt_valid(now);
  }

  void set_future_salts(const std::vector<ServerSalt> &salts, double now);

  std::vector<ServerSalt> get_future_salts() const;

  MessageId next_message_id(double now);

  bool is_valid_outbound_msg_id(MessageId message_id, double now) const;

  bool is_valid_inbound_msg_id(MessageId message_id, double now) const;

  Status check_packet(uint64 session_id, MessageId message_id, double now, bool &time_difference_was_updated);

  Status check_update(MessageId message_id) {
    return updates_duplicate_checker_.check(message_id);
  }

  Status recheck_update(MessageId message_id) {
    return updates_duplicate_rechecker_.check(message_id);
  }

  int32 next_seq_no(bool is_content_related) {
    int32 res = seq_no_;
    if (is_content_related) {
      res |= 1;
      seq_no_ += 2;
    }
    return res;
  }

  void clear_seq_no() {
    seq_no_ = 0;
  }

  void set_use_pfs(bool use_pfs) {
    use_pfs_ = use_pfs;
  }
  bool use_pfs() const {
    return use_pfs_;
  }

 private:
  bool use_pfs_ = true;
  AuthKey main_auth_key_;
  AuthKey tmp_auth_key_;
  bool server_time_difference_was_updated_ = false;
  double server_time_difference_ = 0;
  ServerSalt server_salt_;
  MessageId last_message_id_;
  int32 seq_no_ = 0;
  string header_;
  uint64 session_id_ = 0;

  std::vector<ServerSalt> future_salts_;

  MessageIdDuplicateChecker<1000> duplicate_checker_;
  MessageIdDuplicateChecker<1000> updates_duplicate_checker_;
  MessageIdDuplicateChecker<100> updates_duplicate_rechecker_;

  void update_salt(double now);
};

}  // namespace mtproto
}  // namespace td