Update tdutils from other project

GitOrigin-RevId: 83ec94032ccecef812b01963ac0506655a57e2af
This commit is contained in:
Arseny Smirnov 2018-08-13 20:15:09 +03:00
parent 89fe455514
commit 093651fb22
73 changed files with 4131 additions and 654 deletions

View File

@ -33,7 +33,7 @@ endif()
set(TDUTILS_SOURCE set(TDUTILS_SOURCE
td/utils/port/Clocks.cpp td/utils/port/Clocks.cpp
td/utils/port/Fd.cpp #td/utils/port/Fd.cpp
td/utils/port/FileFd.cpp td/utils/port/FileFd.cpp
td/utils/port/IPAddress.cpp td/utils/port/IPAddress.cpp
td/utils/port/path.cpp td/utils/port/path.cpp
@ -51,8 +51,10 @@ set(TDUTILS_SOURCE
td/utils/port/detail/EventFdWindows.cpp td/utils/port/detail/EventFdWindows.cpp
td/utils/port/detail/KQueue.cpp td/utils/port/detail/KQueue.cpp
td/utils/port/detail/Poll.cpp td/utils/port/detail/Poll.cpp
td/utils/port/detail/PollableFd.cpp
td/utils/port/detail/Select.cpp td/utils/port/detail/Select.cpp
td/utils/port/detail/ThreadIdGuard.cpp td/utils/port/detail/ThreadIdGuard.cpp
td/utils/port/UdpSocketFd.cpp
td/utils/port/detail/WineventPoll.cpp td/utils/port/detail/WineventPoll.cpp
${TDMIME_AUTO} ${TDMIME_AUTO}
@ -60,8 +62,9 @@ set(TDUTILS_SOURCE
td/utils/base64.cpp td/utils/base64.cpp
td/utils/BigNum.cpp td/utils/BigNum.cpp
td/utils/buffer.cpp td/utils/buffer.cpp
td/utils/BufferedUdp.cpp
td/utils/crypto.cpp td/utils/crypto.cpp
td/utils/FileLog.cpp #td/utils/FileLog.cpp
td/utils/filesystem.cpp td/utils/filesystem.cpp
td/utils/find_boundary.cpp td/utils/find_boundary.cpp
td/utils/Gzip.cpp td/utils/Gzip.cpp
@ -88,7 +91,7 @@ set(TDUTILS_SOURCE
td/utils/port/CxCli.h td/utils/port/CxCli.h
td/utils/port/EventFd.h td/utils/port/EventFd.h
td/utils/port/EventFdBase.h td/utils/port/EventFdBase.h
td/utils/port/Fd.h #td/utils/port/Fd.h
td/utils/port/FileFd.h td/utils/port/FileFd.h
td/utils/port/IPAddress.h td/utils/port/IPAddress.h
td/utils/port/path.h td/utils/port/path.h
@ -103,6 +106,7 @@ set(TDUTILS_SOURCE
td/utils/port/Stat.h td/utils/port/Stat.h
td/utils/port/thread.h td/utils/port/thread.h
td/utils/port/thread_local.h td/utils/port/thread_local.h
td/utils/port/UdpSocketFd.h
td/utils/port/wstring_convert.h td/utils/port/wstring_convert.h
td/utils/port/detail/Epoll.h td/utils/port/detail/Epoll.h
@ -110,7 +114,10 @@ set(TDUTILS_SOURCE
td/utils/port/detail/EventFdLinux.h td/utils/port/detail/EventFdLinux.h
td/utils/port/detail/EventFdWindows.h td/utils/port/detail/EventFdWindows.h
td/utils/port/detail/KQueue.h td/utils/port/detail/KQueue.h
td/utils/port/detail/NativeFd.h
td/utils/port/detail/NativeFd.cpp
td/utils/port/detail/Poll.h td/utils/port/detail/Poll.h
td/utils/port/detail/PollableFd.h
td/utils/port/detail/Select.h td/utils/port/detail/Select.h
td/utils/port/detail/ThreadIdGuard.h td/utils/port/detail/ThreadIdGuard.h
td/utils/port/detail/ThreadPthread.h td/utils/port/detail/ThreadPthread.h
@ -122,6 +129,7 @@ set(TDUTILS_SOURCE
td/utils/benchmark.h td/utils/benchmark.h
td/utils/BigNum.h td/utils/BigNum.h
td/utils/buffer.h td/utils/buffer.h
td/utils/BufferedUdp.h
td/utils/BufferedFd.h td/utils/BufferedFd.h
td/utils/BufferedReader.h td/utils/BufferedReader.h
td/utils/ByteFlow.h td/utils/ByteFlow.h
@ -129,9 +137,11 @@ set(TDUTILS_SOURCE
td/utils/Closure.h td/utils/Closure.h
td/utils/common.h td/utils/common.h
td/utils/Container.h td/utils/Container.h
td/utils/Context.h
td/utils/crypto.h td/utils/crypto.h
td/utils/Enumerator.h td/utils/Enumerator.h
td/utils/FileLog.h td/utils/Destructor.h
#td/utils/FileLog.h
td/utils/filesystem.h td/utils/filesystem.h
td/utils/find_boundary.h td/utils/find_boundary.h
td/utils/FloodControlFast.h td/utils/FloodControlFast.h
@ -171,6 +181,7 @@ set(TDUTILS_SOURCE
td/utils/SharedObjectPool.h td/utils/SharedObjectPool.h
td/utils/Slice-decl.h td/utils/Slice-decl.h
td/utils/Slice.h td/utils/Slice.h
td/utils/Span.h
td/utils/SpinLock.h td/utils/SpinLock.h
td/utils/StackAllocator.h td/utils/StackAllocator.h
td/utils/Status.h td/utils/Status.h
@ -189,21 +200,24 @@ set(TDUTILS_SOURCE
td/utils/unicode.h td/utils/unicode.h
td/utils/utf8.h td/utils/utf8.h
td/utils/Variant.h td/utils/Variant.h
td/utils/VectorQueue.h
) )
set(TDUTILS_TEST_SOURCE set(TDUTILS_TEST_SOURCE
${CMAKE_CURRENT_SOURCE_DIR}/test/buffer.cpp
${CMAKE_CURRENT_SOURCE_DIR}/test/crypto.cpp ${CMAKE_CURRENT_SOURCE_DIR}/test/crypto.cpp
${CMAKE_CURRENT_SOURCE_DIR}/test/Enumerator.cpp ${CMAKE_CURRENT_SOURCE_DIR}/test/Enumerator.cpp
${CMAKE_CURRENT_SOURCE_DIR}/test/filesystem.cpp #${CMAKE_CURRENT_SOURCE_DIR}/test/filesystem.cpp
${CMAKE_CURRENT_SOURCE_DIR}/test/gzip.cpp ${CMAKE_CURRENT_SOURCE_DIR}/test/gzip.cpp
${CMAKE_CURRENT_SOURCE_DIR}/test/HazardPointers.cpp ${CMAKE_CURRENT_SOURCE_DIR}/test/HazardPointers.cpp
${CMAKE_CURRENT_SOURCE_DIR}/test/heap.cpp ${CMAKE_CURRENT_SOURCE_DIR}/test/heap.cpp
${CMAKE_CURRENT_SOURCE_DIR}/test/json.cpp ${CMAKE_CURRENT_SOURCE_DIR}/test/json.cpp
${CMAKE_CURRENT_SOURCE_DIR}/test/misc.cpp #${CMAKE_CURRENT_SOURCE_DIR}/test/misc.cpp
${CMAKE_CURRENT_SOURCE_DIR}/test/MpmcQueue.cpp ${CMAKE_CURRENT_SOURCE_DIR}/test/MpmcQueue.cpp
${CMAKE_CURRENT_SOURCE_DIR}/test/MpmcWaiter.cpp ${CMAKE_CURRENT_SOURCE_DIR}/test/MpmcWaiter.cpp
${CMAKE_CURRENT_SOURCE_DIR}/test/MpscLinkQueue.cpp ${CMAKE_CURRENT_SOURCE_DIR}/test/MpscLinkQueue.cpp
${CMAKE_CURRENT_SOURCE_DIR}/test/OrderedEventsProcessor.cpp ${CMAKE_CURRENT_SOURCE_DIR}/test/OrderedEventsProcessor.cpp
${CMAKE_CURRENT_SOURCE_DIR}/test/port.cpp
${CMAKE_CURRENT_SOURCE_DIR}/test/pq.cpp ${CMAKE_CURRENT_SOURCE_DIR}/test/pq.cpp
${CMAKE_CURRENT_SOURCE_DIR}/test/SharedObjectPool.cpp ${CMAKE_CURRENT_SOURCE_DIR}/test/SharedObjectPool.cpp
${CMAKE_CURRENT_SOURCE_DIR}/test/variant.cpp ${CMAKE_CURRENT_SOURCE_DIR}/test/variant.cpp

View File

@ -85,6 +85,15 @@ BigNum BigNum::from_binary(Slice str) {
return BigNum(make_unique<Impl>(BN_bin2bn(str.ubegin(), narrow_cast<int>(str.size()), nullptr))); return BigNum(make_unique<Impl>(BN_bin2bn(str.ubegin(), narrow_cast<int>(str.size()), nullptr)));
} }
BigNum BigNum::from_le_binary(Slice str) {
#if OPENSSL_VERSION_NUMBER >= 0x10100000L
return BigNum(make_unique<Impl>(BN_lebin2bn(str.ubegin(), narrow_cast<int>(str.size()), nullptr)));
#else
LOG(FATAL) << "Unsupported from_le_binary";
return BigNum();
#endif
}
Result<BigNum> BigNum::from_decimal(CSlice str) { Result<BigNum> BigNum::from_decimal(CSlice str) {
BigNum result; BigNum result;
int res = BN_dec2bn(&result.impl_->big_num, str.c_str()); int res = BN_dec2bn(&result.impl_->big_num, str.c_str());
@ -94,6 +103,13 @@ Result<BigNum> BigNum::from_decimal(CSlice str) {
return result; return result;
} }
BigNum BigNum::from_hex(CSlice str) {
BigNum result;
int err = BN_hex2bn(&result.impl_->big_num, str.c_str());
LOG_IF(FATAL, err == 0);
return result;
}
BigNum BigNum::from_raw(void *openssl_big_num) { BigNum BigNum::from_raw(void *openssl_big_num) {
return BigNum(make_unique<Impl>(static_cast<BIGNUM *>(openssl_big_num))); return BigNum(make_unique<Impl>(static_cast<BIGNUM *>(openssl_big_num)));
} }
@ -182,10 +198,27 @@ string BigNum::to_binary(int exact_size) const {
CHECK(exact_size >= num_size); CHECK(exact_size >= num_size);
} }
string res(exact_size, '\0'); string res(exact_size, '\0');
BN_bn2bin(impl_->big_num, reinterpret_cast<unsigned char *>(&res[exact_size - num_size])); BN_bn2bin(impl_->big_num, MutableSlice(res).ubegin() + (exact_size - num_size));
return res; return res;
} }
string BigNum::to_le_binary(int exact_size) const {
#if OPENSSL_VERSION_NUMBER >= 0x10100000L
int num_size = get_num_bytes();
if (exact_size == -1) {
exact_size = num_size;
} else {
CHECK(exact_size >= num_size);
}
string res(exact_size, '\0');
BN_bn2lebinpad(impl_->big_num, MutableSlice(res).ubegin(), exact_size);
return res;
#else
LOG(FATAL) << "Unsupported to_le_binary";
return "";
#endif
}
string BigNum::to_decimal() const { string BigNum::to_decimal() const {
char *result = BN_bn2dec(impl_->big_num); char *result = BN_bn2dec(impl_->big_num);
CHECK(result != nullptr); CHECK(result != nullptr);
@ -216,15 +249,27 @@ void BigNum::mul(BigNum &r, BigNum &a, BigNum &b, BigNumContext &context) {
LOG_IF(FATAL, result != 1); LOG_IF(FATAL, result != 1);
} }
void BigNum::mod_add(BigNum &r, BigNum &a, BigNum &b, const BigNum &m, BigNumContext &context) {
int result = BN_mod_add(r.impl_->big_num, a.impl_->big_num, b.impl_->big_num, m.impl_->big_num,
context.impl_->big_num_context);
LOG_IF(FATAL, result != 1);
}
void BigNum::mod_sub(BigNum &r, BigNum &a, BigNum &b, const BigNum &m, BigNumContext &context) {
int result = BN_mod_sub(r.impl_->big_num, a.impl_->big_num, b.impl_->big_num, m.impl_->big_num,
context.impl_->big_num_context);
LOG_IF(FATAL, result != 1);
}
void BigNum::mod_mul(BigNum &r, BigNum &a, BigNum &b, const BigNum &m, BigNumContext &context) { void BigNum::mod_mul(BigNum &r, BigNum &a, BigNum &b, const BigNum &m, BigNumContext &context) {
int result = BN_mod_mul(r.impl_->big_num, a.impl_->big_num, b.impl_->big_num, m.impl_->big_num, int result = BN_mod_mul(r.impl_->big_num, a.impl_->big_num, b.impl_->big_num, m.impl_->big_num,
context.impl_->big_num_context); context.impl_->big_num_context);
LOG_IF(FATAL, result != 1); LOG_IF(FATAL, result != 1);
} }
void BigNum::mod_inv(BigNum &r, BigNum &a, const BigNum &m, BigNumContext &context) { void BigNum::mod_inverse(BigNum &r, BigNum &a, const BigNum &m, BigNumContext &context) {
BIGNUM *result = BN_mod_inverse(r.impl_->big_num, a.impl_->big_num, m.impl_->big_num, context.impl_->big_num_context); auto result = BN_mod_inverse(r.impl_->big_num, a.impl_->big_num, m.impl_->big_num, context.impl_->big_num_context);
LOG_IF(FATAL, result == nullptr); LOG_IF(FATAL, result != r.impl_->big_num);
} }
void BigNum::div(BigNum *quotient, BigNum *remainder, const BigNum &dividend, const BigNum &divisor, void BigNum::div(BigNum *quotient, BigNum *remainder, const BigNum &dividend, const BigNum &divisor,

View File

@ -43,8 +43,13 @@ class BigNum {
static BigNum from_binary(Slice str); static BigNum from_binary(Slice str);
// Available only if OpenSSL >= 1.1.0
static BigNum from_le_binary(Slice str);
static Result<BigNum> from_decimal(CSlice str); static Result<BigNum> from_decimal(CSlice str);
static BigNum from_hex(CSlice str);
static BigNum from_raw(void *openssl_big_num); static BigNum from_raw(void *openssl_big_num);
void set_value(uint32 new_value); void set_value(uint32 new_value);
@ -67,6 +72,9 @@ class BigNum {
string to_binary(int exact_size = -1) const; string to_binary(int exact_size = -1) const;
// Available only if OpenSSL >= 1.1.0
string to_le_binary(int exact_size = -1) const;
string to_decimal() const; string to_decimal() const;
void operator+=(uint32 value); void operator+=(uint32 value);
@ -87,9 +95,13 @@ class BigNum {
static void mul(BigNum &r, BigNum &a, BigNum &b, BigNumContext &context); static void mul(BigNum &r, BigNum &a, BigNum &b, BigNumContext &context);
static void mod_add(BigNum &r, BigNum &a, BigNum &b, const BigNum &m, BigNumContext &context);
static void mod_sub(BigNum &r, BigNum &a, BigNum &b, const BigNum &m, BigNumContext &context);
static void mod_mul(BigNum &r, BigNum &a, BigNum &b, const BigNum &m, BigNumContext &context); static void mod_mul(BigNum &r, BigNum &a, BigNum &b, const BigNum &m, BigNumContext &context);
static void mod_inv(BigNum &r, BigNum &a, const BigNum &m, BigNumContext &context); static void mod_inverse(BigNum &r, BigNum &a, const BigNum &m, BigNumContext &context);
static void div(BigNum *quotient, BigNum *remainder, const BigNum &dividend, const BigNum &divisor, static void div(BigNum *quotient, BigNum *remainder, const BigNum &dividend, const BigNum &divisor,
BigNumContext &context); BigNumContext &context);

View File

@ -0,0 +1,13 @@
//
// Copyright Aliaksei Levin (levlam@telegram.org), Arseny Smirnov (arseny30@gmail.com) 2014-2018
//
// Distributed under the Boost Software License, Version 1.0. (See accompanying
// file LICENSE_1_0.txt or copy at http://www.boost.org/LICENSE_1_0.txt)
//
#include "td/utils/BufferedUdp.h"
namespace td {
#if TD_PORT_POSIX
TD_THREAD_LOCAL detail::UdpReader* BufferedUdp::udp_reader_;
#endif
} // namespace td

View File

@ -0,0 +1,163 @@
//
// Copyright Aliaksei Levin (levlam@telegram.org), Arseny Smirnov (arseny30@gmail.com) 2014-2018
//
// Distributed under the Boost Software License, Version 1.0. (See accompanying
// file LICENSE_1_0.txt or copy at http://www.boost.org/LICENSE_1_0.txt)
//
#pragma once
#include "td/utils/port/UdpSocketFd.h"
#include "td/utils/buffer.h"
#include "td/utils/optional.h"
#include "td/utils/VectorQueue.h"
namespace td {
#if TD_PORT_POSIX
namespace detail {
class UdpWriter {
public:
static Status write_once(UdpSocketFd& fd, VectorQueue<UdpMessage>& queue) TD_WARN_UNUSED_RESULT {
std::array<UdpSocketFd::OutboundMessage, 16> messages;
auto to_send = queue.as_span();
size_t to_send_n = std::min(messages.size(), to_send.size());
to_send.truncate(to_send_n);
for (size_t i = 0; i < to_send_n; i++) {
messages[i].to = &to_send[i].address;
messages[i].data = to_send[i].data.as_slice();
}
size_t cnt;
auto status = fd.send_messages(::td::Span<UdpSocketFd::OutboundMessage>(messages).truncate(to_send_n), cnt);
queue.pop_n(cnt);
return status;
}
};
class UdpReaderHelper {
public:
void init_inbound_message(UdpSocketFd::InboundMessage& message) {
message.from = &message_.address;
message.error = &message_.error;
if (buffer_.size() < MAX_PACKET_SIZE) {
buffer_ = BufferSlice(RESERVED_SIZE);
}
CHECK(buffer_.size() >= MAX_PACKET_SIZE);
message.data = buffer_.as_slice().truncate(MAX_PACKET_SIZE);
}
UdpMessage extract_udp_message(UdpSocketFd::InboundMessage& message) {
message_.data = buffer_.from_slice(message.data);
auto size = message_.data.size();
size = (size + 7) & ~7;
CHECK(size <= MAX_PACKET_SIZE);
buffer_.confirm_read(size);
return std::move(message_);
}
private:
enum : size_t { MAX_PACKET_SIZE = 2048, RESERVED_SIZE = MAX_PACKET_SIZE * 8 };
UdpMessage message_;
BufferSlice buffer_;
};
//One for thread is enough
class UdpReader {
public:
UdpReader() {
for (size_t i = 0; i < messages_.size(); i++) {
helpers_[i].init_inbound_message(messages_[i]);
}
}
Status read_once(UdpSocketFd& fd, VectorQueue<UdpMessage>& queue) TD_WARN_UNUSED_RESULT {
for (size_t i = 0; i < messages_.size(); i++) {
CHECK(messages_[i].data.size() == 2048);
}
size_t cnt = 0;
auto status = fd.receive_messages(messages_, cnt);
for (size_t i = 0; i < cnt; i++) {
queue.push(helpers_[i].extract_udp_message(messages_[i]));
helpers_[i].init_inbound_message(messages_[i]);
}
for (size_t i = cnt; i < messages_.size(); i++) {
CHECK(messages_[i].data.size() == 2048)
<< " cnt = " << cnt << " i = " << i << " size = " << messages_[i].data.size() << " status = " << status;
;
}
if (status.is_error() && !UdpSocketFd::is_critical_read_error(status)) {
queue.push(UdpMessage{{}, {}, std::move(status)});
}
return status;
}
private:
enum : size_t { BUFFER_SIZE = 16 };
std::array<UdpSocketFd::InboundMessage, BUFFER_SIZE> messages_;
std::array<UdpReaderHelper, BUFFER_SIZE> helpers_;
};
} // namespace detail
#endif
class BufferedUdp : public UdpSocketFd {
public:
BufferedUdp(UdpSocketFd fd) : UdpSocketFd(std::move(fd)) {
}
#if TD_PORT_POSIX
Result<optional<UdpMessage>> receive() {
if (input_.empty() && can_read(*this)) {
TRY_STATUS(flush_read_once());
}
if (input_.empty()) {
return optional<UdpMessage>();
}
return input_.pop();
}
void send(UdpMessage message) {
output_.push(std::move(message));
}
Status flush_send() {
Status status;
while (status.is_ok() && can_write(*this) && !output_.empty()) {
status = flush_send_once();
}
return status;
}
#endif
UdpSocketFd move_as_udp_socket_fd() {
return std::move(as_fd());
}
UdpSocketFd& as_fd() {
return *static_cast<UdpSocketFd*>(this);
}
private:
#if TD_PORT_POSIX
VectorQueue<UdpMessage> input_;
VectorQueue<UdpMessage> output_;
VectorQueue<UdpMessage>& input() {
return input_;
}
VectorQueue<UdpMessage>& output() {
return output_;
}
Status flush_send_once() TD_WARN_UNUSED_RESULT {
return detail::UdpWriter::write_once(as_fd(), output_);
}
Status flush_read_once() TD_WARN_UNUSED_RESULT {
init_thread_local<detail::UdpReader>(udp_reader_);
return udp_reader_->read_once(as_fd(), input_);
}
static TD_THREAD_LOCAL detail::UdpReader* udp_reader_;
#endif
};
} // namespace td

View File

@ -208,7 +208,7 @@ class ByteFlowSink : public ByteFlowInterface {
CHECK(buffer_ == nullptr); CHECK(buffer_ == nullptr);
buffer_ = input; buffer_ = input;
} }
void set_parent(ByteFlowInterface &parent) final { void set_parent(ByteFlowInterface & /*parent*/) final {
UNREACHABLE(); UNREACHABLE();
} }
void close_input(Status status) final { void close_input(Status status) final {
@ -254,7 +254,7 @@ class ByteFlowMoveSink : public ByteFlowInterface {
CHECK(!input_); CHECK(!input_);
input_ = input; input_ = input;
} }
void set_parent(ByteFlowInterface &parent) final { void set_parent(ByteFlowInterface & /*parent*/) final {
UNREACHABLE(); UNREACHABLE();
} }
void close_input(Status status) final { void close_input(Status status) final {

View File

@ -68,17 +68,16 @@ class ImmediateClosure {
friend Delayed; friend Delayed;
using ActorType = ActorT; using ActorType = ActorT;
void run(ActorT *actor) { auto run(ActorT *actor) {
mem_call_tuple(actor, func, std::move(args)); return mem_call_tuple(actor, std::move(args));
} }
// no &&. just save references as references. // no &&. just save references as references.
explicit ImmediateClosure(FunctionT func, ArgsT... args) : func(func), args(std::forward<ArgsT>(args)...) { explicit ImmediateClosure(FunctionT func, ArgsT... args) : args(func, std::forward<ArgsT>(args)...) {
} }
private: private:
FunctionT func; std::tuple<FunctionT, ArgsT...> args;
std::tuple<ArgsT...> args;
}; };
template <class ActorT, class ResultT, class... DestArgsT, class... SrcArgsT> template <class ActorT, class ResultT, class... DestArgsT, class... SrcArgsT>
@ -94,36 +93,34 @@ class DelayedClosure {
using ActorType = ActorT; using ActorType = ActorT;
using Delayed = DelayedClosure<ActorT, FunctionT, ArgsT...>; using Delayed = DelayedClosure<ActorT, FunctionT, ArgsT...>;
void run(ActorT *actor) { auto run(ActorT *actor) {
mem_call_tuple(actor, func, std::move(args)); return mem_call_tuple(actor, std::move(args));
} }
DelayedClosure clone() const { DelayedClosure clone() const {
return do_clone(*this); return do_clone(*this);
} }
explicit DelayedClosure(ImmediateClosure<ActorT, FunctionT, ArgsT...> &&other) explicit DelayedClosure(ImmediateClosure<ActorT, FunctionT, ArgsT...> &&other) : args(std::move(other.args)) {
: func(std::move(other.func)), args(std::move(other.args)) {
} }
explicit DelayedClosure(FunctionT func, ArgsT... args) : func(func), args(std::forward<ArgsT>(args)...) { explicit DelayedClosure(FunctionT func, ArgsT... args) : args(func, std::forward<ArgsT>(args)...) {
} }
template <class F> //template <class F>
void for_each(const F &f) { //void for_each(const F &f) {
tuple_for_each(args, f); //tuple_for_each(args, f);
} //}
private: private:
using ArgsStorageT = std::tuple<typename std::decay<ArgsT>::type...>; using ArgsStorageT = std::tuple<FunctionT, typename std::decay<ArgsT>::type...>;
FunctionT func;
ArgsStorageT args; ArgsStorageT args;
template <class FromActorT, class FromFunctionT, class... FromArgsT> template <class FromActorT, class FromFunctionT, class... FromArgsT>
explicit DelayedClosure(const DelayedClosure<FromActorT, FromFunctionT, FromArgsT...> &other, explicit DelayedClosure(const DelayedClosure<FromActorT, FromFunctionT, FromArgsT...> &other,
std::enable_if_t<LogicAnd<std::is_copy_constructible<FromArgsT>::value...>::value, int> = 0) std::enable_if_t<LogicAnd<std::is_copy_constructible<FromArgsT>::value...>::value, int> = 0)
: func(other.func), args(other.args) { : args(other.args) {
} }
template <class FromActorT, class FromFunctionT, class... FromArgsT> template <class FromActorT, class FromFunctionT, class... FromArgsT>

View File

@ -0,0 +1,43 @@
//
// Copyright Aliaksei Levin (levlam@telegram.org), Arseny Smirnov (arseny30@gmail.com) 2014-2018
//
// Distributed under the Boost Software License, Version 1.0. (See accompanying
// file LICENSE_1_0.txt or copy at http://www.boost.org/LICENSE_1_0.txt)
//
#pragma once
#include "td/utils/port/thread_local.h"
namespace td {
template <class Impl>
class Context {
public:
static Impl *get() {
return context_;
}
class Guard {
public:
explicit Guard(Impl *new_context) {
old_context_ = context_;
context_ = new_context;
}
~Guard() {
context_ = old_context_;
}
Guard(const Guard &) = delete;
Guard &operator=(const Guard &) = delete;
Guard(Guard &&) = delete;
Guard &operator=(Guard &&) = delete;
private:
Impl *old_context_;
};
private:
static TD_THREAD_LOCAL Impl *context_;
};
template <class Impl>
TD_THREAD_LOCAL Impl *Context<Impl>::context_;
} // namespace td

177
tdutils/td/utils/DecTree.h Normal file
View File

@ -0,0 +1,177 @@
//
// Copyright Aliaksei Levin (levlam@telegram.org), Arseny Smirnov (arseny30@gmail.com) 2014-2018
//
// Distributed under the Boost Software License, Version 1.0. (See accompanying
// file LICENSE_1_0.txt or copy at http://www.boost.org/LICENSE_1_0.txt)
//
#pragma once
#include <memory>
#include <utility>
#include "int_types.h"
#include "Random.h"
namespace td {
template <typename KeyType, typename ValueType, typename Compare = std::less<KeyType>>
class DecTree {
private:
struct Node {
std::unique_ptr<Node> left_;
std::unique_ptr<Node> right_;
size_t size_;
KeyType key_;
ValueType value_;
uint32 y_;
void relax() {
size_ = 1;
if (left_ != nullptr) {
size_ += left_->size_;
}
if (right_ != nullptr) {
size_ += right_->size_;
}
}
Node(KeyType key, ValueType value, uint32 y) : key_(std::move(key)), value_(std::move(value)), y_(y) {
size_ = 1;
}
};
std::unique_ptr<Node> root_;
std::unique_ptr<Node> create_node(KeyType key, ValueType value, uint32 y) {
return std::make_unique<Node>(std::move(key), std::move(value), y);
}
std::unique_ptr<Node> insert_node(std::unique_ptr<Node> Tree, KeyType key, ValueType value, uint32 y) {
if (Tree == nullptr) {
return create_node(std::move(key), std::move(value), y);
}
if (Tree->y_ < y) {
auto P = split_node(std::move(Tree), key);
auto T = create_node(std::move(key), std::move(value), y);
T->left_ = std::move(P.first);
T->right_ = std::move(P.second);
T->relax();
return std::move(T);
}
if (Compare()(key, Tree->key_)) {
Tree->left_ = insert_node(std::move(Tree->left_), std::move(key), std::move(value), y);
} else if (Compare()(Tree->key_, key)) {
Tree->right_ = insert_node(std::move(Tree->right_), std::move(key), std::move(value), y);
} else {
// ?? assert
}
Tree->relax();
return std::move(Tree);
}
std::unique_ptr<Node> remove_node(std::unique_ptr<Node> Tree, KeyType &key) {
if (Tree == nullptr) {
// ?? assert
return nullptr;
}
if (Compare()(key, Tree->key_)) {
Tree->left_ = remove_node(std::move(Tree->left_), key);
} else if (Compare()(Tree->key_, key)) {
Tree->right_ = remove_node(std::move(Tree->right_), key);
} else {
Tree = merge_node(std::move(Tree->left_), std::move(Tree->right_));
}
if (Tree != nullptr) {
Tree->relax();
}
return std::move(Tree);
}
ValueType *get_node(std::unique_ptr<Node> &Tree, KeyType &key) {
if (Tree == nullptr) {
return nullptr;
}
if (Compare()(key, Tree->key_)) {
return get_node(Tree->left_, key);
} else if (Compare()(Tree->key_, key)) {
return get_node(Tree->right_, key);
} else {
return &Tree->value_;
}
}
ValueType *get_node_by_idx(std::unique_ptr<Node> &Tree, size_t idx) {
CHECK(Tree != nullptr);
auto s = (Tree->left_ != nullptr) ? Tree->left_->size_ : 0;
if (idx < s) {
return get_node_by_idx(Tree->left_, idx);
} else if (idx == s) {
return &Tree->value_;
} else {
return get_node_by_idx(Tree->right_, idx - s - 1);
}
}
std::pair<std::unique_ptr<Node>, std::unique_ptr<Node>> split_node(std::unique_ptr<Node> Tree, KeyType &key) {
if (Tree == nullptr) {
return std::pair<std::unique_ptr<Node>, std::unique_ptr<Node>>(nullptr, nullptr);
}
if (Compare()(key, Tree->key_)) {
auto P = split_node(std::move(Tree->left_), key);
Tree->left_ = std::move(P.second);
Tree->relax();
P.second = std::move(Tree);
return std::move(P);
} else {
auto P = split_node(std::move(Tree->right_), key);
Tree->right_ = std::move(P.first);
Tree->relax();
P.first = std::move(Tree);
return std::move(P);
}
}
std::unique_ptr<Node> merge_node(std::unique_ptr<Node> left, std::unique_ptr<Node> right) {
if (left == nullptr) {
return std::move(right);
}
if (right == nullptr) {
return std::move(left);
}
if (left->y_ < right->y_) {
right->left_ = merge_node(std::move(left), std::move(right->left_));
right->relax();
return std::move(right);
} else {
left->right_ = merge_node(std::move(left->right_), std::move(right));
left->relax();
return std::move(left);
}
}
public:
DecTree() {
}
size_t size() const {
if (root_ == nullptr) {
return 0;
} else {
return root_->size_;
}
}
void insert(KeyType key, ValueType value) {
root_ = insert_node(std::move(root_), std::move(key), std::move(value), td::Random::fast_uint32());
}
void remove(KeyType &key) {
root_ = remove_node(std::move(root_), key);
}
ValueType *get(KeyType &key) {
return get_node(root_, key);
}
ValueType *get_random() {
if (size() == 0) {
return nullptr;
} else {
return get_node_by_idx(root_, td::Random::fast_uint32() % size());
}
}
bool exists(KeyType &key) {
return get_node(root_, key) != nullptr;
}
};
} // namespace td

View File

@ -0,0 +1,46 @@
//
// Copyright Aliaksei Levin (levlam@telegram.org), Arseny Smirnov (arseny30@gmail.com) 2014-2018
//
// Distributed under the Boost Software License, Version 1.0. (See accompanying
// file LICENSE_1_0.txt or copy at http://www.boost.org/LICENSE_1_0.txt)
//
#include "td/utils/common.h"
namespace td {
class Destructor {
public:
Destructor() = default;
Destructor(const Destructor &other) = delete;
Destructor &operator=(const Destructor &other) = delete;
Destructor(Destructor &&other) = default;
Destructor &operator=(Destructor &&other) = default;
virtual ~Destructor() = default;
};
template <class F>
class LambdaDestructor : public Destructor {
public:
explicit LambdaDestructor(F &&f) : f_(std::move(f)) {
}
LambdaDestructor(const LambdaDestructor &other) = delete;
LambdaDestructor &operator=(const LambdaDestructor &other) = delete;
LambdaDestructor(LambdaDestructor &&other) = default;
LambdaDestructor &operator=(LambdaDestructor &&other) = default;
~LambdaDestructor() override {
f_();
}
private:
F f_;
};
template <class F>
auto create_destructor(F &&f) {
return std::make_unique<LambdaDestructor<F>>(std::forward<F>(f));
}
template <class F>
auto create_shared_destructor(F &&f) {
return std::make_shared<LambdaDestructor<F>>(std::forward<F>(f));
}
} // namespace td

View File

@ -8,7 +8,6 @@
#include "td/utils/misc.h" #include "td/utils/misc.h"
#include "td/utils/port/EventFd.h" #include "td/utils/port/EventFd.h"
#include "td/utils/SpinLock.h"
#if !TD_EVENTFD_UNSUPPORTED #if !TD_EVENTFD_UNSUPPORTED
#if !TD_WINDOWS #if !TD_WINDOWS
@ -18,6 +17,8 @@
#include <utility> #include <utility>
#include "td/utils/SpinLock.h"
namespace td { namespace td {
// interface like in PollableQueue // interface like in PollableQueue
template <class T> template <class T>

View File

@ -92,7 +92,7 @@ uint32 Random::fast_uint32() {
if (!gen) { if (!gen) {
auto &rg = rand_device_helper; auto &rg = rand_device_helper;
std::seed_seq seq{rg(), rg(), rg(), rg(), rg(), rg(), rg(), rg(), rg(), rg(), rg(), rg()}; std::seed_seq seq{rg(), rg(), rg(), rg(), rg(), rg(), rg(), rg(), rg(), rg(), rg(), rg()};
init_thread_local<std::mt19937>(gen, seq); init_thread_local<std::mt19937>(gen);
} }
return static_cast<uint32>((*gen)()); return static_cast<uint32>((*gen)());
} }
@ -102,7 +102,7 @@ uint64 Random::fast_uint64() {
if (!gen) { if (!gen) {
auto &rg = rand_device_helper; auto &rg = rand_device_helper;
std::seed_seq seq{rg(), rg(), rg(), rg(), rg(), rg(), rg(), rg(), rg(), rg(), rg(), rg()}; std::seed_seq seq{rg(), rg(), rg(), rg(), rg(), rg(), rg(), rg(), rg(), rg(), rg(), rg()};
init_thread_local<std::mt19937_64>(gen, seq); init_thread_local<std::mt19937_64>(gen);
} }
return static_cast<uint64>((*gen)()); return static_cast<uint64>((*gen)());
} }
@ -112,8 +112,34 @@ int Random::fast(int min, int max) {
// to prevent integer overflow and division by zero // to prevent integer overflow and division by zero
min++; min++;
} }
CHECK(min <= max); DCHECK(min <= max);
return static_cast<int>(min + fast_uint32() % (max - min + 1)); // TODO signed_cast return static_cast<int>(min + fast_uint32() % (max - min + 1)); // TODO signed_cast
} }
Random::Xorshift128plus::Xorshift128plus(uint32 seed) {
auto next = [&]() {
// splitmix64
uint64_t z = (seed += UINT64_C(0x9E3779B97F4A7C15));
z = (z ^ (z >> 30)) * UINT64_C(0xBF58476D1CE4E5B9);
z = (z ^ (z >> 27)) * UINT64_C(0x94D049BB133111EB);
return z ^ (z >> 31);
};
seed_[0] = next();
seed_[1] = next();
}
Random::Xorshift128plus::Xorshift128plus(uint64 seed_a, uint64 seed_b) {
seed_[0] = seed_a;
seed_[1] = seed_b;
}
uint64 Random::Xorshift128plus::operator()() {
uint64_t x = seed_[0];
uint64_t const y = seed_[1];
seed_[0] = y;
x ^= x << 23; // a
seed_[1] = x ^ y ^ (x >> 17) ^ (y >> 26); // b, c
return seed_[1] + y;
}
} // namespace td } // namespace td

View File

@ -28,6 +28,16 @@ class Random {
// distribution is not uniform, min and max are included // distribution is not uniform, min and max are included
static int fast(int min, int max); static int fast(int min, int max);
class Xorshift128plus {
public:
Xorshift128plus(uint32 seed);
Xorshift128plus(uint64 seed_a, uint64 seed_b);
uint64 operator()();
private:
uint64 seed_[2];
};
}; };
} // namespace td } // namespace td

View File

@ -38,9 +38,7 @@ class AtomicRefCnt {
}; };
template <class DataT, class DeleterT> template <class DataT, class DeleterT>
class SharedPtrRaw class SharedPtrRaw : public DeleterT, private MpscLinkQueueImpl::Node {
: public DeleterT
, private MpscLinkQueueImpl::Node {
public: public:
explicit SharedPtrRaw(DeleterT deleter) : DeleterT(std::move(deleter)), ref_cnt_{0}, option_magic_(Magic) { explicit SharedPtrRaw(DeleterT deleter) : DeleterT(std::move(deleter)), ref_cnt_{0}, option_magic_(Magic) {
} }
@ -97,12 +95,16 @@ class SharedPtr {
reset(); reset();
} }
explicit SharedPtr(Raw *raw) : raw_(raw) { explicit SharedPtr(Raw *raw) : raw_(raw) {
raw_->inc(); if (raw_) {
raw_->inc();
}
} }
SharedPtr(const SharedPtr &other) : SharedPtr(other.raw_) { SharedPtr(const SharedPtr &other) : SharedPtr(other.raw_) {
} }
SharedPtr &operator=(const SharedPtr &other) { SharedPtr &operator=(const SharedPtr &other) {
other.raw_->inc(); if (other.raw_) {
other.raw_->inc();
}
reset(other.raw_); reset(other.raw_);
return *this; return *this;
} }
@ -160,6 +162,9 @@ class SharedPtr {
raw->init_data(std::forward<ArgsT>(args)...); raw->init_data(std::forward<ArgsT>(args)...);
return SharedPtr<T, DeleterT>(raw.release()); return SharedPtr<T, DeleterT>(raw.release());
} }
bool operator==(const SharedPtr<T, DeleterT> &other) const {
return raw_ == other.raw_;
}
private: private:
Raw *raw_{nullptr}; Raw *raw_{nullptr};

View File

@ -112,6 +112,7 @@ class Slice {
bool operator==(const Slice &a, const Slice &b); bool operator==(const Slice &a, const Slice &b);
bool operator!=(const Slice &a, const Slice &b); bool operator!=(const Slice &a, const Slice &b);
bool operator<(const Slice &a, const Slice &b);
class MutableCSlice : public MutableSlice { class MutableCSlice : public MutableSlice {
struct private_tag {}; struct private_tag {};

View File

@ -280,6 +280,14 @@ inline bool operator!=(const Slice &a, const Slice &b) {
return !(a == b); return !(a == b);
} }
inline bool operator<(const Slice &a, const Slice &b) {
auto x = std::memcmp(a.data(), b.data(), std::min(a.size(), b.size()));
if (x == 0) {
return a.size() < b.size();
}
return x < 0;
}
inline MutableCSlice::MutableCSlice(char *s, char *t) : MutableSlice(s, t) { inline MutableCSlice::MutableCSlice(char *s, char *t) : MutableSlice(s, t) {
CHECK(*t == '\0'); CHECK(*t == '\0');
} }

83
tdutils/td/utils/Span.h Normal file
View File

@ -0,0 +1,83 @@
//
// Copyright Aliaksei Levin (levlam@telegram.org), Arseny Smirnov (arseny30@gmail.com) 2014-2018
//
// Distributed under the Boost Software License, Version 1.0. (See accompanying
// file LICENSE_1_0.txt or copy at http://www.boost.org/LICENSE_1_0.txt)
//
#pragma once
#include "td/utils/Status.h"
namespace td {
namespace detail {
template <class T, class InnerT>
class SpanImpl {
private:
InnerT *data_{nullptr};
size_t size_{0};
public:
SpanImpl() = default;
SpanImpl(InnerT *data, size_t size) : data_(data), size_(size) {
}
SpanImpl(InnerT &data) : SpanImpl(&data, 1) {
}
SpanImpl(const SpanImpl &other) = default;
SpanImpl &operator=(const SpanImpl &other) = default;
template <class OtherInnerT>
SpanImpl(const SpanImpl<T, OtherInnerT> &other) : SpanImpl(other.data(), other.size()) {
}
template <size_t N>
SpanImpl(const std::array<T, N> &arr) : SpanImpl(arr.data(), arr.size()) {
}
template <size_t N>
SpanImpl(std::array<T, N> &arr) : SpanImpl(arr.data(), arr.size()) {
}
SpanImpl(const std::vector<T> &v) : SpanImpl(v.data(), v.size()) {
}
SpanImpl(std::vector<T> &v) : SpanImpl(v.data(), v.size()) {
}
template <class OtherInnerT>
SpanImpl &operator=(const SpanImpl<T, OtherInnerT> &other) {
SpanImpl copy{other};
*this = copy;
}
InnerT &operator[](size_t i) {
DCHECK(i < size());
return data_[i];
}
InnerT *data() const {
return data_;
}
InnerT *begin() const {
return data_;
}
InnerT *end() const {
return data_ + size_;
}
size_t size() const {
return size_;
}
SpanImpl &truncate(size_t size) {
CHECK(size <= size_);
size_ = size;
return *this;
}
SpanImpl substr(size_t offset) {
CHECK(offset <= size_);
return SpanImpl(begin() + offset, size_ - offset);
}
};
} // namespace detail
template <class T>
using Span = detail::SpanImpl<T, const T>;
template <class T>
using MutableSpan = detail::SpanImpl<T, T>;
}; // namespace td

View File

@ -0,0 +1,64 @@
//
// Copyright Aliaksei Levin (levlam@telegram.org), Arseny Smirnov (arseny30@gmail.com) 2014-2018
//
// Distributed under the Boost Software License, Version 1.0. (See accompanying
// file LICENSE_1_0.txt or copy at http://www.boost.org/LICENSE_1_0.txt)
//
#pragma once
#include "td/utils/Span.h"
#include <vector>
namespace td {
template <class T>
class VectorQueue {
public:
template <class S>
void push(S&& s) {
vector_.push_back(std::forward<S>(s));
}
T pop() {
try_shrink();
return std::move(vector_[read_pos_++]);
}
void pop_n(size_t n) {
read_pos_ += n;
try_shrink();
}
T& front() {
return vector_[read_pos_];
}
T& back() {
vector_.back();
}
bool empty() const {
return size() == 0;
}
size_t size() const {
return vector_.size() - read_pos_;
}
T* data() {
return vector_.data() + read_pos_;
}
const T* data() const {
return vector_.data() + read_pos_;
}
Span<T> as_span() const {
return {data(), size()};
}
MutableSpan<T> as_mutable_span() {
return {data(), size()};
}
private:
std::vector<T> vector_;
size_t read_pos_{0};
void try_shrink() {
if (read_pos_ * 2 > vector_.size() && read_pos_ > 4) {
vector_.erase(vector_.begin(), vector_.begin() + read_pos_);
read_pos_ = 0;
}
}
};
} // namespace td

View File

@ -110,4 +110,77 @@ BufferRaw *BufferAllocator::create_buffer_raw(size_t size) {
buffer_raw->was_reader_ = false; buffer_raw->was_reader_ = false;
return buffer_raw; return buffer_raw;
} }
void BufferBuilder::append(BufferSlice slice) {
if (append_inplace(slice.as_slice())) {
return;
}
append_slow(std::move(slice));
}
void BufferBuilder::append(Slice slice) {
if (append_inplace(slice)) {
return;
}
append_slow(BufferSlice(slice));
}
void BufferBuilder::prepend(BufferSlice slice) {
if (prepend_inplace(slice.as_slice())) {
return;
}
prepend_slow(std::move(slice));
}
void BufferBuilder::prepend(Slice slice) {
if (prepend_inplace(slice)) {
return;
}
prepend_slow(BufferSlice(slice));
}
BufferSlice BufferBuilder::extract() {
if (to_append_.empty() && to_prepend_.empty()) {
return buffer_writer_.as_buffer_slice();
}
size_t total_size = 0;
for_each([&](auto &&slice) { total_size += slice.size(); });
BufferWriter writer(0, 0, total_size);
for_each([&](auto &&slice) {
writer.prepare_append().truncate(slice.size()).copy_from(slice.as_slice());
writer.confirm_append(slice.size());
});
*this = {};
return writer.as_buffer_slice();
}
bool BufferBuilder::append_inplace(Slice slice) {
if (!to_append_.empty()) {
return false;
}
auto dest = buffer_writer_.prepare_append();
if (dest.size() < slice.size()) {
return false;
}
dest.remove_suffix(dest.size() - slice.size());
dest.copy_from(slice);
buffer_writer_.confirm_append(slice.size());
return true;
}
void BufferBuilder::append_slow(BufferSlice slice) {
to_append_.push_back(std::move(slice));
}
bool BufferBuilder::prepend_inplace(Slice slice) {
if (!to_prepend_.empty()) {
return false;
}
auto dest = buffer_writer_.prepare_prepend();
if (dest.size() < slice.size()) {
return false;
}
dest.remove_prefix(dest.size() - slice.size());
dest.copy_from(slice);
buffer_writer_.confirm_prepend(slice.size());
return true;
}
void BufferBuilder::prepend_slow(BufferSlice slice) {
to_prepend_.push_back(std::move(slice));
}
} // namespace td } // namespace td

View File

@ -10,6 +10,7 @@
#include "td/utils/logging.h" #include "td/utils/logging.h"
#include "td/utils/port/thread_local.h" #include "td/utils/port/thread_local.h"
#include "td/utils/Slice.h" #include "td/utils/Slice.h"
#include "td/utils/misc.h"
#include <atomic> #include <atomic>
#include <cstring> #include <cstring>
@ -143,6 +144,10 @@ class BufferSlice {
return Slice(buffer_->data_ + begin_, size()); return Slice(buffer_->data_ + begin_, size());
} }
operator Slice() const {
return as_slice();
}
MutableSlice as_slice() { MutableSlice as_slice() {
if (is_null()) { if (is_null()) {
return MutableSlice(); return MutableSlice();
@ -202,9 +207,17 @@ class BufferSlice {
} }
size_t size() const { size_t size() const {
if (is_null()) {
return 0;
}
return end_ - begin_; return end_ - begin_;
} }
// like in std::string
size_t length() const {
return size();
}
// set end_ into writer's end_ // set end_ into writer's end_
size_t sync_with_writer() { size_t sync_with_writer() {
CHECK(!is_null()); CHECK(!is_null());
@ -216,6 +229,11 @@ class BufferSlice {
CHECK(!is_null()); CHECK(!is_null());
return buffer_->has_writer_.load(std::memory_order_acquire); return buffer_->has_writer_.load(std::memory_order_acquire);
} }
void clear() {
begin_ = 0;
end_ = 0;
buffer_ = nullptr;
}
private: private:
BufferReaderPtr buffer_; BufferReaderPtr buffer_;
@ -241,6 +259,10 @@ class BufferWriter {
BufferWriter(size_t size, size_t prepend, size_t append) BufferWriter(size_t size, size_t prepend, size_t append)
: BufferWriter(BufferAllocator::create_writer(size, prepend, append)) { : BufferWriter(BufferAllocator::create_writer(size, prepend, append)) {
} }
BufferWriter(Slice slice, size_t prepend, size_t append)
: BufferWriter(BufferAllocator::create_writer(slice.size(), prepend, append)) {
as_slice().copy_from(slice);
}
explicit BufferWriter(BufferWriterPtr buffer_ptr) : buffer_(std::move(buffer_ptr)) { explicit BufferWriter(BufferWriterPtr buffer_ptr) : buffer_(std::move(buffer_ptr)) {
} }
@ -616,7 +638,7 @@ class ChainBufferWriter {
} }
// legacy // legacy
static ChainBufferWriter create_empty(size_t size = 0) { static ChainBufferWriter create_empty(size_t /*size*/ = 0) {
return ChainBufferWriter(); return ChainBufferWriter();
} }
@ -705,4 +727,43 @@ class ChainBufferWriter {
BufferWriter writer_; BufferWriter writer_;
}; };
class BufferBuilder {
public:
BufferBuilder() = default;
BufferBuilder(Slice slice, size_t prepend_size, size_t append_size)
: buffer_writer_(slice, prepend_size, append_size) {
}
void append(BufferSlice slice);
void append(Slice slice);
void prepend(BufferSlice slice);
void prepend(Slice slice);
template <class F>
void for_each(F &&f) {
for (auto &slice : reversed(to_prepend_)) {
f(slice);
}
if (!buffer_writer_.empty()) {
f(buffer_writer_.as_buffer_slice());
}
for (auto &slice : to_append_) {
f(slice);
}
}
BufferSlice extract();
private:
BufferWriter buffer_writer_;
std::vector<BufferSlice> to_append_;
std::vector<BufferSlice> to_prepend_;
bool append_inplace(Slice slice);
void append_slow(BufferSlice slice);
bool prepend_inplace(Slice slice);
void prepend_slow(BufferSlice slice);
}; // namespace td
} // namespace td } // namespace td

View File

@ -295,6 +295,22 @@ auto concat(const ArgsT &... args) {
return Concat<decltype(std::tie(args...))>{std::tie(args...)}; return Concat<decltype(std::tie(args...))>{std::tie(args...)};
} }
template <class F>
struct Lambda {
const F &f;
};
template <class F>
StringBuilder &operator<<(StringBuilder &sb, const Lambda<F> &f) {
f.f(sb);
return sb;
}
template <class LambdaT>
Lambda<LambdaT> lambda(const LambdaT &lambda) {
return Lambda<LambdaT>{lambda};
}
} // namespace format } // namespace format
using format::tag; using format::tag;

View File

@ -7,6 +7,7 @@
#pragma once #pragma once
#include "td/utils/port/platform.h" #include "td/utils/port/platform.h"
//#include "td/utils/format.h"
#include <cstddef> #include <cstddef>
#include <cstdint> #include <cstdint>
@ -54,11 +55,42 @@ inline bool operator==(const UInt<size> &a, const UInt<size> &b) {
return std::memcmp(a.raw, b.raw, sizeof(a.raw)) == 0; return std::memcmp(a.raw, b.raw, sizeof(a.raw)) == 0;
} }
template <size_t size>
inline td::UInt<size> operator^(const UInt<size> &a, const UInt<size> &b) {
td::UInt<size> res;
for (size_t i = 0; i * 8 < size; i++) {
res.raw[i] = a.raw[i] ^ b.raw[i];
}
return res;
}
template <size_t size> template <size_t size>
inline bool operator!=(const UInt<size> &a, const UInt<size> &b) { inline bool operator!=(const UInt<size> &a, const UInt<size> &b) {
return !(a == b); return !(a == b);
} }
template <size_t size>
inline bool is_zero(const UInt<size> &a) {
for (size_t i = 0; i * 8 < size; i++) {
if (a.raw[i]) {
return false;
}
}
return true;
}
template <size_t size>
inline int get_kth_bit(const UInt<size> &a, uint32 bit) {
uint8 b = a.raw[bit / 8];
bit &= 7;
return (b >> (7 - bit)) & 1;
}
template <size_t size>
inline bool operator<(const UInt<size> &a, const UInt<size> &b) {
return memcmp(a.raw, b.raw, sizeof(a.raw)) < 0;
}
using UInt128 = UInt<128>; using UInt128 = UInt<128>;
using UInt256 = UInt<256>; using UInt256 = UInt<256>;

View File

@ -116,18 +116,18 @@ auto invoke(F &&f,
} }
template <class F, class... Args, std::size_t... S> template <class F, class... Args, std::size_t... S>
void call_tuple_impl(F &func, std::tuple<Args...> &&tuple, IntSeq<S...>) { auto call_tuple_impl(F &func, std::tuple<Args...> &&tuple, IntSeq<S...>) {
func(std::forward<Args>(std::get<S>(tuple))...); return func(std::forward<Args>(std::get<S>(tuple))...);
} }
template <class... Args, std::size_t... S> template <class... Args, std::size_t... S>
void invoke_tuple_impl(std::tuple<Args...> &&tuple, IntSeq<S...>) { auto invoke_tuple_impl(std::tuple<Args...> &&tuple, IntSeq<S...>) {
invoke(std::forward<Args>(std::get<S>(tuple))...); return invoke(std::forward<Args>(std::get<S>(tuple))...);
} }
template <class ActorT, class F, class... Args, std::size_t... S> template <class ActorT, class F, class... Args, std::size_t... S>
void mem_call_tuple_impl(ActorT *actor, F &func, std::tuple<Args...> &&tuple, IntSeq<S...>) { auto mem_call_tuple_impl(ActorT *actor, std::tuple<F, Args...> &&tuple, IntSeq<0, S...>) {
(actor->*func)(std::forward<Args>(std::get<S>(tuple))...); return (actor->*std::get<0>(tuple))(std::forward<Args>(std::get<S>(tuple))...);
} }
template <class F, class... Args, std::size_t... S> template <class F, class... Args, std::size_t... S>
@ -151,18 +151,18 @@ class LogicAnd {
}; };
template <class F, class... Args> template <class F, class... Args>
void call_tuple(F &func, std::tuple<Args...> &&tuple) { auto call_tuple(F &func, std::tuple<Args...> &&tuple) {
detail::call_tuple_impl(func, std::move(tuple), detail::IntRange<sizeof...(Args)>()); return detail::call_tuple_impl(func, std::move(tuple), detail::IntRange<sizeof...(Args)>());
} }
template <class... Args> template <class... Args>
void invoke_tuple(std::tuple<Args...> &&tuple) { auto invoke_tuple(std::tuple<Args...> &&tuple) {
detail::invoke_tuple_impl(std::move(tuple), detail::IntRange<sizeof...(Args)>()); return detail::invoke_tuple_impl(std::move(tuple), detail::IntRange<sizeof...(Args)>());
} }
template <class ActorT, class F, class... Args> template <class ActorT, class... Args>
void mem_call_tuple(ActorT *actor, F &func, std::tuple<Args...> &&tuple) { auto mem_call_tuple(ActorT *actor, std::tuple<Args...> &&tuple) {
detail::mem_call_tuple_impl(actor, func, std::move(tuple), detail::IntRange<sizeof...(Args)>()); return detail::mem_call_tuple_impl(actor, std::move(tuple), detail::IntRange<sizeof...(Args)>());
} }
template <class F, class... Args> template <class F, class... Args>
@ -175,4 +175,36 @@ void tuple_for_each(const std::tuple<Args...> &tuple, const F &func) {
detail::tuple_for_each_impl(tuple, func, detail::IntRange<sizeof...(Args)>()); detail::tuple_for_each_impl(tuple, func, detail::IntRange<sizeof...(Args)>());
} }
template <size_t N, class Arg, std::enable_if_t<N == 0> = 0>
auto &&get_nth_argument(Arg &&arg) {
return std::forward<Arg>(arg);
}
template <size_t N, class Arg, class... Args, std::enable_if_t<N == 0, int> = 0>
auto &&get_nth_argument(Arg &&arg, Args &&... args) {
return std::forward<Arg>(arg);
}
template <size_t N, class Arg, class... Args, std::enable_if_t<N != 0, int> = 0>
auto &&get_nth_argument(Arg &&arg, Args &&... args) {
return get_nth_argument<N - 1>(std::forward<Args &&>(args)...);
}
template <class... Args>
auto &&get_last_argument(Args &&... args) {
return get_nth_argument<sizeof...(Args) - 1>(std::forward<Args &&>(args)...);
}
namespace detail {
template <class F, class... Args, std::size_t... S>
auto call_n_arguments_impl(IntSeq<S...>, F &&f, Args &&... args) {
return f(get_nth_argument<S>(std::forward<Args>(args)...)...);
}
} // namespace detail
template <size_t N, class F, class... Args>
auto call_n_arguments(F &&f, Args &&... args) {
return detail::call_n_arguments_impl(detail::IntRange<N>(), f, std::forward<Args>(args)...);
}
} // namespace td } // namespace td

View File

@ -7,7 +7,7 @@
#include "td/utils/logging.h" #include "td/utils/logging.h"
#include "td/utils/port/Clocks.h" #include "td/utils/port/Clocks.h"
#include "td/utils/port/Fd.h" #include "td/utils/port/FileFd.h"
#include "td/utils/port/thread_local.h" #include "td/utils/port/thread_local.h"
#include "td/utils/Slice.h" #include "td/utils/Slice.h"
#include "td/utils/Time.h" #include "td/utils/Time.h"
@ -27,7 +27,6 @@
namespace td { namespace td {
int VERBOSITY_NAME(level) = VERBOSITY_NAME(DEBUG) + 1;
int VERBOSITY_NAME(net_query) = VERBOSITY_NAME(INFO); int VERBOSITY_NAME(net_query) = VERBOSITY_NAME(INFO);
int VERBOSITY_NAME(td_requests) = VERBOSITY_NAME(INFO); int VERBOSITY_NAME(td_requests) = VERBOSITY_NAME(INFO);
int VERBOSITY_NAME(dc) = VERBOSITY_NAME(DEBUG) + 2; int VERBOSITY_NAME(dc) = VERBOSITY_NAME(DEBUG) + 2;
@ -39,55 +38,79 @@ int VERBOSITY_NAME(actor) = VERBOSITY_NAME(DEBUG) + 10;
int VERBOSITY_NAME(buffer) = VERBOSITY_NAME(DEBUG) + 10; int VERBOSITY_NAME(buffer) = VERBOSITY_NAME(DEBUG) + 10;
int VERBOSITY_NAME(sqlite) = VERBOSITY_NAME(DEBUG) + 10; int VERBOSITY_NAME(sqlite) = VERBOSITY_NAME(DEBUG) + 10;
LogOptions log_options;
TD_THREAD_LOCAL const char *Logger::tag_ = nullptr; TD_THREAD_LOCAL const char *Logger::tag_ = nullptr;
TD_THREAD_LOCAL const char *Logger::tag2_ = nullptr; TD_THREAD_LOCAL const char *Logger::tag2_ = nullptr;
Logger::Logger(LogInterface &log, int log_level, Slice file_name, int line_num, Slice comment, bool simple_mode) Logger::Logger(LogInterface &log, const LogOptions &options, int log_level, Slice file_name, int line_num,
: Logger(log, log_level, simple_mode) { Slice comment)
if (simple_mode) { : Logger(log, options, log_level) {
if (!options_.add_info) {
return; return;
} }
auto last_slash_ = static_cast<int32>(file_name.size()) - 1; // log level
while (last_slash_ >= 0 && file_name[last_slash_] != '/' && file_name[last_slash_] != '\\') {
last_slash_--;
}
file_name = file_name.substr(last_slash_ + 1);
auto thread_id = get_thread_id();
sb_ << '['; sb_ << '[';
if (log_level < 10) { if (log_level < 10) {
sb_ << ' '; sb_ << ' ';
} }
sb_ << log_level << "][t"; sb_ << log_level << "]";
// thread id
auto thread_id = get_thread_id();
sb_ << "[t";
if (thread_id < 10) { if (thread_id < 10) {
sb_ << ' '; sb_ << ' ';
} }
sb_ << thread_id << "][" << StringBuilder::FixedDouble(Clocks::system(), 9) << "][" << file_name << ':' << line_num sb_ << thread_id << "]";
<< ']';
// timestamp
sb_ << "[" << StringBuilder::FixedDouble(Clocks::system(), 9) << "]";
// file : line
if (!file_name.empty()) {
auto last_slash_ = static_cast<int32>(file_name.size()) - 1;
while (last_slash_ >= 0 && file_name[last_slash_] != '/' && file_name[last_slash_] != '\\') {
last_slash_--;
}
file_name = file_name.substr(last_slash_ + 1);
sb_ << "[" << file_name << ':' << line_num << ']';
}
// context from tag_
if (tag_ != nullptr && *tag_) { if (tag_ != nullptr && *tag_) {
sb_ << "[#" << Slice(tag_) << ']'; sb_ << "[#" << Slice(tag_) << ']';
} }
// context from tag2_
if (tag2_ != nullptr && *tag2_) { if (tag2_ != nullptr && *tag2_) {
sb_ << "[!" << Slice(tag2_) << ']'; sb_ << "[!" << Slice(tag2_) << ']';
} }
// comment (e.g. condition in LOG_IF)
if (!comment.empty()) { if (!comment.empty()) {
sb_ << "[&" << comment << ']'; sb_ << "[&" << comment << ']';
} }
sb_ << '\t'; sb_ << '\t';
} }
Logger::~Logger() { Logger::~Logger() {
if (!simple_mode_) { if (options_.fix_newlines) {
sb_ << '\n'; sb_ << '\n';
auto slice = as_cslice(); auto slice = as_cslice();
if (slice.back() != '\n') { if (slice.back() != '\n') {
slice.back() = '\n'; slice.back() = '\n';
} }
while (slice.size() > 1 && slice[slice.size() - 2] == '\n') {
slice.back() = 0;
slice = MutableCSlice(slice.begin(), slice.begin() + slice.size() - 1);
}
log_.append(slice, log_level_);
} else {
log_.append(as_cslice(), log_level_);
} }
log_.append(as_cslice(), log_level_);
} }
TsCerr::TsCerr() { TsCerr::TsCerr() {
@ -96,8 +119,14 @@ TsCerr::TsCerr() {
TsCerr::~TsCerr() { TsCerr::~TsCerr() {
exitCritical(); exitCritical();
} }
namespace {
FileFd &Stderr() {
static FileFd res = FileFd::from_native_fd(NativeFd(2, true)).move_as_ok();
return res;
}
} // namespace
TsCerr &TsCerr::operator<<(Slice slice) { TsCerr &TsCerr::operator<<(Slice slice) {
auto &fd = Fd::Stderr(); auto &fd = Stderr();
if (fd.empty()) { if (fd.empty()) {
return *this; return *this;
} }

View File

@ -30,9 +30,11 @@
#include "td/utils/StringBuilder.h" #include "td/utils/StringBuilder.h"
#include <atomic> #include <atomic>
#include <cstdlib>
#include <iostream>
#include <type_traits> #include <type_traits>
#define PSTR_IMPL() ::td::Logger(::td::NullLog().ref(), 0, true) #define PSTR_IMPL() ::td::Logger(::td::NullLog().ref(), ::td::LogOptions::plain(), 0)
#define PSLICE() ::td::detail::Slicify() & PSTR_IMPL() #define PSLICE() ::td::detail::Slicify() & PSTR_IMPL()
#define PSTRING() ::td::detail::Stringify() & PSTR_IMPL() #define PSTRING() ::td::detail::Stringify() & PSTR_IMPL()
#define PSLICE_SAFE() ::td::detail::SlicifySafe() & PSTR_IMPL() #define PSLICE_SAFE() ::td::detail::SlicifySafe() & PSTR_IMPL()
@ -40,8 +42,8 @@
#define VERBOSITY_NAME(x) verbosity_##x #define VERBOSITY_NAME(x) verbosity_##x
#define GET_VERBOSITY_LEVEL() (::td::VERBOSITY_NAME(level)) #define GET_VERBOSITY_LEVEL() (::td::log_options.level)
#define SET_VERBOSITY_LEVEL(new_level) (::td::VERBOSITY_NAME(level) = (new_level)) #define SET_VERBOSITY_LEVEL(new_level) (::td::log_options.level = (new_level))
#ifndef STRIP_LOG #ifndef STRIP_LOG
#define STRIP_LOG VERBOSITY_NAME(DEBUG) #define STRIP_LOG VERBOSITY_NAME(DEBUG)
@ -49,14 +51,15 @@
#define LOG_IS_STRIPPED(strip_level) \ #define LOG_IS_STRIPPED(strip_level) \
(std::integral_constant<int, VERBOSITY_NAME(strip_level)>() > std::integral_constant<int, STRIP_LOG>()) (std::integral_constant<int, VERBOSITY_NAME(strip_level)>() > std::integral_constant<int, STRIP_LOG>())
#define LOGGER(level, comment) \ #define LOGGER(interface, options, level, comment) ::td::Logger(interface, options, level, __FILE__, __LINE__, comment)
::td::Logger(*::td::log_interface, VERBOSITY_NAME(level), __FILE__, __LINE__, comment, \
VERBOSITY_NAME(level) == VERBOSITY_NAME(PLAIN))
#define LOG_IMPL(strip_level, level, condition, comment) \ #define LOG_IMPL_FULL(interface, options, strip_level, runtime_level, condition, comment) \
LOG_IS_STRIPPED(strip_level) || VERBOSITY_NAME(level) > GET_VERBOSITY_LEVEL() || !(condition) \ LOG_IS_STRIPPED(strip_level) || runtime_level > options.level || !(condition) \
? (void)0 \ ? (void)0 \
: ::td::detail::Voidify() & LOGGER(level, comment) : ::td::detail::Voidify() & LOGGER(interface, options, runtime_level, comment)
#define LOG_IMPL(strip_level, level, condition, comment) \
LOG_IMPL_FULL(*::td::log_interface, ::td::log_options, strip_level, VERBOSITY_NAME(level), condition, comment)
#define LOG(level) LOG_IMPL(level, level, true, ::td::Slice()) #define LOG(level) LOG_IMPL(level, level, true, ::td::Slice())
#define LOG_IF(level, condition) LOG_IMPL(level, level, condition, #condition) #define LOG_IF(level, condition) LOG_IMPL(level, level, condition, #condition)
@ -81,6 +84,7 @@ inline bool no_return_func() {
#ifdef CHECK #ifdef CHECK
#undef CHECK #undef CHECK
#endif #endif
#define DUMMY_CHECK(condition) LOG_IF(NEVER, !(condition))
#ifdef TD_DEBUG #ifdef TD_DEBUG
#if TD_MSVC #if TD_MSVC
#define CHECK(condition) \ #define CHECK(condition) \
@ -90,7 +94,12 @@ inline bool no_return_func() {
#define CHECK(condition) LOG_IMPL(FATAL, FATAL, !(condition) && no_return_func(), #condition) #define CHECK(condition) LOG_IMPL(FATAL, FATAL, !(condition) && no_return_func(), #condition)
#endif #endif
#else #else
#define CHECK(condition) LOG_IF(NEVER, !(condition)) #define CHECK DUMMY_CHECK
#endif
#if NDEBUG
#define DCHECK DUMMY_CHECK
#else
#define DCHECK CHECK
#endif #endif
// clang-format on // clang-format on
@ -107,7 +116,6 @@ constexpr int VERBOSITY_NAME(DEBUG) = 4;
constexpr int VERBOSITY_NAME(NEVER) = 1024; constexpr int VERBOSITY_NAME(NEVER) = 1024;
namespace td { namespace td {
extern int VERBOSITY_NAME(level);
// TODO Not part of utils. Should be in some separate file // TODO Not part of utils. Should be in some separate file
extern int VERBOSITY_NAME(mtproto); extern int VERBOSITY_NAME(mtproto);
extern int VERBOSITY_NAME(raw_mtproto); extern int VERBOSITY_NAME(raw_mtproto);
@ -120,6 +128,18 @@ extern int VERBOSITY_NAME(buffer);
extern int VERBOSITY_NAME(files); extern int VERBOSITY_NAME(files);
extern int VERBOSITY_NAME(sqlite); extern int VERBOSITY_NAME(sqlite);
struct LogOptions {
int level{VERBOSITY_NAME(DEBUG) + 1};
bool fix_newlines{true};
bool add_info{true};
static constexpr LogOptions plain() {
return {0, false, false};
}
};
extern LogOptions log_options;
class LogInterface { class LogInterface {
public: public:
LogInterface() = default; LogInterface() = default;
@ -128,13 +148,19 @@ class LogInterface {
LogInterface(LogInterface &&) = delete; LogInterface(LogInterface &&) = delete;
LogInterface &operator=(LogInterface &&) = delete; LogInterface &operator=(LogInterface &&) = delete;
virtual ~LogInterface() = default; virtual ~LogInterface() = default;
virtual void append(CSlice slice, int log_level_) = 0; virtual void append(CSlice slice) {
virtual void rotate() = 0; append(slice, -1);
}
virtual void append(CSlice slice, int /*log_level_*/) {
append(slice);
}
virtual void rotate() {
}
}; };
class NullLog : public LogInterface { class NullLog : public LogInterface {
public: public:
void append(CSlice slice, int log_level_) override { void append(CSlice /*slice*/, int /*log_level_*/) override {
} }
void rotate() override { void rotate() override {
} }
@ -179,15 +205,15 @@ class TsCerr {
class Logger { class Logger {
public: public:
static const int BUFFER_SIZE = 128 * 1024; static const int BUFFER_SIZE = 128 * 1024;
Logger(LogInterface &log, int log_level, bool simple_mode = false) Logger(LogInterface &log, const LogOptions &options, int log_level)
: buffer_(StackAllocator::alloc(BUFFER_SIZE)) : buffer_(StackAllocator::alloc(BUFFER_SIZE))
, log_(log) , log_(log)
, log_level_(log_level)
, sb_(buffer_.as_slice()) , sb_(buffer_.as_slice())
, simple_mode_(simple_mode) { , options_(options)
, log_level_(log_level) {
} }
Logger(LogInterface &log, int log_level, Slice file_name, int line_num, Slice comment, bool simple_mode); Logger(LogInterface &log, const LogOptions &options, int log_level, Slice file_name, int line_num, Slice comment);
template <class T> template <class T>
Logger &operator<<(const T &other) { Logger &operator<<(const T &other) {
@ -213,9 +239,9 @@ class Logger {
private: private:
decltype(StackAllocator::alloc(0)) buffer_; decltype(StackAllocator::alloc(0)) buffer_;
LogInterface &log_; LogInterface &log_;
int log_level_;
StringBuilder sb_; StringBuilder sb_;
bool simple_mode_; const LogOptions &options_;
int log_level_;
}; };
namespace detail { namespace detail {
@ -273,6 +299,7 @@ class TsLog : public LogInterface {
lock_.clear(std::memory_order_release); lock_.clear(std::memory_order_release);
} }
}; };
} // namespace td } // namespace td
#include "td/utils/Slice.h" #include "td/utils/Slice.h"

View File

@ -284,8 +284,8 @@ string url_encode(Slice str);
namespace detail { namespace detail {
template <class T, class U> template <class T, class U>
struct is_same_signedness struct is_same_signedness : public std::integral_constant<bool, std::is_signed<T>::value == std::is_signed<U>::value> {
: public std::integral_constant<bool, std::is_signed<T>::value == std::is_signed<U>::value> {}; };
template <class T, class Enable = void> template <class T, class Enable = void>
struct safe_undeflying_type { struct safe_undeflying_type {
@ -350,4 +350,24 @@ bool is_aligned_pointer(const T *pointer) {
return (reinterpret_cast<std::uintptr_t>(static_cast<const void *>(pointer)) & (Alignment - 1)) == 0; return (reinterpret_cast<std::uintptr_t>(static_cast<const void *>(pointer)) & (Alignment - 1)) == 0;
} }
template <typename T>
struct reversion_wrapper {
T &iterable;
};
template <typename T>
auto begin(reversion_wrapper<T> w) {
return rbegin(w.iterable);
}
template <typename T>
auto end(reversion_wrapper<T> w) {
return rend(w.iterable);
}
template <typename T>
reversion_wrapper<T> reversed(T &&iterable) {
return {iterable};
}
} // namespace td } // namespace td

View File

@ -20,9 +20,7 @@ struct overload<F> : public F {
}; };
template <class F, class... Fs> template <class F, class... Fs>
struct overload<F, Fs...> struct overload<F, Fs...> : public overload<F>, overload<Fs...> {
: public overload<F>
, overload<Fs...> {
overload(F f, Fs... fs) : overload<F>(f), overload<Fs...>(fs...) { overload(F f, Fs... fs) : overload<F>(f), overload<Fs...>(fs...) {
} }
using overload<F>::operator(); using overload<F>::operator();

View File

@ -7,7 +7,7 @@
#pragma once #pragma once
#include "td/utils/common.h" #include "td/utils/common.h"
#include "td/utils/port/Fd.h" #include "td/utils/port/detail/PollableFd.h"
#include "td/utils/Status.h" #include "td/utils/Status.h"
namespace td { namespace td {
@ -23,8 +23,7 @@ class EventFdBase {
virtual void init() = 0; virtual void init() = 0;
virtual bool empty() = 0; virtual bool empty() = 0;
virtual void close() = 0; virtual void close() = 0;
virtual const Fd &get_fd() const = 0; virtual PollableFdInfo &get_poll_info() = 0;
virtual Fd &get_fd() = 0;
virtual Status get_pending_error() TD_WARN_UNUSED_RESULT = 0; virtual Status get_pending_error() TD_WARN_UNUSED_RESULT = 0;
virtual void release() = 0; virtual void release() = 0;
virtual void acquire() = 0; virtual void acquire() = 0;

View File

@ -5,6 +5,7 @@
// file LICENSE_1_0.txt or copy at http://www.boost.org/LICENSE_1_0.txt) // file LICENSE_1_0.txt or copy at http://www.boost.org/LICENSE_1_0.txt)
// //
#include "td/utils/port/Fd.h" #include "td/utils/port/Fd.h"
#if 0
#include "td/utils/common.h" #include "td/utils/common.h"
#include "td/utils/format.h" #include "td/utils/format.h"
@ -47,6 +48,45 @@ Fd::Info &Fd::InfoSet::get_info(int32 id) {
} }
Fd::InfoSet Fd::fd_info_set_; Fd::InfoSet Fd::fd_info_set_;
bool Fd::FlagsSet::write_flags(Flags flags) {
if (!flags) {
return false;
}
auto old_flags = to_write_.fetch_or(flags, std::memory_order_relaxed);
return (flags & ~old_flags) != 0;
}
bool Fd::FlagsSet::write_flags_local(Flags flags) {
auto old_flags = flags_;
flags_ |= flags;
return flags_ != old_flags;
}
bool Fd::FlagsSet::flush() const {
if (to_write_.load(std::memory_order_relaxed) == 0) {
return false;
}
Flags to_write = to_write_.exchange(0, std::memory_order_relaxed);
auto old_flags = flags_;
flags_ |= to_write;
if (flags_ & Close) {
flags_ &= ~Write;
}
return flags_ != old_flags;
}
Fd::Flags Fd::FlagsSet::read_flags() const {
flush();
return flags_;
}
Fd::Flags Fd::FlagsSet::read_flags_local() const {
return flags_;
}
void Fd::FlagsSet::clear_flags(Flags flags) {
flags_ &= ~flags;
}
void Fd::FlagsSet::clear() {
to_write_ = 0;
flags_ = 0;
}
// TODO(bug) if constuctor call tries to output something to the LOG it will fail, because log is not initialized // TODO(bug) if constuctor call tries to output something to the LOG it will fail, because log is not initialized
Fd Fd::stderr_(2, Mode::Reference); Fd Fd::stderr_(2, Mode::Reference);
Fd Fd::stdout_(1, Mode::Reference); Fd Fd::stdout_(1, Mode::Reference);
@ -70,7 +110,7 @@ Fd::Fd(int fd, Mode mode) : mode_(mode), fd_(fd) {
info->refcnt.store(1, std::memory_order_relaxed); info->refcnt.store(1, std::memory_order_relaxed);
CHECK(mode_ != Mode::Reference); CHECK(mode_ != Mode::Reference);
CHECK(info->observer == nullptr); CHECK(info->observer == nullptr);
info->flags = 0; info->flags.clear();
info->observer = nullptr; info->observer = nullptr;
} else { } else {
CHECK(mode_ == Mode::Reference) << tag("fd", fd_); CHECK(mode_ == Mode::Reference) << tag("fd", fd_);
@ -209,66 +249,62 @@ void Fd::clear_info() {
auto *info = get_info(); auto *info = get_info();
int old_ref_cnt = info->refcnt.load(std::memory_order_relaxed); int old_ref_cnt = info->refcnt.load(std::memory_order_relaxed);
CHECK(old_ref_cnt == 1); CHECK(old_ref_cnt == 1) << old_ref_cnt;
info->flags = 0; info->flags.clear();
info->observer = nullptr; info->observer = nullptr;
info->refcnt.store(0, std::memory_order_release); info->refcnt.store(0, std::memory_order_release);
} }
void Fd::update_flags_notify(Flags flags) { void Fd::after_notify() {
update_flags_inner(flags, true); get_info()->flags.flush();
}
void Fd::update_flags_notify(Flags new_flags) {
auto *info = get_info();
auto &flags = info->flags;
if (!flags.write_flags(new_flags)) {
return;
}
VLOG(fd) << "Add flags " << tag("fd", fd_) << tag("to", format::as_binary(new_flags));
auto observer = info->observer;
if (observer == nullptr) {
return;
}
observer->notify();
} }
void Fd::update_flags(Flags flags) { void Fd::update_flags(Flags flags) {
update_flags_inner(flags, false); get_info()->flags.write_flags_local(flags);
}
void Fd::update_flags_inner(int32 new_flags, bool notify_flag) {
if (new_flags & Error) {
new_flags |= Error;
new_flags |= Close;
}
auto *info = get_info();
int32 &flags = info->flags;
int32 old_flags = flags;
flags |= new_flags;
if (new_flags & Close) {
// TODO: ???
flags &= ~Write;
}
if (flags != old_flags) {
VLOG(fd) << "Update flags " << tag("fd", fd_) << tag("from", format::as_binary(old_flags))
<< tag("to", format::as_binary(flags));
}
if (flags != old_flags && notify_flag) {
auto observer = info->observer;
if (observer != nullptr) {
observer->notify();
}
}
} }
Fd::Flags Fd::get_flags() const { Fd::Flags Fd::get_flags() const {
return get_info()->flags; return get_info()->flags.read_flags();
}
Fd::Flags Fd::get_flags_local() const {
return get_info()->flags.read_flags_local();
} }
void Fd::clear_flags(Flags flags) { void Fd::clear_flags(Flags flags) {
get_info()->flags &= ~flags; get_info()->flags.clear_flags(flags);
} }
bool Fd::has_pending_error() const { bool Fd::has_pending_error() const {
return (get_flags() & Fd::Flag::Error) != 0; return (get_flags() & Fd::Flag::Error) != 0;
} }
bool Fd::has_pending_error_local() const {
return (get_flags_local() & Fd::Flag::Error) != 0;
}
Status Fd::get_pending_error() { Status Fd::get_pending_error() {
if (!has_pending_error()) { if (!has_pending_error()) {
return Status::OK(); return Status::OK();
} }
clear_flags(Fd::Error);
int error = 0; int error = 0;
socklen_t errlen = sizeof(error); socklen_t errlen = sizeof(error);
if (getsockopt(fd_, SOL_SOCKET, SO_ERROR, static_cast<void *>(&error), &errlen) == 0) { if (getsockopt(fd_, SOL_SOCKET, SO_ERROR, static_cast<void *>(&error), &errlen) == 0) {
if (error == 0) { if (error == 0) {
clear_flags(Fd::Error);
return Status::OK(); return Status::OK();
} }
return Status::PosixError(error, PSLICE() << "Error on socket [fd_ = " << fd_ << "]"); return Status::PosixError(error, PSLICE() << "Error on socket [fd_ = " << fd_ << "]");
@ -331,6 +367,9 @@ Result<size_t> Fd::write(Slice slice) {
} }
Result<size_t> Fd::read(MutableSlice slice) { Result<size_t> Fd::read(MutableSlice slice) {
if (has_pending_error()) {
return get_pending_error();
}
int native_fd = get_native_fd(); int native_fd = get_native_fd();
CHECK(slice.size() > 0); CHECK(slice.size() > 0);
auto read_res = skip_eintr([&] { return ::read(native_fd, slice.begin(), slice.size()); }); auto read_res = skip_eintr([&] { return ::read(native_fd, slice.begin(), slice.size()); });
@ -427,35 +466,23 @@ class Fd::FdImpl {
} }
void update_flags_inner(int32 new_flags, bool notify_flag) { void update_flags_inner(int32 new_flags, bool notify_flag) {
if (new_flags & Fd::Error) { if (!flags_.write_flags_local(new_flags)) {
new_flags |= Fd::Error; return;
new_flags |= Fd::Close;
} }
int32 old_flags = flags_; VLOG(fd) << "Add flags " << tag("fd", get_io_handle()) << tag("to", format::as_binary(new_flags));
flags_ |= new_flags; auto observer = observer_;
if (new_flags & Fd::Close) { if (!notify_flag || observer == nullptr) {
// TODO: ??? return;
flags_ &= ~Fd::Write;
internal_flags_ &= ~Fd::Write;
}
if (flags_ != old_flags) {
VLOG(fd) << "Update flags " << tag("fd", get_io_handle()) << tag("from", format::as_binary(old_flags))
<< tag("to", format::as_binary(flags_));
}
if (flags_ != old_flags && notify_flag) {
auto observer = get_observer();
if (observer != nullptr) {
observer->notify();
}
} }
observer->notify();
} }
int32 get_flags() const { int32 get_flags() const {
return flags_; return flags_.read_flags_local();
} }
void clear_flags(Fd::Flags mask) { void clear_flags(Fd::Flags mask) {
flags_ &= ~mask; flags_.clear_flags(mask);
} }
Status get_pending_error() { Status get_pending_error() {
@ -688,7 +715,7 @@ class Fd::FdImpl {
bool async_mode_ = false; bool async_mode_ = false;
ObserverBase *observer_ = nullptr; ObserverBase *observer_ = nullptr;
Fd::Flags flags_ = Fd::Flag::Write; Fd::FlagsSet flags_;
Status pending_error_; Status pending_error_;
Fd::Flags internal_flags_ = Fd::Flag::Write | Fd::Flag::Read; Fd::Flags internal_flags_ = Fd::Flag::Write | Fd::Flag::Read;
@ -712,6 +739,7 @@ class Fd::FdImpl {
ChainBufferReader output_reader_ = output_writer_.extract_reader(); ChainBufferReader output_reader_ = output_writer_.extract_reader();
void init() { void init() {
flags_.write_flags_local(Fd::Write);
if (async_mode_) { if (async_mode_) {
if (type_ != Fd::Type::EventFd) { if (type_ != Fd::Type::EventFd) {
write_event_ = CreateEventW(nullptr, true, false, nullptr); write_event_ = CreateEventW(nullptr, true, false, nullptr);
@ -932,6 +960,9 @@ Fd &Fd::get_fd() {
} }
Result<size_t> Fd::read(MutableSlice slice) { Result<size_t> Fd::read(MutableSlice slice) {
if (has_pending_error()) {
return get_pending_error();
}
return impl_->read(slice); return impl_->read(slice);
} }
@ -1102,3 +1133,4 @@ Status set_native_socket_is_blocking(SOCKET fd, bool is_blocking) {
} // namespace detail } // namespace detail
} // namespace td } // namespace td
#endif

View File

@ -6,15 +6,17 @@
// //
#pragma once #pragma once
#if 0
#include "td/utils/port/config.h" #include "td/utils/port/config.h"
#include "td/utils/common.h" #include "td/utils/common.h"
#include "td/utils/Slice.h" #include "td/utils/Slice.h"
#include "td/utils/format.h"
#include "td/utils/Status.h" #include "td/utils/Status.h"
#if TD_PORT_WINDOWS
#include "td/utils/port/IPAddress.h" #include "td/utils/port/IPAddress.h"
#if TD_PORT_WINDOWS
#include <memory> #include <memory>
#endif #endif
@ -49,6 +51,21 @@ class Fd {
None = 0 None = 0
}; };
using Flags = int32; using Flags = int32;
class FlagsSet {
public:
bool write_flags(Flags flags);
bool write_flags_local(Flags flags);
bool flush() const;
Flags read_flags() const;
Flags read_flags_local() const;
void clear_flags(Flags flags);
void clear();
private:
mutable std::atomic<Flags> to_write_;
mutable Flags flags_;
};
enum class Mode { Reference, Owner }; enum class Mode { Reference, Owner };
Fd(); Fd();
@ -91,9 +108,13 @@ class Fd {
void update_flags(Flags flags); void update_flags(Flags flags);
void after_notify();
Flags get_flags() const; Flags get_flags() const;
Flags get_flags_local() const;
bool has_pending_error() const; bool has_pending_error() const;
bool has_pending_error_local() const;
Status get_pending_error() TD_WARN_UNUSED_RESULT; Status get_pending_error() TD_WARN_UNUSED_RESULT;
Result<size_t> write(Slice slice) TD_WARN_UNUSED_RESULT; Result<size_t> write(Slice slice) TD_WARN_UNUSED_RESULT;
@ -132,7 +153,7 @@ class Fd {
#if TD_PORT_POSIX #if TD_PORT_POSIX
struct Info { struct Info {
std::atomic<int> refcnt; std::atomic<int> refcnt;
int32 flags; FlagsSet flags;
ObserverBase *observer; ObserverBase *observer;
}; };
struct InfoSet { struct InfoSet {
@ -201,7 +222,7 @@ auto skip_eintr_cstr(F &&f) {
template <class FdT> template <class FdT>
bool can_read(const FdT &fd) { bool can_read(const FdT &fd) {
return (fd.get_flags() & Fd::Read) != 0; return (fd.get_flags() & (Fd::Read | Fd::Error)) != 0;
} }
template <class FdT> template <class FdT>
@ -214,13 +235,5 @@ bool can_close(const FdT &fd) {
return (fd.get_flags() & Fd::Close) != 0; return (fd.get_flags() & Fd::Close) != 0;
} }
namespace detail {
#if TD_PORT_POSIX
Status set_native_socket_is_blocking(int fd, bool is_blocking);
#endif
#if TD_PORT_WINDOWS
Status set_native_socket_is_blocking(SOCKET fd, bool is_blocking);
#endif
} // namespace detail
} // namespace td } // namespace td
#endif

View File

@ -18,6 +18,8 @@
#include "td/utils/port/sleep.h" #include "td/utils/port/sleep.h"
#include "td/utils/StringBuilder.h" #include "td/utils/StringBuilder.h"
#include "td/utils/port/detail/PollableFd.h"
#include <cstring> #include <cstring>
#if TD_PORT_POSIX #if TD_PORT_POSIX
@ -77,12 +79,19 @@ StringBuilder &operator<<(StringBuilder &sb, const PrintFlags &print_flags) {
} // namespace } // namespace
const Fd &FileFd::get_fd() const { namespace detail {
return fd_; class FileFdImpl {
} public:
PollableFdInfo info;
};
} // namespace detail
Fd &FileFd::get_fd() { FileFd::FileFd() = default;
return fd_; FileFd::FileFd(FileFd &&) = default;
FileFd &FileFd::operator=(FileFd &&) = default;
FileFd::~FileFd() = default;
FileFd::FileFd(std::unique_ptr<detail::FileFdImpl> impl) : impl_(std::move(impl)) {
} }
Result<FileFd> FileFd::open(CSlice filepath, int32 flags, int32 mode) { Result<FileFd> FileFd::open(CSlice filepath, int32 flags, int32 mode) {
@ -126,8 +135,7 @@ Result<FileFd> FileFd::open(CSlice filepath, int32 flags, int32 mode) {
return OS_ERROR(PSLICE() << "File \"" << filepath << "\" can't be " << PrintFlags{flags}); return OS_ERROR(PSLICE() << "File \"" << filepath << "\" can't be " << PrintFlags{flags});
} }
FileFd result; return from_native_fd(NativeFd(native_fd));
result.fd_ = Fd(native_fd, Fd::Mode::Owner);
#elif TD_PORT_WINDOWS #elif TD_PORT_WINDOWS
// TODO: support modes // TODO: support modes
auto r_filepath = to_wstring(filepath); auto r_filepath = to_wstring(filepath);
@ -173,27 +181,30 @@ Result<FileFd> FileFd::open(CSlice filepath, int32 flags, int32 mode) {
if (handle == INVALID_HANDLE_VALUE) { if (handle == INVALID_HANDLE_VALUE) {
return OS_ERROR(PSLICE() << "File \"" << filepath << "\" can't be " << PrintFlags{flags}); return OS_ERROR(PSLICE() << "File \"" << filepath << "\" can't be " << PrintFlags{flags});
} }
auto native_fd = NativeFd(handle);
if (flags & Append) { if (flags & Append) {
LARGE_INTEGER offset; LARGE_INTEGER offset;
offset.QuadPart = 0; offset.QuadPart = 0;
auto set_pointer_res = SetFilePointerEx(handle, offset, nullptr, FILE_END); auto set_pointer_res = SetFilePointerEx(handle, offset, nullptr, FILE_END);
if (!set_pointer_res) { if (!set_pointer_res) {
auto res = OS_ERROR(PSLICE() << "Failed to seek to the end of file \"" << filepath << "\""); auto res = OS_ERROR(PSLICE() << "Failed to seek to the end of file \"" << filepath << "\"");
CloseHandle(handle);
return res; return res;
} }
} }
FileFd result; return from_native_fd(std::move(native_fd));
result.fd_ = Fd::create_file_fd(handle);
#endif #endif
result.fd_.update_flags(Fd::Flag::Write); }
return std::move(result);
Result<FileFd> FileFd::from_native_fd(NativeFd native_fd) {
auto impl = std::make_unique<detail::FileFdImpl>();
impl->info.set_native_fd(std::move(native_fd));
impl->info.add_flags(PollFlags::Write());
return FileFd(std::move(impl));
} }
Result<size_t> FileFd::write(Slice slice) { Result<size_t> FileFd::write(Slice slice) {
#if TD_PORT_POSIX #if TD_PORT_POSIX
CHECK(!fd_.empty()); auto native_fd = get_native_fd().fd();
int native_fd = get_native_fd();
auto write_res = skip_eintr([&] { return ::write(native_fd, slice.begin(), slice.size()); }); auto write_res = skip_eintr([&] { return ::write(native_fd, slice.begin(), slice.size()); });
if (write_res >= 0) { if (write_res >= 0) {
return narrow_cast<size_t>(write_res); return narrow_cast<size_t>(write_res);
@ -210,20 +221,25 @@ Result<size_t> FileFd::write(Slice slice) {
} }
return std::move(error); return std::move(error);
#elif TD_PORT_WINDOWS #elif TD_PORT_WINDOWS
return fd_.write(slice); auto native_fd = get_native_fd().io_handle();
DWORD bytes_written = 0;
auto res = WriteFile(native_fd, slice.data(), narrow_cast<DWORD>(slice.size()), &bytes_written, nullptr);
if (!res) {
return OS_ERROR("Failed to write_sync");
}
return bytes_written;
#endif #endif
} }
Result<size_t> FileFd::read(MutableSlice slice) { Result<size_t> FileFd::read(MutableSlice slice) {
#if TD_PORT_POSIX #if TD_PORT_POSIX
CHECK(!fd_.empty()); auto native_fd = get_native_fd().fd();
int native_fd = get_native_fd();
auto read_res = skip_eintr([&] { return ::read(native_fd, slice.begin(), slice.size()); }); auto read_res = skip_eintr([&] { return ::read(native_fd, slice.begin(), slice.size()); });
auto read_errno = errno; auto read_errno = errno;
if (read_res >= 0) { if (read_res >= 0) {
if (narrow_cast<size_t>(read_res) < slice.size()) { if (narrow_cast<size_t>(read_res) < slice.size()) {
fd_.clear_flags(Read); get_poll_info().clear_flags(PollFlags::Read());
} }
return static_cast<size_t>(read_res); return static_cast<size_t>(read_res);
} }
@ -238,7 +254,16 @@ Result<size_t> FileFd::read(MutableSlice slice) {
} }
return std::move(error); return std::move(error);
#elif TD_PORT_WINDOWS #elif TD_PORT_WINDOWS
return fd_.read(slice); auto native_fd = get_native_fd().io_handle();
DWORD bytes_read = 0;
auto res = ReadFile(native_fd, slice.data(), narrow_cast<DWORD>(slice.size()), &bytes_read, nullptr);
if (!res) {
return OS_ERROR("Failed to read_sync");
}
if (bytes_read == 0) {
get_poll_info().clear_flags(PollFlags::Read());
}
return bytes_read;
#endif #endif
} }
@ -247,9 +272,8 @@ Result<size_t> FileFd::pwrite(Slice slice, int64 offset) {
return Status::Error("Offset must be non-negative"); return Status::Error("Offset must be non-negative");
} }
#if TD_PORT_POSIX #if TD_PORT_POSIX
auto native_fd = get_native_fd().fd();
TRY_RESULT(offset_off_t, narrow_cast_safe<off_t>(offset)); TRY_RESULT(offset_off_t, narrow_cast_safe<off_t>(offset));
CHECK(!fd_.empty());
int native_fd = get_native_fd();
auto pwrite_res = skip_eintr([&] { return ::pwrite(native_fd, slice.begin(), slice.size(), offset_off_t); }); auto pwrite_res = skip_eintr([&] { return ::pwrite(native_fd, slice.begin(), slice.size(), offset_off_t); });
if (pwrite_res >= 0) { if (pwrite_res >= 0) {
return narrow_cast<size_t>(pwrite_res); return narrow_cast<size_t>(pwrite_res);
@ -267,13 +291,13 @@ Result<size_t> FileFd::pwrite(Slice slice, int64 offset) {
} }
return std::move(error); return std::move(error);
#elif TD_PORT_WINDOWS #elif TD_PORT_WINDOWS
auto native_fd = get_native_fd().io_handle();
DWORD bytes_written = 0; DWORD bytes_written = 0;
OVERLAPPED overlapped; OVERLAPPED overlapped;
std::memset(&overlapped, 0, sizeof(overlapped)); std::memset(&overlapped, 0, sizeof(overlapped));
overlapped.Offset = static_cast<DWORD>(offset); overlapped.Offset = static_cast<DWORD>(offset);
overlapped.OffsetHigh = static_cast<DWORD>(offset >> 32); overlapped.OffsetHigh = static_cast<DWORD>(offset >> 32);
auto res = auto res = WriteFile(native_fd, slice.data(), narrow_cast<DWORD>(slice.size()), &bytes_written, &overlapped);
WriteFile(fd_.get_io_handle(), slice.data(), narrow_cast<DWORD>(slice.size()), &bytes_written, &overlapped);
if (!res) { if (!res) {
return OS_ERROR("Failed to pwrite"); return OS_ERROR("Failed to pwrite");
} }
@ -286,9 +310,8 @@ Result<size_t> FileFd::pread(MutableSlice slice, int64 offset) {
return Status::Error("Offset must be non-negative"); return Status::Error("Offset must be non-negative");
} }
#if TD_PORT_POSIX #if TD_PORT_POSIX
auto native_fd = get_native_fd().fd();
TRY_RESULT(offset_off_t, narrow_cast_safe<off_t>(offset)); TRY_RESULT(offset_off_t, narrow_cast_safe<off_t>(offset));
CHECK(!fd_.empty());
int native_fd = get_native_fd();
auto pread_res = skip_eintr([&] { return ::pread(native_fd, slice.begin(), slice.size(), offset_off_t); }); auto pread_res = skip_eintr([&] { return ::pread(native_fd, slice.begin(), slice.size(), offset_off_t); });
if (pread_res >= 0) { if (pread_res >= 0) {
return narrow_cast<size_t>(pread_res); return narrow_cast<size_t>(pread_res);
@ -306,12 +329,13 @@ Result<size_t> FileFd::pread(MutableSlice slice, int64 offset) {
} }
return std::move(error); return std::move(error);
#elif TD_PORT_WINDOWS #elif TD_PORT_WINDOWS
auto native_fd = get_native_fd().io_handle();
DWORD bytes_read = 0; DWORD bytes_read = 0;
OVERLAPPED overlapped; OVERLAPPED overlapped;
std::memset(&overlapped, 0, sizeof(overlapped)); std::memset(&overlapped, 0, sizeof(overlapped));
overlapped.Offset = static_cast<DWORD>(offset); overlapped.Offset = static_cast<DWORD>(offset);
overlapped.OffsetHigh = static_cast<DWORD>(offset >> 32); overlapped.OffsetHigh = static_cast<DWORD>(offset >> 32);
auto res = ReadFile(fd_.get_io_handle(), slice.data(), narrow_cast<DWORD>(slice.size()), &bytes_read, &overlapped); auto res = ReadFile(native_fd, slice.data(), narrow_cast<DWORD>(slice.size()), &bytes_read, &overlapped);
if (!res) { if (!res) {
return OS_ERROR("Failed to pread"); return OS_ERROR("Failed to pread");
} }
@ -323,7 +347,11 @@ Status FileFd::lock(FileFd::LockFlags flags, int32 max_tries) {
if (max_tries <= 0) { if (max_tries <= 0) {
return Status::Error(0, "Can't lock file: wrong max_tries"); return Status::Error(0, "Can't lock file: wrong max_tries");
} }
#if TD_PORT_POSIX
auto native_fd = get_native_fd().fd();
#elif TD_PORT_WINDOWS
auto native_fd = get_native_fd().io_handle();
#endif
while (true) { while (true) {
#if TD_PORT_POSIX #if TD_PORT_POSIX
struct flock lock; struct flock lock;
@ -344,7 +372,7 @@ Status FileFd::lock(FileFd::LockFlags flags, int32 max_tries) {
}()); }());
lock.l_whence = SEEK_SET; lock.l_whence = SEEK_SET;
if (fcntl(get_native_fd(), F_SETLK, &lock) == -1) { if (fcntl(native_fd, F_SETLK, &lock) == -1) {
if (errno == EAGAIN) { if (errno == EAGAIN) {
#elif TD_PORT_WINDOWS #elif TD_PORT_WINDOWS
OVERLAPPED overlapped; OVERLAPPED overlapped;
@ -352,14 +380,14 @@ Status FileFd::lock(FileFd::LockFlags flags, int32 max_tries) {
BOOL result; BOOL result;
if (flags == LockFlags::Unlock) { if (flags == LockFlags::Unlock) {
result = UnlockFileEx(fd_.get_io_handle(), 0, MAXDWORD, MAXDWORD, &overlapped); result = UnlockFileEx(native_fd, 0, MAXDWORD, MAXDWORD, &overlapped);
} else { } else {
DWORD dw_flags = LOCKFILE_FAIL_IMMEDIATELY; DWORD dw_flags = LOCKFILE_FAIL_IMMEDIATELY;
if (flags == LockFlags::Write) { if (flags == LockFlags::Write) {
dw_flags |= LOCKFILE_EXCLUSIVE_LOCK; dw_flags |= LOCKFILE_EXCLUSIVE_LOCK;
} }
result = LockFileEx(fd_.get_io_handle(), dw_flags, 0, MAXDWORD, MAXDWORD, &overlapped); result = LockFileEx(native_fd, dw_flags, 0, MAXDWORD, MAXDWORD, &overlapped);
} }
if (!result) { if (!result) {
@ -380,25 +408,21 @@ Status FileFd::lock(FileFd::LockFlags flags, int32 max_tries) {
} }
void FileFd::close() { void FileFd::close() {
fd_.close(); impl_.reset();
} }
bool FileFd::empty() const { bool FileFd::empty() const {
return fd_.empty(); return !impl_;
} }
#if TD_PORT_POSIX const NativeFd &FileFd::get_native_fd() const {
int FileFd::get_native_fd() const { return get_poll_info().native_fd();
return fd_.get_native_fd();
}
#endif
int32 FileFd::get_flags() const {
return fd_.get_flags();
} }
void FileFd::update_flags(Fd::Flags mask) { NativeFd FileFd::move_as_native_fd() {
fd_.update_flags(mask); auto res = get_poll_info().move_as_native_fd();
impl_.reset();
return res;
} }
int64 FileFd::get_size() { int64 FileFd::get_size() {
@ -415,12 +439,13 @@ static uint64 filetime_to_unix_time_nsec(LONGLONG filetime) {
Stat FileFd::stat() { Stat FileFd::stat() {
CHECK(!empty()); CHECK(!empty());
#if TD_PORT_POSIX #if TD_PORT_POSIX
return detail::fstat(get_native_fd()); return detail::fstat(get_native_fd().fd());
#elif TD_PORT_WINDOWS #elif TD_PORT_WINDOWS
Stat res; Stat res;
FILE_BASIC_INFO basic_info; FILE_BASIC_INFO basic_info;
auto status = GetFileInformationByHandleEx(fd_.get_io_handle(), FileBasicInfo, &basic_info, sizeof(basic_info)); auto status =
GetFileInformationByHandleEx(get_native_fd().io_handle(), FileBasicInfo, &basic_info, sizeof(basic_info));
if (!status) { if (!status) {
auto error = OS_ERROR("Stat failed"); auto error = OS_ERROR("Stat failed");
LOG(FATAL) << error; LOG(FATAL) << error;
@ -431,7 +456,8 @@ Stat FileFd::stat() {
res.is_reg_ = true; res.is_reg_ = true;
FILE_STANDARD_INFO standard_info; FILE_STANDARD_INFO standard_info;
status = GetFileInformationByHandleEx(fd_.get_io_handle(), FileStandardInfo, &standard_info, sizeof(standard_info)); status = GetFileInformationByHandleEx(get_native_fd().io_handle(), FileStandardInfo, &standard_info,
sizeof(standard_info));
if (!status) { if (!status) {
auto error = OS_ERROR("Stat failed"); auto error = OS_ERROR("Stat failed");
LOG(FATAL) << error; LOG(FATAL) << error;
@ -445,9 +471,9 @@ Stat FileFd::stat() {
Status FileFd::sync() { Status FileFd::sync() {
CHECK(!empty()); CHECK(!empty());
#if TD_PORT_POSIX #if TD_PORT_POSIX
if (fsync(fd_.get_native_fd()) != 0) { if (fsync(get_native_fd().fd()) != 0) {
#elif TD_PORT_WINDOWS #elif TD_PORT_WINDOWS
if (FlushFileBuffers(fd_.get_io_handle()) == 0) { if (FlushFileBuffers(get_native_fd().io_handle()) == 0) {
#endif #endif
return OS_ERROR("Sync failed"); return OS_ERROR("Sync failed");
} }
@ -458,11 +484,11 @@ Status FileFd::seek(int64 position) {
CHECK(!empty()); CHECK(!empty());
#if TD_PORT_POSIX #if TD_PORT_POSIX
TRY_RESULT(position_off_t, narrow_cast_safe<off_t>(position)); TRY_RESULT(position_off_t, narrow_cast_safe<off_t>(position));
if (skip_eintr([&] { return ::lseek(fd_.get_native_fd(), position_off_t, SEEK_SET); }) < 0) { if (skip_eintr([&] { return ::lseek(get_native_fd().fd(), position_off_t, SEEK_SET); }) < 0) {
#elif TD_PORT_WINDOWS #elif TD_PORT_WINDOWS
LARGE_INTEGER offset; LARGE_INTEGER offset;
offset.QuadPart = position; offset.QuadPart = position;
if (SetFilePointerEx(fd_.get_io_handle(), offset, nullptr, FILE_BEGIN) == 0) { if (SetFilePointerEx(get_native_fd().io_handle(), offset, nullptr, FILE_BEGIN) == 0) {
#endif #endif
return OS_ERROR("Seek failed"); return OS_ERROR("Seek failed");
} }
@ -473,13 +499,19 @@ Status FileFd::truncate_to_current_position(int64 current_position) {
CHECK(!empty()); CHECK(!empty());
#if TD_PORT_POSIX #if TD_PORT_POSIX
TRY_RESULT(current_position_off_t, narrow_cast_safe<off_t>(current_position)); TRY_RESULT(current_position_off_t, narrow_cast_safe<off_t>(current_position));
if (skip_eintr([&] { return ::ftruncate(fd_.get_native_fd(), current_position_off_t); }) < 0) { if (skip_eintr([&] { return ::ftruncate(get_native_fd().fd(), current_position_off_t); }) < 0) {
#elif TD_PORT_WINDOWS #elif TD_PORT_WINDOWS
if (SetEndOfFile(fd_.get_io_handle()) == 0) { if (SetEndOfFile(get_native_fd().io_handle()) == 0) {
#endif #endif
return OS_ERROR("Truncate failed"); return OS_ERROR("Truncate failed");
} }
return Status::OK(); return Status::OK();
} }
PollableFdInfo &FileFd::get_poll_info() {
return impl_->info;
}
const PollableFdInfo &FileFd::get_poll_info() const {
return impl_->info;
}
} // namespace td } // namespace td

View File

@ -11,21 +11,28 @@
#include "td/utils/common.h" #include "td/utils/common.h"
#include "td/utils/port/Fd.h" #include "td/utils/port/Fd.h"
#include "td/utils/port/Stat.h" #include "td/utils/port/Stat.h"
#include "td/utils/port/detail/PollableFd.h"
#include "td/utils/Slice.h" #include "td/utils/Slice.h"
#include "td/utils/Status.h" #include "td/utils/Status.h"
namespace td { namespace td {
namespace detail {
class FileFdImpl;
}
class FileFd { class FileFd {
public: public:
FileFd() = default; FileFd();
FileFd(FileFd &&);
FileFd &operator=(FileFd &&);
~FileFd();
FileFd(const FileFd &) = delete;
FileFd &operator=(const FileFd &) = delete;
enum Flags : int32 { Write = 1, Read = 2, Truncate = 4, Create = 8, Append = 16, CreateNew = 32 }; enum Flags : int32 { Write = 1, Read = 2, Truncate = 4, Create = 8, Append = 16, CreateNew = 32 };
const Fd &get_fd() const;
Fd &get_fd();
static Result<FileFd> open(CSlice filepath, int32 flags, int32 mode = 0600) TD_WARN_UNUSED_RESULT; static Result<FileFd> open(CSlice filepath, int32 flags, int32 mode = 0600) TD_WARN_UNUSED_RESULT;
static Result<FileFd> from_native_fd(NativeFd fd) TD_WARN_UNUSED_RESULT;
Result<size_t> write(Slice slice) TD_WARN_UNUSED_RESULT; Result<size_t> write(Slice slice) TD_WARN_UNUSED_RESULT;
Result<size_t> read(MutableSlice slice) TD_WARN_UNUSED_RESULT; Result<size_t> read(MutableSlice slice) TD_WARN_UNUSED_RESULT;
@ -36,12 +43,11 @@ class FileFd {
enum class LockFlags { Write, Read, Unlock }; enum class LockFlags { Write, Read, Unlock };
Status lock(LockFlags flags, int32 max_tries = 1) TD_WARN_UNUSED_RESULT; Status lock(LockFlags flags, int32 max_tries = 1) TD_WARN_UNUSED_RESULT;
PollableFdInfo &get_poll_info();
const PollableFdInfo &get_poll_info() const;
void close(); void close();
bool empty() const; bool empty() const;
int32 get_flags() const;
void update_flags(Fd::Flags mask);
int64 get_size(); int64 get_size();
Stat stat(); Stat stat();
@ -52,12 +58,13 @@ class FileFd {
Status truncate_to_current_position(int64 current_position) TD_WARN_UNUSED_RESULT; Status truncate_to_current_position(int64 current_position) TD_WARN_UNUSED_RESULT;
#if TD_PORT_POSIX const NativeFd &get_native_fd() const;
int get_native_fd() const; NativeFd move_as_native_fd();
#endif
private: private:
Fd fd_; std::unique_ptr<detail::FileFdImpl> impl_;
FileFd(std::unique_ptr<detail::FileFdImpl> impl);
}; };
} // namespace td } // namespace td

View File

@ -184,7 +184,7 @@ const sockaddr *IPAddress::get_sockaddr() const {
size_t IPAddress::get_sockaddr_len() const { size_t IPAddress::get_sockaddr_len() const {
CHECK(is_valid()); CHECK(is_valid());
switch (addr_.ss_family) { switch (sockaddr_.sa_family) {
case AF_INET6: case AF_INET6:
return sizeof(ipv6_addr_); return sizeof(ipv6_addr_);
case AF_INET: case AF_INET:
@ -210,7 +210,7 @@ bool IPAddress::is_ipv6() const {
uint32 IPAddress::get_ipv4() const { uint32 IPAddress::get_ipv4() const {
CHECK(is_valid()); CHECK(is_valid());
CHECK(is_ipv4()); CHECK(is_ipv4());
return ipv4_addr_.sin_addr.s_addr; return ntohl(ipv4_addr_.sin_addr.s_addr);
} }
Slice IPAddress::get_ipv6() const { Slice IPAddress::get_ipv6() const {
@ -356,30 +356,36 @@ Status IPAddress::init_host_port(CSlice host_port) {
return init_host_port(host_port.substr(0, pos).str(), host_port.substr(pos + 1).str()); return init_host_port(host_port.substr(0, pos).str(), host_port.substr(pos + 1).str());
} }
Status IPAddress::init_sockaddr(sockaddr *addr) {
if (addr->sa_family == AF_INET6) {
return init_sockaddr(addr, sizeof(ipv6_addr_));
} else if (addr->sa_family == AF_INET) {
return init_sockaddr(addr, sizeof(ipv4_addr_));
} else {
return init_sockaddr(addr, 0);
}
}
Status IPAddress::init_sockaddr(sockaddr *addr, socklen_t len) { Status IPAddress::init_sockaddr(sockaddr *addr, socklen_t len) {
if (addr->sa_family == AF_INET6) { if (addr->sa_family == AF_INET6) {
CHECK(len == sizeof(ipv6_addr_)); CHECK(len == sizeof(ipv6_addr_));
std::memcpy(&ipv6_addr_, reinterpret_cast<sockaddr_in6 *>(addr), sizeof(ipv6_addr_)); std::memcpy(&ipv6_addr_, reinterpret_cast<sockaddr_in6 *>(addr), sizeof(ipv6_addr_));
LOG(DEBUG) << "Have ipv6 address " << get_ip_str() << " with port " << get_port();
} else if (addr->sa_family == AF_INET) { } else if (addr->sa_family == AF_INET) {
CHECK(len == sizeof(ipv4_addr_)); CHECK(len == sizeof(ipv4_addr_));
std::memcpy(&ipv4_addr_, reinterpret_cast<sockaddr_in *>(addr), sizeof(ipv4_addr_)); std::memcpy(&ipv4_addr_, reinterpret_cast<sockaddr_in *>(addr), sizeof(ipv4_addr_));
LOG(DEBUG) << "Have ipv4 address " << get_ip_str() << " with port " << get_port();
} else { } else {
return Status::Error(PSLICE() << "Unknown " << tag("sa_family", addr->sa_family)); return Status::Error(PSLICE() << "Unknown " << tag("sa_family", addr->sa_family));
} }
is_valid_ = true; is_valid_ = true;
LOG(INFO) << "Have address " << get_ip_str() << " with port " << get_port();
return Status::OK(); return Status::OK();
} }
Status IPAddress::init_socket_address(const SocketFd &socket_fd) { Status IPAddress::init_socket_address(const SocketFd &socket_fd) {
is_valid_ = false; is_valid_ = false;
#if TD_WINDOWS auto fd = socket_fd.get_native_fd().socket();
auto fd = socket_fd.get_fd().get_native_socket(); socklen_t len = storage_size();
#else
auto fd = socket_fd.get_fd().get_native_fd();
#endif
socklen_t len = sizeof(addr_);
int ret = getsockname(fd, &sockaddr_, &len); int ret = getsockname(fd, &sockaddr_, &len);
if (ret != 0) { if (ret != 0) {
return OS_SOCKET_ERROR("Failed to get socket address"); return OS_SOCKET_ERROR("Failed to get socket address");
@ -390,12 +396,8 @@ Status IPAddress::init_socket_address(const SocketFd &socket_fd) {
Status IPAddress::init_peer_address(const SocketFd &socket_fd) { Status IPAddress::init_peer_address(const SocketFd &socket_fd) {
is_valid_ = false; is_valid_ = false;
#if TD_WINDOWS auto fd = socket_fd.get_native_fd().socket();
auto fd = socket_fd.get_fd().get_native_socket(); socklen_t len = storage_size();
#else
auto fd = socket_fd.get_fd().get_native_fd();
#endif
socklen_t len = sizeof(addr_);
int ret = getpeername(fd, &sockaddr_, &len); int ret = getpeername(fd, &sockaddr_, &len);
if (ret != 0) { if (ret != 0) {
return OS_SOCKET_ERROR("Failed to get peer socket address"); return OS_SOCKET_ERROR("Failed to get peer socket address");

View File

@ -39,8 +39,6 @@ class IPAddress {
Slice get_ipv6() const; Slice get_ipv6() const;
Slice get_ip_str() const; Slice get_ip_str() const;
static CSlice ipv4_to_str(int32 ipv4);
IPAddress get_any_addr() const; IPAddress get_any_addr() const;
Status init_ipv6_port(CSlice ipv6, int port) TD_WARN_UNUSED_RESULT; Status init_ipv6_port(CSlice ipv6, int port) TD_WARN_UNUSED_RESULT;
@ -59,21 +57,24 @@ class IPAddress {
const sockaddr *get_sockaddr() const; const sockaddr *get_sockaddr() const;
size_t get_sockaddr_len() const; size_t get_sockaddr_len() const;
int get_address_family() const; int get_address_family() const;
static CSlice ipv4_to_str(int32 ipv4);
Status init_sockaddr(sockaddr *addr);
Status init_sockaddr(sockaddr *addr, socklen_t len) TD_WARN_UNUSED_RESULT;
private: private:
union { union {
sockaddr_storage addr_;
sockaddr sockaddr_; sockaddr sockaddr_;
sockaddr_in ipv4_addr_; sockaddr_in ipv4_addr_;
sockaddr_in6 ipv6_addr_; sockaddr_in6 ipv6_addr_;
}; };
socklen_t storage_size() {
return sizeof(ipv6_addr_);
}
bool is_valid_; bool is_valid_;
Status init_sockaddr(sockaddr *addr, socklen_t len) TD_WARN_UNUSED_RESULT;
void init_ipv4_any(); void init_ipv4_any();
void init_ipv6_any(); void init_ipv6_any();
}; };
StringBuilder &operator<<(StringBuilder &builder, const IPAddress &address); StringBuilder &operator<<(StringBuilder &builder, const IPAddress &address);
} // namespace td } // namespace td

View File

@ -7,6 +7,7 @@
#pragma once #pragma once
#include "td/utils/port/Fd.h" #include "td/utils/port/Fd.h"
#include "td/utils/port/detail/PollableFd.h"
namespace td { namespace td {
class PollBase { class PollBase {
@ -19,9 +20,9 @@ class PollBase {
virtual ~PollBase() = default; virtual ~PollBase() = default;
virtual void init() = 0; virtual void init() = 0;
virtual void clear() = 0; virtual void clear() = 0;
virtual void subscribe(const Fd &fd, Fd::Flags flags) = 0; virtual void subscribe(PollableFd fd, PollFlags flags) = 0;
virtual void unsubscribe(const Fd &fd) = 0; virtual void unsubscribe(PollableFdRef fd) = 0;
virtual void unsubscribe_before_close(const Fd &fd) = 0; virtual void unsubscribe_before_close(PollableFdRef fd) = 0;
virtual void run(int timeout_ms) = 0; virtual void run(int timeout_ms) = 0;
}; };
} // namespace td } // namespace td

View File

@ -23,138 +23,328 @@
#endif #endif
#if TD_PORT_WINDOWS
#include "td/utils/port/detail/WineventPoll.h"
#include "td/utils/SpinLock.h"
#include "td/utils/VectorQueue.h"
#endif
namespace td { namespace td {
Result<ServerSocketFd> ServerSocketFd::open(int32 port, CSlice addr) { namespace detail {
ServerSocketFd socket; #if TD_PORT_WINDOWS
TRY_STATUS(socket.init(port, addr)); class ServerSocketFdImpl : private IOCP::Callback {
return std::move(socket); public:
ServerSocketFdImpl(NativeFd fd, int socket_family) : info(std::move(fd)), socket_family_(socket_family) {
VLOG(fd) << get_native_fd().io_handle() << " create ServerSocketFd";
IOCP::get()->subscribe(get_native_fd(), this);
notify_iocp_read();
}
void close() {
notify_iocp_close();
}
PollableFdInfo &get_poll_info() {
return info;
}
const PollableFdInfo &get_poll_info() const {
return info;
}
const NativeFd &get_native_fd() const {
return info.native_fd();
}
Result<SocketFd> accept() {
auto lock = lock_.lock();
if (accepted_.empty()) {
get_poll_info().clear_flags(PollFlags::Read());
return Status::Error(-1, "Operation would block");
}
return accepted_.pop();
}
Status get_pending_error() {
Status res;
{
auto lock = lock_.lock();
if (!pending_errors_.empty()) {
res = pending_errors_.pop();
}
if (res.is_ok()) {
get_poll_info().clear_flags(PollFlags::Error());
}
}
return res;
}
private:
PollableFdInfo info;
SpinLock lock_;
VectorQueue<SocketFd> accepted_;
VectorQueue<Status> pending_errors_;
static constexpr size_t MAX_ADDR_SIZE = sizeof(sockaddr_in6) + 16;
char addr_buf_[MAX_ADDR_SIZE * 2];
bool close_flag_{false};
std::atomic<int> refcnt_{1};
bool is_read_active_{false};
OVERLAPPED read_overlapped_;
char close_overlapped_;
NativeFd accept_socket_;
int socket_family_;
void on_close() {
close_flag_ = true;
info.set_native_fd({});
}
void on_read() {
VLOG(fd) << get_native_fd().io_handle() << " on_read";
if (is_read_active_) {
is_read_active_ = false;
auto r_socket = SocketFd::from_native_fd(std::move(accept_socket_));
VLOG(fd) << get_native_fd().io_handle() << " finish accept";
if (r_socket.is_error()) {
return on_error(r_socket.move_as_error());
}
{
auto lock = lock_.lock();
accepted_.push(r_socket.move_as_ok());
}
get_poll_info().add_flags_from_poll(PollFlags::Read());
}
loop_read();
}
void loop_read() {
CHECK(!is_read_active_);
accept_socket_ = NativeFd(socket(socket_family_, SOCK_STREAM, 0));
std::memset(&read_overlapped_, 0, sizeof(read_overlapped_));
VLOG(fd) << get_native_fd().io_handle() << " start accept";
auto status = AcceptEx(get_native_fd().socket(), accept_socket_.socket(), addr_buf_, 0, MAX_ADDR_SIZE,
MAX_ADDR_SIZE, nullptr, &read_overlapped_);
if (check_status(status, "accent")) {
inc_refcnt();
is_read_active_ = true;
}
}
bool check_status(DWORD status, Slice message) {
if (status == 0) {
return true;
}
auto last_error = GetLastError();
if (last_error == ERROR_IO_PENDING) {
return true;
}
on_error(OS_SOCKET_ERROR(message));
return false;
}
bool dec_refcnt() {
if (--refcnt_ == 0) {
delete this;
return true;
}
return false;
}
void inc_refcnt() {
CHECK(refcnt_ != 0);
refcnt_++;
}
void on_error(Status status) {
{
auto lock = lock_.lock();
pending_errors_.push(std::move(status));
}
get_poll_info().add_flags_from_poll(PollFlags::Error());
}
void on_iocp(Result<size_t> r_size, OVERLAPPED *overlapped) override {
// called from other thread
if (dec_refcnt() || close_flag_) {
return;
}
if (r_size.is_error()) {
return on_error(r_size.move_as_error());
}
if (overlapped == nullptr) {
return on_read();
}
if (overlapped == &read_overlapped_) {
return on_read();
}
if (overlapped == reinterpret_cast<OVERLAPPED *>(&close_overlapped_)) {
return on_close();
}
UNREACHABLE();
}
void notify_iocp_read() {
VLOG(fd) << get_native_fd().io_handle() << " notify_read";
inc_refcnt();
IOCP::get()->post(0, this, nullptr);
}
void notify_iocp_close() {
VLOG(fd) << get_native_fd().io_handle() << " notify_close";
IOCP::get()->post(0, this, reinterpret_cast<OVERLAPPED *>(&close_overlapped_));
}
};
void ServerSocketFdImplDeleter::operator()(ServerSocketFdImpl *impl) {
impl->close();
}
#elif TD_PORT_POSIX
class ServerSocketFdImpl {
public:
ServerSocketFdImpl(NativeFd fd) : info(std::move(fd)) {
}
PollableFdInfo &get_poll_info() {
return info;
}
const PollableFdInfo &get_poll_info() const {
return info;
}
const NativeFd &get_native_fd() const {
return info.native_fd();
}
Result<SocketFd> accept() {
sockaddr_storage addr;
socklen_t addr_len = sizeof(addr);
int native_fd = get_native_fd().fd();
int r_fd = skip_eintr([&] { return ::accept(native_fd, reinterpret_cast<sockaddr *>(&addr), &addr_len); });
auto accept_errno = errno;
if (r_fd >= 0) {
return SocketFd::from_native_fd(NativeFd(r_fd));
}
if (accept_errno == EAGAIN
#if EAGAIN != EWOULDBLOCK
|| accept_errno == EWOULDBLOCK
#endif
) {
get_poll_info().clear_flags(PollFlags::Read());
return Status::Error(-1, "Operation would block");
}
auto error = Status::PosixError(accept_errno, PSLICE() << "Accept from [fd = " << native_fd << "] has failed");
switch (accept_errno) {
case EBADF:
case EFAULT:
case EINVAL:
case ENOTSOCK:
case EOPNOTSUPP:
LOG(FATAL) << error;
UNREACHABLE();
break;
default:
LOG(ERROR) << error;
// fallthrough
case EMFILE:
case ENFILE:
case ECONNABORTED: //???
get_poll_info().clear_flags(PollFlags::Read());
get_poll_info().add_flags(PollFlags::Close());
return std::move(error);
}
}
Status get_pending_error() {
if (get_poll_info().get_flags().has_pending_error()) {
return Status::OK();
}
TRY_STATUS(detail::get_socket_pending_error(get_native_fd()));
get_poll_info().clear_flags(PollFlags::Error());
return Status::OK();
}
private:
PollableFdInfo info;
};
void ServerSocketFdImplDeleter::operator()(ServerSocketFdImpl *impl) {
delete impl;
}
#endif
} // namespace detail
ServerSocketFd::ServerSocketFd() = default;
ServerSocketFd::ServerSocketFd(ServerSocketFd &&) = default;
ServerSocketFd &ServerSocketFd::operator=(ServerSocketFd &&) = default;
ServerSocketFd::~ServerSocketFd() = default;
ServerSocketFd::ServerSocketFd(std::unique_ptr<detail::ServerSocketFdImpl> impl) : impl_(impl.release()) {
}
PollableFdInfo &ServerSocketFd::get_poll_info() {
return impl_->get_poll_info();
} }
const Fd &ServerSocketFd::get_fd() const { const PollableFdInfo &ServerSocketFd::get_poll_info() const {
return fd_; return impl_->get_poll_info();
}
Fd &ServerSocketFd::get_fd() {
return fd_;
}
int32 ServerSocketFd::get_flags() const {
return fd_.get_flags();
} }
Status ServerSocketFd::get_pending_error() { Status ServerSocketFd::get_pending_error() {
return fd_.get_pending_error(); return impl_->get_pending_error();
}
const NativeFd &ServerSocketFd::get_native_fd() const {
return impl_->get_native_fd();
} }
Result<SocketFd> ServerSocketFd::accept() { Result<SocketFd> ServerSocketFd::accept() {
#if TD_PORT_POSIX return impl_->accept();
sockaddr_storage addr;
socklen_t addr_len = sizeof(addr);
int native_fd = fd_.get_native_fd();
int r_fd = skip_eintr([&] { return ::accept(native_fd, reinterpret_cast<sockaddr *>(&addr), &addr_len); });
auto accept_errno = errno;
if (r_fd >= 0) {
return SocketFd::from_native_fd(r_fd);
}
if (accept_errno == EAGAIN
#if EAGAIN != EWOULDBLOCK
|| accept_errno == EWOULDBLOCK
#endif
) {
fd_.clear_flags(Fd::Read);
return Status::Error(-1, "Operation would block");
}
auto error = Status::PosixError(accept_errno, PSLICE() << "Accept from [fd = " << native_fd << "] has failed");
switch (accept_errno) {
case EBADF:
case EFAULT:
case EINVAL:
case ENOTSOCK:
case EOPNOTSUPP:
LOG(FATAL) << error;
UNREACHABLE();
break;
default:
LOG(ERROR) << error;
// fallthrough
case EMFILE:
case ENFILE:
case ECONNABORTED: //???
fd_.clear_flags(Fd::Read);
fd_.update_flags(Fd::Close);
return std::move(error);
}
#elif TD_PORT_WINDOWS
TRY_RESULT(socket_fd, fd_.accept());
return SocketFd(std::move(socket_fd));
#endif
} }
void ServerSocketFd::close() { void ServerSocketFd::close() {
fd_.close(); impl_.reset();
} }
bool ServerSocketFd::empty() const { bool ServerSocketFd::empty() const {
return fd_.empty(); return !impl_;
} }
Status ServerSocketFd::init(int32 port, CSlice addr) { Result<ServerSocketFd> ServerSocketFd::open(int32 port, CSlice addr) {
IPAddress address; IPAddress address;
TRY_STATUS(address.init_ipv4_port(addr, port)); TRY_STATUS(address.init_ipv4_port(addr, port));
auto fd = socket(address.get_address_family(), SOCK_STREAM, 0); NativeFd fd{socket(address.get_address_family(), SOCK_STREAM, 0)};
#if TD_PORT_POSIX if (!fd) {
if (fd == -1) {
#elif TD_PORT_WINDOWS
if (fd == INVALID_SOCKET) {
#endif
return OS_SOCKET_ERROR("Failed to create a socket"); return OS_SOCKET_ERROR("Failed to create a socket");
} }
auto fd_quard = ScopeExit() + [fd]() {
#if TD_PORT_POSIX
::close(fd);
#elif TD_PORT_WINDOWS
::closesocket(fd);
#endif
};
TRY_STATUS(detail::set_native_socket_is_blocking(fd, false)); TRY_STATUS(detail::set_native_socket_is_blocking(fd, false));
auto sock = fd.socket();
linger ling = {0, 0}; linger ling = {0, 0};
#if TD_PORT_POSIX #if TD_PORT_POSIX
int flags = 1; int flags = 1;
#ifdef SO_REUSEPORT #ifdef SO_REUSEPORT
setsockopt(fd, SOL_SOCKET, SO_REUSEPORT, reinterpret_cast<const char *>(&flags), sizeof(flags)); setsockopt(sock, SOL_SOCKET, SO_REUSEPORT, reinterpret_cast<const char *>(&flags), sizeof(flags));
#endif #endif
#elif TD_PORT_WINDOWS #elif TD_PORT_WINDOWS
BOOL flags = TRUE; BOOL flags = TRUE;
#endif #endif
setsockopt(fd, SOL_SOCKET, SO_REUSEADDR, reinterpret_cast<const char *>(&flags), sizeof(flags)); setsockopt(sock, SOL_SOCKET, SO_REUSEADDR, reinterpret_cast<const char *>(&flags), sizeof(flags));
setsockopt(fd, SOL_SOCKET, SO_KEEPALIVE, reinterpret_cast<const char *>(&flags), sizeof(flags)); setsockopt(sock, SOL_SOCKET, SO_KEEPALIVE, reinterpret_cast<const char *>(&flags), sizeof(flags));
setsockopt(fd, SOL_SOCKET, SO_LINGER, reinterpret_cast<const char *>(&ling), sizeof(ling)); setsockopt(sock, SOL_SOCKET, SO_LINGER, reinterpret_cast<const char *>(&ling), sizeof(ling));
setsockopt(fd, IPPROTO_TCP, TCP_NODELAY, reinterpret_cast<const char *>(&flags), sizeof(flags)); setsockopt(sock, IPPROTO_TCP, TCP_NODELAY, reinterpret_cast<const char *>(&flags), sizeof(flags));
int e_bind = bind(fd, address.get_sockaddr(), static_cast<socklen_t>(address.get_sockaddr_len())); int e_bind = bind(sock, address.get_sockaddr(), static_cast<socklen_t>(address.get_sockaddr_len()));
if (e_bind != 0) { if (e_bind != 0) {
return OS_SOCKET_ERROR("Failed to bind a socket"); return OS_SOCKET_ERROR("Failed to bind a socket");
} }
// TODO: magic constant // TODO: magic constant
int e_listen = listen(fd, 8192); int e_listen = listen(sock, 8192);
if (e_listen != 0) { if (e_listen != 0) {
return OS_SOCKET_ERROR("Failed to listen on a socket"); return OS_SOCKET_ERROR("Failed to listen on a socket");
} }
#if TD_PORT_POSIX #if TD_PORT_POSIX
fd_ = Fd(fd, Fd::Mode::Owner); auto impl = std::make_unique<detail::ServerSocketFdImpl>(std::move(fd));
#elif TD_PORT_WINDOWS #elif TD_PORT_WINDOWS
fd_ = Fd::create_server_socket_fd(fd, address.get_address_family()); auto impl = std::make_unique<detail::ServerSocketFdImpl>(std::move(fd), address.get_address_family());
#endif #endif
fd_quard.dismiss(); return ServerSocketFd(std::move(impl));
return Status::OK();
} }
} // namespace td } // namespace td

View File

@ -13,20 +13,28 @@
#include "td/utils/Status.h" #include "td/utils/Status.h"
namespace td { namespace td {
namespace detail {
class ServerSocketFdImpl;
class ServerSocketFdImplDeleter {
public:
void operator()(ServerSocketFdImpl *impl);
};
} // namespace detail
class ServerSocketFd { class ServerSocketFd {
public: public:
ServerSocketFd() = default; ServerSocketFd();
ServerSocketFd(const ServerSocketFd &) = delete; ServerSocketFd(const ServerSocketFd &) = delete;
ServerSocketFd &operator=(const ServerSocketFd &) = delete; ServerSocketFd &operator=(const ServerSocketFd &) = delete;
ServerSocketFd(ServerSocketFd &&) = default; ServerSocketFd(ServerSocketFd &&);
ServerSocketFd &operator=(ServerSocketFd &&) = default; ServerSocketFd &operator=(ServerSocketFd &&);
~ServerSocketFd();
static Result<ServerSocketFd> open(int32 port, CSlice addr = CSlice("0.0.0.0")) TD_WARN_UNUSED_RESULT; static Result<ServerSocketFd> open(int32 port, CSlice addr = CSlice("0.0.0.0")) TD_WARN_UNUSED_RESULT;
const Fd &get_fd() const; PollableFdInfo &get_poll_info();
Fd &get_fd(); const PollableFdInfo &get_poll_info() const;
int32 get_flags() const;
Status get_pending_error() TD_WARN_UNUSED_RESULT; Status get_pending_error() TD_WARN_UNUSED_RESULT;
Result<SocketFd> accept() TD_WARN_UNUSED_RESULT; Result<SocketFd> accept() TD_WARN_UNUSED_RESULT;
@ -34,10 +42,10 @@ class ServerSocketFd {
void close(); void close();
bool empty() const; bool empty() const;
const NativeFd &get_native_fd() const;
private: private:
Fd fd_; std::unique_ptr<detail::ServerSocketFdImpl, detail::ServerSocketFdImplDeleter> impl_;
explicit ServerSocketFd(std::unique_ptr<detail::ServerSocketFdImpl> impl);
Status init(int32 port, CSlice addr) TD_WARN_UNUSED_RESULT;
}; };
} // namespace td } // namespace td

View File

@ -8,9 +8,7 @@
#include "td/utils/logging.h" #include "td/utils/logging.h"
#if TD_PORT_WINDOWS
#include "td/utils/misc.h" #include "td/utils/misc.h"
#endif
#if TD_PORT_POSIX #if TD_PORT_POSIX
#include <arpa/inet.h> #include <arpa/inet.h>
@ -22,118 +20,541 @@
#include <unistd.h> #include <unistd.h>
#endif #endif
namespace td { #if TD_PORT_WINDOWS
#include "td/utils/buffer.h"
#include "td/utils/port/detail/WineventPoll.h"
#include "td/utils/SpinLock.h"
#include "td/utils/VectorQueue.h"
#endif
Result<SocketFd> SocketFd::open(const IPAddress &address) { namespace td {
SocketFd socket; namespace detail {
TRY_STATUS(socket.init(address)); #if TD_PORT_WINDOWS
return std::move(socket); class SocketFdImpl : private IOCP::Callback {
public:
SocketFdImpl(NativeFd native_fd) : info(std::move(native_fd)) {
VLOG(fd) << get_native_fd().io_handle() << " create from native_fd";
get_poll_info().add_flags(PollFlags::Write());
IOCP::get()->subscribe(get_native_fd(), this);
is_read_active_ = true;
notify_iocp_connected();
}
SocketFdImpl(NativeFd native_fd, const IPAddress &addr) : info(std::move(native_fd)) {
VLOG(fd) << get_native_fd().io_handle() << " create from native_fd and connect";
get_poll_info().add_flags(PollFlags::Write());
IOCP::get()->subscribe(get_native_fd(), this);
LPFN_CONNECTEX ConnectExPtr = nullptr;
GUID guid = WSAID_CONNECTEX;
DWORD numBytes;
auto error =
::WSAIoctl(get_native_fd().socket(), SIO_GET_EXTENSION_FUNCTION_POINTER, static_cast<void *>(&guid),
sizeof(guid), static_cast<void *>(&ConnectExPtr), sizeof(ConnectExPtr), &numBytes, nullptr, nullptr);
if (error) {
on_error(OS_SOCKET_ERROR("WSAIoctl failed"));
return;
}
std::memset(&read_overlapped_, 0, sizeof(read_overlapped_));
inc_refcnt();
is_read_active_ = true;
auto status = ConnectExPtr(get_native_fd().socket(), addr.get_sockaddr(), narrow_cast<int>(addr.get_sockaddr_len()),
nullptr, 0, nullptr, &read_overlapped_);
if (!check_status(status, "connect")) {
is_read_active_ = false;
dec_refcnt();
}
}
void close() {
notify_iocp_close();
}
PollableFdInfo &get_poll_info() {
return info;
}
const PollableFdInfo &get_poll_info() const {
return info;
}
const NativeFd &get_native_fd() const {
return info.native_fd();
}
Result<size_t> write(Slice data) {
output_writer_.append(data);
if (is_write_waiting_) {
auto lock = lock_.lock();
is_write_waiting_ = false;
notify_iocp_write();
}
return data.size();
}
Result<size_t> read(MutableSlice slice) {
if (get_poll_info().get_flags().has_pending_error()) {
TRY_STATUS(get_pending_error());
}
input_reader_.sync_with_writer();
auto res = input_reader_.advance(min(slice.size(), input_reader_.size()), slice);
if (res == 0) {
get_poll_info().clear_flags(PollFlags::Read());
}
LOG(ERROR) << "GOT " << res;
return res;
}
Status get_pending_error() {
Status res;
{
auto lock = lock_.lock();
if (!pending_errors_.empty()) {
res = pending_errors_.pop();
}
if (res.is_ok()) {
get_poll_info().clear_flags(PollFlags::Error());
}
}
return res;
}
private:
PollableFdInfo info;
SpinLock lock_;
std::atomic<int> refcnt_{1};
bool close_flag_{false};
bool is_connected_{false};
bool is_read_active_{false};
ChainBufferWriter input_writer_;
ChainBufferReader input_reader_ = input_writer_.extract_reader();
OVERLAPPED read_overlapped_;
VectorQueue<Status> pending_errors_;
bool is_write_active_{false};
std::atomic<bool> is_write_waiting_{false};
ChainBufferWriter output_writer_;
ChainBufferReader output_reader_ = output_writer_.extract_reader();
OVERLAPPED write_overlapped_;
char close_overlapped_;
bool check_status(DWORD status, Slice message) {
if (status == 0) {
return true;
}
auto last_error = GetLastError();
if (last_error == ERROR_IO_PENDING) {
return true;
}
on_error(OS_SOCKET_ERROR(message));
return false;
}
void loop_read() {
CHECK(is_connected_);
CHECK(!is_read_active_);
if (close_flag_) {
return;
}
std::memset(&read_overlapped_, 0, sizeof(read_overlapped_));
auto dest = input_writer_.prepare_append();
auto status =
ReadFile(get_native_fd().io_handle(), dest.data(), narrow_cast<DWORD>(dest.size()), nullptr, &read_overlapped_);
if (check_status(status, "read")) {
inc_refcnt();
is_read_active_ = true;
}
}
void loop_write() {
CHECK(is_connected_);
CHECK(!is_write_active_);
output_reader_.sync_with_writer();
auto to_write = output_reader_.prepare_read();
if (to_write.empty()) {
auto lock = lock_.lock();
to_write = output_reader_.prepare_read();
if (to_write.empty()) {
is_write_waiting_ = true;
return;
}
}
if (to_write.empty()) {
return;
}
auto dest = output_reader_.prepare_read();
std::memset(&write_overlapped_, 0, sizeof(write_overlapped_));
auto status = WriteFile(get_native_fd().io_handle(), dest.data(), narrow_cast<DWORD>(dest.size()), nullptr,
&write_overlapped_);
if (check_status(status, "write")) {
inc_refcnt();
is_write_active_ = true;
}
}
void on_iocp(Result<size_t> r_size, OVERLAPPED *overlapped) override {
// called from other thread
if (dec_refcnt() || close_flag_) {
VLOG(fd) << "ignore iocp (file is closing)";
return;
}
if (r_size.is_error()) {
return on_error(r_size.move_as_error());
}
if (!is_connected_ && overlapped == &read_overlapped_) {
return on_connected();
}
auto size = r_size.move_as_ok();
if (overlapped == &write_overlapped_) {
return on_write(size);
}
if (overlapped == nullptr) {
CHECK(size == 0);
return on_write(size);
}
if (overlapped == &read_overlapped_) {
return on_read(size);
}
if (overlapped == reinterpret_cast<OVERLAPPED *>(&close_overlapped_)) {
return on_close();
}
UNREACHABLE();
}
void on_error(Status status) {
VLOG(fd) << get_native_fd().io_handle() << " "
<< "on error " << status;
{
auto lock = lock_.lock();
pending_errors_.push(std::move(status));
}
get_poll_info().add_flags_from_poll(PollFlags::Error());
}
void on_connected() {
VLOG(fd) << get_native_fd().io_handle() << " on connected";
CHECK(!is_connected_);
CHECK(is_read_active_);
is_connected_ = true;
is_read_active_ = false;
loop_read();
loop_write();
}
void on_read(size_t size) {
VLOG(fd) << get_native_fd().io_handle() << " on read " << size;
CHECK(is_read_active_);
is_read_active_ = false;
input_writer_.confirm_append(size);
get_poll_info().add_flags_from_poll(PollFlags::Read());
loop_read();
}
void on_write(size_t size) {
VLOG(fd) << get_native_fd().io_handle() << " on write " << size;
if (size == 0) {
if (is_write_active_) {
return;
}
is_write_active_ = true;
}
CHECK(is_write_active_);
is_write_active_ = false;
output_reader_.confirm_read(size);
loop_write();
}
void on_close() {
VLOG(fd) << get_native_fd().io_handle() << " on close";
close_flag_ = true;
info.set_native_fd({});
}
bool dec_refcnt() {
if (--refcnt_ == 0) {
LOG(ERROR) << "DELETE";
delete this;
return true;
}
return false;
}
void inc_refcnt() {
CHECK(refcnt_ != 0);
refcnt_++;
}
void notify_iocp_write() {
inc_refcnt();
IOCP::get()->post(0, this, nullptr);
}
void notify_iocp_close() {
IOCP::get()->post(0, this, reinterpret_cast<OVERLAPPED *>(&close_overlapped_));
}
void notify_iocp_connected() {
inc_refcnt();
IOCP::get()->post(0, this, reinterpret_cast<OVERLAPPED *>(&read_overlapped_));
}
};
void SocketFdImplDeleter::operator()(SocketFdImpl *impl) {
impl->close();
}
class InitWSA {
public:
InitWSA() {
/* Use the MAKEWORD(lowbyte, highbyte) macro declared in Windef.h */
WORD wVersionRequested = MAKEWORD(2, 2);
WSADATA wsaData;
if (WSAStartup(wVersionRequested, &wsaData) != 0) {
auto error = OS_SOCKET_ERROR("Failed to init WSA");
LOG(FATAL) << error;
}
}
};
static InitWSA init_wsa;
#else
class SocketFdImpl {
public:
PollableFdInfo info;
SocketFdImpl(NativeFd fd) : info(std::move(fd)) {
}
PollableFdInfo &get_poll_info() {
return info;
}
const PollableFdInfo &get_poll_info() const {
return info;
}
const NativeFd &get_native_fd() const {
return info.native_fd();
}
Result<size_t> write(Slice slice) {
int native_fd = get_native_fd().fd();
auto write_res = skip_eintr([&] { return ::write(native_fd, slice.begin(), slice.size()); });
auto write_errno = errno;
if (write_res >= 0) {
return narrow_cast<size_t>(write_res);
}
if (write_errno == EAGAIN
#if EAGAIN != EWOULDBLOCK
|| write_errno == EWOULDBLOCK
#endif
) {
get_poll_info().clear_flags(PollFlags::Write());
return 0;
}
auto error = Status::PosixError(write_errno, PSLICE() << "Write to fd " << native_fd << " has failed");
switch (write_errno) {
case EBADF:
case ENXIO:
case EFAULT:
case EINVAL:
LOG(FATAL) << error;
UNREACHABLE();
default:
LOG(WARNING) << error;
// fallthrough
case ECONNRESET:
case EDQUOT:
case EFBIG:
case EIO:
case ENETDOWN:
case ENETUNREACH:
case ENOSPC:
case EPIPE:
get_poll_info().clear_flags(PollFlags::Write());
get_poll_info().add_flags(PollFlags::Close());
return std::move(error);
}
}
Result<size_t> read(MutableSlice slice) {
if (get_poll_info().get_flags().has_pending_error()) {
TRY_STATUS(get_pending_error());
}
int native_fd = get_native_fd().fd();
CHECK(slice.size() > 0);
auto read_res = skip_eintr([&] { return ::read(native_fd, slice.begin(), slice.size()); });
auto read_errno = errno;
if (read_res >= 0) {
if (read_res == 0) {
errno = 0;
get_poll_info().clear_flags(PollFlags::Read());
get_poll_info().add_flags(PollFlags::Close());
}
return narrow_cast<size_t>(read_res);
}
if (read_errno == EAGAIN
#if EAGAIN != EWOULDBLOCK
|| read_errno == EWOULDBLOCK
#endif
) {
get_poll_info().clear_flags(PollFlags::Read());
return 0;
}
auto error = Status::PosixError(read_errno, PSLICE() << "Read from fd " << native_fd << " has failed");
switch (read_errno) {
case EISDIR:
case EBADF:
case ENXIO:
case EFAULT:
case EINVAL:
LOG(FATAL) << error;
UNREACHABLE();
default:
LOG(WARNING) << error;
// fallthrough
case ENOTCONN:
case EIO:
case ENOBUFS:
case ENOMEM:
case ECONNRESET:
case ETIMEDOUT:
get_poll_info().clear_flags(PollFlags::Read());
get_poll_info().add_flags(PollFlags::Close());
return std::move(error);
}
}
Status get_pending_error() {
if (get_poll_info().get_flags().has_pending_error()) {
return Status::OK();
}
TRY_STATUS(detail::get_socket_pending_error(get_native_fd()));
get_poll_info().clear_flags(PollFlags::Error());
return Status::OK();
}
};
void SocketFdImplDeleter::operator()(SocketFdImpl *impl) {
delete impl;
}
#endif
Status set_native_socket_is_blocking(const NativeFd &fd, bool is_blocking) {
#if TD_PORT_POSIX
if (fcntl(fd.fd(), F_SETFL, is_blocking ? 0 : O_NONBLOCK) == -1) {
#elif TD_PORT_WINDOWS
u_long mode = is_blocking;
if (ioctlsocket(fd.socket(), FIONBIO, &mode) != 0) {
#endif
return OS_SOCKET_ERROR("Failed to change socket flags");
}
return Status::OK();
} }
#if TD_PORT_POSIX #if TD_PORT_POSIX
Result<SocketFd> SocketFd::from_native_fd(int fd) { Status get_socket_pending_error(const NativeFd &fd) {
auto fd_guard = ScopeExit() + [fd]() { ::close(fd); }; int error = 0;
socklen_t errlen = sizeof(error);
if (getsockopt(fd.socket(), SOL_SOCKET, SO_ERROR, static_cast<void *>(&error), &errlen) == 0) {
if (error == 0) {
return Status::OK();
}
return Status::PosixError(error, PSLICE() << "Error on socket [fd_ = " << fd << "]");
}
auto status = OS_SOCKET_ERROR(PSLICE() << "Can't load error on socket [fd_ = " << fd << "]");
LOG(INFO) << "Can't load pending socket error: " << status;
return status;
}
#endif
} // namespace detail
SocketFd::SocketFd() = default;
SocketFd::SocketFd(SocketFd &&) = default;
SocketFd &SocketFd::operator=(SocketFd &&) = default;
SocketFd::~SocketFd() = default;
SocketFd::SocketFd(std::unique_ptr<detail::SocketFdImpl> impl) : impl_(impl.release()) {
}
Result<SocketFd> SocketFd::from_native_fd(NativeFd fd) {
TRY_STATUS(detail::set_native_socket_is_blocking(fd, false)); TRY_STATUS(detail::set_native_socket_is_blocking(fd, false));
auto sock = fd.socket();
// TODO remove copypaste // TODO remove copypaste
#if TD_PORT_POSIX
int flags = 1; int flags = 1;
setsockopt(fd, SOL_SOCKET, SO_REUSEADDR, reinterpret_cast<const char *>(&flags), sizeof(flags)); #elif TD_PORT_WINDOWS
setsockopt(fd, SOL_SOCKET, SO_KEEPALIVE, reinterpret_cast<const char *>(&flags), sizeof(flags)); BOOL flags = TRUE;
setsockopt(fd, IPPROTO_TCP, TCP_NODELAY, reinterpret_cast<const char *>(&flags), sizeof(flags)); #endif
setsockopt(sock, SOL_SOCKET, SO_REUSEADDR, reinterpret_cast<const char *>(&flags), sizeof(flags));
setsockopt(sock, SOL_SOCKET, SO_KEEPALIVE, reinterpret_cast<const char *>(&flags), sizeof(flags));
setsockopt(sock, IPPROTO_TCP, TCP_NODELAY, reinterpret_cast<const char *>(&flags), sizeof(flags));
// TODO: SO_REUSEADDR, SO_KEEPALIVE, TCP_NODELAY, SO_SNDBUF, SO_RCVBUF, TCP_QUICKACK, SO_LINGER // TODO: SO_REUSEADDR, SO_KEEPALIVE, TCP_NODELAY, SO_SNDBUF, SO_RCVBUF, TCP_QUICKACK, SO_LINGER
fd_guard.dismiss(); return SocketFd(std::make_unique<detail::SocketFdImpl>(std::move(fd)));
SocketFd socket;
socket.fd_ = Fd(fd, Fd::Mode::Owner);
return std::move(socket);
} }
#endif
Status SocketFd::init(const IPAddress &address) { Result<SocketFd> SocketFd::open(const IPAddress &address) {
auto fd = socket(address.get_address_family(), SOCK_STREAM, 0); NativeFd native_fd{socket(address.get_address_family(), SOCK_STREAM, 0)};
#if TD_PORT_POSIX if (!native_fd) {
if (fd == -1) {
#elif TD_PORT_WINDOWS
if (fd == INVALID_SOCKET) {
#endif
return OS_SOCKET_ERROR("Failed to create a socket"); return OS_SOCKET_ERROR("Failed to create a socket");
} }
auto fd_quard = ScopeExit() + [fd]() { auto sock = native_fd.socket();
#if TD_PORT_POSIX
::close(fd);
#elif TD_PORT_WINDOWS
::closesocket(fd);
#endif
};
TRY_STATUS(detail::set_native_socket_is_blocking(fd, false)); TRY_STATUS(detail::set_native_socket_is_blocking(native_fd, false));
#if TD_PORT_POSIX #if TD_PORT_POSIX
int flags = 1; int flags = 1;
#elif TD_PORT_WINDOWS #elif TD_PORT_WINDOWS
BOOL flags = TRUE; BOOL flags = TRUE;
#endif #endif
setsockopt(fd, SOL_SOCKET, SO_REUSEADDR, reinterpret_cast<const char *>(&flags), sizeof(flags)); setsockopt(sock, SOL_SOCKET, SO_REUSEADDR, reinterpret_cast<const char *>(&flags), sizeof(flags));
setsockopt(fd, SOL_SOCKET, SO_KEEPALIVE, reinterpret_cast<const char *>(&flags), sizeof(flags)); setsockopt(sock, SOL_SOCKET, SO_KEEPALIVE, reinterpret_cast<const char *>(&flags), sizeof(flags));
setsockopt(fd, IPPROTO_TCP, TCP_NODELAY, reinterpret_cast<const char *>(&flags), sizeof(flags)); setsockopt(sock, IPPROTO_TCP, TCP_NODELAY, reinterpret_cast<const char *>(&flags), sizeof(flags));
// TODO: SO_REUSEADDR, SO_KEEPALIVE, TCP_NODELAY, SO_SNDBUF, SO_RCVBUF, TCP_QUICKACK, SO_LINGER // TODO: SO_REUSEADDR, SO_KEEPALIVE, TCP_NODELAY, SO_SNDBUF, SO_RCVBUF, TCP_QUICKACK, SO_LINGER
#if TD_PORT_POSIX #if TD_PORT_POSIX
int e_connect = connect(fd, address.get_sockaddr(), static_cast<socklen_t>(address.get_sockaddr_len())); int e_connect = connect(sock, address.get_sockaddr(), static_cast<socklen_t>(address.get_sockaddr_len()));
if (e_connect == -1) { if (e_connect == -1) {
auto connect_errno = errno; auto connect_errno = errno;
if (connect_errno != EINPROGRESS) { if (connect_errno != EINPROGRESS) {
return Status::PosixError(connect_errno, PSLICE() << "Failed to connect to " << address); return Status::PosixError(connect_errno, PSLICE() << "Failed to connect to " << address);
} }
} }
fd_ = Fd(fd, Fd::Mode::Owner); return SocketFd(std::make_unique<detail::SocketFdImpl>(std::move(native_fd)));
#elif TD_PORT_WINDOWS #elif TD_PORT_WINDOWS
auto bind_addr = address.get_any_addr(); auto bind_addr = address.get_any_addr();
auto e_bind = bind(fd, bind_addr.get_sockaddr(), narrow_cast<int>(bind_addr.get_sockaddr_len())); auto e_bind = bind(sock, bind_addr.get_sockaddr(), narrow_cast<int>(bind_addr.get_sockaddr_len()));
if (e_bind != 0) { if (e_bind != 0) {
return OS_SOCKET_ERROR("Failed to bind a socket"); return OS_SOCKET_ERROR("Failed to bind a socket");
} }
return SocketFd(std::make_unique<detail::SocketFdImpl>(std::move(native_fd), address));
fd_ = Fd::create_socket_fd(fd);
fd_.connect(address);
#endif #endif
fd_quard.dismiss();
return Status::OK();
}
const Fd &SocketFd::get_fd() const {
return fd_;
}
Fd &SocketFd::get_fd() {
return fd_;
} }
void SocketFd::close() { void SocketFd::close() {
fd_.close(); impl_.reset();
} }
bool SocketFd::empty() const { bool SocketFd::empty() const {
return fd_.empty(); return !impl_;
} }
int32 SocketFd::get_flags() const { PollableFdInfo &SocketFd::get_poll_info() {
return fd_.get_flags(); return impl_->get_poll_info();
}
const PollableFdInfo &SocketFd::get_poll_info() const {
return impl_->get_poll_info();
}
const NativeFd &SocketFd::get_native_fd() const {
return impl_->get_native_fd();
} }
Status SocketFd::get_pending_error() { Status SocketFd::get_pending_error() {
return fd_.get_pending_error(); return impl_->get_pending_error();
} }
Result<size_t> SocketFd::write(Slice slice) { Result<size_t> SocketFd::write(Slice slice) {
return fd_.write(slice); return impl_->write(slice);
} }
Result<size_t> SocketFd::read(MutableSlice slice) { Result<size_t> SocketFd::read(MutableSlice slice) {
return fd_.read(slice); return impl_->read(slice);
} }
} // namespace td } // namespace td

View File

@ -8,50 +8,65 @@
#include "td/utils/port/config.h" #include "td/utils/port/config.h"
#include "td/utils/port/Fd.h" #include "td/utils/port/detail/PollableFd.h"
#include "td/utils/port/IPAddress.h" #include "td/utils/port/IPAddress.h"
#include "td/utils/Slice.h"
#include "td/utils/Status.h" #include "td/utils/Status.h"
namespace td { namespace td {
namespace detail {
class SocketFdImpl;
class SocketFdImplDeleter {
public:
void operator()(SocketFdImpl *impl);
};
class EventFdBsd;
} // namespace detail
class SocketFd { class SocketFd {
public: public:
SocketFd() = default; SocketFd();
SocketFd(const SocketFd &) = delete; SocketFd(const SocketFd &) = delete;
SocketFd &operator=(const SocketFd &) = delete; SocketFd &operator=(const SocketFd &) = delete;
SocketFd(SocketFd &&) = default; SocketFd(SocketFd &&);
SocketFd &operator=(SocketFd &&) = default; SocketFd &operator=(SocketFd &&);
~SocketFd();
static Result<SocketFd> open(const IPAddress &address) TD_WARN_UNUSED_RESULT; static Result<SocketFd> open(const IPAddress &address) TD_WARN_UNUSED_RESULT;
const Fd &get_fd() const; PollableFdInfo &get_poll_info();
Fd &get_fd(); const PollableFdInfo &get_poll_info() const;
int32 get_flags() const;
Status get_pending_error() TD_WARN_UNUSED_RESULT; Status get_pending_error() TD_WARN_UNUSED_RESULT;
Result<size_t> write(Slice slice) TD_WARN_UNUSED_RESULT; Result<size_t> write(Slice slice) TD_WARN_UNUSED_RESULT;
Result<size_t> read(MutableSlice slice) TD_WARN_UNUSED_RESULT; Result<size_t> read(MutableSlice slice) TD_WARN_UNUSED_RESULT;
Result<size_t> write_message(Slice slice) TD_WARN_UNUSED_RESULT;
Result<size_t> read_message(MutableSlice slice) TD_WARN_UNUSED_RESULT;
const NativeFd &get_native_fd() const;
static Result<SocketFd> from_native_fd(NativeFd fd);
void close(); void close();
bool empty() const; bool empty() const;
private: private:
Fd fd_; std::unique_ptr<detail::SocketFdImpl, detail::SocketFdImplDeleter> impl_;
PollableFdInfo &poll_info();
friend class ServerSocketFd; friend class ServerSocketFd;
friend class detail::EventFdBsd;
Status init(const IPAddress &address) TD_WARN_UNUSED_RESULT; explicit SocketFd(std::unique_ptr<detail::SocketFdImpl> impl);
#if TD_PORT_POSIX
static Result<SocketFd> from_native_fd(int fd);
#endif
#if TD_PORT_WINDOWS
explicit SocketFd(Fd fd) : fd_(std::move(fd)) {
}
#endif
}; };
namespace detail {
Status set_native_socket_is_blocking(const NativeFd &fd, bool is_blocking);
#if TD_PORT_POSIX
Status get_socket_pending_error(const NativeFd &fd);
#endif
} // namespace detail
} // namespace td } // namespace td

View File

@ -170,7 +170,7 @@ Status update_atime(CSlice path) {
SCOPE_EXIT { SCOPE_EXIT {
file.close(); file.close();
}; };
return detail::update_atime(file.get_native_fd()); return detail::update_atime(file.get_native_fd().fd());
} }
Result<Stat> stat(CSlice path) { Result<Stat> stat(CSlice path) {

View File

@ -0,0 +1,846 @@
//
// Copyright Aliaksei Levin (levlam@telegram.org), Arseny Smirnov (arseny30@gmail.com) 2014-2018
//
// Distributed under the Boost Software License, Version 1.0. (See accompanying
// file LICENSE_1_0.txt or copy at http://www.boost.org/LICENSE_1_0.txt)
//
#include "td/utils/port/UdpSocketFd.h"
#include "td/utils/port/SocketFd.h"
#include "td/utils/logging.h"
#include "td/utils/format.h"
#include "td/utils/misc.h"
#if TD_PORT_POSIX
#include <arpa/inet.h>
#include <fcntl.h>
#include <netinet/in.h>
#include <netinet/tcp.h>
#include <sys/socket.h>
#include <sys/types.h>
#include <unistd.h>
#if TD_LINUX
#include <linux/errqueue.h>
#endif
#endif // TD_PORT_POSIX
#if TD_PORT_WINDOWS
#include <Mswsock.h>
#include "td/utils/port/Poll.h"
#include "td/utils/VectorQueue.h"
#endif
namespace td {
namespace detail {
#if TD_PORT_WINDOWS
class UdpSocketReceiveHelper {
public:
void to_native(const UdpMessage &message, WSAMSG &message_header) {
socklen_t addr_len{narrow_cast<socklen_t>(sizeof(addr_))};
message_header.name = reinterpret_cast<struct sockaddr *>(&addr_);
message_header.namelen = addr_len;
buf_.buf = const_cast<char *>(message.data.as_slice().begin());
buf_.len = narrow_cast<DWORD>(message.data.size());
message_header.lpBuffers = &buf_;
message_header.dwBufferCount = 1;
message_header.Control.buf = nullptr; // control_buf_.data();
message_header.Control.len = 0; // narrow_cast<decltype(message_header.Control.len)>(control_buf_.size());
message_header.dwFlags = 0;
}
void from_native(WSAMSG &message_header, size_t message_size, UdpMessage &message) {
message.address.init_sockaddr(reinterpret_cast<struct sockaddr *>(message_header.name), message_header.namelen)
.ignore();
message.error = Status::OK();
if ((message_header.dwFlags & (MSG_TRUNC | MSG_CTRUNC)) != 0) {
message.error = Status::Error(501, "message too long");
message.data = BufferSlice();
return;
}
CHECK(message_size <= message.data.size());
message.data.truncate(message_size);
CHECK(message_size == message.data.size());
if (message_size >= 1500) {
LOG(ERROR) << "received datagram of size " << message_size;
}
}
private:
std::array<char, 1024> control_buf_;
sockaddr_storage addr_;
WSABUF buf_;
};
class UdpSocketSendHelper {
public:
void to_native(const UdpMessage &message, WSAMSG &message_header) {
message_header.name = const_cast<struct sockaddr *>(message.address.get_sockaddr());
message_header.namelen = narrow_cast<socklen_t>(message.address.get_sockaddr_len());
buf_.buf = const_cast<char *>(message.data.as_slice().begin());
buf_.len = narrow_cast<DWORD>(message.data.size());
message_header.lpBuffers = &buf_;
message_header.dwBufferCount = 1;
message_header.Control.buf = nullptr;
message_header.Control.len = 0;
message_header.dwFlags = 0;
}
private:
WSABUF buf_;
};
class UdpSocketFdImpl : private IOCP::Callback {
public:
UdpSocketFdImpl(NativeFd fd) : info(std::move(fd)) {
get_poll_info().add_flags(PollFlags::Write());
IOCP::get()->subscribe(get_native_fd(), this);
is_receive_active_ = true;
notify_iocp_connected();
}
PollableFdInfo &get_poll_info() {
return info;
}
const PollableFdInfo &get_poll_info() const {
return info;
}
const NativeFd &get_native_fd() const {
return info.native_fd();
}
void close() {
notify_iocp_close();
}
Result<optional<UdpMessage>> receive() {
auto lock = lock_.lock();
if (!pending_errors_.empty()) {
auto status = pending_errors_.pop();
if (!UdpSocketFd::is_critical_read_error(status)) {
return UdpMessage{{}, {}, std::move(status)};
}
return std::move(status);
}
if (!receive_queue_.empty()) {
return receive_queue_.pop();
}
return optional<UdpMessage>{};
}
void send(UdpMessage message) {
auto lock = lock_.lock();
send_queue_.push(std::move(message));
}
Status flush_send() {
if (is_send_waiting_) {
auto lock = lock_.lock();
is_send_waiting_ = false;
notify_iocp_send();
}
return Status::OK();
}
private:
PollableFdInfo info;
SpinLock lock_;
std::atomic<int> refcnt_{1};
bool is_connected_{false};
bool close_flag_{false};
bool is_send_active_{false};
bool is_send_waiting_{false};
VectorQueue<UdpMessage> send_queue_;
OVERLAPPED send_overlapped_;
bool is_receive_active_{false};
VectorQueue<UdpMessage> receive_queue_;
VectorQueue<Status> pending_errors_;
UdpMessage to_receive_;
WSAMSG receive_message_;
UdpSocketReceiveHelper receive_helper_;
enum : size_t { MAX_PACKET_SIZE = 2048, RESERVED_SIZE = MAX_PACKET_SIZE * 8 };
BufferSlice receive_buffer_;
UdpMessage to_send_;
OVERLAPPED receive_overlapped_;
char close_overlapped_;
bool check_status(DWORD status, Slice message) {
if (status == 0) {
return true;
}
auto last_error = GetLastError();
if (last_error == ERROR_IO_PENDING) {
return true;
}
on_error(OS_SOCKET_ERROR(message));
return false;
}
void loop_receive() {
CHECK(!is_receive_active_);
if (close_flag_) {
return;
}
std::memset(&receive_overlapped_, 0, sizeof(receive_overlapped_));
if (receive_buffer_.size() < MAX_PACKET_SIZE) {
receive_buffer_ = BufferSlice(RESERVED_SIZE);
}
to_receive_.data = receive_buffer_.clone();
receive_helper_.to_native(to_receive_, receive_message_);
LPFN_WSARECVMSG WSARecvMsgPtr = nullptr;
GUID guid = WSAID_WSARECVMSG;
DWORD numBytes;
auto error = ::WSAIoctl(get_native_fd().socket(), SIO_GET_EXTENSION_FUNCTION_POINTER, static_cast<void *>(&guid),
sizeof(guid), static_cast<void *>(&WSARecvMsgPtr), sizeof(WSARecvMsgPtr), &numBytes,
nullptr, nullptr);
if (error) {
on_error(OS_SOCKET_ERROR("WSAIoctl failed"));
return;
}
auto status = WSARecvMsgPtr(get_native_fd().socket(), &receive_message_, nullptr, &receive_overlapped_, nullptr);
if (check_status(status, "receive")) {
inc_refcnt();
is_receive_active_ = true;
}
}
void loop_send() {
CHECK(!is_send_active_);
{
auto lock = lock_.lock();
if (send_queue_.empty()) {
is_send_waiting_ = true;
return;
}
to_send_ = send_queue_.pop();
}
std::memset(&send_overlapped_, 0, sizeof(send_overlapped_));
WSAMSG message;
UdpSocketSendHelper send_helper;
send_helper.to_native(to_send_, message);
auto status = WSASendMsg(get_native_fd().socket(), &message, 0, nullptr, &send_overlapped_, nullptr);
if (check_status(status, "send")) {
inc_refcnt();
is_send_active_ = true;
}
}
void on_iocp(Result<size_t> r_size, OVERLAPPED *overlapped) override {
// called from other thread
if (dec_refcnt() || close_flag_) {
VLOG(fd) << "ignore iocp (file is closing)";
return;
}
if (r_size.is_error()) {
return on_error(r_size.move_as_error());
}
if (!is_connected_ && overlapped == &receive_overlapped_) {
return on_connected();
}
auto size = r_size.move_as_ok();
if (overlapped == &send_overlapped_) {
return on_send(size);
}
if (overlapped == nullptr) {
CHECK(size == 0);
return on_send(size);
}
if (overlapped == &receive_overlapped_) {
return on_receive(size);
}
if (overlapped == reinterpret_cast<OVERLAPPED *>(&close_overlapped_)) {
return on_close();
}
UNREACHABLE();
}
void on_error(Status status) {
VLOG(fd) << get_native_fd().io_handle() << " "
<< "on error " << status;
{
auto lock = lock_.lock();
pending_errors_.push(std::move(status));
}
get_poll_info().add_flags_from_poll(PollFlags::Error());
}
void on_connected() {
VLOG(fd) << get_native_fd().io_handle() << " on connected";
CHECK(!is_connected_);
CHECK(is_receive_active_);
is_connected_ = true;
is_receive_active_ = false;
loop_receive();
loop_send();
}
void on_receive(size_t size) {
VLOG(fd) << get_native_fd().io_handle() << " on receive " << size;
CHECK(is_receive_active_);
is_receive_active_ = false;
receive_helper_.from_native(receive_message_, size, to_receive_);
receive_buffer_.confirm_read((to_receive_.data.size() + 7) & ~7);
{
auto lock = lock_.lock();
LOG(ERROR) << format::escaped(to_receive_.data.as_slice());
receive_queue_.push(std::move(to_receive_));
}
get_poll_info().add_flags_from_poll(PollFlags::Read());
loop_receive();
}
void on_send(size_t size) {
VLOG(fd) << get_native_fd().io_handle() << " on send " << size;
if (size == 0) {
if (is_send_active_) {
return;
}
is_send_active_ = true;
}
CHECK(is_send_active_);
is_send_active_ = false;
loop_send();
}
void on_close() {
VLOG(fd) << get_native_fd().io_handle() << " on close";
close_flag_ = true;
info.set_native_fd({});
}
bool dec_refcnt() {
if (--refcnt_ == 0) {
LOG(ERROR) << "DELETE";
delete this;
return true;
}
return false;
}
void inc_refcnt() {
CHECK(refcnt_ != 0);
refcnt_++;
}
void notify_iocp_send() {
inc_refcnt();
IOCP::get()->post(0, this, nullptr);
}
void notify_iocp_close() {
IOCP::get()->post(0, this, reinterpret_cast<OVERLAPPED *>(&close_overlapped_));
}
void notify_iocp_connected() {
inc_refcnt();
IOCP::get()->post(0, this, reinterpret_cast<OVERLAPPED *>(&receive_overlapped_));
}
};
void UdpSocketFdImplDeleter::operator()(UdpSocketFdImpl *impl) {
impl->close();
}
#elif TD_PORT_POSIX
//struct iovec { [> Scatter/gather array items <]
// void *iov_base; [> Starting address <]
// size_t iov_len; [> Number of bytes to transfer <]
//};
//struct msghdr {
// void *msg_name; [> optional address <]
// socklen_t msg_namelen; [> size of address <]
// struct iovec *msg_iov; [> scatter/gather array <]
// size_t msg_iovlen; [> # elements in msg_iov <]
// void *msg_control; [> ancillary data, see below <]
// size_t msg_controllen; [> ancillary data buffer len <]
// int msg_flags; [> flags on received message <]
//};
class UdpSocketReceiveHelper {
public:
void to_native(const UdpSocketFd::InboundMessage &message, struct msghdr &message_header) {
socklen_t addr_len{narrow_cast<socklen_t>(sizeof(addr_))};
message_header.msg_name = &addr_;
message_header.msg_namelen = addr_len;
io_vec_.iov_base = message.data.begin();
io_vec_.iov_len = message.data.size();
message_header.msg_iov = &io_vec_;
message_header.msg_iovlen = 1;
message_header.msg_control = control_buf_.data();
message_header.msg_controllen = narrow_cast<decltype(message_header.msg_controllen)>(control_buf_.size());
message_header.msg_flags = 0;
}
void from_native(struct msghdr &message_header, size_t message_size, UdpSocketFd::InboundMessage &message) {
#if TD_LINUX
struct cmsghdr *cmsg;
struct sock_extended_err *ee = nullptr;
for (cmsg = CMSG_FIRSTHDR(&message_header); cmsg != nullptr; cmsg = CMSG_NXTHDR(&message_header, cmsg)) {
if (cmsg->cmsg_type == IP_PKTINFO && cmsg->cmsg_level == IPPROTO_IP) {
//auto *pi = reinterpret_cast<struct in_pktinfo *>(CMSG_DATA(cmsg));
} else if (cmsg->cmsg_type == IPV6_PKTINFO && cmsg->cmsg_level == IPPROTO_IPV6) {
//auto *pi = reinterpret_cast<struct in6_pktinfo *>(CMSG_DATA(cmsg));
} else if ((cmsg->cmsg_type == IP_RECVERR && cmsg->cmsg_level == IPPROTO_IP) ||
(cmsg->cmsg_type == IPV6_RECVERR && cmsg->cmsg_level == IPPROTO_IPV6)) {
ee = reinterpret_cast<struct sock_extended_err *>(CMSG_DATA(cmsg));
}
}
if (ee != nullptr) {
auto *addr = reinterpret_cast<struct sockaddr *>(SO_EE_OFFENDER(ee));
IPAddress address;
address.init_sockaddr(addr).ignore();
if (message.from != nullptr) {
*message.from = address;
}
if (message.error) {
*message.error = Status::PosixError(ee->ee_errno, "");
}
//message.data = MutableSlice();
message.data.truncate(0);
return;
}
#endif
if (message.from != nullptr) {
message.from
->init_sockaddr(reinterpret_cast<struct sockaddr *>(message_header.msg_name), message_header.msg_namelen)
.ignore();
}
if (message.error) {
*message.error = Status::OK();
}
if (message_header.msg_flags & MSG_TRUNC) {
if (message.error) {
*message.error = Status::Error(501, "message too long");
}
message.data.truncate(0);
return;
}
CHECK(message_size <= message.data.size());
message.data.truncate(message_size);
CHECK(message_size == message.data.size());
if (message_size >= 1500) {
LOG(ERROR) << "received datagram of size " << message_size;
}
}
private:
std::array<char, 1024> control_buf_;
sockaddr_storage addr_;
struct iovec io_vec_;
};
class UdpSocketSendHelper {
public:
void to_native(const UdpSocketFd::OutboundMessage &message, struct msghdr &message_header) {
CHECK(message.to != nullptr && message.to->is_valid());
message_header.msg_name = const_cast<struct sockaddr *>(message.to->get_sockaddr());
message_header.msg_namelen = narrow_cast<socklen_t>(message.to->get_sockaddr_len());
io_vec_.iov_base = const_cast<char *>(message.data.begin());
io_vec_.iov_len = message.data.size();
message_header.msg_iov = &io_vec_;
message_header.msg_iovlen = 1;
//TODO
message_header.msg_control = nullptr;
message_header.msg_controllen = 0;
message_header.msg_flags = 0;
}
private:
struct iovec io_vec_;
};
class UdpSocketFdImpl {
public:
UdpSocketFdImpl(NativeFd fd) : info(std::move(fd)) {
}
PollableFdInfo &get_poll_info() {
return info;
}
const PollableFdInfo &get_poll_info() const {
return info;
}
const NativeFd &get_native_fd() const {
return info.native_fd();
}
Status get_pending_error() {
if (get_poll_info().get_flags().has_pending_error()) {
return Status::OK();
}
TRY_STATUS(detail::get_socket_pending_error(get_native_fd()));
get_poll_info().clear_flags(PollFlags::Error());
return Status::OK();
}
Status receive_message(UdpSocketFd::InboundMessage &message, bool &is_received) {
is_received = false;
int flags = 0;
if (get_poll_info().get_flags().has_pending_error()) {
#ifdef MSG_ERRQUEUE
flags = MSG_ERRQUEUE;
#else
return get_pending_error();
#endif
}
struct msghdr message_header;
detail::UdpSocketReceiveHelper helper;
helper.to_native(message, message_header);
auto native_fd = get_native_fd().fd();
auto recvmsg_res = skip_eintr([&] { return recvmsg(native_fd, &message_header, flags); });
auto recvmsg_errno = errno;
if (recvmsg_res >= 0) {
helper.from_native(message_header, recvmsg_res, message);
is_received = true;
return Status::OK();
}
return process_recvmsg_error(recvmsg_errno, is_received);
}
Status process_recvmsg_error(int recvmsg_errno, bool &is_received) {
is_received = false;
if (recvmsg_errno == EAGAIN
#if EAGAIN != EWOULDBLOCK
|| recvmsg_errno == EWOULDBLOCK
#endif
) {
if (get_poll_info().get_flags_local().has_pending_error()) {
get_poll_info().clear_flags(PollFlags::Error());
} else {
get_poll_info().clear_flags(PollFlags::Read());
}
return Status::OK();
}
auto error =
Status::PosixError(recvmsg_errno, PSLICE() << "Receive from [fd=" << get_native_fd() << "] has failed");
switch (recvmsg_errno) {
case EBADF:
case EFAULT:
case EINVAL:
case ENOTCONN:
case ECONNRESET:
case ETIMEDOUT:
LOG(FATAL) << error;
UNREACHABLE();
default:
LOG(WARNING) << "Unknown error: " << error;
// fallthrough
case ENOBUFS:
case ENOMEM:
#ifdef MSG_ERRQUEUE
get_poll_info().add_flags(PollFlags::Error());
#endif
return error;
}
}
Status send_message(const UdpSocketFd::OutboundMessage &message, bool &is_sent) {
is_sent = false;
struct msghdr message_header;
detail::UdpSocketSendHelper helper;
helper.to_native(message, message_header);
auto native_fd = get_native_fd().fd();
auto sendmsg_res = skip_eintr([&] { return sendmsg(native_fd, &message_header, 0); });
auto sendmsg_errno = errno;
if (sendmsg_res >= 0) {
is_sent = true;
return Status::OK();
}
return process_sendmsg_error(sendmsg_errno, is_sent);
}
Status process_sendmsg_error(int sendmsg_errno, bool &is_sent) {
if (sendmsg_errno == EAGAIN
#if EAGAIN != EWOULDBLOCK
|| sendmsg_errno == EWOULDBLOCK
#endif
) {
get_poll_info().clear_flags(PollFlags::Write());
return Status::OK();
}
auto error = Status::PosixError(sendmsg_errno, PSLICE() << "Send from [fd=" << get_native_fd() << "] has failed");
switch (sendmsg_errno) {
// Still may send some other packets, but there is no point to resend this particular message
case EACCES:
case EMSGSIZE:
case EPERM:
LOG(WARNING) << "Silently drop packet :( " << error;
//TODO: get errors from MSG_ERRQUEUE is possible
is_sent = true;
return error;
// Some general problems, wich may be fixed in future
case ENOMEM:
case EDQUOT:
case EFBIG:
case ENETDOWN:
case ENETUNREACH:
case ENOSPC:
case EHOSTUNREACH:
case ENOBUFS:
default:
#ifdef MSG_ERRQUEUE
get_poll_info().add_flags(PollFlags::Error());
#endif
return error;
case EBADF: // impossible
case ENOTSOCK: // impossible
case EPIPE: // impossible for udp
case ECONNRESET: // impossible for udp
case EDESTADDRREQ: // we checked that address is valid
case ENOTCONN: // we checked that address is valid
case EINTR: // we already skipped all EINTR
case EISCONN: // impossible for udp socket
case EOPNOTSUPP:
case ENOTDIR:
case EFAULT:
case EINVAL:
case EAFNOSUPPORT:
LOG(FATAL) << error;
UNREACHABLE();
return error;
}
}
Status send_messages(Span<UdpSocketFd::OutboundMessage> messages, size_t &cnt) {
#if TD_HAS_MMSG
return send_messages_fast(messages, cnt);
#else
return send_messages_slow(messages, cnt);
#endif
}
Status receive_messages(MutableSpan<UdpSocketFd::InboundMessage> messages, size_t &cnt) {
#if TD_HAS_MMSG
return receive_messages_fast(messages, cnt);
#else
return receive_messages_slow(messages, cnt);
#endif
}
private:
PollableFdInfo info;
Status send_messages_slow(Span<UdpSocketFd::OutboundMessage> messages, size_t &cnt) {
cnt = 0;
for (auto &message : messages) {
CHECK(!message.data.empty());
bool is_sent;
auto error = send_message(message, is_sent);
cnt += is_sent;
TRY_STATUS(std::move(error));
}
return Status::OK();
}
#if TD_HAS_MMSG
Status send_messages_fast(Span<UdpSocketFd::OutboundMessage> messages, size_t &cnt) {
//struct mmsghdr {
// struct msghdr msg_hdr; [> Message header <]
// unsigned int msg_len; [> Number of bytes transmitted <]
//};
struct std::array<detail::UdpSocketSendHelper, 16> helpers;
struct std::array<struct mmsghdr, 16> headers;
size_t to_send = std::min(messages.size(), headers.size());
for (size_t i = 0; i < to_send; i++) {
helpers[i].to_native(messages[i], headers[i].msg_hdr);
headers[i].msg_len = 0;
}
auto native_fd = get_native_fd().fd();
auto sendmmsg_res =
skip_eintr([&] { return sendmmsg(native_fd, headers.data(), narrow_cast<unsigned int>(to_send), 0); });
auto sendmmsg_errno = errno;
if (sendmmsg_res >= 0) {
cnt = sendmmsg_res;
return Status::OK();
}
bool is_sent = false;
auto status = process_sendmsg_error(sendmmsg_errno, is_sent);
cnt = is_sent;
return status;
}
#endif
Status receive_messages_slow(MutableSpan<UdpSocketFd::InboundMessage> messages, size_t &cnt) {
cnt = 0;
while (cnt < messages.size() && get_poll_info().get_flags().can_read()) {
auto &message = messages[cnt];
CHECK(!message.data.empty());
bool is_received;
auto error = receive_message(message, is_received);
cnt += is_received;
TRY_STATUS(std::move(error));
}
return Status::OK();
}
#if TD_HAS_MMSG
Status receive_messages_fast(MutableSpan<UdpSocketFd::InboundMessage> messages, size_t &cnt) {
int flags = 0;
cnt = 0;
if (get_poll_info().get_flags().has_pending_error()) {
#ifdef MSG_ERRQUEUE
flags = MSG_ERRQUEUE;
#else
return fd_.get_pending_error();
#endif
}
//struct mmsghdr {
// struct msghdr msg_hdr; [> Message header <]
// unsigned int msg_len; [> Number of bytes transmitted <]
//};
struct std::array<detail::UdpSocketReceiveHelper, 16> helpers;
struct std::array<struct mmsghdr, 16> headers;
size_t to_receive = std::min(messages.size(), headers.size());
for (size_t i = 0; i < to_receive; i++) {
helpers[i].to_native(messages[i], headers[i].msg_hdr);
headers[i].msg_len = 0;
}
auto native_fd = get_native_fd().fd();
auto recvmmsg_res = skip_eintr(
[&] { return recvmmsg(native_fd, headers.data(), narrow_cast<unsigned int>(to_receive), flags, nullptr); });
auto recvmmsg_errno = errno;
if (recvmmsg_res >= 0) {
cnt = narrow_cast<size_t>(recvmmsg_res);
for (size_t i = 0; i < cnt; i++) {
helpers[i].from_native(headers[i].msg_hdr, headers[i].msg_len, messages[i]);
}
return Status::OK();
}
bool is_received;
auto status = process_recvmsg_error(recvmmsg_errno, is_received);
cnt = is_received;
return status;
}
#endif
};
void UdpSocketFdImplDeleter::operator()(UdpSocketFdImpl *impl) {
delete impl;
}
#endif
} // namespace detail
UdpSocketFd::UdpSocketFd() = default;
UdpSocketFd::UdpSocketFd(UdpSocketFd &&) = default;
UdpSocketFd &UdpSocketFd::operator=(UdpSocketFd &&) = default;
UdpSocketFd::~UdpSocketFd() = default;
PollableFdInfo &UdpSocketFd::get_poll_info() {
return impl_->get_poll_info();
}
const PollableFdInfo &UdpSocketFd::get_poll_info() const {
return impl_->get_poll_info();
}
//Result<UdpSocketFd> UdpSocketFd::from_native_fd(int fd) {
//auto fd_guard = ScopeExit() + [fd]() { ::close(fd); };
//TRY_STATUS(detail::set_native_socket_is_blocking(fd, false));
//// TODO remove copypaste
//int flags = 1;
//setsockopt(fd, SOL_SOCKET, SO_REUSEADDR, reinterpret_cast<const char *>(&flags), sizeof(flags));
//setsockopt(fd, SOL_SOCKET, SO_KEEPALIVE, reinterpret_cast<const char *>(&flags), sizeof(flags));
//// TODO: SO_REUSEADDR, SO_KEEPALIVE, TCP_NODELAY, SO_SNDBUF, SO_RCVBUF, TCP_QUICKACK, SO_LINGER
//fd_guard.dismiss();
//UdpSocketFd socket;
//socket.fd_ = Fd(fd, Fd::Mode::Owner);
//return std::move(socket);
//}
Result<UdpSocketFd> UdpSocketFd::open(const IPAddress &address) {
NativeFd native_fd{socket(address.get_address_family(), SOCK_DGRAM, IPPROTO_UDP)};
if (!native_fd) {
return OS_SOCKET_ERROR("Failed to create a socket");
}
TRY_STATUS(detail::set_native_socket_is_blocking(native_fd, false));
auto sock = native_fd.socket();
#if TD_PORT_POSIX
int flags = 1;
#elif TD_PORT_WINDOWS
BOOL flags = TRUE;
#endif
setsockopt(sock, SOL_SOCKET, SO_REUSEADDR, reinterpret_cast<const char *>(&flags), sizeof(flags));
setsockopt(sock, SOL_SOCKET, SO_KEEPALIVE, reinterpret_cast<const char *>(&flags), sizeof(flags));
// TODO: SO_REUSEADDR, SO_KEEPALIVE, TCP_NODELAY, SO_SNDBUF, SO_RCVBUF, TCP_QUICKACK, SO_LINGER
auto bind_addr = address.get_any_addr();
bind_addr.set_port(address.get_port());
auto e_bind = bind(sock, bind_addr.get_sockaddr(), narrow_cast<int>(bind_addr.get_sockaddr_len()));
if (e_bind != 0) {
return OS_SOCKET_ERROR("Failed to bind a socket");
}
return UdpSocketFd(std::make_unique<detail::UdpSocketFdImpl>(std::move(native_fd)));
}
UdpSocketFd::UdpSocketFd(std::unique_ptr<detail::UdpSocketFdImpl> impl) : impl_(impl.release()) {
}
void UdpSocketFd::close() {
impl_.reset();
}
bool UdpSocketFd::empty() const {
return !impl_;
}
const NativeFd &UdpSocketFd::get_native_fd() const {
return get_poll_info().native_fd();
}
#if TD_PORT_POSIX
Status UdpSocketFd::send_message(const OutboundMessage &message, bool &is_sent) {
return impl_->send_message(message, is_sent);
}
Status UdpSocketFd::receive_message(InboundMessage &message, bool &is_received) {
return impl_->receive_message(message, is_received);
}
Status UdpSocketFd::send_messages(Span<OutboundMessage> messages, size_t &count) {
return impl_->send_messages(messages, count);
}
Status UdpSocketFd::receive_messages(MutableSpan<InboundMessage> messages, size_t &count) {
return impl_->receive_messages(messages, count);
}
#endif
#if TD_PORT_WINDOWS
Result<optional<UdpMessage>> UdpSocketFd::receive() {
return impl_->receive();
}
void UdpSocketFd::send(UdpMessage message) {
return impl_->send(std::move(message));
}
Status UdpSocketFd::flush_send() {
return impl_->flush_send();
}
#endif
bool UdpSocketFd::is_critical_read_error(const Status &status) {
return status.code() == ENOMEM || status.code() == ENOBUFS;
}
} // namespace td

View File

@ -0,0 +1,85 @@
//
// Copyright Aliaksei Levin (levlam@telegram.org), Arseny Smirnov (arseny30@gmail.com) 2014-2018
//
// Distributed under the Boost Software License, Version 1.0. (See accompanying
// file LICENSE_1_0.txt or copy at http://www.boost.org/LICENSE_1_0.txt)
//
#pragma once
#include "td/utils/port/config.h"
#include "td/utils/port/detail/PollableFd.h"
#include "td/utils/port/IPAddress.h"
#include "td/utils/Slice.h"
#include "td/utils/Status.h"
#include "td/utils/buffer.h"
#include "td/utils/Span.h"
#include "td/utils/optional.h"
namespace td {
// Udp and errors
namespace detail {
class UdpSocketFdImpl;
class UdpSocketFdImplDeleter {
public:
void operator()(UdpSocketFdImpl *impl);
};
} // namespace detail
struct UdpMessage {
IPAddress address;
BufferSlice data;
Status error;
};
class UdpSocketFd {
public:
UdpSocketFd();
UdpSocketFd(UdpSocketFd &&);
UdpSocketFd &operator=(UdpSocketFd &&);
~UdpSocketFd();
UdpSocketFd(const UdpSocketFd &) = delete;
UdpSocketFd &operator=(const UdpSocketFd &) = delete;
static Result<UdpSocketFd> open(const IPAddress &address) TD_WARN_UNUSED_RESULT;
PollableFdInfo &get_poll_info();
const PollableFdInfo &get_poll_info() const;
const NativeFd &get_native_fd() const;
void close();
bool empty() const;
static bool is_critical_read_error(const Status &status);
#if TD_PORT_POSIX
struct OutboundMessage {
const IPAddress *to;
Slice data;
};
struct InboundMessage {
IPAddress *from;
MutableSlice data;
Status *error;
};
Status send_message(const OutboundMessage &message, bool &is_sent) TD_WARN_UNUSED_RESULT;
Status receive_message(InboundMessage &message, bool &is_received) TD_WARN_UNUSED_RESULT;
Status send_messages(Span<OutboundMessage> messages, size_t &count) TD_WARN_UNUSED_RESULT;
Status receive_messages(MutableSpan<InboundMessage> messages, size_t &count) TD_WARN_UNUSED_RESULT;
#elif TD_PORT_WINDOWS
Result<optional<UdpMessage> > receive();
void send(UdpMessage message);
Status flush_send();
#endif
private:
std::unique_ptr<detail::UdpSocketFdImpl, detail::UdpSocketFdImplDeleter> impl_;
explicit UdpSocketFd(std::unique_ptr<detail::UdpSocketFdImpl> impl);
};
} // namespace td

View File

@ -43,4 +43,8 @@
#define TD_THREAD_STL 1 #define TD_THREAD_STL 1
#endif #endif
#if TD_LINUX
#define TD_HAS_MMSG 1
#endif
// clang-format on // clang-format on

View File

@ -35,37 +35,46 @@ void Epoll::clear() {
close(epoll_fd); close(epoll_fd);
epoll_fd = -1; epoll_fd = -1;
for (auto *list_node = list_root.next; list_node != &list_root;) {
auto pollable_fd = PollableFd::from_list_node(list_node);
list_node = list_node->next;
}
} }
void Epoll::subscribe(const Fd &fd, Fd::Flags flags) { void Epoll::subscribe(PollableFd fd, PollFlags flags) {
epoll_event event; epoll_event event;
event.events = EPOLLHUP | EPOLLERR | EPOLLET; event.events = EPOLLHUP | EPOLLERR | EPOLLET;
#ifdef EPOLLRDHUP #ifdef EPOLLRDHUP
event.events |= EPOLLRDHUP; event.events |= EPOLLRDHUP;
#endif #endif
if (flags & Fd::Read) { if (flags.can_read()) {
event.events |= EPOLLIN; event.events |= EPOLLIN;
} }
if (flags & Fd::Write) { if (flags.can_write()) {
event.events |= EPOLLOUT; event.events |= EPOLLOUT;
} }
auto native_fd = fd.get_native_fd(); auto native_fd = fd.native_fd().fd();
event.data.fd = native_fd; auto *list_node = fd.release_as_list_node();
list_root.put(list_node);
event.data.ptr = list_node;
int err = epoll_ctl(epoll_fd, EPOLL_CTL_ADD, native_fd, &event); int err = epoll_ctl(epoll_fd, EPOLL_CTL_ADD, native_fd, &event);
auto epoll_ctl_errno = errno; auto epoll_ctl_errno = errno;
LOG_IF(FATAL, err == -1) << Status::PosixError(epoll_ctl_errno, "epoll_ctl ADD failed") << ", epoll_fd = " << epoll_fd LOG_IF(FATAL, err == -1) << Status::PosixError(epoll_ctl_errno, "epoll_ctl ADD failed") << ", epoll_fd = " << epoll_fd
<< ", fd = " << native_fd; << ", fd = " << native_fd;
} }
void Epoll::unsubscribe(const Fd &fd) { void Epoll::unsubscribe(PollableFdRef fd_ref) {
auto native_fd = fd.get_native_fd(); auto fd = fd_ref.lock();
auto native_fd = fd.native_fd().fd();
int err = epoll_ctl(epoll_fd, EPOLL_CTL_DEL, native_fd, nullptr); int err = epoll_ctl(epoll_fd, EPOLL_CTL_DEL, native_fd, nullptr);
auto epoll_ctl_errno = errno; auto epoll_ctl_errno = errno;
LOG_IF(FATAL, err == -1) << Status::PosixError(epoll_ctl_errno, "epoll_ctl DEL failed") << ", epoll_fd = " << epoll_fd LOG_IF(FATAL, err == -1) << Status::PosixError(epoll_ctl_errno, "epoll_ctl DEL failed") << ", epoll_fd = " << epoll_fd
<< ", fd = " << native_fd; << ", fd = " << native_fd;
} }
void Epoll::unsubscribe_before_close(const Fd &fd) { void Epoll::unsubscribe_before_close(PollableFdRef fd) {
unsubscribe(fd); unsubscribe(fd);
} }
@ -76,15 +85,15 @@ void Epoll::run(int timeout_ms) {
<< Status::PosixError(epoll_wait_errno, "epoll_wait failed"); << Status::PosixError(epoll_wait_errno, "epoll_wait failed");
for (int i = 0; i < ready_n; i++) { for (int i = 0; i < ready_n; i++) {
Fd::Flags flags = 0; PollFlags flags;
epoll_event *event = &events[i]; epoll_event *event = &events[i];
if (event->events & EPOLLIN) { if (event->events & EPOLLIN) {
event->events &= ~EPOLLIN; event->events &= ~EPOLLIN;
flags |= Fd::Read; flags = flags | PollFlags::Read();
} }
if (event->events & EPOLLOUT) { if (event->events & EPOLLOUT) {
event->events &= ~EPOLLOUT; event->events &= ~EPOLLOUT;
flags |= Fd::Write; flags = flags | PollFlags::Write();
} }
#ifdef EPOLLRDHUP #ifdef EPOLLRDHUP
if (event->events & EPOLLRDHUP) { if (event->events & EPOLLRDHUP) {
@ -95,17 +104,19 @@ void Epoll::run(int timeout_ms) {
#endif #endif
if (event->events & EPOLLHUP) { if (event->events & EPOLLHUP) {
event->events &= ~EPOLLHUP; event->events &= ~EPOLLHUP;
flags |= Fd::Close; flags = flags | PollFlags::Close();
} }
if (event->events & EPOLLERR) { if (event->events & EPOLLERR) {
event->events &= ~EPOLLERR; event->events &= ~EPOLLERR;
flags |= Fd::Error; flags = flags | PollFlags::Error();
} }
if (event->events) { if (event->events) {
LOG(FATAL) << "Unsupported epoll events: " << event->events; LOG(FATAL) << "Unsupported epoll events: " << event->events;
} }
// LOG(DEBUG) << "Epoll event " << tag("fd", event->data.fd) << tag("flags", format::as_binary(flags)); //LOG(DEBUG) << "Epoll event " << tag("fd", event->data.fd) << tag("flags", format::as_binary(flags));
Fd(event->data.fd, Fd::Mode::Reference).update_flags_notify(flags); auto pollable_fd = PollableFd::from_list_node(static_cast<ListNode *>(event->data.ptr));
pollable_fd.add_flags(flags);
pollable_fd.release_as_list_node();
} }
} }
} // namespace detail } // namespace detail

View File

@ -32,17 +32,22 @@ class Epoll final : public PollBase {
void clear() override; void clear() override;
void subscribe(const Fd &fd, Fd::Flags flags) override; void subscribe(PollableFd fd, PollFlags flags) override;
void unsubscribe(const Fd &fd) override; void unsubscribe(PollableFdRef fd) override;
void unsubscribe_before_close(const Fd &fd) override; void unsubscribe_before_close(PollableFdRef fd) override;
void run(int timeout_ms) override; void run(int timeout_ms) override;
static bool is_edge_triggered() {
return true;
}
private: private:
int epoll_fd = -1; int epoll_fd = -1;
vector<struct epoll_event> events; vector<struct epoll_event> events;
ListNode list_root;
}; };
} // namespace detail } // namespace detail

View File

@ -12,6 +12,7 @@ char disable_linker_warning_about_empty_file_event_fd_bsd_cpp TD_UNUSED;
#include "td/utils/logging.h" #include "td/utils/logging.h"
#include "td/utils/Slice.h" #include "td/utils/Slice.h"
#include "td/utils/port/SocketFd.h"
#include <fcntl.h> #include <fcntl.h>
#include <sys/socket.h> #include <sys/socket.h>
@ -36,11 +37,13 @@ void EventFdBsd::init() {
#endif #endif
LOG_IF(FATAL, err == -1) << Status::PosixError(socketpair_errno, "socketpair failed"); LOG_IF(FATAL, err == -1) << Status::PosixError(socketpair_errno, "socketpair failed");
detail::set_native_socket_is_blocking(fds[0], false).ensure(); auto fd_a = NativeFd(fds[0]);
detail::set_native_socket_is_blocking(fds[1], false).ensure(); auto fd_b = NativeFd(fds[1]);
detail::set_native_socket_is_blocking(fd_a, false).ensure();
detail::set_native_socket_is_blocking(fd_b, false).ensure();
in_ = Fd(fds[0], Fd::Mode::Owner); in_ = SocketFd::from_native_fd(std::move(fd_a)).move_as_ok();
out_ = Fd(fds[1], Fd::Mode::Owner); out_ = SocketFd::from_native_fd(std::move(fd_b)).move_as_ok();
} }
bool EventFdBsd::empty() { bool EventFdBsd::empty() {
@ -56,12 +59,8 @@ Status EventFdBsd::get_pending_error() {
return Status::OK(); return Status::OK();
} }
const Fd &EventFdBsd::get_fd() const { PollableFdInfo &EventFdBsd::get_poll_info() {
return out_; return out_.get_poll_info();
}
Fd &EventFdBsd::get_fd() {
return out_;
} }
void EventFdBsd::release() { void EventFdBsd::release() {
@ -77,7 +76,7 @@ void EventFdBsd::release() {
} }
void EventFdBsd::acquire() { void EventFdBsd::acquire() {
out_.update_flags(Fd::Read); out_.get_poll_info().add_flags(PollFlags::Read());
while (can_read(out_)) { while (can_read(out_)) {
uint8 value[1024]; uint8 value[1024];
auto result = out_.read(MutableSlice(value, sizeof(value))); auto result = out_.read(MutableSlice(value, sizeof(value)));

View File

@ -12,15 +12,15 @@
#include "td/utils/common.h" #include "td/utils/common.h"
#include "td/utils/port/EventFdBase.h" #include "td/utils/port/EventFdBase.h"
#include "td/utils/port/Fd.h" #include "td/utils/port/SocketFd.h"
#include "td/utils/Status.h" #include "td/utils/Status.h"
namespace td { namespace td {
namespace detail { namespace detail {
class EventFdBsd final : public EventFdBase { class EventFdBsd final : public EventFdBase {
Fd in_; SocketFd in_;
Fd out_; SocketFd out_;
public: public:
EventFdBsd() = default; EventFdBsd() = default;
@ -33,8 +33,7 @@ class EventFdBsd final : public EventFdBase {
Status get_pending_error() override TD_WARN_UNUSED_RESULT; Status get_pending_error() override TD_WARN_UNUSED_RESULT;
const Fd &get_fd() const override; PollableFdInfo &get_poll_info() override;
Fd &get_fd() override;
void release() override; void release() override;

View File

@ -6,50 +6,70 @@
// //
#include "td/utils/port/detail/EventFdLinux.h" #include "td/utils/port/detail/EventFdLinux.h"
#include "td/utils/misc.h"
char disable_linker_warning_about_empty_file_event_fd_linux_cpp TD_UNUSED; char disable_linker_warning_about_empty_file_event_fd_linux_cpp TD_UNUSED;
#ifdef TD_EVENTFD_LINUX #ifdef TD_EVENTFD_LINUX
#include "td/utils/port/detail/PollableFd.h"
#include "td/utils/logging.h" #include "td/utils/logging.h"
#include "td/utils/Slice.h" #include "td/utils/Slice.h"
#include <sys/eventfd.h> #include <sys/eventfd.h>
#include <unistd.h>
namespace td { namespace td {
namespace detail { namespace detail {
class EventFdLinuxImpl {
public:
PollableFdInfo info;
};
EventFdLinux::EventFdLinux() = default;
EventFdLinux::EventFdLinux(EventFdLinux &&) = default;
EventFdLinux &EventFdLinux::operator=(EventFdLinux &&) = default;
EventFdLinux::~EventFdLinux() = default;
void EventFdLinux::init() { void EventFdLinux::init() {
int fd = eventfd(0, EFD_NONBLOCK | EFD_CLOEXEC); auto fd = NativeFd(eventfd(0, EFD_NONBLOCK | EFD_CLOEXEC));
auto eventfd_errno = errno; auto eventfd_errno = errno;
LOG_IF(FATAL, fd == -1) << Status::PosixError(eventfd_errno, "eventfd call failed"); LOG_IF(FATAL, !fd) << Status::PosixError(eventfd_errno, "eventfd call failed");
impl_ = std::make_unique<EventFdLinuxImpl>();
fd_ = Fd(fd, Fd::Mode::Owner); impl_->info.set_native_fd(std::move(fd));
} }
bool EventFdLinux::empty() { bool EventFdLinux::empty() {
return fd_.empty(); return !impl_;
} }
void EventFdLinux::close() { void EventFdLinux::close() {
fd_.close(); impl_.reset();
} }
Status EventFdLinux::get_pending_error() { Status EventFdLinux::get_pending_error() {
return Status::OK(); return Status::OK();
} }
const Fd &EventFdLinux::get_fd() const { PollableFdInfo &EventFdLinux::get_poll_info() {
return fd_; return impl_->info;
}
Fd &EventFdLinux::get_fd() {
return fd_;
} }
// NB: will be called from multiple threads
void EventFdLinux::release() { void EventFdLinux::release() {
const uint64 value = 1; const uint64 value = 1;
// NB: write_unsafe is used, because release will be called from multiple threads auto slice = Slice(reinterpret_cast<const char *>(&value), sizeof(value));
auto result = fd_.write_unsafe(Slice(reinterpret_cast<const char *>(&value), sizeof(value))); auto native_fd = impl_->info.native_fd().fd();
auto result = [&]() -> Result<size_t> {
auto write_res = skip_eintr([&] { return ::write(native_fd, slice.begin(), slice.size()); });
auto write_errno = errno;
if (write_res >= 0) {
return narrow_cast<size_t>(write_res);
}
return Status::PosixError(write_errno, PSLICE() << "Write to fd " << native_fd << " has failed");
}();
if (result.is_error()) { if (result.is_error()) {
LOG(FATAL) << "EventFdLinux write failed: " << result.error(); LOG(FATAL) << "EventFdLinux write failed: " << result.error();
} }
@ -61,11 +81,29 @@ void EventFdLinux::release() {
void EventFdLinux::acquire() { void EventFdLinux::acquire() {
uint64 res; uint64 res;
auto result = fd_.read(MutableSlice(reinterpret_cast<char *>(&res), sizeof(res))); auto slice = MutableSlice(reinterpret_cast<char *>(&res), sizeof(res));
auto native_fd = impl_->info.native_fd().fd();
auto result = [&]() -> Result<size_t> {
CHECK(slice.size() > 0);
auto read_res = skip_eintr([&] { return ::read(native_fd, slice.begin(), slice.size()); });
auto read_errno = errno;
if (read_res >= 0) {
CHECK(read_res != 0);
return narrow_cast<size_t>(read_res);
}
if (read_errno == EAGAIN
#if EAGAIN != EWOULDBLOCK
|| read_errno == EWOULDBLOCK
#endif
) {
get_poll_info().clear_flags(PollFlags::Read());
return 0;
}
return Status::PosixError(read_errno, PSLICE() << "Read from fd " << native_fd << " has failed");
}();
if (result.is_error()) { if (result.is_error()) {
LOG(FATAL) << "EventFdLinux read failed: " << result.error(); LOG(FATAL) << "EventFdLinux read failed: " << result.error();
} }
fd_.clear_flags(Fd::Read);
} }
} // namespace detail } // namespace detail

View File

@ -12,16 +12,22 @@
#include "td/utils/common.h" #include "td/utils/common.h"
#include "td/utils/port/EventFdBase.h" #include "td/utils/port/EventFdBase.h"
#include "td/utils/port/Fd.h" #include "td/utils/port/SocketFd.h"
#include "td/utils/Status.h" #include "td/utils/Status.h"
namespace td { namespace td {
namespace detail { namespace detail {
class EventFdLinuxImpl;
class EventFdLinux final : public EventFdBase { class EventFdLinux final : public EventFdBase {
Fd fd_; std::unique_ptr<EventFdLinuxImpl> impl_;
public: public:
EventFdLinux();
EventFdLinux(EventFdLinux &&);
EventFdLinux &operator=(EventFdLinux &&);
~EventFdLinux();
void init() override; void init() override;
bool empty() override; bool empty() override;
@ -30,8 +36,7 @@ class EventFdLinux final : public EventFdBase {
Status get_pending_error() override TD_WARN_UNUSED_RESULT; Status get_pending_error() override TD_WARN_UNUSED_RESULT;
const Fd &get_fd() const override; PollableFdInfo &get_poll_info() override;
Fd &get_fd() override;
void release() override; void release() override;

View File

@ -14,35 +14,36 @@ namespace td {
namespace detail { namespace detail {
void EventFdWindows::init() { void EventFdWindows::init() {
fd_ = Fd::create_event_fd(); event_ = NativeFd(CreateEventW(nullptr, true, false, nullptr));
} }
bool EventFdWindows::empty() { bool EventFdWindows::empty() {
return fd_.empty(); return !event_;
} }
void EventFdWindows::close() { void EventFdWindows::close() {
fd_.close(); event_.close();
} }
Status EventFdWindows::get_pending_error() { Status EventFdWindows::get_pending_error() {
return Status::OK(); return Status::OK();
} }
const Fd &EventFdWindows::get_fd() const { PollableFdInfo &EventFdWindows::get_poll_info() {
return fd_; UNREACHABLE();
}
Fd &EventFdWindows::get_fd() {
return fd_;
} }
void EventFdWindows::release() { void EventFdWindows::release() {
fd_.release(); SetEvent(event_.io_handle());
} }
void EventFdWindows::acquire() { void EventFdWindows::acquire() {
fd_.acquire(); ResetEvent(event_.io_handle());
}
void EventFdWindows::wait(int timeout_ms) {
WaitForSingleObject(event_.io_handle(), timeout_ms);
ResetEvent(event_.io_handle());
} }
} // namespace detail } // namespace detail

View File

@ -12,14 +12,14 @@
#include "td/utils/common.h" #include "td/utils/common.h"
#include "td/utils/port/EventFdBase.h" #include "td/utils/port/EventFdBase.h"
#include "td/utils/port/Fd.h" #include "td/utils/port/detail/PollableFd.h"
#include "td/utils/Status.h" #include "td/utils/Status.h"
namespace td { namespace td {
namespace detail { namespace detail {
class EventFdWindows final : public EventFdBase { class EventFdWindows final : public EventFdBase {
Fd fd_; NativeFd event_;
public: public:
EventFdWindows() = default; EventFdWindows() = default;
@ -32,12 +32,13 @@ class EventFdWindows final : public EventFdBase {
Status get_pending_error() override TD_WARN_UNUSED_RESULT; Status get_pending_error() override TD_WARN_UNUSED_RESULT;
const Fd &get_fd() const override; PollableFdInfo &get_poll_info() override;
Fd &get_fd() override;
void release() override; void release() override;
void acquire() override; void acquire() override;
void wait(int timeout_ms);
}; };
} // namespace detail } // namespace detail

View File

@ -43,6 +43,10 @@ void KQueue::clear() {
events.clear(); events.clear();
close(kq); close(kq);
kq = -1; kq = -1;
for (auto *list_node = list_root.next; list_node != &list_root;) {
auto pollable_fd = PollableFd::from_list_node(list_node);
list_node = list_node->next;
}
} }
int KQueue::update(int nevents, const timespec *timeout, bool may_fail) { int KQueue::update(int nevents, const timespec *timeout, bool may_fail) {
@ -85,18 +89,21 @@ void KQueue::add_change(std::uintptr_t ident, int16 filter, uint16 flags, uint32
changes_n++; changes_n++;
} }
void KQueue::subscribe(const Fd &fd, Fd::Flags flags) { void KQueue::subscribe(PollableFd fd, PollFlags flags) {
if (flags & Fd::Read) { auto native_fd = fd.native_fd().fd();
add_change(fd.get_native_fd(), EVFILT_READ, EV_ADD | EV_CLEAR, 0, 0, nullptr); auto list_node = fd.release_as_list_node();
list_root.put(list_node);
if (flags.can_read()) {
add_change(native_fd, EVFILT_READ, EV_ADD | EV_CLEAR, 0, 0, list_node);
} }
if (flags & Fd::Write) { if (flags.can_write()) {
add_change(fd.get_native_fd(), EVFILT_WRITE, EV_ADD | EV_CLEAR, 0, 0, nullptr); add_change(native_fd, EVFILT_WRITE, EV_ADD | EV_CLEAR, 0, 0, list_node);
} }
} }
void KQueue::invalidate(const Fd &fd) { void KQueue::invalidate(int native_fd) {
for (int i = 0; i < changes_n; i++) { for (int i = 0; i < changes_n; i++) {
if (events[i].ident == static_cast<std::uintptr_t>(fd.get_native_fd())) { if (events[i].ident == static_cast<std::uintptr_t>(native_fd)) {
changes_n--; changes_n--;
std::swap(events[i], events[changes_n]); std::swap(events[i], events[changes_n]);
i--; i--;
@ -104,17 +111,21 @@ void KQueue::invalidate(const Fd &fd) {
} }
} }
void KQueue::unsubscribe(const Fd &fd) { void KQueue::unsubscribe(PollableFdRef fd_ref) {
auto pollable_fd = fd_ref.lock();
auto native_fd = pollable_fd.native_fd().fd();
// invalidate(fd); // invalidate(fd);
flush_changes(); flush_changes();
add_change(fd.get_native_fd(), EVFILT_READ, EV_DELETE, 0, 0, nullptr); add_change(native_fd, EVFILT_READ, EV_DELETE, 0, 0, nullptr);
flush_changes(true); flush_changes(true);
add_change(fd.get_native_fd(), EVFILT_WRITE, EV_DELETE, 0, 0, nullptr); add_change(native_fd, EVFILT_WRITE, EV_DELETE, 0, 0, nullptr);
flush_changes(true); flush_changes(true);
} }
void KQueue::unsubscribe_before_close(const Fd &fd) { void KQueue::unsubscribe_before_close(PollableFdRef fd_ref) {
invalidate(fd); auto pollable_fd = fd_ref.lock();
invalidate(pollable_fd.native_fd().fd());
// just to avoid O(changes_n ^ 2) // just to avoid O(changes_n ^ 2)
if (changes_n != 0) { if (changes_n != 0) {
@ -136,22 +147,24 @@ void KQueue::run(int timeout_ms) {
int n = update(static_cast<int>(events.size()), timeout_ptr); int n = update(static_cast<int>(events.size()), timeout_ptr);
for (int i = 0; i < n; i++) { for (int i = 0; i < n; i++) {
struct kevent *event = &events[i]; struct kevent *event = &events[i];
Fd::Flags flags = 0; PollFlags flags;
if (event->filter == EVFILT_WRITE) { if (event->filter == EVFILT_WRITE) {
flags |= Fd::Write; flags.add_flags(PollFlags::Write());
} }
if (event->filter == EVFILT_READ) { if (event->filter == EVFILT_READ) {
flags |= Fd::Read; flags.add_flags(PollFlags::Read());
} }
if (event->flags & EV_EOF) { if (event->flags & EV_EOF) {
flags |= Fd::Close; flags.add_flags(PollFlags::Close());
} }
if (event->fflags & EV_ERROR) { if (event->fflags & EV_ERROR) {
LOG(FATAL) << "EV_ERROR in kqueue is not supported"; LOG(FATAL) << "EV_ERROR in kqueue is not supported";
} }
VLOG(fd) << "Event [fd:" << event->ident << "] [filter:" << event->filter << "] [udata: " << event->udata << "]"; VLOG(fd) << "Event [fd:" << event->ident << "] [filter:" << event->filter << "] [udata: " << event->udata << "]";
// LOG(WARNING) << "event->ident = " << event->ident << "event->filter = " << event->filter; // LOG(WARNING) << "event->ident = " << event->ident << "event->filter = " << event->filter;
Fd(static_cast<int>(event->ident), Fd::Mode::Reference).update_flags_notify(flags); auto pollable_fd = PollableFd::from_list_node(static_cast<ListNode *>(event->udata));
pollable_fd.add_flags(flags);
pollable_fd.release_as_list_node();
} }
} }
} // namespace detail } // namespace detail

View File

@ -34,22 +34,27 @@ class KQueue final : public PollBase {
void clear() override; void clear() override;
void subscribe(const Fd &fd, Fd::Flags flags) override; void subscribe(PollableFd fd, PollFlags flags) override;
void unsubscribe(const Fd &fd) override; void unsubscribe(PollableFdRef fd) override;
void unsubscribe_before_close(const Fd &fd) override; void unsubscribe_before_close(PollableFdRef fd) override;
void run(int timeout_ms) override; void run(int timeout_ms) override;
static bool is_edge_triggered() {
return true;
}
private: private:
vector<struct kevent> events; vector<struct kevent> events;
int changes_n; int changes_n;
int kq; int kq;
ListNode list_root;
int update(int nevents, const timespec *timeout, bool may_fail = false); int update(int nevents, const timespec *timeout, bool may_fail = false);
void invalidate(const Fd &fd); void invalidate(int native_fd);
void flush_changes(bool may_fail = false); void flush_changes(bool may_fail = false);

View File

@ -0,0 +1,86 @@
//
// Copyright Aliaksei Levin (levlam@telegram.org), Arseny Smirnov (arseny30@gmail.com) 2014-2018
//
// Distributed under the Boost Software License, Version 1.0. (See accompanying
// file LICENSE_1_0.txt or copy at http://www.boost.org/LICENSE_1_0.txt)
//
#include "td/utils/port/detail/NativeFd.h"
#include "td/utils/logging.h"
#include "td/utils/Status.h"
#include "td/utils/format.h"
#if TD_PORT_POSIX
#include <unistd.h>
#endif
namespace td {
NativeFd::NativeFd(NativeFd::Raw raw) : fd_(raw) {
VLOG(fd) << *this << " create";
}
NativeFd::NativeFd(NativeFd::Raw raw, bool nolog) : fd_(raw) {
}
#if TD_PORT_WINDOWS
NativeFd::NativeFd(SOCKET raw) : fd_(reinterpret_cast<HANDLE>(raw)), is_socket_(true) {
VLOG(fd) << *this << " create";
}
#endif
NativeFd::~NativeFd() {
close();
}
NativeFd::operator bool() const {
return fd_.get() != empty_raw();
}
constexpr NativeFd::Raw NativeFd::empty_raw() {
#if TD_PORT_POSIX
return -1;
#elif TD_PORT_WINDOWS
return INVALID_HANDLE_VALUE;
#endif
}
NativeFd::Raw NativeFd::raw() const {
return fd_.get();
}
NativeFd::Raw NativeFd::fd() const {
return raw();
}
#if TD_PORT_WINDOWS
NativeFd::Raw NativeFd::io_handle() const {
return raw();
}
SOCKET NativeFd::socket() const {
CHECK(is_socket_);
return reinterpret_cast<SOCKET>(fd_.get());
}
#elif TD_PORT_POSIX
NativeFd::Raw NativeFd::socket() const {
return raw();
}
#endif
void NativeFd::close() {
if (!*this) {
return;
}
VLOG(fd) << *this << " close";
#if TD_PORT_WINDOWS
if (!CloseHandle(io_handle())) {
#elif TD_PORT_POSIX
if (::close(fd()) < 0) {
#endif
auto error = OS_ERROR("Close fd");
LOG(ERROR) << error;
}
fd_ = {};
}
NativeFd::Raw NativeFd::release() {
VLOG(fd) << *this << " release";
auto res = fd_.get();
fd_ = {};
return res;
}
StringBuilder &operator<<(StringBuilder &sb, const NativeFd &fd) {
sb << tag("fd", fd.raw());
return sb;
}
} // namespace td

View File

@ -0,0 +1,56 @@
//
// Copyright Aliaksei Levin (levlam@telegram.org), Arseny Smirnov (arseny30@gmail.com) 2014-2018
//
// Distributed under the Boost Software License, Version 1.0. (See accompanying
// file LICENSE_1_0.txt or copy at http://www.boost.org/LICENSE_1_0.txt)
//
#pragma once
#include "td/utils/port/config.h"
#include "td/utils/common.h"
#include "td/utils/MovableValue.h"
namespace td {
class StringBuilder;
}
namespace td {
class NativeFd {
public:
#if TD_PORT_POSIX
using Raw = int;
#elif TD_PORT_WINDOWS
using Raw = HANDLE;
#endif
NativeFd() = default;
NativeFd(NativeFd &&) = default;
NativeFd &operator=(NativeFd &&) = default;
explicit NativeFd(Raw raw);
NativeFd(Raw raw, bool nolog);
#if TD_PORT_WINDOWS
explicit NativeFd(SOCKET raw);
#endif
~NativeFd();
explicit operator bool() const;
static constexpr Raw empty_raw();
Raw raw() const;
Raw fd() const;
#if TD_PORT_WINDOWS
Raw io_handle() const;
SOCKET socket() const;
#elif TD_PORT_POSIX
Raw socket() const;
#endif
void close();
Raw release();
private:
#if TD_PORT_POSIX
MovableValue<Raw, -1> fd_;
#elif TD_PORT_WINDOWS
MovableValue<Raw, INVALID_HANDLE_VALUE> fd_;
bool is_socket_{false};
#endif
};
StringBuilder &operator<<(StringBuilder &sb, const NativeFd &fd);
} // namespace td

View File

@ -25,31 +25,37 @@ void Poll::clear() {
pollfds_.clear(); pollfds_.clear();
} }
void Poll::subscribe(const Fd &fd, Fd::Flags flags) { void Poll::subscribe(PollableFd fd, PollFlags flags) {
unsubscribe(fd); unsubscribe(fd.ref());
struct pollfd pollfd; struct pollfd pollfd;
pollfd.fd = fd.get_native_fd(); pollfd.fd = fd.native_fd().fd();
pollfd.events = 0; pollfd.events = 0;
if (flags & Fd::Read) { if (flags.can_read()) {
pollfd.events |= POLLIN; pollfd.events |= POLLIN;
} }
if (flags & Fd::Write) { if (flags.can_write()) {
pollfd.events |= POLLOUT; pollfd.events |= POLLOUT;
} }
pollfd.revents = 0; pollfd.revents = 0;
pollfds_.push_back(pollfd); pollfds_.push_back(pollfd);
fds_.push_back(std::move(fd));
} }
void Poll::unsubscribe(const Fd &fd) { void Poll::unsubscribe(PollableFdRef fd_ref) {
auto fd = fd_ref.lock();
SCOPE_EXIT {
fd.release_as_list_node();
};
for (auto it = pollfds_.begin(); it != pollfds_.end(); ++it) { for (auto it = pollfds_.begin(); it != pollfds_.end(); ++it) {
if (it->fd == fd.get_native_fd()) { if (it->fd == fd.native_fd().fd()) {
pollfds_.erase(it); pollfds_.erase(it);
fds_.erase(fds_.begin() + (it - pollfds_.begin()));
return; return;
} }
} }
} }
void Poll::unsubscribe_before_close(const Fd &fd) { void Poll::unsubscribe_before_close(PollableFdRef fd) {
unsubscribe(fd); unsubscribe(fd);
} }
@ -58,23 +64,26 @@ void Poll::run(int timeout_ms) {
auto poll_errno = errno; auto poll_errno = errno;
LOG_IF(FATAL, err == -1 && poll_errno != EINTR) << Status::PosixError(poll_errno, "poll failed"); LOG_IF(FATAL, err == -1 && poll_errno != EINTR) << Status::PosixError(poll_errno, "poll failed");
for (auto &pollfd : pollfds_) { for (size_t i = 0; i < pollfds_.size(); i++) {
Fd::Flags flags = 0; auto &pollfd = pollfds_[i];
auto &fd = fds_[i];
PollFlags flags;
if (pollfd.revents & POLLIN) { if (pollfd.revents & POLLIN) {
pollfd.revents &= ~POLLIN; pollfd.revents &= ~POLLIN;
flags |= Fd::Read; flags = flags | PollFlags::Read();
} }
if (pollfd.revents & POLLOUT) { if (pollfd.revents & POLLOUT) {
pollfd.revents &= ~POLLOUT; pollfd.revents &= ~POLLOUT;
flags |= Fd::Write; flags = flags | PollFlags::Write();
} }
if (pollfd.revents & POLLHUP) { if (pollfd.revents & POLLHUP) {
pollfd.revents &= ~POLLHUP; pollfd.revents &= ~POLLHUP;
flags |= Fd::Close; flags = flags | PollFlags::Close();
} }
if (pollfd.revents & POLLERR) { if (pollfd.revents & POLLERR) {
pollfd.revents &= ~POLLERR; pollfd.revents &= ~POLLERR;
flags |= Fd::Error; flags = flags | PollFlags::Error();
} }
if (pollfd.revents & POLLNVAL) { if (pollfd.revents & POLLNVAL) {
LOG(FATAL) << "Unexpected POLLNVAL " << tag("fd", pollfd.fd); LOG(FATAL) << "Unexpected POLLNVAL " << tag("fd", pollfd.fd);
@ -82,7 +91,7 @@ void Poll::run(int timeout_ms) {
if (pollfd.revents) { if (pollfd.revents) {
LOG(FATAL) << "Unsupported poll events: " << pollfd.revents; LOG(FATAL) << "Unsupported poll events: " << pollfd.revents;
} }
Fd(pollfd.fd, Fd::Mode::Reference).update_flags_notify(flags); fd.add_flags(flags);
} }
} }

View File

@ -32,16 +32,21 @@ class Poll final : public PollBase {
void clear() override; void clear() override;
void subscribe(const Fd &fd, Fd::Flags flags) override; void subscribe(PollableFd fd, PollFlags flags) override;
void unsubscribe(const Fd &fd) override; void unsubscribe(PollableFdRef fd) override;
void unsubscribe_before_close(const Fd &fd) override; void unsubscribe_before_close(PollableFdRef fd) override;
void run(int timeout_ms) override; void run(int timeout_ms) override;
static bool is_edge_triggered() {
return false;
}
private: private:
vector<pollfd> pollfds_; vector<pollfd> pollfds_;
vector<PollableFd> fds_;
}; };
} // namespace detail } // namespace detail

View File

@ -0,0 +1,51 @@
//
// Copyright Aliaksei Levin (levlam@telegram.org), Arseny Smirnov (arseny30@gmail.com) 2014-2018
//
// Distributed under the Boost Software License, Version 1.0. (See accompanying
// file LICENSE_1_0.txt or copy at http://www.boost.org/LICENSE_1_0.txt)
//
#include "td/utils/port/detail/PollableFd.h"
#if TD_POSIX_PORT
#include <unistd.h>
#endif
namespace td {
bool PollFlagsSet::write_flags(PollFlags flags) {
if (flags.empty()) {
return false;
}
auto old_flags = to_write_.fetch_or(flags.raw(), std::memory_order_relaxed);
return (flags.raw() & ~old_flags) != 0;
}
bool PollFlagsSet::write_flags_local(PollFlags flags) {
return flags_.add_flags(flags);
}
bool PollFlagsSet::flush() const {
if (to_write_.load(std::memory_order_relaxed) == 0) {
return false;
}
auto to_write = to_write_.exchange(0, std::memory_order_relaxed);
auto old_flags = flags_;
flags_.add_flags(PollFlags::from_raw(to_write));
if (flags_.can_close()) {
flags_.remove_flags(PollFlags::Write());
}
return flags_ != old_flags;
}
PollFlags PollFlagsSet::read_flags() const {
flush();
return flags_;
}
PollFlags PollFlagsSet::read_flags_local() const {
return flags_;
}
void PollFlagsSet::clear_flags(PollFlags flags) {
flags_.remove_flags(flags);
}
void PollFlagsSet::clear() {
to_write_ = 0;
flags_ = {};
}
} // namespace td

View File

@ -0,0 +1,360 @@
//
// Copyright Aliaksei Levin (levlam@telegram.org), Arseny Smirnov (arseny30@gmail.com) 2014-2018
//
// Distributed under the Boost Software License, Version 1.0. (See accompanying
// file LICENSE_1_0.txt or copy at http://www.boost.org/LICENSE_1_0.txt)
//
#pragma once
#include "td/utils/Status.h"
#include "td/utils/List.h"
#include "td/utils/Observer.h"
#include "td/utils/MovableValue.h"
#include "td/utils/SpinLock.h"
#include "td/utils/format.h"
#include "td/utils/port/detail/NativeFd.h"
#include <atomic>
namespace td {
class ObserverBase;
class PollFlags {
public:
using Raw = int32;
bool can_read() const {
return has_flags(Read());
}
bool can_write() const {
return has_flags(Write());
}
bool can_close() const {
return has_flags(Close());
}
bool has_pending_error() const {
return has_flags(Error());
}
void remove_flags(PollFlags flags) {
remove_flags(flags.raw());
}
bool add_flags(PollFlags flags) {
auto old_flags = flags_;
add_flags(flags.raw());
return old_flags != flags_;
}
bool has_flags(PollFlags flags) const {
return has_flags(flags.raw());
}
bool empty() const {
return flags_ == 0;
}
Raw raw() const {
return flags_;
}
static PollFlags from_raw(Raw raw) {
return PollFlags(raw);
}
PollFlags() = default;
bool operator==(const PollFlags &other) const {
return flags_ == other.flags_;
}
bool operator!=(const PollFlags &other) const {
return !(*this == other);
}
PollFlags operator|(const PollFlags other) const {
return from_raw(raw() | other.raw());
}
static PollFlags Write() {
return PollFlags(Flag::Write);
}
static PollFlags Error() {
return PollFlags(Flag::Error);
}
static PollFlags Close() {
return PollFlags(Flag::Close);
}
static PollFlags Read() {
return PollFlags(Flag::Read);
}
static PollFlags ReadWrite() {
return Read() | Write();
}
private:
enum class Flag : Raw { Write = 0x001, Read = 0x002, Close = 0x004, Error = 0x008, None = 0 };
Raw flags_{static_cast<Raw>(Flag::None)};
explicit PollFlags(Raw raw) : flags_(raw) {
}
explicit PollFlags(Flag flag) : PollFlags(static_cast<Raw>(flag)) {
}
PollFlags &add_flags(Raw flags) {
flags_ |= flags;
return *this;
}
PollFlags &remove_flags(Raw flags) {
flags_ &= ~flags;
return *this;
}
bool has_flags(Raw flags) const {
return (flags_ & flags) == flags;
}
};
inline StringBuilder &operator<<(StringBuilder &sb, PollFlags flags) {
sb << "[";
if (flags.can_read()) {
sb << "R";
}
if (flags.can_write()) {
sb << "W";
}
if (flags.can_close()) {
sb << "C";
}
if (flags.has_pending_error()) {
sb << "E";
}
return sb << "]";
}
class PollFlagsSet {
public:
// write flags from any thread
// this is the only function that should be called from other threads
bool write_flags(PollFlags flags);
bool write_flags_local(PollFlags flags);
bool flush() const;
PollFlags read_flags() const;
PollFlags read_flags_local() const;
void clear_flags(PollFlags flags);
void clear();
private:
mutable std::atomic<PollFlags::Raw> to_write_{0};
mutable PollFlags flags_;
};
class PollableFdInfo;
class PollableFdInfoUnlock {
public:
void operator()(PollableFdInfo *ptr);
};
class PollableFd;
class PollableFdRef {
public:
explicit PollableFdRef(ListNode *list_node) : list_node_(list_node) {
}
PollableFd lock();
private:
ListNode *list_node_;
};
class PollableFd {
public:
// Interface for kqueue, epoll and e.t.c.
const NativeFd &native_fd() const;
ListNode *release_as_list_node();
PollableFdRef ref();
static PollableFd from_list_node(ListNode *node);
void add_flags(PollFlags flags);
PollFlags get_flags_unsafe() const;
private:
std::unique_ptr<PollableFdInfo, PollableFdInfoUnlock> fd_info_;
friend class PollableFdInfo;
explicit PollableFd(std::unique_ptr<PollableFdInfo, PollableFdInfoUnlock> fd_info) : fd_info_(std::move(fd_info)) {
}
};
inline PollableFd PollableFdRef::lock() {
return PollableFd::from_list_node(list_node_);
}
class PollableFdInfo : private ListNode {
public:
PollableFdInfo() = default;
PollableFd extract_pollable_fd(ObserverBase *observer) {
VLOG(fd) << native_fd() << " extract pollable fd " << tag("observer", observer);
CHECK(!empty());
bool was_locked = lock_.test_and_set(std::memory_order_acquire);
CHECK(!was_locked);
set_observer(observer);
return PollableFd{std::unique_ptr<PollableFdInfo, PollableFdInfoUnlock>{this}};
}
PollableFdRef get_pollable_fd_ref() {
CHECK(!empty());
bool was_locked = lock_.test_and_set(std::memory_order_acquire);
CHECK(was_locked);
return PollableFdRef{as_list_node()};
}
void add_flags(PollFlags flags) {
flags_.write_flags_local(flags);
}
void clear_flags(PollFlags flags) {
flags_.clear_flags(flags);
}
PollFlags get_flags() const {
return flags_.read_flags();
}
PollFlags get_flags_local() const {
return flags_.read_flags_local();
}
bool empty() const {
return !fd_;
}
void set_native_fd(NativeFd new_native_fd) {
if (fd_) {
CHECK(!new_native_fd);
bool was_locked = lock_.test_and_set(std::memory_order_acquire);
CHECK(!was_locked);
lock_.clear(std::memory_order_release);
}
fd_ = std::move(new_native_fd);
}
explicit PollableFdInfo(NativeFd native_fd) {
set_native_fd(std::move(native_fd));
}
const NativeFd &native_fd() const {
//CHECK(!empty());
return fd_;
}
NativeFd move_as_native_fd() {
return std::move(fd_);
}
~PollableFdInfo() {
VLOG(fd) << native_fd() << " destroy PollableFdInfo";
bool was_locked = lock_.test_and_set(std::memory_order_acquire);
CHECK(!was_locked);
}
void add_flags_from_poll(PollFlags flags) {
VLOG(fd) << native_fd() << " add flags from poll " << flags;
if (flags_.write_flags(flags)) {
notify_observer();
}
}
private:
NativeFd fd_{};
std::atomic_flag lock_{false};
PollFlagsSet flags_;
#if TD_PORT_WINDOWS
SpinLock observer_lock_;
#endif
ObserverBase *observer_{nullptr};
friend class PollableFd;
friend class PollableFdInfoUnlock;
void set_observer(ObserverBase *observer) {
#if TD_PORT_WINDOWS
auto lock = observer_lock_.lock();
#endif
CHECK(!observer_);
observer_ = observer;
}
void clear_observer() {
#if TD_PORT_WINDOWS
auto lock = observer_lock_.lock();
#endif
observer_ = nullptr;
}
void notify_observer() {
#if TD_PORT_WINDOWS
auto lock = observer_lock_.lock();
#endif
VLOG(fd) << native_fd() << " notify " << tag("observer", observer_);
if (observer_) {
observer_->notify();
}
}
void unlock() {
lock_.clear(std::memory_order_release);
as_list_node()->remove();
}
ListNode *as_list_node() {
return static_cast<ListNode *>(this);
}
static PollableFdInfo *from_list_node(ListNode *list_node) {
return static_cast<PollableFdInfo *>(list_node);
}
};
inline void PollableFdInfoUnlock::operator()(PollableFdInfo *ptr) {
ptr->unlock();
}
inline ListNode *PollableFd::release_as_list_node() {
return fd_info_.release()->as_list_node();
}
inline PollableFdRef PollableFd::ref() {
return PollableFdRef{fd_info_->as_list_node()};
}
inline PollableFd PollableFd::from_list_node(ListNode *node) {
return PollableFd(std::unique_ptr<PollableFdInfo, PollableFdInfoUnlock>(PollableFdInfo::from_list_node(node)));
}
inline void PollableFd::add_flags(PollFlags flags) {
fd_info_->add_flags_from_poll(flags);
}
inline PollFlags PollableFd::get_flags_unsafe() const {
return fd_info_->get_flags_local();
}
inline const NativeFd &PollableFd::native_fd() const {
return fd_info_->native_fd();
}
#if TD_PORT_POSIX
template <class F>
auto skip_eintr(F &&f) {
decltype(f()) res;
static_assert(std::is_integral<decltype(res)>::value, "integral type expected");
do {
errno = 0; // just in case
res = f();
} while (res < 0 && errno == EINTR);
return res;
}
template <class F>
auto skip_eintr_cstr(F &&f) {
char *res;
do {
errno = 0; // just in case
res = f();
} while (res == nullptr && errno == EINTR);
return res;
}
#endif
template <class FdT>
bool can_read(const FdT &fd) {
return fd.get_poll_info().get_flags().can_read() || fd.get_poll_info().get_flags().has_pending_error();
}
template <class FdT>
bool can_write(const FdT &fd) {
return fd.get_poll_info().get_flags().can_write();
}
template <class FdT>
bool can_close(const FdT &fd) {
return fd.get_poll_info().get_flags().can_close();
}
} // namespace td

View File

@ -29,12 +29,12 @@ void Select::clear() {
fds_.clear(); fds_.clear();
} }
void Select::subscribe(const Fd &fd, Fd::Flags flags) { void Select::subscribe(PollableFd fd, PollFlags flags) {
int native_fd = fd.get_native_fd(); int native_fd = fd.native_fd().fd();
for (auto &it : fds_) { for (auto &it : fds_) {
CHECK(it.fd_ref.get_native_fd() != native_fd); CHECK(it.fd.native_fd().fd() != native_fd);
} }
fds_.push_back(FdInfo{Fd(native_fd, Fd::Mode::Reference), flags}); fds_.push_back(FdInfo{std::move(fd), flags});
CHECK(0 <= native_fd && native_fd < FD_SETSIZE) << native_fd << " " << FD_SETSIZE; CHECK(0 <= native_fd && native_fd < FD_SETSIZE) << native_fd << " " << FD_SETSIZE;
FD_SET(native_fd, &all_fd_); FD_SET(native_fd, &all_fd_);
if (native_fd > max_fd_) { if (native_fd > max_fd_) {
@ -42,8 +42,11 @@ void Select::subscribe(const Fd &fd, Fd::Flags flags) {
} }
} }
void Select::unsubscribe(const Fd &fd) { void Select::unsubscribe(PollableFdRef fd) {
int native_fd = fd.get_native_fd(); auto fd_locked = fd.lock();
int native_fd = fd_locked.native_fd().fd();
fd_locked.release_as_list_node();
CHECK(0 <= native_fd && native_fd < FD_SETSIZE) << native_fd << " " << FD_SETSIZE; CHECK(0 <= native_fd && native_fd < FD_SETSIZE) << native_fd << " " << FD_SETSIZE;
FD_CLR(native_fd, &all_fd_); FD_CLR(native_fd, &all_fd_);
FD_CLR(native_fd, &read_fd_); FD_CLR(native_fd, &read_fd_);
@ -53,7 +56,7 @@ void Select::unsubscribe(const Fd &fd) {
max_fd_--; max_fd_--;
} }
for (auto it = fds_.begin(); it != fds_.end();) { for (auto it = fds_.begin(); it != fds_.end();) {
if (it->fd_ref.get_native_fd() == native_fd) { if (it->fd.native_fd().fd() == native_fd) {
std::swap(*it, fds_.back()); std::swap(*it, fds_.back());
fds_.pop_back(); fds_.pop_back();
break; break;
@ -63,7 +66,7 @@ void Select::unsubscribe(const Fd &fd) {
} }
} }
void Select::unsubscribe_before_close(const Fd &fd) { void Select::unsubscribe_before_close(PollableFdRef fd) {
unsubscribe(fd); unsubscribe(fd);
} }
@ -79,14 +82,14 @@ void Select::run(int timeout_ms) {
} }
for (auto &it : fds_) { for (auto &it : fds_) {
int native_fd = it.fd_ref.get_native_fd(); int native_fd = it.fd.native_fd().fd();
Fd::Flags fd_flags = it.fd_ref.get_flags(); PollFlags fd_flags = it.fd.get_flags_unsafe(); // concurrent calls are UB
if ((it.flags & Fd::Write) && !(fd_flags & Fd::Write)) { if (it.flags.can_write() && !fd_flags.can_write()) {
FD_SET(native_fd, &write_fd_); FD_SET(native_fd, &write_fd_);
} else { } else {
FD_CLR(native_fd, &write_fd_); FD_CLR(native_fd, &write_fd_);
} }
if ((it.flags & Fd::Read) && !(fd_flags & Fd::Read)) { if (it.flags.can_read() && !fd_flags.can_read()) {
FD_SET(native_fd, &read_fd_); FD_SET(native_fd, &read_fd_);
} else { } else {
FD_CLR(native_fd, &read_fd_); FD_CLR(native_fd, &read_fd_);
@ -96,20 +99,18 @@ void Select::run(int timeout_ms) {
select(max_fd_ + 1, &read_fd_, &write_fd_, &except_fd_, timeout_ptr); select(max_fd_ + 1, &read_fd_, &write_fd_, &except_fd_, timeout_ptr);
for (auto &it : fds_) { for (auto &it : fds_) {
int native_fd = it.fd_ref.get_native_fd(); int native_fd = it.fd.native_fd().fd();
Fd::Flags flags = 0; PollFlags flags;
if (FD_ISSET(native_fd, &read_fd_)) { if (FD_ISSET(native_fd, &read_fd_)) {
flags |= Fd::Read; flags = flags | PollFlags::Read();
} }
if (FD_ISSET(native_fd, &write_fd_)) { if (FD_ISSET(native_fd, &write_fd_)) {
flags |= Fd::Write; flags = flags | PollFlags::Write();
} }
if (FD_ISSET(native_fd, &except_fd_)) { if (FD_ISSET(native_fd, &except_fd_)) {
flags |= Fd::Error; flags = flags | PollFlags::Error();
}
if (flags != 0) {
it.fd_ref.update_flags_notify(flags);
} }
it.fd.add_flags(flags);
} }
} }

View File

@ -11,7 +11,6 @@
#ifdef TD_POLL_SELECT #ifdef TD_POLL_SELECT
#include "td/utils/common.h" #include "td/utils/common.h"
#include "td/utils/port/Fd.h"
#include "td/utils/port/PollBase.h" #include "td/utils/port/PollBase.h"
#include <sys/select.h> #include <sys/select.h>
@ -32,18 +31,22 @@ class Select final : public PollBase {
void clear() override; void clear() override;
void subscribe(const Fd &fd, Fd::Flags flags) override; void subscribe(PollableFd fd, PollFlags flags) override;
void unsubscribe(const Fd &fd) override; void unsubscribe(PollableFdRef fd) override;
void unsubscribe_before_close(const Fd &fd) override; void unsubscribe_before_close(PollableFdRef fd) override;
void run(int timeout_ms) override; void run(int timeout_ms) override;
static bool is_edge_triggered() {
return false;
}
private: private:
struct FdInfo { struct FdInfo {
Fd fd_ref; PollableFd fd;
Fd::Flags flags; PollFlags flags;
}; };
vector<FdInfo> fds_; vector<FdInfo> fds_;
fd_set all_fd_; fd_set all_fd_;

View File

@ -22,75 +22,93 @@ char disable_linker_warning_about_empty_file_wineventpoll_cpp TD_UNUSED;
namespace td { namespace td {
namespace detail { namespace detail {
IOCP::~IOCP() {
void WineventPoll::init() {
clear(); clear();
} }
void WineventPoll::clear() { void IOCP::loop() {
fds_.clear(); IOCP::Guard guard(this);
} while (true) {
DWORD bytes = 0;
void WineventPoll::subscribe(const Fd &fd, Fd::Flags flags) { ULONG_PTR key = 0;
for (auto &it : fds_) { OVERLAPPED *overlapped = nullptr;
if (it.fd_ref.get_key() == fd.get_key()) { bool ok = GetQueuedCompletionStatus(iocp_handle_.io_handle(), &bytes, &key, &overlapped, 1000);
it.flags = flags; if (bytes || key || overlapped) {
return; LOG(ERROR) << "Got iocp " << bytes << " " << key << " " << overlapped;
} }
} if (ok) {
fds_.push_back({fd.clone(), flags}); auto callback = reinterpret_cast<IOCP::Callback *>(key);
} if (callback == nullptr) {
LOG(ERROR) << "Interrupt IOCP loop";
void WineventPoll::unsubscribe(const Fd &fd) { return;
for (auto it = fds_.begin(); it != fds_.end(); ++it) { }
if (it->fd_ref.get_key() == fd.get_key()) { callback->on_iocp(bytes, overlapped);
std::swap(*it, fds_.back()); } else {
fds_.pop_back(); if (overlapped != nullptr) {
return; auto error = OS_ERROR("from iocp");
} auto callback = reinterpret_cast<IOCP::Callback *>(key);
} CHECK(callback != nullptr);
} callback->on_iocp(std::move(error), overlapped);
void WineventPoll::unsubscribe_before_close(const Fd &fd) {
unsubscribe(fd);
}
void WineventPoll::run(int timeout_ms) {
vector<std::pair<size_t, Fd::Flag>> events_desc;
vector<HANDLE> events;
for (size_t i = 0; i < fds_.size(); i++) {
auto &fd_info = fds_[i];
if (fd_info.flags & Fd::Flag::Write) {
events_desc.emplace_back(i, Fd::Flag::Write);
events.push_back(fd_info.fd_ref.get_write_event());
}
if (fd_info.flags & Fd::Flag::Read) {
events_desc.emplace_back(i, Fd::Flag::Read);
events.push_back(fd_info.fd_ref.get_read_event());
}
}
if (events.empty()) {
usleep_for(timeout_ms * 1000);
return;
}
auto status = WaitForMultipleObjects(narrow_cast<DWORD>(events.size()), events.data(), false, timeout_ms);
if (status == WAIT_FAILED) {
auto error = OS_ERROR("WaitForMultipleObjects failed");
LOG(FATAL) << events.size() << " " << timeout_ms << " " << error;
}
for (size_t i = 0; i < events.size(); i++) {
if (WaitForSingleObject(events[i], 0) == WAIT_OBJECT_0) {
auto &fd = fds_[events_desc[i].first].fd_ref;
if (events_desc[i].second == Fd::Flag::Read) {
fd.on_read_event();
} else {
fd.on_write_event();
} }
} }
} }
} }
void IOCP::interrupt_loop() {
post(0, nullptr, nullptr);
}
void IOCP::init() {
CHECK(!iocp_handle_);
auto res = CreateIoCompletionPort(INVALID_HANDLE_VALUE, nullptr, 0, 0);
if (res == nullptr) {
auto error = OS_ERROR("Iocp creation failed");
LOG(FATAL) << error;
}
iocp_handle_ = NativeFd(res);
}
void IOCP::clear() {
iocp_handle_.close();
}
void IOCP::subscribe(const NativeFd &native_fd, Callback *callback) {
CHECK(iocp_handle_);
auto iocp_handle =
CreateIoCompletionPort(native_fd.io_handle(), iocp_handle_.io_handle(), reinterpret_cast<ULONG_PTR>(callback), 0);
if (iocp_handle == INVALID_HANDLE_VALUE) {
auto error = OS_ERROR("CreateIoCompletionPort");
LOG(FATAL) << error;
}
CHECK(iocp_handle == iocp_handle_.io_handle()) << iocp_handle << " " << iocp_handle_.io_handle();
}
void IOCP::post(size_t size, Callback *callback, OVERLAPPED *overlapped) {
PostQueuedCompletionStatus(iocp_handle_.io_handle(), DWORD(size), reinterpret_cast<ULONG_PTR>(callback), overlapped);
}
void WineventPoll::init() {
}
void WineventPoll::clear() {
}
void WineventPoll::subscribe(PollableFd fd, PollFlags flags) {
fd.release_as_list_node();
}
void WineventPoll::unsubscribe(PollableFdRef fd) {
fd.lock();
}
void WineventPoll::unsubscribe_before_close(PollableFdRef fd) {
unsubscribe(std::move(fd));
}
void WineventPoll::run(int timeout_ms) {
UNREACHABLE();
}
} // namespace detail } // namespace detail
} // namespace td } // namespace td

View File

@ -11,12 +11,41 @@
#ifdef TD_POLL_WINEVENT #ifdef TD_POLL_WINEVENT
#include "td/utils/common.h" #include "td/utils/common.h"
#include "td/utils/Context.h"
#include "td/utils/port/Fd.h" #include "td/utils/port/Fd.h"
#include "td/utils/port/PollBase.h" #include "td/utils/port/PollBase.h"
#include "td/utils/port/thread.h"
namespace td { namespace td {
namespace detail { namespace detail {
class IOCP final : public Context<IOCP> {
public:
IOCP() = default;
IOCP(const IOCP &) = delete;
IOCP &operator=(const IOCP &) = delete;
IOCP(IOCP &&) = delete;
IOCP &operator=(IOCP &&) = delete;
~IOCP();
class Callback {
public:
virtual ~Callback() = default;
virtual void on_iocp(Result<size_t> r_size, OVERLAPPED *overlapped) = 0;
};
void init();
void subscribe(const NativeFd &fd, Callback *callback);
void post(size_t size, Callback *callback, OVERLAPPED *overlapped);
void loop();
void interrupt_loop();
void clear();
private:
NativeFd iocp_handle_;
std::vector<td::thread> workers_;
};
class WineventPoll final : public PollBase { class WineventPoll final : public PollBase {
public: public:
WineventPoll() = default; WineventPoll() = default;
@ -30,20 +59,19 @@ class WineventPoll final : public PollBase {
void clear() override; void clear() override;
void subscribe(const Fd &fd, Fd::Flags flags) override; void subscribe(PollableFd fd, PollFlags flags) override;
void unsubscribe(const Fd &fd) override; void unsubscribe(PollableFdRef fd) override;
void unsubscribe_before_close(const Fd &fd) override; void unsubscribe_before_close(PollableFdRef fd) override;
void run(int timeout_ms) override; void run(int timeout_ms) override;
static bool is_edge_triggered() {
return true;
}
private: private:
struct FdInfo {
Fd fd_ref;
Fd::Flags flags;
};
vector<FdInfo> fds_;
}; };
} // namespace detail } // namespace detail

View File

@ -66,6 +66,16 @@ Status mkpath(CSlice path, int32 mode) {
return Status::OK(); return Status::OK();
} }
Status rmrf(CSlice path) {
return walk_path(path, [&](CSlice path, bool is_dir) {
if (is_dir) {
return rmdir(path);
} else {
return unlink(path);
}
});
}
#if TD_PORT_POSIX #if TD_PORT_POSIX
Status mkdir(CSlice dir, int32 mode) { Status mkdir(CSlice dir, int32 mode) {

View File

@ -43,6 +43,7 @@ Result<string> realpath(CSlice slice, bool ignore_access_denied = false) TD_WARN
Status chdir(CSlice dir) TD_WARN_UNUSED_RESULT; Status chdir(CSlice dir) TD_WARN_UNUSED_RESULT;
Status rmdir(CSlice dir) TD_WARN_UNUSED_RESULT; Status rmdir(CSlice dir) TD_WARN_UNUSED_RESULT;
Status unlink(CSlice path) TD_WARN_UNUSED_RESULT; Status unlink(CSlice path) TD_WARN_UNUSED_RESULT;
Status rmrf(CSlice path) TD_WARN_UNUSED_RESULT;
Status set_temporary_dir(CSlice dir) TD_WARN_UNUSED_RESULT; Status set_temporary_dir(CSlice dir) TD_WARN_UNUSED_RESULT;
CSlice get_temporary_dir(); CSlice get_temporary_dir();
Result<std::pair<FileFd, string>> mkstemp(CSlice dir) TD_WARN_UNUSED_RESULT; Result<std::pair<FileFd, string>> mkstemp(CSlice dir) TD_WARN_UNUSED_RESULT;
@ -119,12 +120,13 @@ Status walk_path_dir(string &path, DIR *subdir, Func &&func) {
template <class Func> template <class Func>
Status walk_path_dir(string &path, FileFd fd, Func &&func) { Status walk_path_dir(string &path, FileFd fd, Func &&func) {
auto *subdir = fdopendir(fd.get_fd().move_as_native_fd()); auto native_fd = fd.move_as_native_fd();
auto *subdir = fdopendir(native_fd.fd());
if (subdir == nullptr) { if (subdir == nullptr) {
auto error = OS_ERROR("fdopendir"); auto error = OS_ERROR("fdopendir");
fd.close();
return error; return error;
} }
native_fd.release();
return walk_path_dir(path, subdir, std::forward<Func>(func)); return walk_path_dir(path, subdir, std::forward<Func>(func));
} }

View File

@ -13,11 +13,11 @@ namespace td {
namespace detail { namespace detail {
static TD_THREAD_LOCAL int32 thread_id_; static TD_THREAD_LOCAL int32 thread_id_;
static TD_THREAD_LOCAL std::vector<std::unique_ptr<Guard>> *thread_local_destructors; static TD_THREAD_LOCAL std::vector<std::unique_ptr<Destructor>> *thread_local_destructors;
void add_thread_local_destructor(std::unique_ptr<Guard> destructor) { void add_thread_local_destructor(std::unique_ptr<Destructor> destructor) {
if (thread_local_destructors == nullptr) { if (thread_local_destructors == nullptr) {
thread_local_destructors = new std::vector<std::unique_ptr<Guard>>(); thread_local_destructors = new std::vector<std::unique_ptr<Destructor>>();
} }
thread_local_destructors->push_back(std::move(destructor)); thread_local_destructors->push_back(std::move(destructor));
} }

View File

@ -9,7 +9,7 @@
#include "td/utils/port/config.h" #include "td/utils/port/config.h"
#include "td/utils/common.h" #include "td/utils/common.h"
#include "td/utils/ScopeGuard.h" #include "td/utils/Destructor.h"
#include <memory> #include <memory>
#include <utility> #include <utility>
@ -39,18 +39,18 @@ void set_thread_id(int32 id);
int32 get_thread_id(); int32 get_thread_id();
namespace detail { namespace detail {
void add_thread_local_destructor(std::unique_ptr<Guard> destructor); void add_thread_local_destructor(std::unique_ptr<Destructor> destructor);
template <class T, class P, class... ArgsT> template <class T, class P, class... ArgsT>
void do_init_thread_local(P &raw_ptr, ArgsT &&... args) { void do_init_thread_local(P &raw_ptr, ArgsT &&... args) {
auto ptr = std::make_unique<T>(std::forward<ArgsT>(args)...); auto ptr = std::make_unique<T>(std::forward<ArgsT>(args)...);
raw_ptr = ptr.get(); raw_ptr = ptr.get();
detail::add_thread_local_destructor(create_lambda_guard([ptr = std::move(ptr), &raw_ptr]() mutable { detail::add_thread_local_destructor(create_destructor([ptr = std::move(ptr), &raw_ptr]() mutable {
ptr.reset(); ptr.reset();
raw_ptr = nullptr; raw_ptr = nullptr;
})); }));
} } // namespace detail
} // namespace detail } // namespace detail
template <class T, class P, class... ArgsT> template <class T, class P, class... ArgsT>

View File

@ -37,11 +37,6 @@ class TlParser {
public: public:
explicit TlParser(Slice slice) { explicit TlParser(Slice slice) {
if (slice.size() % sizeof(int32) != 0) {
set_error("Wrong length");
return;
}
data_len = left_len = slice.size(); data_len = left_len = slice.size();
if (is_aligned_pointer<4>(slice.begin())) { if (is_aligned_pointer<4>(slice.begin())) {
data = slice.ubegin(); data = slice.ubegin();
@ -51,7 +46,7 @@ class TlParser {
buf = &small_data_array[0]; buf = &small_data_array[0];
} else { } else {
LOG(ERROR) << "Unexpected big unaligned data pointer of length " << slice.size() << " at " << slice.begin(); LOG(ERROR) << "Unexpected big unaligned data pointer of length " << slice.size() << " at " << slice.begin();
data_buf = make_unique<int32[]>(data_len / sizeof(int32)); data_buf = make_unique<int32[]>(1 + data_len / sizeof(int32));
buf = data_buf.get(); buf = data_buf.get();
} }
std::memcpy(static_cast<void *>(buf), static_cast<const void *>(slice.begin()), slice.size()); std::memcpy(static_cast<void *>(buf), static_cast<const void *>(slice.begin()), slice.size());
@ -91,7 +86,8 @@ class TlParser {
} }
int32 fetch_int_unsafe() { int32 fetch_int_unsafe() {
int32 result = *reinterpret_cast<const int32 *>(data); int32 result;
std::memcpy(reinterpret_cast<unsigned char *>(&result), data, sizeof(int32));
data += sizeof(int32); data += sizeof(int32);
return result; return result;
} }
@ -136,7 +132,7 @@ class TlParser {
template <class T> template <class T>
T fetch_binary() { T fetch_binary() {
static_assert(sizeof(T) <= sizeof(empty_data), "too big fetch_binary"); static_assert(sizeof(T) <= sizeof(empty_data), "too big fetch_binary");
static_assert(sizeof(T) % sizeof(int32) == 0, "wrong call to fetch_binary"); //static_assert(sizeof(T) % sizeof(int32) == 0, "wrong call to fetch_binary");
check_len(sizeof(T)); check_len(sizeof(T));
return fetch_binary_unsafe<T>(); return fetch_binary_unsafe<T>();
} }
@ -165,7 +161,7 @@ class TlParser {
template <class T> template <class T>
T fetch_string_raw(const size_t size) { T fetch_string_raw(const size_t size) {
CHECK(size % sizeof(int32) == 0); //CHECK(size % sizeof(int32) == 0);
check_len(size); check_len(size);
const char *result = reinterpret_cast<const char *>(data); const char *result = reinterpret_cast<const char *>(data);
data += size; data += size;

View File

@ -34,8 +34,7 @@ class TlStorerUnsafe {
} }
void store_int(int32 x) { void store_int(int32 x) {
*reinterpret_cast<int32 *>(buf_) = x; store_binary<int32>(x);
buf_ += sizeof(int32);
} }
void store_long(int64 x) { void store_long(int64 x) {

View File

@ -11,12 +11,20 @@ namespace td {
template <class FunctionT> template <class FunctionT>
struct member_function_class; struct member_function_class;
template <class ReturnType, class Type> template <class ReturnType, class Type, class... Args>
struct member_function_class<ReturnType Type::*> { struct member_function_class<ReturnType (Type::*)(Args...)> {
using type = Type; using type = Type;
constexpr static size_t arguments_count() {
return sizeof...(Args);
}
}; };
template <class FunctionT> template <class FunctionT>
using member_function_class_t = typename member_function_class<FunctionT>::type; using member_function_class_t = typename member_function_class<FunctionT>::type;
template <class FunctionT>
constexpr size_t member_function_arguments_count() {
return member_function_class<FunctionT>::arguments_count();
}
} // namespace td } // namespace td

58
tdutils/test/buffer.cpp Normal file
View File

@ -0,0 +1,58 @@
//
// Copyright Aliaksei Levin (levlam@telegram.org), Arseny Smirnov (arseny30@gmail.com) 2014-2018
//
// Distributed under the Boost Software License, Version 1.0. (See accompanying
// file LICENSE_1_0.txt or copy at http://www.boost.org/LICENSE_1_0.txt)
//
#include "td/utils/tests.h"
#include "td/utils/buffer.h"
using namespace td;
TEST(Buffer, buffer_builder) {
{
BufferBuilder builder;
builder.append("b");
builder.prepend("a");
builder.append("c");
ASSERT_EQ(builder.extract().as_slice(), "abc");
}
{
BufferBuilder builder{"hello", 0, 0};
ASSERT_EQ(builder.extract().as_slice(), "hello");
}
{
BufferBuilder builder{"hello", 1, 1};
builder.prepend("A ");
builder.append(" B");
ASSERT_EQ(builder.extract().as_slice(), "A hello B");
}
//Slice slice(std::string());
{
std::string str = rand_string('a', 'z', 10000);
auto splitted_str = rand_split(str);
int l = Random::fast(0, static_cast<int32>(splitted_str.size() - 1));
int r = l;
BufferBuilder builder(splitted_str[l], 123, 1000);
while (l != 0 || r != static_cast<int32>(splitted_str.size()) - 1) {
if (l == 0 || (Random::fast(0, 1) == 1 && r != static_cast<int32>(splitted_str.size() - 1))) {
r++;
if (Random::fast(0, 1) == 1) {
builder.append(splitted_str[r]);
} else {
builder.append(BufferSlice(splitted_str[r]));
}
} else {
l--;
if (Random::fast(0, 1) == 1) {
builder.prepend(splitted_str[l]);
} else {
builder.prepend(BufferSlice(splitted_str[l]));
}
}
}
ASSERT_EQ(builder.extract().as_slice(), str);
}
}

View File

@ -9,6 +9,7 @@
#include "td/utils/HttpUrl.h" #include "td/utils/HttpUrl.h"
#include "td/utils/logging.h" #include "td/utils/logging.h"
#include "td/utils/misc.h" #include "td/utils/misc.h"
#include "td/utils/invoke.h"
#include "td/utils/port/EventFd.h" #include "td/utils/port/EventFd.h"
#include "td/utils/port/FileFd.h" #include "td/utils/port/FileFd.h"
#include "td/utils/port/IPAddress.h" #include "td/utils/port/IPAddress.h"
@ -125,6 +126,21 @@ TEST(Misc, errno_tls_bug) {
#endif #endif
} }
TEST(Misc, get_last_argument) {
auto a = std::make_unique<int>(5);
ASSERT_EQ(*get_last_argument(std::move(a)), 5);
ASSERT_EQ(*get_last_argument(1, 2, 3, 4, a), 5);
ASSERT_EQ(*get_last_argument(a), 5);
auto b = get_last_argument(1, 2, 3, std::move(a));
ASSERT_TRUE(!a);
ASSERT_EQ(*b, 5);
}
TEST(Misc, call_n_arguments) {
auto f = [](int, int) {};
call_n_arguments<2>(f, 1, 3, 4);
}
TEST(Misc, base64) { TEST(Misc, base64) {
ASSERT_TRUE(is_base64("dGVzdA==") == true); ASSERT_TRUE(is_base64("dGVzdA==") == true);
ASSERT_TRUE(is_base64("dGVzdB==") == false); ASSERT_TRUE(is_base64("dGVzdB==") == false);

42
tdutils/test/port.cpp Normal file
View File

@ -0,0 +1,42 @@
//
// Copyright Aliaksei Levin (levlam@telegram.org), Arseny Smirnov (arseny30@gmail.com) 2014-2018
//
// Distributed under the Boost Software License, Version 1.0. (See accompanying
// file LICENSE_1_0.txt or copy at http://www.boost.org/LICENSE_1_0.txt)
//
#include "td/utils/tests.h"
#include "td/utils/port/FileFd.h"
#include "td/utils/port/path.h"
using namespace td;
TEST(Port, files) {
CSlice main_dir = "test_dir";
rmrf(main_dir).ignore();
mkdir(main_dir).ensure();
mkdir(PSLICE() << main_dir << TD_DIR_SLASH << "A").ensure();
mkdir(PSLICE() << main_dir << TD_DIR_SLASH << "B").ensure();
mkdir(PSLICE() << main_dir << TD_DIR_SLASH << "C").ensure();
ASSERT_TRUE(FileFd::open(main_dir, FileFd::Write).is_error());
std::string fd_path = PSTRING() << main_dir << TD_DIR_SLASH << "t.txt";
auto fd = FileFd::open(fd_path, FileFd::Write | FileFd::CreateNew).move_as_ok();
ASSERT_EQ(0u, fd.get_size());
ASSERT_EQ(12u, fd.write("Hello world!").move_as_ok());
ASSERT_EQ(4u, fd.pwrite("abcd", 1).move_as_ok());
char buf[100];
MutableSlice buf_slice(buf, sizeof(buf));
ASSERT_TRUE(fd.pread(buf_slice.substr(0, 4), 2).is_error());
fd.seek(11).ensure();
ASSERT_EQ(2u, fd.write("?!").move_as_ok());
ASSERT_TRUE(FileFd::open(main_dir, FileFd::Read | FileFd::CreateNew).is_error());
fd = FileFd::open(fd_path, FileFd::Read | FileFd::Create).move_as_ok();
ASSERT_EQ(13u, fd.get_size());
ASSERT_EQ(4u, fd.pread(buf_slice.substr(0, 4), 1).move_as_ok());
ASSERT_STREQ("abcd", buf_slice.substr(0, 4));
fd.seek(0).ensure();
ASSERT_EQ(13u, fd.read(buf_slice.substr(0, 13)).move_as_ok());
ASSERT_STREQ("Habcd world?!", buf_slice.substr(0, 13));
}