tdlight/tdutils/test/StealingQueue.cpp
2024-01-01 03:07:21 +03:00

181 lines
4.6 KiB
C++

//
// Copyright Aliaksei Levin (levlam@telegram.org), Arseny Smirnov (arseny30@gmail.com) 2014-2024
//
// Distributed under the Boost Software License, Version 1.0. (See accompanying
// file LICENSE_1_0.txt or copy at http://www.boost.org/LICENSE_1_0.txt)
//
#include "td/utils/AtomicRead.h"
#include "td/utils/benchmark.h"
#include "td/utils/common.h"
#include "td/utils/logging.h"
#include "td/utils/MpmcQueue.h"
#include "td/utils/port/thread.h"
#include "td/utils/Random.h"
#include "td/utils/SliceBuilder.h"
#include "td/utils/StealingQueue.h"
#include "td/utils/tests.h"
#include <atomic>
#include <cstring>
TEST(StealingQueue, very_simple) {
td::StealingQueue<int, 8> q;
q.local_push(1, [](auto x) { UNREACHABLE(); });
int x;
CHECK(q.local_pop(x));
ASSERT_EQ(1, x);
}
#if !TD_THREAD_UNSUPPORTED
TEST(AtomicRead, simple) {
td::Stage run;
td::Stage check;
std::size_t threads_n = 10;
td::vector<td::thread> threads;
int x{0};
std::atomic<int> version{0};
td::int64 res = 0;
for (std::size_t i = 0; i < threads_n; i++) {
threads.emplace_back([&, id = static_cast<td::uint32>(i)] {
for (td::uint64 round = 1; round < 10000; round++) {
run.wait(round * threads_n);
if (id == 0) {
version++;
x++;
version++;
} else {
int y = 0;
auto v1 = version.load();
y = x;
auto v2 = version.load();
if (v1 == v2 && v1 % 2 == 0) {
res += y;
}
}
check.wait(round * threads_n);
}
});
}
td::do_not_optimize_away(res);
for (auto &thread : threads) {
thread.join();
}
}
TEST(AtomicRead, simple2) {
td::Stage run;
td::Stage check;
std::size_t threads_n = 10;
td::vector<td::thread> threads;
struct Value {
td::uint64 value = 0;
char str[50] = "0 0 0 0";
};
td::AtomicRead<Value> value;
auto to_str = [](td::uint64 i) {
return PSTRING() << i << " " << i << " " << i << " " << i;
};
for (std::size_t i = 0; i < threads_n; i++) {
threads.emplace_back([&, id = static_cast<td::uint32>(i)] {
for (td::uint64 round = 1; round < 10000; round++) {
run.wait(round * threads_n);
if (id == 0) {
auto x = value.lock();
x->value = round;
auto str = to_str(round);
std::memcpy(x->str, str.c_str(), str.size() + 1);
} else {
Value x;
value.read(x);
LOG_CHECK(x.value == round || x.value == round - 1) << x.value << " " << round;
CHECK(x.str == to_str(x.value));
}
check.wait(round * threads_n);
}
});
}
for (auto &thread : threads) {
thread.join();
}
}
TEST(StealingQueue, simple) {
td::uint64 sum = 0;
std::atomic<td::uint64> got_sum{0};
td::Stage run;
td::Stage check;
std::size_t threads_n = 10;
td::vector<td::thread> threads;
td::vector<td::StealingQueue<int, 8>> lq(threads_n);
td::MpmcQueue<int> gq(threads_n);
constexpr td::uint64 XN = 20;
td::uint64 x_sum[XN];
x_sum[0] = 0;
x_sum[1] = 1;
for (td::uint64 i = 2; i < XN; i++) {
x_sum[i] = i + x_sum[i - 1] + x_sum[i - 2];
}
td::Random::Xorshift128plus rnd(123);
for (std::size_t i = 0; i < threads_n; i++) {
threads.emplace_back([&, id = static_cast<td::uint32>(i)] {
for (td::uint64 round = 1; round < 1000; round++) {
if (id == 0) {
sum = 0;
auto n = static_cast<int>(rnd() % 5);
for (int j = 0; j < n; j++) {
auto x = static_cast<int>(rnd() % XN);
sum += x_sum[x];
gq.push(x, id);
}
got_sum = 0;
}
run.wait(round * threads_n);
while (got_sum.load() != sum) {
auto x = [&] {
int res;
if (lq[id].local_pop(res)) {
return res;
}
if (gq.try_pop(res, id)) {
return res;
}
if (lq[id].steal(res, lq[static_cast<size_t>(rnd()) % threads_n])) {
//LOG(ERROR) << "STEAL";
return res;
}
return 0;
}();
if (x == 0) {
continue;
}
//LOG(ERROR) << x << " " << got_sum.load() << " " << sum;
got_sum.fetch_add(x, std::memory_order_relaxed);
lq[id].local_push(x - 1, [&](auto y) {
//LOG(ERROR) << "OVERFLOW";
gq.push(y, id);
});
if (x > 1) {
lq[id].local_push(x - 2, [&](auto y) { gq.push(y, id); });
}
}
check.wait(round * threads_n);
}
});
}
for (auto &thread : threads) {
thread.join();
}
}
#endif