Move FloodControlGlobal to tdutils.

This commit is contained in:
levlam 2022-06-09 17:12:59 +03:00
parent a30ac1c277
commit 1b5798393d
8 changed files with 83 additions and 71 deletions

View File

@ -346,33 +346,5 @@ Status AuthKeyHandshake::on_message(Slice message, Callback *connection, AuthKey
return status;
}
GlobalFloodControl::GlobalFloodControl(uint64 limit) : limit_(limit) {
}
void GlobalFloodControl::finish() {
// TODO: instead of decrementing count, we can wake up some pending request
auto old_value = active_count_.fetch_sub(1, std::memory_order_relaxed);
CHECK(old_value > 0);
}
Result<GlobalFloodControl::Guard> GlobalFloodControl::try_start() {
auto old_value = active_count_.fetch_add(1, std::memory_order_relaxed);
if (old_value >= limit_) {
finish();
return Status::Error("Handshake limit reached");
}
return Guard(this);
}
GlobalFloodControl *GlobalFloodControl::get_handshake_flood() {
constexpr uint64 MAX_CONCURRENT_HANDSHAKES = 250;
static GlobalFloodControl flood{MAX_CONCURRENT_HANDSHAKES};
return &flood;
}
void GlobalFloodControl::Finish::operator()(GlobalFloodControl *ctrl) const {
ctrl->finish();
}
} // namespace mtproto
} // namespace td

View File

@ -16,9 +16,6 @@
#include "td/utils/StorerBase.h"
#include "td/utils/UInt.h"
#include <atomic>
#include <memory>
namespace td {
namespace mtproto_api {
@ -29,31 +26,11 @@ namespace mtproto {
class DhCallback;
class GlobalFloodControl {
public:
explicit GlobalFloodControl(uint64 limit);
struct Finish {
void operator()(GlobalFloodControl *ctrl) const;
};
using Guard = std::unique_ptr<GlobalFloodControl, Finish>;
Result<Guard> try_start();
static GlobalFloodControl *get_handshake_flood();
private:
std::atomic<uint64> active_count_{0};
uint64 limit_{0};
void finish();
};
class AuthKeyHandshakeContext {
public:
virtual ~AuthKeyHandshakeContext() = default;
virtual DhCallback *get_dh_callback() = 0;
virtual PublicRsaKeyInterface *get_public_rsa_key_interface() = 0;
virtual Status try_start() = 0;
};
class AuthKeyHandshake {

View File

@ -500,9 +500,6 @@ class TestProxyRequest final : public RequestOnceActor {
mtproto::PublicRsaKeyInterface *get_public_rsa_key_interface() final {
return &public_rsa_key;
}
Status try_start() final {
return Status::OK();
}
private:
PublicRsaKeyShared public_rsa_key{DcId::empty(), false};

View File

@ -28,6 +28,7 @@
#include "td/utils/algorithm.h"
#include "td/utils/as.h"
#include "td/utils/FloodControlGlobal.h"
#include "td/utils/format.h"
#include "td/utils/logging.h"
#include "td/utils/misc.h"
@ -79,7 +80,13 @@ class GenAuthKeyActor final : public Actor {
CancellationTokenSource cancellation_token_source_;
ActorOwn<mtproto::HandshakeActor> child_;
Status alarm_error_;
FloodControlGlobal::Guard guard_;
FloodControlGlobal *get_handshake_flood() {
constexpr uint64 MAX_CONCURRENT_HANDSHAKES = 500;
static FloodControlGlobal flood{MAX_CONCURRENT_HANDSHAKES};
return &flood;
}
void start_up() final {
// Bug in Android clang and MSVC
@ -89,13 +96,12 @@ class GenAuthKeyActor final : public Actor {
//
// TODO: we may want to use a blocking wait - semaphore but for actors.
// (problem is - multiple schedulers may want to uses this semaphore)
auto status = context_->try_start();
if (status.is_error()) {
alarm_error_ = std::move(status);
guard_ = get_handshake_flood()->try_start();
if (guard_ == nullptr) {
// Set timeout because otherwise this actor will be recreated immediately.
// Sadly, it is still O(clients_count^2) time, because all clients will keep waking up.
// Still much better than creating new connection each time.
set_timeout_in(5);
set_timeout_in(1);
return;
}
@ -134,8 +140,8 @@ class GenAuthKeyActor final : public Actor {
}
void timeout_expired() override {
CHECK(alarm_error_.is_error());
connection_promise_.set_error(std::move(alarm_error_));
CHECK(guard_ == nullptr);
connection_promise_.set_error(Status::Error(1, "Handshake limit reached"));
handshake_promise_.set_value(std::move(handshake_));
}
};
@ -1315,7 +1321,7 @@ void Session::on_handshake_ready(Result<unique_ptr<mtproto::AuthKeyHandshake>> r
} else {
auto handshake = r_handshake.move_as_ok();
if (!handshake->is_ready_for_finish()) {
LOG(WARNING) << "Handshake is not yet ready";
LOG(INFO) << "Handshake is not yet ready";
info.handshake_ = std::move(handshake);
} else {
if (is_main) {
@ -1372,15 +1378,9 @@ void Session::create_gen_auth_key_actor(HandshakeId handshake_id) {
return public_rsa_key_.get();
}
Status try_start() final {
TRY_RESULT_ASSIGN(guard_, mtproto::GlobalFloodControl::get_handshake_flood()->try_start());
return Status::OK();
}
private:
mtproto::DhCallback *dh_callback_;
std::shared_ptr<mtproto::PublicRsaKeyInterface> public_rsa_key_;
mtproto::GlobalFloodControl::Guard guard_;
};
info.actor_ = create_actor<detail::GenAuthKeyActor>(

View File

@ -99,6 +99,7 @@ set(TDUTILS_SOURCE
td/utils/filesystem.cpp
td/utils/find_boundary.cpp
td/utils/FlatHashTable.cpp
td/utils/FloodControlGlobal.cpp
td/utils/Gzip.cpp
td/utils/GzipByteFlow.cpp
td/utils/Hints.cpp
@ -213,6 +214,7 @@ set(TDUTILS_SOURCE
td/utils/FlatHashSet.h
td/utils/FlatHashTable.h
td/utils/FloodControlFast.h
td/utils/FloodControlGlobal.h
td/utils/FloodControlStrict.h
td/utils/format.h
td/utils/Gzip.h

View File

@ -0,0 +1,32 @@
//
// Copyright Aliaksei Levin (levlam@telegram.org), Arseny Smirnov (arseny30@gmail.com) 2014-2022
//
// Distributed under the Boost Software License, Version 1.0. (See accompanying
// file LICENSE_1_0.txt or copy at http://www.boost.org/LICENSE_1_0.txt)
//
#include "td/utils/FloodControlGlobal.h"
namespace td {
FloodControlGlobal::FloodControlGlobal(uint64 limit) : limit_(limit) {
}
void FloodControlGlobal::finish() {
auto old_value = active_count_.fetch_sub(1, std::memory_order_relaxed);
CHECK(old_value > 0);
}
FloodControlGlobal::Guard FloodControlGlobal::try_start() {
auto old_value = active_count_.fetch_add(1, std::memory_order_relaxed);
if (old_value >= limit_) {
finish();
return nullptr;
}
return Guard(this);
}
void FloodControlGlobal::Finish::operator()(FloodControlGlobal *ctrl) const {
ctrl->finish();
}
} // namespace td

View File

@ -0,0 +1,35 @@
//
// Copyright Aliaksei Levin (levlam@telegram.org), Arseny Smirnov (arseny30@gmail.com) 2014-2022
//
// Distributed under the Boost Software License, Version 1.0. (See accompanying
// file LICENSE_1_0.txt or copy at http://www.boost.org/LICENSE_1_0.txt)
//
#pragma once
#include "td/utils/common.h"
#include <atomic>
#include <memory>
namespace td {
// Restricts the total number of events
class FloodControlGlobal {
public:
explicit FloodControlGlobal(uint64 limit);
struct Finish {
void operator()(FloodControlGlobal *ctrl) const;
};
using Guard = std::unique_ptr<FloodControlGlobal, Finish>;
Guard try_start();
private:
std::atomic<uint64> active_count_{0};
uint64 limit_{0};
void finish();
};
} // namespace td

View File

@ -305,9 +305,6 @@ class HandshakeContext final : public td::mtproto::AuthKeyHandshakeContext {
td::mtproto::PublicRsaKeyInterface *get_public_rsa_key_interface() final {
return &public_rsa_key;
}
td::Status try_start() final {
return td::Status::OK();
}
private:
td::PublicRsaKeyShared public_rsa_key{td::DcId::empty(), true};