simple CancellationToken

GitOrigin-RevId: 081b63eac0ac1e45153f2df4a2eea9fa825c9924
This commit is contained in:
Arseny Smirnov 2019-05-01 15:13:48 +02:00
parent 0264e2bbfd
commit e3e54b7a53
6 changed files with 97 additions and 42 deletions

View File

@ -71,7 +71,7 @@ class GenAuthKeyActor : public Actor {
Promise<unique_ptr<mtproto::RawConnection>> connection_promise_; Promise<unique_ptr<mtproto::RawConnection>> connection_promise_;
Promise<unique_ptr<mtproto::AuthKeyHandshake>> handshake_promise_; Promise<unique_ptr<mtproto::AuthKeyHandshake>> handshake_promise_;
std::shared_ptr<Session::Callback> callback_; std::shared_ptr<Session::Callback> callback_;
CancellationToken cancellation_token_{true}; CancellationTokenSource cancellation_token_source_;
ActorOwn<mtproto::HandshakeActor> child_; ActorOwn<mtproto::HandshakeActor> child_;
@ -80,7 +80,8 @@ class GenAuthKeyActor : public Actor {
// std::tuple<Result<int>> b(std::forward_as_tuple(Result<int>())); // std::tuple<Result<int>> b(std::forward_as_tuple(Result<int>()));
callback_->request_raw_connection(PromiseCreator::cancellable_lambda( callback_->request_raw_connection(PromiseCreator::cancellable_lambda(
cancellation_token_, [actor_id = actor_id(this)](Result<unique_ptr<mtproto::RawConnection>> r_raw_connection) { cancellation_token_source_.get_cancellation_token(),
[actor_id = actor_id(this)](Result<unique_ptr<mtproto::RawConnection>> r_raw_connection) {
send_closure(actor_id, &GenAuthKeyActor::on_connection, std::move(r_raw_connection), false); send_closure(actor_id, &GenAuthKeyActor::on_connection, std::move(r_raw_connection), false);
})); }));
} }
@ -878,10 +879,10 @@ void Session::connection_open(ConnectionInfo *info, bool ask_info) {
info->ask_info = ask_info; info->ask_info = ask_info;
info->state = ConnectionInfo::State::Connecting; info->state = ConnectionInfo::State::Connecting;
info->cancellation_token_ = CancellationToken{true}; info->cancellation_token_source_ = CancellationTokenSource{};
// NB: rely on constant location of info // NB: rely on constant location of info
auto promise = PromiseCreator::cancellable_lambda( auto promise = PromiseCreator::cancellable_lambda(
info->cancellation_token_, info->cancellation_token_source_.get_cancellation_token(),
[actor_id = actor_id(this), info = info](Result<unique_ptr<mtproto::RawConnection>> res) { [actor_id = actor_id(this), info = info](Result<unique_ptr<mtproto::RawConnection>> res) {
send_closure(actor_id, &Session::connection_open_finish, info, std::move(res)); send_closure(actor_id, &Session::connection_open_finish, info, std::move(res));
}); });

View File

@ -128,7 +128,7 @@ class Session final
int8 connection_id; int8 connection_id;
Mode mode; Mode mode;
enum class State : int8 { Empty, Connecting, Ready } state = State::Empty; enum class State : int8 { Empty, Connecting, Ready } state = State::Empty;
CancellationToken cancellation_token_; CancellationTokenSource cancellation_token_source_;
unique_ptr<mtproto::SessionConnection> connection; unique_ptr<mtproto::SessionConnection> connection;
bool ask_info; bool ask_info;
double wakeup_at = 0; double wakeup_at = 0;

View File

@ -8,6 +8,7 @@
#include "td/actor/actor.h" #include "td/actor/actor.h"
#include "td/utils/CancellationToken.h"
#include "td/utils/Closure.h" #include "td/utils/Closure.h"
#include "td/utils/common.h" #include "td/utils/common.h"
#include "td/utils/invoke.h" // for tuple_for_each #include "td/utils/invoke.h" // for tuple_for_each
@ -170,42 +171,6 @@ Promise<T> &Promise<T>::operator=(SafePromise<T> &&other) {
return *this; return *this;
} }
class CancellationToken {
public:
explicit CancellationToken(bool init = false) {
if (init) {
ptr_ = std::make_shared<std::atomic<bool>>(false);
}
}
CancellationToken(const CancellationToken &other) = default;
CancellationToken &operator=(const CancellationToken &other) {
cancel();
ptr_ = other.ptr_;
return *this;
}
CancellationToken(CancellationToken &&other) = default;
CancellationToken &operator=(CancellationToken &&other) {
cancel();
ptr_ = std::move(other.ptr_);
return *this;
}
~CancellationToken() {
cancel();
}
bool is_canceled() const {
return !ptr_ || *ptr_;
}
void cancel() {
if (ptr_) {
ptr_->store(true, std::memory_order_relaxed);
ptr_.reset();
}
}
private:
std::shared_ptr<std::atomic<bool>> ptr_;
};
namespace detail { namespace detail {
class EventPromise : public PromiseInterface<Unit> { class EventPromise : public PromiseInterface<Unit> {
@ -287,7 +252,7 @@ class CancellablePromise : public PromiseT {
return true; return true;
} }
virtual bool is_cancelled() const { virtual bool is_cancelled() const {
return cancellation_token_.is_canceled(); return bool(cancellation_token_);
} }
private: private:

View File

@ -144,6 +144,7 @@ set(TDUTILS_SOURCE
td/utils/BufferedReader.h td/utils/BufferedReader.h
td/utils/BufferedUdp.h td/utils/BufferedUdp.h
td/utils/ByteFlow.h td/utils/ByteFlow.h
td/utils/CancellationToken.h
td/utils/ChangesProcessor.h td/utils/ChangesProcessor.h
td/utils/check.h td/utils/check.h
td/utils/Closure.h td/utils/Closure.h

View File

@ -0,0 +1,67 @@
//
// Copyright Aliaksei Levin (levlam@telegram.org), Arseny Smirnov (arseny30@gmail.com) 2014-2019
//
// Distributed under the Boost Software License, Version 1.0. (See accompanying
// file LICENSE_1_0.txt or copy at http://www.boost.org/LICENSE_1_0.txt)
//
#pragma once
#include <atomic>
namespace td {
namespace detail {
struct RawCancellationToken {
std::atomic<bool> is_cancelled_{false};
};
} // namespace detail
class CancellationToken {
public:
explicit operator bool() const {
return token_->is_cancelled_.load(std::memory_order_acquire);
}
explicit CancellationToken(std::shared_ptr<detail::RawCancellationToken> token) : token_(std::move(token)) {
}
CancellationToken(CancellationToken &&other) = default;
CancellationToken(const CancellationToken &other) = default;
CancellationToken &operator=(CancellationToken &&other) = default;
CancellationToken &operator=(const CancellationToken &other) = default;
~CancellationToken() = default;
private:
std::shared_ptr<detail::RawCancellationToken> token_;
};
class CancellationTokenSource {
public:
CancellationTokenSource() = default;
CancellationTokenSource(CancellationTokenSource &&other) : token_(std::move(other.token_)) {
}
CancellationTokenSource &operator=(CancellationTokenSource &&other) {
cancel();
token_ = std::move(other.token_);
return *this;
}
~CancellationTokenSource() {
cancel();
}
CancellationTokenSource(const CancellationTokenSource &other) = delete;
CancellationTokenSource &operator=(const CancellationTokenSource &other) = delete;
CancellationToken get_cancellation_token() {
if (!token_) {
token_ = std::make_shared<detail::RawCancellationToken>();
}
return CancellationToken(token_);
}
void cancel() {
if (!token_) {
return;
}
token_->is_cancelled_.store(true, std::memory_order_release);
token_.reset();
}
private:
std::shared_ptr<detail::RawCancellationToken> token_;
};
} // namespace td

View File

@ -8,6 +8,7 @@
#include "td/utils/base64.h" #include "td/utils/base64.h"
#include "td/utils/BigNum.h" #include "td/utils/BigNum.h"
#include "td/utils/bits.h" #include "td/utils/bits.h"
#include "td/utils/CancellationToken.h"
#include "td/utils/common.h" #include "td/utils/common.h"
#include "td/utils/HttpUrl.h" #include "td/utils/HttpUrl.h"
#include "td/utils/invoke.h" #include "td/utils/invoke.h"
@ -662,3 +663,23 @@ TEST(Misc, Bits) {
ASSERT_EQ(4, count_bits32((1u << 31) | 7)); ASSERT_EQ(4, count_bits32((1u << 31) | 7));
ASSERT_EQ(4, count_bits64((1ull << 63) | 7)); ASSERT_EQ(4, count_bits64((1ull << 63) | 7));
} }
TEST(Misc, CancellationToken) {
CancellationTokenSource source;
source.cancel();
auto token1 = source.get_cancellation_token();
auto token2 = source.get_cancellation_token();
CHECK(!token1);
source.cancel();
CHECK(token1);
CHECK(token2);
auto token3 = source.get_cancellation_token();
CHECK(!token3);
source.cancel();
CHECK(token3);
auto token4 = source.get_cancellation_token();
CHECK(!token4);
source = CancellationTokenSource{};
CHECK(token4);
}