diff --git a/include/rocksdb/utilities/write_batch_with_index.h b/include/rocksdb/utilities/write_batch_with_index.h index f31c86ea1..ee5ec198e 100644 --- a/include/rocksdb/utilities/write_batch_with_index.h +++ b/include/rocksdb/utilities/write_batch_with_index.h @@ -12,6 +12,7 @@ #pragma once #include "rocksdb/comparator.h" +#include "rocksdb/iterator.h" #include "rocksdb/slice.h" #include "rocksdb/status.h" #include "rocksdb/write_batch.h" @@ -101,6 +102,11 @@ class WriteBatchWithIndex { // Create an iterator of the default column family. WBWIIterator* NewIterator(); + // Will create a new Iterator that will use WBWIIterator as a delta and + // base_iterator as base + Iterator* NewIteratorWithBase(ColumnFamilyHandle* column_family, + Iterator* base_iterator); + private: struct Rep; Rep* rep; diff --git a/utilities/write_batch_with_index/write_batch_with_index.cc b/utilities/write_batch_with_index/write_batch_with_index.cc index 0b460cd15..adfa5b324 100644 --- a/utilities/write_batch_with_index/write_batch_with_index.cc +++ b/utilities/write_batch_with_index/write_batch_with_index.cc @@ -4,13 +4,289 @@ // of patent rights can be found in the PATENTS file in the same directory. #include "rocksdb/utilities/write_batch_with_index.h" + +#include + #include "rocksdb/comparator.h" +#include "rocksdb/iterator.h" #include "db/column_family.h" #include "db/skiplist.h" #include "util/arena.h" namespace rocksdb { +// when direction == forward +// * current_at_base_ <=> base_iterator > delta_iterator +// when direction == backwards +// * current_at_base_ <=> base_iterator < delta_iterator +// always: +// * equal_keys_ <=> base_iterator == delta_iterator +class BaseDeltaIterator : public Iterator { + public: + BaseDeltaIterator(Iterator* base_iterator, WBWIIterator* delta_iterator, + const Comparator* comparator) + : forward_(true), + current_at_base_(true), + equal_keys_(false), + status_(Status::OK()), + base_iterator_(base_iterator), + delta_iterator_(delta_iterator), + comparator_(comparator) {} + + virtual ~BaseDeltaIterator() {} + + bool Valid() const override { + return current_at_base_ ? BaseValid() : DeltaValid(); + } + + void SeekToFirst() override { + forward_ = true; + base_iterator_->SeekToFirst(); + delta_iterator_->SeekToFirst(); + UpdateCurrent(); + } + + void SeekToLast() override { + forward_ = false; + base_iterator_->SeekToLast(); + delta_iterator_->SeekToLast(); + UpdateCurrent(); + } + + void Seek(const Slice& key) override { + forward_ = true; + base_iterator_->Seek(key); + delta_iterator_->Seek(key); + UpdateCurrent(); + } + + void Next() override { + if (!Valid()) { + status_ = Status::NotSupported("Next() on invalid iterator"); + } + + if (!forward_) { + // Need to change direction + // if our direction was backward and we're not equal, we have two states: + // * both iterators are valid: we're already in a good state (current + // shows to smaller) + // * only one iterator is valid: we need to advance that iterator + forward_ = true; + equal_keys_ = false; + if (!BaseValid()) { + assert(DeltaValid()); + base_iterator_->SeekToFirst(); + } else if (!DeltaValid()) { + delta_iterator_->SeekToFirst(); + } else if (current_at_base_) { + // Change delta from larger than base to smaller + AdvanceDelta(); + } else { + // Change base from larger than delta to smaller + AdvanceBase(); + } + if (DeltaValid() && BaseValid()) { + if (Compare() == 0) { + equal_keys_ = true; + } + } + } + Advance(); + } + + void Prev() override { + if (!Valid()) { + status_ = Status::NotSupported("Prev() on invalid iterator"); + } + + if (forward_) { + // Need to change direction + // if our direction was backward and we're not equal, we have two states: + // * both iterators are valid: we're already in a good state (current + // shows to smaller) + // * only one iterator is valid: we need to advance that iterator + forward_ = false; + equal_keys_ = false; + if (!BaseValid()) { + assert(DeltaValid()); + base_iterator_->SeekToLast(); + } else if (!DeltaValid()) { + delta_iterator_->SeekToLast(); + } else if (current_at_base_) { + // Change delta from less advanced than base to more advanced + AdvanceDelta(); + } else { + // Change base from less advanced than delta to more advanced + AdvanceBase(); + } + if (DeltaValid() && BaseValid()) { + if (Compare() == 0) { + equal_keys_ = true; + } + } + } + + Advance(); + } + + Slice key() const override { + return current_at_base_ ? base_iterator_->key() + : delta_iterator_->Entry().key; + } + + Slice value() const override { + return current_at_base_ ? base_iterator_->value() + : delta_iterator_->Entry().value; + } + + Status status() const { + if (!status_.ok()) { + return status_; + } + if (!base_iterator_->status().ok()) { + return base_iterator_->status(); + } + return delta_iterator_->status(); + } + + private: + // -1 -- delta less advanced than base + // 0 -- delta == base + // 1 -- delta more advanced than base + int Compare() const { + assert(delta_iterator_->Valid() && base_iterator_->Valid()); + int cmp = comparator_->Compare(delta_iterator_->Entry().key, + base_iterator_->key()); + if (forward_) { + return cmp; + } else { + return -cmp; + } + } + bool IsDeltaDelete() { + assert(DeltaValid()); + return delta_iterator_->Entry().type == kDeleteRecord; + } + void AssertInvariants() { +#ifndef NDEBUG + if (!Valid()) { + return; + } + if (!BaseValid()) { + assert(!current_at_base_ && delta_iterator_->Valid()); + return; + } + if (!DeltaValid()) { + assert(current_at_base_ && base_iterator_->Valid()); + return; + } + // we don't support those yet + assert(delta_iterator_->Entry().type != kMergeRecord && + delta_iterator_->Entry().type != kLogDataRecord); + int compare = comparator_->Compare(delta_iterator_->Entry().key, + base_iterator_->key()); + if (forward_) { + // current_at_base -> compare < 0 + assert(!current_at_base_ || compare < 0); + // !current_at_base -> compare <= 0 + assert(current_at_base_ && compare >= 0); + } else { + // current_at_base -> compare > 0 + assert(!current_at_base_ || compare > 0); + // !current_at_base -> compare <= 0 + assert(current_at_base_ && compare <= 0); + } + // equal_keys_ <=> compare == 0 + assert((equal_keys_ || compare != 0) && (!equal_keys_ || compare == 0)); +#endif + } + + void Advance() { + if (equal_keys_) { + assert(BaseValid() && DeltaValid()); + AdvanceBase(); + AdvanceDelta(); + } else { + if (current_at_base_) { + assert(BaseValid()); + AdvanceBase(); + } else { + assert(DeltaValid()); + AdvanceDelta(); + } + } + UpdateCurrent(); + } + + void AdvanceDelta() { + if (forward_) { + delta_iterator_->Next(); + } else { + delta_iterator_->Prev(); + } + } + void AdvanceBase() { + if (forward_) { + base_iterator_->Next(); + } else { + base_iterator_->Prev(); + } + } + bool BaseValid() const { return base_iterator_->Valid(); } + bool DeltaValid() const { return delta_iterator_->Valid(); } + void UpdateCurrent() { + while (true) { + equal_keys_ = false; + if (!BaseValid()) { + // Base has finished. + if (!DeltaValid()) { + // Finished + return; + } + if (IsDeltaDelete()) { + AdvanceDelta(); + } else { + current_at_base_ = false; + return; + } + } else if (!DeltaValid()) { + // Delta has finished. + current_at_base_ = true; + return; + } else { + int compare = Compare(); + if (compare <= 0) { // delta bigger or equal + if (compare == 0) { + equal_keys_ = true; + } + if (!IsDeltaDelete()) { + current_at_base_ = false; + return; + } + // Delta is less advanced and is delete. + AdvanceDelta(); + if (equal_keys_) { + AdvanceBase(); + } + } else { + current_at_base_ = true; + return; + } + } + } + + AssertInvariants(); + } + + bool forward_; + bool current_at_base_; + bool equal_keys_; + Status status_; + std::unique_ptr base_iterator_; + std::unique_ptr delta_iterator_; + const Comparator* comparator_; // not owned +}; + class ReadableWriteBatch : public WriteBatch { public: explicit ReadableWriteBatch(size_t reserved_bytes = 0) @@ -298,6 +574,16 @@ WBWIIterator* WriteBatchWithIndex::NewIterator( &(rep->skip_list), &rep->write_batch); } +Iterator* WriteBatchWithIndex::NewIteratorWithBase( + ColumnFamilyHandle* column_family, Iterator* base_iterator) { + if (rep->overwrite_key == false) { + assert(false); + return nullptr; + } + return new BaseDeltaIterator(base_iterator, NewIterator(column_family), + GetColumnFamilyUserComparator(column_family)); +} + void WriteBatchWithIndex::Put(ColumnFamilyHandle* column_family, const Slice& key, const Slice& value) { rep->SetLastEntryOffset(); diff --git a/utilities/write_batch_with_index/write_batch_with_index_test.cc b/utilities/write_batch_with_index/write_batch_with_index_test.cc index d34380fd7..32b45e339 100644 --- a/utilities/write_batch_with_index/write_batch_with_index_test.cc +++ b/utilities/write_batch_with_index/write_batch_with_index_test.cc @@ -464,6 +464,305 @@ TEST(WriteBatchWithIndexTest, TestOverwriteKey) { } } +namespace { +typedef std::map KVMap; + +class KVIter : public Iterator { + public: + explicit KVIter(const KVMap* map) : map_(map), iter_(map_->end()) {} + virtual bool Valid() const { return iter_ != map_->end(); } + virtual void SeekToFirst() { iter_ = map_->begin(); } + virtual void SeekToLast() { + if (map_->empty()) { + iter_ = map_->end(); + } else { + iter_ = map_->find(map_->rbegin()->first); + } + } + virtual void Seek(const Slice& k) { iter_ = map_->lower_bound(k.ToString()); } + virtual void Next() { ++iter_; } + virtual void Prev() { + if (iter_ == map_->begin()) { + iter_ = map_->end(); + return; + } + --iter_; + } + + virtual Slice key() const { return iter_->first; } + virtual Slice value() const { return iter_->second; } + virtual Status status() const { return Status::OK(); } + + private: + const KVMap* const map_; + KVMap::const_iterator iter_; +}; + +void AssertIter(Iterator* iter, const std::string& key, + const std::string& value) { + ASSERT_OK(iter->status()); + ASSERT_TRUE(iter->Valid()); + ASSERT_EQ(key, iter->key().ToString()); + ASSERT_EQ(value, iter->value().ToString()); +} + +void AssertItersEqual(Iterator* iter1, Iterator* iter2) { + ASSERT_EQ(iter1->Valid(), iter2->Valid()); + if (iter1->Valid()) { + ASSERT_EQ(iter1->key().ToString(), iter2->key().ToString()); + ASSERT_EQ(iter1->value().ToString(), iter2->value().ToString()); + } +} +} // namespace + +TEST(WriteBatchWithIndexTest, TestRandomIteraratorWithBase) { + std::vector source_strings = {"a", "b", "c", "d", "e", + "f", "g", "h", "i", "j"}; + for (int rand_seed = 301; rand_seed < 366; rand_seed++) { + Random rnd(rand_seed); + + ColumnFamilyHandleImplDummy cf1(6, BytewiseComparator()); + WriteBatchWithIndex batch(BytewiseComparator(), 20, true); + KVMap map; + KVMap merged_map; + for (auto key : source_strings) { + std::string value = key + key; + int type = rnd.Uniform(6); + switch (type) { + case 0: + // only base has it + map[key] = value; + merged_map[key] = value; + break; + case 1: + // only delta has it + batch.Put(&cf1, key, value); + map[key] = value; + merged_map[key] = value; + break; + case 2: + // both has it. Delta should win + batch.Put(&cf1, key, value); + map[key] = "wrong_value"; + merged_map[key] = value; + break; + case 3: + // both has it. Delta is delete + batch.Delete(&cf1, key); + map[key] = "wrong_value"; + break; + case 4: + // only delta has it. Delta is delete + batch.Delete(&cf1, key); + map[key] = "wrong_value"; + break; + default: + // Neither iterator has it. + break; + } + } + + std::unique_ptr iter( + batch.NewIteratorWithBase(&cf1, new KVIter(&map))); + std::unique_ptr result_iter(new KVIter(&merged_map)); + + bool is_valid = false; + for (int i = 0; i < 128; i++) { + // Random walk and make sure iter and result_iter returns the + // same key and value + int type = rnd.Uniform(5); + ASSERT_OK(iter->status()); + switch (type) { + case 0: + // Seek to First + iter->SeekToFirst(); + result_iter->SeekToFirst(); + break; + case 1: + // Seek to last + iter->SeekToLast(); + result_iter->SeekToLast(); + break; + case 2: { + // Seek to random key + auto key_idx = rnd.Uniform(source_strings.size()); + auto key = source_strings[key_idx]; + iter->Seek(key); + result_iter->Seek(key); + break; + } + case 3: + // Next + if (is_valid) { + iter->Next(); + result_iter->Next(); + } else { + continue; + } + break; + default: + assert(type == 4); + // Prev + if (is_valid) { + iter->Prev(); + result_iter->Prev(); + } else { + continue; + } + break; + } + AssertItersEqual(iter.get(), result_iter.get()); + is_valid = iter->Valid(); + } + } +} + +TEST(WriteBatchWithIndexTest, TestIteraratorWithBase) { + ColumnFamilyHandleImplDummy cf1(6, BytewiseComparator()); + WriteBatchWithIndex batch(BytewiseComparator(), 20, true); + + { + KVMap map; + map["a"] = "aa"; + map["c"] = "cc"; + map["e"] = "ee"; + std::unique_ptr iter( + batch.NewIteratorWithBase(&cf1, new KVIter(&map))); + + iter->SeekToFirst(); + AssertIter(iter.get(), "a", "aa"); + iter->Next(); + AssertIter(iter.get(), "c", "cc"); + iter->Next(); + AssertIter(iter.get(), "e", "ee"); + iter->Next(); + ASSERT_OK(iter->status()); + ASSERT_TRUE(!iter->Valid()); + + iter->SeekToLast(); + AssertIter(iter.get(), "e", "ee"); + iter->Prev(); + AssertIter(iter.get(), "c", "cc"); + iter->Prev(); + AssertIter(iter.get(), "a", "aa"); + iter->Prev(); + ASSERT_OK(iter->status()); + ASSERT_TRUE(!iter->Valid()); + + iter->Seek("b"); + AssertIter(iter.get(), "c", "cc"); + + iter->Prev(); + AssertIter(iter.get(), "a", "aa"); + + iter->Seek("a"); + AssertIter(iter.get(), "a", "aa"); + } + + batch.Put(&cf1, "a", "aa"); + batch.Delete(&cf1, "b"); + batch.Put(&cf1, "c", "cc"); + batch.Put(&cf1, "d", "dd"); + batch.Delete(&cf1, "e"); + + { + KVMap map; + map["b"] = ""; + map["cc"] = "cccc"; + map["f"] = "ff"; + std::unique_ptr iter( + batch.NewIteratorWithBase(&cf1, new KVIter(&map))); + + iter->SeekToFirst(); + AssertIter(iter.get(), "a", "aa"); + iter->Next(); + AssertIter(iter.get(), "c", "cc"); + iter->Next(); + AssertIter(iter.get(), "cc", "cccc"); + iter->Next(); + AssertIter(iter.get(), "d", "dd"); + iter->Next(); + AssertIter(iter.get(), "f", "ff"); + iter->Next(); + ASSERT_OK(iter->status()); + ASSERT_TRUE(!iter->Valid()); + + iter->SeekToLast(); + AssertIter(iter.get(), "f", "ff"); + iter->Prev(); + AssertIter(iter.get(), "d", "dd"); + iter->Prev(); + AssertIter(iter.get(), "cc", "cccc"); + iter->Prev(); + AssertIter(iter.get(), "c", "cc"); + iter->Next(); + AssertIter(iter.get(), "cc", "cccc"); + iter->Prev(); + AssertIter(iter.get(), "c", "cc"); + iter->Prev(); + AssertIter(iter.get(), "a", "aa"); + iter->Prev(); + ASSERT_OK(iter->status()); + ASSERT_TRUE(!iter->Valid()); + + iter->Seek("c"); + AssertIter(iter.get(), "c", "cc"); + + iter->Seek("cb"); + AssertIter(iter.get(), "cc", "cccc"); + + iter->Seek("cc"); + AssertIter(iter.get(), "cc", "cccc"); + iter->Next(); + AssertIter(iter.get(), "d", "dd"); + + iter->Seek("e"); + AssertIter(iter.get(), "f", "ff"); + + iter->Prev(); + AssertIter(iter.get(), "d", "dd"); + + iter->Next(); + AssertIter(iter.get(), "f", "ff"); + } + { + KVMap empty_map; + std::unique_ptr iter( + batch.NewIteratorWithBase(&cf1, new KVIter(&empty_map))); + + iter->SeekToFirst(); + AssertIter(iter.get(), "a", "aa"); + iter->Next(); + AssertIter(iter.get(), "c", "cc"); + iter->Next(); + AssertIter(iter.get(), "d", "dd"); + iter->Next(); + ASSERT_OK(iter->status()); + ASSERT_TRUE(!iter->Valid()); + + iter->SeekToLast(); + AssertIter(iter.get(), "d", "dd"); + iter->Prev(); + AssertIter(iter.get(), "c", "cc"); + iter->Prev(); + AssertIter(iter.get(), "a", "aa"); + + iter->Prev(); + ASSERT_OK(iter->status()); + ASSERT_TRUE(!iter->Valid()); + + iter->Seek("aa"); + AssertIter(iter.get(), "c", "cc"); + iter->Next(); + AssertIter(iter.get(), "d", "dd"); + + iter->Seek("ca"); + AssertIter(iter.get(), "d", "dd"); + + iter->Prev(); + AssertIter(iter.get(), "c", "cc"); + } +} } // namespace int main(int argc, char** argv) { return rocksdb::test::RunAllTests(); }