//
// Copyright Aliaksei Levin (levlam@telegram.org), Arseny Smirnov (arseny30@gmail.com) 2014-2020
//
// Distributed under the Boost Software License, Version 1.0. (See accompanying
// file LICENSE_1_0.txt or copy at http://www.boost.org/LICENSE_1_0.txt)
//
#include "td/net/Wget.h"

#include "td/net/HttpHeaderCreator.h"
#include "td/net/HttpOutboundConnection.h"
#include "td/net/SslStream.h"

#include "td/utils/buffer.h"
#include "td/utils/HttpUrl.h"
#include "td/utils/logging.h"
#include "td/utils/misc.h"
#include "td/utils/port/IPAddress.h"
#include "td/utils/port/SocketFd.h"
#include "td/utils/Slice.h"

#include <limits>

namespace td {

Wget::Wget(Promise<unique_ptr<HttpQuery>> promise, string url, std::vector<std::pair<string, string>> headers,
           int32 timeout_in, int32 ttl, bool prefer_ipv6, SslStream::VerifyPeer verify_peer, string content,
           string content_type)
    : promise_(std::move(promise))
    , input_url_(std::move(url))
    , headers_(std::move(headers))
    , timeout_in_(timeout_in)
    , ttl_(ttl)
    , prefer_ipv6_(prefer_ipv6)
    , verify_peer_(verify_peer)
    , content_(std::move(content))
    , content_type_(std::move(content_type)) {
}

Status Wget::try_init() {
  TRY_RESULT(url, parse_url(input_url_));
  TRY_RESULT_ASSIGN(url.host_, idn_to_ascii(url.host_));

  HttpHeaderCreator hc;
  if (content_.empty()) {
    hc.init_get(url.query_);
  } else {
    hc.init_post(url.query_);
    hc.set_content_size(content_.size());
    if (!content_type_.empty()) {
      hc.set_content_type(content_type_);
    }
  }
  bool was_host = false;
  bool was_accept_encoding = false;
  for (auto &header : headers_) {
    auto header_lower = to_lower(header.first);
    if (header_lower == "host") {
      was_host = true;
    }
    if (header_lower == "accept-encoding") {
      was_accept_encoding = true;
    }
    hc.add_header(header.first, header.second);
  }
  if (!was_host) {
    hc.add_header("Host", url.host_);
  }
  if (!was_accept_encoding) {
    hc.add_header("Accept-Encoding", "gzip, deflate");
  }
  TRY_RESULT(header, hc.finish(content_));

  IPAddress addr;
  TRY_STATUS(addr.init_host_port(url.host_, url.port_, prefer_ipv6_));

  TRY_RESULT(fd, SocketFd::open(addr));
  if (url.protocol_ == HttpUrl::Protocol::HTTP) {
    connection_ = create_actor<HttpOutboundConnection>("Connect", std::move(fd), SslStream{},
                                                       std::numeric_limits<std::size_t>::max(), 0, 0,
                                                       ActorOwn<HttpOutboundConnection::Callback>(actor_id(this)));
  } else {
    TRY_RESULT(ssl_stream, SslStream::create(url.host_, CSlice() /* certificate */, verify_peer_));
    connection_ = create_actor<HttpOutboundConnection>("Connect", std::move(fd), std::move(ssl_stream),
                                                       std::numeric_limits<std::size_t>::max(), 0, 0,
                                                       ActorOwn<HttpOutboundConnection::Callback>(actor_id(this)));
  }

  send_closure(connection_, &HttpOutboundConnection::write_next, BufferSlice(header));
  send_closure(connection_, &HttpOutboundConnection::write_ok);
  return Status::OK();
}

void Wget::loop() {
  if (connection_.empty()) {
    auto status = try_init();
    if (status.is_error()) {
      return on_error(std::move(status));
    }
  }
}

void Wget::handle(unique_ptr<HttpQuery> result) {
  on_ok(std::move(result));
}

void Wget::on_connection_error(Status error) {
  on_error(std::move(error));
}

void Wget::on_ok(unique_ptr<HttpQuery> http_query_ptr) {
  CHECK(promise_);
  CHECK(http_query_ptr);
  if ((http_query_ptr->code_ == 301 || http_query_ptr->code_ == 302 || http_query_ptr->code_ == 307 ||
       http_query_ptr->code_ == 308) &&
      ttl_ > 0) {
    LOG(DEBUG) << *http_query_ptr;
    input_url_ = http_query_ptr->get_header("location").str();
    LOG(DEBUG) << input_url_;
    ttl_--;
    connection_.reset();
    yield();
  } else if (http_query_ptr->code_ >= 200 && http_query_ptr->code_ < 300) {
    promise_.set_value(std::move(http_query_ptr));
    stop();
  } else {
    on_error(Status::Error(PSLICE() << "HTTP error: " << http_query_ptr->code_));
  }
}

void Wget::on_error(Status error) {
  CHECK(error.is_error());
  CHECK(promise_);
  promise_.set_error(std::move(error));
  stop();
}

void Wget::start_up() {
  set_timeout_in(timeout_in_);
  loop();
}

void Wget::timeout_expired() {
  on_error(Status::Error("Response timeout expired"));
}

void Wget::tear_down() {
  if (promise_) {
    on_error(Status::Error("Cancelled"));
  }
}

}  // namespace td