Add StartThread type checking wrapper (#8303)
Summary: - Add class `FunctorWrapper` to invoke the function with given parameters - Implement `StartThreadTyped` which wraps `StartThread` with type checking cover - Demonstrate `StartThreadTyped` in test `util/thread_local_test.cc` https://github.com/facebook/rocksdb/issues/8285 Pull Request resolved: https://github.com/facebook/rocksdb/pull/8303 Reviewed By: ajkr Differential Revision: D28539318 Pulled By: pdillinger fbshipit-source-id: 624789c236bde31163deda95c1e1471aee68933e
This commit is contained in:
parent
13232e11d4
commit
748e3acc11
@ -17,12 +17,15 @@
|
||||
#pragma once
|
||||
|
||||
#include <stdint.h>
|
||||
|
||||
#include <cstdarg>
|
||||
#include <functional>
|
||||
#include <limits>
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "rocksdb/functor_wrapper.h"
|
||||
#include "rocksdb/status.h"
|
||||
#include "rocksdb/thread_status.h"
|
||||
|
||||
@ -422,6 +425,21 @@ class Env {
|
||||
// When "function(arg)" returns, the thread will be destroyed.
|
||||
virtual void StartThread(void (*function)(void* arg), void* arg) = 0;
|
||||
|
||||
// Start a new thread, invoking "function(args...)" within the new thread.
|
||||
// When "function(args...)" returns, the thread will be destroyed.
|
||||
template <typename FunctionT, typename... Args>
|
||||
void StartThreadTyped(FunctionT function, Args&&... args) {
|
||||
using FWType = FunctorWrapper<Args...>;
|
||||
StartThread(
|
||||
[](void* arg) {
|
||||
auto* functor = static_cast<FWType*>(arg);
|
||||
functor->invoke();
|
||||
delete functor;
|
||||
},
|
||||
new FWType(std::function<void(Args...)>(function),
|
||||
std::forward<Args>(args)...));
|
||||
}
|
||||
|
||||
// Wait for all threads started by StartThread to terminate.
|
||||
virtual void WaitForJoin() {}
|
||||
|
||||
|
55
include/rocksdb/functor_wrapper.h
Normal file
55
include/rocksdb/functor_wrapper.h
Normal file
@ -0,0 +1,55 @@
|
||||
// Copyright (c) 2011-present, Facebook, Inc. All rights reserved.
|
||||
// This source code is licensed under both the GPLv2 (found in the
|
||||
// COPYING file in the root directory) and Apache 2.0 License
|
||||
// (found in the LICENSE.Apache file in the root directory).
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <memory>
|
||||
#include <utility>
|
||||
|
||||
#include "rocksdb/rocksdb_namespace.h"
|
||||
|
||||
namespace ROCKSDB_NAMESPACE {
|
||||
|
||||
namespace detail {
|
||||
template <std::size_t...>
|
||||
struct IndexSequence {};
|
||||
|
||||
template <std::size_t N, std::size_t... Next>
|
||||
struct IndexSequenceHelper
|
||||
: public IndexSequenceHelper<N - 1U, N - 1U, Next...> {};
|
||||
|
||||
template <std::size_t... Next>
|
||||
struct IndexSequenceHelper<0U, Next...> {
|
||||
using type = IndexSequence<Next...>;
|
||||
};
|
||||
|
||||
template <std::size_t N>
|
||||
using make_index_sequence = typename IndexSequenceHelper<N>::type;
|
||||
|
||||
template <typename Function, typename Tuple, size_t... I>
|
||||
void call(Function f, Tuple t, IndexSequence<I...>) {
|
||||
f(std::get<I>(t)...);
|
||||
}
|
||||
|
||||
template <typename Function, typename Tuple>
|
||||
void call(Function f, Tuple t) {
|
||||
static constexpr auto size = std::tuple_size<Tuple>::value;
|
||||
call(f, t, make_index_sequence<size>{});
|
||||
}
|
||||
} // namespace detail
|
||||
|
||||
template <typename... Args>
|
||||
class FunctorWrapper {
|
||||
public:
|
||||
explicit FunctorWrapper(std::function<void(Args...)> functor, Args &&...args)
|
||||
: functor_(std::move(functor)), args_(std::forward<Args>(args)...) {}
|
||||
|
||||
void invoke() { detail::call(functor_, args_); }
|
||||
|
||||
private:
|
||||
std::function<void(Args...)> functor_;
|
||||
std::tuple<Args...> args_;
|
||||
};
|
||||
} // namespace ROCKSDB_NAMESPACE
|
@ -3,9 +3,11 @@
|
||||
// COPYING file in the root directory) and Apache 2.0 License
|
||||
// (found in the LICENSE.Apache file in the root directory).
|
||||
|
||||
#include <thread>
|
||||
#include "util/thread_local.h"
|
||||
|
||||
#include <atomic>
|
||||
#include <string>
|
||||
#include <thread>
|
||||
|
||||
#include "port/port.h"
|
||||
#include "rocksdb/env.h"
|
||||
@ -13,7 +15,6 @@
|
||||
#include "test_util/testharness.h"
|
||||
#include "test_util/testutil.h"
|
||||
#include "util/autovector.h"
|
||||
#include "util/thread_local.h"
|
||||
|
||||
namespace ROCKSDB_NAMESPACE {
|
||||
|
||||
@ -51,10 +52,8 @@ struct Params {
|
||||
};
|
||||
|
||||
class IDChecker : public ThreadLocalPtr {
|
||||
public:
|
||||
static uint32_t PeekId() {
|
||||
return TEST_PeekId();
|
||||
}
|
||||
public:
|
||||
static uint32_t PeekId() { return TEST_PeekId(); }
|
||||
};
|
||||
|
||||
} // anonymous namespace
|
||||
@ -122,9 +121,8 @@ TEST_F(ThreadLocalTest, SequentialReadWriteTest) {
|
||||
ASSERT_GT(IDChecker::PeekId(), base_id);
|
||||
base_id = IDChecker::PeekId();
|
||||
|
||||
auto func = [](void* ptr) {
|
||||
auto& params = *static_cast<Params*>(ptr);
|
||||
|
||||
auto func = [](Params* ptr) {
|
||||
Params& params = *ptr;
|
||||
ASSERT_TRUE(params.tls1.Get() == nullptr);
|
||||
params.tls1.Reset(reinterpret_cast<int*>(1));
|
||||
ASSERT_TRUE(params.tls1.Get() == reinterpret_cast<int*>(1));
|
||||
@ -146,7 +144,8 @@ TEST_F(ThreadLocalTest, SequentialReadWriteTest) {
|
||||
for (int iter = 0; iter < 1024; ++iter) {
|
||||
ASSERT_EQ(IDChecker::PeekId(), base_id);
|
||||
// Another new thread, read/write should not see value from previous thread
|
||||
env_->StartThread(func, static_cast<void*>(&p));
|
||||
env_->StartThreadTyped(func, &p);
|
||||
|
||||
mu.Lock();
|
||||
while (p.completed != iter + 1) {
|
||||
cv.Wait();
|
||||
@ -221,10 +220,10 @@ TEST_F(ThreadLocalTest, ConcurrentReadWriteTest) {
|
||||
// Each thread local copy of the value are also different from each
|
||||
// other.
|
||||
for (int th = 0; th < p1.total; ++th) {
|
||||
env_->StartThread(func, static_cast<void*>(&p1));
|
||||
env_->StartThreadTyped(func, &p1);
|
||||
}
|
||||
for (int th = 0; th < p2.total; ++th) {
|
||||
env_->StartThread(func, static_cast<void*>(&p2));
|
||||
env_->StartThreadTyped(func, &p2);
|
||||
}
|
||||
|
||||
mu1.Lock();
|
||||
@ -251,9 +250,8 @@ TEST_F(ThreadLocalTest, Unref) {
|
||||
};
|
||||
|
||||
// Case 0: no unref triggered if ThreadLocalPtr is never accessed
|
||||
auto func0 = [](void* ptr) {
|
||||
auto& p = *static_cast<Params*>(ptr);
|
||||
|
||||
auto func0 = [](Params* ptr) {
|
||||
auto& p = *ptr;
|
||||
p.mu->Lock();
|
||||
++(p.started);
|
||||
p.cv->SignalAll();
|
||||
@ -270,15 +268,15 @@ TEST_F(ThreadLocalTest, Unref) {
|
||||
Params p(&mu, &cv, &unref_count, th, unref);
|
||||
|
||||
for (int i = 0; i < p.total; ++i) {
|
||||
env_->StartThread(func0, static_cast<void*>(&p));
|
||||
env_->StartThreadTyped(func0, &p);
|
||||
}
|
||||
env_->WaitForJoin();
|
||||
ASSERT_EQ(unref_count, 0);
|
||||
}
|
||||
|
||||
// Case 1: unref triggered by thread exit
|
||||
auto func1 = [](void* ptr) {
|
||||
auto& p = *static_cast<Params*>(ptr);
|
||||
auto func1 = [](Params* ptr) {
|
||||
auto& p = *ptr;
|
||||
|
||||
p.mu->Lock();
|
||||
++(p.started);
|
||||
@ -307,7 +305,7 @@ TEST_F(ThreadLocalTest, Unref) {
|
||||
p.tls2 = &tls2;
|
||||
|
||||
for (int i = 0; i < p.total; ++i) {
|
||||
env_->StartThread(func1, static_cast<void*>(&p));
|
||||
env_->StartThreadTyped(func1, &p);
|
||||
}
|
||||
|
||||
env_->WaitForJoin();
|
||||
@ -317,8 +315,8 @@ TEST_F(ThreadLocalTest, Unref) {
|
||||
}
|
||||
|
||||
// Case 2: unref triggered by ThreadLocal instance destruction
|
||||
auto func2 = [](void* ptr) {
|
||||
auto& p = *static_cast<Params*>(ptr);
|
||||
auto func2 = [](Params* ptr) {
|
||||
auto& p = *ptr;
|
||||
|
||||
p.mu->Lock();
|
||||
++(p.started);
|
||||
@ -356,7 +354,7 @@ TEST_F(ThreadLocalTest, Unref) {
|
||||
p.tls2 = new ThreadLocalPtr(unref);
|
||||
|
||||
for (int i = 0; i < p.total; ++i) {
|
||||
env_->StartThread(func2, static_cast<void*>(&p));
|
||||
env_->StartThreadTyped(func2, &p);
|
||||
}
|
||||
|
||||
// Wait for all threads to finish using Params
|
||||
@ -431,7 +429,7 @@ TEST_F(ThreadLocalTest, Scrape) {
|
||||
p.tls2 = new ThreadLocalPtr(unref);
|
||||
|
||||
for (int i = 0; i < p.total; ++i) {
|
||||
env_->StartThread(func, static_cast<void*>(&p));
|
||||
env_->StartThreadTyped(func, &p);
|
||||
}
|
||||
|
||||
// Wait for all threads to finish using Params
|
||||
@ -490,7 +488,7 @@ TEST_F(ThreadLocalTest, Fold) {
|
||||
};
|
||||
|
||||
for (int th = 0; th < params.total; ++th) {
|
||||
env_->StartThread(func, static_cast<void*>(¶ms));
|
||||
env_->StartThread(func, ¶ms);
|
||||
}
|
||||
|
||||
// Wait for all threads to finish using Params
|
||||
|
Loading…
Reference in New Issue
Block a user