// // Copyright Aliaksei Levin (levlam@telegram.org), Arseny Smirnov (arseny30@gmail.com) 2014-2023 // // Distributed under the Boost Software License, Version 1.0. (See accompanying // file LICENSE_1_0.txt or copy at http://www.boost.org/LICENSE_1_0.txt) // #include "td/mtproto/RawConnection.h" #include "td/mtproto/AuthKey.h" #include "td/mtproto/IStreamTransport.h" #include "td/mtproto/ProxySecret.h" #include "td/mtproto/Transport.h" #if TD_DARWIN_WATCH_OS #include "td/net/DarwinHttp.h" #endif #include "td/utils/FlatHashMap.h" #include "td/utils/format.h" #include "td/utils/logging.h" #include "td/utils/misc.h" #include "td/utils/MpscPollableQueue.h" #include "td/utils/port/EventFd.h" #include "td/utils/Slice.h" #include "td/utils/SliceBuilder.h" #include "td/utils/Status.h" #include "td/utils/StorerBase.h" #include #include namespace td { namespace mtproto { class RawConnectionDefault final : public RawConnection { public: RawConnectionDefault(BufferedFd buffered_socket_fd, TransportType transport_type, unique_ptr stats_callback) : socket_fd_(std::move(buffered_socket_fd)) , transport_(create_transport(std::move(transport_type))) , stats_callback_(std::move(stats_callback)) { transport_->init(&socket_fd_.input_buffer(), &socket_fd_.output_buffer()); } void set_connection_token(ConnectionManager::ConnectionToken connection_token) final { connection_token_ = std::move(connection_token); } bool can_send() const final { return transport_->can_write(); } TransportType get_transport_type() const final { return transport_->get_type(); } size_t send_crypto(const Storer &storer, int64 session_id, int64 salt, const AuthKey &auth_key, uint64 quick_ack_token) final { PacketInfo info; info.version = 2; info.no_crypto_flag = false; info.salt = salt; info.session_id = session_id; info.use_random_padding = transport_->use_random_padding(); auto packet = BufferWriter{Transport::write(storer, auth_key, &info), transport_->max_prepend_size(), transport_->max_append_size()}; Transport::write(storer, auth_key, &info, packet.as_mutable_slice()); bool use_quick_ack = false; if (quick_ack_token != 0 && transport_->support_quick_ack()) { CHECK(info.message_ack & (1u << 31)); auto tmp = quick_ack_to_token_.emplace(info.message_ack, quick_ack_token); if (tmp.second) { use_quick_ack = true; } else { LOG(ERROR) << "Quick ack " << info.message_ack << " collision"; } } auto packet_size = packet.size(); transport_->write(std::move(packet), use_quick_ack); return packet_size; } uint64 send_no_crypto(const Storer &storer) final { PacketInfo info; info.no_crypto_flag = true; auto packet = BufferWriter{Transport::write(storer, AuthKey(), &info), transport_->max_prepend_size(), transport_->max_append_size()}; Transport::write(storer, AuthKey(), &info, packet.as_mutable_slice()); LOG(INFO) << "Send handshake packet: " << format::as_hex_dump<4>(packet.as_slice()); transport_->write(std::move(packet), false); return info.message_id; } PollableFdInfo &get_poll_info() final { return socket_fd_.get_poll_info(); } StatsCallback *stats_callback() final { return stats_callback_.get(); } // NB: After first returned error, all subsequent calls will return error too. Status flush(const AuthKey &auth_key, Callback &callback) final { auto status = do_flush(auth_key, callback); if (status.is_error()) { if (stats_callback_ && status.code() != 2) { stats_callback_->on_error(); } has_error_ = true; } return status; } bool has_error() const final { return has_error_; } void close() final { transport_.reset(); socket_fd_.close(); } PublicFields &extra() final { return extra_; } const PublicFields &extra() const final { return extra_; } private: PublicFields extra_; BufferedFd socket_fd_; unique_ptr transport_; FlatHashMap quick_ack_to_token_; bool has_error_{false}; unique_ptr stats_callback_; ConnectionManager::ConnectionToken connection_token_; Status flush_read(const AuthKey &auth_key, Callback &callback) { auto r = socket_fd_.flush_read(); if (r.is_ok()) { if (stats_callback_) { stats_callback_->on_read(r.ok()); } callback.on_read(r.ok()); } while (transport_->can_read()) { BufferSlice packet; uint32 quick_ack = 0; TRY_RESULT(wait_size, transport_->read_next(&packet, &quick_ack)); if (!is_aligned_pointer<4>(packet.as_slice().ubegin())) { BufferSlice new_packet(packet.size()); new_packet.as_mutable_slice().copy_from(packet.as_slice()); packet = std::move(new_packet); } LOG_CHECK(is_aligned_pointer<4>(packet.as_slice().ubegin())) << packet.as_slice().ubegin() << ' ' << packet.size() << ' ' << wait_size; if (wait_size != 0) { constexpr size_t MAX_PACKET_SIZE = (1 << 22) + 1024; if (wait_size > MAX_PACKET_SIZE) { return Status::Error(PSLICE() << "Expected packet size is too big: " << wait_size); } break; } if (quick_ack != 0) { TRY_STATUS(on_quick_ack(quick_ack, callback)); continue; } PacketInfo info; info.version = 2; TRY_RESULT(read_result, Transport::read(packet.as_mutable_slice(), auth_key, &info)); switch (read_result.type()) { case Transport::ReadResult::Quickack: { TRY_STATUS(on_quick_ack(read_result.quick_ack(), callback)); break; } case Transport::ReadResult::Error: { TRY_STATUS(on_read_mtproto_error(read_result.error())); break; } case Transport::ReadResult::Packet: { // If a packet was successfully decrypted, then it is ok to assume that the connection is alive if (!auth_key.empty()) { if (stats_callback_) { stats_callback_->on_pong(); } } TRY_STATUS(callback.on_raw_packet(info, packet.from_slice(read_result.packet()))); break; } case Transport::ReadResult::Nop: break; default: UNREACHABLE(); } } TRY_STATUS(std::move(r)); return Status::OK(); } Status on_read_mtproto_error(int32 error_code) { if (error_code == -429) { if (stats_callback_) { stats_callback_->on_mtproto_error(); } return Status::Error(500, PSLICE() << "MTProto error: " << error_code); } if (error_code == -404) { return Status::Error(-404, PSLICE() << "MTProto error: " << error_code); } return Status::Error(PSLICE() << "MTProto error: " << error_code); } Status on_quick_ack(uint32 quick_ack, Callback &callback) { if ((quick_ack & (1u << 31)) == 0) { LOG(ERROR) << "Receive invalid quick_ack " << quick_ack; return Status::OK(); } auto it = quick_ack_to_token_.find(quick_ack); if (it == quick_ack_to_token_.end()) { LOG(WARNING) << "Receive unknown quick_ack " << quick_ack; return Status::OK(); } auto token = it->second; quick_ack_to_token_.erase(it); callback.on_quick_ack(token).ignore(); return Status::OK(); } Status flush_write() { TRY_RESULT(size, socket_fd_.flush_write()); if (size > 0 && stats_callback_) { stats_callback_->on_write(size); } return Status::OK(); } Status do_flush(const AuthKey &auth_key, Callback &callback) TD_WARN_UNUSED_RESULT { if (has_error_) { return Status::Error("Connection has already failed"); } sync_with_poll(socket_fd_); // read/write // EINVAL can be returned in Linux kernel < 2.6.28. And on some new kernels too. // just close connection and hope that read or write will not return this error too. TRY_STATUS(socket_fd_.get_pending_error()); TRY_STATUS(flush_read(auth_key, callback)); TRY_STATUS(callback.before_write()); TRY_STATUS(flush_write()); if (can_close_local(socket_fd_)) { return Status::Error("Connection closed"); } return Status::OK(); } }; #if TD_DARWIN_WATCH_OS class RawConnectionHttp final : public RawConnection { public: RawConnectionHttp(IPAddress ip_address, unique_ptr stats_callback) : ip_address_(std::move(ip_address)), stats_callback_(std::move(stats_callback)) { answers_ = std::make_shared>>(); answers_->init(); } void set_connection_token(ConnectionManager::ConnectionToken connection_token) final { connection_token_ = std::move(connection_token); } bool can_send() const final { return mode_ == Send; } TransportType get_transport_type() const final { return mtproto::TransportType{mtproto::TransportType::Http, 0, mtproto::ProxySecret()}; } size_t send_crypto(const Storer &storer, int64 session_id, int64 salt, const AuthKey &auth_key, uint64 quick_ack_token) final { PacketInfo info; info.version = 2; info.no_crypto_flag = false; info.salt = salt; info.session_id = session_id; info.use_random_padding = false; auto packet = BufferWriter{Transport::write(storer, auth_key, &info), 0, 0}; Transport::write(storer, auth_key, &info, packet.as_slice()); auto packet_size = packet.size(); send_packet(packet.as_buffer_slice()); return packet_size; } uint64 send_no_crypto(const Storer &storer) final { PacketInfo info; info.no_crypto_flag = true; auto packet = BufferWriter{Transport::write(storer, AuthKey(), &info), 0, 0}; Transport::write(storer, AuthKey(), &info, packet.as_slice()); LOG(INFO) << "Send handshake packet: " << format::as_hex_dump<4>(packet.as_slice()); send_packet(packet.as_buffer_slice()); return info.message_id; } PollableFdInfo &get_poll_info() final { return answers_->reader_get_event_fd().get_poll_info(); } StatsCallback *stats_callback() final { return stats_callback_.get(); } // NB: After first returned error, all subsequent calls will return error too. Status flush(const AuthKey &auth_key, Callback &callback) final { auto status = do_flush(auth_key, callback); if (status.is_error()) { if (stats_callback_ && status.code() != 2) { stats_callback_->on_error(); } has_error_ = true; } return status; } bool has_error() const final { return has_error_; } void close() final { } PublicFields &extra() final { return extra_; } const PublicFields &extra() const final { return extra_; } private: PublicFields extra_; IPAddress ip_address_; bool has_error_{false}; EventFd event_fd_; enum Mode { Send, Receive } mode_{Send}; unique_ptr stats_callback_; ConnectionManager::ConnectionToken connection_token_; std::shared_ptr>> answers_; std::vector to_send_; void send_packet(BufferSlice packet) { CHECK(mode_ == Send); mode_ = Receive; to_send_.push_back(std::move(packet)); } Status flush_read(const AuthKey &auth_key, Callback &callback) { while (true) { auto packets_n = answers_->reader_wait_nonblock(); if (packets_n == 0) { break; } for (int i = 0; i < packets_n; i++) { TRY_RESULT(packet, answers_->reader_get_unsafe()); if (stats_callback_) { stats_callback_->on_read(packet.size()); } callback.on_read(packet.size()); CHECK(mode_ == Receive); mode_ = Send; PacketInfo info; info.version = 2; TRY_RESULT(read_result, Transport::read(packet.as_mutable_slice(), auth_key, &info)); switch (read_result.type()) { case Transport::ReadResult::Quickack: { break; } case Transport::ReadResult::Error: { TRY_STATUS(on_read_mtproto_error(read_result.error())); break; } case Transport::ReadResult::Packet: { // If a packet was successfully decrypted, then it is ok to assume that the connection is alive if (!auth_key.empty()) { if (stats_callback_) { stats_callback_->on_pong(); } } TRY_STATUS(callback.on_raw_packet(info, packet.from_slice(read_result.packet()))); break; } case Transport::ReadResult::Nop: break; default: UNREACHABLE(); } } } return Status::OK(); } Status on_read_mtproto_error(int32 error_code) { if (error_code == -429) { if (stats_callback_) { stats_callback_->on_mtproto_error(); } return Status::Error(500, PSLICE() << "MTProto error: " << error_code); } if (error_code == -404) { return Status::Error(-404, PSLICE() << "MTProto error: " << error_code); } return Status::Error(PSLICE() << "MTProto error: " << error_code); } Status flush_write() { for (auto &packet : to_send_) { TRY_STATUS(do_send(packet.as_slice())); if (packet.size() > 0 && stats_callback_) { stats_callback_->on_write(packet.size()); } } to_send_.clear(); return Status::OK(); } Status do_send(Slice data) { DarwinHttp::post(PSLICE() << "http://" << ip_address_.get_ip_str() << ":" << ip_address_.get_port() << "/api", data, [answers = answers_](auto res) { answers->writer_put(std::move(res)); }); return Status::OK(); } Status do_flush(const AuthKey &auth_key, Callback &callback) TD_WARN_UNUSED_RESULT { if (has_error_) { return Status::Error("Connection has already failed"); } TRY_STATUS(flush_read(auth_key, callback)); TRY_STATUS(callback.before_write()); TRY_STATUS(flush_write()); return Status::OK(); } }; #endif unique_ptr RawConnection::create(IPAddress ip_address, BufferedFd buffered_socket_fd, TransportType transport_type, unique_ptr stats_callback) { #if TD_DARWIN_WATCH_OS return td::make_unique(std::move(ip_address), std::move(stats_callback)); #else return td::make_unique(std::move(buffered_socket_fd), std::move(transport_type), std::move(stats_callback)); #endif } } // namespace mtproto } // namespace td