Fold function for thread-local data

Summary:
This function allows the user to provide a custom function to fold all
threads' local data. It will be used in my next diff for aggregating statistics
stored in thread-local data. Note the test case uses atomics as thread-local
values due to the synchronization requirement (documented in code).

Test Plan: unit test

Reviewers: yhchiang, sdong, kradhakrishnan

Reviewed By: kradhakrishnan

Subscribers: andrewkr, dhruba, leveldb

Differential Revision: https://reviews.facebook.net/D62049
This commit is contained in:
Andrew Kryczka 2016-08-22 15:37:39 -07:00
parent 817eeb29b4
commit 6584cec8f2
3 changed files with 86 additions and 1 deletions

View File

@ -307,6 +307,18 @@ void ThreadLocalPtr::StaticMeta::Scrape(uint32_t id, autovector<void*>* ptrs,
}
}
void ThreadLocalPtr::StaticMeta::Fold(uint32_t id, FoldFunc func, void* res) {
MutexLock l(Mutex());
for (ThreadData* t = head_.next; t != &head_; t = t->next) {
if (id < t->entries.size()) {
void* ptr = t->entries[id].ptr.load();
if (ptr != nullptr) {
func(ptr, res);
}
}
}
}
void ThreadLocalPtr::StaticMeta::SetHandler(uint32_t id, UnrefHandler handler) {
MutexLock l(Mutex());
handler_map_[id] = handler;
@ -388,4 +400,8 @@ void ThreadLocalPtr::Scrape(autovector<void*>* ptrs, void* const replacement) {
Instance()->Scrape(id_, ptrs, replacement);
}
void ThreadLocalPtr::Fold(FoldFunc func, void* res) {
Instance()->Fold(id_, func, res);
}
} // namespace rocksdb

View File

@ -63,6 +63,13 @@ class ThreadLocalPtr {
// data for all existing threads
void Scrape(autovector<void*>* ptrs, void* const replacement);
typedef std::function<void(void*, void*)> FoldFunc;
// Update res by applying func on each thread-local value. Holds a lock that
// prevents unref handler from running during this call, but clients must
// still provide external synchronization since the owning thread can
// access the values without internal locking, e.g., via Get() and Reset().
void Fold(FoldFunc func, void* res);
// Initialize the static singletons of the ThreadLocalPtr.
//
// If this function is not called, then the singletons will be
@ -119,7 +126,6 @@ class ThreadLocalPtr {
// Return the pointer value for the given id for the current thread.
void* Get(uint32_t id) const;
// Reset the pointer value for the given id for the current thread.
// It triggers UnrefHanlder if the id has existing pointer value.
void Reset(uint32_t id, void* ptr);
// Atomically swap the supplied ptr and return the previous value
void* Swap(uint32_t id, void* ptr);
@ -129,6 +135,11 @@ class ThreadLocalPtr {
// Reset all thread local data to replacement, and return non-nullptr
// data for all existing threads
void Scrape(uint32_t id, autovector<void*>* ptrs, void* const replacement);
// Update res by applying func on each thread-local value. Holds a lock that
// prevents unref handler from running during this call, but clients must
// still provide external synchronization since the owning thread can
// access the values without internal locking, e.g., via Get() and Reset().
void Fold(uint32_t id, FoldFunc func, void* res);
// Register the UnrefHandler for id
void SetHandler(uint32_t id, UnrefHandler handler);

View File

@ -457,6 +457,64 @@ TEST_F(ThreadLocalTest, Scrape) {
}
}
TEST_F(ThreadLocalTest, Fold) {
auto unref = [](void* ptr) {
delete static_cast<std::atomic<int64_t>*>(ptr);
};
const int kNumThreads = 16;
const int kItersPerThread = 10;
port::Mutex mu;
port::CondVar cv(&mu);
Params params(&mu, &cv, nullptr, kNumThreads, unref);
auto func = [](void* ptr) {
auto& p = *static_cast<Params*>(ptr);
ASSERT_TRUE(p.tls1.Get() == nullptr);
p.tls1.Reset(new std::atomic<int64_t>(0));
for (int i = 0; i < kItersPerThread; ++i) {
static_cast<std::atomic<int64_t>*>(p.tls1.Get())->fetch_add(1);
}
p.mu->Lock();
++(p.completed);
p.cv->SignalAll();
// Waiting for instruction to exit thread
while (p.completed != 0) {
p.cv->Wait();
}
p.mu->Unlock();
};
for (int th = 0; th < params.total; ++th) {
env_->StartThread(func, static_cast<void*>(&params));
}
// Wait for all threads to finish using Params
mu.Lock();
while (params.completed != params.total) {
cv.Wait();
}
mu.Unlock();
// Verify Fold() behavior
int64_t sum = 0;
params.tls1.Fold(
[](void* ptr, void* res) {
auto sum_ptr = static_cast<int64_t*>(res);
*sum_ptr += static_cast<std::atomic<int64_t>*>(ptr)->load();
},
&sum);
ASSERT_EQ(sum, kNumThreads * kItersPerThread);
// Signal to exit
mu.Lock();
params.completed = 0;
cv.SignalAll();
mu.Unlock();
env_->WaitForJoin();
}
TEST_F(ThreadLocalTest, CompareAndSwap) {
ThreadLocalPtr tls;
ASSERT_TRUE(tls.Swap(reinterpret_cast<void*>(1)) == nullptr);