Ssl refactoring

GitOrigin-RevId: f5916787608227b6914c10520dfe7a7039522ef9
This commit is contained in:
Arseny Smirnov 2018-08-15 15:41:42 +03:00
parent 7fc96ddff5
commit ab2b189722
15 changed files with 611 additions and 548 deletions

View File

@ -27,9 +27,9 @@ class HttpClient : public HttpOutboundConnection::Callback {
addr.init_ipv4_port("127.0.0.1", 8082).ensure();
auto fd = SocketFd::open(addr);
CHECK(fd.is_ok()) << fd.error();
connection_ =
create_actor<HttpOutboundConnection>("Connect", fd.move_as_ok(), std::numeric_limits<size_t>::max(), 0, 0,
ActorOwn<HttpOutboundConnection::Callback>(actor_id(this)));
connection_ = create_actor<HttpOutboundConnection>("Connect", fd.move_as_ok(), SslStream{},
std::numeric_limits<size_t>::max(), 0, 0,
ActorOwn<HttpOutboundConnection::Callback>(actor_id(this)));
yield();
cnt_ = 100000;
counter++;

View File

@ -18,7 +18,7 @@
#include <string>
int main(int argc, char *argv[]) {
SET_VERBOSITY_LEVEL(VERBOSITY_NAME(INFO));
SET_VERBOSITY_LEVEL(VERBOSITY_NAME(DEBUG));
td::VERBOSITY_NAME(fd) = VERBOSITY_NAME(INFO);
std::string url = (argc > 1 ? argv[1] : "https://telegram.org");

View File

@ -20,7 +20,7 @@
#if !TD_EMSCRIPTEN //FIXME
#include "td/net/HttpQuery.h"
#include "td/net/SslFd.h"
#include "td/net/SslStream.h"
#include "td/net/Wget.h"
#endif
@ -126,7 +126,7 @@ static ActorOwn<> get_simple_config_impl(Promise<SimpleConfig> promise, int32 sc
}());
}),
std::move(url), std::vector<std::pair<string, string>>({{"Host", std::move(host)}}), timeout, ttl, prefer_ipv6,
SslFd::VerifyPeer::Off));
SslStream::VerifyPeer::Off));
#endif
}
@ -186,7 +186,7 @@ ActorOwn<> get_simple_config_google_dns(Promise<SimpleConfig> promise, const Con
}),
PSTRING() << "https://www.google.com/resolve?name=" << url_encode(name) << "&type=16",
std::vector<std::pair<string, string>>({{"Host", "dns.google.com"}}), timeout, ttl, prefer_ipv6,
SslFd::VerifyPeer::Off));
SslStream::VerifyPeer::Off));
#endif
}

View File

@ -18,7 +18,7 @@ set(TDNET_SOURCE
td/net/HttpQuery.cpp
td/net/HttpReader.cpp
td/net/Socks5.cpp
td/net/SslFd.cpp
td/net/SslStream.cpp
td/net/TcpListener.cpp
td/net/TransparentProxy.cpp
td/net/Wget.cpp
@ -36,7 +36,7 @@ set(TDNET_SOURCE
td/net/HttpReader.h
td/net/NetStats.h
td/net/Socks5.h
td/net/SslFd.h
td/net/SslStream.h
td/net/TcpListener.h
td/net/TransparentProxy.h
td/net/Wget.h

View File

@ -14,14 +14,23 @@
namespace td {
namespace detail {
HttpConnectionBase::HttpConnectionBase(State state, FdProxy fd, size_t max_post_size, size_t max_files,
int32 idle_timeout)
HttpConnectionBase::HttpConnectionBase(State state, SocketFd fd, SslStream ssl_stream, size_t max_post_size,
size_t max_files, int32 idle_timeout)
: state_(state)
, stream_connection_(std::move(fd))
, fd_(std::move(fd))
, ssl_stream_(std::move(ssl_stream))
, max_post_size_(max_post_size)
, max_files_(max_files)
, idle_timeout_(idle_timeout) {
CHECK(state_ != State::Close);
if (ssl_stream_) {
read_source_ >> ssl_stream_.read_byte_flow() >> read_sink_;
write_source_ >> ssl_stream_.write_byte_flow() >> write_sink_;
} else {
read_source_ >> read_sink_;
write_source_ >> write_sink_;
}
}
void HttpConnectionBase::live_event() {
@ -31,9 +40,9 @@ void HttpConnectionBase::live_event() {
}
void HttpConnectionBase::start_up() {
stream_connection_.get_fd().set_observer(this);
subscribe(stream_connection_.get_fd());
reader_.init(&stream_connection_.input_buffer(), max_post_size_, max_files_);
fd_.get_fd().set_observer(this);
subscribe(fd_.get_fd());
reader_.init(read_sink_.get_output(), max_post_size_, max_files_);
if (state_ == State::Read) {
current_query_ = make_unique<HttpQuery>();
}
@ -41,13 +50,13 @@ void HttpConnectionBase::start_up() {
yield();
}
void HttpConnectionBase::tear_down() {
unsubscribe_before_close(stream_connection_.get_fd());
stream_connection_.close();
unsubscribe_before_close(fd_.get_fd());
fd_.close();
}
void HttpConnectionBase::write_next(BufferSlice buffer) {
CHECK(state_ == State::Write);
stream_connection_.output_buffer().append(std::move(buffer));
write_buffer_.append(std::move(buffer));
loop();
}
@ -69,7 +78,7 @@ void HttpConnectionBase::write_error(Status error) {
void HttpConnectionBase::timeout_expired() {
LOG(INFO) << "Idle timeout expired";
if (stream_connection_.need_flush_write()) {
if (fd_.need_flush_write()) {
on_error(Status::Error("Write timeout expired"));
} else if (state_ == State::Read) {
on_error(Status::Error("Read timeout expired"));
@ -78,9 +87,9 @@ void HttpConnectionBase::timeout_expired() {
stop();
}
void HttpConnectionBase::loop() {
if (can_read(stream_connection_)) {
if (can_read(fd_)) {
LOG(DEBUG) << "Can read from the connection";
auto r = stream_connection_.flush_read();
auto r = fd_.flush_read();
if (r.is_error()) {
if (!begins_with(r.error().message(), "SSL error {336134278")) { // if error is not yet outputed
LOG(INFO) << "flush_read error: " << r.error();
@ -89,6 +98,7 @@ void HttpConnectionBase::loop() {
return stop();
}
}
read_source_.wakeup();
// TODO: read_next even when state_ == State::Write
@ -102,7 +112,7 @@ void HttpConnectionBase::loop() {
HttpHeaderCreator hc;
hc.init_status_line(res.error().code());
hc.set_content_size(0);
stream_connection_.output_buffer().append(hc.finish().ok());
write_buffer_.append(hc.finish().ok());
close_after_write_ = true;
on_error(Status::Error(res.error().public_message()));
} else if (res.ok() == 0) {
@ -115,34 +125,45 @@ void HttpConnectionBase::loop() {
}
}
if (can_write(stream_connection_)) {
write_source_.wakeup();
if (can_write(fd_)) {
LOG(DEBUG) << "Can write to the connection";
auto r = stream_connection_.flush_write();
auto r = fd_.flush_write();
if (r.is_error()) {
LOG(INFO) << "flush_write error: " << r.error();
on_error(Status::Error(r.error().public_message()));
}
if (close_after_write_ && !stream_connection_.need_flush_write()) {
if (close_after_write_ && !fd_.need_flush_write()) {
return stop();
}
}
if (stream_connection_.get_fd().has_pending_error()) {
auto pending_error = stream_connection_.get_pending_error();
Status pending_error;
if (fd_.get_fd().has_pending_error()) {
pending_error = fd_.get_pending_error();
}
if (pending_error.is_ok() && write_sink_.status().is_error()) {
pending_error = std::move(write_sink_.status());
}
if (pending_error.is_ok() && read_sink_.status().is_error()) {
pending_error = std::move(read_sink_.status());
}
if (pending_error.is_error()) {
LOG(INFO) << pending_error;
if (!close_after_write_) {
on_error(Status::Error(pending_error.public_message()));
}
state_ = State::Close;
}
if (can_close(stream_connection_)) {
if (can_close(fd_)) {
LOG(DEBUG) << "Can close the connection";
state_ = State::Close;
}
if (state_ == State::Close) {
LOG_IF(INFO, stream_connection_.need_flush_write()) << "Close nonempty connection";
LOG_IF(INFO, want_read &&
(stream_connection_.input_buffer().size() > 0 || current_query_->type_ != HttpQuery::Type::EMPTY))
LOG_IF(INFO, fd_.need_flush_write()) << "Close nonempty connection";
LOG_IF(INFO, want_read && (fd_.input_buffer().size() > 0 || current_query_->type_ != HttpQuery::Type::EMPTY))
<< "Close connection while reading request/response";
return stop();
}

View File

@ -10,120 +10,16 @@
#include "td/net/HttpQuery.h"
#include "td/net/HttpReader.h"
#include "td/net/SslStream.h"
#include "td/utils/buffer.h"
#include "td/utils/BufferedFd.h"
#include "td/utils/port/Fd.h"
#include "td/utils/port/SocketFd.h"
#include "td/utils/Slice.h"
#include "td/utils/Status.h"
namespace td {
class FdInterface {
public:
FdInterface() = default;
FdInterface(const FdInterface &) = delete;
FdInterface &operator=(const FdInterface &) = delete;
FdInterface(FdInterface &&) = default;
FdInterface &operator=(FdInterface &&) = default;
virtual ~FdInterface() = default;
virtual const Fd &get_fd() const = 0;
virtual Fd &get_fd() = 0;
virtual int32 get_flags() const = 0;
virtual Status get_pending_error() TD_WARN_UNUSED_RESULT = 0;
virtual Result<size_t> write(Slice slice) TD_WARN_UNUSED_RESULT = 0;
virtual Result<size_t> read(MutableSlice slice) TD_WARN_UNUSED_RESULT = 0;
virtual void close() = 0;
virtual bool empty() const = 0;
};
template <class FdT>
class FdToInterface : public FdInterface {
public:
FdToInterface() = default;
explicit FdToInterface(FdT fd) : fd_(std::move(fd)) {
}
const Fd &get_fd() const final {
return fd_.get_fd();
}
Fd &get_fd() final {
return fd_.get_fd();
}
int32 get_flags() const final {
return fd_.get_flags();
}
Status get_pending_error() final TD_WARN_UNUSED_RESULT {
return fd_.get_pending_error();
}
Result<size_t> write(Slice slice) final TD_WARN_UNUSED_RESULT {
return fd_.write(slice);
}
Result<size_t> read(MutableSlice slice) final TD_WARN_UNUSED_RESULT {
return fd_.read(slice);
}
void close() final {
fd_.close();
}
bool empty() const final {
return fd_.empty();
}
private:
FdT fd_;
};
template <class FdT>
std::unique_ptr<FdInterface> make_fd_interface(FdT fd) {
return make_unique<FdToInterface<FdT>>(std::move(fd));
}
class FdProxy {
public:
FdProxy() = default;
explicit FdProxy(std::unique_ptr<FdInterface> fd) : fd_(std::move(fd)) {
}
const Fd &get_fd() const {
return fd_->get_fd();
}
Fd &get_fd() {
return fd_->get_fd();
}
int32 get_flags() const {
return fd_->get_flags();
}
Status get_pending_error() TD_WARN_UNUSED_RESULT {
return fd_->get_pending_error();
}
Result<size_t> write(Slice slice) TD_WARN_UNUSED_RESULT {
return fd_->write(slice);
}
Result<size_t> read(MutableSlice slice) TD_WARN_UNUSED_RESULT {
return fd_->read(slice);
}
void close() {
fd_->close();
}
bool empty() const {
return fd_->empty();
}
private:
std::unique_ptr<FdInterface> fd_;
};
template <class FdT>
FdProxy make_fd_proxy(FdT fd) {
return FdProxy(make_fd_interface(std::move(fd)));
}
namespace detail {
class HttpConnectionBase : public Actor {
public:
@ -133,16 +29,23 @@ class HttpConnectionBase : public Actor {
protected:
enum class State { Read, Write, Close };
template <class FdT>
HttpConnectionBase(State state, FdT fd, size_t max_post_size, size_t max_files, int32 idle_timeout)
: HttpConnectionBase(state, make_fd_proxy(std::move(fd)), max_post_size, max_files, idle_timeout) {
}
HttpConnectionBase(State state, FdProxy fd, size_t max_post_size, size_t max_files, int32 idle_timeout);
HttpConnectionBase(State state, SocketFd fd, SslStream ssl_stream, size_t max_post_size, size_t max_files,
int32 idle_timeout);
private:
using StreamConnection = BufferedFd<FdProxy>;
State state_;
StreamConnection stream_connection_;
BufferedFd<SocketFd> fd_;
SslStream ssl_stream_;
ByteFlowSource read_source_{&fd_.input_buffer()};
ByteFlowSink read_sink_;
ChainBufferWriter write_buffer_;
ChainBufferReader write_buffer_reader_ = write_buffer_.extract_reader();
ByteFlowSource write_source_{&write_buffer_reader_};
ByteFlowMoveSink write_sink_{&fd_.output_buffer()};
size_t max_post_size_;
size_t max_files_;
int32 idle_timeout_;

View File

@ -12,7 +12,7 @@ namespace td {
// HttpInboundConnection implementation
HttpInboundConnection::HttpInboundConnection(SocketFd fd, size_t max_post_size, size_t max_files, int32 idle_timeout,
ActorShared<Callback> callback)
: HttpConnectionBase(State::Read, std::move(fd), max_post_size, max_files, idle_timeout)
: HttpConnectionBase(State::Read, std::move(fd), SslStream(), max_post_size, max_files, idle_timeout)
, callback_(std::move(callback)) {
}

View File

@ -22,10 +22,10 @@ class HttpOutboundConnection final : public detail::HttpConnectionBase {
virtual void handle(HttpQueryPtr query) = 0;
virtual void on_connection_error(Status error) = 0; // TODO rename to on_error
};
template <class FdT>
HttpOutboundConnection(FdT fd, size_t max_post_size, size_t max_files, int32 idle_timeout,
HttpOutboundConnection(SocketFd fd, SslStream ssl_stream, size_t max_post_size, size_t max_files, int32 idle_timeout,
ActorShared<Callback> callback)
: HttpConnectionBase(HttpConnectionBase::State::Write, std::move(fd), max_post_size, max_files, idle_timeout)
: HttpConnectionBase(HttpConnectionBase::State::Write, std::move(fd), std::move(ssl_stream), max_post_size,
max_files, idle_timeout)
, callback_(std::move(callback)) {
}
// Inherited interface

View File

@ -1,279 +0,0 @@
//
// Copyright Aliaksei Levin (levlam@telegram.org), Arseny Smirnov (arseny30@gmail.com) 2014-2018
//
// Distributed under the Boost Software License, Version 1.0. (See accompanying
// file LICENSE_1_0.txt or copy at http://www.boost.org/LICENSE_1_0.txt)
//
#include "td/net/SslFd.h"
#include "td/utils/logging.h"
#include "td/utils/StackAllocator.h"
#include "td/utils/StringBuilder.h"
#include "td/utils/Time.h"
#include <openssl/err.h>
#include <openssl/evp.h>
#include <openssl/ssl.h>
#include <openssl/x509v3.h>
#include <map>
#include <mutex>
namespace td {
#if !TD_WINDOWS
static int verify_callback(int preverify_ok, X509_STORE_CTX *ctx) {
if (!preverify_ok) {
char buf[256];
X509_NAME_oneline(X509_get_subject_name(X509_STORE_CTX_get_current_cert(ctx)), buf, 256);
int err = X509_STORE_CTX_get_error(ctx);
auto warning = PSTRING() << "verify error:num=" << err << ":" << X509_verify_cert_error_string(err)
<< ":depth=" << X509_STORE_CTX_get_error_depth(ctx) << ":" << buf;
double now = Time::now();
static std::mutex warning_mutex;
{
std::lock_guard<std::mutex> lock(warning_mutex);
static std::map<std::string, double> next_warning_time;
double &next = next_warning_time[warning];
if (next <= now) {
next = now + 300; // one warning per 5 minutes
LOG(WARNING) << warning;
}
}
}
return preverify_ok;
}
#endif
namespace {
Status create_openssl_error(int code, Slice message) {
const int buf_size = 1 << 12;
auto buf = StackAllocator::alloc(buf_size);
StringBuilder sb(buf.as_slice());
sb << message;
while (unsigned long error_code = ERR_get_error()) {
sb << "{" << error_code << ", " << ERR_error_string(error_code, nullptr) << "}";
}
LOG_IF(ERROR, sb.is_error()) << "OPENSSL error buffer overflow";
return Status::Error(code, sb.as_cslice());
}
void openssl_clear_errors(Slice from) {
if (ERR_peek_error() != 0) {
LOG(ERROR) << from << ": " << create_openssl_error(0, "Unprocessed OPENSSL_ERROR");
}
errno = 0;
}
void do_ssl_shutdown(SSL *ssl_handle) {
if (!SSL_is_init_finished(ssl_handle)) {
return;
}
openssl_clear_errors("Before SSL_shutdown");
SSL_set_quiet_shutdown(ssl_handle, 1);
SSL_shutdown(ssl_handle);
openssl_clear_errors("After SSL_shutdown");
}
} // namespace
SslFd::SslFd(SocketFd &&fd, SSL *ssl_handle_, SSL_CTX *ssl_ctx_)
: fd_(std::move(fd)), ssl_handle_(ssl_handle_), ssl_ctx_(ssl_ctx_) {
}
Result<SslFd> SslFd::init(SocketFd fd, CSlice host, CSlice cert_file, VerifyPeer verify_peer) {
#if TD_WINDOWS
return Status::Error("TODO");
#else
static bool init_openssl = [] {
#if OPENSSL_VERSION_NUMBER >= 0x10100000L
return OPENSSL_init_ssl(0, nullptr) != 0;
#else
OpenSSL_add_all_algorithms();
SSL_load_error_strings();
return OpenSSL_add_ssl_algorithms() != 0;
#endif
}();
CHECK(init_openssl);
openssl_clear_errors("Before SslFd::init");
CHECK(!fd.empty());
auto ssl_method =
#if OPENSSL_VERSION_NUMBER >= 0x10100000L
TLS_client_method();
#else
SSLv23_client_method();
#endif
if (ssl_method == nullptr) {
return create_openssl_error(-6, "Failed to create an SSL client method");
}
auto ssl_ctx = SSL_CTX_new(ssl_method);
if (ssl_ctx == nullptr) {
return create_openssl_error(-7, "Failed to create an SSL context");
}
auto ssl_ctx_guard = ScopeExit() + [&]() { SSL_CTX_free(ssl_ctx); };
long options = 0;
#ifdef SSL_OP_NO_SSLv2
options |= SSL_OP_NO_SSLv2;
#endif
#ifdef SSL_OP_NO_SSLv3
options |= SSL_OP_NO_SSLv3;
#endif
SSL_CTX_set_options(ssl_ctx, options);
SSL_CTX_set_mode(ssl_ctx, SSL_MODE_ACCEPT_MOVING_WRITE_BUFFER | SSL_MODE_ENABLE_PARTIAL_WRITE);
if (cert_file.empty()) {
SSL_CTX_set_default_verify_paths(ssl_ctx);
} else {
if (SSL_CTX_load_verify_locations(ssl_ctx, cert_file.c_str(), nullptr) == 0) {
return create_openssl_error(-8, "Failed to set custom cert file");
}
}
if (VERIFY_PEER && verify_peer == VerifyPeer::On) {
SSL_CTX_set_verify(ssl_ctx, SSL_VERIFY_PEER, verify_callback);
if (VERIFY_DEPTH != -1) {
SSL_CTX_set_verify_depth(ssl_ctx, VERIFY_DEPTH);
}
} else {
SSL_CTX_set_verify(ssl_ctx, SSL_VERIFY_NONE, nullptr);
}
// TODO(now): cipher list
string cipher_list;
if (SSL_CTX_set_cipher_list(ssl_ctx, cipher_list.empty() ? "DEFAULT" : cipher_list.c_str()) == 0) {
return create_openssl_error(-9, PSLICE() << "Failed to set cipher list \"" << cipher_list << '"');
}
auto ssl_handle = SSL_new(ssl_ctx);
if (ssl_handle == nullptr) {
return create_openssl_error(-13, "Failed to create an SSL handle");
}
auto ssl_handle_guard = ScopeExit() + [&]() {
do_ssl_shutdown(ssl_handle);
SSL_free(ssl_handle);
};
#if OPENSSL_VERSION_NUMBER >= 0x10002000L
X509_VERIFY_PARAM *param = SSL_get0_param(ssl_handle);
/* Enable automatic hostname checks */
// TODO: X509_CHECK_FLAG_NO_PARTIAL_WILDCARDS
X509_VERIFY_PARAM_set_hostflags(param, 0);
X509_VERIFY_PARAM_set1_host(param, host.c_str(), 0);
#else
#warning DANGEROUS! HTTPS HOST WILL NOT BE CHECKED. INSTALL OPENSSL >= 1.0.2 OR IMPLEMENT HTTPS HOST CHECK MANUALLY
#endif
if (!SSL_set_fd(ssl_handle, fd.get_fd().get_native_fd())) {
return create_openssl_error(-14, "Failed to set fd");
}
#if OPENSSL_VERSION_NUMBER >= 0x0090806fL && !defined(OPENSSL_NO_TLSEXT)
auto host_str = host.str();
SSL_set_tlsext_host_name(ssl_handle, MutableCSlice(host_str).begin());
#endif
SSL_set_connect_state(ssl_handle);
ssl_ctx_guard.dismiss();
ssl_handle_guard.dismiss();
return SslFd(std::move(fd), ssl_handle, ssl_ctx);
#endif
}
Result<size_t> SslFd::process_ssl_error(int ret, int *mask) {
#if TD_WINDOWS
return Status::Error("TODO");
#else
auto openssl_errno = errno;
int error = SSL_get_error(ssl_handle_, ret);
switch (error) {
case SSL_ERROR_NONE:
LOG(ERROR) << "SSL_get_error returned no error";
return 0;
case SSL_ERROR_ZERO_RETURN:
LOG(DEBUG) << "SSL_ERROR_ZERO_RETURN";
fd_.get_fd().update_flags(Fd::Close);
write_mask_ |= Fd::Error;
*mask |= Fd::Error;
return 0;
case SSL_ERROR_WANT_READ:
LOG(DEBUG) << "SSL_ERROR_WANT_READ";
fd_.get_fd().clear_flags(Fd::Read);
*mask |= Fd::Read;
return 0;
case SSL_ERROR_WANT_WRITE:
LOG(DEBUG) << "SSL_ERROR_WANT_WRITE";
fd_.get_fd().clear_flags(Fd::Write);
*mask |= Fd::Write;
return 0;
case SSL_ERROR_WANT_CONNECT:
case SSL_ERROR_WANT_ACCEPT:
case SSL_ERROR_WANT_X509_LOOKUP:
LOG(DEBUG) << "SSL_ERROR: CONNECT ACCEPT LOOKUP";
fd_.get_fd().clear_flags(Fd::Write);
*mask |= Fd::Write;
return 0;
case SSL_ERROR_SYSCALL:
LOG(DEBUG) << "SSL_ERROR_SYSCALL";
if (ERR_peek_error() == 0) {
if (openssl_errno != 0) {
CHECK(openssl_errno != EAGAIN);
return Status::PosixError(openssl_errno, "SSL_ERROR_SYSCALL");
} else {
// Socket was closed from the other side, probably. Not an error
fd_.get_fd().update_flags(Fd::Close);
write_mask_ |= Fd::Error;
*mask |= Fd::Error;
return 0;
}
}
/* fall through */
default:
LOG(DEBUG) << "SSL_ERROR Default";
fd_.get_fd().update_flags(Fd::Close);
write_mask_ |= Fd::Error;
read_mask_ |= Fd::Error;
return create_openssl_error(1, "SSL error ");
}
#endif
}
Result<size_t> SslFd::write(Slice slice) {
openssl_clear_errors("Before SslFd::write");
auto size = SSL_write(ssl_handle_, slice.data(), static_cast<int>(slice.size()));
if (size <= 0) {
return process_ssl_error(size, &write_mask_);
}
return size;
}
Result<size_t> SslFd::read(MutableSlice slice) {
openssl_clear_errors("Before SslFd::read");
auto size = SSL_read(ssl_handle_, slice.data(), static_cast<int>(slice.size()));
if (size <= 0) {
return process_ssl_error(size, &read_mask_);
}
return size;
}
void SslFd::close() {
if (fd_.empty()) {
CHECK(!ssl_handle_ && !ssl_ctx_);
return;
}
CHECK(ssl_handle_ && ssl_ctx_);
do_ssl_shutdown(ssl_handle_);
SSL_free(ssl_handle_);
ssl_handle_ = nullptr;
SSL_CTX_free(ssl_ctx_);
ssl_ctx_ = nullptr;
fd_.close();
}
} // namespace td

View File

@ -1,109 +0,0 @@
//
// Copyright Aliaksei Levin (levlam@telegram.org), Arseny Smirnov (arseny30@gmail.com) 2014-2018
//
// Distributed under the Boost Software License, Version 1.0. (See accompanying
// file LICENSE_1_0.txt or copy at http://www.boost.org/LICENSE_1_0.txt)
//
#pragma once
#include "td/utils/port/Fd.h"
#include "td/utils/port/SocketFd.h"
#include "td/utils/Slice.h"
#include "td/utils/Status.h"
#include <openssl/ssl.h> // TODO can we remove it from header and make target_link_libraries dependence PRIVATE?
namespace td {
class SslFd {
public:
enum class VerifyPeer { On, Off };
static Result<SslFd> init(SocketFd fd, CSlice host, CSlice cert_file = CSlice(),
VerifyPeer verify_peer = VerifyPeer::On) TD_WARN_UNUSED_RESULT;
SslFd(const SslFd &other) = delete;
SslFd &operator=(const SslFd &other) = delete;
SslFd(SslFd &&other)
: fd_(std::move(other.fd_))
, write_mask_(other.write_mask_)
, read_mask_(other.read_mask_)
, ssl_handle_(other.ssl_handle_)
, ssl_ctx_(other.ssl_ctx_) {
other.ssl_handle_ = nullptr;
other.ssl_ctx_ = nullptr;
}
SslFd &operator=(SslFd &&other) {
close();
fd_ = std::move(other.fd_);
write_mask_ = other.write_mask_;
read_mask_ = other.read_mask_;
ssl_handle_ = other.ssl_handle_;
ssl_ctx_ = other.ssl_ctx_;
other.ssl_handle_ = nullptr;
other.ssl_ctx_ = nullptr;
return *this;
}
const Fd &get_fd() const {
return fd_.get_fd();
}
Fd &get_fd() {
return fd_.get_fd();
}
Status get_pending_error() TD_WARN_UNUSED_RESULT {
return fd_.get_pending_error();
}
Result<size_t> write(Slice slice) TD_WARN_UNUSED_RESULT;
Result<size_t> read(MutableSlice slice) TD_WARN_UNUSED_RESULT;
void close();
int32 get_flags() const {
int32 res = 0;
int32 fd_flags = fd_.get_flags();
fd_flags &= ~Fd::Error;
if (fd_flags & Fd::Close) {
res |= Fd::Close;
}
write_mask_ &= ~fd_flags;
read_mask_ &= ~fd_flags;
if (write_mask_ == 0) {
res |= Fd::Write;
}
if (read_mask_ == 0) {
res |= Fd::Read;
}
return res;
}
bool empty() const {
return fd_.empty();
}
~SslFd() {
close();
}
private:
static constexpr bool VERIFY_PEER = true;
static constexpr int VERIFY_DEPTH = 10;
SocketFd fd_;
mutable int write_mask_ = 0;
mutable int read_mask_ = 0;
// TODO unique_ptr
SSL *ssl_handle_ = nullptr;
SSL_CTX *ssl_ctx_ = nullptr;
SslFd(SocketFd &&fd, SSL *ssl_handle_, SSL_CTX *ssl_ctx_);
Result<size_t> process_ssl_error(int ret, int *mask) TD_WARN_UNUSED_RESULT;
};
} // namespace td

473
tdnet/td/net/SslStream.cpp Normal file
View File

@ -0,0 +1,473 @@
//
// Copyright Aliaksei Levin (levlam@telegram.org), Arseny Smirnov (arseny30@gmail.com) 2014-2018
//
// Distributed under the Boost Software License, Version 1.0. (See accompanying
// file LICENSE_1_0.txt or copy at http://www.boost.org/LICENSE_1_0.txt)
//
#include "td/net/SslStream.h"
#include "td/utils/logging.h"
#include "td/utils/StackAllocator.h"
#include "td/utils/StringBuilder.h"
#include "td/utils/Time.h"
#include "td/utils/misc.h"
#include <openssl/err.h>
#include <openssl/evp.h>
#include <openssl/ssl.h>
#include <openssl/x509v3.h>
#include <map>
#include <mutex>
namespace td {
namespace detail {
namespace {
#if OPENSSL_VERSION_NUMBER < 0x10100000L
void* BIO_get_data(BIO* b) {
return b->ptr;
}
void BIO_set_data(BIO* b, void* ptr) {
b->ptr = ptr;
}
void BIO_set_init(BIO* b, int init) {
b->init = init;
}
int BIO_get_new_index() {
return 0;
}
BIO_METHOD* BIO_meth_new(int type, const char* name) {
auto res = new BIO_METHOD();
memset(res, 0, sizeof(*res));
return res;
}
int BIO_meth_set_write(BIO_METHOD* biom, int (*bwrite)(BIO*, const char*, int)) {
biom->bwrite = bwrite;
return 1;
}
int BIO_meth_set_read(BIO_METHOD* biom, int (*bread)(BIO*, char*, int)) {
biom->bread = bread;
return 1;
}
int BIO_meth_set_ctrl(BIO_METHOD* biom, long (*ctrl)(BIO*, int, long, void*)) {
biom->ctrl = ctrl;
return 1;
}
int BIO_meth_set_create(BIO_METHOD* biom, int (*create)(BIO*)) {
biom->create = create;
return 1;
}
int BIO_meth_set_destroy(BIO_METHOD* biom, int (*destroy)(BIO*)) {
biom->destroy = destroy;
return 1;
}
#endif
int strm_create(BIO* b) {
BIO_set_init(b, 1);
return 1;
}
int strm_destroy(BIO* b) {
return 1;
}
int strm_read(BIO* b, char* buf, int len);
int strm_write(BIO* b, const char* buf, int len);
long strm_ctrl(BIO* b, int cmd, long num, void* ptr) {
switch (cmd) {
case BIO_CTRL_FLUSH:
return 1;
case BIO_CTRL_PUSH:
return 0;
case BIO_CTRL_POP:
return 0;
default:
LOG(FATAL) << b << " " << cmd << " " << num << " " << ptr;
}
return 1;
}
BIO_METHOD* BIO_s_sslstream() {
static BIO_METHOD* res = [] {
BIO_METHOD* res = BIO_meth_new(BIO_get_new_index(), "td::SslStream helper bio");
BIO_meth_set_write(res, strm_write);
BIO_meth_set_read(res, strm_read);
BIO_meth_set_create(res, strm_create);
BIO_meth_set_destroy(res, strm_destroy);
BIO_meth_set_ctrl(res, strm_ctrl);
return res;
}();
return res;
}
int verify_callback(int preverify_ok, X509_STORE_CTX* ctx) {
if (!preverify_ok) {
char buf[256];
X509_NAME_oneline(X509_get_subject_name(X509_STORE_CTX_get_current_cert(ctx)), buf, 256);
int err = X509_STORE_CTX_get_error(ctx);
auto warning = PSTRING() << "verify error:num=" << err << ":" << X509_verify_cert_error_string(err)
<< ":depth=" << X509_STORE_CTX_get_error_depth(ctx) << ":" << buf;
double now = Time::now();
static std::mutex warning_mutex;
{
std::lock_guard<std::mutex> lock(warning_mutex);
static std::map<std::string, double> next_warning_time;
double& next = next_warning_time[warning];
if (next <= now) {
next = now + 300; // one warning per 5 minutes
LOG(WARNING) << warning;
}
}
}
return preverify_ok;
}
Status create_openssl_error(int code, Slice message) {
const int buf_size = 1 << 12;
auto buf = StackAllocator::alloc(buf_size);
StringBuilder sb(buf.as_slice());
sb << message;
while (unsigned long error_code = ERR_get_error()) {
sb << "{" << error_code << ", " << ERR_error_string(error_code, nullptr) << "}";
}
LOG_IF(ERROR, sb.is_error()) << "OPENSSL error buffer overflow";
return Status::Error(code, sb.as_cslice());
}
void openssl_clear_errors(Slice from) {
if (ERR_peek_error() != 0) {
LOG(ERROR) << from << ": " << create_openssl_error(0, "Unprocessed OPENSSL_ERROR");
}
errno = 0;
}
void do_ssl_shutdown(SSL* ssl_handle) {
if (!SSL_is_init_finished(ssl_handle)) {
return;
}
openssl_clear_errors("Before SSL_shutdown");
SSL_set_quiet_shutdown(ssl_handle, 1);
SSL_shutdown(ssl_handle);
openssl_clear_errors("After SSL_shutdown");
}
} // namespace
class SslStreamImpl {
public:
using VerifyPeer = SslStream::VerifyPeer;
~SslStreamImpl() {
if (!ssl_handle_) {
CHECK(!ssl_ctx_ && !bio_);
return;
}
CHECK(ssl_handle_ && ssl_ctx_ && bio_);
do_ssl_shutdown(ssl_handle_);
SSL_free(ssl_handle_);
ssl_handle_ = nullptr;
SSL_CTX_free(ssl_ctx_);
ssl_ctx_ = nullptr;
}
Status init(CSlice host, CSlice cert_file, VerifyPeer verify_peer) {
static bool init_openssl = [] {
#if OPENSSL_VERSION_NUMBER >= 0x10100000L
return OPENSSL_init_ssl(0, nullptr) != 0;
#else
OpenSSL_add_all_algorithms();
SSL_load_error_strings();
return OpenSSL_add_ssl_algorithms() != 0;
#endif
}();
CHECK(init_openssl);
openssl_clear_errors("Before SslFd::init");
auto ssl_method =
#if OPENSSL_VERSION_NUMBER >= 0x10100000L
TLS_client_method();
#else
SSLv23_client_method();
#endif
if (ssl_method == nullptr) {
return create_openssl_error(-6, "Failed to create an SSL client method");
}
auto ssl_ctx = SSL_CTX_new(ssl_method);
if (ssl_ctx == nullptr) {
return create_openssl_error(-7, "Failed to create an SSL context");
}
auto ssl_ctx_guard = ScopeExit() + [&]() { SSL_CTX_free(ssl_ctx); };
long options = 0;
#ifdef SSL_OP_NO_SSLv2
options |= SSL_OP_NO_SSLv2;
#endif
#ifdef SSL_OP_NO_SSLv3
options |= SSL_OP_NO_SSLv3;
#endif
SSL_CTX_set_options(ssl_ctx, options);
SSL_CTX_set_mode(ssl_ctx, SSL_MODE_ACCEPT_MOVING_WRITE_BUFFER | SSL_MODE_ENABLE_PARTIAL_WRITE);
if (cert_file.empty()) {
SSL_CTX_set_default_verify_paths(ssl_ctx);
} else {
if (SSL_CTX_load_verify_locations(ssl_ctx, cert_file.c_str(), nullptr) == 0) {
return create_openssl_error(-8, "Failed to set custom cert file");
}
}
if (verify_peer == VerifyPeer::On) {
SSL_CTX_set_verify(ssl_ctx, SSL_VERIFY_PEER, verify_callback);
if (VERIFY_DEPTH != -1) {
SSL_CTX_set_verify_depth(ssl_ctx, VERIFY_DEPTH);
}
} else {
SSL_CTX_set_verify(ssl_ctx, SSL_VERIFY_NONE, nullptr);
}
// TODO(now): cipher list
string cipher_list;
if (SSL_CTX_set_cipher_list(ssl_ctx, cipher_list.empty() ? "DEFAULT" : cipher_list.c_str()) == 0) {
return create_openssl_error(-9, PSLICE() << "Failed to set cipher list \"" << cipher_list << '"');
}
auto ssl_handle = SSL_new(ssl_ctx);
if (ssl_handle == nullptr) {
return create_openssl_error(-13, "Failed to create an SSL handle");
}
auto ssl_handle_guard = ScopeExit() + [&]() {
do_ssl_shutdown(ssl_handle);
SSL_free(ssl_handle);
};
#if OPENSSL_VERSION_NUMBER >= 0x10002000L
X509_VERIFY_PARAM* param = SSL_get0_param(ssl_handle);
/* Enable automatic hostname checks */
// TODO: X509_CHECK_FLAG_NO_PARTIAL_WILDCARDS
X509_VERIFY_PARAM_set_hostflags(param, 0);
X509_VERIFY_PARAM_set1_host(param, host.c_str(), 0);
#else
#warning DANGEROUS! HTTPS HOST WILL NOT BE CHECKED. INSTALL OPENSSL >= 1.0.2 OR IMPLEMENT HTTPS HOST CHECK MANUALLY
#endif
auto* bio = BIO_new(BIO_s_sslstream());
BIO_set_data(bio, this);
SSL_set_bio(ssl_handle, bio, bio);
#if OPENSSL_VERSION_NUMBER >= 0x0090806fL && !defined(OPENSSL_NO_TLSEXT)
auto host_str = host.str();
SSL_set_tlsext_host_name(ssl_handle, MutableCSlice(host_str).begin());
#endif
SSL_set_connect_state(ssl_handle);
ssl_ctx_guard.dismiss();
ssl_handle_guard.dismiss();
ssl_handle_ = ssl_handle;
ssl_ctx_ = ssl_ctx;
bio_ = bio;
return Status::OK();
}
ByteFlowInterface& read_byte_flow() {
return read_flow_;
}
ByteFlowInterface& write_byte_flow() {
return write_flow_;
}
size_t flow_read(MutableSlice slice) {
return read_flow_.read(slice);
}
size_t flow_write(Slice slice) {
return write_flow_.write(slice);
}
private:
static constexpr int VERIFY_DEPTH = 10;
SSL* ssl_handle_ = nullptr;
SSL_CTX* ssl_ctx_ = nullptr;
BIO* bio_ = nullptr;
friend class SslReadByteFlow;
friend class SslWriteByteFlow;
Result<size_t> write(Slice slice) {
openssl_clear_errors("Before SslFd::write");
auto size = SSL_write(ssl_handle_, slice.data(), static_cast<int>(slice.size()));
if (size <= 0) {
return process_ssl_error(size);
}
return size;
}
Result<size_t> read(MutableSlice slice) {
openssl_clear_errors("Before SslFd::read");
auto size = SSL_read(ssl_handle_, slice.data(), static_cast<int>(slice.size()));
if (size <= 0) {
return process_ssl_error(size);
}
return size;
}
class SslReadByteFlow : public ByteFlowBase {
public:
SslReadByteFlow(SslStreamImpl* stream) : stream_(stream) {
}
void loop() override {
bool was_append = false;
while (true) {
auto to_read = output_.prepare_append();
auto r_size = stream_->read(to_read);
if (r_size.is_error()) {
return finish(r_size.move_as_error());
}
auto size = r_size.move_as_ok();
if (size == 0) {
break;
}
output_.confirm_append(size);
was_append = true;
}
if (was_append) {
on_output_updated();
}
}
size_t read(MutableSlice data) {
return input_->advance(std::min(data.size(), input_->size()), data);
}
private:
SslStreamImpl* stream_;
};
class SslWriteByteFlow : public ByteFlowBase {
public:
SslWriteByteFlow(SslStreamImpl* stream) : stream_(stream) {
}
void loop() override {
while (!input_->empty()) {
auto to_write = input_->prepare_read();
auto r_size = stream_->write(to_write);
if (r_size.is_error()) {
return finish(r_size.move_as_error());
}
auto size = r_size.move_as_ok();
if (size == 0) {
break;
}
input_->confirm_read(size);
}
if (output_updated_) {
output_updated_ = false;
on_output_updated();
}
}
size_t write(Slice data) {
output_.append(data);
output_updated_ = true;
return data.size();
}
private:
SslStreamImpl* stream_;
bool output_updated_{false};
};
SslReadByteFlow read_flow_{this};
SslWriteByteFlow write_flow_{this};
Result<size_t> process_ssl_error(int ret) {
auto openssl_errno = errno;
int error = SSL_get_error(ssl_handle_, ret);
switch (error) {
case SSL_ERROR_NONE:
LOG(ERROR) << "SSL_get_error returned no error";
return 0;
case SSL_ERROR_ZERO_RETURN:
LOG(DEBUG) << "SSL_ERROR_ZERO_RETURN";
return 0;
case SSL_ERROR_WANT_READ:
LOG(DEBUG) << "SSL_ERROR_WANT_READ";
return 0;
case SSL_ERROR_WANT_WRITE:
LOG(DEBUG) << "SSL_ERROR_WANT_WRITE";
return 0;
case SSL_ERROR_WANT_CONNECT:
case SSL_ERROR_WANT_ACCEPT:
case SSL_ERROR_WANT_X509_LOOKUP:
LOG(DEBUG) << "SSL_ERROR: CONNECT ACCEPT LOOKUP";
return 0;
case SSL_ERROR_SYSCALL:
LOG(DEBUG) << "SSL_ERROR_SYSCALL";
if (ERR_peek_error() == 0) {
if (openssl_errno != 0) {
CHECK(openssl_errno != EAGAIN);
return Status::PosixError(openssl_errno, "SSL_ERROR_SYSCALL");
} else {
return 0;
}
}
/* fall through */
default:
LOG(DEBUG) << "SSL_ERROR Default";
return create_openssl_error(1, "SSL error ");
}
}
};
namespace {
int strm_read(BIO* b, char* buf, int len) {
auto* stream = reinterpret_cast<SslStreamImpl*>(BIO_get_data(b));
CHECK(stream);
BIO_clear_retry_flags(b);
int res = narrow_cast<int>(stream->flow_read(MutableSlice(buf, len)));
if (res == 0) {
BIO_set_retry_read(b);
return -1;
}
return res;
}
int strm_write(BIO* b, const char* buf, int len) {
auto* stream = reinterpret_cast<SslStreamImpl*>(BIO_get_data(b));
CHECK(stream);
BIO_clear_retry_flags(b);
return narrow_cast<int>(stream->flow_write(Slice(buf, len)));
}
} // namespace
} // namespace detail
SslStream::SslStream() = default;
SslStream::SslStream(SslStream&&) = default;
SslStream& SslStream::operator=(SslStream&&) = default;
SslStream::~SslStream() = default;
Result<SslStream> SslStream::create(CSlice host, CSlice cert_file, VerifyPeer verify_peer) {
auto impl = std::make_unique<detail::SslStreamImpl>();
TRY_STATUS(impl->init(host, cert_file, verify_peer));
return SslStream(std::move(impl));
}
SslStream::SslStream(std::unique_ptr<detail::SslStreamImpl> impl) : impl_(std::move(impl)) {
}
ByteFlowInterface& SslStream::read_byte_flow() {
return impl_->read_byte_flow();
}
ByteFlowInterface& SslStream::write_byte_flow() {
return impl_->write_byte_flow();
}
size_t SslStream::flow_read(MutableSlice slice) {
return impl_->flow_read(slice);
}
size_t SslStream::flow_write(Slice slice) {
return impl_->flow_write(slice);
}
} // namespace td

49
tdnet/td/net/SslStream.h Normal file
View File

@ -0,0 +1,49 @@
//
// Copyright Aliaksei Levin (levlam@telegram.org), Arseny Smirnov (arseny30@gmail.com) 2014-2018
//
// Distributed under the Boost Software License, Version 1.0. (See accompanying
// file LICENSE_1_0.txt or copy at http://www.boost.org/LICENSE_1_0.txt)
//
#pragma once
#include "td/utils/port/Fd.h"
#include "td/utils/port/SocketFd.h"
#include "td/utils/Slice.h"
#include "td/utils/Status.h"
#include "td/utils/ByteFlow.h"
#include "td/utils/BufferedFd.h"
namespace td {
namespace detail {
class SslStreamImpl;
}
class SslStream {
public:
SslStream();
SslStream(SslStream &&);
SslStream &operator=(SslStream &&);
~SslStream();
enum class VerifyPeer { On, Off };
static Result<SslStream> create(CSlice host, CSlice cert_file = CSlice(), VerifyPeer verify_peer = VerifyPeer::On);
ByteFlowInterface &read_byte_flow();
ByteFlowInterface &write_byte_flow();
size_t flow_read(MutableSlice slice);
size_t flow_write(Slice slice);
explicit operator bool() const {
return bool(impl_);
}
private:
std::unique_ptr<detail::SslStreamImpl> impl_;
SslStream(std::unique_ptr<detail::SslStreamImpl> impl);
};
} // namespace td

View File

@ -8,7 +8,7 @@
#include "td/net/HttpHeaderCreator.h"
#include "td/net/HttpOutboundConnection.h"
#include "td/net/SslFd.h"
#include "td/net/SslStream.h"
#include "td/utils/buffer.h"
#include "td/utils/HttpUrl.h"
@ -23,7 +23,7 @@
namespace td {
Wget::Wget(Promise<HttpQueryPtr> promise, string url, std::vector<std::pair<string, string>> headers, int32 timeout_in,
int32 ttl, bool prefer_ipv6, SslFd::VerifyPeer verify_peer)
int32 ttl, bool prefer_ipv6, SslStream::VerifyPeer verify_peer)
: promise_(std::move(promise))
, input_url_(std::move(url))
, headers_(std::move(headers))
@ -66,14 +66,15 @@ Status Wget::try_init() {
TRY_RESULT(fd, SocketFd::open(addr));
if (url.protocol_ == HttpUrl::Protocol::HTTP) {
connection_ =
create_actor<HttpOutboundConnection>("Connect", std::move(fd), std::numeric_limits<std::size_t>::max(), 0, 0,
ActorOwn<HttpOutboundConnection::Callback>(actor_id(this)));
connection_ = create_actor<HttpOutboundConnection>("Connect", std::move(fd), SslStream{},
std::numeric_limits<std::size_t>::max(), 0, 0,
ActorOwn<HttpOutboundConnection::Callback>(actor_id(this)));
} else {
TRY_RESULT(ssl_fd, SslFd::init(std::move(fd), url.host_, CSlice() /* certificate */, verify_peer_));
connection_ =
create_actor<HttpOutboundConnection>("Connect", std::move(ssl_fd), std::numeric_limits<std::size_t>::max(), 0,
0, ActorOwn<HttpOutboundConnection::Callback>(actor_id(this)));
LOG(ERROR) << "HTTPS";
TRY_RESULT(ssl_stream, SslStream::create(url.host_, CSlice() /* certificate */, verify_peer_));
connection_ = create_actor<HttpOutboundConnection>("Connect", std::move(fd), std::move(ssl_stream),
std::numeric_limits<std::size_t>::max(), 0, 0,
ActorOwn<HttpOutboundConnection::Callback>(actor_id(this)));
}
send_closure(connection_, &HttpOutboundConnection::write_next, BufferSlice(header));

View File

@ -8,7 +8,7 @@
#include "td/net/HttpOutboundConnection.h"
#include "td/net/HttpQuery.h"
#include "td/net/SslFd.h"
#include "td/net/SslStream.h"
#include "td/actor/actor.h"
#include "td/actor/PromiseFuture.h"
@ -24,7 +24,7 @@ class Wget : public HttpOutboundConnection::Callback {
public:
explicit Wget(Promise<HttpQueryPtr> promise, string url, std::vector<std::pair<string, string>> headers = {},
int32 timeout_in = 10, int32 ttl = 3, bool prefer_ipv6 = false,
SslFd::VerifyPeer verify_peer = SslFd::VerifyPeer::On);
SslStream::VerifyPeer verify_peer = SslStream::VerifyPeer::On);
private:
Status try_init();
@ -45,7 +45,7 @@ class Wget : public HttpOutboundConnection::Callback {
int32 timeout_in_;
int32 ttl_;
bool prefer_ipv6_ = false;
SslFd::VerifyPeer verify_peer_;
SslStream::VerifyPeer verify_peer_;
};
} // namespace td

View File

@ -246,6 +246,10 @@ class ByteFlowSink : public ByteFlowInterface {
class ByteFlowMoveSink : public ByteFlowInterface {
public:
ByteFlowMoveSink() = default;
ByteFlowMoveSink(ChainBufferWriter *output) {
set_output(output);
}
void set_input(ChainBufferReader *input) final {
CHECK(!input_);
input_ = input;