// Copyright (c) 2011-present, Facebook, Inc. All rights reserved. // This source code is licensed under both the GPLv2 (found in the // COPYING file in the root directory) and Apache 2.0 License // (found in the LICENSE.Apache file in the root directory). // #include <assert.h> #include <iostream> #include <memory> #include "db/db_impl/db_impl.h" #include "db/dbformat.h" #include "db/write_batch_internal.h" #include "port/stack_trace.h" #include "rocksdb/cache.h" #include "rocksdb/comparator.h" #include "rocksdb/db.h" #include "rocksdb/env.h" #include "rocksdb/merge_operator.h" #include "rocksdb/utilities/db_ttl.h" #include "test_util/testharness.h" #include "util/coding.h" #include "utilities/merge_operators.h" namespace ROCKSDB_NAMESPACE { bool use_compression; class MergeTest : public testing::Test {}; size_t num_merge_operator_calls; void resetNumMergeOperatorCalls() { num_merge_operator_calls = 0; } size_t num_partial_merge_calls; void resetNumPartialMergeCalls() { num_partial_merge_calls = 0; } class CountMergeOperator : public AssociativeMergeOperator { public: CountMergeOperator() { mergeOperator_ = MergeOperators::CreateUInt64AddOperator(); } bool Merge(const Slice& key, const Slice* existing_value, const Slice& value, std::string* new_value, Logger* logger) const override { assert(new_value->empty()); ++num_merge_operator_calls; if (existing_value == nullptr) { new_value->assign(value.data(), value.size()); return true; } return mergeOperator_->PartialMerge( key, *existing_value, value, new_value, logger); } bool PartialMergeMulti(const Slice& key, const std::deque<Slice>& operand_list, std::string* new_value, Logger* logger) const override { assert(new_value->empty()); ++num_partial_merge_calls; return mergeOperator_->PartialMergeMulti(key, operand_list, new_value, logger); } const char* Name() const override { return "UInt64AddOperator"; } private: std::shared_ptr<MergeOperator> mergeOperator_; }; std::shared_ptr<DB> OpenDb(const std::string& dbname, const bool ttl = false, const size_t max_successive_merges = 0) { DB* db; Options options; options.create_if_missing = true; options.merge_operator = std::make_shared<CountMergeOperator>(); options.max_successive_merges = max_successive_merges; EXPECT_OK(DestroyDB(dbname, Options())); Status s; // DBWithTTL is not supported in ROCKSDB_LITE #ifndef ROCKSDB_LITE if (ttl) { DBWithTTL* db_with_ttl; s = DBWithTTL::Open(options, dbname, &db_with_ttl); db = db_with_ttl; } else { s = DB::Open(options, dbname, &db); } #else assert(!ttl); s = DB::Open(options, dbname, &db); #endif // !ROCKSDB_LITE EXPECT_OK(s); assert(s.ok()); return std::shared_ptr<DB>(db); } // Imagine we are maintaining a set of uint64 counters. // Each counter has a distinct name. And we would like // to support four high level operations: // set, add, get and remove // This is a quick implementation without a Merge operation. class Counters { protected: std::shared_ptr<DB> db_; WriteOptions put_option_; ReadOptions get_option_; WriteOptions delete_option_; uint64_t default_; public: explicit Counters(std::shared_ptr<DB> db, uint64_t defaultCount = 0) : db_(db), put_option_(), get_option_(), delete_option_(), default_(defaultCount) { assert(db_); } virtual ~Counters() {} // public interface of Counters. // All four functions return false // if the underlying level db operation failed. // mapped to a levedb Put bool set(const std::string& key, uint64_t value) { // just treat the internal rep of int64 as the string char buf[sizeof(value)]; EncodeFixed64(buf, value); Slice slice(buf, sizeof(value)); auto s = db_->Put(put_option_, key, slice); if (s.ok()) { return true; } else { std::cerr << s.ToString() << std::endl; return false; } } // mapped to a rocksdb Delete bool remove(const std::string& key) { auto s = db_->Delete(delete_option_, key); if (s.ok()) { return true; } else { std::cerr << s.ToString() << std::endl; return false; } } // mapped to a rocksdb Get bool get(const std::string& key, uint64_t* value) { std::string str; auto s = db_->Get(get_option_, key, &str); if (s.IsNotFound()) { // return default value if not found; *value = default_; return true; } else if (s.ok()) { // deserialization if (str.size() != sizeof(uint64_t)) { std::cerr << "value corruption\n"; return false; } *value = DecodeFixed64(&str[0]); return true; } else { std::cerr << s.ToString() << std::endl; return false; } } // 'add' is implemented as get -> modify -> set // An alternative is a single merge operation, see MergeBasedCounters virtual bool add(const std::string& key, uint64_t value) { uint64_t base = default_; return get(key, &base) && set(key, base + value); } // convenience functions for testing void assert_set(const std::string& key, uint64_t value) { assert(set(key, value)); } void assert_remove(const std::string& key) { assert(remove(key)); } uint64_t assert_get(const std::string& key) { uint64_t value = default_; int result = get(key, &value); assert(result); if (result == 0) exit(1); // Disable unused variable warning. return value; } void assert_add(const std::string& key, uint64_t value) { int result = add(key, value); assert(result); if (result == 0) exit(1); // Disable unused variable warning. } }; // Implement 'add' directly with the new Merge operation class MergeBasedCounters : public Counters { private: WriteOptions merge_option_; // for merge public: explicit MergeBasedCounters(std::shared_ptr<DB> db, uint64_t defaultCount = 0) : Counters(db, defaultCount), merge_option_() { } // mapped to a rocksdb Merge operation bool add(const std::string& key, uint64_t value) override { char encoded[sizeof(uint64_t)]; EncodeFixed64(encoded, value); Slice slice(encoded, sizeof(uint64_t)); auto s = db_->Merge(merge_option_, key, slice); if (s.ok()) { return true; } else { std::cerr << s.ToString() << std::endl; return false; } } }; void dumpDb(DB* db) { auto it = std::unique_ptr<Iterator>(db->NewIterator(ReadOptions())); for (it->SeekToFirst(); it->Valid(); it->Next()) { //uint64_t value = DecodeFixed64(it->value().data()); //std::cout << it->key().ToString() << ": " << value << std::endl; } assert(it->status().ok()); // Check for any errors found during the scan } void testCounters(Counters& counters, DB* db, bool test_compaction) { FlushOptions o; o.wait = true; counters.assert_set("a", 1); if (test_compaction) { ASSERT_OK(db->Flush(o)); } ASSERT_EQ(counters.assert_get("a"), 1); counters.assert_remove("b"); // defaut value is 0 if non-existent ASSERT_EQ(counters.assert_get("b"), 0); counters.assert_add("a", 2); if (test_compaction) { ASSERT_OK(db->Flush(o)); } // 1+2 = 3 ASSERT_EQ(counters.assert_get("a"), 3); dumpDb(db); // 1+...+49 = ? uint64_t sum = 0; for (int i = 1; i < 50; i++) { counters.assert_add("b", i); sum += i; } ASSERT_EQ(counters.assert_get("b"), sum); dumpDb(db); if (test_compaction) { ASSERT_OK(db->Flush(o)); ASSERT_OK(db->CompactRange(CompactRangeOptions(), nullptr, nullptr)); dumpDb(db); ASSERT_EQ(counters.assert_get("a"), 3); ASSERT_EQ(counters.assert_get("b"), sum); } } void testCountersWithFlushAndCompaction(Counters& counters, DB* db) { ASSERT_OK(db->Put({}, "1", "1")); ASSERT_OK(db->Flush(FlushOptions())); std::atomic<int> cnt{0}; const auto get_thread_id = [&cnt]() { thread_local int thread_id{cnt++}; return thread_id; }; SyncPoint::GetInstance()->DisableProcessing(); SyncPoint::GetInstance()->ClearAllCallBacks(); SyncPoint::GetInstance()->SetCallBack( "VersionSet::LogAndApply:BeforeWriterWaiting", [&](void* /*arg*/) { int thread_id = get_thread_id(); if (1 == thread_id) { TEST_SYNC_POINT( "testCountersWithFlushAndCompaction::bg_compact_thread:0"); } else if (2 == thread_id) { TEST_SYNC_POINT( "testCountersWithFlushAndCompaction::bg_flush_thread:0"); } }); SyncPoint::GetInstance()->SetCallBack( "VersionSet::LogAndApply:WriteManifest", [&](void* /*arg*/) { int thread_id = get_thread_id(); if (0 == thread_id) { TEST_SYNC_POINT( "testCountersWithFlushAndCompaction::set_options_thread:0"); TEST_SYNC_POINT( "testCountersWithFlushAndCompaction::set_options_thread:1"); } }); SyncPoint::GetInstance()->SetCallBack( "VersionSet::LogAndApply:WakeUpAndDone", [&](void* arg) { auto* mutex = reinterpret_cast<InstrumentedMutex*>(arg); mutex->AssertHeld(); int thread_id = get_thread_id(); ASSERT_EQ(2, thread_id); mutex->Unlock(); TEST_SYNC_POINT( "testCountersWithFlushAndCompaction::bg_flush_thread:1"); TEST_SYNC_POINT( "testCountersWithFlushAndCompaction::bg_flush_thread:2"); mutex->Lock(); }); SyncPoint::GetInstance()->LoadDependency({ {"testCountersWithFlushAndCompaction::set_options_thread:0", "testCountersWithCompactionAndFlush:BeforeCompact"}, {"testCountersWithFlushAndCompaction::bg_compact_thread:0", "testCountersWithFlushAndCompaction:BeforeIncCounters"}, {"testCountersWithFlushAndCompaction::bg_flush_thread:0", "testCountersWithFlushAndCompaction::set_options_thread:1"}, {"testCountersWithFlushAndCompaction::bg_flush_thread:1", "testCountersWithFlushAndCompaction:BeforeVerification"}, {"testCountersWithFlushAndCompaction:AfterGet", "testCountersWithFlushAndCompaction::bg_flush_thread:2"}, }); SyncPoint::GetInstance()->EnableProcessing(); port::Thread set_options_thread([&]() { ASSERT_OK(reinterpret_cast<DBImpl*>(db)->SetOptions( {{"disable_auto_compactions", "false"}})); }); TEST_SYNC_POINT("testCountersWithCompactionAndFlush:BeforeCompact"); port::Thread compact_thread([&]() { ASSERT_OK(reinterpret_cast<DBImpl*>(db)->CompactRange( CompactRangeOptions(), db->DefaultColumnFamily(), nullptr, nullptr)); }); TEST_SYNC_POINT("testCountersWithFlushAndCompaction:BeforeIncCounters"); counters.add("test-key", 1); FlushOptions flush_opts; flush_opts.wait = false; ASSERT_OK(db->Flush(flush_opts)); TEST_SYNC_POINT("testCountersWithFlushAndCompaction:BeforeVerification"); std::string expected; PutFixed64(&expected, 1); std::string actual; Status s = db->Get(ReadOptions(), "test-key", &actual); TEST_SYNC_POINT("testCountersWithFlushAndCompaction:AfterGet"); set_options_thread.join(); compact_thread.join(); ASSERT_OK(s); ASSERT_EQ(expected, actual); SyncPoint::GetInstance()->DisableProcessing(); SyncPoint::GetInstance()->ClearAllCallBacks(); } void testSuccessiveMerge(Counters& counters, size_t max_num_merges, size_t num_merges) { counters.assert_remove("z"); uint64_t sum = 0; for (size_t i = 1; i <= num_merges; ++i) { resetNumMergeOperatorCalls(); counters.assert_add("z", i); sum += i; if (i % (max_num_merges + 1) == 0) { ASSERT_EQ(num_merge_operator_calls, max_num_merges + 1); } else { ASSERT_EQ(num_merge_operator_calls, 0); } resetNumMergeOperatorCalls(); ASSERT_EQ(counters.assert_get("z"), sum); ASSERT_EQ(num_merge_operator_calls, i % (max_num_merges + 1)); } } void testPartialMerge(Counters* counters, DB* db, size_t max_merge, size_t min_merge, size_t count) { FlushOptions o; o.wait = true; // Test case 1: partial merge should be called when the number of merge // operands exceeds the threshold. uint64_t tmp_sum = 0; resetNumPartialMergeCalls(); for (size_t i = 1; i <= count; i++) { counters->assert_add("b", i); tmp_sum += i; } ASSERT_OK(db->Flush(o)); ASSERT_OK(db->CompactRange(CompactRangeOptions(), nullptr, nullptr)); ASSERT_EQ(tmp_sum, counters->assert_get("b")); if (count > max_merge) { // in this case, FullMerge should be called instead. ASSERT_EQ(num_partial_merge_calls, 0U); } else { // if count >= min_merge, then partial merge should be called once. ASSERT_EQ((count >= min_merge), (num_partial_merge_calls == 1)); } // Test case 2: partial merge should not be called when a put is found. resetNumPartialMergeCalls(); tmp_sum = 0; ASSERT_OK(db->Put(ROCKSDB_NAMESPACE::WriteOptions(), "c", "10")); for (size_t i = 1; i <= count; i++) { counters->assert_add("c", i); tmp_sum += i; } ASSERT_OK(db->Flush(o)); ASSERT_OK(db->CompactRange(CompactRangeOptions(), nullptr, nullptr)); ASSERT_EQ(tmp_sum, counters->assert_get("c")); ASSERT_EQ(num_partial_merge_calls, 0U); } void testSingleBatchSuccessiveMerge(DB* db, size_t max_num_merges, size_t num_merges) { ASSERT_GT(num_merges, max_num_merges); Slice key("BatchSuccessiveMerge"); uint64_t merge_value = 1; char buf[sizeof(merge_value)]; EncodeFixed64(buf, merge_value); Slice merge_value_slice(buf, sizeof(merge_value)); // Create the batch WriteBatch batch; for (size_t i = 0; i < num_merges; ++i) { ASSERT_OK(batch.Merge(key, merge_value_slice)); } // Apply to memtable and count the number of merges resetNumMergeOperatorCalls(); ASSERT_OK(db->Write(WriteOptions(), &batch)); ASSERT_EQ( num_merge_operator_calls, static_cast<size_t>(num_merges - (num_merges % (max_num_merges + 1)))); // Get the value resetNumMergeOperatorCalls(); std::string get_value_str; ASSERT_OK(db->Get(ReadOptions(), key, &get_value_str)); assert(get_value_str.size() == sizeof(uint64_t)); uint64_t get_value = DecodeFixed64(&get_value_str[0]); ASSERT_EQ(get_value, num_merges * merge_value); ASSERT_EQ(num_merge_operator_calls, static_cast<size_t>((num_merges % (max_num_merges + 1)))); } void runTest(const std::string& dbname, const bool use_ttl = false) { { auto db = OpenDb(dbname, use_ttl); { Counters counters(db, 0); testCounters(counters, db.get(), true); } { MergeBasedCounters counters(db, 0); testCounters(counters, db.get(), use_compression); } } ASSERT_OK(DestroyDB(dbname, Options())); { size_t max_merge = 5; auto db = OpenDb(dbname, use_ttl, max_merge); MergeBasedCounters counters(db, 0); testCounters(counters, db.get(), use_compression); testSuccessiveMerge(counters, max_merge, max_merge * 2); testSingleBatchSuccessiveMerge(db.get(), 5, 7); ASSERT_OK(db->Close()); ASSERT_OK(DestroyDB(dbname, Options())); } { size_t max_merge = 100; // Min merge is hard-coded to 2. uint32_t min_merge = 2; for (uint32_t count = min_merge - 1; count <= min_merge + 1; count++) { auto db = OpenDb(dbname, use_ttl, max_merge); MergeBasedCounters counters(db, 0); testPartialMerge(&counters, db.get(), max_merge, min_merge, count); ASSERT_OK(db->Close()); ASSERT_OK(DestroyDB(dbname, Options())); } { auto db = OpenDb(dbname, use_ttl, max_merge); MergeBasedCounters counters(db, 0); testPartialMerge(&counters, db.get(), max_merge, min_merge, min_merge * 10); ASSERT_OK(db->Close()); ASSERT_OK(DestroyDB(dbname, Options())); } } { { auto db = OpenDb(dbname); MergeBasedCounters counters(db, 0); counters.add("test-key", 1); counters.add("test-key", 1); counters.add("test-key", 1); ASSERT_OK(db->CompactRange(CompactRangeOptions(), nullptr, nullptr)); } DB* reopen_db; ASSERT_OK(DB::Open(Options(), dbname, &reopen_db)); std::string value; ASSERT_NOK(reopen_db->Get(ReadOptions(), "test-key", &value)); delete reopen_db; ASSERT_OK(DestroyDB(dbname, Options())); } /* Temporary remove this test { std::cout << "Test merge-operator not set after reopen (recovery case)\n"; { auto db = OpenDb(dbname); MergeBasedCounters counters(db, 0); counters.add("test-key", 1); counters.add("test-key", 1); counters.add("test-key", 1); } DB* reopen_db; ASSERT_TRUE(DB::Open(Options(), dbname, &reopen_db).IsInvalidArgument()); } */ } TEST_F(MergeTest, MergeDbTest) { runTest(test::PerThreadDBPath("merge_testdb")); } #ifndef ROCKSDB_LITE TEST_F(MergeTest, MergeDbTtlTest) { runTest(test::PerThreadDBPath("merge_testdbttl"), true); // Run test on TTL database } TEST_F(MergeTest, MergeWithCompactionAndFlush) { const std::string dbname = test::PerThreadDBPath("merge_with_compaction_and_flush"); { auto db = OpenDb(dbname); { MergeBasedCounters counters(db, 0); testCountersWithFlushAndCompaction(counters, db.get()); } } ASSERT_OK(DestroyDB(dbname, Options())); } #endif // !ROCKSDB_LITE } // namespace ROCKSDB_NAMESPACE int main(int argc, char** argv) { ROCKSDB_NAMESPACE::use_compression = false; if (argc > 1) { ROCKSDB_NAMESPACE::use_compression = true; } ROCKSDB_NAMESPACE::port::InstallStackTraceHandler(); ::testing::InitGoogleTest(&argc, argv); return RUN_ALL_TESTS(); }