Fold function for thread-local data
Summary: This function allows the user to provide a custom function to fold all threads' local data. It will be used in my next diff for aggregating statistics stored in thread-local data. Note the test case uses atomics as thread-local values due to the synchronization requirement (documented in code). Test Plan: unit test Reviewers: yhchiang, sdong, kradhakrishnan Reviewed By: kradhakrishnan Subscribers: andrewkr, dhruba, leveldb Differential Revision: https://reviews.facebook.net/D62049
This commit is contained in:
parent
817eeb29b4
commit
6584cec8f2
@ -307,6 +307,18 @@ void ThreadLocalPtr::StaticMeta::Scrape(uint32_t id, autovector<void*>* ptrs,
|
||||
}
|
||||
}
|
||||
|
||||
void ThreadLocalPtr::StaticMeta::Fold(uint32_t id, FoldFunc func, void* res) {
|
||||
MutexLock l(Mutex());
|
||||
for (ThreadData* t = head_.next; t != &head_; t = t->next) {
|
||||
if (id < t->entries.size()) {
|
||||
void* ptr = t->entries[id].ptr.load();
|
||||
if (ptr != nullptr) {
|
||||
func(ptr, res);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void ThreadLocalPtr::StaticMeta::SetHandler(uint32_t id, UnrefHandler handler) {
|
||||
MutexLock l(Mutex());
|
||||
handler_map_[id] = handler;
|
||||
@ -388,4 +400,8 @@ void ThreadLocalPtr::Scrape(autovector<void*>* ptrs, void* const replacement) {
|
||||
Instance()->Scrape(id_, ptrs, replacement);
|
||||
}
|
||||
|
||||
void ThreadLocalPtr::Fold(FoldFunc func, void* res) {
|
||||
Instance()->Fold(id_, func, res);
|
||||
}
|
||||
|
||||
} // namespace rocksdb
|
||||
|
@ -63,6 +63,13 @@ class ThreadLocalPtr {
|
||||
// data for all existing threads
|
||||
void Scrape(autovector<void*>* ptrs, void* const replacement);
|
||||
|
||||
typedef std::function<void(void*, void*)> FoldFunc;
|
||||
// Update res by applying func on each thread-local value. Holds a lock that
|
||||
// prevents unref handler from running during this call, but clients must
|
||||
// still provide external synchronization since the owning thread can
|
||||
// access the values without internal locking, e.g., via Get() and Reset().
|
||||
void Fold(FoldFunc func, void* res);
|
||||
|
||||
// Initialize the static singletons of the ThreadLocalPtr.
|
||||
//
|
||||
// If this function is not called, then the singletons will be
|
||||
@ -119,7 +126,6 @@ class ThreadLocalPtr {
|
||||
// Return the pointer value for the given id for the current thread.
|
||||
void* Get(uint32_t id) const;
|
||||
// Reset the pointer value for the given id for the current thread.
|
||||
// It triggers UnrefHanlder if the id has existing pointer value.
|
||||
void Reset(uint32_t id, void* ptr);
|
||||
// Atomically swap the supplied ptr and return the previous value
|
||||
void* Swap(uint32_t id, void* ptr);
|
||||
@ -129,6 +135,11 @@ class ThreadLocalPtr {
|
||||
// Reset all thread local data to replacement, and return non-nullptr
|
||||
// data for all existing threads
|
||||
void Scrape(uint32_t id, autovector<void*>* ptrs, void* const replacement);
|
||||
// Update res by applying func on each thread-local value. Holds a lock that
|
||||
// prevents unref handler from running during this call, but clients must
|
||||
// still provide external synchronization since the owning thread can
|
||||
// access the values without internal locking, e.g., via Get() and Reset().
|
||||
void Fold(uint32_t id, FoldFunc func, void* res);
|
||||
|
||||
// Register the UnrefHandler for id
|
||||
void SetHandler(uint32_t id, UnrefHandler handler);
|
||||
|
@ -457,6 +457,64 @@ TEST_F(ThreadLocalTest, Scrape) {
|
||||
}
|
||||
}
|
||||
|
||||
TEST_F(ThreadLocalTest, Fold) {
|
||||
auto unref = [](void* ptr) {
|
||||
delete static_cast<std::atomic<int64_t>*>(ptr);
|
||||
};
|
||||
const int kNumThreads = 16;
|
||||
const int kItersPerThread = 10;
|
||||
port::Mutex mu;
|
||||
port::CondVar cv(&mu);
|
||||
Params params(&mu, &cv, nullptr, kNumThreads, unref);
|
||||
auto func = [](void* ptr) {
|
||||
auto& p = *static_cast<Params*>(ptr);
|
||||
ASSERT_TRUE(p.tls1.Get() == nullptr);
|
||||
p.tls1.Reset(new std::atomic<int64_t>(0));
|
||||
|
||||
for (int i = 0; i < kItersPerThread; ++i) {
|
||||
static_cast<std::atomic<int64_t>*>(p.tls1.Get())->fetch_add(1);
|
||||
}
|
||||
|
||||
p.mu->Lock();
|
||||
++(p.completed);
|
||||
p.cv->SignalAll();
|
||||
|
||||
// Waiting for instruction to exit thread
|
||||
while (p.completed != 0) {
|
||||
p.cv->Wait();
|
||||
}
|
||||
p.mu->Unlock();
|
||||
};
|
||||
|
||||
for (int th = 0; th < params.total; ++th) {
|
||||
env_->StartThread(func, static_cast<void*>(¶ms));
|
||||
}
|
||||
|
||||
// Wait for all threads to finish using Params
|
||||
mu.Lock();
|
||||
while (params.completed != params.total) {
|
||||
cv.Wait();
|
||||
}
|
||||
mu.Unlock();
|
||||
|
||||
// Verify Fold() behavior
|
||||
int64_t sum = 0;
|
||||
params.tls1.Fold(
|
||||
[](void* ptr, void* res) {
|
||||
auto sum_ptr = static_cast<int64_t*>(res);
|
||||
*sum_ptr += static_cast<std::atomic<int64_t>*>(ptr)->load();
|
||||
},
|
||||
&sum);
|
||||
ASSERT_EQ(sum, kNumThreads * kItersPerThread);
|
||||
|
||||
// Signal to exit
|
||||
mu.Lock();
|
||||
params.completed = 0;
|
||||
cv.SignalAll();
|
||||
mu.Unlock();
|
||||
env_->WaitForJoin();
|
||||
}
|
||||
|
||||
TEST_F(ThreadLocalTest, CompareAndSwap) {
|
||||
ThreadLocalPtr tls;
|
||||
ASSERT_TRUE(tls.Swap(reinterpret_cast<void*>(1)) == nullptr);
|
||||
|
Loading…
Reference in New Issue
Block a user