diff --git a/tdutils/td/utils/port/detail/ThreadStl.h b/tdutils/td/utils/port/detail/ThreadStl.h index ce6ac27fe..dd55a21ec 100644 --- a/tdutils/td/utils/port/detail/ThreadStl.h +++ b/tdutils/td/utils/port/detail/ThreadStl.h @@ -12,6 +12,9 @@ #include "td/utils/common.h" #include "td/utils/invoke.h" +#if TD_WINDOWS +#include "td/utils/port/detail/NativeFd.h" +#endif #include "td/utils/port/detail/ThreadIdGuard.h" #include "td/utils/port/thread_local.h" #include "td/utils/Slice.h" @@ -67,7 +70,7 @@ class ThreadStl { } #if TD_WINDOWS - using id = HANDLE; + using id = DWORD; #else using id = std::thread::id; #endif @@ -81,7 +84,12 @@ class ThreadStl { if (static_cast(mask) != mask) { return Status::Error("Invalid thread affinity mask specified"); } - if (SetThreadAffinityMask(thread_id, static_cast(mask))) { + auto handle = OpenThread(THREAD_SET_LIMITED_INFORMATION | THREAD_QUERY_LIMITED_INFORMATION, FALSE, thread_id); + if (handle == nullptr) { + return Status::Error("Failed to access thread"); + } + NativeFd thread_handle(handle); + if (SetThreadAffinityMask(thread_handle.fd(), static_cast(mask))) { return Status::OK(); } return OS_ERROR("Failed to set thread affinity mask"); @@ -95,9 +103,14 @@ class ThreadStl { DWORD_PTR process_mask = 0; DWORD_PTR system_mask = 0; if (GetProcessAffinityMask(GetCurrentProcess(), &process_mask, &system_mask)) { - auto result = SetThreadAffinityMask(thread_id, process_mask); + auto handle = OpenThread(THREAD_SET_LIMITED_INFORMATION | THREAD_QUERY_LIMITED_INFORMATION, FALSE, thread_id); + if (handle == nullptr) { + return 0; + } + NativeFd thread_handle(handle); + auto result = SetThreadAffinityMask(thread_handle.fd(), process_mask); if (result != 0 && result != process_mask) { - SetThreadAffinityMask(thread_id, result); + SetThreadAffinityMask(thread_handle.fd(), result); } return result; } @@ -117,7 +130,7 @@ class ThreadStl { namespace this_thread_stl { #if TD_WINDOWS inline ThreadStl::id get_id() { - return GetCurrentThread(); + return GetCurrentThreadId(); } #else using std::this_thread::get_id; diff --git a/tdutils/test/port.cpp b/tdutils/test/port.cpp index 79a41596c..9287dae0b 100644 --- a/tdutils/test/port.cpp +++ b/tdutils/test/port.cpp @@ -289,19 +289,20 @@ TEST(Port, EventFdAndSignals) { TEST(Port, ThreadAffinityMask) { auto thread_id = td::this_thread::get_id(); auto old_mask = td::thread::get_affinity_mask(thread_id); - LOG(INFO) << "Initial thread affinity mask: " << old_mask; + LOG(INFO) << "Initial thread " << thread_id << " affinity mask: " << old_mask; for (size_t i = 0; i < 64; i++) { auto mask = td::thread::get_affinity_mask(thread_id); LOG(INFO) << mask; auto result = td::thread::set_affinity_mask(thread_id, static_cast(1) << i); - LOG(INFO) << i << ": " << result; - mask = td::thread::get_affinity_mask(thread_id); - LOG(INFO) << mask; + LOG(INFO) << i << ": " << result << ' ' << td::thread::get_affinity_mask(thread_id); if (i <= 1) { td::thread thread([] { - auto mask = td::thread::get_affinity_mask(td::this_thread::get_id()); - LOG(INFO) << "New thread affinity mask: " << mask; + auto thread_id = td::this_thread::get_id(); + auto mask = td::thread::get_affinity_mask(thread_id); + LOG(INFO) << "New thread " << thread_id << " affinity mask: " << mask; + auto result = td::thread::set_affinity_mask(thread_id, 1); + LOG(INFO) << "Thread " << thread_id << ": " << result << ' ' << td::thread::get_affinity_mask(thread_id); }); } }