// // Copyright Aliaksei Levin (levlam@telegram.org), Arseny Smirnov (arseny30@gmail.com) 2014-2024 // // Distributed under the Boost Software License, Version 1.0. (See accompanying // file LICENSE_1_0.txt or copy at http://www.boost.org/LICENSE_1_0.txt) // #include "td/utils/port/StdStreams.h" #include "td/utils/logging.h" #include "td/utils/misc.h" #include "td/utils/port/detail/Iocp.h" #include "td/utils/port/detail/NativeFd.h" #include "td/utils/port/detail/PollableFd.h" #include "td/utils/port/PollFlags.h" #include "td/utils/port/thread.h" #include "td/utils/ScopeGuard.h" #include "td/utils/Slice.h" #include "td/utils/SliceBuilder.h" #include <atomic> namespace td { #if TD_PORT_POSIX template <int id> static FileFd &get_file_fd() { static FileFd result = FileFd::from_native_fd(NativeFd(id, true)); static auto guard = ScopeExit() + [&] { result.move_as_native_fd().release(); }; return result; } FileFd &Stdin() { return get_file_fd<0>(); } FileFd &Stdout() { return get_file_fd<1>(); } FileFd &Stderr() { return get_file_fd<2>(); } #elif TD_PORT_WINDOWS template <DWORD id> static FileFd &get_file_fd() { #if WINAPI_FAMILY_PARTITION(WINAPI_PARTITION_DESKTOP | WINAPI_PARTITION_SYSTEM) static auto handle = GetStdHandle(id); LOG_IF(FATAL, handle == INVALID_HANDLE_VALUE) << "Failed to GetStdHandle " << id; static FileFd result = handle == nullptr ? FileFd() : FileFd::from_native_fd(NativeFd(handle, true)); static auto guard = ScopeExit() + [&] { if (handle != nullptr) { result.move_as_native_fd().release(); } }; #else static FileFd result; #endif return result; } FileFd &Stdin() { return get_file_fd<STD_INPUT_HANDLE>(); } FileFd &Stdout() { return get_file_fd<STD_OUTPUT_HANDLE>(); } FileFd &Stderr() { return get_file_fd<STD_ERROR_HANDLE>(); } #endif #if TD_PORT_WINDOWS namespace detail { class BufferedStdinImpl final : private Iocp::Callback { public: #if WINAPI_FAMILY_PARTITION(WINAPI_PARTITION_DESKTOP | WINAPI_PARTITION_SYSTEM) BufferedStdinImpl() : info_(NativeFd(GetStdHandle(STD_INPUT_HANDLE), true)) { iocp_ref_ = Iocp::get()->get_ref(); read_thread_ = thread([this] { this->read_loop(); }); } #else BufferedStdinImpl() { close(); } #endif BufferedStdinImpl(const BufferedStdinImpl &) = delete; BufferedStdinImpl &operator=(const BufferedStdinImpl &) = delete; BufferedStdinImpl(BufferedStdinImpl &&) = delete; BufferedStdinImpl &operator=(BufferedStdinImpl &&) = delete; ~BufferedStdinImpl() { info_.move_as_native_fd().release(); } void close() { close_flag_ = true; } ChainBufferReader &input_buffer() { return reader_; } PollableFdInfo &get_poll_info() { return info_; } const PollableFdInfo &get_poll_info() const { return info_; } Result<size_t> flush_read(size_t max_read = std::numeric_limits<size_t>::max()) TD_WARN_UNUSED_RESULT { info_.sync_with_poll(); info_.clear_flags(PollFlags::Read()); reader_.sync_with_writer(); return reader_.size(); } private: PollableFdInfo info_; ChainBufferWriter writer_; ChainBufferReader reader_ = writer_.extract_reader(); thread read_thread_; std::atomic<bool> close_flag_{false}; IocpRef iocp_ref_; std::atomic<int> refcnt_{1}; void read_loop() { while (!close_flag_) { auto slice = writer_.prepare_append(); auto r_size = read(slice); if (r_size.is_error()) { LOG(ERROR) << "Stop read stdin loop: " << r_size.error(); break; } writer_.confirm_append(r_size.ok()); inc_refcnt(); if (!iocp_ref_.post(0, this, nullptr)) { dec_refcnt(); } } if (!iocp_ref_.post(0, this, nullptr)) { read_thread_.detach(); dec_refcnt(); } } void on_iocp(Result<size_t> r_size, WSAOVERLAPPED *overlapped) final { info_.add_flags_from_poll(PollFlags::Read()); dec_refcnt(); } bool dec_refcnt() { if (--refcnt_ == 0) { delete this; return true; } return false; } void inc_refcnt() { CHECK(refcnt_ != 0); refcnt_++; } Result<size_t> read(MutableSlice slice) { auto native_fd = info_.native_fd().fd(); DWORD bytes_read = 0; auto res = ReadFile(native_fd, slice.data(), narrow_cast<DWORD>(slice.size()), &bytes_read, nullptr); if (res) { return static_cast<size_t>(bytes_read); } return OS_ERROR(PSLICE() << "Read from " << info_.native_fd() << " has failed"); } }; void BufferedStdinImplDeleter::operator()(BufferedStdinImpl *impl) { // LOG(ERROR) << "Close"; impl->close(); } } // namespace detail #elif TD_PORT_POSIX namespace detail { class BufferedStdinImpl { public: BufferedStdinImpl() { file_fd_ = FileFd::from_native_fd(NativeFd(Stdin().get_native_fd().fd())); file_fd_.get_native_fd().set_is_blocking(false); } BufferedStdinImpl(const BufferedStdinImpl &) = delete; BufferedStdinImpl &operator=(const BufferedStdinImpl &) = delete; BufferedStdinImpl(BufferedStdinImpl &&) = delete; BufferedStdinImpl &operator=(BufferedStdinImpl &&) = delete; ~BufferedStdinImpl() { file_fd_.get_native_fd().set_is_blocking(true); file_fd_.move_as_native_fd().release(); } ChainBufferReader &input_buffer() { return reader_; } PollableFdInfo &get_poll_info() { return file_fd_.get_poll_info(); } const PollableFdInfo &get_poll_info() const { return file_fd_.get_poll_info(); } Result<size_t> flush_read(size_t max_read = std::numeric_limits<size_t>::max()) TD_WARN_UNUSED_RESULT { size_t result = 0; ::td::sync_with_poll(*this); while (::td::can_read_local(*this) && max_read) { MutableSlice slice = writer_.prepare_append(); slice.truncate(max_read); TRY_RESULT(x, file_fd_.read(slice)); slice.truncate(x); writer_.confirm_append(x); result += x; max_read -= x; } if (result) { reader_.sync_with_writer(); } return result; } private: FileFd file_fd_; ChainBufferWriter writer_; ChainBufferReader reader_ = writer_.extract_reader(); }; void BufferedStdinImplDeleter::operator()(BufferedStdinImpl *impl) { delete impl; } } // namespace detail #endif BufferedStdin::BufferedStdin() : impl_(make_unique<detail::BufferedStdinImpl>().release()) { } BufferedStdin::BufferedStdin(BufferedStdin &&) noexcept = default; BufferedStdin &BufferedStdin::operator=(BufferedStdin &&) noexcept = default; BufferedStdin::~BufferedStdin() = default; ChainBufferReader &BufferedStdin::input_buffer() { return impl_->input_buffer(); } PollableFdInfo &BufferedStdin::get_poll_info() { return impl_->get_poll_info(); } const PollableFdInfo &BufferedStdin::get_poll_info() const { return impl_->get_poll_info(); } Result<size_t> BufferedStdin::flush_read(size_t max_read) { return impl_->flush_read(max_read); } } // namespace td