// Copyright (c) 2011-present, Facebook, Inc. All rights reserved. // This source code is licensed under both the GPLv2 (found in the // COPYING file in the root directory) and Apache 2.0 License // (found in the LICENSE.Apache file in the root directory). #include "utilities/transactions/lock/point/point_lock_tracker.h" namespace ROCKSDB_NAMESPACE { namespace { class TrackedKeysColumnFamilyIterator : public LockTracker::ColumnFamilyIterator { public: explicit TrackedKeysColumnFamilyIterator(const TrackedKeys& keys) : tracked_keys_(keys), it_(keys.begin()) {} bool HasNext() const override { return it_ != tracked_keys_.end(); } ColumnFamilyId Next() override { return (it_++)->first; } private: const TrackedKeys& tracked_keys_; TrackedKeys::const_iterator it_; }; class TrackedKeysIterator : public LockTracker::KeyIterator { public: TrackedKeysIterator(const TrackedKeys& keys, ColumnFamilyId id) : key_infos_(keys.at(id)), it_(key_infos_.begin()) {} bool HasNext() const override { return it_ != key_infos_.end(); } const std::string& Next() override { return (it_++)->first; } private: const TrackedKeyInfos& key_infos_; TrackedKeyInfos::const_iterator it_; }; } // namespace void PointLockTracker::Track(const PointLockRequest& r) { auto& keys = tracked_keys_[r.column_family_id]; #ifdef __cpp_lib_unordered_map_try_emplace // use c++17's try_emplace if available, to avoid rehashing the key // in case it is not already in the map auto result = keys.try_emplace(r.key, r.seq); auto it = result.first; if (!result.second && r.seq < it->second.seq) { // Now tracking this key with an earlier sequence number it->second.seq = r.seq; } #else auto it = keys.find(r.key); if (it == keys.end()) { auto result = keys.emplace(r.key, TrackedKeyInfo(r.seq)); it = result.first; } else if (r.seq < it->second.seq) { // Now tracking this key with an earlier sequence number it->second.seq = r.seq; } #endif // else we do not update the seq. The smaller the tracked seq, the stronger it // the guarantee since it implies from the seq onward there has not been a // concurrent update to the key. So we update the seq if it implies stronger // guarantees, i.e., if it is smaller than the existing tracked seq. if (r.read_only) { it->second.num_reads++; } else { it->second.num_writes++; } it->second.exclusive = it->second.exclusive || r.exclusive; } UntrackStatus PointLockTracker::Untrack(const PointLockRequest& r) { auto cf_keys = tracked_keys_.find(r.column_family_id); if (cf_keys == tracked_keys_.end()) { return UntrackStatus::NOT_TRACKED; } auto& keys = cf_keys->second; auto it = keys.find(r.key); if (it == keys.end()) { return UntrackStatus::NOT_TRACKED; } bool untracked = false; auto& info = it->second; if (r.read_only) { if (info.num_reads > 0) { info.num_reads--; untracked = true; } } else { if (info.num_writes > 0) { info.num_writes--; untracked = true; } } bool removed = false; if (info.num_reads == 0 && info.num_writes == 0) { keys.erase(it); if (keys.empty()) { tracked_keys_.erase(cf_keys); } removed = true; } if (removed) { return UntrackStatus::REMOVED; } if (untracked) { return UntrackStatus::UNTRACKED; } return UntrackStatus::NOT_TRACKED; } void PointLockTracker::Merge(const LockTracker& tracker) { const PointLockTracker& t = static_cast<const PointLockTracker&>(tracker); for (const auto& cf_keys : t.tracked_keys_) { ColumnFamilyId cf = cf_keys.first; const auto& keys = cf_keys.second; auto current_cf_keys = tracked_keys_.find(cf); if (current_cf_keys == tracked_keys_.end()) { tracked_keys_.emplace(cf_keys); } else { auto& current_keys = current_cf_keys->second; for (const auto& key_info : keys) { const std::string& key = key_info.first; const TrackedKeyInfo& info = key_info.second; // If key was not previously tracked, just copy the whole struct over. // Otherwise, some merging needs to occur. auto current_info = current_keys.find(key); if (current_info == current_keys.end()) { current_keys.emplace(key_info); } else { current_info->second.Merge(info); } } } } } void PointLockTracker::Subtract(const LockTracker& tracker) { const PointLockTracker& t = static_cast<const PointLockTracker&>(tracker); for (const auto& cf_keys : t.tracked_keys_) { ColumnFamilyId cf = cf_keys.first; const auto& keys = cf_keys.second; auto& current_keys = tracked_keys_.at(cf); for (const auto& key_info : keys) { const std::string& key = key_info.first; const TrackedKeyInfo& info = key_info.second; uint32_t num_reads = info.num_reads; uint32_t num_writes = info.num_writes; auto current_key_info = current_keys.find(key); assert(current_key_info != current_keys.end()); // Decrement the total reads/writes of this key by the number of // reads/writes done since the last SavePoint. if (num_reads > 0) { assert(current_key_info->second.num_reads >= num_reads); current_key_info->second.num_reads -= num_reads; } if (num_writes > 0) { assert(current_key_info->second.num_writes >= num_writes); current_key_info->second.num_writes -= num_writes; } if (current_key_info->second.num_reads == 0 && current_key_info->second.num_writes == 0) { current_keys.erase(current_key_info); } } } } LockTracker* PointLockTracker::GetTrackedLocksSinceSavePoint( const LockTracker& save_point_tracker) const { // Examine the number of reads/writes performed on all keys written // since the last SavePoint and compare to the total number of reads/writes // for each key. LockTracker* t = new PointLockTracker(); const PointLockTracker& save_point_t = static_cast<const PointLockTracker&>(save_point_tracker); for (const auto& cf_keys : save_point_t.tracked_keys_) { ColumnFamilyId cf = cf_keys.first; const auto& keys = cf_keys.second; auto& current_keys = tracked_keys_.at(cf); for (const auto& key_info : keys) { const std::string& key = key_info.first; const TrackedKeyInfo& info = key_info.second; uint32_t num_reads = info.num_reads; uint32_t num_writes = info.num_writes; auto current_key_info = current_keys.find(key); assert(current_key_info != current_keys.end()); assert(current_key_info->second.num_reads >= num_reads); assert(current_key_info->second.num_writes >= num_writes); if (current_key_info->second.num_reads == num_reads && current_key_info->second.num_writes == num_writes) { // All the reads/writes to this key were done in the last savepoint. PointLockRequest r; r.column_family_id = cf; r.key = key; r.seq = info.seq; r.read_only = (num_writes == 0); r.exclusive = info.exclusive; t->Track(r); } } } return t; } PointLockStatus PointLockTracker::GetPointLockStatus( ColumnFamilyId column_family_id, const std::string& key) const { assert(IsPointLockSupported()); PointLockStatus status; auto it = tracked_keys_.find(column_family_id); if (it == tracked_keys_.end()) { return status; } const auto& keys = it->second; auto key_it = keys.find(key); if (key_it == keys.end()) { return status; } const TrackedKeyInfo& key_info = key_it->second; status.locked = true; status.exclusive = key_info.exclusive; status.seq = key_info.seq; return status; } uint64_t PointLockTracker::GetNumPointLocks() const { uint64_t num_keys = 0; for (const auto& cf_keys : tracked_keys_) { num_keys += cf_keys.second.size(); } return num_keys; } LockTracker::ColumnFamilyIterator* PointLockTracker::GetColumnFamilyIterator() const { return new TrackedKeysColumnFamilyIterator(tracked_keys_); } LockTracker::KeyIterator* PointLockTracker::GetKeyIterator( ColumnFamilyId column_family_id) const { assert(tracked_keys_.find(column_family_id) != tracked_keys_.end()); return new TrackedKeysIterator(tracked_keys_, column_family_id); } void PointLockTracker::Clear() { tracked_keys_.clear(); } } // namespace ROCKSDB_NAMESPACE