diff --git a/td/telegram/Td.cpp b/td/telegram/Td.cpp index 6ac071583..c044b28da 100644 --- a/td/telegram/Td.cpp +++ b/td/telegram/Td.cpp @@ -3413,7 +3413,7 @@ void Td::request(uint64 id, tl_object_ptr function) { request_set_.insert(id); if (function == nullptr) { LOG(ERROR) << "Receive empty request"; - return send_error_raw(id, 400, "Request is empty"); + return send_error_impl(id, make_error(400, "Request is empty")); } VLOG(td_requests) << "Receive request " << id << ": " << to_string(function); @@ -3457,7 +3457,8 @@ void Td::request(uint64 id, tl_object_ptr function) { pending_preauthentication_requests_.emplace_back(id, std::move(function)); return; } - return send_error_raw(id, 400, "Initialization parameters are needed: call setTdlibParameters first"); + return send_error_impl( + id, make_error(400, "Initialization parameters are needed: call setTdlibParameters first")); } break; } @@ -3486,15 +3487,16 @@ void Td::request(uint64 id, tl_object_ptr function) { pending_preauthentication_requests_.emplace_back(id, std::move(function)); return; } - return send_error_raw(id, 400, "Database encryption key is needed: call checkDatabaseEncryptionKey first"); + return send_error_impl( + id, make_error(400, "Database encryption key is needed: call checkDatabaseEncryptionKey first")); } return answer_ok_query(id, init(as_db_key(encryption_key))); } case State::Close: if (destroy_flag_) { - return send_error_raw(id, 401, "Unauthorized"); + return send_error_impl(id, make_error(401, "Unauthorized")); } else { - return send_error_raw(id, 500, "Request aborted"); + return send_error_impl(id, make_error(500, "Request aborted")); } case State::Run: break; @@ -3502,7 +3504,7 @@ void Td::request(uint64 id, tl_object_ptr function) { if ((auth_manager_ == nullptr || !auth_manager_->is_authorized()) && !is_preauthentication_request(function_id) && !is_preinitialization_request(function_id) && !is_authentication_request(function_id)) { - return send_error_raw(id, 401, "Unauthorized"); + return send_error_impl(id, make_error(401, "Unauthorized")); } downcast_call(*function, [this, id](auto &request) { this->on_request(id, request); }); } diff --git a/tdnet/td/net/SslStream.cpp b/tdnet/td/net/SslStream.cpp index 00d3cb0a4..5f128c015 100644 --- a/tdnet/td/net/SslStream.cpp +++ b/tdnet/td/net/SslStream.cpp @@ -545,7 +545,8 @@ SslStream::SslStream(SslStream &&) = default; SslStream &SslStream::operator=(SslStream &&) = default; SslStream::~SslStream() = default; -Result SslStream::create(CSlice host, CSlice cert_file, VerifyPeer verify_peer) { +Result SslStream::create(CSlice host, CSlice cert_file, VerifyPeer verify_peer, + bool check_ip_address_as_host) { return Status::Error("Not supported in emscripten"); } diff --git a/test/tdclient.cpp b/test/tdclient.cpp index ac9257093..6ab0d533d 100644 --- a/test/tdclient.cpp +++ b/test/tdclient.cpp @@ -23,6 +23,7 @@ #include "td/utils/misc.h" #include "td/utils/port/FileFd.h" #include "td/utils/port/path.h" +#include "td/utils/port/sleep.h" #include "td/utils/port/thread.h" #include "td/utils/Random.h" #include "td/utils/Slice.h" @@ -949,6 +950,123 @@ TEST(Client, Manager) { } } } + +TEST(Client, Close) { + std::atomic stop_send{false}; + std::atomic can_stop_receive{false}; + std::atomic send_count{1}; + std::atomic receive_count{0}; + td::Client client; + + std::mutex request_ids_mutex; + std::set request_ids; + request_ids.insert(1); + td::thread send_thread([&] { + td::uint64 request_id = 2; + while (!stop_send.load()) { + { + std::unique_lock guard(request_ids_mutex); + request_ids.insert(request_id); + } + client.send({request_id++, td::make_tl_object(3)}); + send_count++; + } + can_stop_receive = true; + }); + + auto max_continue_send = td::Random::fast(0, 1) ? 0 : 1000; + td::thread receive_thread([&] { + while (true) { + auto response = client.receive(100.0); + if (stop_send && response.object == nullptr) { + return; + } + if (response.id > 0) { + if (!stop_send && response.object->get_id() == td::td_api::error::ID && + static_cast(*response.object).code_ == 500 && + td::Random::fast(0, max_continue_send) == 0) { + stop_send = true; + } + receive_count++; + { + std::unique_lock guard(request_ids_mutex); + size_t erase_count = request_ids.erase(response.id); + CHECK(erase_count > 0); + } + } + if (can_stop_receive && receive_count == send_count) { + break; + } + } + }); + + td::usleep_for((td::Random::fast(0, 1) ? 0 : 1000) * (td::Random::fast(0, 1) ? 1 : 50)); + client.send({1, td::make_tl_object()}); + + send_thread.join(); + receive_thread.join(); + ASSERT_EQ(send_count.load(), receive_count.load()); + ASSERT_TRUE(request_ids.empty()); +} + +TEST(Client, ManagerClose) { + std::atomic stop_send{false}; + std::atomic can_stop_receive{false}; + std::atomic send_count{1}; + std::atomic receive_count{0}; + td::ClientManager client_manager; + auto client_id = client_manager.create_client(); + + std::mutex request_ids_mutex; + std::set request_ids; + request_ids.insert(1); + td::thread send_thread([&] { + td::uint64 request_id = 2; + while (!stop_send.load()) { + { + std::unique_lock guard(request_ids_mutex); + request_ids.insert(request_id); + } + client_manager.send(client_id, request_id++, td::make_tl_object(3)); + send_count++; + } + can_stop_receive = true; + }); + + auto max_continue_send = td::Random::fast(0, 1) ? 0 : 1000; + td::thread receive_thread([&] { + while (true) { + auto response = client_manager.receive(100.0); + if (stop_send && response.object == nullptr) { + return; + } + if (response.request_id > 0) { + if (!stop_send && response.object->get_id() == td::td_api::error::ID && + static_cast(*response.object).code_ == 400 && + td::Random::fast(0, max_continue_send) == 0) { + stop_send = true; + } + receive_count++; + { + std::unique_lock guard(request_ids_mutex); + size_t erase_count = request_ids.erase(response.request_id); + CHECK(erase_count > 0); + } + } + if (can_stop_receive && receive_count == send_count) { + break; + } + } + }); + + td::usleep_for((td::Random::fast(0, 1) ? 0 : 1000) * (td::Random::fast(0, 1) ? 1 : 50)); + client_manager.send(client_id, 1, td::make_tl_object()); + + send_thread.join(); + receive_thread.join(); + ASSERT_EQ(send_count.load(), receive_count.load()); + ASSERT_TRUE(request_ids.empty()); +} #endif TEST(PartsManager, hands) {