ByteFlow: backpressure draft

GitOrigin-RevId: 09adce82dd88fcf84f41e525e45b07da03acc0f6
This commit is contained in:
Arseny Smirnov 2020-07-22 21:52:00 +03:00
parent 50da604d65
commit 9ea1bc824f
15 changed files with 279 additions and 150 deletions

View File

@ -12,30 +12,28 @@
namespace td {
namespace mtproto {
void TlsReaderByteFlow::loop() {
while (true) {
if (input_->size() < 5) {
set_need_size(5);
return;
}
auto it = input_->clone();
uint8 buf[5];
it.advance(5, MutableSlice(buf, 5));
if (Slice(buf, 3) != Slice("\x17\x03\x03")) {
close_input(Status::Error("Invalid bytes at the beginning of a packet (emulated tls)"));
return;
}
size_t len = (buf[3] << 8) | buf[4];
if (it.size() < len) {
set_need_size(5 + len);
return;
}
output_.append(it.cut_head(len));
*input_ = std::move(it);
on_output_updated();
bool TlsReaderByteFlow::loop() {
if (input_->size() < 5) {
set_need_size(5);
return false;
}
auto it = input_->clone();
uint8 buf[5];
it.advance(5, MutableSlice(buf, 5));
if (Slice(buf, 3) != Slice("\x17\x03\x03")) {
close_input(Status::Error("Invalid bytes at the beginning of a packet (emulated tls)"));
return false;
}
size_t len = (buf[3] << 8) | buf[4];
if (it.size() < len) {
set_need_size(5 + len);
return false;
}
output_.append(it.cut_head(len));
*input_ = std::move(it);
return true;
}
} // namespace mtproto

View File

@ -13,7 +13,7 @@ namespace mtproto {
class TlsReaderByteFlow final : public ByteFlowBase {
public:
void loop() override;
bool loop() override;
};
} // namespace mtproto

View File

@ -43,6 +43,7 @@ ActorContext *&Scheduler::context() {
void Scheduler::on_context_updated() {
LOG_TAG = context_->tag_;
LOG(ERROR) << "on context updated " << context_;
}
void Scheduler::set_scheduler(Scheduler *scheduler) {

View File

@ -14,25 +14,26 @@
namespace td {
void HttpChunkedByteFlow::loop() {
bool was_updated = false;
size_t need_size;
while (true) {
bool HttpChunkedByteFlow::loop() {
bool result = false;
do {
if (state_ == State::ReadChunkLength) {
bool ok = find_boundary(input_->clone(), "\r\n", len_);
if (len_ > 10) {
return finish(Status::Error(PSLICE() << "Too long length in chunked "
<< input_->cut_head(len_).move_as_buffer_slice().as_slice()));
finish(Status::Error(PSLICE() << "Too long length in chunked "
<< input_->cut_head(len_).move_as_buffer_slice().as_slice()));
return false;
}
if (!ok) {
need_size = input_->size() + 1;
set_need_size(input_->size() + 1);
break;
}
auto s_len = input_->cut_head(len_).move_as_buffer_slice();
input_->advance(2);
len_ = hex_to_integer<size_t>(s_len.as_slice());
if (len_ > MAX_CHUNK_SIZE) {
return finish(Status::Error(PSLICE() << "Invalid chunk size " << tag("size", len_)));
finish(Status::Error(PSLICE() << "Invalid chunk size " << tag("size", len_)));
return false;
}
save_len_ = len_;
state_ = State::ReadChunkContent;
@ -40,21 +41,23 @@ void HttpChunkedByteFlow::loop() {
auto size = input_->size();
auto ready = min(len_, size);
need_size = min(MIN_UPDATE_SIZE, len_ + 2);
auto need_size = min(MIN_UPDATE_SIZE, len_ + 2);
if (size < need_size) {
set_need_size(need_size);
break;
}
total_size_ += ready;
uncommited_size_ += ready;
if (total_size_ > MAX_SIZE) {
return finish(Status::Error(PSLICE() << "Too big query " << tag("size", input_->size())));
finish(Status::Error(PSLICE() << "Too big query " << tag("size", input_->size())));
return false;
}
output_.append(input_->cut_head(ready));
result = true;
len_ -= ready;
if (uncommited_size_ >= MIN_UPDATE_SIZE) {
uncommited_size_ = 0;
was_updated = true;
}
if (len_ == 0) {
@ -65,19 +68,17 @@ void HttpChunkedByteFlow::loop() {
input_->advance(2);
total_size_ += 2;
if (save_len_ == 0) {
return finish(Status::OK());
finish(Status::OK());
return false;
}
state_ = State::ReadChunkLength;
len_ = 0;
}
} while (0);
if (!is_input_active_ && !result) {
finish(Status::Error("Unexpected end of stream"));
}
if (was_updated) {
on_output_updated();
}
if (!is_input_active_) {
return finish(Status::Error("Unexpected end of stream"));
}
set_need_size(need_size);
return result;
}
} // namespace td

View File

@ -12,7 +12,7 @@ namespace td {
class HttpChunkedByteFlow final : public ByteFlowBase {
public:
void loop() override;
bool loop() override;
private:
static constexpr int MAX_CHUNK_SIZE = 15 << 20; // some reasonable limit

View File

@ -10,7 +10,7 @@
namespace td {
void HttpContentLengthByteFlow::loop() {
bool HttpContentLengthByteFlow::loop() {
auto ready_size = input_->size();
if (ready_size > len_) {
ready_size = len_;
@ -18,17 +18,19 @@ void HttpContentLengthByteFlow::loop() {
auto need_size = min(MIN_UPDATE_SIZE, len_);
if (ready_size < need_size) {
set_need_size(need_size);
return;
return false;
}
output_.append(input_->cut_head(ready_size));
len_ -= ready_size;
if (len_ == 0) {
return finish(Status::OK());
finish(Status::OK());
return false;
}
if (!is_input_active_) {
return finish(Status::Error("Unexpected end of stream"));
finish(Status::Error("Unexpected end of stream"));
return false;
}
on_output_updated();
return true;
}
} // namespace td

View File

@ -15,7 +15,7 @@ class HttpContentLengthByteFlow final : public ByteFlowBase {
HttpContentLengthByteFlow() = default;
explicit HttpContentLengthByteFlow(size_t len) : len_(len) {
}
void loop() override;
bool loop() override;
private:
static constexpr size_t MIN_UPDATE_SIZE = 1 << 14;

View File

@ -68,6 +68,7 @@ Result<size_t> HttpReader::read_next(HttpQuery *query) {
size_t need_size = input_->size() + 1;
while (true) {
if (state_ != State::ReadHeaders) {
gzip_flow_.wakeup();
flow_source_.wakeup();
if (flow_sink_.is_ready() && flow_sink_.status().is_error()) {
if (!temp_file_.empty()) {
@ -108,7 +109,11 @@ Result<size_t> HttpReader::read_next(HttpQuery *query) {
if (content_encoding_.empty()) {
} else if (content_encoding_ == "gzip" || content_encoding_ == "deflate") {
gzip_flow_ = GzipByteFlow(Gzip::Mode::Decode);
gzip_flow_.set_max_output_size(MAX_CONTENT_SIZE);
GzipByteFlow::Options options;
options.write_watermark.low = 0;
options.write_watermark.hight = max_post_size_ + 10;
gzip_flow_.set_options(options);
//gzip_flow_.set_max_output_size(MAX_CONTENT_SIZE);
*source >> gzip_flow_;
source = &gzip_flow_;
} else {
@ -170,6 +175,10 @@ Result<size_t> HttpReader::read_next(HttpQuery *query) {
case State::ReadContent: {
if (content_->size() > max_post_size_) {
state_ = State::ReadContentToFile;
GzipByteFlow::Options options;
options.write_watermark.low = 4 << 20;
options.write_watermark.hight = 8 << 20;
gzip_flow_.set_options(options);
continue;
}
if (flow_sink_.is_ready()) {
@ -191,14 +200,19 @@ Result<size_t> HttpReader::read_next(HttpQuery *query) {
}
auto size = content_->size();
if (size) {
bool restart = false;
if (size > (1 << 20) || flow_sink_.is_ready()) {
TRY_STATUS(save_file_part(content_->cut_head(size).move_as_buffer_slice()));
restart = true;
}
if (flow_sink_.is_ready()) {
query_->files_.emplace_back("file", "", content_type_.str(), file_size_, temp_file_name_);
close_temp_file();
break;
}
if (restart) {
continue;
}
return need_size;
}

View File

@ -387,24 +387,19 @@ class SslStreamImpl {
public:
explicit SslReadByteFlow(SslStreamImpl *stream) : stream_(stream) {
}
void loop() override {
bool was_append = false;
while (true) {
auto to_read = output_.prepare_append();
auto r_size = stream_->read(to_read);
if (r_size.is_error()) {
return finish(r_size.move_as_error());
}
auto size = r_size.move_as_ok();
if (size == 0) {
break;
}
output_.confirm_append(size);
was_append = true;
bool loop() override {
auto to_read = output_.prepare_append();
auto r_size = stream_->read(to_read);
if (r_size.is_error()) {
finish(r_size.move_as_error());
return false;
}
if (was_append) {
on_output_updated();
auto size = r_size.move_as_ok();
if (size == 0) {
return false;
}
output_.confirm_append(size);
return true;
}
size_t read(MutableSlice data) {
@ -419,34 +414,28 @@ class SslStreamImpl {
public:
explicit SslWriteByteFlow(SslStreamImpl *stream) : stream_(stream) {
}
void loop() override {
while (!input_->empty()) {
auto to_write = input_->prepare_read();
auto r_size = stream_->write(to_write);
if (r_size.is_error()) {
return finish(r_size.move_as_error());
}
auto size = r_size.move_as_ok();
if (size == 0) {
break;
}
input_->confirm_read(size);
bool loop() override {
auto to_write = input_->prepare_read();
auto r_size = stream_->write(to_write);
if (r_size.is_error()) {
finish(r_size.move_as_error());
return false;
}
if (output_updated_) {
output_updated_ = false;
on_output_updated();
auto size = r_size.move_as_ok();
if (size == 0) {
return false;
}
input_->confirm_read(size);
return true;
}
size_t write(Slice data) {
output_.append(data);
output_updated_ = true;
return data.size();
}
private:
SslStreamImpl *stream_;
bool output_updated_{false};
};
SslReadByteFlow read_flow_{this};

View File

@ -27,25 +27,20 @@ class AesCtrByteFlow : public ByteFlowInplaceBase {
AesCtrState move_aes_ctr_state() {
return std::move(state_);
}
void loop() override {
bool was_updated = false;
while (true) {
auto ready = input_->prepare_read();
if (ready.empty()) {
break;
}
bool loop() override {
bool result = false;
auto ready = input_->prepare_read();
if (!ready.empty()) {
state_.encrypt(ready, MutableSlice(const_cast<char *>(ready.data()), ready.size()));
input_->confirm_read(ready.size());
output_.advance_end(ready.size());
was_updated = true;
}
if (was_updated) {
on_output_updated();
result = true;
}
if (!is_input_active_) {
finish(Status::OK()); // End of input stream.
}
set_need_size(1);
return result;
}
private:

View File

@ -9,6 +9,7 @@
#include "td/utils/buffer.h"
#include "td/utils/common.h"
#include "td/utils/Status.h"
#include <limits>
namespace td {
@ -19,6 +20,8 @@ class ByteFlowInterface {
virtual void set_parent(ByteFlowInterface &other) = 0;
virtual void set_input(ChainBufferReader *input) = 0;
virtual size_t get_need_size() = 0;
virtual size_t get_read_size() = 0;
virtual size_t get_write_size() = 0;
ByteFlowInterface() = default;
ByteFlowInterface(const ByteFlowInterface &) = delete;
ByteFlowInterface &operator=(const ByteFlowInterface &) = delete;
@ -45,32 +48,92 @@ class ByteFlowBaseCommon : public ByteFlowInterface {
return;
}
input_->sync_with_writer();
if (waiting_flag_) {
if (!is_input_active_) {
finish(Status::OK());
}
return;
}
if (is_input_active_) {
if (need_size_ != 0 && input_->size() < need_size_) {
return;
while (true) {
if (stop_flag_) {
break;
}
// update can_read
if (is_input_active_) {
auto read_size = get_read_size();
if (read_size < min(need_size_, options_.read_watermark.low)) {
can_read = false;
}
if (read_size >= max(need_size_, options_.read_watermark.hight)) {
can_read = true;
}
} else {
//Alway can read when input is closed
can_read = true;
}
// update can_write
{
auto write_size = get_write_size();
if (write_size > options_.write_watermark.hight) {
can_write = false;
}
if (write_size <= options_.write_watermark.low) {
can_write = true;
}
}
if (!can_read || !can_write) {
break;
}
need_size_ = 0;
if (!loop()) {
if (need_size_ <= get_read_size()) {
need_size_ = get_read_size() + 1;
}
}
}
need_size_ = 0;
loop();
on_output_updated();
}
size_t get_need_size() final {
return need_size_;
}
size_t get_read_size() override {
input_->sync_with_writer();
return input_->size();
}
size_t get_write_size() override {
CHECK(parent_);
return parent_->get_read_size();
}
virtual void loop() = 0;
struct Watermark {
size_t low{std::numeric_limits<size_t>::max()};
size_t hight{0};
};
struct Options {
Watermark write_watermark;
Watermark read_watermark;
};
void set_options(Options options) {
options_ = options;
}
virtual bool loop() = 0;
protected:
bool waiting_flag_ = false;
ChainBufferReader *input_ = nullptr;
bool is_input_active_ = true;
size_t need_size_ = 0;
bool can_read{true};
bool can_write{true};
Options options_;
void finish(Status status) {
stop_flag_ = true;
need_size_ = 0;
@ -114,7 +177,7 @@ class ByteFlowBase : public ByteFlowBaseCommon {
parent_ = &other;
parent_->set_input(&output_reader_);
}
void loop() override = 0;
bool loop() override = 0;
// ChainBufferWriter &get_output() {
// return output_;
@ -137,7 +200,7 @@ class ByteFlowInplaceBase : public ByteFlowBaseCommon {
parent_ = &other;
parent_->set_input(&output_);
}
void loop() override = 0;
bool loop() override = 0;
ChainBufferReader &get_output() {
return output_;
@ -195,6 +258,14 @@ class ByteFlowSource : public ByteFlowInterface {
}
return parent_->get_need_size();
}
size_t get_read_size() final {
UNREACHABLE();
return 0;
}
size_t get_write_size() final {
CHECK(parent_);
return parent_->get_read_size();
}
private:
ChainBufferReader *buffer_ = nullptr;
@ -223,6 +294,14 @@ class ByteFlowSink : public ByteFlowInterface {
UNREACHABLE();
return 0;
}
size_t get_read_size() final {
buffer_->sync_with_writer();
return buffer_->size();
}
size_t get_write_size() final {
UNREACHABLE();
return 0;
}
bool is_ready() {
return !active_;
}
@ -270,6 +349,15 @@ class ByteFlowMoveSink : public ByteFlowInterface {
UNREACHABLE();
return 0;
}
size_t get_read_size() final {
input_->sync_with_writer();
//TODO: must be input_->size() + output_->size()
return input_->size();
}
size_t get_write_size() final {
UNREACHABLE();
return 0;
}
void set_output(ChainBufferWriter *output) {
CHECK(!output_);
output_ = output;

View File

@ -14,57 +14,54 @@ char disable_linker_warning_about_empty_file_gzipbyteflow_cpp TD_UNUSED;
namespace td {
void GzipByteFlow::loop() {
while (true) {
if (gzip_.need_input()) {
auto slice = input_->prepare_read();
if (slice.empty()) {
if (!is_input_active_) {
gzip_.close_input();
} else {
break;
}
bool GzipByteFlow::loop() {
bool result = false;
if (gzip_.need_input()) {
auto slice = input_->prepare_read();
if (slice.empty()) {
if (!is_input_active_) {
gzip_.close_input();
} else {
gzip_.set_input(input_->prepare_read());
return result;
}
} else {
gzip_.set_input(input_->prepare_read());
}
if (gzip_.need_output()) {
auto slice = output_.prepare_append();
CHECK(!slice.empty());
gzip_.set_output(slice);
}
auto r_state = gzip_.run();
auto output_size = gzip_.flush_output();
if (output_size) {
uncommited_size_ += output_size;
total_output_size_ += output_size;
if (total_output_size_ > max_output_size_) {
return finish(Status::Error("Max output size limit exceeded"));
}
output_.confirm_append(output_size);
}
if (gzip_.need_output()) {
auto slice = output_.prepare_append();
CHECK(!slice.empty());
gzip_.set_output(slice);
}
auto r_state = gzip_.run();
auto output_size = gzip_.flush_output();
if (output_size) {
uncommited_size_ += output_size;
total_output_size_ += output_size;
if (total_output_size_ > max_output_size_) {
finish(Status::Error("Max output size limit exceeded"));
return result;
}
output_.confirm_append(output_size);
result = true;
}
auto input_size = gzip_.flush_input();
if (input_size) {
input_->confirm_read(input_size);
}
if (r_state.is_error()) {
return finish(r_state.move_as_error());
}
auto state = r_state.ok();
if (state == Gzip::State::Done) {
on_output_updated();
return consume_input();
}
auto input_size = gzip_.flush_input();
if (input_size) {
input_->confirm_read(input_size);
}
if (uncommited_size_ >= MIN_UPDATE_SIZE) {
uncommited_size_ = 0;
on_output_updated();
if (r_state.is_error()) {
finish(r_state.move_as_error());
return false;
}
auto state = r_state.ok();
if (state == Gzip::State::Done) {
consume_input();
return false;
}
return result;
}
constexpr size_t GzipByteFlow::MIN_UPDATE_SIZE;
} // namespace td
#endif

View File

@ -34,14 +34,13 @@ class GzipByteFlow final : public ByteFlowBase {
max_output_size_ = max_output_size;
}
void loop() override;
bool loop() override;
private:
Gzip gzip_;
size_t uncommited_size_ = 0;
size_t total_output_size_ = 0;
size_t max_output_size_ = std::numeric_limits<size_t>::max();
static constexpr size_t MIN_UPDATE_SIZE = 1 << 14;
};
#endif

View File

@ -372,3 +372,48 @@ TEST(Http, gzip_chunked_flow) {
ASSERT_TRUE(sink.status().is_ok());
ASSERT_EQ(str, sink.result()->move_as_buffer_slice().as_slice().str());
}
TEST(Http, gzip_bomb_with_limit) {
std::string gzip_bomb_str;
{
ChainBufferWriter input_writer;
auto input = input_writer.extract_reader();
GzipByteFlow gzip_flow(Gzip::Mode::Encode);
ByteFlowSource source(&input);
ByteFlowSink sink;
source >> gzip_flow >> sink;
std::string s(1 << 20, 'a');
for (int i = 0; i < 2000; i++) {
input_writer.append(s);
source.wakeup();
}
source.close_input(Status::OK());
ASSERT_TRUE(sink.is_ready());
LOG_IF(ERROR, sink.status().is_error()) << sink.status();
ASSERT_TRUE(sink.status().is_ok());
gzip_bomb_str = sink.result()->move_as_buffer_slice().as_slice().str();
}
auto query = make_http_query("", false, true, 0.01, gzip_bomb_str);
auto parts = rand_split(query);
td::ChainBufferWriter input_writer;
auto input = input_writer.extract_reader();
HttpReader reader;
HttpQuery q;
reader.init(&input, 1000000);
bool ok = false;
for (auto &part : parts) {
input_writer.append(part);
input.sync_with_writer();
auto r_state = reader.read_next(&q);
if (r_state.is_error()) {
LOG(FATAL) << r_state.error();
return;
} else if (r_state.ok() == 0) {
LOG(ERROR) << q;
ok = true;
}
}
ASSERT_TRUE(ok);
}

View File

@ -897,7 +897,7 @@ class Master : public Actor {
ActorShared<NetQueryCallback> callback) {
BufferSlice answer(8);
answer.as_slice().fill(0);
as<int32>(answer.as_slice().begin()) = my_api::messages_sentEncryptedMessage::ID;
as<int32>(answer.as_slice().begin()) = static_cast<int32>(my_api::messages_sentEncryptedMessage::ID);
net_query->set_ok(std::move(answer));
send_closure(std::move(callback), &NetQueryCallback::on_result, std::move(net_query));