diff --git a/td/mtproto/Handshake.cpp b/td/mtproto/Handshake.cpp index 93dc13df4..10cab04e3 100644 --- a/td/mtproto/Handshake.cpp +++ b/td/mtproto/Handshake.cpp @@ -346,26 +346,33 @@ Status AuthKeyHandshake::on_message(Slice message, Callback *connection, AuthKey return status; } +GlobalFloodControl::GlobalFloodControl(uint64 limit) : limit_(limit) { +} + void GlobalFloodControl::finish() { - // TODO: instead of decrementing count, we may wake up some pending request - active_count_--; -} -GlobalFloodControl::GlobalFloodControl(uint64_t limit) : limit_(limit) { + // TODO: instead of decrementing count, we can wake up some pending request + auto old_value = active_count_.fetch_sub(1, std::memory_order_relaxed); + CHECK(old_value > 0); } + Result GlobalFloodControl::try_start() { - if (++active_count_ > limit_) { - active_count_--; - return td::Status::Error("Handshake limit reached"); + auto old_value = active_count_.fetch_add(1, std::memory_order_relaxed); + if (old_value >= limit_) { + finish(); + return Status::Error("Handshake limit reached"); } return Guard(this); } + GlobalFloodControl *GlobalFloodControl::get_handshake_flood() { - constexpr uint64_t MAX_CONCURRENT_HANDSHAKES = 50; + constexpr uint64 MAX_CONCURRENT_HANDSHAKES = 250; static GlobalFloodControl flood{MAX_CONCURRENT_HANDSHAKES}; return &flood; } + void GlobalFloodControl::Finish::operator()(GlobalFloodControl *ctrl) const { ctrl->finish(); } + } // namespace mtproto } // namespace td diff --git a/td/mtproto/Handshake.h b/td/mtproto/Handshake.h index 58867725c..6833555df 100644 --- a/td/mtproto/Handshake.h +++ b/td/mtproto/Handshake.h @@ -10,11 +10,15 @@ #include "td/mtproto/RSA.h" #include "td/utils/buffer.h" +#include "td/utils/common.h" #include "td/utils/Slice.h" #include "td/utils/Status.h" #include "td/utils/StorerBase.h" #include "td/utils/UInt.h" +#include +#include + namespace td { namespace mtproto_api { @@ -27,17 +31,19 @@ class DhCallback; class GlobalFloodControl { public: - explicit GlobalFloodControl(uint64_t limit); + explicit GlobalFloodControl(uint64 limit); + struct Finish { void operator()(GlobalFloodControl *ctrl) const; }; using Guard = std::unique_ptr; - td::Result try_start(); + Result try_start(); + static GlobalFloodControl *get_handshake_flood(); private: - std::atomic active_count_{0}; - uint64_t limit_{0}; + std::atomic active_count_{0}; + uint64 limit_{0}; void finish(); }; @@ -47,7 +53,7 @@ class AuthKeyHandshakeContext { virtual ~AuthKeyHandshakeContext() = default; virtual DhCallback *get_dh_callback() = 0; virtual PublicRsaKeyInterface *get_public_rsa_key_interface() = 0; - virtual td::Status try_start() = 0; + virtual Status try_start() = 0; }; class AuthKeyHandshake { diff --git a/td/telegram/Td.cpp b/td/telegram/Td.cpp index 5206dbf8e..50d89f73b 100644 --- a/td/telegram/Td.cpp +++ b/td/telegram/Td.cpp @@ -500,8 +500,8 @@ class TestProxyRequest final : public RequestOnceActor { mtproto::PublicRsaKeyInterface *get_public_rsa_key_interface() final { return &public_rsa_key; } - td::Status try_start() final { - return td::Status::OK(); + Status try_start() final { + return Status::OK(); } private: diff --git a/td/telegram/net/Session.cpp b/td/telegram/net/Session.cpp index a9c521e97..e5eccf604 100644 --- a/td/telegram/net/Session.cpp +++ b/td/telegram/net/Session.cpp @@ -79,13 +79,13 @@ class GenAuthKeyActor final : public Actor { CancellationTokenSource cancellation_token_source_; ActorOwn child_; - td::Status alarm_error_; + Status alarm_error_; void start_up() final { // Bug in Android clang and MSVC // std::tuple> b(std::forward_as_tuple(Result())); - // Will sleep a little it there are too much active handshakes now + // Will sleep a little if there are too many active handshakes now // // TODO: we may want to use a blocking wait - semaphore but for actors. // (problem is - multiple schedulers may want to uses this semaphore) @@ -1372,9 +1372,9 @@ void Session::create_gen_auth_key_actor(HandshakeId handshake_id) { return public_rsa_key_.get(); } - td::Status try_start() { + Status try_start() final { TRY_RESULT_ASSIGN(guard_, mtproto::GlobalFloodControl::get_handshake_flood()->try_start()); - return td::Status::OK(); + return Status::OK(); } private: diff --git a/test/mtproto.cpp b/test/mtproto.cpp index 91c82abe5..9990c31ee 100644 --- a/test/mtproto.cpp +++ b/test/mtproto.cpp @@ -305,6 +305,9 @@ class HandshakeContext final : public td::mtproto::AuthKeyHandshakeContext { td::mtproto::PublicRsaKeyInterface *get_public_rsa_key_interface() final { return &public_rsa_key; } + td::Status try_start() final { + return td::Status::OK(); + } private: td::PublicRsaKeyShared public_rsa_key{td::DcId::empty(), true};