tdutils: add skip_eintr_timeout and use it for EventFd

GitOrigin-RevId: 716218731f914e89e6f1e542054298380441b522
This commit is contained in:
Arseny Smirnov 2020-08-07 18:50:33 +03:00
parent 857f981847
commit c48ef93e1e
4 changed files with 83 additions and 9 deletions

View File

@ -12,9 +12,11 @@ char disable_linker_warning_about_empty_file_event_fd_bsd_cpp TD_UNUSED;
#include "td/utils/logging.h"
#include "td/utils/port/detail/NativeFd.h"
#include "td/utils/port/detail/skip_eintr.h"
#include "td/utils/port/PollFlags.h"
#include "td/utils/port/SocketFd.h"
#include "td/utils/Slice.h"
#include "td/utils/Time.h"
#include <cerrno>
@ -93,10 +95,14 @@ void EventFdBsd::acquire() {
}
void EventFdBsd::wait(int timeout_ms) {
pollfd fd;
fd.fd = get_poll_info().native_fd().fd();
fd.events = POLLIN;
poll(&fd, 1, timeout_ms);
detail::skip_eintr_timeout(
[this](int timeout_ms) {
pollfd fd;
fd.fd = get_poll_info().native_fd().fd();
fd.events = POLLIN;
return poll(&fd, 1, timeout_ms);
},
timeout_ms);
}
} // namespace detail

View File

@ -117,10 +117,14 @@ void EventFdLinux::acquire() {
}
void EventFdLinux::wait(int timeout_ms) {
pollfd fd;
fd.fd = get_poll_info().native_fd().fd();
fd.events = POLLIN;
poll(&fd, 1, timeout_ms);
detail::skip_eintr_timeout(
[this](int timeout_ms) {
pollfd fd;
fd.fd = get_poll_info().native_fd().fd();
fd.events = POLLIN;
return poll(&fd, 1, timeout_ms);
},
timeout_ms);
}
} // namespace detail

View File

@ -9,6 +9,7 @@
#if TD_PORT_POSIX
#include <cerrno>
#include <type_traits>
#include "td/utils/Time.h"
#endif
namespace td {
@ -35,6 +36,24 @@ auto skip_eintr_cstr(F &&f) {
} while (res == nullptr && errno == EINTR);
return res;
}
template <class F>
auto skip_eintr_timeout(F &&f, int32 timeout_ms) {
decltype(f(timeout_ms)) res;
static_assert(std::is_integral<decltype(res)>::value, "integral type expected");
auto start = Timestamp::now();
auto left_timeout_ms = timeout_ms;
while (true) {
errno = 0; // just in case
res = f(left_timeout_ms);
if (res >= 0 || errno != EINTR) {
break;
}
left_timeout_ms = max(static_cast<int32>((start.at() - Timestamp::now().at()) * 1000 + timeout_ms + 1 - 1e-9), 0);
}
return res;
}
} // namespace detail
#endif

View File

@ -7,15 +7,19 @@
#include "td/utils/common.h"
#include "td/utils/logging.h"
#include "td/utils/misc.h"
#include "td/utils/port/EventFd.h"
#include "td/utils/port/FileFd.h"
#include "td/utils/port/IoSlice.h"
#include "td/utils/port/path.h"
#include "td/utils/port/signals.h"
#include "td/utils/port/sleep.h"
#include "td/utils/port/thread.h"
#include "td/utils/port/thread_local.h"
#include "td/utils/Slice.h"
#include "td/utils/tests.h"
#include "td/utils/Time.h"
#include <atomic>
#include <set>
using namespace td;
@ -155,9 +159,12 @@ static void on_user_signal(int sig) {
ptrs.push_back(std::string(ptr));
}
TEST(Post, SignalsAndThread) {
TEST(Port, SignalsAndThread) {
setup_signals_alt_stack().ensure();
set_signal_handler(SignalType::User, on_user_signal).ensure();
SCOPE_EXIT {
set_signal_handler(SignalType::User, nullptr).ensure();
};
std::vector<std::string> ans = {"0", "1", "2", "3", "4", "5", "6", "7", "8", "9"};
{
std::vector<td::thread> threads;
@ -212,4 +219,42 @@ TEST(Post, SignalsAndThread) {
//LOG(ERROR) << addrs;
}
}
TEST(Port, EventFdAndSignals) {
set_signal_handler(SignalType::User, [](int signal) {}).ensure();
SCOPE_EXIT {
set_signal_handler(SignalType::User, nullptr).ensure();
};
std::atomic_flag flag;
flag.test_and_set();
auto main_thread = pthread_self();
td::thread interrupt_thread{[&flag, &main_thread] {
setup_signals_alt_stack().ensure();
while (flag.test_and_set()) {
pthread_kill(main_thread, SIGUSR1);
td::usleep_for(1000 * td::Random::fast(1, 10)); // 0.001s - 0.01s
}
}};
for (int timeout_ms : {0, 1, 2, 10, 100, 500}) {
double min_diff = 10000000;
double max_diff = 0;
for (int t = 0; t < max(5, 1000 / max(timeout_ms, 1)); t++) {
td::EventFd event_fd;
event_fd.init();
auto start = td::Timestamp::now();
event_fd.wait(timeout_ms);
auto end = td::Timestamp::now();
auto passed = end.at() - start.at();
auto diff = passed * 1000 - timeout_ms;
min_diff = min(min_diff, diff);
max_diff = max(max_diff, diff);
}
LOG_CHECK(min_diff >= 0) << min_diff;
LOG_CHECK(max_diff < 10) << max_diff;
LOG(ERROR) << min_diff << " " << max_diff;
}
flag.clear();
}
#endif