From 4f9c0fd083d00bacd86b723f7789658ccfd3f108 Mon Sep 17 00:00:00 2001 From: sdong Date: Fri, 15 Apr 2022 23:24:05 -0700 Subject: [PATCH] Add Aggregation Merge Operator (#9780) Summary: Add a merge operator that allows users to register specific aggregation function so that they can does aggregation based per key using different aggregation types. See comments of function CreateAggMergeOperator() for actual usage. Pull Request resolved: https://github.com/facebook/rocksdb/pull/9780 Test Plan: Add a unit test to coverage various cases. Reviewed By: ltamasi Differential Revision: D35267444 fbshipit-source-id: 5b02f31c4f3e17e96dd4025cdc49fca8c2868628 --- CMakeLists.txt | 3 + HISTORY.md | 1 + Makefile | 3 + TARGETS | 9 + include/rocksdb/db.h | 2 + include/rocksdb/utilities/agg_merge.h | 138 +++++++++++++++ src.mk | 3 + utilities/agg_merge/agg_merge.cc | 238 ++++++++++++++++++++++++++ utilities/agg_merge/agg_merge.h | 49 ++++++ utilities/agg_merge/agg_merge_test.cc | 134 +++++++++++++++ utilities/agg_merge/test_agg_merge.cc | 104 +++++++++++ utilities/agg_merge/test_agg_merge.h | 47 +++++ 12 files changed, 731 insertions(+) create mode 100644 include/rocksdb/utilities/agg_merge.h create mode 100644 utilities/agg_merge/agg_merge.cc create mode 100644 utilities/agg_merge/agg_merge.h create mode 100644 utilities/agg_merge/agg_merge_test.cc create mode 100644 utilities/agg_merge/test_agg_merge.cc create mode 100644 utilities/agg_merge/test_agg_merge.h diff --git a/CMakeLists.txt b/CMakeLists.txt index 1923d1707..0d400462e 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -816,6 +816,7 @@ set(SOURCES util/thread_local.cc util/threadpool_imp.cc util/xxhash.cc + utilities/agg_merge/agg_merge.cc utilities/backup/backup_engine.cc utilities/blob_db/blob_compaction_filter.cc utilities/blob_db/blob_db.cc @@ -1335,6 +1336,7 @@ if(WITH_TESTS) util/thread_list_test.cc util/thread_local_test.cc util/work_queue_test.cc + utilities/agg_merge/agg_merge_test.cc utilities/backup/backup_engine_test.cc utilities/blob_db/blob_db_test.cc utilities/cassandra/cassandra_functional_test.cc @@ -1370,6 +1372,7 @@ if(WITH_TESTS) db/db_test_util.cc monitoring/thread_status_updater_debug.cc table/mock_table.cc + utilities/agg_merge/test_agg_merge.cc utilities/cassandra/test_utils.cc ) enable_testing() diff --git a/HISTORY.md b/HISTORY.md index 28a18bbd8..b294d5478 100644 --- a/HISTORY.md +++ b/HISTORY.md @@ -21,6 +21,7 @@ * Add event listener support on remote compaction compactor side. * Added a dedicated integer DB property `rocksdb.live-blob-file-garbage-size` that exposes the total amount of garbage in the blob files in the current version. * RocksDB does internal auto prefetching if it notices sequential reads. It starts with readahead size `initial_auto_readahead_size` which now can be configured through BlockBasedTableOptions. +* Add a merge operator that allows users to register specific aggregation function so that they can does aggregation using different aggregation types for different keys. See comments in include/rocksdb/utilities/agg_merge.h for actual usage. The feature is experimental and the format is subject to change and we won't provide a migration tool. ### Behavior changes * Disallow usage of commit-time-write-batch for write-prepared/write-unprepared transactions if TransactionOptions::use_only_the_last_commit_time_batch_for_recovery is false to prevent two (or more) uncommitted versions of the same key in the database. Otherwise, bottommost compaction may violate the internal key uniqueness invariant of SSTs if the sequence numbers of both internal keys are zeroed out (#9794). diff --git a/Makefile b/Makefile index b81e5cf58..41b8a98d6 100644 --- a/Makefile +++ b/Makefile @@ -1380,6 +1380,9 @@ ribbon_test: $(OBJ_DIR)/util/ribbon_test.o $(TEST_LIBRARY) $(LIBRARY) option_change_migration_test: $(OBJ_DIR)/utilities/option_change_migration/option_change_migration_test.o $(TEST_LIBRARY) $(LIBRARY) $(AM_LINK) +agg_merge_test: $(OBJ_DIR)/utilities/agg_merge/agg_merge_test.o $(TEST_LIBRARY) $(LIBRARY) + $(AM_LINK) + stringappend_test: $(OBJ_DIR)/utilities/merge_operators/string_append/stringappend_test.o $(TEST_LIBRARY) $(LIBRARY) $(AM_LINK) diff --git a/TARGETS b/TARGETS index 1f2bd2bfb..f6403accd 100644 --- a/TARGETS +++ b/TARGETS @@ -245,6 +245,7 @@ cpp_library_wrapper(name="rocksdb_lib", srcs=[ "util/thread_local.cc", "util/threadpool_imp.cc", "util/xxhash.cc", + "utilities/agg_merge/agg_merge.cc", "utilities/backup/backup_engine.cc", "utilities/blob_db/blob_compaction_filter.cc", "utilities/blob_db/blob_db.cc", @@ -563,6 +564,7 @@ cpp_library_wrapper(name="rocksdb_whole_archive_lib", srcs=[ "util/thread_local.cc", "util/threadpool_imp.cc", "util/xxhash.cc", + "utilities/agg_merge/agg_merge.cc", "utilities/backup/backup_engine.cc", "utilities/blob_db/blob_compaction_filter.cc", "utilities/blob_db/blob_db.cc", @@ -652,6 +654,7 @@ cpp_library_wrapper(name="rocksdb_test_lib", srcs=[ "test_util/testutil.cc", "tools/block_cache_analyzer/block_cache_trace_analyzer.cc", "tools/trace_analyzer_tool.cc", + "utilities/agg_merge/test_agg_merge.cc", "utilities/cassandra/test_utils.cc", ], deps=[":rocksdb_lib"], headers=None, link_whole=False, extra_test_libs=True) @@ -4698,6 +4701,12 @@ fancy_bench_wrapper(suite_name="rocksdb_microbench_suite_14_slow", binary_to_ben # Do not build the tests in opt mode, since SyncPoint and other test code # will not be included. +cpp_unittest_wrapper(name="agg_merge_test", + srcs=["utilities/agg_merge/agg_merge_test.cc"], + deps=[":rocksdb_test_lib"], + extra_compiler_flags=[]) + + cpp_unittest_wrapper(name="arena_test", srcs=["memory/arena_test.cc"], deps=[":rocksdb_test_lib"], diff --git a/include/rocksdb/db.h b/include/rocksdb/db.h index 4d8e80013..74754da5a 100644 --- a/include/rocksdb/db.h +++ b/include/rocksdb/db.h @@ -10,11 +10,13 @@ #include #include + #include #include #include #include #include + #include "rocksdb/iterator.h" #include "rocksdb/listener.h" #include "rocksdb/metadata.h" diff --git a/include/rocksdb/utilities/agg_merge.h b/include/rocksdb/utilities/agg_merge.h new file mode 100644 index 000000000..4e21082db --- /dev/null +++ b/include/rocksdb/utilities/agg_merge.h @@ -0,0 +1,138 @@ +// Copyright (c) 2011-present, Facebook, Inc. All rights reserved. +// This source code is licensed under both the GPLv2 (found in the +// COPYING file in the root directory) and Apache 2.0 License +// (found in the LICENSE.Apache file in the root directory). + +#pragma once + +#include +#include + +#include "rocksdb/merge_operator.h" +#include "rocksdb/slice.h" + +namespace ROCKSDB_NAMESPACE { +// The feature is still in development so the encoding format is subject +// to change. +// +// Aggregation Merge Operator is a merge operator that allows users to +// aggregate merge operands of different keys with different registered +// aggregation functions. The aggregation can also change for the same +// key if the functions store the data in the same format. +// The target application highly overlaps with merge operator in general +// but we try to provide a better interface so that users are more likely +// to use pre-implemented plug-in functions and connect with existing +// third-party aggregation functions (such as those from SQL engines). +// In this case, the need for users to write customized C++ plug-in code +// is reduced. +// If the idea proves to useful, we might consider to move it to be +// a core functionality of RocksDB, and reduce the support of merge +// operators. +// +// Users can implement aggregation functions by implementing abstract +// class Aggregator, and register it using AddAggregator(). +// The merge operator can be retrieved from GetAggMergeOperator() and +// it is a singleton. +// +// Users can push values to be updated with a merge operand encoded with +// registered function name and payload using EncodeAggFuncAndPayload(), +// and the merge operator will invoke the aggregation function. +// An example: +// +// // Assume class ExampleSumAggregator is implemented to do simple sum. +// AddAggregator("sum", std::make_unique()); +// std::shared_ptr mp_guard = CreateAggMergeOperator(); +// options.merge_operator = mp_guard.get(); +// ...... // Creating DB +// +// +// std::string encoded_value; +// s = EncodeAggFuncAndPayload(kUnamedFuncName, "200", encoded_value); +// assert(s.ok()); +// db->Put(WriteOptions(), "foo", encoded_value); +// s = EncodeAggFuncAndPayload("sum", "200", encoded_value); +// assert(s.ok()); +// db->Merge(WriteOptions(), "foo", encoded_value); +// s = EncodeAggFuncAndPayload("sum", "200", encoded_value); +// assert(s.ok()); +// db->Merge(WriteOptions(), "foo", encoded_value); +// +// std::string value; +// Status s = db->Get(ReadOptions, "foo", &value); +// assert(s.ok()); +// Slice func, aggregated_value; +// assert(ExtractAggFuncAndValue(value, func, aggregated_value)); +// assert(func == "sum"); +// assert(aggregated_value == "600"); +// +// +// DB::Put() can also be used to add a payloadin the same way as Merge(). +// +// kUnamedFuncName can be used as a placeholder function name. This will +// be aggregated with merge operands inserted later based on function +// name given there. +// +// If the aggregation function is not registered or there is an error +// returned by aggregation function, the result will be encoded with a fake +// aggregation function kErrorFuncName, with each merge operands to be encoded +// into a list that can be extracted using ExtractList(); +// +// If users add a merge operand using a different aggregation function from +// the previous one, the merge operands for the previous one is aggregated +// and the payload part of the result is treated as the first payload of +// the items for the new aggregation function. For example, users can +// Merge("plus, 1"), merge("plus 2"), merge("minus 3") and the aggregation +// result would be "minus 0". +// + +// A class used to aggregate data per key/value. The plug-in function is +// implemented and registered using AddAggregator(). And then use it +// with merge operator created using CreateAggMergeOperator(). +class Aggregator { + public: + virtual ~Aggregator() {} + // The input list is in reverse insertion order, with values[0] to be + // the one inserted last and values.back() to be the one inserted first. + // The oldest one might be from Get(). + // Return whether aggregation succeeded. False for aggregation error. + virtual bool Aggregate(const std::vector& values, + std::string& result) const = 0; + + // True if a partial aggregation should be invoked. Some aggregators + // might opt to skip partial aggregation if possible. + virtual bool DoPartialAggregate() const { return true; } +}; + +// The function adds aggregation plugin by function name. It is used +// by all the aggregation operator created using CreateAggMergeOperator(). +// It's currently not thread safe to run concurrently with the aggregation +// merge operator. It is recommended that all the aggregation function +// is added before calling CreateAggMergeOperator(). +Status AddAggregator(const std::string& function_name, + std::unique_ptr&& agg); + +// Get the singleton instance of merge operator for aggregation. +// Always the same one is returned with a shared_ptr is hold as a +// static variable by the function. +// This is done so because options.merge_operator is shared_ptr. +std::shared_ptr GetAggMergeOperator(); + +// Encode aggregation function and payload that can be consumed by aggregation +// merge operator. +Status EncodeAggFuncAndPayload(const Slice& function_name, const Slice& payload, + std::string& output); +// Helper function to extract aggregation function name and payload. +// Return false if it fails to decode. +bool ExtractAggFuncAndValue(const Slice& op, Slice& func, Slice& value); + +// Extract encoded list. This can be used to extract error merge operands when +// the returned function name is kErrorFuncName. +bool ExtractList(const Slice& encoded_list, std::vector& decoded_list); + +// Special function name that allows it to be merged to subsequent type. +extern const std::string kUnnamedFuncName; + +// Special error function name reserved for merging or aggregation error. +extern const std::string kErrorFuncName; + +} // namespace ROCKSDB_NAMESPACE diff --git a/src.mk b/src.mk index 9e1fa5e0f..72c4d5f54 100644 --- a/src.mk +++ b/src.mk @@ -232,6 +232,7 @@ LIB_SOURCES = \ util/thread_local.cc \ util/threadpool_imp.cc \ util/xxhash.cc \ + utilities/agg_merge/agg_merge.cc \ utilities/backup/backup_engine.cc \ utilities/blob_db/blob_compaction_filter.cc \ utilities/blob_db/blob_db.cc \ @@ -364,6 +365,7 @@ TEST_LIB_SOURCES = \ test_util/mock_time_env.cc \ test_util/testharness.cc \ test_util/testutil.cc \ + utilities/agg_merge/test_agg_merge.cc \ utilities/cassandra/test_utils.cc \ FOLLY_SOURCES = \ @@ -559,6 +561,7 @@ TEST_MAIN_SOURCES = \ util/thread_list_test.cc \ util/thread_local_test.cc \ util/work_queue_test.cc \ + utilities/agg_merge/agg_merge_test.cc \ utilities/backup/backup_engine_test.cc \ utilities/blob_db/blob_db_test.cc \ utilities/cassandra/cassandra_format_test.cc \ diff --git a/utilities/agg_merge/agg_merge.cc b/utilities/agg_merge/agg_merge.cc new file mode 100644 index 000000000..a7eab1f12 --- /dev/null +++ b/utilities/agg_merge/agg_merge.cc @@ -0,0 +1,238 @@ +// Copyright (c) 2017-present, Facebook, Inc. All rights reserved. +// This source code is licensed under both the GPLv2 (found in the +// COPYING file in the root directory) and Apache 2.0 License +// (found in the LICENSE.Apache file in the root directory). + +#include "utilities/agg_merge/agg_merge.h" + +#include + +#include +#include +#include +#include +#include + +#include "port/lang.h" +#include "port/likely.h" +#include "rocksdb/merge_operator.h" +#include "rocksdb/slice.h" +#include "rocksdb/utilities/agg_merge.h" +#include "rocksdb/utilities/options_type.h" +#include "util/coding.h" +#include "utilities/merge_operators.h" + +namespace ROCKSDB_NAMESPACE { +static std::unordered_map> func_map; +const std::string kUnnamedFuncName = ""; +const std::string kErrorFuncName = "kErrorFuncName"; + +Status AddAggregator(const std::string& function_name, + std::unique_ptr&& agg) { + if (function_name == kErrorFuncName) { + return Status::InvalidArgument( + "Cannot register function name kErrorFuncName"); + } + func_map.emplace(function_name, std::move(agg)); + return Status::OK(); +} + +AggMergeOperator::AggMergeOperator() {} + +std::string EncodeAggFuncAndPayloadNoCheck(const Slice& function_name, + const Slice& value) { + std::string result; + PutLengthPrefixedSlice(&result, function_name); + result += value.ToString(); + return result; +} + +Status EncodeAggFuncAndPayload(const Slice& function_name, const Slice& payload, + std::string& output) { + if (function_name == kErrorFuncName) { + return Status::InvalidArgument("Cannot use error function name"); + } + if (function_name != kUnnamedFuncName && + func_map.find(function_name.ToString()) == func_map.end()) { + return Status::InvalidArgument("Function name not registered"); + } + output = EncodeAggFuncAndPayloadNoCheck(function_name, payload); + return Status::OK(); +} + +bool ExtractAggFuncAndValue(const Slice& op, Slice& func, Slice& value) { + value = op; + return GetLengthPrefixedSlice(&value, &func); +} + +bool ExtractList(const Slice& encoded_list, std::vector& decoded_list) { + decoded_list.clear(); + Slice list_slice = encoded_list; + Slice item; + while (GetLengthPrefixedSlice(&list_slice, &item)) { + decoded_list.push_back(item); + } + return list_slice.empty(); +} + +class AggMergeOperator::Accumulator { + public: + bool Add(const Slice& op, bool is_partial_aggregation) { + if (ignore_operands_) { + return true; + } + Slice my_func; + Slice my_value; + bool ret = ExtractAggFuncAndValue(op, my_func, my_value); + if (!ret) { + ignore_operands_ = true; + return true; + } + + // Determine whether we need to do partial merge. + if (is_partial_aggregation && !my_func.empty()) { + auto f = func_map.find(my_func.ToString()); + if (f == func_map.end() || !f->second->DoPartialAggregate()) { + return false; + } + } + + if (!func_valid_) { + if (my_func != kUnnamedFuncName) { + func_ = my_func; + func_valid_ = true; + } + } else if (func_ != my_func) { + // User switched aggregation function. Need to aggregate the older + // one first. + + // Previous aggreagion can't be done in partial merge + if (is_partial_aggregation) { + func_valid_ = false; + ignore_operands_ = true; + return false; + } + + // We could consider stashing an iterator into the hash of aggregators + // to avoid repeated lookups when the aggregator doesn't change. + auto f = func_map.find(func_.ToString()); + if (f == func_map.end() || !f->second->Aggregate(values_, scratch_)) { + func_valid_ = false; + ignore_operands_ = true; + return true; + } + std::swap(scratch_, aggregated_); + values_.clear(); + values_.push_back(aggregated_); + func_ = my_func; + } + values_.push_back(my_value); + return true; + } + + // Return false if aggregation fails. + // One possible reason + bool GetResult(std::string& result) { + if (!func_valid_) { + return false; + } + auto f = func_map.find(func_.ToString()); + if (f == func_map.end()) { + return false; + } + if (!f->second->Aggregate(values_, scratch_)) { + return false; + } + result = EncodeAggFuncAndPayloadNoCheck(func_, scratch_); + return true; + } + + void Clear() { + func_.clear(); + values_.clear(); + aggregated_.clear(); + scratch_.clear(); + ignore_operands_ = false; + func_valid_ = false; + } + + private: + Slice func_; + std::vector values_; + std::string aggregated_; + std::string scratch_; + bool ignore_operands_ = false; + bool func_valid_ = false; +}; + +// Creating and using a new Accumulator might invoke multiple malloc and is +// expensive if it needs to be done when processing each merge operation. +// AggMergeOperator's merge operators can be invoked concurrently by multiple +// threads so we cannot simply create one Aggregator and reuse. +// We use thread local instances instead. +AggMergeOperator::Accumulator& AggMergeOperator::GetTLSAccumulator() { + static thread_local Accumulator tls_acc; + tls_acc.Clear(); + return tls_acc; +} + +void AggMergeOperator::PackAllMergeOperands(const MergeOperationInput& merge_in, + MergeOperationOutput& merge_out) { + merge_out.new_value = ""; + PutLengthPrefixedSlice(&merge_out.new_value, kErrorFuncName); + if (merge_in.existing_value != nullptr) { + PutLengthPrefixedSlice(&merge_out.new_value, *merge_in.existing_value); + } + for (const Slice& op : merge_in.operand_list) { + PutLengthPrefixedSlice(&merge_out.new_value, op); + } +} + +bool AggMergeOperator::FullMergeV2(const MergeOperationInput& merge_in, + MergeOperationOutput* merge_out) const { + Accumulator& agg = GetTLSAccumulator(); + if (merge_in.existing_value != nullptr) { + agg.Add(*merge_in.existing_value, /*is_partial_aggregation=*/false); + } + for (const Slice& e : merge_in.operand_list) { + agg.Add(e, /*is_partial_aggregation=*/false); + } + + bool succ = agg.GetResult(merge_out->new_value); + if (!succ) { + // If aggregation can't happen, pack all merge operands. In contrast to + // merge operator, we don't want to fail the DB. If users insert wrong + // format or call unregistered an aggregation function, we still hope + // the DB can continue functioning with other keys. + PackAllMergeOperands(merge_in, *merge_out); + } + agg.Clear(); + return true; +} + +bool AggMergeOperator::PartialMergeMulti(const Slice& /*key*/, + const std::deque& operand_list, + std::string* new_value, + Logger* /*logger*/) const { + Accumulator& agg = GetTLSAccumulator(); + bool do_aggregation = true; + for (const Slice& item : operand_list) { + do_aggregation = agg.Add(item, /*is_partial_aggregation=*/true); + if (!do_aggregation) { + break; + } + } + if (do_aggregation) { + do_aggregation = agg.GetResult(*new_value); + } + agg.Clear(); + return do_aggregation; +} + +std::shared_ptr GetAggMergeOperator() { + STATIC_AVOID_DESTRUCTION(std::shared_ptr, instance) + (std::make_shared()); + assert(instance); + return instance; +} +} // namespace ROCKSDB_NAMESPACE diff --git a/utilities/agg_merge/agg_merge.h b/utilities/agg_merge/agg_merge.h new file mode 100644 index 000000000..00e58de08 --- /dev/null +++ b/utilities/agg_merge/agg_merge.h @@ -0,0 +1,49 @@ +// Copyright (c) 2017-present, Facebook, Inc. All rights reserved. +// This source code is licensed under both the GPLv2 (found in the +// COPYING file in the root directory) and Apache 2.0 License +// (found in the LICENSE.Apache file in the root directory). + +#pragma once +#include +#include +#include +#include + +#include "rocksdb/merge_operator.h" +#include "rocksdb/slice.h" +#include "rocksdb/utilities/agg_merge.h" +#include "utilities/cassandra/cassandra_options.h" + +namespace ROCKSDB_NAMESPACE { +class AggMergeOperator : public MergeOperator { + public: + explicit AggMergeOperator(); + + bool FullMergeV2(const MergeOperationInput& merge_in, + MergeOperationOutput* merge_out) const override; + + bool PartialMergeMulti(const Slice& key, + const std::deque& operand_list, + std::string* new_value, Logger* logger) const override; + + const char* Name() const override { return kClassName(); } + static const char* kClassName() { return "AggMergeOperator.v1"; } + + bool AllowSingleOperand() const override { return true; } + + bool ShouldMerge(const std::vector&) const override { return false; } + + private: + class Accumulator; + + // Pack all merge operands into one value. This is called when aggregation + // fails. The existing values are preserved and returned so that users can + // debug the problem. + static void PackAllMergeOperands(const MergeOperationInput& merge_in, + MergeOperationOutput& merge_out); + static Accumulator& GetTLSAccumulator(); +}; + +extern std::string EncodeAggFuncAndPayloadNoCheck(const Slice& function_name, + const Slice& value); +} // namespace ROCKSDB_NAMESPACE diff --git a/utilities/agg_merge/agg_merge_test.cc b/utilities/agg_merge/agg_merge_test.cc new file mode 100644 index 000000000..6502daa36 --- /dev/null +++ b/utilities/agg_merge/agg_merge_test.cc @@ -0,0 +1,134 @@ +// Copyright (c) 2017-present, Facebook, Inc. All rights reserved. +// This source code is licensed under both the GPLv2 (found in the +// COPYING file in the root directory) and Apache 2.0 License +// (found in the LICENSE.Apache file in the root directory). + +#include "rocksdb/utilities/agg_merge.h" + +#include + +#include + +#include "db/db_test_util.h" +#include "rocksdb/options.h" +#include "test_util/testharness.h" +#include "utilities/agg_merge/agg_merge.h" +#include "utilities/agg_merge/test_agg_merge.h" + +namespace ROCKSDB_NAMESPACE { + +class AggMergeTest : public DBTestBase { + public: + AggMergeTest() : DBTestBase("agg_merge_db_test", /*env_do_fsync=*/true) {} +}; + +TEST_F(AggMergeTest, TestUsingMergeOperator) { + ASSERT_OK(AddAggregator("sum", std::make_unique())); + ASSERT_OK(AddAggregator("last3", std::make_unique())); + ASSERT_OK(AddAggregator("mul", std::make_unique())); + + Options options = CurrentOptions(); + options.merge_operator = GetAggMergeOperator(); + Reopen(options); + std::string v = EncodeHelper::EncodeFuncAndInt("sum", 10); + ASSERT_OK(Merge("foo", v)); + v = EncodeHelper::EncodeFuncAndInt("sum", 20); + ASSERT_OK(Merge("foo", v)); + v = EncodeHelper::EncodeFuncAndInt("sum", 15); + ASSERT_OK(Merge("foo", v)); + + v = EncodeHelper::EncodeFuncAndList("last3", {"a", "b"}); + ASSERT_OK(Merge("bar", v)); + v = EncodeHelper::EncodeFuncAndList("last3", {"c", "d", "e"}); + ASSERT_OK(Merge("bar", v)); + ASSERT_OK(Flush()); + v = EncodeHelper::EncodeFuncAndList("last3", {"f"}); + ASSERT_OK(Merge("bar", v)); + + // Test Put() without aggregation type. + v = EncodeHelper::EncodeFuncAndInt(kUnnamedFuncName, 30); + ASSERT_OK(Put("foo2", v)); + v = EncodeHelper::EncodeFuncAndInt("sum", 10); + ASSERT_OK(Merge("foo2", v)); + v = EncodeHelper::EncodeFuncAndInt("sum", 20); + ASSERT_OK(Merge("foo2", v)); + + EXPECT_EQ(EncodeHelper::EncodeFuncAndInt("sum", 45), Get("foo")); + EXPECT_EQ(EncodeHelper::EncodeFuncAndList("last3", {"f", "c", "d"}), + Get("bar")); + EXPECT_EQ(EncodeHelper::EncodeFuncAndInt("sum", 60), Get("foo2")); + + // Test changing aggregation type + v = EncodeHelper::EncodeFuncAndInt("mul", 10); + ASSERT_OK(Put("bar2", v)); + v = EncodeHelper::EncodeFuncAndInt("mul", 20); + ASSERT_OK(Merge("bar2", v)); + v = EncodeHelper::EncodeFuncAndInt("sum", 30); + ASSERT_OK(Merge("bar2", v)); + v = EncodeHelper::EncodeFuncAndInt("sum", 40); + ASSERT_OK(Merge("bar2", v)); + EXPECT_EQ(EncodeHelper::EncodeFuncAndInt("sum", 10 * 20 + 30 + 40), + Get("bar2")); + + // Changing aggregation type with partial merge + v = EncodeHelper::EncodeFuncAndInt("mul", 10); + ASSERT_OK(Merge("foo3", v)); + ASSERT_OK(Flush()); + v = EncodeHelper::EncodeFuncAndInt("mul", 10); + ASSERT_OK(Merge("foo3", v)); + v = EncodeHelper::EncodeFuncAndInt("mul", 10); + ASSERT_OK(Merge("foo3", v)); + v = EncodeHelper::EncodeFuncAndInt("sum", 10); + ASSERT_OK(Merge("foo3", v)); + ASSERT_OK(Flush()); + EXPECT_EQ(EncodeHelper::EncodeFuncAndInt("sum", 10 * 10 * 10 + 10), + Get("foo3")); + + // Merge after full merge + v = EncodeHelper::EncodeFuncAndInt("sum", 1); + ASSERT_OK(Merge("foo4", v)); + v = EncodeHelper::EncodeFuncAndInt("sum", 2); + ASSERT_OK(Merge("foo4", v)); + ASSERT_OK(Flush()); + v = EncodeHelper::EncodeFuncAndInt("sum", 3); + ASSERT_OK(Merge("foo4", v)); + v = EncodeHelper::EncodeFuncAndInt("sum", 4); + ASSERT_OK(Merge("foo4", v)); + ASSERT_OK(Flush()); + ASSERT_OK(db_->CompactRange(CompactRangeOptions(), nullptr, nullptr)); + v = EncodeHelper::EncodeFuncAndInt("sum", 5); + ASSERT_OK(Merge("foo4", v)); + EXPECT_EQ(EncodeHelper::EncodeFuncAndInt("sum", 15), Get("foo4")); + + // Test unregistered function name + v = EncodeAggFuncAndPayloadNoCheck("non_existing", "1"); + ASSERT_OK(Merge("bar3", v)); + std::string v1; + v1 = EncodeAggFuncAndPayloadNoCheck("non_existing", "invalid"); + ; + ASSERT_OK(Merge("bar3", v1)); + EXPECT_EQ(EncodeAggFuncAndPayloadNoCheck(kErrorFuncName, + EncodeHelper::EncodeList({v, v1})), + Get("bar3")); + + // invalidate input + ASSERT_OK(EncodeAggFuncAndPayload("sum", "invalid", v)); + ASSERT_OK(Merge("bar4", v)); + v1 = EncodeHelper::EncodeFuncAndInt("sum", 20); + ASSERT_OK(Merge("bar4", v1)); + std::string aggregated_value = Get("bar4"); + Slice func, payload; + ASSERT_TRUE(ExtractAggFuncAndValue(aggregated_value, func, payload)); + EXPECT_EQ(kErrorFuncName, func); + std::vector decoded_list; + ASSERT_TRUE(ExtractList(payload, decoded_list)); + ASSERT_EQ(2, decoded_list.size()); + ASSERT_EQ(v, decoded_list[0]); + ASSERT_EQ(v1, decoded_list[1]); +} +} // namespace ROCKSDB_NAMESPACE + +int main(int argc, char** argv) { + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} diff --git a/utilities/agg_merge/test_agg_merge.cc b/utilities/agg_merge/test_agg_merge.cc new file mode 100644 index 000000000..06e5b5697 --- /dev/null +++ b/utilities/agg_merge/test_agg_merge.cc @@ -0,0 +1,104 @@ +// Copyright (c) 2017-present, Facebook, Inc. All rights reserved. +// This source code is licensed under both the GPLv2 (found in the +// COPYING file in the root directory) and Apache 2.0 License +// (found in the LICENSE.Apache file in the root directory). + +#include "test_agg_merge.h" + +#include + +#include +#include + +#include "util/coding.h" +#include "utilities/agg_merge/agg_merge.h" + +namespace ROCKSDB_NAMESPACE { + +std::string EncodeHelper::EncodeFuncAndInt(const Slice& function_name, + int64_t value) { + std::string encoded_value; + PutVarsignedint64(&encoded_value, value); + std::string ret; + Status s = EncodeAggFuncAndPayload(function_name, encoded_value, ret); + assert(s.ok()); + return ret; +} + +std::string EncodeHelper::EncodeInt(int64_t value) { + std::string encoded_value; + PutVarsignedint64(&encoded_value, value); + return encoded_value; +} + +std::string EncodeHelper::EncodeFuncAndList(const Slice& function_name, + const std::vector& list) { + std::string ret; + Status s = EncodeAggFuncAndPayload(function_name, EncodeList(list), ret); + assert(s.ok()); + return ret; +} + +std::string EncodeHelper::EncodeList(const std::vector& list) { + std::string result; + for (const Slice& entity : list) { + PutLengthPrefixedSlice(&result, entity); + } + return result; +} + +bool SumAggregator::Aggregate(const std::vector& item_list, + std::string& result) const { + int64_t sum = 0; + for (const Slice& item : item_list) { + int64_t ivalue; + Slice v = item; + if (!GetVarsignedint64(&v, &ivalue) || !v.empty()) { + return false; + } + sum += ivalue; + } + result = EncodeHelper::EncodeInt(sum); + return true; +} + +bool MultipleAggregator::Aggregate(const std::vector& item_list, + std::string& result) const { + int64_t mresult = 1; + for (const Slice& item : item_list) { + int64_t ivalue; + Slice v = item; + if (!GetVarsignedint64(&v, &ivalue) || !v.empty()) { + return false; + } + mresult *= ivalue; + } + result = EncodeHelper::EncodeInt(mresult); + return true; +} + +bool Last3Aggregator::Aggregate(const std::vector& item_list, + std::string& result) const { + std::vector last3; + last3.reserve(3); + for (auto it = item_list.rbegin(); it != item_list.rend(); ++it) { + Slice input = *it; + Slice entity; + bool ret; + while ((ret = GetLengthPrefixedSlice(&input, &entity)) == true) { + last3.push_back(entity); + if (last3.size() >= 3) { + break; + } + } + if (last3.size() >= 3) { + break; + } + if (!ret) { + continue; + } + } + result = EncodeHelper::EncodeList(last3); + return true; +} +} // namespace ROCKSDB_NAMESPACE diff --git a/utilities/agg_merge/test_agg_merge.h b/utilities/agg_merge/test_agg_merge.h new file mode 100644 index 000000000..5bdf8b9cc --- /dev/null +++ b/utilities/agg_merge/test_agg_merge.h @@ -0,0 +1,47 @@ +// Copyright (c) 2017-present, Facebook, Inc. All rights reserved. +// This source code is licensed under both the GPLv2 (found in the +// COPYING file in the root directory) and Apache 2.0 License +// (found in the LICENSE.Apache file in the root directory). + +#pragma once +#include +#include +#include +#include + +#include "rocksdb/merge_operator.h" +#include "rocksdb/slice.h" +#include "rocksdb/utilities/agg_merge.h" +#include "utilities/cassandra/cassandra_options.h" + +namespace ROCKSDB_NAMESPACE { +class SumAggregator : public Aggregator { + public: + ~SumAggregator() override {} + bool Aggregate(const std::vector&, std::string& result) const override; + bool DoPartialAggregate() const override { return true; } +}; + +class MultipleAggregator : public Aggregator { + public: + ~MultipleAggregator() override {} + bool Aggregate(const std::vector&, std::string& result) const override; + bool DoPartialAggregate() const override { return true; } +}; + +class Last3Aggregator : public Aggregator { + public: + ~Last3Aggregator() override {} + bool Aggregate(const std::vector&, std::string& result) const override; +}; + +class EncodeHelper { + public: + static std::string EncodeFuncAndInt(const Slice& function_name, + int64_t value); + static std::string EncodeInt(int64_t value); + static std::string EncodeList(const std::vector& list); + static std::string EncodeFuncAndList(const Slice& function_name, + const std::vector& list); +}; +} // namespace ROCKSDB_NAMESPACE