diff --git a/tdactor/td/actor/impl/ConcurrentScheduler.cpp b/tdactor/td/actor/impl/ConcurrentScheduler.cpp index 46a726975..167402a52 100644 --- a/tdactor/td/actor/impl/ConcurrentScheduler.cpp +++ b/tdactor/td/actor/impl/ConcurrentScheduler.cpp @@ -14,10 +14,6 @@ #include "td/utils/MpscPollableQueue.h" #include "td/utils/port/thread_local.h" -#if TD_PORT_WINDOWS -#include "td/utils/port/detail/WineventPoll.h" -#endif - #include namespace td { @@ -61,7 +57,7 @@ void ConcurrentScheduler::init(int32 threads_n) { } #if TD_PORT_WINDOWS - iocp_ = std::make_unique(); + iocp_ = std::make_unique(); iocp_->init(); #endif @@ -86,7 +82,7 @@ void ConcurrentScheduler::start() { threads_.push_back(td::thread([&, tid = i]() { set_thread_id(static_cast(tid)); #if TD_PORT_WINDOWS - td::detail::IOCP::Guard iocp_guard(iocp_.get()); + td::detail::Iocp::Guard iocp_guard(iocp_.get()); #endif while (!is_finished()) { sched->run(10); @@ -110,7 +106,7 @@ bool ConcurrentScheduler::run_main(double timeout) { auto &main_sched = schedulers_[0]; if (!is_finished()) { #if TD_PORT_WINDOWS - td::detail::IOCP::Guard iocp_guard(iocp_.get()); + td::detail::Iocp::Guard iocp_guard(iocp_.get()); #endif main_sched->run(timeout); } @@ -126,7 +122,7 @@ void ConcurrentScheduler::finish() { SCOPE_EXIT { iocp_->clear(); }; - td::detail::IOCP::Guard iocp_guard(iocp_.get()); + td::detail::Iocp::Guard iocp_guard(iocp_.get()); #endif #if !TD_THREAD_UNSUPPORTED && !TD_EVENTFD_UNSUPPORTED diff --git a/tdactor/td/actor/impl/ConcurrentScheduler.h b/tdactor/td/actor/impl/ConcurrentScheduler.h index 52c6c5ab4..f835badba 100644 --- a/tdactor/td/actor/impl/ConcurrentScheduler.h +++ b/tdactor/td/actor/impl/ConcurrentScheduler.h @@ -13,6 +13,10 @@ #include "td/utils/port/thread.h" #include "td/utils/Slice.h" +#if TD_PORT_WINDOWS +#include "td/utils/port/detail/Iocp.h" +#endif + #include #include #include @@ -20,12 +24,6 @@ namespace td { -#if TD_PORT_WINDOWS -namespace detail { -class IOCP; -} -#endif - class ConcurrentScheduler : private Scheduler::Callback { public: void init(int32 threads_n); @@ -87,7 +85,7 @@ class ConcurrentScheduler : private Scheduler::Callback { std::vector threads_; #endif #if TD_PORT_WINDOWS - std::unique_ptr iocp_; + std::unique_ptr iocp_; td::thread iocp_thread_; #endif diff --git a/tdutils/CMakeLists.txt b/tdutils/CMakeLists.txt index 414e3d3c5..c63175801 100644 --- a/tdutils/CMakeLists.txt +++ b/tdutils/CMakeLists.txt @@ -52,6 +52,7 @@ set(TDUTILS_SOURCE td/utils/port/detail/EventFdBsd.cpp td/utils/port/detail/EventFdLinux.cpp td/utils/port/detail/EventFdWindows.cpp + td/utils/port/detail/Iocp.cpp td/utils/port/detail/KQueue.cpp td/utils/port/detail/NativeFd.cpp td/utils/port/detail/Poll.cpp @@ -117,6 +118,7 @@ set(TDUTILS_SOURCE td/utils/port/detail/EventFdBsd.h td/utils/port/detail/EventFdLinux.h td/utils/port/detail/EventFdWindows.h + td/utils/port/detail/Iocp.h td/utils/port/detail/KQueue.h td/utils/port/detail/NativeFd.h td/utils/port/detail/Poll.h diff --git a/tdutils/td/utils/port/ServerSocketFd.cpp b/tdutils/td/utils/port/ServerSocketFd.cpp index 66a405f94..712d983ec 100644 --- a/tdutils/td/utils/port/ServerSocketFd.cpp +++ b/tdutils/td/utils/port/ServerSocketFd.cpp @@ -25,7 +25,7 @@ #endif #if TD_PORT_WINDOWS -#include "td/utils/port/detail/WineventPoll.h" +#include "td/utils/port/detail/Iocp.h" #include "td/utils/SpinLock.h" #include "td/utils/VectorQueue.h" #endif @@ -37,11 +37,11 @@ namespace td { namespace detail { #if TD_PORT_WINDOWS -class ServerSocketFdImpl : private IOCP::Callback { +class ServerSocketFdImpl : private Iocp::Callback { public: ServerSocketFdImpl(NativeFd fd, int socket_family) : info_(std::move(fd)), socket_family_(socket_family) { VLOG(fd) << get_native_fd().socket() << " create ServerSocketFd"; - IOCP::get()->subscribe(get_native_fd(), this); + Iocp::get()->subscribe(get_native_fd(), this); notify_iocp_read(); } void close() { @@ -185,11 +185,11 @@ class ServerSocketFdImpl : private IOCP::Callback { void notify_iocp_read() { VLOG(fd) << get_native_fd().socket() << " notify_read"; inc_refcnt(); - IOCP::get()->post(0, this, nullptr); + Iocp::get()->post(0, this, nullptr); } void notify_iocp_close() { VLOG(fd) << get_native_fd().socket() << " notify_close"; - IOCP::get()->post(0, this, reinterpret_cast(&close_overlapped_)); + Iocp::get()->post(0, this, reinterpret_cast(&close_overlapped_)); } }; void ServerSocketFdImplDeleter::operator()(ServerSocketFdImpl *impl) { diff --git a/tdutils/td/utils/port/SocketFd.cpp b/tdutils/td/utils/port/SocketFd.cpp index a52748a91..ecd7311a5 100644 --- a/tdutils/td/utils/port/SocketFd.cpp +++ b/tdutils/td/utils/port/SocketFd.cpp @@ -12,7 +12,7 @@ #if TD_PORT_WINDOWS #include "td/utils/buffer.h" -#include "td/utils/port/detail/WineventPoll.h" +#include "td/utils/port/detail/Iocp.h" #include "td/utils/SpinLock.h" #include "td/utils/VectorQueue.h" #endif @@ -33,12 +33,12 @@ namespace td { namespace detail { #if TD_PORT_WINDOWS -class SocketFdImpl : private IOCP::Callback { +class SocketFdImpl : private Iocp::Callback { public: explicit SocketFdImpl(NativeFd native_fd) : info(std::move(native_fd)) { VLOG(fd) << get_native_fd().socket() << " create from native_fd"; get_poll_info().add_flags(PollFlags::Write()); - IOCP::get()->subscribe(get_native_fd(), this); + Iocp::get()->subscribe(get_native_fd(), this); is_read_active_ = true; notify_iocp_connected(); } @@ -46,7 +46,7 @@ class SocketFdImpl : private IOCP::Callback { SocketFdImpl(NativeFd native_fd, const IPAddress &addr) : info(std::move(native_fd)) { VLOG(fd) << get_native_fd().socket() << " create from native_fd and connect"; get_poll_info().add_flags(PollFlags::Write()); - IOCP::get()->subscribe(get_native_fd(), this); + Iocp::get()->subscribe(get_native_fd(), this); LPFN_CONNECTEX ConnectExPtr = nullptr; GUID guid = WSAID_CONNECTEX; DWORD numBytes; @@ -294,14 +294,14 @@ class SocketFdImpl : private IOCP::Callback { void notify_iocp_write() { inc_refcnt(); - IOCP::get()->post(0, this, nullptr); + Iocp::get()->post(0, this, nullptr); } void notify_iocp_close() { - IOCP::get()->post(0, this, reinterpret_cast(&close_overlapped_)); + Iocp::get()->post(0, this, reinterpret_cast(&close_overlapped_)); } void notify_iocp_connected() { inc_refcnt(); - IOCP::get()->post(0, this, &read_overlapped_); + Iocp::get()->post(0, this, &read_overlapped_); } }; @@ -498,7 +498,8 @@ Result SocketFd::open(const IPAddress &address) { TRY_STATUS(detail::init_socket_options(native_fd)); #if TD_PORT_POSIX - int e_connect = connect(native_fd.socket(), address.get_sockaddr(), narrow_cast(address.get_sockaddr_len())); + int e_connect = + connect(native_fd.socket(), address.get_sockaddr(), narrow_cast(address.get_sockaddr_len())); if (e_connect == -1) { auto connect_errno = errno; if (connect_errno != EINPROGRESS) { diff --git a/tdutils/td/utils/port/UdpSocketFd.cpp b/tdutils/td/utils/port/UdpSocketFd.cpp index 4ec12bc35..d0fba9c63 100644 --- a/tdutils/td/utils/port/UdpSocketFd.cpp +++ b/tdutils/td/utils/port/UdpSocketFd.cpp @@ -15,7 +15,7 @@ #include "td/utils/VectorQueue.h" #if TD_PORT_WINDOWS -#include "td/utils/port/detail/WineventPoll.h" +#include "td/utils/port/detail/Iocp.h" #include "td/utils/SpinLock.h" #endif @@ -98,11 +98,11 @@ class UdpSocketSendHelper { WSABUF buf_; }; -class UdpSocketFdImpl : private IOCP::Callback { +class UdpSocketFdImpl : private Iocp::Callback { public: explicit UdpSocketFdImpl(NativeFd fd) : info_(std::move(fd)) { get_poll_info().add_flags(PollFlags::Write()); - IOCP::get()->subscribe(get_native_fd(), this); + Iocp::get()->subscribe(get_native_fd(), this); is_receive_active_ = true; notify_iocp_connected(); } @@ -339,14 +339,14 @@ class UdpSocketFdImpl : private IOCP::Callback { void notify_iocp_send() { inc_refcnt(); - IOCP::get()->post(0, this, nullptr); + Iocp::get()->post(0, this, nullptr); } void notify_iocp_close() { - IOCP::get()->post(0, this, reinterpret_cast(&close_overlapped_)); + Iocp::get()->post(0, this, reinterpret_cast(&close_overlapped_)); } void notify_iocp_connected() { inc_refcnt(); - IOCP::get()->post(0, this, reinterpret_cast(&receive_overlapped_)); + Iocp::get()->post(0, this, reinterpret_cast(&receive_overlapped_)); } }; diff --git a/tdutils/td/utils/port/detail/Iocp.cpp b/tdutils/td/utils/port/detail/Iocp.cpp new file mode 100644 index 000000000..f651d54ea --- /dev/null +++ b/tdutils/td/utils/port/detail/Iocp.cpp @@ -0,0 +1,87 @@ +// +// Copyright Aliaksei Levin (levlam@telegram.org), Arseny Smirnov (arseny30@gmail.com) 2014-2018 +// +// Distributed under the Boost Software License, Version 1.0. (See accompanying +// file LICENSE_1_0.txt or copy at http://www.boost.org/LICENSE_1_0.txt) +// +#include "td/utils/port/detail/Iocp.h" + +char disable_linker_warning_about_empty_file_iocp_cpp TD_UNUSED; + +#ifdef TD_PORT_WINDOWS + +#include "td/utils/logging.h" + +namespace td { +namespace detail { +Iocp::~Iocp() { + clear(); +} + +void Iocp::loop() { + Iocp::Guard guard(this); + while (true) { + DWORD bytes = 0; + ULONG_PTR key = 0; + WSAOVERLAPPED *overlapped = nullptr; + BOOL ok = + GetQueuedCompletionStatus(iocp_handle_.fd(), &bytes, &key, reinterpret_cast(&overlapped), 1000); + if (bytes || key || overlapped) { + // LOG(ERROR) << "Got IOCP " << bytes << " " << key << " " << overlapped; + } + if (ok) { + auto callback = reinterpret_cast(key); + if (callback == nullptr) { + // LOG(ERROR) << "Interrupt IOCP loop"; + return; + } + callback->on_iocp(bytes, overlapped); + } else { + if (overlapped != nullptr) { + auto error = OS_ERROR("Received from IOCP"); + auto callback = reinterpret_cast(key); + CHECK(callback != nullptr); + callback->on_iocp(std::move(error), overlapped); + } + } + } +} + +void Iocp::interrupt_loop() { + post(0, nullptr, nullptr); +} + +void Iocp::init() { + CHECK(!iocp_handle_); + auto res = CreateIoCompletionPort(INVALID_HANDLE_VALUE, nullptr, 0, 0); + if (res == nullptr) { + auto error = OS_ERROR("IOCP creation failed"); + LOG(FATAL) << error; + } + iocp_handle_ = NativeFd(res); +} + +void Iocp::clear() { + iocp_handle_.close(); +} + +void Iocp::subscribe(const NativeFd &native_fd, Callback *callback) { + CHECK(iocp_handle_); + auto iocp_handle = + CreateIoCompletionPort(native_fd.fd(), iocp_handle_.fd(), reinterpret_cast(callback), 0); + if (iocp_handle == INVALID_HANDLE_VALUE) { + auto error = OS_ERROR("CreateIoCompletionPort"); + LOG(FATAL) << error; + } + CHECK(iocp_handle == iocp_handle_.fd()) << iocp_handle << " " << iocp_handle_.fd(); +} + +void Iocp::post(size_t size, Callback *callback, WSAOVERLAPPED *overlapped) { + PostQueuedCompletionStatus(iocp_handle_.fd(), DWORD(size), reinterpret_cast(callback), + reinterpret_cast(overlapped)); +} + +} // namespace detail +} // namespace td + +#endif diff --git a/tdutils/td/utils/port/detail/Iocp.h b/tdutils/td/utils/port/detail/Iocp.h new file mode 100644 index 000000000..9c7caec4f --- /dev/null +++ b/tdutils/td/utils/port/detail/Iocp.h @@ -0,0 +1,52 @@ +// +// Copyright Aliaksei Levin (levlam@telegram.org), Arseny Smirnov (arseny30@gmail.com) 2014-2018 +// +// Distributed under the Boost Software License, Version 1.0. (See accompanying +// file LICENSE_1_0.txt or copy at http://www.boost.org/LICENSE_1_0.txt) +// +#pragma once + +#include "td/utils/port/config.h" + +#ifdef TD_PORT_WINDOWS + +#include "td/utils/common.h" +#include "td/utils/Context.h" +#include "td/utils/port/detail/NativeFd.h" +#include "td/utils/port/thread.h" +#include "td/utils/Status.h" + +namespace td { +namespace detail { + +class Iocp final : public Context { + public: + Iocp() = default; + Iocp(const Iocp &) = delete; + Iocp &operator=(const Iocp &) = delete; + Iocp(Iocp &&) = delete; + Iocp &operator=(Iocp &&) = delete; + ~Iocp(); + + class Callback { + public: + virtual ~Callback() = default; + virtual void on_iocp(Result r_size, WSAOVERLAPPED *overlapped) = 0; + }; + + void init(); + void subscribe(const NativeFd &fd, Callback *callback); + void post(size_t size, Callback *callback, WSAOVERLAPPED *overlapped); + void loop(); + void interrupt_loop(); + void clear(); + + private: + NativeFd iocp_handle_; + std::vector workers_; +}; + +} // namespace detail +} // namespace td + +#endif diff --git a/tdutils/td/utils/port/detail/WineventPoll.cpp b/tdutils/td/utils/port/detail/WineventPoll.cpp index e433b32f4..05be85cac 100644 --- a/tdutils/td/utils/port/detail/WineventPoll.cpp +++ b/tdutils/td/utils/port/detail/WineventPoll.cpp @@ -10,81 +10,10 @@ char disable_linker_warning_about_empty_file_wineventpoll_cpp TD_UNUSED; #ifdef TD_POLL_WINEVENT -#include "td/utils/common.h" #include "td/utils/logging.h" -#include "td/utils/port/PollBase.h" -#include "td/utils/Status.h" - -#include namespace td { namespace detail { -IOCP::~IOCP() { - clear(); -} - -void IOCP::loop() { - IOCP::Guard guard(this); - while (true) { - DWORD bytes = 0; - ULONG_PTR key = 0; - WSAOVERLAPPED *overlapped = nullptr; - BOOL ok = - GetQueuedCompletionStatus(iocp_handle_.fd(), &bytes, &key, reinterpret_cast(&overlapped), 1000); - if (bytes || key || overlapped) { - // LOG(ERROR) << "Got IOCP " << bytes << " " << key << " " << overlapped; - } - if (ok) { - auto callback = reinterpret_cast(key); - if (callback == nullptr) { - // LOG(ERROR) << "Interrupt IOCP loop"; - return; - } - callback->on_iocp(bytes, overlapped); - } else { - if (overlapped != nullptr) { - auto error = OS_ERROR("Received from IOCP"); - auto callback = reinterpret_cast(key); - CHECK(callback != nullptr); - callback->on_iocp(std::move(error), overlapped); - } - } - } -} - -void IOCP::interrupt_loop() { - post(0, nullptr, nullptr); -} - -void IOCP::init() { - CHECK(!iocp_handle_); - auto res = CreateIoCompletionPort(INVALID_HANDLE_VALUE, nullptr, 0, 0); - if (res == nullptr) { - auto error = OS_ERROR("IOCP creation failed"); - LOG(FATAL) << error; - } - iocp_handle_ = NativeFd(res); -} - -void IOCP::clear() { - iocp_handle_.close(); -} - -void IOCP::subscribe(const NativeFd &native_fd, Callback *callback) { - CHECK(iocp_handle_); - auto iocp_handle = - CreateIoCompletionPort(native_fd.fd(), iocp_handle_.fd(), reinterpret_cast(callback), 0); - if (iocp_handle == INVALID_HANDLE_VALUE) { - auto error = OS_ERROR("CreateIoCompletionPort"); - LOG(FATAL) << error; - } - CHECK(iocp_handle == iocp_handle_.fd()) << iocp_handle << " " << iocp_handle_.fd(); -} - -void IOCP::post(size_t size, Callback *callback, WSAOVERLAPPED *overlapped) { - PostQueuedCompletionStatus(iocp_handle_.fd(), DWORD(size), reinterpret_cast(callback), - reinterpret_cast(overlapped)); -} void WineventPoll::init() { } diff --git a/tdutils/td/utils/port/detail/WineventPoll.h b/tdutils/td/utils/port/detail/WineventPoll.h index 6213541aa..bc9bbd16e 100644 --- a/tdutils/td/utils/port/detail/WineventPoll.h +++ b/tdutils/td/utils/port/detail/WineventPoll.h @@ -11,44 +11,13 @@ #ifdef TD_POLL_WINEVENT #include "td/utils/common.h" -#include "td/utils/Context.h" -#include "td/utils/port/detail/NativeFd.h" #include "td/utils/port/detail/PollableFd.h" #include "td/utils/port/PollBase.h" #include "td/utils/port/PollFlags.h" -#include "td/utils/port/thread.h" -#include "td/utils/Status.h" namespace td { namespace detail { -class IOCP final : public Context { - public: - IOCP() = default; - IOCP(const IOCP &) = delete; - IOCP &operator=(const IOCP &) = delete; - IOCP(IOCP &&) = delete; - IOCP &operator=(IOCP &&) = delete; - ~IOCP(); - - class Callback { - public: - virtual ~Callback() = default; - virtual void on_iocp(Result r_size, WSAOVERLAPPED *overlapped) = 0; - }; - - void init(); - void subscribe(const NativeFd &fd, Callback *callback); - void post(size_t size, Callback *callback, WSAOVERLAPPED *overlapped); - void loop(); - void interrupt_loop(); - void clear(); - - private: - NativeFd iocp_handle_; - std::vector workers_; -}; - class WineventPoll final : public PollBase { public: WineventPoll() = default;