diff --git a/util/thread_local.cc b/util/thread_local.cc index b69f41a77..ba6b545dd 100644 --- a/util/thread_local.cc +++ b/util/thread_local.cc @@ -307,6 +307,18 @@ void ThreadLocalPtr::StaticMeta::Scrape(uint32_t id, autovector* ptrs, } } +void ThreadLocalPtr::StaticMeta::Fold(uint32_t id, FoldFunc func, void* res) { + MutexLock l(Mutex()); + for (ThreadData* t = head_.next; t != &head_; t = t->next) { + if (id < t->entries.size()) { + void* ptr = t->entries[id].ptr.load(); + if (ptr != nullptr) { + func(ptr, res); + } + } + } +} + void ThreadLocalPtr::StaticMeta::SetHandler(uint32_t id, UnrefHandler handler) { MutexLock l(Mutex()); handler_map_[id] = handler; @@ -388,4 +400,8 @@ void ThreadLocalPtr::Scrape(autovector* ptrs, void* const replacement) { Instance()->Scrape(id_, ptrs, replacement); } +void ThreadLocalPtr::Fold(FoldFunc func, void* res) { + Instance()->Fold(id_, func, res); +} + } // namespace rocksdb diff --git a/util/thread_local.h b/util/thread_local.h index 08eabd06e..5806b544e 100644 --- a/util/thread_local.h +++ b/util/thread_local.h @@ -63,6 +63,13 @@ class ThreadLocalPtr { // data for all existing threads void Scrape(autovector* ptrs, void* const replacement); + typedef std::function FoldFunc; + // Update res by applying func on each thread-local value. Holds a lock that + // prevents unref handler from running during this call, but clients must + // still provide external synchronization since the owning thread can + // access the values without internal locking, e.g., via Get() and Reset(). + void Fold(FoldFunc func, void* res); + // Initialize the static singletons of the ThreadLocalPtr. // // If this function is not called, then the singletons will be @@ -119,7 +126,6 @@ class ThreadLocalPtr { // Return the pointer value for the given id for the current thread. void* Get(uint32_t id) const; // Reset the pointer value for the given id for the current thread. - // It triggers UnrefHanlder if the id has existing pointer value. void Reset(uint32_t id, void* ptr); // Atomically swap the supplied ptr and return the previous value void* Swap(uint32_t id, void* ptr); @@ -129,6 +135,11 @@ class ThreadLocalPtr { // Reset all thread local data to replacement, and return non-nullptr // data for all existing threads void Scrape(uint32_t id, autovector* ptrs, void* const replacement); + // Update res by applying func on each thread-local value. Holds a lock that + // prevents unref handler from running during this call, but clients must + // still provide external synchronization since the owning thread can + // access the values without internal locking, e.g., via Get() and Reset(). + void Fold(uint32_t id, FoldFunc func, void* res); // Register the UnrefHandler for id void SetHandler(uint32_t id, UnrefHandler handler); diff --git a/util/thread_local_test.cc b/util/thread_local_test.cc index 262b6a557..3f148b874 100644 --- a/util/thread_local_test.cc +++ b/util/thread_local_test.cc @@ -457,6 +457,64 @@ TEST_F(ThreadLocalTest, Scrape) { } } +TEST_F(ThreadLocalTest, Fold) { + auto unref = [](void* ptr) { + delete static_cast*>(ptr); + }; + const int kNumThreads = 16; + const int kItersPerThread = 10; + port::Mutex mu; + port::CondVar cv(&mu); + Params params(&mu, &cv, nullptr, kNumThreads, unref); + auto func = [](void* ptr) { + auto& p = *static_cast(ptr); + ASSERT_TRUE(p.tls1.Get() == nullptr); + p.tls1.Reset(new std::atomic(0)); + + for (int i = 0; i < kItersPerThread; ++i) { + static_cast*>(p.tls1.Get())->fetch_add(1); + } + + p.mu->Lock(); + ++(p.completed); + p.cv->SignalAll(); + + // Waiting for instruction to exit thread + while (p.completed != 0) { + p.cv->Wait(); + } + p.mu->Unlock(); + }; + + for (int th = 0; th < params.total; ++th) { + env_->StartThread(func, static_cast(¶ms)); + } + + // Wait for all threads to finish using Params + mu.Lock(); + while (params.completed != params.total) { + cv.Wait(); + } + mu.Unlock(); + + // Verify Fold() behavior + int64_t sum = 0; + params.tls1.Fold( + [](void* ptr, void* res) { + auto sum_ptr = static_cast(res); + *sum_ptr += static_cast*>(ptr)->load(); + }, + &sum); + ASSERT_EQ(sum, kNumThreads * kItersPerThread); + + // Signal to exit + mu.Lock(); + params.completed = 0; + cv.SignalAll(); + mu.Unlock(); + env_->WaitForJoin(); +} + TEST_F(ThreadLocalTest, CompareAndSwap) { ThreadLocalPtr tls; ASSERT_TRUE(tls.Swap(reinterpret_cast(1)) == nullptr);