tdlight/tdutils/td/utils/port/UdpSocketFd.cpp
levlam 3442a88413 Unify constant names style.
GitOrigin-RevId: 6e4475366b94cea6ab0331d57f254311490bdee2
2020-06-16 05:10:16 +03:00

866 lines
26 KiB
C++

//
// Copyright Aliaksei Levin (levlam@telegram.org), Arseny Smirnov (arseny30@gmail.com) 2014-2020
//
// Distributed under the Boost Software License, Version 1.0. (See accompanying
// file LICENSE_1_0.txt or copy at http://www.boost.org/LICENSE_1_0.txt)
//
#include "td/utils/port/UdpSocketFd.h"
#include "td/utils/common.h"
#include "td/utils/format.h"
#include "td/utils/logging.h"
#include "td/utils/misc.h"
#include "td/utils/port/detail/skip_eintr.h"
#include "td/utils/port/PollFlags.h"
#include "td/utils/port/SocketFd.h"
#include "td/utils/VectorQueue.h"
#if TD_PORT_WINDOWS
#include "td/utils/port/detail/Iocp.h"
#include "td/utils/SpinLock.h"
#endif
#if TD_PORT_POSIX
#include <cerrno>
#include <arpa/inet.h>
#include <fcntl.h>
#include <netinet/in.h>
#include <netinet/tcp.h>
#include <sys/socket.h>
#include <sys/types.h>
#include <unistd.h>
#if TD_LINUX
#include <linux/errqueue.h>
#endif
#endif // TD_PORT_POSIX
#include <array>
#include <atomic>
#include <cstring>
namespace td {
namespace detail {
#if TD_PORT_WINDOWS
class UdpSocketReceiveHelper {
public:
void to_native(const UdpMessage &message, WSAMSG &message_header) {
socklen_t addr_len{narrow_cast<socklen_t>(sizeof(addr_))};
message_header.name = reinterpret_cast<struct sockaddr *>(&addr_);
message_header.namelen = addr_len;
buf_.buf = const_cast<char *>(message.data.as_slice().begin());
buf_.len = narrow_cast<DWORD>(message.data.size());
message_header.lpBuffers = &buf_;
message_header.dwBufferCount = 1;
message_header.Control.buf = nullptr; // control_buf_.data();
message_header.Control.len = 0; // narrow_cast<decltype(message_header.Control.len)>(control_buf_.size());
message_header.dwFlags = 0;
}
void from_native(WSAMSG &message_header, size_t message_size, UdpMessage &message) {
message.address.init_sockaddr(reinterpret_cast<struct sockaddr *>(message_header.name), message_header.namelen)
.ignore();
message.error = Status::OK();
if ((message_header.dwFlags & (MSG_TRUNC | MSG_CTRUNC)) != 0) {
message.error = Status::Error(501, "message too long");
message.data = BufferSlice();
return;
}
CHECK(message_size <= message.data.size());
message.data.truncate(message_size);
CHECK(message_size == message.data.size());
}
private:
std::array<char, 1024> control_buf_;
sockaddr_storage addr_;
WSABUF buf_;
};
class UdpSocketSendHelper {
public:
void to_native(const UdpMessage &message, WSAMSG &message_header) {
message_header.name = const_cast<struct sockaddr *>(message.address.get_sockaddr());
message_header.namelen = narrow_cast<socklen_t>(message.address.get_sockaddr_len());
buf_.buf = const_cast<char *>(message.data.as_slice().begin());
buf_.len = narrow_cast<DWORD>(message.data.size());
message_header.lpBuffers = &buf_;
message_header.dwBufferCount = 1;
message_header.Control.buf = nullptr;
message_header.Control.len = 0;
message_header.dwFlags = 0;
}
private:
WSABUF buf_;
};
class UdpSocketFdImpl : private Iocp::Callback {
public:
explicit UdpSocketFdImpl(NativeFd fd) : info_(std::move(fd)) {
get_poll_info().add_flags(PollFlags::Write());
Iocp::get()->subscribe(get_native_fd(), this);
is_receive_active_ = true;
notify_iocp_connected();
}
PollableFdInfo &get_poll_info() {
return info_;
}
const PollableFdInfo &get_poll_info() const {
return info_;
}
const NativeFd &get_native_fd() const {
return info_.native_fd();
}
void close() {
notify_iocp_close();
}
Result<optional<UdpMessage>> receive() {
auto lock = lock_.lock();
if (!pending_errors_.empty()) {
auto status = pending_errors_.pop();
if (!UdpSocketFd::is_critical_read_error(status)) {
return UdpMessage{{}, {}, std::move(status)};
}
return std::move(status);
}
if (!receive_queue_.empty()) {
return receive_queue_.pop();
}
return optional<UdpMessage>{};
}
void send(UdpMessage message) {
auto lock = lock_.lock();
send_queue_.push(std::move(message));
}
Status flush_send() {
if (is_send_waiting_) {
auto lock = lock_.lock();
is_send_waiting_ = false;
notify_iocp_send();
}
return Status::OK();
}
private:
PollableFdInfo info_;
SpinLock lock_;
std::atomic<int> refcnt_{1};
bool is_connected_{false};
bool close_flag_{false};
bool is_send_active_{false};
bool is_send_waiting_{false};
VectorQueue<UdpMessage> send_queue_;
WSAOVERLAPPED send_overlapped_;
bool is_receive_active_{false};
VectorQueue<UdpMessage> receive_queue_;
VectorQueue<Status> pending_errors_;
UdpMessage to_receive_;
WSAMSG receive_message_;
UdpSocketReceiveHelper receive_helper_;
static constexpr size_t MAX_PACKET_SIZE = 2048;
static constexpr size_t RESERVED_SIZE = MAX_PACKET_SIZE * 8;
BufferSlice receive_buffer_;
UdpMessage to_send_;
WSAOVERLAPPED receive_overlapped_;
char close_overlapped_;
bool check_status(Slice message) {
auto last_error = WSAGetLastError();
if (last_error == ERROR_IO_PENDING) {
return true;
}
on_error(OS_SOCKET_ERROR(message));
return false;
}
void loop_receive() {
CHECK(!is_receive_active_);
if (close_flag_) {
return;
}
std::memset(&receive_overlapped_, 0, sizeof(receive_overlapped_));
if (receive_buffer_.size() < MAX_PACKET_SIZE) {
receive_buffer_ = BufferSlice(RESERVED_SIZE);
}
to_receive_.data = receive_buffer_.clone();
receive_helper_.to_native(to_receive_, receive_message_);
LPFN_WSARECVMSG WSARecvMsgPtr = nullptr;
GUID guid = WSAID_WSARECVMSG;
DWORD numBytes;
auto error = ::WSAIoctl(get_native_fd().socket(), SIO_GET_EXTENSION_FUNCTION_POINTER, static_cast<void *>(&guid),
sizeof(guid), static_cast<void *>(&WSARecvMsgPtr), sizeof(WSARecvMsgPtr), &numBytes,
nullptr, nullptr);
if (error) {
on_error(OS_SOCKET_ERROR("WSAIoctl failed"));
return;
}
auto status = WSARecvMsgPtr(get_native_fd().socket(), &receive_message_, nullptr, &receive_overlapped_, nullptr);
if (status == 0 || check_status("WSARecvMsg failed")) {
inc_refcnt();
is_receive_active_ = true;
}
}
void loop_send() {
CHECK(!is_send_active_);
{
auto lock = lock_.lock();
if (send_queue_.empty()) {
is_send_waiting_ = true;
return;
}
to_send_ = send_queue_.pop();
}
std::memset(&send_overlapped_, 0, sizeof(send_overlapped_));
WSAMSG message;
UdpSocketSendHelper send_helper;
send_helper.to_native(to_send_, message);
auto status = WSASendMsg(get_native_fd().socket(), &message, 0, nullptr, &send_overlapped_, nullptr);
if (status == 0 || check_status("WSASendMsg failed")) {
inc_refcnt();
is_send_active_ = true;
}
}
void on_iocp(Result<size_t> r_size, WSAOVERLAPPED *overlapped) override {
// called from other thread
if (dec_refcnt() || close_flag_) {
VLOG(fd) << "Ignore IOCP (UDP socket is closing)";
return;
}
if (r_size.is_error()) {
return on_error(get_socket_pending_error(get_native_fd(), overlapped, r_size.move_as_error()));
}
if (!is_connected_ && overlapped == &receive_overlapped_) {
return on_connected();
}
auto size = r_size.move_as_ok();
if (overlapped == &send_overlapped_) {
return on_send(size);
}
if (overlapped == nullptr) {
CHECK(size == 0);
return on_send(size);
}
if (overlapped == &receive_overlapped_) {
return on_receive(size);
}
if (overlapped == reinterpret_cast<WSAOVERLAPPED *>(&close_overlapped_)) {
return on_close();
}
UNREACHABLE();
}
void on_error(Status status) {
VLOG(fd) << get_native_fd() << " on error " << status;
{
auto lock = lock_.lock();
pending_errors_.push(std::move(status));
}
get_poll_info().add_flags_from_poll(PollFlags::Error());
}
void on_connected() {
VLOG(fd) << get_native_fd() << " on connected";
CHECK(!is_connected_);
CHECK(is_receive_active_);
is_connected_ = true;
is_receive_active_ = false;
loop_receive();
loop_send();
}
void on_receive(size_t size) {
VLOG(fd) << get_native_fd() << " on receive " << size;
CHECK(is_receive_active_);
is_receive_active_ = false;
receive_helper_.from_native(receive_message_, size, to_receive_);
receive_buffer_.confirm_read((to_receive_.data.size() + 7) & ~7);
{
auto lock = lock_.lock();
// LOG(ERROR) << format::escaped(to_receive_.data.as_slice());
receive_queue_.push(std::move(to_receive_));
}
get_poll_info().add_flags_from_poll(PollFlags::Read());
loop_receive();
}
void on_send(size_t size) {
VLOG(fd) << get_native_fd() << " on send " << size;
if (size == 0) {
if (is_send_active_) {
return;
}
is_send_active_ = true;
}
CHECK(is_send_active_);
is_send_active_ = false;
loop_send();
}
void on_close() {
VLOG(fd) << get_native_fd() << " on close";
close_flag_ = true;
info_.set_native_fd({});
}
bool dec_refcnt() {
if (--refcnt_ == 0) {
delete this;
return true;
}
return false;
}
void inc_refcnt() {
CHECK(refcnt_ != 0);
refcnt_++;
}
void notify_iocp_send() {
inc_refcnt();
Iocp::get()->post(0, this, nullptr);
}
void notify_iocp_close() {
Iocp::get()->post(0, this, reinterpret_cast<WSAOVERLAPPED *>(&close_overlapped_));
}
void notify_iocp_connected() {
inc_refcnt();
Iocp::get()->post(0, this, reinterpret_cast<WSAOVERLAPPED *>(&receive_overlapped_));
}
};
void UdpSocketFdImplDeleter::operator()(UdpSocketFdImpl *impl) {
impl->close();
}
#elif TD_PORT_POSIX
//struct iovec { [> Scatter/gather array items <]
// void *iov_base; [> Starting address <]
// size_t iov_len; [> Number of bytes to transfer <]
//};
//struct msghdr {
// void *msg_name; [> optional address <]
// socklen_t msg_namelen; [> size of address <]
// struct iovec *msg_iov; [> scatter/gather array <]
// size_t msg_iovlen; [> # elements in msg_iov <]
// void *msg_control; [> ancillary data, see below <]
// size_t msg_controllen; [> ancillary data buffer len <]
// int msg_flags; [> flags on received message <]
//};
class UdpSocketReceiveHelper {
public:
void to_native(const UdpSocketFd::InboundMessage &message, struct msghdr &message_header) {
socklen_t addr_len{narrow_cast<socklen_t>(sizeof(addr_))};
message_header.msg_name = &addr_;
message_header.msg_namelen = addr_len;
io_vec_.iov_base = message.data.begin();
io_vec_.iov_len = message.data.size();
message_header.msg_iov = &io_vec_;
message_header.msg_iovlen = 1;
message_header.msg_control = control_buf_.data();
message_header.msg_controllen = narrow_cast<decltype(message_header.msg_controllen)>(control_buf_.size());
message_header.msg_flags = 0;
}
void from_native(struct msghdr &message_header, size_t message_size, UdpSocketFd::InboundMessage &message) {
#if TD_LINUX
struct cmsghdr *cmsg;
struct sock_extended_err *ee = nullptr;
for (cmsg = CMSG_FIRSTHDR(&message_header); cmsg != nullptr; cmsg = CMSG_NXTHDR(&message_header, cmsg)) {
if (cmsg->cmsg_type == IP_PKTINFO && cmsg->cmsg_level == IPPROTO_IP) {
//auto *pi = reinterpret_cast<struct in_pktinfo *>(CMSG_DATA(cmsg));
} else if (cmsg->cmsg_type == IPV6_PKTINFO && cmsg->cmsg_level == IPPROTO_IPV6) {
//auto *pi = reinterpret_cast<struct in6_pktinfo *>(CMSG_DATA(cmsg));
} else if ((cmsg->cmsg_type == IP_RECVERR && cmsg->cmsg_level == IPPROTO_IP) ||
(cmsg->cmsg_type == IPV6_RECVERR && cmsg->cmsg_level == IPPROTO_IPV6)) {
ee = reinterpret_cast<struct sock_extended_err *>(CMSG_DATA(cmsg));
}
}
if (ee != nullptr) {
auto *addr = reinterpret_cast<struct sockaddr *>(SO_EE_OFFENDER(ee));
IPAddress address;
address.init_sockaddr(addr).ignore();
if (message.from != nullptr) {
*message.from = address;
}
if (message.error) {
*message.error = Status::PosixError(ee->ee_errno, "");
}
//message.data = MutableSlice();
message.data.truncate(0);
return;
}
#endif
if (message.from != nullptr) {
message.from
->init_sockaddr(reinterpret_cast<struct sockaddr *>(message_header.msg_name), message_header.msg_namelen)
.ignore();
}
if (message.error) {
*message.error = Status::OK();
}
if (message_header.msg_flags & MSG_TRUNC) {
if (message.error) {
*message.error = Status::Error(501, "message too long");
}
message.data.truncate(0);
return;
}
CHECK(message_size <= message.data.size());
message.data.truncate(message_size);
CHECK(message_size == message.data.size());
}
private:
std::array<char, 1024> control_buf_;
sockaddr_storage addr_;
struct iovec io_vec_;
};
class UdpSocketSendHelper {
public:
void to_native(const UdpSocketFd::OutboundMessage &message, struct msghdr &message_header) {
CHECK(message.to != nullptr && message.to->is_valid());
message_header.msg_name = const_cast<struct sockaddr *>(message.to->get_sockaddr());
message_header.msg_namelen = narrow_cast<socklen_t>(message.to->get_sockaddr_len());
io_vec_.iov_base = const_cast<char *>(message.data.begin());
io_vec_.iov_len = message.data.size();
message_header.msg_iov = &io_vec_;
message_header.msg_iovlen = 1;
//TODO
message_header.msg_control = nullptr;
message_header.msg_controllen = 0;
message_header.msg_flags = 0;
}
private:
struct iovec io_vec_;
};
class UdpSocketFdImpl {
public:
explicit UdpSocketFdImpl(NativeFd fd) : info_(std::move(fd)) {
}
PollableFdInfo &get_poll_info() {
return info_;
}
const PollableFdInfo &get_poll_info() const {
return info_;
}
const NativeFd &get_native_fd() const {
return info_.native_fd();
}
Status get_pending_error() {
if (!get_poll_info().get_flags().has_pending_error()) {
return Status::OK();
}
TRY_STATUS(detail::get_socket_pending_error(get_native_fd()));
get_poll_info().clear_flags(PollFlags::Error());
return Status::OK();
}
Status receive_message(UdpSocketFd::InboundMessage &message, bool &is_received) {
is_received = false;
int flags = 0;
if (get_poll_info().get_flags().has_pending_error()) {
#ifdef MSG_ERRQUEUE
flags = MSG_ERRQUEUE;
#else
return get_pending_error();
#endif
}
struct msghdr message_header;
detail::UdpSocketReceiveHelper helper;
helper.to_native(message, message_header);
auto native_fd = get_native_fd().socket();
auto recvmsg_res = detail::skip_eintr([&] { return recvmsg(native_fd, &message_header, flags); });
auto recvmsg_errno = errno;
if (recvmsg_res >= 0) {
helper.from_native(message_header, recvmsg_res, message);
is_received = true;
return Status::OK();
}
return process_recvmsg_error(recvmsg_errno, is_received);
}
Status process_recvmsg_error(int recvmsg_errno, bool &is_received) {
is_received = false;
if (recvmsg_errno == EAGAIN
#if EAGAIN != EWOULDBLOCK
|| recvmsg_errno == EWOULDBLOCK
#endif
) {
if (get_poll_info().get_flags_local().has_pending_error()) {
get_poll_info().clear_flags(PollFlags::Error());
} else {
get_poll_info().clear_flags(PollFlags::Read());
}
return Status::OK();
}
auto error = Status::PosixError(recvmsg_errno, PSLICE() << "Receive from " << get_native_fd() << " has failed");
switch (recvmsg_errno) {
case EBADF:
case EFAULT:
case EINVAL:
case ENOTCONN:
case ECONNRESET:
case ETIMEDOUT:
LOG(FATAL) << error;
UNREACHABLE();
default:
LOG(WARNING) << "Unknown error: " << error;
// fallthrough
case ENOBUFS:
case ENOMEM:
#ifdef MSG_ERRQUEUE
get_poll_info().add_flags(PollFlags::Error());
#endif
return error;
}
}
Status send_message(const UdpSocketFd::OutboundMessage &message, bool &is_sent) {
is_sent = false;
struct msghdr message_header;
detail::UdpSocketSendHelper helper;
helper.to_native(message, message_header);
auto native_fd = get_native_fd().socket();
auto sendmsg_res = detail::skip_eintr([&] { return sendmsg(native_fd, &message_header, 0); });
auto sendmsg_errno = errno;
if (sendmsg_res >= 0) {
is_sent = true;
return Status::OK();
}
return process_sendmsg_error(sendmsg_errno, is_sent);
}
Status process_sendmsg_error(int sendmsg_errno, bool &is_sent) {
if (sendmsg_errno == EAGAIN
#if EAGAIN != EWOULDBLOCK
|| sendmsg_errno == EWOULDBLOCK
#endif
) {
get_poll_info().clear_flags(PollFlags::Write());
return Status::OK();
}
auto error = Status::PosixError(sendmsg_errno, PSLICE() << "Send from " << get_native_fd() << " has failed");
switch (sendmsg_errno) {
// Still may send some other packets, but there is no point to resend this particular message
case EACCES:
case EMSGSIZE:
case EPERM:
LOG(WARNING) << "Silently drop packet :( " << error;
//TODO: get errors from MSG_ERRQUEUE is possible
is_sent = true;
return error;
// Some general problems, which may be fixed in future
case ENOMEM:
case EDQUOT:
case EFBIG:
case ENETDOWN:
case ENETUNREACH:
case ENOSPC:
case EHOSTUNREACH:
case ENOBUFS:
default:
#ifdef MSG_ERRQUEUE
get_poll_info().add_flags(PollFlags::Error());
#endif
return error;
case EBADF: // impossible
case ENOTSOCK: // impossible
case EPIPE: // impossible for udp
case ECONNRESET: // impossible for udp
case EDESTADDRREQ: // we checked that address is valid
case ENOTCONN: // we checked that address is valid
case EINTR: // we already skipped all EINTR
case EISCONN: // impossible for udp socket
case EOPNOTSUPP:
case ENOTDIR:
case EFAULT:
case EINVAL:
case EAFNOSUPPORT:
LOG(FATAL) << error;
UNREACHABLE();
return error;
}
}
Status send_messages(Span<UdpSocketFd::OutboundMessage> messages, size_t &cnt) {
#if TD_HAS_MMSG
return send_messages_fast(messages, cnt);
#else
return send_messages_slow(messages, cnt);
#endif
}
Status receive_messages(MutableSpan<UdpSocketFd::InboundMessage> messages, size_t &cnt) {
#if TD_HAS_MMSG
return receive_messages_fast(messages, cnt);
#else
return receive_messages_slow(messages, cnt);
#endif
}
private:
PollableFdInfo info_;
Status send_messages_slow(Span<UdpSocketFd::OutboundMessage> messages, size_t &cnt) {
cnt = 0;
for (auto &message : messages) {
CHECK(!message.data.empty());
bool is_sent;
auto error = send_message(message, is_sent);
cnt += is_sent;
TRY_STATUS(std::move(error));
}
return Status::OK();
}
#if TD_HAS_MMSG
Status send_messages_fast(Span<UdpSocketFd::OutboundMessage> messages, size_t &cnt) {
//struct mmsghdr {
// struct msghdr msg_hdr; [> Message header <]
// unsigned int msg_len; [> Number of bytes transmitted <]
//};
struct std::array<detail::UdpSocketSendHelper, 16> helpers;
struct std::array<struct mmsghdr, 16> headers;
size_t to_send = min(messages.size(), headers.size());
for (size_t i = 0; i < to_send; i++) {
helpers[i].to_native(messages[i], headers[i].msg_hdr);
headers[i].msg_len = 0;
}
auto native_fd = get_native_fd().socket();
auto sendmmsg_res =
detail::skip_eintr([&] { return sendmmsg(native_fd, headers.data(), narrow_cast<unsigned int>(to_send), 0); });
auto sendmmsg_errno = errno;
if (sendmmsg_res >= 0) {
cnt = sendmmsg_res;
return Status::OK();
}
bool is_sent = false;
auto status = process_sendmsg_error(sendmmsg_errno, is_sent);
cnt = is_sent;
return status;
}
#endif
Status receive_messages_slow(MutableSpan<UdpSocketFd::InboundMessage> messages, size_t &cnt) {
cnt = 0;
while (cnt < messages.size() && get_poll_info().get_flags().can_read()) {
auto &message = messages[cnt];
CHECK(!message.data.empty());
bool is_received;
auto error = receive_message(message, is_received);
cnt += is_received;
TRY_STATUS(std::move(error));
}
return Status::OK();
}
#if TD_HAS_MMSG
Status receive_messages_fast(MutableSpan<UdpSocketFd::InboundMessage> messages, size_t &cnt) {
int flags = 0;
cnt = 0;
if (get_poll_info().get_flags().has_pending_error()) {
#ifdef MSG_ERRQUEUE
flags = MSG_ERRQUEUE;
#else
return get_pending_error();
#endif
}
//struct mmsghdr {
// struct msghdr msg_hdr; [> Message header <]
// unsigned int msg_len; [> Number of bytes transmitted <]
//};
struct std::array<detail::UdpSocketReceiveHelper, 16> helpers;
struct std::array<struct mmsghdr, 16> headers;
size_t to_receive = min(messages.size(), headers.size());
for (size_t i = 0; i < to_receive; i++) {
helpers[i].to_native(messages[i], headers[i].msg_hdr);
headers[i].msg_len = 0;
}
auto native_fd = get_native_fd().socket();
auto recvmmsg_res = detail::skip_eintr(
[&] { return recvmmsg(native_fd, headers.data(), narrow_cast<unsigned int>(to_receive), flags, nullptr); });
auto recvmmsg_errno = errno;
if (recvmmsg_res >= 0) {
cnt = narrow_cast<size_t>(recvmmsg_res);
for (size_t i = 0; i < cnt; i++) {
helpers[i].from_native(headers[i].msg_hdr, headers[i].msg_len, messages[i]);
}
return Status::OK();
}
bool is_received;
auto status = process_recvmsg_error(recvmmsg_errno, is_received);
cnt = is_received;
return status;
}
#endif
};
void UdpSocketFdImplDeleter::operator()(UdpSocketFdImpl *impl) {
delete impl;
}
#endif
} // namespace detail
UdpSocketFd::UdpSocketFd() = default;
UdpSocketFd::UdpSocketFd(UdpSocketFd &&) = default;
UdpSocketFd &UdpSocketFd::operator=(UdpSocketFd &&) = default;
UdpSocketFd::~UdpSocketFd() = default;
PollableFdInfo &UdpSocketFd::get_poll_info() {
return impl_->get_poll_info();
}
const PollableFdInfo &UdpSocketFd::get_poll_info() const {
return impl_->get_poll_info();
}
Result<UdpSocketFd> UdpSocketFd::open(const IPAddress &address) {
NativeFd native_fd{socket(address.get_address_family(), SOCK_DGRAM, IPPROTO_UDP)};
if (!native_fd) {
return OS_SOCKET_ERROR("Failed to create a socket");
}
TRY_STATUS(native_fd.set_is_blocking_unsafe(false));
auto sock = native_fd.socket();
#if TD_PORT_POSIX
int flags = 1;
#elif TD_PORT_WINDOWS
BOOL flags = TRUE;
#endif
setsockopt(sock, SOL_SOCKET, SO_REUSEADDR, reinterpret_cast<const char *>(&flags), sizeof(flags));
// TODO: SO_REUSEADDR, SO_KEEPALIVE, TCP_NODELAY, SO_SNDBUF, SO_RCVBUF, TCP_QUICKACK, SO_LINGER
auto bind_addr = address.get_any_addr();
bind_addr.set_port(address.get_port());
auto e_bind = bind(sock, bind_addr.get_sockaddr(), narrow_cast<int>(bind_addr.get_sockaddr_len()));
if (e_bind != 0) {
return OS_SOCKET_ERROR("Failed to bind a socket");
}
return UdpSocketFd(make_unique<detail::UdpSocketFdImpl>(std::move(native_fd)));
}
UdpSocketFd::UdpSocketFd(unique_ptr<detail::UdpSocketFdImpl> impl) : impl_(impl.release()) {
}
void UdpSocketFd::close() {
impl_.reset();
}
bool UdpSocketFd::empty() const {
return !impl_;
}
const NativeFd &UdpSocketFd::get_native_fd() const {
return get_poll_info().native_fd();
}
#if TD_PORT_POSIX
static Result<uint32> maximize_buffer(int socket_fd, int optname, uint32 max) {
/* Start with the default size. */
uint32 old_size;
socklen_t intsize = sizeof(old_size);
if (getsockopt(socket_fd, SOL_SOCKET, optname, &old_size, &intsize)) {
return OS_ERROR("getsockopt() failed");
}
/* Binary-search for the real maximum. */
uint32 last_good = old_size;
uint32 min = old_size;
while (min <= max) {
uint32 avg = min + (max - min) / 2;
if (setsockopt(socket_fd, SOL_SOCKET, optname, &avg, intsize) == 0) {
last_good = avg;
min = avg + 1;
} else {
max = avg - 1;
}
}
return last_good;
}
Result<uint32> UdpSocketFd::maximize_snd_buffer(uint32 max) {
return maximize_buffer(get_native_fd().fd(), SO_SNDBUF, max == 0 ? DEFAULT_UDP_MAX_SND_BUFFER_SIZE : max);
}
Result<uint32> UdpSocketFd::maximize_rcv_buffer(uint32 max) {
return maximize_buffer(get_native_fd().fd(), SO_RCVBUF, max == 0 ? DEFAULT_UDP_MAX_RCV_BUFFER_SIZE : max);
}
#else
Result<uint32> UdpSocketFd::maximize_snd_buffer(uint32 max) {
return 0;
}
Result<uint32> UdpSocketFd::maximize_rcv_buffer(uint32 max) {
return 0;
}
#endif
#if TD_PORT_POSIX
Status UdpSocketFd::send_message(const OutboundMessage &message, bool &is_sent) {
return impl_->send_message(message, is_sent);
}
Status UdpSocketFd::receive_message(InboundMessage &message, bool &is_received) {
return impl_->receive_message(message, is_received);
}
Status UdpSocketFd::send_messages(Span<OutboundMessage> messages, size_t &count) {
return impl_->send_messages(messages, count);
}
Status UdpSocketFd::receive_messages(MutableSpan<InboundMessage> messages, size_t &count) {
return impl_->receive_messages(messages, count);
}
#endif
#if TD_PORT_WINDOWS
Result<optional<UdpMessage>> UdpSocketFd::receive() {
return impl_->receive();
}
void UdpSocketFd::send(UdpMessage message) {
return impl_->send(std::move(message));
}
Status UdpSocketFd::flush_send() {
return impl_->flush_send();
}
#endif
bool UdpSocketFd::is_critical_read_error(const Status &status) {
return status.code() == ENOMEM || status.code() == ENOBUFS;
}
} // namespace td