diff --git a/CMakeLists.txt b/CMakeLists.txt index 4b976608b..2731a60ef 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -778,6 +778,8 @@ set(SOURCES utilities/simulator_cache/sim_cache.cc utilities/table_properties_collectors/compact_on_deletion_collector.cc utilities/trace/file_trace_reader_writer.cc + utilities/transactions/lock/lock_tracker.cc + utilities/transactions/lock/point_lock_tracker.cc utilities/transactions/optimistic_transaction_db_impl.cc utilities/transactions/optimistic_transaction.cc utilities/transactions/pessimistic_transaction.cc diff --git a/TARGETS b/TARGETS index f82ec8bcf..4fcb54fe0 100644 --- a/TARGETS +++ b/TARGETS @@ -358,6 +358,8 @@ cpp_library( "utilities/simulator_cache/sim_cache.cc", "utilities/table_properties_collectors/compact_on_deletion_collector.cc", "utilities/trace/file_trace_reader_writer.cc", + "utilities/transactions/lock/lock_tracker.cc", + "utilities/transactions/lock/point_lock_tracker.cc", "utilities/transactions/optimistic_transaction.cc", "utilities/transactions/optimistic_transaction_db_impl.cc", "utilities/transactions/pessimistic_transaction.cc", diff --git a/src.mk b/src.mk index be64083c1..a684311be 100644 --- a/src.mk +++ b/src.mk @@ -238,6 +238,8 @@ LIB_SOURCES = \ utilities/simulator_cache/sim_cache.cc \ utilities/table_properties_collectors/compact_on_deletion_collector.cc \ utilities/trace/file_trace_reader_writer.cc \ + utilities/transactions/lock/lock_tracker.cc \ + utilities/transactions/lock/point_lock_tracker.cc \ utilities/transactions/optimistic_transaction.cc \ utilities/transactions/optimistic_transaction_db_impl.cc \ utilities/transactions/pessimistic_transaction.cc \ diff --git a/utilities/transactions/lock/lock_tracker.cc b/utilities/transactions/lock/lock_tracker.cc new file mode 100644 index 000000000..c367c273d --- /dev/null +++ b/utilities/transactions/lock/lock_tracker.cc @@ -0,0 +1,17 @@ +// Copyright (c) 2011-present, Facebook, Inc. All rights reserved. +// This source code is licensed under both the GPLv2 (found in the +// COPYING file in the root directory) and Apache 2.0 License +// (found in the LICENSE.Apache file in the root directory). + +#include "utilities/transactions/lock/lock_tracker.h" + +#include "utilities/transactions/lock/point_lock_tracker.h" + +namespace ROCKSDB_NAMESPACE { + +LockTracker* NewLockTracker() { + // TODO: determine the lock tracker implementation based on configuration. + return new PointLockTracker(); +} + +} // namespace ROCKSDB_NAMESPACE diff --git a/utilities/transactions/lock/lock_tracker.h b/utilities/transactions/lock/lock_tracker.h new file mode 100644 index 000000000..2129dd2a6 --- /dev/null +++ b/utilities/transactions/lock/lock_tracker.h @@ -0,0 +1,199 @@ +// Copyright (c) 2011-present, Facebook, Inc. All rights reserved. +// This source code is licensed under both the GPLv2 (found in the +// COPYING file in the root directory) and Apache 2.0 License +// (found in the LICENSE.Apache file in the root directory). + +#pragma once + +#include + +#include "rocksdb/rocksdb_namespace.h" +#include "rocksdb/status.h" +#include "rocksdb/types.h" + +namespace ROCKSDB_NAMESPACE { + +using ColumnFamilyId = uint32_t; + +// Request for locking a single key. +struct PointLockRequest { + // The id of the key's column family. + ColumnFamilyId column_family_id = 0; + // The key to lock. + std::string key; + // The sequence number from which there is no concurrent update to key. + SequenceNumber seq = 0; + // Whether the lock is acquired only for read. + bool read_only = false; + // Whether the lock is in exclusive mode. + bool exclusive = true; +}; + +// Request for locking a range of keys. +struct RangeLockRequest { + // TODO +}; + +struct PointLockStatus { + // Whether the key is locked. + bool locked = false; + // Whether the key is locked in exclusive mode. + bool exclusive = true; + // The sequence number in the tracked PointLockRequest. + SequenceNumber seq = 0; +}; + +// Return status when calling LockTracker::Untrack. +enum class UntrackStatus { + // The lock is not tracked at all, so no lock to untrack. + NOT_TRACKED, + // The lock is untracked but not removed from the tracker. + UNTRACKED, + // The lock is removed from the tracker. + REMOVED, +}; + +// Tracks the lock requests. +// In PessimisticTransaction, it tracks the locks acquired through LockMgr; +// In OptimisticTransaction, since there is no LockMgr, it tracks the lock +// intention. Not thread-safe. +class LockTracker { + public: + virtual ~LockTracker() {} + + // Whether supports locking a specific key. + virtual bool IsPointLockSupported() const = 0; + + // Whether supports locking a range of keys. + virtual bool IsRangeLockSupported() const = 0; + + // Tracks the acquirement of a lock on key. + // + // If this method is not supported, leave it as a no-op. + virtual void Track(const PointLockRequest& /*lock_request*/) = 0; + + // Untracks the lock on a key. + // seq and exclusive in lock_request are not used. + // + // If this method is not supported, leave it as a no-op and + // returns NOT_TRACKED. + virtual UntrackStatus Untrack(const PointLockRequest& /*lock_request*/) = 0; + + // Counterpart of Track(const PointLockRequest&) for RangeLockRequest. + virtual void Track(const RangeLockRequest& /*lock_request*/) = 0; + + // Counterpart of Untrack(const PointLockRequest&) for RangeLockRequest. + virtual UntrackStatus Untrack(const RangeLockRequest& /*lock_request*/) = 0; + + // Merges lock requests tracked in the specified tracker into the current + // tracker. + // + // E.g. for point lock, if a key in tracker is not yet tracked, + // track this new key; otherwise, merge the tracked information of the key + // such as lock's exclusiveness, read/write statistics. + // + // If this method is not supported, leave it as a no-op. + // + // REQUIRED: the specified tracker must be of the same concrete class type as + // the current tracker. + virtual void Merge(const LockTracker& /*tracker*/) = 0; + + // This is a reverse operation of Merge. + // + // E.g. for point lock, if a key exists in both current and the sepcified + // tracker, then subtract the information (such as read/write statistics) of + // the key in the specified tracker from the current tracker. + // + // If this method is not supported, leave it as a no-op. + // + // REQUIRED: + // The specified tracker must be of the same concrete class type as + // the current tracker. + // The tracked locks in the specified tracker must be a subset of those + // tracked by the current tracker. + virtual void Subtract(const LockTracker& /*tracker*/) = 0; + + // Clears all tracked locks. + virtual void Clear() = 0; + + // Gets the new locks (excluding the locks that have been tracked before the + // save point) tracked since the specified save point, the result is stored + // in an internally constructed LockTracker and returned. + // + // save_point_tracker is the tracker used by a SavePoint to track locks + // tracked after creating the SavePoint. + // + // The implementation should document whether point lock, or range lock, or + // both are considered in this method. + // If this method is not supported, returns nullptr. + // + // REQUIRED: + // The save_point_tracker must be of the same concrete class type as the + // current tracker. + // The tracked locks in the specified tracker must be a subset of those + // tracked by the current tracker. + virtual LockTracker* GetTrackedLocksSinceSavePoint( + const LockTracker& /*save_point_tracker*/) const = 0; + + // Gets lock related information of the key. + // + // If point lock is not supported, always returns LockStatus with + // locked=false. + virtual PointLockStatus GetPointLockStatus( + ColumnFamilyId /*column_family_id*/, + const std::string& /*key*/) const = 0; + + // Gets number of tracked point locks. + // + // If point lock is not supported, always returns 0. + virtual uint64_t GetNumPointLocks() const = 0; + + class ColumnFamilyIterator { + public: + virtual ~ColumnFamilyIterator() {} + + // Whether there are remaining column families. + virtual bool HasNext() const = 0; + + // Gets next column family id. + // + // If HasNext is false, calling this method has undefined behavior. + virtual ColumnFamilyId Next() = 0; + }; + + // Gets an iterator for column families. + // + // Returned iterator must not be nullptr. + // If there is no column family to iterate, + // returns an empty non-null iterator. + // Caller owns the returned pointer. + virtual ColumnFamilyIterator* GetColumnFamilyIterator() const = 0; + + class KeyIterator { + public: + virtual ~KeyIterator() {} + + // Whether there are remaining keys. + virtual bool HasNext() const = 0; + + // Gets the next key. + // + // If HasNext is false, calling this method has undefined behavior. + virtual const std::string& Next() = 0; + }; + + // Gets an iterator for keys with tracked point locks in the column family. + // + // The column family must exist. + // Returned iterator must not be nullptr. + // Caller owns the returned pointer. + virtual KeyIterator* GetKeyIterator( + ColumnFamilyId /*column_family_id*/) const = 0; +}; + +// LockTracker should always be constructed through this factory method, +// instead of constructing through concrete implementations' constructor. +// Caller owns the returned pointer. +LockTracker* NewLockTracker(); + +} // namespace ROCKSDB_NAMESPACE diff --git a/utilities/transactions/lock/point_lock_tracker.cc b/utilities/transactions/lock/point_lock_tracker.cc new file mode 100644 index 000000000..d6f609ee4 --- /dev/null +++ b/utilities/transactions/lock/point_lock_tracker.cc @@ -0,0 +1,266 @@ +// Copyright (c) 2011-present, Facebook, Inc. All rights reserved. +// This source code is licensed under both the GPLv2 (found in the +// COPYING file in the root directory) and Apache 2.0 License +// (found in the LICENSE.Apache file in the root directory). + +#include "utilities/transactions/lock/point_lock_tracker.h" + +namespace ROCKSDB_NAMESPACE { + +namespace { + +class TrackedKeysColumnFamilyIterator + : public LockTracker::ColumnFamilyIterator { + public: + explicit TrackedKeysColumnFamilyIterator(const TrackedKeys& keys) + : tracked_keys_(keys), it_(keys.begin()) {} + + bool HasNext() const override { return it_ != tracked_keys_.end(); } + + ColumnFamilyId Next() override { return (it_++)->first; } + + private: + const TrackedKeys& tracked_keys_; + TrackedKeys::const_iterator it_; +}; + +class TrackedKeysIterator : public LockTracker::KeyIterator { + public: + TrackedKeysIterator(const TrackedKeys& keys, ColumnFamilyId id) + : key_infos_(keys.at(id)), it_(key_infos_.begin()) {} + + bool HasNext() const override { return it_ != key_infos_.end(); } + + const std::string& Next() override { return (it_++)->first; } + + private: + const TrackedKeyInfos& key_infos_; + TrackedKeyInfos::const_iterator it_; +}; + +} // namespace + +void PointLockTracker::Track(const PointLockRequest& r) { + auto& keys = tracked_keys_[r.column_family_id]; +#ifdef __cpp_lib_unordered_map_try_emplace + // use c++17's try_emplace if available, to avoid rehashing the key + // in case it is not already in the map + auto result = keys.try_emplace(r.key, r.seq); + auto it = result.first; + if (!result.second && r.seq < it->second.seq) { + // Now tracking this key with an earlier sequence number + it->second.seq = r.seq; + } +#else + auto it = keys.find(r.key); + if (it == keys.end()) { + auto result = keys.emplace(r.key, TrackedKeyInfo(r.seq)); + it = result.first; + } else if (r.seq < it->second.seq) { + // Now tracking this key with an earlier sequence number + it->second.seq = r.seq; + } +#endif + // else we do not update the seq. The smaller the tracked seq, the stronger it + // the guarantee since it implies from the seq onward there has not been a + // concurrent update to the key. So we update the seq if it implies stronger + // guarantees, i.e., if it is smaller than the existing tracked seq. + + if (r.read_only) { + it->second.num_reads++; + } else { + it->second.num_writes++; + } + + it->second.exclusive = it->second.exclusive || r.exclusive; +} + +UntrackStatus PointLockTracker::Untrack(const PointLockRequest& r) { + auto cf_keys = tracked_keys_.find(r.column_family_id); + if (cf_keys == tracked_keys_.end()) { + return UntrackStatus::NOT_TRACKED; + } + + auto& keys = cf_keys->second; + auto it = keys.find(r.key); + if (it == keys.end()) { + return UntrackStatus::NOT_TRACKED; + } + + bool untracked = false; + auto& info = it->second; + if (r.read_only) { + if (info.num_reads > 0) { + info.num_reads--; + untracked = true; + } + } else { + if (info.num_writes > 0) { + info.num_writes--; + untracked = true; + } + } + + bool removed = false; + if (info.num_reads == 0 && info.num_writes == 0) { + keys.erase(it); + if (keys.empty()) { + tracked_keys_.erase(cf_keys); + } + removed = true; + } + + if (removed) { + return UntrackStatus::REMOVED; + } + if (untracked) { + return UntrackStatus::UNTRACKED; + } + return UntrackStatus::NOT_TRACKED; +} + +void PointLockTracker::Merge(const LockTracker& tracker) { + const PointLockTracker& t = static_cast(tracker); + for (const auto& cf_keys : t.tracked_keys_) { + ColumnFamilyId cf = cf_keys.first; + const auto& keys = cf_keys.second; + + auto current_cf_keys = tracked_keys_.find(cf); + if (current_cf_keys == tracked_keys_.end()) { + tracked_keys_.emplace(cf_keys); + } else { + auto& current_keys = current_cf_keys->second; + for (const auto& key_info : keys) { + const std::string& key = key_info.first; + const TrackedKeyInfo& info = key_info.second; + // If key was not previously tracked, just copy the whole struct over. + // Otherwise, some merging needs to occur. + auto current_info = current_keys.find(key); + if (current_info == current_keys.end()) { + current_keys.emplace(key_info); + } else { + current_info->second.Merge(info); + } + } + } + } +} + +void PointLockTracker::Subtract(const LockTracker& tracker) { + const PointLockTracker& t = static_cast(tracker); + for (const auto& cf_keys : t.tracked_keys_) { + ColumnFamilyId cf = cf_keys.first; + const auto& keys = cf_keys.second; + + auto& current_keys = tracked_keys_.at(cf); + for (const auto& key_info : keys) { + const std::string& key = key_info.first; + const TrackedKeyInfo& info = key_info.second; + uint32_t num_reads = info.num_reads; + uint32_t num_writes = info.num_writes; + + auto current_key_info = current_keys.find(key); + assert(current_key_info != current_keys.end()); + + // Decrement the total reads/writes of this key by the number of + // reads/writes done since the last SavePoint. + if (num_reads > 0) { + assert(current_key_info->second.num_reads >= num_reads); + current_key_info->second.num_reads -= num_reads; + } + if (num_writes > 0) { + assert(current_key_info->second.num_writes >= num_writes); + current_key_info->second.num_writes -= num_writes; + } + if (current_key_info->second.num_reads == 0 && + current_key_info->second.num_writes == 0) { + current_keys.erase(current_key_info); + } + } + } +} + +LockTracker* PointLockTracker::GetTrackedLocksSinceSavePoint( + const LockTracker& save_point_tracker) const { + // Examine the number of reads/writes performed on all keys written + // since the last SavePoint and compare to the total number of reads/writes + // for each key. + LockTracker* t = new PointLockTracker(); + const PointLockTracker& save_point_t = + static_cast(save_point_tracker); + for (const auto& cf_keys : save_point_t.tracked_keys_) { + ColumnFamilyId cf = cf_keys.first; + const auto& keys = cf_keys.second; + + auto& current_keys = tracked_keys_.at(cf); + for (const auto& key_info : keys) { + const std::string& key = key_info.first; + const TrackedKeyInfo& info = key_info.second; + uint32_t num_reads = info.num_reads; + uint32_t num_writes = info.num_writes; + + auto current_key_info = current_keys.find(key); + assert(current_key_info != current_keys.end()); + assert(current_key_info->second.num_reads >= num_reads); + assert(current_key_info->second.num_writes >= num_writes); + + if (current_key_info->second.num_reads == num_reads && + current_key_info->second.num_writes == num_writes) { + // All the reads/writes to this key were done in the last savepoint. + PointLockRequest r; + r.column_family_id = cf; + r.key = key; + r.seq = info.seq; + r.read_only = (num_writes == 0); + r.exclusive = info.exclusive; + t->Track(r); + } + } + } + return t; +} + +PointLockStatus PointLockTracker::GetPointLockStatus( + ColumnFamilyId column_family_id, const std::string& key) const { + assert(IsPointLockSupported()); + PointLockStatus status; + auto it = tracked_keys_.find(column_family_id); + if (it == tracked_keys_.end()) { + return status; + } + + const auto& keys = it->second; + auto key_it = keys.find(key); + if (key_it == keys.end()) { + return status; + } + + const TrackedKeyInfo& key_info = key_it->second; + status.locked = true; + status.exclusive = key_info.exclusive; + status.seq = key_info.seq; + return status; +} + +uint64_t PointLockTracker::GetNumPointLocks() const { + uint64_t num_keys = 0; + for (const auto& cf_keys : tracked_keys_) { + num_keys += cf_keys.second.size(); + } + return num_keys; +} + +LockTracker::ColumnFamilyIterator* PointLockTracker::GetColumnFamilyIterator() + const { + return new TrackedKeysColumnFamilyIterator(tracked_keys_); +} + +LockTracker::KeyIterator* PointLockTracker::GetKeyIterator( + ColumnFamilyId column_family_id) const { + assert(tracked_keys_.find(column_family_id) != tracked_keys_.end()); + return new TrackedKeysIterator(tracked_keys_, column_family_id); +} + +void PointLockTracker::Clear() { tracked_keys_.clear(); } + +} // namespace ROCKSDB_NAMESPACE diff --git a/utilities/transactions/lock/point_lock_tracker.h b/utilities/transactions/lock/point_lock_tracker.h new file mode 100644 index 000000000..f307d1892 --- /dev/null +++ b/utilities/transactions/lock/point_lock_tracker.h @@ -0,0 +1,84 @@ +// Copyright (c) 2011-present, Facebook, Inc. All rights reserved. +// This source code is licensed under both the GPLv2 (found in the +// COPYING file in the root directory) and Apache 2.0 License +// (found in the LICENSE.Apache file in the root directory). + +#pragma once + +#include +#include +#include + +#include "utilities/transactions/lock/lock_tracker.h" + +namespace ROCKSDB_NAMESPACE { + +struct TrackedKeyInfo { + // Earliest sequence number that is relevant to this transaction for this key + SequenceNumber seq; + + uint32_t num_writes; + uint32_t num_reads; + + bool exclusive; + + explicit TrackedKeyInfo(SequenceNumber seq_no) + : seq(seq_no), num_writes(0), num_reads(0), exclusive(false) {} + + void Merge(const TrackedKeyInfo& info) { + assert(seq <= info.seq); + num_reads += info.num_reads; + num_writes += info.num_writes; + exclusive = exclusive || info.exclusive; + } +}; + +using TrackedKeyInfos = std::unordered_map; + +using TrackedKeys = std::unordered_map; + +// Tracks point locks on single keys. +class PointLockTracker : public LockTracker { + public: + PointLockTracker() = default; + + PointLockTracker(const PointLockTracker&) = delete; + PointLockTracker& operator=(const PointLockTracker&) = delete; + + bool IsPointLockSupported() const override { return true; } + + bool IsRangeLockSupported() const override { return false; } + + void Track(const PointLockRequest& lock_request) override; + + UntrackStatus Untrack(const PointLockRequest& lock_request) override; + + void Track(const RangeLockRequest& /*lock_request*/) override {} + + UntrackStatus Untrack(const RangeLockRequest& /*lock_request*/) override { + return UntrackStatus::NOT_TRACKED; + } + + void Merge(const LockTracker& tracker) override; + + void Subtract(const LockTracker& tracker) override; + + void Clear() override; + + virtual LockTracker* GetTrackedLocksSinceSavePoint( + const LockTracker& save_point_tracker) const override; + + PointLockStatus GetPointLockStatus(ColumnFamilyId column_family_id, + const std::string& key) const override; + + uint64_t GetNumPointLocks() const override; + + ColumnFamilyIterator* GetColumnFamilyIterator() const override; + + KeyIterator* GetKeyIterator(ColumnFamilyId column_family_id) const override; + + private: + TrackedKeys tracked_keys_; +}; + +} // namespace ROCKSDB_NAMESPACE diff --git a/utilities/transactions/optimistic_transaction.cc b/utilities/transactions/optimistic_transaction.cc index f9fb8adfa..7da2cce10 100644 --- a/utilities/transactions/optimistic_transaction.cc +++ b/utilities/transactions/optimistic_transaction.cc @@ -97,9 +97,17 @@ Status OptimisticTransaction::CommitWithParallelValidate() { const size_t space = txn_db_impl->GetLockBucketsSize(); std::set lk_idxes; std::vector> lks; - for (auto& cfit : GetTrackedKeys()) { - for (auto& keyit : cfit.second) { - lk_idxes.insert(fastrange64(GetSliceNPHash64(keyit.first), space)); + std::unique_ptr cf_it( + tracked_locks_->GetColumnFamilyIterator()); + assert(cf_it != nullptr); + while (cf_it->HasNext()) { + ColumnFamilyId cf = cf_it->Next(); + std::unique_ptr key_it( + tracked_locks_->GetKeyIterator(cf)); + assert(key_it != nullptr); + while (key_it->HasNext()) { + const std::string& key = key_it->Next(); + lk_idxes.insert(fastrange64(GetSliceNPHash64(key), space)); } } // NOTE: in a single txn, all bucket-locks are taken in ascending order. @@ -109,7 +117,7 @@ Status OptimisticTransaction::CommitWithParallelValidate() { lks.emplace_back(txn_db_impl->LockBucket(v)); } - Status s = TransactionUtil::CheckKeysForConflicts(db_impl, GetTrackedKeys(), + Status s = TransactionUtil::CheckKeysForConflicts(db_impl, *tracked_locks_, true /* cache_only */); if (!s.ok()) { return s; @@ -174,7 +182,7 @@ Status OptimisticTransaction::CheckTransactionForConflicts(DB* db) { // we will do a cache-only conflict check. This can result in TryAgain // getting returned if there is not sufficient memtable history to check // for conflicts. - return TransactionUtil::CheckKeysForConflicts(db_impl, GetTrackedKeys(), + return TransactionUtil::CheckKeysForConflicts(db_impl, *tracked_locks_, true /* cache_only */); } diff --git a/utilities/transactions/pessimistic_transaction.cc b/utilities/transactions/pessimistic_transaction.cc index d9851f4cd..d92818528 100644 --- a/utilities/transactions/pessimistic_transaction.cc +++ b/utilities/transactions/pessimistic_transaction.cc @@ -91,7 +91,7 @@ void PessimisticTransaction::Initialize(const TransactionOptions& txn_options) { } PessimisticTransaction::~PessimisticTransaction() { - txn_db_impl_->UnLock(this, &GetTrackedKeys()); + txn_db_impl_->UnLock(this, *tracked_locks_); if (expiration_time_ > 0) { txn_db_impl_->RemoveExpirableTransaction(txn_id_); } @@ -101,7 +101,7 @@ PessimisticTransaction::~PessimisticTransaction() { } void PessimisticTransaction::Clear() { - txn_db_impl_->UnLock(this, &GetTrackedKeys()); + txn_db_impl_->UnLock(this, *tracked_locks_); TransactionBaseImpl::Clear(); } @@ -132,8 +132,8 @@ WriteCommittedTxn::WriteCommittedTxn(TransactionDB* txn_db, : PessimisticTransaction(txn_db, write_options, txn_options){}; Status PessimisticTransaction::CommitBatch(WriteBatch* batch) { - TransactionKeyMap keys_to_unlock; - Status s = LockBatch(batch, &keys_to_unlock); + std::unique_ptr keys_to_unlock(NewLockTracker()); + Status s = LockBatch(batch, keys_to_unlock.get()); if (!s.ok()) { return s; @@ -164,7 +164,7 @@ Status PessimisticTransaction::CommitBatch(WriteBatch* batch) { s = Status::InvalidArgument("Transaction is not in state for commit."); } - txn_db_impl_->UnLock(this, &keys_to_unlock); + txn_db_impl_->UnLock(this, *keys_to_unlock); return s; } @@ -446,12 +446,14 @@ Status PessimisticTransaction::RollbackToSavePoint() { return Status::InvalidArgument("Transaction is beyond state for rollback."); } - // Unlock any keys locked since last transaction - const std::unique_ptr& keys = - GetTrackedKeysSinceSavePoint(); - - if (keys) { - txn_db_impl_->UnLock(this, keys.get()); + if (save_points_ != nullptr && !save_points_->empty()) { + // Unlock any keys locked since last transaction + auto& save_point_tracker = *save_points_->top().new_locks_; + std::unique_ptr t( + tracked_locks_->GetTrackedLocksSinceSavePoint(save_point_tracker)); + if (t) { + txn_db_impl_->UnLock(this, *t); + } } return TransactionBaseImpl::RollbackToSavePoint(); @@ -460,7 +462,7 @@ Status PessimisticTransaction::RollbackToSavePoint() { // Lock all keys in this batch. // On success, caller should unlock keys_to_unlock Status PessimisticTransaction::LockBatch(WriteBatch* batch, - TransactionKeyMap* keys_to_unlock) { + LockTracker* keys_to_unlock) { class Handler : public WriteBatch::Handler { public: // Sorted map of column_family_id to sorted set of keys. @@ -516,8 +518,13 @@ Status PessimisticTransaction::LockBatch(WriteBatch* batch, if (!s.ok()) { break; } - TrackKey(keys_to_unlock, cfh_id, std::move(key), kMaxSequenceNumber, - false, true /* exclusive */); + PointLockRequest r; + r.column_family_id = cfh_id; + r.key = key; + r.seq = kMaxSequenceNumber; + r.read_only = false; + r.exclusive = true; + keys_to_unlock->Track(r); } if (!s.ok()) { @@ -526,7 +533,7 @@ Status PessimisticTransaction::LockBatch(WriteBatch* batch, } if (!s.ok()) { - txn_db_impl_->UnLock(this, keys_to_unlock); + txn_db_impl_->UnLock(this, *keys_to_unlock); } return s; @@ -548,28 +555,9 @@ Status PessimisticTransaction::TryLock(ColumnFamilyHandle* column_family, } uint32_t cfh_id = GetColumnFamilyID(column_family); std::string key_str = key.ToString(); - bool previously_locked; - bool lock_upgrade = false; - - // lock this key if this transactions hasn't already locked it - SequenceNumber tracked_at_seq = kMaxSequenceNumber; - - const auto& tracked_keys = GetTrackedKeys(); - const auto tracked_keys_cf = tracked_keys.find(cfh_id); - if (tracked_keys_cf == tracked_keys.end()) { - previously_locked = false; - } else { - auto iter = tracked_keys_cf->second.find(key_str); - if (iter == tracked_keys_cf->second.end()) { - previously_locked = false; - } else { - if (!iter->second.exclusive && exclusive) { - lock_upgrade = true; - } - previously_locked = true; - tracked_at_seq = iter->second.seq; - } - } + PointLockStatus status = tracked_locks_->GetPointLockStatus(cfh_id, key_str); + bool previously_locked = status.locked; + bool lock_upgrade = previously_locked && exclusive && !status.exclusive; // Lock this key if this transactions hasn't already locked it or we require // an upgrade. @@ -585,6 +573,8 @@ Status PessimisticTransaction::TryLock(ColumnFamilyHandle* column_family, // any writes since this transaction's snapshot. // TODO(agiardullo): could optimize by supporting shared txn locks in the // future + SequenceNumber tracked_at_seq = + status.locked ? status.seq : kMaxSequenceNumber; if (!do_validate || snapshot_ == nullptr) { if (assume_tracked && !previously_locked) { s = Status::InvalidArgument( @@ -614,15 +604,13 @@ Status PessimisticTransaction::TryLock(ColumnFamilyHandle* column_family, if (!s.ok()) { // Failed to validate key - if (!previously_locked) { - // Unlock key we just locked - if (lock_upgrade) { - s = txn_db_impl_->TryLock(this, cfh_id, key_str, - false /* exclusive */); - assert(s.ok()); - } else { - txn_db_impl_->UnLock(this, cfh_id, key.ToString()); - } + // Unlock key we just locked + if (lock_upgrade) { + s = txn_db_impl_->TryLock(this, cfh_id, key_str, + false /* exclusive */); + assert(s.ok()); + } else if (!previously_locked) { + txn_db_impl_->UnLock(this, cfh_id, key.ToString()); } } } @@ -645,10 +633,11 @@ Status PessimisticTransaction::TryLock(ColumnFamilyHandle* column_family, TrackKey(cfh_id, key_str, tracked_at_seq, read_only, exclusive); } else { #ifndef NDEBUG - assert(tracked_keys_cf->second.count(key_str) > 0); - const auto& info = tracked_keys_cf->second.find(key_str)->second; - assert(info.seq <= tracked_at_seq); - assert(info.exclusive == exclusive); + PointLockStatus lock_status = + tracked_locks_->GetPointLockStatus(cfh_id, key_str); + assert(lock_status.locked); + assert(lock_status.seq <= tracked_at_seq); + assert(lock_status.exclusive == exclusive); #endif } } diff --git a/utilities/transactions/pessimistic_transaction.h b/utilities/transactions/pessimistic_transaction.h index f81405bc3..308d7460f 100644 --- a/utilities/transactions/pessimistic_transaction.h +++ b/utilities/transactions/pessimistic_transaction.h @@ -139,7 +139,7 @@ class PessimisticTransaction : public TransactionBaseImpl { virtual void Initialize(const TransactionOptions& txn_options); - Status LockBatch(WriteBatch* batch, TransactionKeyMap* keys_to_unlock); + Status LockBatch(WriteBatch* batch, LockTracker* keys_to_unlock); Status TryLock(ColumnFamilyHandle* column_family, const Slice& key, bool read_only, bool exclusive, const bool do_validate = true, diff --git a/utilities/transactions/pessimistic_transaction_db.cc b/utilities/transactions/pessimistic_transaction_db.cc index f68ec94ad..a15df47e1 100644 --- a/utilities/transactions/pessimistic_transaction_db.cc +++ b/utilities/transactions/pessimistic_transaction_db.cc @@ -402,7 +402,7 @@ Status PessimisticTransactionDB::TryLock(PessimisticTransaction* txn, } void PessimisticTransactionDB::UnLock(PessimisticTransaction* txn, - const TransactionKeyMap* keys) { + const LockTracker& keys) { lock_mgr_.UnLock(txn, keys, GetEnv()); } diff --git a/utilities/transactions/pessimistic_transaction_db.h b/utilities/transactions/pessimistic_transaction_db.h index 144458b53..e2b548121 100644 --- a/utilities/transactions/pessimistic_transaction_db.h +++ b/utilities/transactions/pessimistic_transaction_db.h @@ -99,7 +99,7 @@ class PessimisticTransactionDB : public TransactionDB { Status TryLock(PessimisticTransaction* txn, uint32_t cfh_id, const std::string& key, bool exclusive); - void UnLock(PessimisticTransaction* txn, const TransactionKeyMap* keys); + void UnLock(PessimisticTransaction* txn, const LockTracker& keys); void UnLock(PessimisticTransaction* txn, uint32_t cfh_id, const std::string& key); diff --git a/utilities/transactions/transaction_base.cc b/utilities/transactions/transaction_base.cc index 92b3956e1..4c4234027 100644 --- a/utilities/transactions/transaction_base.cc +++ b/utilities/transactions/transaction_base.cc @@ -16,6 +16,7 @@ #include "rocksdb/status.h" #include "util/cast_util.h" #include "util/string_util.h" +#include "utilities/transactions/lock/lock_tracker.h" namespace ROCKSDB_NAMESPACE { @@ -27,6 +28,7 @@ TransactionBaseImpl::TransactionBaseImpl(DB* db, cmp_(GetColumnFamilyUserComparator(db->DefaultColumnFamily())), start_time_(db_->GetEnv()->NowMicros()), write_batch_(cmp_, 0, true, 0), + tracked_locks_(NewLockTracker()), indexing_enabled_(true) { assert(dynamic_cast(db_) != nullptr); log_number_ = 0; @@ -44,7 +46,7 @@ void TransactionBaseImpl::Clear() { save_points_.reset(nullptr); write_batch_.Clear(); commit_time_batch_.Clear(); - tracked_keys_.clear(); + tracked_locks_->Clear(); num_puts_ = 0; num_deletes_ = 0; num_merges_ = 0; @@ -143,37 +145,7 @@ Status TransactionBaseImpl::RollbackToSavePoint() { assert(s.ok()); // Rollback any keys that were tracked since the last savepoint - const TransactionKeyMap& key_map = save_point.new_keys_; - for (const auto& key_map_iter : key_map) { - uint32_t column_family_id = key_map_iter.first; - auto& keys = key_map_iter.second; - - auto& cf_tracked_keys = tracked_keys_[column_family_id]; - - for (const auto& key_iter : keys) { - const std::string& key = key_iter.first; - uint32_t num_reads = key_iter.second.num_reads; - uint32_t num_writes = key_iter.second.num_writes; - - auto tracked_keys_iter = cf_tracked_keys.find(key); - assert(tracked_keys_iter != cf_tracked_keys.end()); - - // Decrement the total reads/writes of this key by the number of - // reads/writes done since the last SavePoint. - if (num_reads > 0) { - assert(tracked_keys_iter->second.num_reads >= num_reads); - tracked_keys_iter->second.num_reads -= num_reads; - } - if (num_writes > 0) { - assert(tracked_keys_iter->second.num_writes >= num_writes); - tracked_keys_iter->second.num_writes -= num_writes; - } - if (tracked_keys_iter->second.num_reads == 0 && - tracked_keys_iter->second.num_writes == 0) { - cf_tracked_keys.erase(tracked_keys_iter); - } - } - } + tracked_locks_->Subtract(*save_point.new_locks_); save_points_->pop(); @@ -204,35 +176,7 @@ Status TransactionBaseImpl::PopSavePoint() { std::swap(top, save_points_->top()); save_points_->pop(); - const TransactionKeyMap& curr_cf_key_map = top.new_keys_; - TransactionKeyMap& prev_cf_key_map = save_points_->top().new_keys_; - - for (const auto& curr_cf_key_iter : curr_cf_key_map) { - uint32_t column_family_id = curr_cf_key_iter.first; - const std::unordered_map& curr_keys = - curr_cf_key_iter.second; - - // If cfid was not previously tracked, just copy everything over. - auto prev_keys_iter = prev_cf_key_map.find(column_family_id); - if (prev_keys_iter == prev_cf_key_map.end()) { - prev_cf_key_map.emplace(curr_cf_key_iter); - } else { - std::unordered_map& prev_keys = - prev_keys_iter->second; - for (const auto& key_iter : curr_keys) { - const std::string& key = key_iter.first; - const TransactionKeyMapInfo& info = key_iter.second; - // If key was not previously tracked, just copy the whole struct over. - // Otherwise, some merging needs to occur. - auto prev_info = prev_keys.find(key); - if (prev_info == prev_keys.end()) { - prev_keys.emplace(key_iter); - } else { - prev_info->second.Merge(info); - } - } - } - } + save_points_->top().new_locks_->Merge(*top.new_locks_); } return write_batch_.PopSavePoint(); @@ -601,108 +545,28 @@ uint64_t TransactionBaseImpl::GetNumDeletes() const { return num_deletes_; } uint64_t TransactionBaseImpl::GetNumMerges() const { return num_merges_; } uint64_t TransactionBaseImpl::GetNumKeys() const { - uint64_t count = 0; - - // sum up locked keys in all column families - for (const auto& key_map_iter : tracked_keys_) { - const auto& keys = key_map_iter.second; - count += keys.size(); - } - - return count; + return tracked_locks_->GetNumPointLocks(); } void TransactionBaseImpl::TrackKey(uint32_t cfh_id, const std::string& key, SequenceNumber seq, bool read_only, bool exclusive) { + PointLockRequest r; + r.column_family_id = cfh_id; + r.key = key; + r.seq = seq; + r.read_only = read_only; + r.exclusive = exclusive; + // Update map of all tracked keys for this transaction - TrackKey(&tracked_keys_, cfh_id, key, seq, read_only, exclusive); + tracked_locks_->Track(r); if (save_points_ != nullptr && !save_points_->empty()) { // Update map of tracked keys in this SavePoint - TrackKey(&save_points_->top().new_keys_, cfh_id, key, seq, read_only, - exclusive); + save_points_->top().new_locks_->Track(r); } } -// Add a key to the given TransactionKeyMap -// seq for pessimistic transactions is the sequence number from which we know -// there has not been a concurrent update to the key. -void TransactionBaseImpl::TrackKey(TransactionKeyMap* key_map, uint32_t cfh_id, - const std::string& key, SequenceNumber seq, - bool read_only, bool exclusive) { - auto& cf_key_map = (*key_map)[cfh_id]; -#ifdef __cpp_lib_unordered_map_try_emplace - // use c++17's try_emplace if available, to avoid rehashing the key - // in case it is not already in the map - auto result = cf_key_map.try_emplace(key, seq); - auto iter = result.first; - if (!result.second && seq < iter->second.seq) { - // Now tracking this key with an earlier sequence number - iter->second.seq = seq; - } -#else - auto iter = cf_key_map.find(key); - if (iter == cf_key_map.end()) { - auto result = cf_key_map.emplace(key, TransactionKeyMapInfo(seq)); - iter = result.first; - } else if (seq < iter->second.seq) { - // Now tracking this key with an earlier sequence number - iter->second.seq = seq; - } -#endif - // else we do not update the seq. The smaller the tracked seq, the stronger it - // the guarantee since it implies from the seq onward there has not been a - // concurrent update to the key. So we update the seq if it implies stronger - // guarantees, i.e., if it is smaller than the existing tracked seq. - - if (read_only) { - iter->second.num_reads++; - } else { - iter->second.num_writes++; - } - iter->second.exclusive |= exclusive; -} - -std::unique_ptr -TransactionBaseImpl::GetTrackedKeysSinceSavePoint() { - if (save_points_ != nullptr && !save_points_->empty()) { - // Examine the number of reads/writes performed on all keys written - // since the last SavePoint and compare to the total number of reads/writes - // for each key. - TransactionKeyMap* result = new TransactionKeyMap(); - for (const auto& key_map_iter : save_points_->top().new_keys_) { - uint32_t column_family_id = key_map_iter.first; - auto& keys = key_map_iter.second; - - auto& cf_tracked_keys = tracked_keys_[column_family_id]; - - for (const auto& key_iter : keys) { - const std::string& key = key_iter.first; - uint32_t num_reads = key_iter.second.num_reads; - uint32_t num_writes = key_iter.second.num_writes; - - auto total_key_info = cf_tracked_keys.find(key); - assert(total_key_info != cf_tracked_keys.end()); - assert(total_key_info->second.num_reads >= num_reads); - assert(total_key_info->second.num_writes >= num_writes); - - if (total_key_info->second.num_reads == num_reads && - total_key_info->second.num_writes == num_writes) { - // All the reads/writes to this key were done in the last savepoint. - bool read_only = (num_writes == 0); - TrackKey(result, column_family_id, key, key_iter.second.seq, - read_only, key_iter.second.exclusive); - } - } - } - return std::unique_ptr(result); - } - - // No SavePoint - return nullptr; -} - // Gets the write batch that should be used for Put/Merge/Deletes. // // Returns either a WriteBatch or WriteBatchWithIndex depending on whether @@ -728,54 +592,28 @@ void TransactionBaseImpl::ReleaseSnapshot(const Snapshot* snapshot, DB* db) { void TransactionBaseImpl::UndoGetForUpdate(ColumnFamilyHandle* column_family, const Slice& key) { - uint32_t column_family_id = GetColumnFamilyID(column_family); - auto& cf_tracked_keys = tracked_keys_[column_family_id]; - std::string key_str = key.ToString(); - bool can_decrement = false; - bool can_unlock __attribute__((__unused__)) = false; + PointLockRequest r; + r.column_family_id = GetColumnFamilyID(column_family); + r.key = key.ToString(); + r.read_only = true; + bool can_untrack = false; if (save_points_ != nullptr && !save_points_->empty()) { - // Check if this key was fetched ForUpdate in this SavePoint - auto& cf_savepoint_keys = save_points_->top().new_keys_[column_family_id]; - - auto savepoint_iter = cf_savepoint_keys.find(key_str); - if (savepoint_iter != cf_savepoint_keys.end()) { - if (savepoint_iter->second.num_reads > 0) { - savepoint_iter->second.num_reads--; - can_decrement = true; - - if (savepoint_iter->second.num_reads == 0 && - savepoint_iter->second.num_writes == 0) { - // No other GetForUpdates or write on this key in this SavePoint - cf_savepoint_keys.erase(savepoint_iter); - can_unlock = true; - } - } - } + // If there is no GetForUpdate of the key in this save point, + // then cannot untrack from the global lock tracker. + UntrackStatus s = save_points_->top().new_locks_->Untrack(r); + can_untrack = (s != UntrackStatus::NOT_TRACKED); } else { - // No SavePoint set - can_decrement = true; - can_unlock = true; + // No save point, so can untrack from the global lock tracker. + can_untrack = true; } - // We can only decrement the read count for this key if we were able to - // decrement the read count in the current SavePoint, OR if there is no - // SavePoint set. - if (can_decrement) { - auto key_iter = cf_tracked_keys.find(key_str); - - if (key_iter != cf_tracked_keys.end()) { - if (key_iter->second.num_reads > 0) { - key_iter->second.num_reads--; - - if (key_iter->second.num_reads == 0 && - key_iter->second.num_writes == 0) { - // No other GetForUpdates or writes on this key - assert(can_unlock); - cf_tracked_keys.erase(key_iter); - UnlockGetForUpdate(column_family, key); - } - } + if (can_untrack) { + // If erased from the global tracker, then can unlock the key. + UntrackStatus s = tracked_locks_->Untrack(r); + bool can_unlock = (s == UntrackStatus::REMOVED); + if (can_unlock) { + UnlockGetForUpdate(column_family, key); } } } diff --git a/utilities/transactions/transaction_base.h b/utilities/transactions/transaction_base.h index f279676c6..c7832bdc8 100644 --- a/utilities/transactions/transaction_base.h +++ b/utilities/transactions/transaction_base.h @@ -1,7 +1,7 @@ // Copyright (c) 2011-present, Facebook, Inc. All rights reserved. -// This source code is licensed under both the GPLv2 (found in the -// COPYING file in the root directory) and Apache 2.0 License -// (found in the LICENSE.Apache file in the root directory). +// This source code is licensed under both the GPLv2 (found in the +// COPYING file in the root directory) and Apache 2.0 License +// (found in the LICENSE.Apache file in the root directory). #pragma once @@ -21,6 +21,7 @@ #include "rocksdb/utilities/transaction_db.h" #include "rocksdb/utilities/write_batch_with_index.h" #include "util/autovector.h" +#include "utilities/transactions/lock/lock_tracker.h" #include "utilities/transactions/transaction_util.h" namespace ROCKSDB_NAMESPACE { @@ -233,10 +234,6 @@ class TransactionBaseImpl : public Transaction { return UndoGetForUpdate(nullptr, key); }; - // Get list of keys in this transaction that must not have any conflicts - // with writes in other transactions. - const TransactionKeyMap& GetTrackedKeys() const { return tracked_keys_; } - WriteOptions* GetWriteOptions() override { return &write_options_; } void SetWriteOptions(const WriteOptions& write_options) override { @@ -260,17 +257,10 @@ class TransactionBaseImpl : public Transaction { void TrackKey(uint32_t cfh_id, const std::string& key, SequenceNumber seqno, bool readonly, bool exclusive); - // Helper function to add a key to the given TransactionKeyMap - static void TrackKey(TransactionKeyMap* key_map, uint32_t cfh_id, - const std::string& key, SequenceNumber seqno, - bool readonly, bool exclusive); - // Called when UndoGetForUpdate determines that this key can be unlocked. virtual void UnlockGetForUpdate(ColumnFamilyHandle* column_family, const Slice& key) = 0; - std::unique_ptr GetTrackedKeysSinceSavePoint(); - // Sets a snapshot if SetSnapshotOnNextOperation() has been called. void SetSnapshotIfNeeded(); @@ -310,8 +300,8 @@ class TransactionBaseImpl : public Transaction { uint64_t num_deletes_ = 0; uint64_t num_merges_ = 0; - // Record all keys tracked since the last savepoint - TransactionKeyMap new_keys_; + // Record all locks tracked since the last savepoint + std::shared_ptr new_locks_; SavePoint(std::shared_ptr snapshot, bool snapshot_needed, std::shared_ptr snapshot_notifier, @@ -321,19 +311,20 @@ class TransactionBaseImpl : public Transaction { snapshot_notifier_(snapshot_notifier), num_puts_(num_puts), num_deletes_(num_deletes), - num_merges_(num_merges) {} + num_merges_(num_merges), + new_locks_(NewLockTracker()) {} - SavePoint() = default; + SavePoint() : new_locks_(NewLockTracker()) {} }; // Records writes pending in this transaction WriteBatchWithIndex write_batch_; - // Map from column_family_id to map of keys that are involved in this - // transaction. - // For Pessimistic Transactions this is the list of locked keys. - // Optimistic Transactions will wait till commit time to do conflict checking. - TransactionKeyMap tracked_keys_; + // For Pessimistic Transactions this is the set of acquired locks. + // Optimistic Transactions will keep note the requested locks (not actually + // locked), and do conflict checking until commit time based on the tracked + // lock requests. + std::unique_ptr tracked_locks_; // Stack of the Snapshot saved at each save point. Saved snapshots may be // nullptr if there was no snapshot at the time SetSavePoint() was called. diff --git a/utilities/transactions/transaction_lock_mgr.cc b/utilities/transactions/transaction_lock_mgr.cc index 64fe00aba..08e3b9dcf 100644 --- a/utilities/transactions/transaction_lock_mgr.cc +++ b/utilities/transactions/transaction_lock_mgr.cc @@ -643,26 +643,27 @@ void TransactionLockMgr::UnLock(PessimisticTransaction* txn, } void TransactionLockMgr::UnLock(const PessimisticTransaction* txn, - const TransactionKeyMap* key_map, Env* env) { - for (auto& key_map_iter : *key_map) { - uint32_t column_family_id = key_map_iter.first; - auto& keys = key_map_iter.second; - - std::shared_ptr lock_map_ptr = GetLockMap(column_family_id); + const LockTracker& tracker, Env* env) { + std::unique_ptr cf_it( + tracker.GetColumnFamilyIterator()); + assert(cf_it != nullptr); + while (cf_it->HasNext()) { + ColumnFamilyId cf = cf_it->Next(); + std::shared_ptr lock_map_ptr = GetLockMap(cf); LockMap* lock_map = lock_map_ptr.get(); - - if (lock_map == nullptr) { + if (!lock_map) { // Column Family must have been dropped. return; } // Bucket keys by lock_map_ stripe std::unordered_map> keys_by_stripe( - std::max(keys.size(), lock_map->num_stripes_)); - - for (auto& key_iter : keys) { - const std::string& key = key_iter.first; - + lock_map->num_stripes_); + std::unique_ptr key_it( + tracker.GetKeyIterator(cf)); + assert(key_it != nullptr); + while (key_it->HasNext()) { + const std::string& key = key_it->Next(); size_t stripe_num = lock_map->GetStripe(key); keys_by_stripe[stripe_num].push_back(&key); } diff --git a/utilities/transactions/transaction_lock_mgr.h b/utilities/transactions/transaction_lock_mgr.h index 78c3a3809..0a9474488 100644 --- a/utilities/transactions/transaction_lock_mgr.h +++ b/utilities/transactions/transaction_lock_mgr.h @@ -77,7 +77,7 @@ class TransactionLockMgr { // Unlock a key locked by TryLock(). txn must be the same Transaction that // locked this key. - void UnLock(const PessimisticTransaction* txn, const TransactionKeyMap* keys, + void UnLock(const PessimisticTransaction* txn, const LockTracker& tracker, Env* env); void UnLock(PessimisticTransaction* txn, uint32_t column_family_id, const std::string& key, Env* env); diff --git a/utilities/transactions/transaction_util.cc b/utilities/transactions/transaction_util.cc index 45b4cb89c..494f132e7 100644 --- a/utilities/transactions/transaction_util.cc +++ b/utilities/transactions/transaction_util.cc @@ -137,18 +137,20 @@ Status TransactionUtil::CheckKey(DBImpl* db_impl, SuperVersion* sv, } Status TransactionUtil::CheckKeysForConflicts(DBImpl* db_impl, - const TransactionKeyMap& key_map, + const LockTracker& tracker, bool cache_only) { Status result; - for (auto& key_map_iter : key_map) { - uint32_t cf_id = key_map_iter.first; - const auto& keys = key_map_iter.second; + std::unique_ptr cf_it( + tracker.GetColumnFamilyIterator()); + assert(cf_it != nullptr); + while (cf_it->HasNext()) { + ColumnFamilyId cf = cf_it->Next(); - SuperVersion* sv = db_impl->GetAndRefSuperVersion(cf_id); + SuperVersion* sv = db_impl->GetAndRefSuperVersion(cf); if (sv == nullptr) { result = Status::InvalidArgument("Could not access column family " + - ToString(cf_id)); + ToString(cf)); break; } @@ -157,18 +159,21 @@ Status TransactionUtil::CheckKeysForConflicts(DBImpl* db_impl, // For each of the keys in this transaction, check to see if someone has // written to this key since the start of the transaction. - for (const auto& key_iter : keys) { - const auto& key = key_iter.first; - const SequenceNumber key_seq = key_iter.second.seq; + std::unique_ptr key_it( + tracker.GetKeyIterator(cf)); + assert(key_it != nullptr); + while (key_it->HasNext()) { + const std::string& key = key_it->Next(); + PointLockStatus status = tracker.GetPointLockStatus(cf, key); + const SequenceNumber key_seq = status.seq; result = CheckKey(db_impl, sv, earliest_seq, key_seq, key, cache_only); - if (!result.ok()) { break; } } - db_impl->ReturnAndCleanupSuperVersion(cf_id, sv); + db_impl->ReturnAndCleanupSuperVersion(cf, sv); if (!result.ok()) { break; diff --git a/utilities/transactions/transaction_util.h b/utilities/transactions/transaction_util.h index 2e48f84a4..707d487c5 100644 --- a/utilities/transactions/transaction_util.h +++ b/utilities/transactions/transaction_util.h @@ -12,39 +12,14 @@ #include "db/dbformat.h" #include "db/read_callback.h" - #include "rocksdb/db.h" #include "rocksdb/slice.h" #include "rocksdb/status.h" #include "rocksdb/types.h" +#include "utilities/transactions/lock/lock_tracker.h" namespace ROCKSDB_NAMESPACE { -struct TransactionKeyMapInfo { - // Earliest sequence number that is relevant to this transaction for this key - SequenceNumber seq; - - uint32_t num_writes; - uint32_t num_reads; - - bool exclusive; - - explicit TransactionKeyMapInfo(SequenceNumber seq_no) - : seq(seq_no), num_writes(0), num_reads(0), exclusive(false) {} - - // Used in PopSavePoint to collapse two savepoints together. - void Merge(const TransactionKeyMapInfo& info) { - assert(seq <= info.seq); - num_reads += info.num_reads; - num_writes += info.num_writes; - exclusive |= info.exclusive; - } -}; - -using TransactionKeyMap = - std::unordered_map>; - class DBImpl; struct SuperVersion; class WriteBatchWithIndex; @@ -69,17 +44,19 @@ class TransactionUtil { ReadCallback* snap_checker = nullptr, SequenceNumber min_uncommitted = kMaxSequenceNumber); - // For each key,SequenceNumber pair in the TransactionKeyMap, this function + // For each key,SequenceNumber pair tracked by the LockTracker, this function // will verify there have been no writes to the key in the db since that // sequence number. // // Returns OK on success, BUSY if there is a conflicting write, or other error // status for any unexpected errors. // - // REQUIRED: this function should only be called on the write thread or if the + // REQUIRED: + // This function should only be called on the write thread or if the // mutex is held. + // tracker must support point lock. static Status CheckKeysForConflicts(DBImpl* db_impl, - const TransactionKeyMap& keys, + const LockTracker& tracker, bool cache_only); private: diff --git a/utilities/transactions/write_unprepared_txn.cc b/utilities/transactions/write_unprepared_txn.cc index 0f9993dca..ed2600026 100644 --- a/utilities/transactions/write_unprepared_txn.cc +++ b/utilities/transactions/write_unprepared_txn.cc @@ -72,10 +72,10 @@ WriteUnpreparedTxn::~WriteUnpreparedTxn() { } } - // Call tracked_keys_.clear() so that ~PessimisticTransaction does not + // Clear the tracked locks so that ~PessimisticTransaction does not // try to unlock keys for recovered transactions. if (recovered_txn_) { - tracked_keys_.clear(); + tracked_locks_->Clear(); } } @@ -296,7 +296,9 @@ Status WriteUnpreparedTxn::FlushWriteBatchToDBInternal(bool prepared) { Status AddUntrackedKey(uint32_t cf, const Slice& key) { auto str = key.ToString(); - if (txn_->tracked_keys_[cf].count(str) == 0) { + PointLockStatus lock_status = + txn_->tracked_locks_->GetPointLockStatus(cf, str); + if (!lock_status.locked) { txn_->untracked_keys_[cf].push_back(str); } return Status::OK(); @@ -639,8 +641,10 @@ Status WriteUnpreparedTxn::CommitInternal() { } Status WriteUnpreparedTxn::WriteRollbackKeys( - const TransactionKeyMap& tracked_keys, WriteBatchWithIndex* rollback_batch, + const LockTracker& lock_tracker, WriteBatchWithIndex* rollback_batch, ReadCallback* callback, const ReadOptions& roptions) { + // This assertion can be removed when range lock is supported. + assert(lock_tracker.IsPointLockSupported()); const auto& cf_map = *wupt_db_->GetCFHandleMap(); auto WriteRollbackKey = [&](const std::string& key, uint32_t cfid) { const auto& cf_handle = cf_map.at(cfid); @@ -666,11 +670,17 @@ Status WriteUnpreparedTxn::WriteRollbackKeys( return Status::OK(); }; - for (const auto& cfkey : tracked_keys) { - const auto cfid = cfkey.first; - const auto& keys = cfkey.second; - for (const auto& pair : keys) { - auto s = WriteRollbackKey(pair.first, cfid); + std::unique_ptr cf_it( + lock_tracker.GetColumnFamilyIterator()); + assert(cf_it != nullptr); + while (cf_it->HasNext()) { + ColumnFamilyId cf = cf_it->Next(); + std::unique_ptr key_it( + lock_tracker.GetKeyIterator(cf)); + assert(key_it != nullptr); + while (key_it->HasNext()) { + const std::string& key = key_it->Next(); + auto s = WriteRollbackKey(key, cf); if (!s.ok()) { return s; } @@ -709,7 +719,7 @@ Status WriteUnpreparedTxn::RollbackInternal() { // TODO(lth): We write rollback batch all in a single batch here, but this // should be subdivded into multiple batches as well. In phase 2, when key // sets are read from WAL, this will happen naturally. - WriteRollbackKeys(GetTrackedKeys(), &rollback_batch, &callback, roptions); + WriteRollbackKeys(*tracked_locks_, &rollback_batch, &callback, roptions); // The Rollback marker will be used as a batch separator WriteBatchInternal::MarkRollback(rollback_batch.GetWriteBatch(), name_); @@ -790,7 +800,7 @@ Status WriteUnpreparedTxn::RollbackInternal() { void WriteUnpreparedTxn::Clear() { if (!recovered_txn_) { - txn_db_impl_->UnLock(this, &GetTrackedKeys()); + txn_db_impl_->UnLock(this, *tracked_locks_); } unprep_seqs_.clear(); flushed_save_points_.reset(nullptr); @@ -842,7 +852,7 @@ Status WriteUnpreparedTxn::RollbackToSavePointInternal() { WriteUnpreparedTxn::SavePoint& top = flushed_save_points_->back(); assert(save_points_ != nullptr && save_points_->size() > 0); - const TransactionKeyMap& tracked_keys = save_points_->top().new_keys_; + const LockTracker& tracked_keys = *save_points_->top().new_locks_; ReadOptions roptions; roptions.snapshot = top.snapshot_->snapshot(); diff --git a/utilities/transactions/write_unprepared_txn.h b/utilities/transactions/write_unprepared_txn.h index 55730e15b..5a3227f4e 100644 --- a/utilities/transactions/write_unprepared_txn.h +++ b/utilities/transactions/write_unprepared_txn.h @@ -212,7 +212,7 @@ class WriteUnpreparedTxn : public WritePreparedTxn { friend class WriteUnpreparedTxnDB; const std::map& GetUnpreparedSequenceNumbers(); - Status WriteRollbackKeys(const TransactionKeyMap& tracked_keys, + Status WriteRollbackKeys(const LockTracker& tracked_keys, WriteBatchWithIndex* rollback_batch, ReadCallback* callback, const ReadOptions& roptions);