diff --git a/db/db_impl/db_impl.h b/db/db_impl/db_impl.h index 339399bbb..a4aded732 100644 --- a/db/db_impl/db_impl.h +++ b/db/db_impl/db_impl.h @@ -32,6 +32,7 @@ #include "db/log_writer.h" #include "db/logs_with_prep_tracker.h" #include "db/memtable_list.h" +#include "db/post_memtable_callback.h" #include "db/pre_release_callback.h" #include "db/range_del_aggregator.h" #include "db/read_callback.h" @@ -1309,7 +1310,8 @@ class DBImpl : public DB { uint64_t* log_used = nullptr, uint64_t log_ref = 0, bool disable_memtable = false, uint64_t* seq_used = nullptr, size_t batch_cnt = 0, - PreReleaseCallback* pre_release_callback = nullptr); + PreReleaseCallback* pre_release_callback = nullptr, + PostMemTableCallback* post_memtable_callback = nullptr); Status PipelinedWriteImpl(const WriteOptions& options, WriteBatch* updates, WriteCallback* callback = nullptr, diff --git a/db/db_impl/db_impl_write.cc b/db/db_impl/db_impl_write.cc index 39657d462..c98f5e246 100644 --- a/db/db_impl/db_impl_write.cc +++ b/db/db_impl/db_impl_write.cc @@ -126,7 +126,8 @@ Status DBImpl::WriteImpl(const WriteOptions& write_options, uint64_t* log_used, uint64_t log_ref, bool disable_memtable, uint64_t* seq_used, size_t batch_cnt, - PreReleaseCallback* pre_release_callback) { + PreReleaseCallback* pre_release_callback, + PostMemTableCallback* post_memtable_callback) { assert(!seq_per_batch_ || batch_cnt != 0); if (my_batch == nullptr) { return Status::InvalidArgument("Batch is nullptr!"); @@ -241,7 +242,8 @@ Status DBImpl::WriteImpl(const WriteOptions& write_options, PERF_TIMER_GUARD(write_pre_and_post_process_time); WriteThread::Writer w(write_options, my_batch, callback, log_ref, - disable_memtable, batch_cnt, pre_release_callback); + disable_memtable, batch_cnt, pre_release_callback, + post_memtable_callback); StopWatch write_sw(immutable_db_options_.clock, stats_, DB_WRITE); write_thread_.JoinBatchGroup(&w); @@ -268,6 +270,16 @@ Status DBImpl::WriteImpl(const WriteOptions& write_options, // we're responsible for exit batch group // TODO(myabandeh): propagate status to write_group auto last_sequence = w.write_group->last_sequence; + for (auto* tmp_w : *(w.write_group)) { + assert(tmp_w); + if (tmp_w->post_memtable_callback) { + Status tmp_s = + (*tmp_w->post_memtable_callback)(last_sequence, disable_memtable); + // TODO: propagate the execution status of post_memtable_callback to + // caller. + assert(tmp_s.ok()); + } + } versions_->SetLastSequence(last_sequence); MemTableInsertStatusCheck(w.status); write_thread_.ExitAsBatchGroupFollower(&w); @@ -545,6 +557,16 @@ Status DBImpl::WriteImpl(const WriteOptions& write_options, } if (should_exit_batch_group) { if (status.ok()) { + for (auto* tmp_w : write_group) { + assert(tmp_w); + if (tmp_w->post_memtable_callback) { + Status tmp_s = + (*tmp_w->post_memtable_callback)(last_sequence, disable_memtable); + // TODO: propagate the execution status of post_memtable_callback to + // caller. + assert(tmp_s.ok()); + } + } // Note: if we are to resume after non-OK statuses we need to revisit how // we reacts to non-OK statuses here. versions_->SetLastSequence(last_sequence); diff --git a/db/post_memtable_callback.h b/db/post_memtable_callback.h new file mode 100644 index 000000000..a877980b0 --- /dev/null +++ b/db/post_memtable_callback.h @@ -0,0 +1,20 @@ +// Copyright (c) 2011-present, Facebook, Inc. All rights reserved. +// This source code is licensed under both the GPLv2 (found in the +// COPYING file in the root directory) and Apache 2.0 License +// (found in the LICENSE.Apache file in the root directory). + +#pragma once + +#include "rocksdb/status.h" +#include "rocksdb/types.h" + +namespace ROCKSDB_NAMESPACE { + +class PostMemTableCallback { + public: + virtual ~PostMemTableCallback() {} + + virtual Status operator()(SequenceNumber seq, bool disable_memtable) = 0; +}; + +} // namespace ROCKSDB_NAMESPACE diff --git a/db/write_thread.h b/db/write_thread.h index af4d0967e..f78b01cd9 100644 --- a/db/write_thread.h +++ b/db/write_thread.h @@ -15,6 +15,7 @@ #include #include "db/dbformat.h" +#include "db/post_memtable_callback.h" #include "db/pre_release_callback.h" #include "db/write_callback.h" #include "monitoring/instrumented_mutex.h" @@ -122,6 +123,7 @@ class WriteThread { size_t batch_cnt; // if non-zero, number of sub-batches in the write batch size_t protection_bytes_per_key; PreReleaseCallback* pre_release_callback; + PostMemTableCallback* post_memtable_callback; uint64_t log_used; // log number that this batch was inserted into uint64_t log_ref; // log number that memtable insert should reference WriteCallback* callback; @@ -147,6 +149,7 @@ class WriteThread { batch_cnt(0), protection_bytes_per_key(0), pre_release_callback(nullptr), + post_memtable_callback(nullptr), log_used(0), log_ref(0), callback(nullptr), @@ -160,7 +163,8 @@ class WriteThread { Writer(const WriteOptions& write_options, WriteBatch* _batch, WriteCallback* _callback, uint64_t _log_ref, bool _disable_memtable, size_t _batch_cnt = 0, - PreReleaseCallback* _pre_release_callback = nullptr) + PreReleaseCallback* _pre_release_callback = nullptr, + PostMemTableCallback* _post_memtable_callback = nullptr) : batch(_batch), sync(write_options.sync), no_slowdown(write_options.no_slowdown), @@ -170,6 +174,7 @@ class WriteThread { batch_cnt(_batch_cnt), protection_bytes_per_key(_batch->GetProtectionBytesPerKey()), pre_release_callback(_pre_release_callback), + post_memtable_callback(_post_memtable_callback), log_used(0), log_ref(_log_ref), callback(_callback),