Client: share scheduler between different clients

GitOrigin-RevId: 6bddeaf1938a1bb35dc9a7670c10b044419089fe
This commit is contained in:
Arseny Smirnov 2019-03-21 22:59:20 +13:00
parent 75f323e2f6
commit 18900e9d69
3 changed files with 151 additions and 28 deletions

View File

@ -106,16 +106,103 @@ class Client::Impl final {
#else
using OutputQueue = MpscPollableQueue<Client::Response>;
class MultiTd : public Actor {
public:
void create(int td_id, unique_ptr<TdCallback> callback) {
auto &td = tds_[td_id];
CHECK(td.empty());
string name = "Td";
if (td_id != 0) {
name += PSTRING() << "#" << td_id;
}
td = create_actor<Td>(name, std::move(callback));
}
void send(int td_id, Client::Request request) {
auto &td = tds_[td_id];
CHECK(!td.empty());
send_closure(td, &Td::request, request.id, std::move(request.function));
}
void destroy(int td_id) {
auto size = tds_.erase(td_id);
CHECK(size == 1);
}
private:
std::unordered_map<int, ActorOwn<Td> > tds_;
};
class MultiImpl {
public:
static std::shared_ptr<MultiImpl> get() {
static std::mutex mutex;
static std::weak_ptr<MultiImpl> impl;
std::unique_lock<std::mutex> lock(mutex);
auto res = impl.lock();
if (!res) {
res = std::make_shared<MultiImpl>();
impl = res;
}
return res;
}
MultiImpl() {
concurrent_scheduler_ = std::make_shared<ConcurrentScheduler>();
concurrent_scheduler_->init(3);
concurrent_scheduler_->start();
{
auto guard = concurrent_scheduler_->get_main_guard();
multi_td_ = create_actor<MultiTd>("MultiTd");
}
scheduler_thread_ = thread([concurrent_scheduler = concurrent_scheduler_] {
while (concurrent_scheduler->run_main(10)) {
}
concurrent_scheduler->finish();
});
}
int32 create_id() {
return id_.fetch_add(1) + 1;
}
void create(int32 td_id, td::unique_ptr<TdCallback> callback) {
auto guard = concurrent_scheduler_->get_send_guard();
send_closure(multi_td_, &MultiTd::create, td_id, std::move(callback));
}
void send(int32 td_id, Client::Request request) {
auto guard = concurrent_scheduler_->get_send_guard();
send_closure(multi_td_, &MultiTd::send, td_id, std::move(request));
}
void destroy(int32 td_id) {
auto guard = concurrent_scheduler_->get_send_guard();
send_closure(multi_td_, &MultiTd::destroy, td_id);
}
~MultiImpl() {
{
auto guard = concurrent_scheduler_->get_send_guard();
multi_td_.reset();
Scheduler::instance()->finish();
}
scheduler_thread_.join();
}
private:
std::shared_ptr<ConcurrentScheduler> concurrent_scheduler_;
td::thread scheduler_thread_;
td::ActorOwn<MultiTd> multi_td_;
std::atomic<int32> id_{0};
};
/*** Client::Impl ***/
class Client::Impl final {
public:
using OutputQueue = MpscPollableQueue<Client::Response>;
Impl() {
multi_impl_ = MultiImpl::get();
td_id_ = multi_impl_->create_id();
output_queue_ = std::make_shared<OutputQueue>();
output_queue_->init();
concurrent_scheduler_ = std::make_shared<ConcurrentScheduler>();
concurrent_scheduler_->init(3);
class Callback : public TdCallback {
public:
explicit Callback(std::shared_ptr<OutputQueue> output_queue) : output_queue_(std::move(output_queue)) {
@ -131,33 +218,26 @@ class Client::Impl final {
Callback(Callback &&) = delete;
Callback &operator=(Callback &&) = delete;
~Callback() override {
Scheduler::instance()->finish();
output_queue_->writer_put({0, nullptr});
}
private:
std::shared_ptr<OutputQueue> output_queue_;
};
td_ = concurrent_scheduler_->create_actor_unsafe<Td>(0, "Td", td::make_unique<Callback>(output_queue_));
concurrent_scheduler_->start();
scheduler_thread_ = thread([concurrent_scheduler = concurrent_scheduler_] {
while (concurrent_scheduler->run_main(10)) {
}
concurrent_scheduler->finish();
});
multi_impl_->create(td_id_, td::make_unique<Callback>(output_queue_));
}
void send(Request request) {
void send(Client::Request request) {
if (request.id == 0 || request.function == nullptr) {
LOG(ERROR) << "Drop wrong request " << request.id;
return;
}
auto guard = concurrent_scheduler_->get_send_guard();
send_closure(td_, &Td::request, request.id, std::move(request.function));
multi_impl_->send(td_id_, std::move(request));
}
Response receive(double timeout) {
Client::Response receive(double timeout) {
VLOG(td_requests) << "Begin to wait for updates with timeout " << timeout;
auto is_locked = receive_lock_.exchange(true);
CHECK(!is_locked);
@ -173,26 +253,32 @@ class Client::Impl final {
Impl(Impl &&) = delete;
Impl &operator=(Impl &&) = delete;
~Impl() {
auto guard = concurrent_scheduler_->get_send_guard();
td_.reset();
scheduler_thread_.join();
multi_impl_->destroy(td_id_);
while (!is_closed_) {
receive(10);
}
}
private:
std::shared_ptr<OutputQueue> output_queue_;
std::shared_ptr<ConcurrentScheduler> concurrent_scheduler_;
int output_queue_ready_cnt_{0};
thread scheduler_thread_;
std::atomic<bool> receive_lock_{false};
ActorOwn<Td> td_;
std::shared_ptr<MultiImpl> multi_impl_;
Response receive_unlocked(double timeout) {
std::shared_ptr<OutputQueue> output_queue_;
int output_queue_ready_cnt_{0};
std::atomic<bool> receive_lock_{false};
bool is_closed_{false};
int32 td_id_;
Client::Response receive_unlocked(double timeout) {
if (output_queue_ready_cnt_ == 0) {
output_queue_ready_cnt_ = output_queue_->reader_wait_nonblock();
}
if (output_queue_ready_cnt_ > 0) {
output_queue_ready_cnt_--;
return output_queue_->reader_get_unsafe();
auto res = output_queue_->reader_get_unsafe();
if (!res.object) {
is_closed_ = true;
}
return res;
}
if (timeout != 0) {
output_queue_->reader_get_event_fd().wait(static_cast<int>(timeout * 1000));

View File

@ -27,7 +27,7 @@ set(TESTS_MAIN
add_library(all_tests STATIC ${TD_TEST_SOURCE})
target_include_directories(all_tests PUBLIC $<BUILD_INTERFACE:${CMAKE_CURRENT_SOURCE_DIR}>)
target_link_libraries(all_tests PRIVATE tdactor tddb tdcore tdnet tdutils)
target_link_libraries(all_tests PRIVATE tdactor tddb tdcore tdnet tdutils tdclient)
if (NOT CMAKE_CROSSCOMPILING OR EMSCRIPTEN)
#Tests
@ -37,7 +37,7 @@ if (NOT CMAKE_CROSSCOMPILING OR EMSCRIPTEN)
set(CMAKE_EXE_LINKER_FLAGS "${CMAKE_EXE_LINKER_FLAGS} -fsanitize=undefined -fno-sanitize=vptr")
endif()
target_include_directories(run_all_tests PUBLIC $<BUILD_INTERFACE:${CMAKE_CURRENT_SOURCE_DIR}>)
target_link_libraries(run_all_tests PRIVATE tdactor tddb tdcore tdnet tdutils)
target_link_libraries(run_all_tests PRIVATE tdactor tddb tdcore tdnet tdutils tdclient)
if (CLANG)
# add_executable(fuzz_url fuzz_url.cpp)

View File

@ -10,6 +10,7 @@
#include "td/actor/PromiseFuture.h"
#include "td/telegram/ClientActor.h"
#include "td/telegram/Client.h"
#include "td/telegram/td_api.h"
@ -832,4 +833,40 @@ class Tdclient_login : public Test {
};
//RegisterTest<Tdclient_login> Tdclient_login("Tdclient_login");
TEST(Client, Simple) {
td::Client client;
//client.execute({1, td::td_api::make_object<td::td_api::setLogTagVerbosityLevel>("actor", 1)});
client.send({3, td::make_tl_object<td::td_api::testSquareInt>(3)});
while (true) {
auto result = client.receive(10);
if (result.id == 3) {
auto test_int = td::td_api::move_object_as<td::td_api::testInt>(result.object);
ASSERT_EQ(test_int->value_, 9);
break;
}
}
}
TEST(Client, SimpleMulti) {
std::vector<td::Client> clients(50);
//for (auto &client : clients) {
//client.execute({1, td::td_api::make_object<td::td_api::setLogTagVerbosityLevel>("td_requests", 1)});
//}
for (auto &client : clients) {
client.send({3, td::make_tl_object<td::td_api::testSquareInt>(3)});
}
for (auto &client : clients) {
while (true) {
auto result = client.receive(10);
if (result.id == 3) {
auto test_int = td::td_api::move_object_as<td::td_api::testInt>(result.object);
ASSERT_EQ(test_int->value_, 9);
break;
}
}
}
}
} // namespace td