diff --git a/db/db_impl.cc b/db/db_impl.cc index bc3866816..391bfa6db 100644 --- a/db/db_impl.cc +++ b/db/db_impl.cc @@ -4095,7 +4095,6 @@ Status DBImpl::WriteImpl(const WriteOptions& write_options, } Status status; - bool callback_failed = false; bool xfunc_attempted_write = false; XFUNC_TEST("transaction", "transaction_xftest_write_impl", @@ -4113,7 +4112,7 @@ Status DBImpl::WriteImpl(const WriteOptions& write_options, w.sync = write_options.sync; w.disableWAL = write_options.disableWAL; w.in_batch_group = false; - w.has_callback = (callback != nullptr) ? true : false; + w.callback = callback; if (!write_options.disableWAL) { RecordTick(stats_, WRITE_WITH_WAL); @@ -4126,30 +4125,32 @@ Status DBImpl::WriteImpl(const WriteOptions& write_options, // we are a non-leader in a parallel group PERF_TIMER_GUARD(write_memtable_time); - ColumnFamilyMemTablesImpl column_family_memtables( - versions_->GetColumnFamilySet()); - WriteBatchInternal::SetSequence(w.batch, w.sequence); - w.status = WriteBatchInternal::InsertInto( - w.batch, &column_family_memtables, &flush_scheduler_, - write_options.ignore_missing_column_families, 0 /*log_number*/, this, - true /*dont_filter_deletes*/, true /*concurrent_memtable_writes*/); + if (!w.CallbackFailed()) { + ColumnFamilyMemTablesImpl column_family_memtables( + versions_->GetColumnFamilySet()); + WriteBatchInternal::SetSequence(w.batch, w.sequence); + w.status = WriteBatchInternal::InsertInto( + w.batch, &column_family_memtables, &flush_scheduler_, + write_options.ignore_missing_column_families, 0 /*log_number*/, this, + true /*dont_filter_deletes*/, true /*concurrent_memtable_writes*/); + } if (write_thread_.CompleteParallelWorker(&w)) { // we're responsible for early exit - auto last_sequence = - w.parallel_group->last_writer->sequence + - WriteBatchInternal::Count(w.parallel_group->last_writer->batch) - 1; + auto last_sequence = w.parallel_group->last_sequence; SetTickerCount(stats_, SEQUENCE_NUMBER, last_sequence); versions_->SetLastSequence(last_sequence); write_thread_.EarlyExitParallelGroup(&w); } assert(w.state == WriteThread::STATE_COMPLETED); // STATE_COMPLETED conditional below handles exit + + status = w.FinalStatus(); } if (w.state == WriteThread::STATE_COMPLETED) { // write is complete and leader has updated sequence RecordTick(stats_, WRITE_DONE_BY_OTHER); - return w.status; + return w.FinalStatus(); } // else we are the leader of the write batch group assert(w.state == WriteThread::STATE_GROUP_LEADER); @@ -4255,7 +4256,7 @@ Status DBImpl::WriteImpl(const WriteOptions& write_options, uint64_t last_sequence = versions_->LastSequence(); WriteThread::Writer* last_writer = &w; - autovector write_batch_group; + autovector write_group; bool need_log_sync = !write_options.disableWAL && write_options.sync; bool need_log_dir_sync = need_log_sync && !log_dir_synced_; @@ -4274,24 +4275,15 @@ Status DBImpl::WriteImpl(const WriteOptions& write_options, // during this phase since &w is currently responsible for logging // and protects against concurrent loggers and concurrent writes // into memtables - - mutex_.Unlock(); - - if (callback != nullptr) { - // If this write has a validation callback, check to see if this write - // is able to be written. Must be called on the write thread. - status = callback->Callback(this); - callback_failed = true; - } - } else { - mutex_.Unlock(); } + mutex_.Unlock(); + // At this point the mutex is unlocked bool exit_completed_early = false; - last_batch_group_size_ = write_thread_.EnterAsBatchGroupLeader( - &w, &last_writer, &write_batch_group); + last_batch_group_size_ = + write_thread_.EnterAsBatchGroupLeader(&w, &last_writer, &write_group); if (status.ok()) { // Rules for when we can update the memtable concurrently @@ -4307,15 +4299,22 @@ Status DBImpl::WriteImpl(const WriteOptions& write_options, // assumed to be true. Rule 4 is checked for each batch. We could // relax rules 2 and 3 if we could prevent write batches from referring // more than once to a particular key. - bool parallel = db_options_.allow_concurrent_memtable_write && - write_batch_group.size() > 1; + bool parallel = + db_options_.allow_concurrent_memtable_write && write_group.size() > 1; int total_count = 0; uint64_t total_byte_size = 0; - for (auto b : write_batch_group) { - total_count += WriteBatchInternal::Count(b); - total_byte_size = WriteBatchInternal::AppendedByteSize( - total_byte_size, WriteBatchInternal::ByteSize(b)); - parallel = parallel && !b->HasMerge(); + for (auto writer : write_group) { + if (writer->CheckCallback(this)) { + total_count += WriteBatchInternal::Count(writer->batch); + total_byte_size = WriteBatchInternal::AppendedByteSize( + total_byte_size, WriteBatchInternal::ByteSize(writer->batch)); + parallel = parallel && !writer->batch->HasMerge(); + } + } + + if (total_count == 0) { + write_thread_.ExitAsBatchGroupLeader(&w, last_writer, status); + return w.FinalStatus(); } const SequenceNumber current_sequence = last_sequence + 1; @@ -4336,15 +4335,17 @@ Status DBImpl::WriteImpl(const WriteOptions& write_options, PERF_TIMER_GUARD(write_wal_time); WriteBatch* merged_batch = nullptr; - if (write_batch_group.size() == 1) { - merged_batch = write_batch_group[0]; + if (write_group.size() == 1) { + merged_batch = write_group[0]->batch; } else { // WAL needs all of the batches flattened into a single batch. // We could avoid copying here with an iov-like AddRecord // interface merged_batch = &tmp_batch_; - for (auto b : write_batch_group) { - WriteBatchInternal::Append(merged_batch, b); + for (auto writer : write_group) { + if (!writer->CallbackFailed()) { + WriteBatchInternal::Append(merged_batch, writer->batch); + } } } WriteBatchInternal::SetSequence(merged_batch, current_sequence); @@ -4405,7 +4406,7 @@ Status DBImpl::WriteImpl(const WriteOptions& write_options, } stats->AddDBStats(InternalStats::WAL_FILE_BYTES, log_size); } - uint64_t for_other = write_batch_group.size() - 1; + uint64_t for_other = write_group.size() - 1; if (for_other > 0) { stats->AddDBStats(InternalStats::WRITE_DONE_BY_OTHER, for_other); if (!write_options.disableWAL) { @@ -4416,43 +4417,50 @@ Status DBImpl::WriteImpl(const WriteOptions& write_options, if (!parallel) { status = WriteBatchInternal::InsertInto( - write_batch_group, current_sequence, column_family_memtables_.get(), + write_group, current_sequence, column_family_memtables_.get(), &flush_scheduler_, write_options.ignore_missing_column_families, 0 /*log_number*/, this, false /*dont_filter_deletes*/); + + if (status.ok()) { + // There were no write failures. Set leader's status + // in case the write callback returned a non-ok status. + status = w.FinalStatus(); + } + } else { WriteThread::ParallelGroup pg; pg.leader = &w; pg.last_writer = last_writer; + pg.last_sequence = last_sequence; pg.early_exit_allowed = !need_log_sync; - pg.running.store(static_cast(write_batch_group.size()), + pg.running.store(static_cast(write_group.size()), std::memory_order_relaxed); write_thread_.LaunchParallelFollowers(&pg, current_sequence); - ColumnFamilyMemTablesImpl column_family_memtables( - versions_->GetColumnFamilySet()); - assert(w.sequence == current_sequence); - WriteBatchInternal::SetSequence(w.batch, w.sequence); - w.status = WriteBatchInternal::InsertInto( - w.batch, &column_family_memtables, &flush_scheduler_, - write_options.ignore_missing_column_families, 0 /*log_number*/, - this, true /*dont_filter_deletes*/, - true /*concurrent_memtable_writes*/); + if (!w.CallbackFailed()) { + // do leader write + ColumnFamilyMemTablesImpl column_family_memtables( + versions_->GetColumnFamilySet()); + assert(w.sequence == current_sequence); + WriteBatchInternal::SetSequence(w.batch, w.sequence); + w.status = WriteBatchInternal::InsertInto( + w.batch, &column_family_memtables, &flush_scheduler_, + write_options.ignore_missing_column_families, 0 /*log_number*/, + this, true /*dont_filter_deletes*/, + true /*concurrent_memtable_writes*/); + } - assert(last_writer->sequence + - WriteBatchInternal::Count(last_writer->batch) - 1 == - last_sequence); // CompleteParallelWorker returns true if this thread should // handle exit, false means somebody else did exit_completed_early = !write_thread_.CompleteParallelWorker(&w); - status = w.status; - assert(status.ok() || !exit_completed_early); + status = w.FinalStatus(); } - if (status.ok() && !exit_completed_early) { + if (!exit_completed_early && w.status.ok()) { SetTickerCount(stats_, SEQUENCE_NUMBER, last_sequence); versions_->SetLastSequence(last_sequence); if (!need_log_sync) { - write_thread_.ExitAsBatchGroupLeader(&w, last_writer, status); + write_thread_.ExitAsBatchGroupLeader(&w, last_writer, w.status); exit_completed_early = true; } } @@ -4465,14 +4473,14 @@ Status DBImpl::WriteImpl(const WriteOptions& write_options, // // Is setting bg_error_ enough here? This will at least stop // compaction and fail any further writes. - if (!status.ok() && bg_error_.ok()) { + if (!status.ok() && bg_error_.ok() && !w.CallbackFailed()) { bg_error_ = status; } } } PERF_TIMER_START(write_pre_and_post_process_time); - if (db_options_.paranoid_checks && !status.ok() && !callback_failed && + if (db_options_.paranoid_checks && !status.ok() && !w.CallbackFailed() && !status.IsBusy()) { mutex_.Lock(); if (bg_error_.ok()) { @@ -4488,7 +4496,7 @@ Status DBImpl::WriteImpl(const WriteOptions& write_options, } if (!exit_completed_early) { - write_thread_.ExitAsBatchGroupLeader(&w, last_writer, status); + write_thread_.ExitAsBatchGroupLeader(&w, last_writer, w.status); } return status; diff --git a/db/db_test.cc b/db/db_test.cc index d39cbfe29..c6be57a99 100644 --- a/db/db_test.cc +++ b/db/db_test.cc @@ -235,13 +235,11 @@ TEST_F(DBTest, WriteEmptyBatch) { CreateAndReopenWithCF({"pikachu"}, options); ASSERT_OK(Put(1, "foo", "bar")); - env_->sync_counter_.store(0); WriteOptions wo; wo.sync = true; wo.disableWAL = false; WriteBatch empty_batch; ASSERT_OK(dbfull()->Write(wo, &empty_batch)); - ASSERT_GE(env_->sync_counter_.load(), 1); // make sure we can re-open it. ASSERT_OK(TryReopenWithColumnFamilies({"default", "pikachu"}, options)); diff --git a/db/write_batch.cc b/db/write_batch.cc index 0565c0599..accc313e4 100644 --- a/db/write_batch.cc +++ b/db/write_batch.cc @@ -798,18 +798,23 @@ class MemTableInserter : public WriteBatch::Handler { // 3) During Write(), in a concurrent context where memtables has been cloned // The reason is that it calls memtables->Seek(), which has a stateful cache Status WriteBatchInternal::InsertInto( - const autovector& batches, SequenceNumber sequence, + const autovector& writers, SequenceNumber sequence, ColumnFamilyMemTables* memtables, FlushScheduler* flush_scheduler, bool ignore_missing_column_families, uint64_t log_number, DB* db, const bool dont_filter_deletes, bool concurrent_memtable_writes) { MemTableInserter inserter(sequence, memtables, flush_scheduler, ignore_missing_column_families, log_number, db, dont_filter_deletes, concurrent_memtable_writes); - Status rv = Status::OK(); - for (size_t i = 0; i < batches.size() && rv.ok(); ++i) { - rv = batches[i]->Iterate(&inserter); + + for (size_t i = 0; i < writers.size(); i++) { + if (!writers[i]->CallbackFailed()) { + writers[i]->status = writers[i]->batch->Iterate(&inserter); + if (!writers[i]->status.ok()) { + return writers[i]->status; + } + } } - return rv; + return Status::OK(); } Status WriteBatchInternal::InsertInto(const WriteBatch* batch, diff --git a/db/write_batch_internal.h b/db/write_batch_internal.h index d75d2ef65..1ee234b84 100644 --- a/db/write_batch_internal.h +++ b/db/write_batch_internal.h @@ -9,6 +9,7 @@ #pragma once #include +#include "db/write_thread.h" #include "rocksdb/types.h" #include "rocksdb/write_batch.h" #include "rocksdb/db.h" @@ -134,7 +135,7 @@ class WriteBatchInternal { // // Under concurrent use, the caller is responsible for making sure that // the memtables object itself is thread-local. - static Status InsertInto(const autovector& batches, + static Status InsertInto(const autovector& batches, SequenceNumber sequence, ColumnFamilyMemTables* memtables, FlushScheduler* flush_scheduler, diff --git a/db/write_callback.h b/db/write_callback.h index 7dcca96fe..a549f415a 100644 --- a/db/write_callback.h +++ b/db/write_callback.h @@ -19,6 +19,9 @@ class WriteCallback { // this function returns a non-OK status, the write will be aborted and this // status will be returned to the caller of DB::Write(). virtual Status Callback(DB* db) = 0; + + // return true if writes with this callback can be batched with other writes + virtual bool AllowWriteBatching() = 0; }; } // namespace rocksdb diff --git a/db/write_callback_test.cc b/db/write_callback_test.cc index 47b7cf72a..3b76fd2d1 100644 --- a/db/write_callback_test.cc +++ b/db/write_callback_test.cc @@ -6,12 +6,15 @@ #ifndef ROCKSDB_LITE #include +#include +#include #include "db/db_impl.h" #include "db/write_callback.h" #include "rocksdb/db.h" #include "rocksdb/write_batch.h" #include "util/logging.h" +#include "util/sync_point.h" #include "util/testharness.h" using std::string; @@ -42,6 +45,8 @@ class WriteCallbackTestWriteCallback1 : public WriteCallback { return Status::OK(); } + + bool AllowWriteBatching() override { return true; } }; class WriteCallbackTestWriteCallback2 : public WriteCallback { @@ -49,8 +54,223 @@ class WriteCallbackTestWriteCallback2 : public WriteCallback { Status Callback(DB *db) override { return Status::Busy(); } + bool AllowWriteBatching() override { return true; } }; +class MockWriteCallback : public WriteCallback { + public: + bool should_fail_ = false; + bool was_called_ = false; + bool allow_batching_ = false; + + Status Callback(DB* db) override { + was_called_ = true; + if (should_fail_) { + return Status::Busy(); + } else { + return Status::OK(); + } + } + + bool AllowWriteBatching() override { return allow_batching_; } +}; + +TEST_F(WriteCallbackTest, WriteWithCallbackTest) { + struct WriteOP { + WriteOP(bool should_fail = false) { callback_.should_fail_ = should_fail; } + + void Put(const string& key, const string& val) { + kvs_.push_back(std::make_pair(key, val)); + write_batch_.Put(key, val); + } + + void Clear() { + kvs_.clear(); + write_batch_.Clear(); + callback_.was_called_ = false; + } + + MockWriteCallback callback_; + WriteBatch write_batch_; + std::vector> kvs_; + }; + + std::vector> write_scenarios = { + {true}, + {false}, + {false, false}, + {true, true}, + {true, false}, + {false, true}, + {false, false, false}, + {true, true, true}, + {false, true, false}, + {true, false, true}, + {true, false, false, false, false}, + {false, false, false, false, true}, + {false, false, true, false, true}, + }; + + for (auto& allow_parallel : {true, false}) { + for (auto& allow_batching : {true, false}) { + for (auto& write_group : write_scenarios) { + Options options; + options.create_if_missing = true; + options.allow_concurrent_memtable_write = allow_parallel; + + WriteOptions write_options; + ReadOptions read_options; + DB* db; + DBImpl* db_impl; + + ASSERT_OK(DB::Open(options, dbname, &db)); + + db_impl = dynamic_cast(db); + ASSERT_TRUE(db_impl); + + std::atomic threads_waiting(0); + std::atomic seq(db_impl->GetLatestSequenceNumber()); + ASSERT_EQ(db_impl->GetLatestSequenceNumber(), 0); + + rocksdb::SyncPoint::GetInstance()->SetCallBack( + "WriteThread::JoinBatchGroup:Wait", [&](void* arg) { + uint64_t cur_threads_waiting = 0; + bool is_leader = false; + bool is_last = false; + + // who am i + do { + cur_threads_waiting = threads_waiting.load(); + is_leader = (cur_threads_waiting == 0); + is_last = (cur_threads_waiting == write_group.size() - 1); + } while (!threads_waiting.compare_exchange_strong( + cur_threads_waiting, cur_threads_waiting + 1)); + + // check my state + auto* writer = reinterpret_cast(arg); + + if (is_leader) { + ASSERT_TRUE(writer->state == + WriteThread::State::STATE_GROUP_LEADER); + } else { + ASSERT_TRUE(writer->state == WriteThread::State::STATE_INIT); + } + + // (meta test) the first WriteOP should indeed be the first + // and the last should be the last (all others can be out of + // order) + if (is_leader) { + ASSERT_TRUE(writer->callback->Callback(nullptr).ok() == + !write_group.front().callback_.should_fail_); + } else if (is_last) { + ASSERT_TRUE(writer->callback->Callback(nullptr).ok() == + !write_group.back().callback_.should_fail_); + } + + // wait for friends + while (threads_waiting.load() < write_group.size()) { + } + }); + + rocksdb::SyncPoint::GetInstance()->SetCallBack( + "WriteThread::JoinBatchGroup:DoneWaiting", [&](void* arg) { + // check my state + auto* writer = reinterpret_cast(arg); + + if (!allow_batching) { + // no batching so everyone should be a leader + ASSERT_TRUE(writer->state == + WriteThread::State::STATE_GROUP_LEADER); + } else if (!allow_parallel) { + ASSERT_TRUE(writer->state == + WriteThread::State::STATE_COMPLETED); + } + }); + + std::atomic thread_num(0); + std::atomic dummy_key(0); + std::function write_with_callback_func = [&]() { + uint32_t i = thread_num.fetch_add(1); + Random rnd(i); + + // leaders gotta lead + while (i > 0 && threads_waiting.load() < 1) { + } + + // loser has to lose + while (i == write_group.size() - 1 && + threads_waiting.load() < write_group.size() - 1) { + } + + auto& write_op = write_group.at(i); + write_op.Clear(); + write_op.callback_.allow_batching_ = allow_batching; + + // insert some keys + for (uint32_t j = 0; j < rnd.Next() % 50; j++) { + // grab unique key + char my_key = 0; + do { + my_key = dummy_key.load(); + } while (!dummy_key.compare_exchange_strong(my_key, my_key + 1)); + + string skey(5, my_key); + string sval(10, my_key); + write_op.Put(skey, sval); + + if (!write_op.callback_.should_fail_) { + seq.fetch_add(1); + } + } + + WriteOptions woptions; + Status s = db_impl->WriteWithCallback( + woptions, &write_op.write_batch_, &write_op.callback_); + + if (write_op.callback_.should_fail_) { + ASSERT_TRUE(s.IsBusy()); + } else { + ASSERT_OK(s); + } + }; + + rocksdb::SyncPoint::GetInstance()->EnableProcessing(); + + // do all the writes + std::vector threads; + for (uint32_t i = 0; i < write_group.size(); i++) { + threads.emplace_back(write_with_callback_func); + } + for (auto& t : threads) { + t.join(); + } + + rocksdb::SyncPoint::GetInstance()->DisableProcessing(); + + // check for keys + string value; + for (auto& w : write_group) { + ASSERT_TRUE(w.callback_.was_called_); + for (auto& kvp : w.kvs_) { + if (w.callback_.should_fail_) { + ASSERT_TRUE( + db->Get(read_options, kvp.first, &value).IsNotFound()); + } else { + ASSERT_OK(db->Get(read_options, kvp.first, &value)); + ASSERT_EQ(value, kvp.second); + } + } + } + + ASSERT_EQ(seq.load(), db_impl->GetLatestSequenceNumber()); + + delete db; + DestroyDB(dbname, options); + } + } + } +} + TEST_F(WriteCallbackTest, WriteCallBackTest) { Options options; WriteOptions write_options; diff --git a/db/write_thread.cc b/db/write_thread.cc index e153f319b..ce269f664 100644 --- a/db/write_thread.cc +++ b/db/write_thread.cc @@ -218,21 +218,25 @@ void WriteThread::JoinBatchGroup(Writer* w) { assert(w->batch != nullptr); bool linked_as_leader; LinkOne(w, &linked_as_leader); + + TEST_SYNC_POINT_CALLBACK("WriteThread::JoinBatchGroup:Wait", w); + if (!linked_as_leader) { AwaitState(w, STATE_GROUP_LEADER | STATE_PARALLEL_FOLLOWER | STATE_COMPLETED, &ctx); + TEST_SYNC_POINT_CALLBACK("WriteThread::JoinBatchGroup:DoneWaiting", w); } } size_t WriteThread::EnterAsBatchGroupLeader( Writer* leader, WriteThread::Writer** last_writer, - autovector* write_batch_group) { + autovector* write_batch_group) { assert(leader->link_older == nullptr); assert(leader->batch != nullptr); size_t size = WriteBatchInternal::ByteSize(leader->batch); - write_batch_group->push_back(leader->batch); + write_batch_group->push_back(leader); // Allow the group to grow up to a maximum size, but if the // original write is small, limit the growth so we do not slow @@ -244,12 +248,6 @@ size_t WriteThread::EnterAsBatchGroupLeader( *last_writer = leader; - if (leader->has_callback) { - // TODO(agiardullo:) Batching not currently supported as this write may - // fail if the callback function decides to abort this write. - return size; - } - Writer* newest_writer = newest_writer_.load(std::memory_order_acquire); // This is safe regardless of any db mutex status of the caller. Previous @@ -276,18 +274,17 @@ size_t WriteThread::EnterAsBatchGroupLeader( break; } - if (w->has_callback) { - // Do not include writes which may be aborted if the callback does not - // succeed. - break; - } - if (w->batch == nullptr) { // Do not include those writes with nullptr batch. Those are not writes, // those are something else. They want to be alone break; } + if (w->callback != nullptr && !w->callback->AllowWriteBatching()) { + // dont batch writes that don't want to be batched + break; + } + auto batch_size = WriteBatchInternal::ByteSize(w->batch); if (size + batch_size > max_size) { // Do not make batch too big @@ -295,7 +292,7 @@ size_t WriteThread::EnterAsBatchGroupLeader( } size += batch_size; - write_batch_group->push_back(w->batch); + write_batch_group->push_back(w); w->in_batch_group = true; *last_writer = w; } @@ -313,7 +310,10 @@ void WriteThread::LaunchParallelFollowers(ParallelGroup* pg, w->sequence = sequence; while (w != pg->last_writer) { - sequence += WriteBatchInternal::Count(w->batch); + // Writers that won't write don't get sequence allotment + if (!w->CallbackFailed()) { + sequence += WriteBatchInternal::Count(w->batch); + } w = w->link_newer; w->sequence = sequence; @@ -330,6 +330,7 @@ bool WriteThread::CompleteParallelWorker(Writer* w) { std::lock_guard guard(w->StateMutex()); pg->status = w->status; } + auto leader = pg->leader; auto early_exit_allowed = pg->early_exit_allowed; @@ -364,8 +365,8 @@ void WriteThread::EarlyExitParallelGroup(Writer* w) { assert(w->state == STATE_PARALLEL_FOLLOWER); assert(pg->status.ok()); ExitAsBatchGroupLeader(pg->leader, pg->last_writer, pg->status); - assert(w->state == STATE_COMPLETED); assert(w->status.ok()); + assert(w->state == STATE_COMPLETED); SetState(pg->leader, STATE_COMPLETED); } @@ -407,7 +408,6 @@ void WriteThread::ExitAsBatchGroupLeader(Writer* leader, Writer* last_writer, while (last_writer != leader) { last_writer->status = status; - // we need to read link_older before calling SetState, because as soon // as it is marked committed the other thread's Await may return and // deallocate the Writer. diff --git a/db/write_thread.h b/db/write_thread.h index e31904ed1..b1dbaca32 100644 --- a/db/write_thread.h +++ b/db/write_thread.h @@ -13,8 +13,10 @@ #include #include #include -#include "db/write_batch_internal.h" +#include "db/write_callback.h" +#include "rocksdb/types.h" #include "rocksdb/status.h" +#include "rocksdb/write_batch.h" #include "util/autovector.h" #include "util/instrumented_mutex.h" @@ -65,6 +67,7 @@ class WriteThread { struct ParallelGroup { Writer* leader; Writer* last_writer; + SequenceNumber last_sequence; bool early_exit_allowed; // before running goes to zero, status needs leader->StateMutex() Status status; @@ -77,12 +80,13 @@ class WriteThread { bool sync; bool disableWAL; bool in_batch_group; - bool has_callback; + WriteCallback* callback; bool made_waitable; // records lazy construction of mutex and cv std::atomic state; // write under StateMutex() or pre-link ParallelGroup* parallel_group; SequenceNumber sequence; // the sequence number to use - Status status; + Status status; // status of memtable inserter + Status callback_status; // status returned by callback->Callback() std::aligned_storage::type state_mutex_bytes; std::aligned_storage::type state_cv_bytes; Writer* link_older; // read/write only before linking, or as leader @@ -93,9 +97,10 @@ class WriteThread { sync(false), disableWAL(false), in_batch_group(false), - has_callback(false), + callback(nullptr), made_waitable(false), state(STATE_INIT), + parallel_group(nullptr), link_older(nullptr), link_newer(nullptr) {} @@ -106,6 +111,13 @@ class WriteThread { } } + bool CheckCallback(DB* db) { + if (callback != nullptr) { + callback_status = callback->Callback(db); + } + return callback_status.ok(); + } + void CreateMutex() { if (!made_waitable) { // Note that made_waitable is tracked separately from state @@ -117,6 +129,30 @@ class WriteThread { } } + // returns the aggregate status of this Writer + Status FinalStatus() { + if (!status.ok()) { + // a non-ok memtable write status takes presidence + assert(callback == nullptr || callback_status.ok()); + return status; + } else if (!callback_status.ok()) { + // if the callback failed then that is the status we want + // because a memtable insert should not have been attempted + assert(callback != nullptr); + assert(status.ok()); + return callback_status; + } else { + // if there is no callback then we only care about + // the memtable insert status + assert(callback == nullptr || callback_status.ok()); + return status; + } + } + + bool CallbackFailed() { + return (callback != nullptr) && !callback_status.ok(); + } + // No other mutexes may be acquired while holding StateMutex(), it is // always last in the order std::mutex& StateMutex() { @@ -160,8 +196,9 @@ class WriteThread { // Writer** last_writer: Out-param that identifies the last follower // autovector* write_batch_group: Out-param of group members // returns: Total batch group byte size - size_t EnterAsBatchGroupLeader(Writer* leader, Writer** last_writer, - autovector* write_batch_group); + size_t EnterAsBatchGroupLeader( + Writer* leader, Writer** last_writer, + autovector* write_batch_group); // Causes JoinBatchGroup to return STATE_PARALLEL_FOLLOWER for all of the // non-leader members of this write batch group. Sets Writer::sequence diff --git a/utilities/transactions/optimistic_transaction_impl.h b/utilities/transactions/optimistic_transaction_impl.h index a18561efd..36db5e94c 100644 --- a/utilities/transactions/optimistic_transaction_impl.h +++ b/utilities/transactions/optimistic_transaction_impl.h @@ -71,6 +71,8 @@ class OptimisticTransactionCallback : public WriteCallback { return txn_->CheckTransactionForConflicts(db); } + bool AllowWriteBatching() override { return false; } + private: OptimisticTransactionImpl* txn_; }; diff --git a/utilities/transactions/transaction_impl.h b/utilities/transactions/transaction_impl.h index caed15d3a..37a556ef6 100644 --- a/utilities/transactions/transaction_impl.h +++ b/utilities/transactions/transaction_impl.h @@ -110,6 +110,26 @@ class TransactionImpl : public TransactionBaseImpl { void operator=(const TransactionImpl&); }; +// Used at commit time to check whether transaction is committing before its +// expiration time. +class TransactionCallback : public WriteCallback { + public: + explicit TransactionCallback(TransactionImpl* txn) : txn_(txn) {} + + Status Callback(DB* db) override { + if (txn_->IsExpired()) { + return Status::Expired(); + } else { + return Status::OK(); + } + } + + bool AllowWriteBatching() override { return true; } + + private: + TransactionImpl* txn_; +}; + } // namespace rocksdb #endif // ROCKSDB_LITE