Use WSARecv/WSASend instead of ReadFile/WriteFile. Fix check_status.
GitOrigin-RevId: bbfdf27d508f5c985b14bc13bd8549201ae1cb15
This commit is contained in:
parent
3e991d2464
commit
f4c85df878
@ -93,7 +93,7 @@ class ServerSocketFdImpl : private IOCP::Callback {
|
||||
bool close_flag_{false};
|
||||
std::atomic<int> refcnt_{1};
|
||||
bool is_read_active_{false};
|
||||
OVERLAPPED read_overlapped_;
|
||||
WSAOVERLAPPED read_overlapped_;
|
||||
|
||||
char close_overlapped_;
|
||||
|
||||
@ -126,17 +126,14 @@ class ServerSocketFdImpl : private IOCP::Callback {
|
||||
accept_socket_ = NativeFd(socket(socket_family_, SOCK_STREAM, 0));
|
||||
std::memset(&read_overlapped_, 0, sizeof(read_overlapped_));
|
||||
VLOG(fd) << get_native_fd().io_handle() << " start accept";
|
||||
auto status = AcceptEx(get_native_fd().socket(), accept_socket_.socket(), addr_buf_, 0, MAX_ADDR_SIZE,
|
||||
BOOL status = AcceptEx(get_native_fd().socket(), accept_socket_.socket(), addr_buf_, 0, MAX_ADDR_SIZE,
|
||||
MAX_ADDR_SIZE, nullptr, &read_overlapped_);
|
||||
if (check_status(status, "Failed to accept connection")) {
|
||||
if (status == TRUE || check_status("Failed to accept connection")) {
|
||||
inc_refcnt();
|
||||
is_read_active_ = true;
|
||||
}
|
||||
}
|
||||
bool check_status(DWORD status, Slice message) {
|
||||
if (status == 0) {
|
||||
return true;
|
||||
}
|
||||
bool check_status(Slice message) {
|
||||
auto last_error = WSAGetLastError();
|
||||
if (last_error == ERROR_IO_PENDING) {
|
||||
return true;
|
||||
@ -164,7 +161,7 @@ class ServerSocketFdImpl : private IOCP::Callback {
|
||||
get_poll_info().add_flags_from_poll(PollFlags::Error());
|
||||
}
|
||||
|
||||
void on_iocp(Result<size_t> r_size, OVERLAPPED *overlapped) override {
|
||||
void on_iocp(Result<size_t> r_size, WSAOVERLAPPED *overlapped) override {
|
||||
// called from other thread
|
||||
if (dec_refcnt() || close_flag_) {
|
||||
return;
|
||||
@ -180,7 +177,7 @@ class ServerSocketFdImpl : private IOCP::Callback {
|
||||
return on_read();
|
||||
}
|
||||
|
||||
if (overlapped == reinterpret_cast<OVERLAPPED *>(&close_overlapped_)) {
|
||||
if (overlapped == reinterpret_cast<WSAOVERLAPPED *>(&close_overlapped_)) {
|
||||
return on_close();
|
||||
}
|
||||
UNREACHABLE();
|
||||
@ -192,7 +189,7 @@ class ServerSocketFdImpl : private IOCP::Callback {
|
||||
}
|
||||
void notify_iocp_close() {
|
||||
VLOG(fd) << get_native_fd().io_handle() << " notify_close";
|
||||
IOCP::get()->post(0, this, reinterpret_cast<OVERLAPPED *>(&close_overlapped_));
|
||||
IOCP::get()->post(0, this, reinterpret_cast<WSAOVERLAPPED *>(&close_overlapped_));
|
||||
}
|
||||
};
|
||||
void ServerSocketFdImplDeleter::operator()(ServerSocketFdImpl *impl) {
|
||||
|
@ -63,7 +63,7 @@ class SocketFdImpl : private IOCP::Callback {
|
||||
auto status = ConnectExPtr(get_native_fd().socket(), addr.get_sockaddr(), narrow_cast<int>(addr.get_sockaddr_len()),
|
||||
nullptr, 0, nullptr, &read_overlapped_);
|
||||
|
||||
if (!check_status(status, "Failed to connect")) {
|
||||
if (status == TRUE || !check_status("Failed to connect")) {
|
||||
is_read_active_ = false;
|
||||
dec_refcnt();
|
||||
}
|
||||
@ -131,21 +131,18 @@ class SocketFdImpl : private IOCP::Callback {
|
||||
bool is_read_active_{false};
|
||||
ChainBufferWriter input_writer_;
|
||||
ChainBufferReader input_reader_ = input_writer_.extract_reader();
|
||||
OVERLAPPED read_overlapped_;
|
||||
WSAOVERLAPPED read_overlapped_;
|
||||
VectorQueue<Status> pending_errors_;
|
||||
|
||||
bool is_write_active_{false};
|
||||
std::atomic<bool> is_write_waiting_{false};
|
||||
ChainBufferWriter output_writer_;
|
||||
ChainBufferReader output_reader_ = output_writer_.extract_reader();
|
||||
OVERLAPPED write_overlapped_;
|
||||
WSAOVERLAPPED write_overlapped_;
|
||||
|
||||
char close_overlapped_;
|
||||
|
||||
bool check_status(DWORD status, Slice message) {
|
||||
if (status == 0) {
|
||||
return true;
|
||||
}
|
||||
bool check_status(Slice message) {
|
||||
auto last_error = WSAGetLastError();
|
||||
if (last_error == ERROR_IO_PENDING) {
|
||||
return true;
|
||||
@ -162,9 +159,12 @@ class SocketFdImpl : private IOCP::Callback {
|
||||
}
|
||||
std::memset(&read_overlapped_, 0, sizeof(read_overlapped_));
|
||||
auto dest = input_writer_.prepare_append();
|
||||
auto status =
|
||||
ReadFile(get_native_fd().io_handle(), dest.data(), narrow_cast<DWORD>(dest.size()), nullptr, &read_overlapped_);
|
||||
if (check_status(status, "Failed to read from connection")) {
|
||||
WSABUF buf;
|
||||
buf.len = narrow_cast<ULONG>(dest.size());
|
||||
buf.buf = dest.data();
|
||||
DWORD flags = 0;
|
||||
int status = WSARecv(get_native_fd().socket(), &buf, 1, nullptr, &flags, &read_overlapped_, nullptr);
|
||||
if (status == 0 || check_status("Failed to read from connection")) {
|
||||
inc_refcnt();
|
||||
is_read_active_ = true;
|
||||
}
|
||||
@ -189,15 +189,17 @@ class SocketFdImpl : private IOCP::Callback {
|
||||
}
|
||||
auto dest = output_reader_.prepare_read();
|
||||
std::memset(&write_overlapped_, 0, sizeof(write_overlapped_));
|
||||
auto status = WriteFile(get_native_fd().io_handle(), dest.data(), narrow_cast<DWORD>(dest.size()), nullptr,
|
||||
&write_overlapped_);
|
||||
if (check_status(status, "Failed to write to connection")) {
|
||||
WSABUF buf;
|
||||
buf.len = narrow_cast<ULONG>(dest.size());
|
||||
buf.buf = const_cast<CHAR *>(dest.data());
|
||||
int status = WSASend(get_native_fd().socket(), &buf, 1, nullptr, 0, &write_overlapped_, nullptr);
|
||||
if (status == 0 || check_status("Failed to write to connection")) {
|
||||
inc_refcnt();
|
||||
is_write_active_ = true;
|
||||
}
|
||||
}
|
||||
|
||||
void on_iocp(Result<size_t> r_size, OVERLAPPED *overlapped) override {
|
||||
void on_iocp(Result<size_t> r_size, WSAOVERLAPPED *overlapped) override {
|
||||
// called from other thread
|
||||
if (dec_refcnt() || close_flag_) {
|
||||
VLOG(fd) << "ignore iocp (file is closing)";
|
||||
@ -223,7 +225,7 @@ class SocketFdImpl : private IOCP::Callback {
|
||||
if (overlapped == &read_overlapped_) {
|
||||
return on_read(size);
|
||||
}
|
||||
if (overlapped == reinterpret_cast<OVERLAPPED *>(&close_overlapped_)) {
|
||||
if (overlapped == reinterpret_cast<WSAOVERLAPPED *>(&close_overlapped_)) {
|
||||
return on_close();
|
||||
}
|
||||
UNREACHABLE();
|
||||
@ -277,6 +279,7 @@ class SocketFdImpl : private IOCP::Callback {
|
||||
info.set_native_fd({});
|
||||
}
|
||||
bool dec_refcnt() {
|
||||
VLOG(fd) << get_native_fd().io_handle() << " dec_refcnt from " << refcnt_;
|
||||
if (--refcnt_ == 0) {
|
||||
delete this;
|
||||
return true;
|
||||
@ -286,6 +289,7 @@ class SocketFdImpl : private IOCP::Callback {
|
||||
void inc_refcnt() {
|
||||
CHECK(refcnt_ != 0);
|
||||
refcnt_++;
|
||||
VLOG(fd) << get_native_fd().io_handle() << " inc_refcnt to " << refcnt_;
|
||||
}
|
||||
|
||||
void notify_iocp_write() {
|
||||
@ -293,11 +297,11 @@ class SocketFdImpl : private IOCP::Callback {
|
||||
IOCP::get()->post(0, this, nullptr);
|
||||
}
|
||||
void notify_iocp_close() {
|
||||
IOCP::get()->post(0, this, reinterpret_cast<OVERLAPPED *>(&close_overlapped_));
|
||||
IOCP::get()->post(0, this, reinterpret_cast<WSAOVERLAPPED *>(&close_overlapped_));
|
||||
}
|
||||
void notify_iocp_connected() {
|
||||
inc_refcnt();
|
||||
IOCP::get()->post(0, this, reinterpret_cast<OVERLAPPED *>(&read_overlapped_));
|
||||
IOCP::get()->post(0, this, &read_overlapped_);
|
||||
}
|
||||
};
|
||||
|
||||
|
@ -162,7 +162,7 @@ class UdpSocketFdImpl : private IOCP::Callback {
|
||||
bool is_send_active_{false};
|
||||
bool is_send_waiting_{false};
|
||||
VectorQueue<UdpMessage> send_queue_;
|
||||
OVERLAPPED send_overlapped_;
|
||||
WSAOVERLAPPED send_overlapped_;
|
||||
|
||||
bool is_receive_active_{false};
|
||||
VectorQueue<UdpMessage> receive_queue_;
|
||||
@ -174,14 +174,11 @@ class UdpSocketFdImpl : private IOCP::Callback {
|
||||
BufferSlice receive_buffer_;
|
||||
|
||||
UdpMessage to_send_;
|
||||
OVERLAPPED receive_overlapped_;
|
||||
WSAOVERLAPPED receive_overlapped_;
|
||||
|
||||
char close_overlapped_;
|
||||
|
||||
bool check_status(DWORD status, Slice message) {
|
||||
if (status == 0) {
|
||||
return true;
|
||||
}
|
||||
bool check_status(Slice message) {
|
||||
auto last_error = WSAGetLastError();
|
||||
if (last_error == ERROR_IO_PENDING) {
|
||||
return true;
|
||||
@ -214,7 +211,7 @@ class UdpSocketFdImpl : private IOCP::Callback {
|
||||
}
|
||||
|
||||
auto status = WSARecvMsgPtr(get_native_fd().socket(), &receive_message_, nullptr, &receive_overlapped_, nullptr);
|
||||
if (check_status(status, "receive")) {
|
||||
if (status == 0 || check_status("WSARecvMsg failed")) {
|
||||
inc_refcnt();
|
||||
is_receive_active_ = true;
|
||||
}
|
||||
@ -236,13 +233,13 @@ class UdpSocketFdImpl : private IOCP::Callback {
|
||||
UdpSocketSendHelper send_helper;
|
||||
send_helper.to_native(to_send_, message);
|
||||
auto status = WSASendMsg(get_native_fd().socket(), &message, 0, nullptr, &send_overlapped_, nullptr);
|
||||
if (check_status(status, "send")) {
|
||||
if (status == 0 || check_status("WSASendMsg failed")) {
|
||||
inc_refcnt();
|
||||
is_send_active_ = true;
|
||||
}
|
||||
}
|
||||
|
||||
void on_iocp(Result<size_t> r_size, OVERLAPPED *overlapped) override {
|
||||
void on_iocp(Result<size_t> r_size, WSAOVERLAPPED *overlapped) override {
|
||||
// called from other thread
|
||||
if (dec_refcnt() || close_flag_) {
|
||||
VLOG(fd) << "ignore iocp (file is closing)";
|
||||
@ -268,15 +265,14 @@ class UdpSocketFdImpl : private IOCP::Callback {
|
||||
if (overlapped == &receive_overlapped_) {
|
||||
return on_receive(size);
|
||||
}
|
||||
if (overlapped == reinterpret_cast<OVERLAPPED *>(&close_overlapped_)) {
|
||||
if (overlapped == reinterpret_cast<WSAOVERLAPPED *>(&close_overlapped_)) {
|
||||
return on_close();
|
||||
}
|
||||
UNREACHABLE();
|
||||
}
|
||||
|
||||
void on_error(Status status) {
|
||||
VLOG(fd) << get_native_fd().io_handle() << " "
|
||||
<< "on error " << status;
|
||||
VLOG(fd) << get_native_fd().io_handle() << " on error " << status;
|
||||
{
|
||||
auto lock = lock_.lock();
|
||||
pending_errors_.push(std::move(status));
|
||||
@ -346,11 +342,11 @@ class UdpSocketFdImpl : private IOCP::Callback {
|
||||
IOCP::get()->post(0, this, nullptr);
|
||||
}
|
||||
void notify_iocp_close() {
|
||||
IOCP::get()->post(0, this, reinterpret_cast<OVERLAPPED *>(&close_overlapped_));
|
||||
IOCP::get()->post(0, this, reinterpret_cast<WSAOVERLAPPED *>(&close_overlapped_));
|
||||
}
|
||||
void notify_iocp_connected() {
|
||||
inc_refcnt();
|
||||
IOCP::get()->post(0, this, reinterpret_cast<OVERLAPPED *>(&receive_overlapped_));
|
||||
IOCP::get()->post(0, this, reinterpret_cast<WSAOVERLAPPED *>(&receive_overlapped_));
|
||||
}
|
||||
};
|
||||
|
||||
|
@ -28,8 +28,9 @@ void IOCP::loop() {
|
||||
while (true) {
|
||||
DWORD bytes = 0;
|
||||
ULONG_PTR key = 0;
|
||||
OVERLAPPED *overlapped = nullptr;
|
||||
BOOL ok = GetQueuedCompletionStatus(iocp_handle_.io_handle(), &bytes, &key, &overlapped, 1000);
|
||||
WSAOVERLAPPED *overlapped = nullptr;
|
||||
BOOL ok = GetQueuedCompletionStatus(iocp_handle_.io_handle(), &bytes, &key,
|
||||
reinterpret_cast<OVERLAPPED **>(&overlapped), 1000);
|
||||
if (bytes || key || overlapped) {
|
||||
// LOG(ERROR) << "Got iocp " << bytes << " " << key << " " << overlapped;
|
||||
}
|
||||
@ -80,8 +81,9 @@ void IOCP::subscribe(const NativeFd &native_fd, Callback *callback) {
|
||||
CHECK(iocp_handle == iocp_handle_.io_handle()) << iocp_handle << " " << iocp_handle_.io_handle();
|
||||
}
|
||||
|
||||
void IOCP::post(size_t size, Callback *callback, OVERLAPPED *overlapped) {
|
||||
PostQueuedCompletionStatus(iocp_handle_.io_handle(), DWORD(size), reinterpret_cast<ULONG_PTR>(callback), overlapped);
|
||||
void IOCP::post(size_t size, Callback *callback, WSAOVERLAPPED *overlapped) {
|
||||
PostQueuedCompletionStatus(iocp_handle_.io_handle(), DWORD(size), reinterpret_cast<ULONG_PTR>(callback),
|
||||
reinterpret_cast<OVERLAPPED *>(overlapped));
|
||||
}
|
||||
|
||||
void WineventPoll::init() {
|
||||
|
@ -34,12 +34,12 @@ class IOCP final : public Context<IOCP> {
|
||||
class Callback {
|
||||
public:
|
||||
virtual ~Callback() = default;
|
||||
virtual void on_iocp(Result<size_t> r_size, OVERLAPPED *overlapped) = 0;
|
||||
virtual void on_iocp(Result<size_t> r_size, WSAOVERLAPPED *overlapped) = 0;
|
||||
};
|
||||
|
||||
void init();
|
||||
void subscribe(const NativeFd &fd, Callback *callback);
|
||||
void post(size_t size, Callback *callback, OVERLAPPED *overlapped);
|
||||
void post(size_t size, Callback *callback, WSAOVERLAPPED *overlapped);
|
||||
void loop();
|
||||
void interrupt_loop();
|
||||
void clear();
|
||||
|
Loading…
Reference in New Issue
Block a user