diff --git a/td/telegram/net/Session.cpp b/td/telegram/net/Session.cpp index a23be873..34db4387 100644 --- a/td/telegram/net/Session.cpp +++ b/td/telegram/net/Session.cpp @@ -71,7 +71,7 @@ class GenAuthKeyActor : public Actor { Promise> connection_promise_; Promise> handshake_promise_; std::shared_ptr callback_; - CancellationToken cancellation_token_{true}; + CancellationTokenSource cancellation_token_source_; ActorOwn child_; @@ -80,7 +80,8 @@ class GenAuthKeyActor : public Actor { // std::tuple> b(std::forward_as_tuple(Result())); callback_->request_raw_connection(PromiseCreator::cancellable_lambda( - cancellation_token_, [actor_id = actor_id(this)](Result> r_raw_connection) { + cancellation_token_source_.get_cancellation_token(), + [actor_id = actor_id(this)](Result> r_raw_connection) { send_closure(actor_id, &GenAuthKeyActor::on_connection, std::move(r_raw_connection), false); })); } @@ -878,10 +879,10 @@ void Session::connection_open(ConnectionInfo *info, bool ask_info) { info->ask_info = ask_info; info->state = ConnectionInfo::State::Connecting; - info->cancellation_token_ = CancellationToken{true}; + info->cancellation_token_source_ = CancellationTokenSource{}; // NB: rely on constant location of info auto promise = PromiseCreator::cancellable_lambda( - info->cancellation_token_, + info->cancellation_token_source_.get_cancellation_token(), [actor_id = actor_id(this), info = info](Result> res) { send_closure(actor_id, &Session::connection_open_finish, info, std::move(res)); }); diff --git a/td/telegram/net/Session.h b/td/telegram/net/Session.h index 3b843b9b..67fa51d1 100644 --- a/td/telegram/net/Session.h +++ b/td/telegram/net/Session.h @@ -128,7 +128,7 @@ class Session final int8 connection_id; Mode mode; enum class State : int8 { Empty, Connecting, Ready } state = State::Empty; - CancellationToken cancellation_token_; + CancellationTokenSource cancellation_token_source_; unique_ptr connection; bool ask_info; double wakeup_at = 0; diff --git a/tdactor/td/actor/PromiseFuture.h b/tdactor/td/actor/PromiseFuture.h index b163ac7d..0d1c2c5e 100644 --- a/tdactor/td/actor/PromiseFuture.h +++ b/tdactor/td/actor/PromiseFuture.h @@ -8,6 +8,7 @@ #include "td/actor/actor.h" +#include "td/utils/CancellationToken.h" #include "td/utils/Closure.h" #include "td/utils/common.h" #include "td/utils/invoke.h" // for tuple_for_each @@ -170,42 +171,6 @@ Promise &Promise::operator=(SafePromise &&other) { return *this; } -class CancellationToken { - public: - explicit CancellationToken(bool init = false) { - if (init) { - ptr_ = std::make_shared>(false); - } - } - CancellationToken(const CancellationToken &other) = default; - CancellationToken &operator=(const CancellationToken &other) { - cancel(); - ptr_ = other.ptr_; - return *this; - } - CancellationToken(CancellationToken &&other) = default; - CancellationToken &operator=(CancellationToken &&other) { - cancel(); - ptr_ = std::move(other.ptr_); - return *this; - } - ~CancellationToken() { - cancel(); - } - bool is_canceled() const { - return !ptr_ || *ptr_; - } - void cancel() { - if (ptr_) { - ptr_->store(true, std::memory_order_relaxed); - ptr_.reset(); - } - } - - private: - std::shared_ptr> ptr_; -}; - namespace detail { class EventPromise : public PromiseInterface { @@ -287,7 +252,7 @@ class CancellablePromise : public PromiseT { return true; } virtual bool is_cancelled() const { - return cancellation_token_.is_canceled(); + return bool(cancellation_token_); } private: diff --git a/tdutils/CMakeLists.txt b/tdutils/CMakeLists.txt index ddd8c857..3de5c232 100644 --- a/tdutils/CMakeLists.txt +++ b/tdutils/CMakeLists.txt @@ -144,6 +144,7 @@ set(TDUTILS_SOURCE td/utils/BufferedReader.h td/utils/BufferedUdp.h td/utils/ByteFlow.h + td/utils/CancellationToken.h td/utils/ChangesProcessor.h td/utils/check.h td/utils/Closure.h diff --git a/tdutils/td/utils/CancellationToken.h b/tdutils/td/utils/CancellationToken.h new file mode 100644 index 00000000..f2a7ff84 --- /dev/null +++ b/tdutils/td/utils/CancellationToken.h @@ -0,0 +1,67 @@ +// +// Copyright Aliaksei Levin (levlam@telegram.org), Arseny Smirnov (arseny30@gmail.com) 2014-2019 +// +// Distributed under the Boost Software License, Version 1.0. (See accompanying +// file LICENSE_1_0.txt or copy at http://www.boost.org/LICENSE_1_0.txt) +// +#pragma once + +#include +namespace td { +namespace detail { +struct RawCancellationToken { + std::atomic is_cancelled_{false}; +}; +} // namespace detail + +class CancellationToken { + public: + explicit operator bool() const { + return token_->is_cancelled_.load(std::memory_order_acquire); + } + explicit CancellationToken(std::shared_ptr token) : token_(std::move(token)) { + } + CancellationToken(CancellationToken &&other) = default; + CancellationToken(const CancellationToken &other) = default; + CancellationToken &operator=(CancellationToken &&other) = default; + CancellationToken &operator=(const CancellationToken &other) = default; + ~CancellationToken() = default; + + private: + std::shared_ptr token_; +}; + +class CancellationTokenSource { + public: + CancellationTokenSource() = default; + CancellationTokenSource(CancellationTokenSource &&other) : token_(std::move(other.token_)) { + } + CancellationTokenSource &operator=(CancellationTokenSource &&other) { + cancel(); + token_ = std::move(other.token_); + return *this; + } + ~CancellationTokenSource() { + cancel(); + } + CancellationTokenSource(const CancellationTokenSource &other) = delete; + CancellationTokenSource &operator=(const CancellationTokenSource &other) = delete; + + CancellationToken get_cancellation_token() { + if (!token_) { + token_ = std::make_shared(); + } + return CancellationToken(token_); + } + void cancel() { + if (!token_) { + return; + } + token_->is_cancelled_.store(true, std::memory_order_release); + token_.reset(); + } + + private: + std::shared_ptr token_; +}; +} // namespace td diff --git a/tdutils/test/misc.cpp b/tdutils/test/misc.cpp index 0cb87d47..055e462d 100644 --- a/tdutils/test/misc.cpp +++ b/tdutils/test/misc.cpp @@ -8,6 +8,7 @@ #include "td/utils/base64.h" #include "td/utils/BigNum.h" #include "td/utils/bits.h" +#include "td/utils/CancellationToken.h" #include "td/utils/common.h" #include "td/utils/HttpUrl.h" #include "td/utils/invoke.h" @@ -662,3 +663,23 @@ TEST(Misc, Bits) { ASSERT_EQ(4, count_bits32((1u << 31) | 7)); ASSERT_EQ(4, count_bits64((1ull << 63) | 7)); } + +TEST(Misc, CancellationToken) { + CancellationTokenSource source; + source.cancel(); + auto token1 = source.get_cancellation_token(); + auto token2 = source.get_cancellation_token(); + CHECK(!token1); + source.cancel(); + CHECK(token1); + CHECK(token2); + auto token3 = source.get_cancellation_token(); + CHECK(!token3); + source.cancel(); + CHECK(token3); + + auto token4 = source.get_cancellation_token(); + CHECK(!token4); + source = CancellationTokenSource{}; + CHECK(token4); +}