//  Copyright (c) 2011-present, Facebook, Inc.  All rights reserved.
//  This source code is licensed under both the GPLv2 (found in the
//  COPYING file in the root directory) and Apache 2.0 License
//  (found in the LICENSE.Apache file in the root directory).

#ifndef ROCKSDB_LITE

#include <atomic>
#include <functional>
#include <string>
#include <utility>
#include <vector>

#include "db/db_impl/db_impl.h"
#include "db/write_callback.h"
#include "port/port.h"
#include "rocksdb/db.h"
#include "rocksdb/write_batch.h"
#include "test_util/sync_point.h"
#include "test_util/testharness.h"
#include "util/random.h"

using std::string;

namespace ROCKSDB_NAMESPACE {

class WriteCallbackTest : public testing::Test {
 public:
  string dbname;

  WriteCallbackTest() {
    dbname = test::PerThreadDBPath("write_callback_testdb");
  }
};

class WriteCallbackTestWriteCallback1 : public WriteCallback {
 public:
  bool was_called = false;

  Status Callback(DB *db) override {
    was_called = true;

    // Make sure db is a DBImpl
    DBImpl* db_impl = dynamic_cast<DBImpl*> (db);
    if (db_impl == nullptr) {
      return Status::InvalidArgument("");
    }

    return Status::OK();
  }

  bool AllowWriteBatching() override { return true; }
};

class WriteCallbackTestWriteCallback2 : public WriteCallback {
 public:
  Status Callback(DB* /*db*/) override { return Status::Busy(); }
  bool AllowWriteBatching() override { return true; }
};

class MockWriteCallback : public WriteCallback {
 public:
  bool should_fail_ = false;
  bool allow_batching_ = false;
  std::atomic<bool> was_called_{false};

  MockWriteCallback() {}

  MockWriteCallback(const MockWriteCallback& other) {
    should_fail_ = other.should_fail_;
    allow_batching_ = other.allow_batching_;
    was_called_.store(other.was_called_.load());
  }

  Status Callback(DB* /*db*/) override {
    was_called_.store(true);
    if (should_fail_) {
      return Status::Busy();
    } else {
      return Status::OK();
    }
  }

  bool AllowWriteBatching() override { return allow_batching_; }
};

class WriteCallbackPTest
    : public WriteCallbackTest,
      public ::testing::WithParamInterface<
          std::tuple<bool, bool, bool, bool, bool, bool, bool>> {
 public:
  WriteCallbackPTest() {
    std::tie(unordered_write_, seq_per_batch_, two_queues_, allow_parallel_,
             allow_batching_, enable_WAL_, enable_pipelined_write_) =
        GetParam();
  }

 protected:
  bool unordered_write_;
  bool seq_per_batch_;
  bool two_queues_;
  bool allow_parallel_;
  bool allow_batching_;
  bool enable_WAL_;
  bool enable_pipelined_write_;
};

TEST_P(WriteCallbackPTest, WriteWithCallbackTest) {
  struct WriteOP {
    WriteOP(bool should_fail = false) { callback_.should_fail_ = should_fail; }

    void Put(const string& key, const string& val) {
      kvs_.push_back(std::make_pair(key, val));
      write_batch_.Put(key, val);
    }

    void Clear() {
      kvs_.clear();
      write_batch_.Clear();
      callback_.was_called_.store(false);
    }

    MockWriteCallback callback_;
    WriteBatch write_batch_;
    std::vector<std::pair<string, string>> kvs_;
  };

  // In each scenario we'll launch multiple threads to write.
  // The size of each array equals to number of threads, and
  // each boolean in it denote whether callback of corresponding
  // thread should succeed or fail.
  std::vector<std::vector<WriteOP>> write_scenarios = {
      {true},
      {false},
      {false, false},
      {true, true},
      {true, false},
      {false, true},
      {false, false, false},
      {true, true, true},
      {false, true, false},
      {true, false, true},
      {true, false, false, false, false},
      {false, false, false, false, true},
      {false, false, true, false, true},
  };

  for (auto& write_group : write_scenarios) {
    Options options;
    options.create_if_missing = true;
    options.unordered_write = unordered_write_;
    options.allow_concurrent_memtable_write = allow_parallel_;
    options.enable_pipelined_write = enable_pipelined_write_;
    options.two_write_queues = two_queues_;
    // Skip unsupported combinations
    if (options.enable_pipelined_write && seq_per_batch_) {
      continue;
    }
    if (options.enable_pipelined_write && options.two_write_queues) {
      continue;
    }
    if (options.unordered_write && !options.allow_concurrent_memtable_write) {
      continue;
    }
    if (options.unordered_write && options.enable_pipelined_write) {
      continue;
    }

    ReadOptions read_options;
    DB* db;
    DBImpl* db_impl;

    DestroyDB(dbname, options);

    DBOptions db_options(options);
    ColumnFamilyOptions cf_options(options);
    std::vector<ColumnFamilyDescriptor> column_families;
    column_families.push_back(
        ColumnFamilyDescriptor(kDefaultColumnFamilyName, cf_options));
    std::vector<ColumnFamilyHandle*> handles;
    auto open_s = DBImpl::Open(db_options, dbname, column_families, &handles,
                               &db, seq_per_batch_, true /* batch_per_txn */);
    ASSERT_OK(open_s);
    assert(handles.size() == 1);
    delete handles[0];

    db_impl = dynamic_cast<DBImpl*>(db);
    ASSERT_TRUE(db_impl);

    // Writers that have called JoinBatchGroup.
    std::atomic<uint64_t> threads_joining(0);
    // Writers that have linked to the queue
    std::atomic<uint64_t> threads_linked(0);
    // Writers that pass WriteThread::JoinBatchGroup:Wait sync-point.
    std::atomic<uint64_t> threads_verified(0);

    std::atomic<uint64_t> seq(db_impl->GetLatestSequenceNumber());
    ASSERT_EQ(db_impl->GetLatestSequenceNumber(), 0);

    ROCKSDB_NAMESPACE::SyncPoint::GetInstance()->SetCallBack(
        "WriteThread::JoinBatchGroup:Start", [&](void*) {
          uint64_t cur_threads_joining = threads_joining.fetch_add(1);
          // Wait for the last joined writer to link to the queue.
          // In this way the writers link to the queue one by one.
          // This allows us to confidently detect the first writer
          // who increases threads_linked as the leader.
          while (threads_linked.load() < cur_threads_joining) {
          }
        });

    // Verification once writers call JoinBatchGroup.
    ROCKSDB_NAMESPACE::SyncPoint::GetInstance()->SetCallBack(
        "WriteThread::JoinBatchGroup:Wait", [&](void* arg) {
          uint64_t cur_threads_linked = threads_linked.fetch_add(1);
          bool is_leader = false;
          bool is_last = false;

          // who am i
          is_leader = (cur_threads_linked == 0);
          is_last = (cur_threads_linked == write_group.size() - 1);

          // check my state
          auto* writer = reinterpret_cast<WriteThread::Writer*>(arg);

          if (is_leader) {
            ASSERT_TRUE(writer->state ==
                        WriteThread::State::STATE_GROUP_LEADER);
          } else {
            ASSERT_TRUE(writer->state == WriteThread::State::STATE_INIT);
          }

          // (meta test) the first WriteOP should indeed be the first
          // and the last should be the last (all others can be out of
          // order)
          if (is_leader) {
            ASSERT_TRUE(writer->callback->Callback(nullptr).ok() ==
                        !write_group.front().callback_.should_fail_);
          } else if (is_last) {
            ASSERT_TRUE(writer->callback->Callback(nullptr).ok() ==
                        !write_group.back().callback_.should_fail_);
          }

          threads_verified.fetch_add(1);
          // Wait here until all verification in this sync-point
          // callback finish for all writers.
          while (threads_verified.load() < write_group.size()) {
          }
        });

    ROCKSDB_NAMESPACE::SyncPoint::GetInstance()->SetCallBack(
        "WriteThread::JoinBatchGroup:DoneWaiting", [&](void* arg) {
          // check my state
          auto* writer = reinterpret_cast<WriteThread::Writer*>(arg);

          if (!allow_batching_) {
            // no batching so everyone should be a leader
            ASSERT_TRUE(writer->state ==
                        WriteThread::State::STATE_GROUP_LEADER);
          } else if (!allow_parallel_) {
            ASSERT_TRUE(writer->state == WriteThread::State::STATE_COMPLETED ||
                        (enable_pipelined_write_ &&
                         writer->state ==
                             WriteThread::State::STATE_MEMTABLE_WRITER_LEADER));
          }
        });

    std::atomic<uint32_t> thread_num(0);
    std::atomic<char> dummy_key(0);

    // Each write thread create a random write batch and write to DB
    // with a write callback.
    std::function<void()> write_with_callback_func = [&]() {
      uint32_t i = thread_num.fetch_add(1);
      Random rnd(i);

      // leaders gotta lead
      while (i > 0 && threads_verified.load() < 1) {
      }

      // loser has to lose
      while (i == write_group.size() - 1 &&
             threads_verified.load() < write_group.size() - 1) {
      }

      auto& write_op = write_group.at(i);
      write_op.Clear();
      write_op.callback_.allow_batching_ = allow_batching_;

      // insert some keys
      for (uint32_t j = 0; j < rnd.Next() % 50; j++) {
        // grab unique key
        char my_key = dummy_key.fetch_add(1);

        string skey(5, my_key);
        string sval(10, my_key);
        write_op.Put(skey, sval);

        if (!write_op.callback_.should_fail_ && !seq_per_batch_) {
          seq.fetch_add(1);
        }
      }
      if (!write_op.callback_.should_fail_ && seq_per_batch_) {
        seq.fetch_add(1);
      }

      WriteOptions woptions;
      woptions.disableWAL = !enable_WAL_;
      woptions.sync = enable_WAL_;
      Status s;
      if (seq_per_batch_) {
        class PublishSeqCallback : public PreReleaseCallback {
         public:
          PublishSeqCallback(DBImpl* db_impl_in) : db_impl_(db_impl_in) {}
          Status Callback(SequenceNumber last_seq, bool /*not used*/, uint64_t,
                          size_t /*index*/, size_t /*total*/) override {
            db_impl_->SetLastPublishedSequence(last_seq);
            return Status::OK();
          }
          DBImpl* db_impl_;
        } publish_seq_callback(db_impl);
        // seq_per_batch_ requires a natural batch separator or Noop
        WriteBatchInternal::InsertNoop(&write_op.write_batch_);
        const size_t ONE_BATCH = 1;
        s = db_impl->WriteImpl(woptions, &write_op.write_batch_,
                               &write_op.callback_, nullptr, 0, false, nullptr,
                               ONE_BATCH,
                               two_queues_ ? &publish_seq_callback : nullptr);
      } else {
        s = db_impl->WriteWithCallback(woptions, &write_op.write_batch_,
                                       &write_op.callback_);
      }

      if (write_op.callback_.should_fail_) {
        ASSERT_TRUE(s.IsBusy());
      } else {
        ASSERT_OK(s);
      }
    };

    ROCKSDB_NAMESPACE::SyncPoint::GetInstance()->EnableProcessing();

    // do all the writes
    std::vector<port::Thread> threads;
    for (uint32_t i = 0; i < write_group.size(); i++) {
      threads.emplace_back(write_with_callback_func);
    }
    for (auto& t : threads) {
      t.join();
    }

    ROCKSDB_NAMESPACE::SyncPoint::GetInstance()->DisableProcessing();

    // check for keys
    string value;
    for (auto& w : write_group) {
      ASSERT_TRUE(w.callback_.was_called_.load());
      for (auto& kvp : w.kvs_) {
        if (w.callback_.should_fail_) {
          ASSERT_TRUE(db->Get(read_options, kvp.first, &value).IsNotFound());
        } else {
          ASSERT_OK(db->Get(read_options, kvp.first, &value));
          ASSERT_EQ(value, kvp.second);
        }
      }
    }

    ASSERT_EQ(seq.load(), db_impl->TEST_GetLastVisibleSequence());

    delete db;
    DestroyDB(dbname, options);
  }
}

INSTANTIATE_TEST_CASE_P(WriteCallbackPTest, WriteCallbackPTest,
                        ::testing::Combine(::testing::Bool(), ::testing::Bool(),
                                           ::testing::Bool(), ::testing::Bool(),
                                           ::testing::Bool(), ::testing::Bool(),
                                           ::testing::Bool()));

TEST_F(WriteCallbackTest, WriteCallBackTest) {
  Options options;
  WriteOptions write_options;
  ReadOptions read_options;
  string value;
  DB* db;
  DBImpl* db_impl;

  DestroyDB(dbname, options);

  options.create_if_missing = true;
  Status s = DB::Open(options, dbname, &db);
  ASSERT_OK(s);

  db_impl = dynamic_cast<DBImpl*> (db);
  ASSERT_TRUE(db_impl);

  WriteBatch wb;

  wb.Put("a", "value.a");
  wb.Delete("x");

  // Test a simple Write
  s = db->Write(write_options, &wb);
  ASSERT_OK(s);

  s = db->Get(read_options, "a", &value);
  ASSERT_OK(s);
  ASSERT_EQ("value.a", value);

  // Test WriteWithCallback
  WriteCallbackTestWriteCallback1 callback1;
  WriteBatch wb2;

  wb2.Put("a", "value.a2");

  s = db_impl->WriteWithCallback(write_options, &wb2, &callback1);
  ASSERT_OK(s);
  ASSERT_TRUE(callback1.was_called);

  s = db->Get(read_options, "a", &value);
  ASSERT_OK(s);
  ASSERT_EQ("value.a2", value);

  // Test WriteWithCallback for a callback that fails
  WriteCallbackTestWriteCallback2 callback2;
  WriteBatch wb3;

  wb3.Put("a", "value.a3");

  s = db_impl->WriteWithCallback(write_options, &wb3, &callback2);
  ASSERT_NOK(s);

  s = db->Get(read_options, "a", &value);
  ASSERT_OK(s);
  ASSERT_EQ("value.a2", value);

  delete db;
  DestroyDB(dbname, options);
}

}  // namespace ROCKSDB_NAMESPACE

int main(int argc, char** argv) {
  ::testing::InitGoogleTest(&argc, argv);
  return RUN_ALL_TESTS();
}

#else
#include <stdio.h>

int main(int /*argc*/, char** /*argv*/) {
  fprintf(stderr,
          "SKIPPED as WriteWithCallback is not supported in ROCKSDB_LITE\n");
  return 0;
}

#endif  // !ROCKSDB_LITE