Ssl refactoring
GitOrigin-RevId: f5916787608227b6914c10520dfe7a7039522ef9
This commit is contained in:
parent
7fc96ddff5
commit
ab2b189722
@ -27,8 +27,8 @@ class HttpClient : public HttpOutboundConnection::Callback {
|
||||
addr.init_ipv4_port("127.0.0.1", 8082).ensure();
|
||||
auto fd = SocketFd::open(addr);
|
||||
CHECK(fd.is_ok()) << fd.error();
|
||||
connection_ =
|
||||
create_actor<HttpOutboundConnection>("Connect", fd.move_as_ok(), std::numeric_limits<size_t>::max(), 0, 0,
|
||||
connection_ = create_actor<HttpOutboundConnection>("Connect", fd.move_as_ok(), SslStream{},
|
||||
std::numeric_limits<size_t>::max(), 0, 0,
|
||||
ActorOwn<HttpOutboundConnection::Callback>(actor_id(this)));
|
||||
yield();
|
||||
cnt_ = 100000;
|
||||
|
@ -18,7 +18,7 @@
|
||||
#include <string>
|
||||
|
||||
int main(int argc, char *argv[]) {
|
||||
SET_VERBOSITY_LEVEL(VERBOSITY_NAME(INFO));
|
||||
SET_VERBOSITY_LEVEL(VERBOSITY_NAME(DEBUG));
|
||||
td::VERBOSITY_NAME(fd) = VERBOSITY_NAME(INFO);
|
||||
|
||||
std::string url = (argc > 1 ? argv[1] : "https://telegram.org");
|
||||
|
@ -20,7 +20,7 @@
|
||||
|
||||
#if !TD_EMSCRIPTEN //FIXME
|
||||
#include "td/net/HttpQuery.h"
|
||||
#include "td/net/SslFd.h"
|
||||
#include "td/net/SslStream.h"
|
||||
#include "td/net/Wget.h"
|
||||
#endif
|
||||
|
||||
@ -126,7 +126,7 @@ static ActorOwn<> get_simple_config_impl(Promise<SimpleConfig> promise, int32 sc
|
||||
}());
|
||||
}),
|
||||
std::move(url), std::vector<std::pair<string, string>>({{"Host", std::move(host)}}), timeout, ttl, prefer_ipv6,
|
||||
SslFd::VerifyPeer::Off));
|
||||
SslStream::VerifyPeer::Off));
|
||||
#endif
|
||||
}
|
||||
|
||||
@ -186,7 +186,7 @@ ActorOwn<> get_simple_config_google_dns(Promise<SimpleConfig> promise, const Con
|
||||
}),
|
||||
PSTRING() << "https://www.google.com/resolve?name=" << url_encode(name) << "&type=16",
|
||||
std::vector<std::pair<string, string>>({{"Host", "dns.google.com"}}), timeout, ttl, prefer_ipv6,
|
||||
SslFd::VerifyPeer::Off));
|
||||
SslStream::VerifyPeer::Off));
|
||||
#endif
|
||||
}
|
||||
|
||||
|
@ -18,7 +18,7 @@ set(TDNET_SOURCE
|
||||
td/net/HttpQuery.cpp
|
||||
td/net/HttpReader.cpp
|
||||
td/net/Socks5.cpp
|
||||
td/net/SslFd.cpp
|
||||
td/net/SslStream.cpp
|
||||
td/net/TcpListener.cpp
|
||||
td/net/TransparentProxy.cpp
|
||||
td/net/Wget.cpp
|
||||
@ -36,7 +36,7 @@ set(TDNET_SOURCE
|
||||
td/net/HttpReader.h
|
||||
td/net/NetStats.h
|
||||
td/net/Socks5.h
|
||||
td/net/SslFd.h
|
||||
td/net/SslStream.h
|
||||
td/net/TcpListener.h
|
||||
td/net/TransparentProxy.h
|
||||
td/net/Wget.h
|
||||
|
@ -14,14 +14,23 @@
|
||||
namespace td {
|
||||
namespace detail {
|
||||
|
||||
HttpConnectionBase::HttpConnectionBase(State state, FdProxy fd, size_t max_post_size, size_t max_files,
|
||||
int32 idle_timeout)
|
||||
HttpConnectionBase::HttpConnectionBase(State state, SocketFd fd, SslStream ssl_stream, size_t max_post_size,
|
||||
size_t max_files, int32 idle_timeout)
|
||||
: state_(state)
|
||||
, stream_connection_(std::move(fd))
|
||||
, fd_(std::move(fd))
|
||||
, ssl_stream_(std::move(ssl_stream))
|
||||
, max_post_size_(max_post_size)
|
||||
, max_files_(max_files)
|
||||
, idle_timeout_(idle_timeout) {
|
||||
CHECK(state_ != State::Close);
|
||||
|
||||
if (ssl_stream_) {
|
||||
read_source_ >> ssl_stream_.read_byte_flow() >> read_sink_;
|
||||
write_source_ >> ssl_stream_.write_byte_flow() >> write_sink_;
|
||||
} else {
|
||||
read_source_ >> read_sink_;
|
||||
write_source_ >> write_sink_;
|
||||
}
|
||||
}
|
||||
|
||||
void HttpConnectionBase::live_event() {
|
||||
@ -31,9 +40,9 @@ void HttpConnectionBase::live_event() {
|
||||
}
|
||||
|
||||
void HttpConnectionBase::start_up() {
|
||||
stream_connection_.get_fd().set_observer(this);
|
||||
subscribe(stream_connection_.get_fd());
|
||||
reader_.init(&stream_connection_.input_buffer(), max_post_size_, max_files_);
|
||||
fd_.get_fd().set_observer(this);
|
||||
subscribe(fd_.get_fd());
|
||||
reader_.init(read_sink_.get_output(), max_post_size_, max_files_);
|
||||
if (state_ == State::Read) {
|
||||
current_query_ = make_unique<HttpQuery>();
|
||||
}
|
||||
@ -41,13 +50,13 @@ void HttpConnectionBase::start_up() {
|
||||
yield();
|
||||
}
|
||||
void HttpConnectionBase::tear_down() {
|
||||
unsubscribe_before_close(stream_connection_.get_fd());
|
||||
stream_connection_.close();
|
||||
unsubscribe_before_close(fd_.get_fd());
|
||||
fd_.close();
|
||||
}
|
||||
|
||||
void HttpConnectionBase::write_next(BufferSlice buffer) {
|
||||
CHECK(state_ == State::Write);
|
||||
stream_connection_.output_buffer().append(std::move(buffer));
|
||||
write_buffer_.append(std::move(buffer));
|
||||
loop();
|
||||
}
|
||||
|
||||
@ -69,7 +78,7 @@ void HttpConnectionBase::write_error(Status error) {
|
||||
void HttpConnectionBase::timeout_expired() {
|
||||
LOG(INFO) << "Idle timeout expired";
|
||||
|
||||
if (stream_connection_.need_flush_write()) {
|
||||
if (fd_.need_flush_write()) {
|
||||
on_error(Status::Error("Write timeout expired"));
|
||||
} else if (state_ == State::Read) {
|
||||
on_error(Status::Error("Read timeout expired"));
|
||||
@ -78,9 +87,9 @@ void HttpConnectionBase::timeout_expired() {
|
||||
stop();
|
||||
}
|
||||
void HttpConnectionBase::loop() {
|
||||
if (can_read(stream_connection_)) {
|
||||
if (can_read(fd_)) {
|
||||
LOG(DEBUG) << "Can read from the connection";
|
||||
auto r = stream_connection_.flush_read();
|
||||
auto r = fd_.flush_read();
|
||||
if (r.is_error()) {
|
||||
if (!begins_with(r.error().message(), "SSL error {336134278")) { // if error is not yet outputed
|
||||
LOG(INFO) << "flush_read error: " << r.error();
|
||||
@ -89,6 +98,7 @@ void HttpConnectionBase::loop() {
|
||||
return stop();
|
||||
}
|
||||
}
|
||||
read_source_.wakeup();
|
||||
|
||||
// TODO: read_next even when state_ == State::Write
|
||||
|
||||
@ -102,7 +112,7 @@ void HttpConnectionBase::loop() {
|
||||
HttpHeaderCreator hc;
|
||||
hc.init_status_line(res.error().code());
|
||||
hc.set_content_size(0);
|
||||
stream_connection_.output_buffer().append(hc.finish().ok());
|
||||
write_buffer_.append(hc.finish().ok());
|
||||
close_after_write_ = true;
|
||||
on_error(Status::Error(res.error().public_message()));
|
||||
} else if (res.ok() == 0) {
|
||||
@ -115,34 +125,45 @@ void HttpConnectionBase::loop() {
|
||||
}
|
||||
}
|
||||
|
||||
if (can_write(stream_connection_)) {
|
||||
write_source_.wakeup();
|
||||
|
||||
if (can_write(fd_)) {
|
||||
LOG(DEBUG) << "Can write to the connection";
|
||||
auto r = stream_connection_.flush_write();
|
||||
auto r = fd_.flush_write();
|
||||
if (r.is_error()) {
|
||||
LOG(INFO) << "flush_write error: " << r.error();
|
||||
on_error(Status::Error(r.error().public_message()));
|
||||
}
|
||||
if (close_after_write_ && !stream_connection_.need_flush_write()) {
|
||||
if (close_after_write_ && !fd_.need_flush_write()) {
|
||||
return stop();
|
||||
}
|
||||
}
|
||||
|
||||
if (stream_connection_.get_fd().has_pending_error()) {
|
||||
auto pending_error = stream_connection_.get_pending_error();
|
||||
Status pending_error;
|
||||
if (fd_.get_fd().has_pending_error()) {
|
||||
pending_error = fd_.get_pending_error();
|
||||
}
|
||||
if (pending_error.is_ok() && write_sink_.status().is_error()) {
|
||||
pending_error = std::move(write_sink_.status());
|
||||
}
|
||||
if (pending_error.is_ok() && read_sink_.status().is_error()) {
|
||||
pending_error = std::move(read_sink_.status());
|
||||
}
|
||||
if (pending_error.is_error()) {
|
||||
LOG(INFO) << pending_error;
|
||||
if (!close_after_write_) {
|
||||
on_error(Status::Error(pending_error.public_message()));
|
||||
}
|
||||
state_ = State::Close;
|
||||
}
|
||||
if (can_close(stream_connection_)) {
|
||||
|
||||
if (can_close(fd_)) {
|
||||
LOG(DEBUG) << "Can close the connection";
|
||||
state_ = State::Close;
|
||||
}
|
||||
if (state_ == State::Close) {
|
||||
LOG_IF(INFO, stream_connection_.need_flush_write()) << "Close nonempty connection";
|
||||
LOG_IF(INFO, want_read &&
|
||||
(stream_connection_.input_buffer().size() > 0 || current_query_->type_ != HttpQuery::Type::EMPTY))
|
||||
LOG_IF(INFO, fd_.need_flush_write()) << "Close nonempty connection";
|
||||
LOG_IF(INFO, want_read && (fd_.input_buffer().size() > 0 || current_query_->type_ != HttpQuery::Type::EMPTY))
|
||||
<< "Close connection while reading request/response";
|
||||
return stop();
|
||||
}
|
||||
|
@ -10,120 +10,16 @@
|
||||
|
||||
#include "td/net/HttpQuery.h"
|
||||
#include "td/net/HttpReader.h"
|
||||
#include "td/net/SslStream.h"
|
||||
|
||||
#include "td/utils/buffer.h"
|
||||
#include "td/utils/BufferedFd.h"
|
||||
#include "td/utils/port/Fd.h"
|
||||
#include "td/utils/port/SocketFd.h"
|
||||
#include "td/utils/Slice.h"
|
||||
#include "td/utils/Status.h"
|
||||
|
||||
namespace td {
|
||||
|
||||
class FdInterface {
|
||||
public:
|
||||
FdInterface() = default;
|
||||
FdInterface(const FdInterface &) = delete;
|
||||
FdInterface &operator=(const FdInterface &) = delete;
|
||||
FdInterface(FdInterface &&) = default;
|
||||
FdInterface &operator=(FdInterface &&) = default;
|
||||
virtual ~FdInterface() = default;
|
||||
virtual const Fd &get_fd() const = 0;
|
||||
virtual Fd &get_fd() = 0;
|
||||
virtual int32 get_flags() const = 0;
|
||||
virtual Status get_pending_error() TD_WARN_UNUSED_RESULT = 0;
|
||||
|
||||
virtual Result<size_t> write(Slice slice) TD_WARN_UNUSED_RESULT = 0;
|
||||
virtual Result<size_t> read(MutableSlice slice) TD_WARN_UNUSED_RESULT = 0;
|
||||
|
||||
virtual void close() = 0;
|
||||
virtual bool empty() const = 0;
|
||||
};
|
||||
|
||||
template <class FdT>
|
||||
class FdToInterface : public FdInterface {
|
||||
public:
|
||||
FdToInterface() = default;
|
||||
explicit FdToInterface(FdT fd) : fd_(std::move(fd)) {
|
||||
}
|
||||
|
||||
const Fd &get_fd() const final {
|
||||
return fd_.get_fd();
|
||||
}
|
||||
Fd &get_fd() final {
|
||||
return fd_.get_fd();
|
||||
}
|
||||
int32 get_flags() const final {
|
||||
return fd_.get_flags();
|
||||
}
|
||||
Status get_pending_error() final TD_WARN_UNUSED_RESULT {
|
||||
return fd_.get_pending_error();
|
||||
}
|
||||
|
||||
Result<size_t> write(Slice slice) final TD_WARN_UNUSED_RESULT {
|
||||
return fd_.write(slice);
|
||||
}
|
||||
Result<size_t> read(MutableSlice slice) final TD_WARN_UNUSED_RESULT {
|
||||
return fd_.read(slice);
|
||||
}
|
||||
|
||||
void close() final {
|
||||
fd_.close();
|
||||
}
|
||||
bool empty() const final {
|
||||
return fd_.empty();
|
||||
}
|
||||
|
||||
private:
|
||||
FdT fd_;
|
||||
};
|
||||
|
||||
template <class FdT>
|
||||
std::unique_ptr<FdInterface> make_fd_interface(FdT fd) {
|
||||
return make_unique<FdToInterface<FdT>>(std::move(fd));
|
||||
}
|
||||
|
||||
class FdProxy {
|
||||
public:
|
||||
FdProxy() = default;
|
||||
explicit FdProxy(std::unique_ptr<FdInterface> fd) : fd_(std::move(fd)) {
|
||||
}
|
||||
|
||||
const Fd &get_fd() const {
|
||||
return fd_->get_fd();
|
||||
}
|
||||
Fd &get_fd() {
|
||||
return fd_->get_fd();
|
||||
}
|
||||
int32 get_flags() const {
|
||||
return fd_->get_flags();
|
||||
}
|
||||
Status get_pending_error() TD_WARN_UNUSED_RESULT {
|
||||
return fd_->get_pending_error();
|
||||
}
|
||||
|
||||
Result<size_t> write(Slice slice) TD_WARN_UNUSED_RESULT {
|
||||
return fd_->write(slice);
|
||||
}
|
||||
Result<size_t> read(MutableSlice slice) TD_WARN_UNUSED_RESULT {
|
||||
return fd_->read(slice);
|
||||
}
|
||||
|
||||
void close() {
|
||||
fd_->close();
|
||||
}
|
||||
bool empty() const {
|
||||
return fd_->empty();
|
||||
}
|
||||
|
||||
private:
|
||||
std::unique_ptr<FdInterface> fd_;
|
||||
};
|
||||
|
||||
template <class FdT>
|
||||
FdProxy make_fd_proxy(FdT fd) {
|
||||
return FdProxy(make_fd_interface(std::move(fd)));
|
||||
}
|
||||
|
||||
namespace detail {
|
||||
class HttpConnectionBase : public Actor {
|
||||
public:
|
||||
@ -133,16 +29,23 @@ class HttpConnectionBase : public Actor {
|
||||
|
||||
protected:
|
||||
enum class State { Read, Write, Close };
|
||||
template <class FdT>
|
||||
HttpConnectionBase(State state, FdT fd, size_t max_post_size, size_t max_files, int32 idle_timeout)
|
||||
: HttpConnectionBase(state, make_fd_proxy(std::move(fd)), max_post_size, max_files, idle_timeout) {
|
||||
}
|
||||
HttpConnectionBase(State state, FdProxy fd, size_t max_post_size, size_t max_files, int32 idle_timeout);
|
||||
HttpConnectionBase(State state, SocketFd fd, SslStream ssl_stream, size_t max_post_size, size_t max_files,
|
||||
int32 idle_timeout);
|
||||
|
||||
private:
|
||||
using StreamConnection = BufferedFd<FdProxy>;
|
||||
State state_;
|
||||
StreamConnection stream_connection_;
|
||||
|
||||
BufferedFd<SocketFd> fd_;
|
||||
SslStream ssl_stream_;
|
||||
|
||||
ByteFlowSource read_source_{&fd_.input_buffer()};
|
||||
ByteFlowSink read_sink_;
|
||||
|
||||
ChainBufferWriter write_buffer_;
|
||||
ChainBufferReader write_buffer_reader_ = write_buffer_.extract_reader();
|
||||
ByteFlowSource write_source_{&write_buffer_reader_};
|
||||
ByteFlowMoveSink write_sink_{&fd_.output_buffer()};
|
||||
|
||||
size_t max_post_size_;
|
||||
size_t max_files_;
|
||||
int32 idle_timeout_;
|
||||
|
@ -12,7 +12,7 @@ namespace td {
|
||||
// HttpInboundConnection implementation
|
||||
HttpInboundConnection::HttpInboundConnection(SocketFd fd, size_t max_post_size, size_t max_files, int32 idle_timeout,
|
||||
ActorShared<Callback> callback)
|
||||
: HttpConnectionBase(State::Read, std::move(fd), max_post_size, max_files, idle_timeout)
|
||||
: HttpConnectionBase(State::Read, std::move(fd), SslStream(), max_post_size, max_files, idle_timeout)
|
||||
, callback_(std::move(callback)) {
|
||||
}
|
||||
|
||||
|
@ -22,10 +22,10 @@ class HttpOutboundConnection final : public detail::HttpConnectionBase {
|
||||
virtual void handle(HttpQueryPtr query) = 0;
|
||||
virtual void on_connection_error(Status error) = 0; // TODO rename to on_error
|
||||
};
|
||||
template <class FdT>
|
||||
HttpOutboundConnection(FdT fd, size_t max_post_size, size_t max_files, int32 idle_timeout,
|
||||
HttpOutboundConnection(SocketFd fd, SslStream ssl_stream, size_t max_post_size, size_t max_files, int32 idle_timeout,
|
||||
ActorShared<Callback> callback)
|
||||
: HttpConnectionBase(HttpConnectionBase::State::Write, std::move(fd), max_post_size, max_files, idle_timeout)
|
||||
: HttpConnectionBase(HttpConnectionBase::State::Write, std::move(fd), std::move(ssl_stream), max_post_size,
|
||||
max_files, idle_timeout)
|
||||
, callback_(std::move(callback)) {
|
||||
}
|
||||
// Inherited interface
|
||||
|
@ -1,279 +0,0 @@
|
||||
//
|
||||
// Copyright Aliaksei Levin (levlam@telegram.org), Arseny Smirnov (arseny30@gmail.com) 2014-2018
|
||||
//
|
||||
// Distributed under the Boost Software License, Version 1.0. (See accompanying
|
||||
// file LICENSE_1_0.txt or copy at http://www.boost.org/LICENSE_1_0.txt)
|
||||
//
|
||||
#include "td/net/SslFd.h"
|
||||
|
||||
#include "td/utils/logging.h"
|
||||
#include "td/utils/StackAllocator.h"
|
||||
#include "td/utils/StringBuilder.h"
|
||||
#include "td/utils/Time.h"
|
||||
|
||||
#include <openssl/err.h>
|
||||
#include <openssl/evp.h>
|
||||
#include <openssl/ssl.h>
|
||||
#include <openssl/x509v3.h>
|
||||
|
||||
#include <map>
|
||||
#include <mutex>
|
||||
|
||||
namespace td {
|
||||
|
||||
#if !TD_WINDOWS
|
||||
static int verify_callback(int preverify_ok, X509_STORE_CTX *ctx) {
|
||||
if (!preverify_ok) {
|
||||
char buf[256];
|
||||
X509_NAME_oneline(X509_get_subject_name(X509_STORE_CTX_get_current_cert(ctx)), buf, 256);
|
||||
|
||||
int err = X509_STORE_CTX_get_error(ctx);
|
||||
auto warning = PSTRING() << "verify error:num=" << err << ":" << X509_verify_cert_error_string(err)
|
||||
<< ":depth=" << X509_STORE_CTX_get_error_depth(ctx) << ":" << buf;
|
||||
double now = Time::now();
|
||||
|
||||
static std::mutex warning_mutex;
|
||||
{
|
||||
std::lock_guard<std::mutex> lock(warning_mutex);
|
||||
static std::map<std::string, double> next_warning_time;
|
||||
double &next = next_warning_time[warning];
|
||||
if (next <= now) {
|
||||
next = now + 300; // one warning per 5 minutes
|
||||
LOG(WARNING) << warning;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return preverify_ok;
|
||||
}
|
||||
#endif
|
||||
|
||||
namespace {
|
||||
|
||||
Status create_openssl_error(int code, Slice message) {
|
||||
const int buf_size = 1 << 12;
|
||||
auto buf = StackAllocator::alloc(buf_size);
|
||||
StringBuilder sb(buf.as_slice());
|
||||
|
||||
sb << message;
|
||||
while (unsigned long error_code = ERR_get_error()) {
|
||||
sb << "{" << error_code << ", " << ERR_error_string(error_code, nullptr) << "}";
|
||||
}
|
||||
LOG_IF(ERROR, sb.is_error()) << "OPENSSL error buffer overflow";
|
||||
return Status::Error(code, sb.as_cslice());
|
||||
}
|
||||
|
||||
void openssl_clear_errors(Slice from) {
|
||||
if (ERR_peek_error() != 0) {
|
||||
LOG(ERROR) << from << ": " << create_openssl_error(0, "Unprocessed OPENSSL_ERROR");
|
||||
}
|
||||
errno = 0;
|
||||
}
|
||||
|
||||
void do_ssl_shutdown(SSL *ssl_handle) {
|
||||
if (!SSL_is_init_finished(ssl_handle)) {
|
||||
return;
|
||||
}
|
||||
openssl_clear_errors("Before SSL_shutdown");
|
||||
SSL_set_quiet_shutdown(ssl_handle, 1);
|
||||
SSL_shutdown(ssl_handle);
|
||||
openssl_clear_errors("After SSL_shutdown");
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
SslFd::SslFd(SocketFd &&fd, SSL *ssl_handle_, SSL_CTX *ssl_ctx_)
|
||||
: fd_(std::move(fd)), ssl_handle_(ssl_handle_), ssl_ctx_(ssl_ctx_) {
|
||||
}
|
||||
|
||||
Result<SslFd> SslFd::init(SocketFd fd, CSlice host, CSlice cert_file, VerifyPeer verify_peer) {
|
||||
#if TD_WINDOWS
|
||||
return Status::Error("TODO");
|
||||
#else
|
||||
static bool init_openssl = [] {
|
||||
#if OPENSSL_VERSION_NUMBER >= 0x10100000L
|
||||
return OPENSSL_init_ssl(0, nullptr) != 0;
|
||||
#else
|
||||
OpenSSL_add_all_algorithms();
|
||||
SSL_load_error_strings();
|
||||
return OpenSSL_add_ssl_algorithms() != 0;
|
||||
#endif
|
||||
}();
|
||||
CHECK(init_openssl);
|
||||
|
||||
openssl_clear_errors("Before SslFd::init");
|
||||
CHECK(!fd.empty());
|
||||
|
||||
auto ssl_method =
|
||||
#if OPENSSL_VERSION_NUMBER >= 0x10100000L
|
||||
TLS_client_method();
|
||||
#else
|
||||
SSLv23_client_method();
|
||||
#endif
|
||||
if (ssl_method == nullptr) {
|
||||
return create_openssl_error(-6, "Failed to create an SSL client method");
|
||||
}
|
||||
|
||||
auto ssl_ctx = SSL_CTX_new(ssl_method);
|
||||
if (ssl_ctx == nullptr) {
|
||||
return create_openssl_error(-7, "Failed to create an SSL context");
|
||||
}
|
||||
auto ssl_ctx_guard = ScopeExit() + [&]() { SSL_CTX_free(ssl_ctx); };
|
||||
long options = 0;
|
||||
#ifdef SSL_OP_NO_SSLv2
|
||||
options |= SSL_OP_NO_SSLv2;
|
||||
#endif
|
||||
#ifdef SSL_OP_NO_SSLv3
|
||||
options |= SSL_OP_NO_SSLv3;
|
||||
#endif
|
||||
SSL_CTX_set_options(ssl_ctx, options);
|
||||
SSL_CTX_set_mode(ssl_ctx, SSL_MODE_ACCEPT_MOVING_WRITE_BUFFER | SSL_MODE_ENABLE_PARTIAL_WRITE);
|
||||
|
||||
if (cert_file.empty()) {
|
||||
SSL_CTX_set_default_verify_paths(ssl_ctx);
|
||||
} else {
|
||||
if (SSL_CTX_load_verify_locations(ssl_ctx, cert_file.c_str(), nullptr) == 0) {
|
||||
return create_openssl_error(-8, "Failed to set custom cert file");
|
||||
}
|
||||
}
|
||||
if (VERIFY_PEER && verify_peer == VerifyPeer::On) {
|
||||
SSL_CTX_set_verify(ssl_ctx, SSL_VERIFY_PEER, verify_callback);
|
||||
|
||||
if (VERIFY_DEPTH != -1) {
|
||||
SSL_CTX_set_verify_depth(ssl_ctx, VERIFY_DEPTH);
|
||||
}
|
||||
} else {
|
||||
SSL_CTX_set_verify(ssl_ctx, SSL_VERIFY_NONE, nullptr);
|
||||
}
|
||||
|
||||
// TODO(now): cipher list
|
||||
string cipher_list;
|
||||
if (SSL_CTX_set_cipher_list(ssl_ctx, cipher_list.empty() ? "DEFAULT" : cipher_list.c_str()) == 0) {
|
||||
return create_openssl_error(-9, PSLICE() << "Failed to set cipher list \"" << cipher_list << '"');
|
||||
}
|
||||
|
||||
auto ssl_handle = SSL_new(ssl_ctx);
|
||||
if (ssl_handle == nullptr) {
|
||||
return create_openssl_error(-13, "Failed to create an SSL handle");
|
||||
}
|
||||
auto ssl_handle_guard = ScopeExit() + [&]() {
|
||||
do_ssl_shutdown(ssl_handle);
|
||||
SSL_free(ssl_handle);
|
||||
};
|
||||
|
||||
#if OPENSSL_VERSION_NUMBER >= 0x10002000L
|
||||
X509_VERIFY_PARAM *param = SSL_get0_param(ssl_handle);
|
||||
/* Enable automatic hostname checks */
|
||||
// TODO: X509_CHECK_FLAG_NO_PARTIAL_WILDCARDS
|
||||
X509_VERIFY_PARAM_set_hostflags(param, 0);
|
||||
X509_VERIFY_PARAM_set1_host(param, host.c_str(), 0);
|
||||
#else
|
||||
#warning DANGEROUS! HTTPS HOST WILL NOT BE CHECKED. INSTALL OPENSSL >= 1.0.2 OR IMPLEMENT HTTPS HOST CHECK MANUALLY
|
||||
#endif
|
||||
|
||||
if (!SSL_set_fd(ssl_handle, fd.get_fd().get_native_fd())) {
|
||||
return create_openssl_error(-14, "Failed to set fd");
|
||||
}
|
||||
|
||||
#if OPENSSL_VERSION_NUMBER >= 0x0090806fL && !defined(OPENSSL_NO_TLSEXT)
|
||||
auto host_str = host.str();
|
||||
SSL_set_tlsext_host_name(ssl_handle, MutableCSlice(host_str).begin());
|
||||
#endif
|
||||
SSL_set_connect_state(ssl_handle);
|
||||
|
||||
ssl_ctx_guard.dismiss();
|
||||
ssl_handle_guard.dismiss();
|
||||
return SslFd(std::move(fd), ssl_handle, ssl_ctx);
|
||||
#endif
|
||||
}
|
||||
|
||||
Result<size_t> SslFd::process_ssl_error(int ret, int *mask) {
|
||||
#if TD_WINDOWS
|
||||
return Status::Error("TODO");
|
||||
#else
|
||||
auto openssl_errno = errno;
|
||||
int error = SSL_get_error(ssl_handle_, ret);
|
||||
switch (error) {
|
||||
case SSL_ERROR_NONE:
|
||||
LOG(ERROR) << "SSL_get_error returned no error";
|
||||
return 0;
|
||||
case SSL_ERROR_ZERO_RETURN:
|
||||
LOG(DEBUG) << "SSL_ERROR_ZERO_RETURN";
|
||||
fd_.get_fd().update_flags(Fd::Close);
|
||||
write_mask_ |= Fd::Error;
|
||||
*mask |= Fd::Error;
|
||||
return 0;
|
||||
case SSL_ERROR_WANT_READ:
|
||||
LOG(DEBUG) << "SSL_ERROR_WANT_READ";
|
||||
fd_.get_fd().clear_flags(Fd::Read);
|
||||
*mask |= Fd::Read;
|
||||
return 0;
|
||||
case SSL_ERROR_WANT_WRITE:
|
||||
LOG(DEBUG) << "SSL_ERROR_WANT_WRITE";
|
||||
fd_.get_fd().clear_flags(Fd::Write);
|
||||
*mask |= Fd::Write;
|
||||
return 0;
|
||||
case SSL_ERROR_WANT_CONNECT:
|
||||
case SSL_ERROR_WANT_ACCEPT:
|
||||
case SSL_ERROR_WANT_X509_LOOKUP:
|
||||
LOG(DEBUG) << "SSL_ERROR: CONNECT ACCEPT LOOKUP";
|
||||
fd_.get_fd().clear_flags(Fd::Write);
|
||||
*mask |= Fd::Write;
|
||||
return 0;
|
||||
case SSL_ERROR_SYSCALL:
|
||||
LOG(DEBUG) << "SSL_ERROR_SYSCALL";
|
||||
if (ERR_peek_error() == 0) {
|
||||
if (openssl_errno != 0) {
|
||||
CHECK(openssl_errno != EAGAIN);
|
||||
return Status::PosixError(openssl_errno, "SSL_ERROR_SYSCALL");
|
||||
} else {
|
||||
// Socket was closed from the other side, probably. Not an error
|
||||
fd_.get_fd().update_flags(Fd::Close);
|
||||
write_mask_ |= Fd::Error;
|
||||
*mask |= Fd::Error;
|
||||
return 0;
|
||||
}
|
||||
}
|
||||
/* fall through */
|
||||
default:
|
||||
LOG(DEBUG) << "SSL_ERROR Default";
|
||||
fd_.get_fd().update_flags(Fd::Close);
|
||||
write_mask_ |= Fd::Error;
|
||||
read_mask_ |= Fd::Error;
|
||||
return create_openssl_error(1, "SSL error ");
|
||||
}
|
||||
#endif
|
||||
}
|
||||
|
||||
Result<size_t> SslFd::write(Slice slice) {
|
||||
openssl_clear_errors("Before SslFd::write");
|
||||
auto size = SSL_write(ssl_handle_, slice.data(), static_cast<int>(slice.size()));
|
||||
if (size <= 0) {
|
||||
return process_ssl_error(size, &write_mask_);
|
||||
}
|
||||
return size;
|
||||
}
|
||||
Result<size_t> SslFd::read(MutableSlice slice) {
|
||||
openssl_clear_errors("Before SslFd::read");
|
||||
auto size = SSL_read(ssl_handle_, slice.data(), static_cast<int>(slice.size()));
|
||||
if (size <= 0) {
|
||||
return process_ssl_error(size, &read_mask_);
|
||||
}
|
||||
return size;
|
||||
}
|
||||
|
||||
void SslFd::close() {
|
||||
if (fd_.empty()) {
|
||||
CHECK(!ssl_handle_ && !ssl_ctx_);
|
||||
return;
|
||||
}
|
||||
CHECK(ssl_handle_ && ssl_ctx_);
|
||||
do_ssl_shutdown(ssl_handle_);
|
||||
SSL_free(ssl_handle_);
|
||||
ssl_handle_ = nullptr;
|
||||
SSL_CTX_free(ssl_ctx_);
|
||||
ssl_ctx_ = nullptr;
|
||||
fd_.close();
|
||||
}
|
||||
|
||||
} // namespace td
|
@ -1,109 +0,0 @@
|
||||
//
|
||||
// Copyright Aliaksei Levin (levlam@telegram.org), Arseny Smirnov (arseny30@gmail.com) 2014-2018
|
||||
//
|
||||
// Distributed under the Boost Software License, Version 1.0. (See accompanying
|
||||
// file LICENSE_1_0.txt or copy at http://www.boost.org/LICENSE_1_0.txt)
|
||||
//
|
||||
#pragma once
|
||||
|
||||
#include "td/utils/port/Fd.h"
|
||||
#include "td/utils/port/SocketFd.h"
|
||||
#include "td/utils/Slice.h"
|
||||
#include "td/utils/Status.h"
|
||||
|
||||
#include <openssl/ssl.h> // TODO can we remove it from header and make target_link_libraries dependence PRIVATE?
|
||||
|
||||
namespace td {
|
||||
|
||||
class SslFd {
|
||||
public:
|
||||
enum class VerifyPeer { On, Off };
|
||||
static Result<SslFd> init(SocketFd fd, CSlice host, CSlice cert_file = CSlice(),
|
||||
VerifyPeer verify_peer = VerifyPeer::On) TD_WARN_UNUSED_RESULT;
|
||||
|
||||
SslFd(const SslFd &other) = delete;
|
||||
SslFd &operator=(const SslFd &other) = delete;
|
||||
SslFd(SslFd &&other)
|
||||
: fd_(std::move(other.fd_))
|
||||
, write_mask_(other.write_mask_)
|
||||
, read_mask_(other.read_mask_)
|
||||
, ssl_handle_(other.ssl_handle_)
|
||||
, ssl_ctx_(other.ssl_ctx_) {
|
||||
other.ssl_handle_ = nullptr;
|
||||
other.ssl_ctx_ = nullptr;
|
||||
}
|
||||
SslFd &operator=(SslFd &&other) {
|
||||
close();
|
||||
|
||||
fd_ = std::move(other.fd_);
|
||||
write_mask_ = other.write_mask_;
|
||||
read_mask_ = other.read_mask_;
|
||||
ssl_handle_ = other.ssl_handle_;
|
||||
ssl_ctx_ = other.ssl_ctx_;
|
||||
|
||||
other.ssl_handle_ = nullptr;
|
||||
other.ssl_ctx_ = nullptr;
|
||||
return *this;
|
||||
}
|
||||
|
||||
const Fd &get_fd() const {
|
||||
return fd_.get_fd();
|
||||
}
|
||||
|
||||
Fd &get_fd() {
|
||||
return fd_.get_fd();
|
||||
}
|
||||
|
||||
Status get_pending_error() TD_WARN_UNUSED_RESULT {
|
||||
return fd_.get_pending_error();
|
||||
}
|
||||
|
||||
Result<size_t> write(Slice slice) TD_WARN_UNUSED_RESULT;
|
||||
Result<size_t> read(MutableSlice slice) TD_WARN_UNUSED_RESULT;
|
||||
|
||||
void close();
|
||||
|
||||
int32 get_flags() const {
|
||||
int32 res = 0;
|
||||
int32 fd_flags = fd_.get_flags();
|
||||
fd_flags &= ~Fd::Error;
|
||||
if (fd_flags & Fd::Close) {
|
||||
res |= Fd::Close;
|
||||
}
|
||||
write_mask_ &= ~fd_flags;
|
||||
read_mask_ &= ~fd_flags;
|
||||
if (write_mask_ == 0) {
|
||||
res |= Fd::Write;
|
||||
}
|
||||
if (read_mask_ == 0) {
|
||||
res |= Fd::Read;
|
||||
}
|
||||
return res;
|
||||
}
|
||||
|
||||
bool empty() const {
|
||||
return fd_.empty();
|
||||
}
|
||||
|
||||
~SslFd() {
|
||||
close();
|
||||
}
|
||||
|
||||
private:
|
||||
static constexpr bool VERIFY_PEER = true;
|
||||
static constexpr int VERIFY_DEPTH = 10;
|
||||
|
||||
SocketFd fd_;
|
||||
mutable int write_mask_ = 0;
|
||||
mutable int read_mask_ = 0;
|
||||
|
||||
// TODO unique_ptr
|
||||
SSL *ssl_handle_ = nullptr;
|
||||
SSL_CTX *ssl_ctx_ = nullptr;
|
||||
|
||||
SslFd(SocketFd &&fd, SSL *ssl_handle_, SSL_CTX *ssl_ctx_);
|
||||
|
||||
Result<size_t> process_ssl_error(int ret, int *mask) TD_WARN_UNUSED_RESULT;
|
||||
};
|
||||
|
||||
} // namespace td
|
473
tdnet/td/net/SslStream.cpp
Normal file
473
tdnet/td/net/SslStream.cpp
Normal file
@ -0,0 +1,473 @@
|
||||
//
|
||||
// Copyright Aliaksei Levin (levlam@telegram.org), Arseny Smirnov (arseny30@gmail.com) 2014-2018
|
||||
//
|
||||
// Distributed under the Boost Software License, Version 1.0. (See accompanying
|
||||
// file LICENSE_1_0.txt or copy at http://www.boost.org/LICENSE_1_0.txt)
|
||||
//
|
||||
#include "td/net/SslStream.h"
|
||||
|
||||
#include "td/utils/logging.h"
|
||||
#include "td/utils/StackAllocator.h"
|
||||
#include "td/utils/StringBuilder.h"
|
||||
#include "td/utils/Time.h"
|
||||
#include "td/utils/misc.h"
|
||||
|
||||
#include <openssl/err.h>
|
||||
#include <openssl/evp.h>
|
||||
#include <openssl/ssl.h>
|
||||
#include <openssl/x509v3.h>
|
||||
|
||||
#include <map>
|
||||
#include <mutex>
|
||||
|
||||
namespace td {
|
||||
|
||||
namespace detail {
|
||||
namespace {
|
||||
#if OPENSSL_VERSION_NUMBER < 0x10100000L
|
||||
void* BIO_get_data(BIO* b) {
|
||||
return b->ptr;
|
||||
}
|
||||
void BIO_set_data(BIO* b, void* ptr) {
|
||||
b->ptr = ptr;
|
||||
}
|
||||
void BIO_set_init(BIO* b, int init) {
|
||||
b->init = init;
|
||||
}
|
||||
|
||||
int BIO_get_new_index() {
|
||||
return 0;
|
||||
}
|
||||
BIO_METHOD* BIO_meth_new(int type, const char* name) {
|
||||
auto res = new BIO_METHOD();
|
||||
memset(res, 0, sizeof(*res));
|
||||
return res;
|
||||
}
|
||||
|
||||
int BIO_meth_set_write(BIO_METHOD* biom, int (*bwrite)(BIO*, const char*, int)) {
|
||||
biom->bwrite = bwrite;
|
||||
return 1;
|
||||
}
|
||||
int BIO_meth_set_read(BIO_METHOD* biom, int (*bread)(BIO*, char*, int)) {
|
||||
biom->bread = bread;
|
||||
return 1;
|
||||
}
|
||||
int BIO_meth_set_ctrl(BIO_METHOD* biom, long (*ctrl)(BIO*, int, long, void*)) {
|
||||
biom->ctrl = ctrl;
|
||||
return 1;
|
||||
}
|
||||
int BIO_meth_set_create(BIO_METHOD* biom, int (*create)(BIO*)) {
|
||||
biom->create = create;
|
||||
return 1;
|
||||
}
|
||||
int BIO_meth_set_destroy(BIO_METHOD* biom, int (*destroy)(BIO*)) {
|
||||
biom->destroy = destroy;
|
||||
return 1;
|
||||
}
|
||||
#endif
|
||||
|
||||
int strm_create(BIO* b) {
|
||||
BIO_set_init(b, 1);
|
||||
return 1;
|
||||
}
|
||||
int strm_destroy(BIO* b) {
|
||||
return 1;
|
||||
}
|
||||
int strm_read(BIO* b, char* buf, int len);
|
||||
int strm_write(BIO* b, const char* buf, int len);
|
||||
|
||||
long strm_ctrl(BIO* b, int cmd, long num, void* ptr) {
|
||||
switch (cmd) {
|
||||
case BIO_CTRL_FLUSH:
|
||||
return 1;
|
||||
case BIO_CTRL_PUSH:
|
||||
return 0;
|
||||
case BIO_CTRL_POP:
|
||||
return 0;
|
||||
default:
|
||||
LOG(FATAL) << b << " " << cmd << " " << num << " " << ptr;
|
||||
}
|
||||
return 1;
|
||||
}
|
||||
|
||||
BIO_METHOD* BIO_s_sslstream() {
|
||||
static BIO_METHOD* res = [] {
|
||||
BIO_METHOD* res = BIO_meth_new(BIO_get_new_index(), "td::SslStream helper bio");
|
||||
BIO_meth_set_write(res, strm_write);
|
||||
BIO_meth_set_read(res, strm_read);
|
||||
BIO_meth_set_create(res, strm_create);
|
||||
BIO_meth_set_destroy(res, strm_destroy);
|
||||
BIO_meth_set_ctrl(res, strm_ctrl);
|
||||
return res;
|
||||
}();
|
||||
return res;
|
||||
}
|
||||
int verify_callback(int preverify_ok, X509_STORE_CTX* ctx) {
|
||||
if (!preverify_ok) {
|
||||
char buf[256];
|
||||
X509_NAME_oneline(X509_get_subject_name(X509_STORE_CTX_get_current_cert(ctx)), buf, 256);
|
||||
|
||||
int err = X509_STORE_CTX_get_error(ctx);
|
||||
auto warning = PSTRING() << "verify error:num=" << err << ":" << X509_verify_cert_error_string(err)
|
||||
<< ":depth=" << X509_STORE_CTX_get_error_depth(ctx) << ":" << buf;
|
||||
double now = Time::now();
|
||||
|
||||
static std::mutex warning_mutex;
|
||||
{
|
||||
std::lock_guard<std::mutex> lock(warning_mutex);
|
||||
static std::map<std::string, double> next_warning_time;
|
||||
double& next = next_warning_time[warning];
|
||||
if (next <= now) {
|
||||
next = now + 300; // one warning per 5 minutes
|
||||
LOG(WARNING) << warning;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return preverify_ok;
|
||||
}
|
||||
|
||||
Status create_openssl_error(int code, Slice message) {
|
||||
const int buf_size = 1 << 12;
|
||||
auto buf = StackAllocator::alloc(buf_size);
|
||||
StringBuilder sb(buf.as_slice());
|
||||
|
||||
sb << message;
|
||||
while (unsigned long error_code = ERR_get_error()) {
|
||||
sb << "{" << error_code << ", " << ERR_error_string(error_code, nullptr) << "}";
|
||||
}
|
||||
LOG_IF(ERROR, sb.is_error()) << "OPENSSL error buffer overflow";
|
||||
return Status::Error(code, sb.as_cslice());
|
||||
}
|
||||
|
||||
void openssl_clear_errors(Slice from) {
|
||||
if (ERR_peek_error() != 0) {
|
||||
LOG(ERROR) << from << ": " << create_openssl_error(0, "Unprocessed OPENSSL_ERROR");
|
||||
}
|
||||
errno = 0;
|
||||
}
|
||||
|
||||
void do_ssl_shutdown(SSL* ssl_handle) {
|
||||
if (!SSL_is_init_finished(ssl_handle)) {
|
||||
return;
|
||||
}
|
||||
openssl_clear_errors("Before SSL_shutdown");
|
||||
SSL_set_quiet_shutdown(ssl_handle, 1);
|
||||
SSL_shutdown(ssl_handle);
|
||||
openssl_clear_errors("After SSL_shutdown");
|
||||
}
|
||||
} // namespace
|
||||
|
||||
class SslStreamImpl {
|
||||
public:
|
||||
using VerifyPeer = SslStream::VerifyPeer;
|
||||
~SslStreamImpl() {
|
||||
if (!ssl_handle_) {
|
||||
CHECK(!ssl_ctx_ && !bio_);
|
||||
return;
|
||||
}
|
||||
CHECK(ssl_handle_ && ssl_ctx_ && bio_);
|
||||
do_ssl_shutdown(ssl_handle_);
|
||||
SSL_free(ssl_handle_);
|
||||
ssl_handle_ = nullptr;
|
||||
SSL_CTX_free(ssl_ctx_);
|
||||
ssl_ctx_ = nullptr;
|
||||
}
|
||||
Status init(CSlice host, CSlice cert_file, VerifyPeer verify_peer) {
|
||||
static bool init_openssl = [] {
|
||||
#if OPENSSL_VERSION_NUMBER >= 0x10100000L
|
||||
return OPENSSL_init_ssl(0, nullptr) != 0;
|
||||
#else
|
||||
OpenSSL_add_all_algorithms();
|
||||
SSL_load_error_strings();
|
||||
return OpenSSL_add_ssl_algorithms() != 0;
|
||||
#endif
|
||||
}();
|
||||
CHECK(init_openssl);
|
||||
|
||||
openssl_clear_errors("Before SslFd::init");
|
||||
|
||||
auto ssl_method =
|
||||
#if OPENSSL_VERSION_NUMBER >= 0x10100000L
|
||||
TLS_client_method();
|
||||
#else
|
||||
SSLv23_client_method();
|
||||
#endif
|
||||
if (ssl_method == nullptr) {
|
||||
return create_openssl_error(-6, "Failed to create an SSL client method");
|
||||
}
|
||||
|
||||
auto ssl_ctx = SSL_CTX_new(ssl_method);
|
||||
if (ssl_ctx == nullptr) {
|
||||
return create_openssl_error(-7, "Failed to create an SSL context");
|
||||
}
|
||||
auto ssl_ctx_guard = ScopeExit() + [&]() { SSL_CTX_free(ssl_ctx); };
|
||||
long options = 0;
|
||||
#ifdef SSL_OP_NO_SSLv2
|
||||
options |= SSL_OP_NO_SSLv2;
|
||||
#endif
|
||||
#ifdef SSL_OP_NO_SSLv3
|
||||
options |= SSL_OP_NO_SSLv3;
|
||||
#endif
|
||||
SSL_CTX_set_options(ssl_ctx, options);
|
||||
SSL_CTX_set_mode(ssl_ctx, SSL_MODE_ACCEPT_MOVING_WRITE_BUFFER | SSL_MODE_ENABLE_PARTIAL_WRITE);
|
||||
|
||||
if (cert_file.empty()) {
|
||||
SSL_CTX_set_default_verify_paths(ssl_ctx);
|
||||
} else {
|
||||
if (SSL_CTX_load_verify_locations(ssl_ctx, cert_file.c_str(), nullptr) == 0) {
|
||||
return create_openssl_error(-8, "Failed to set custom cert file");
|
||||
}
|
||||
}
|
||||
|
||||
if (verify_peer == VerifyPeer::On) {
|
||||
SSL_CTX_set_verify(ssl_ctx, SSL_VERIFY_PEER, verify_callback);
|
||||
|
||||
if (VERIFY_DEPTH != -1) {
|
||||
SSL_CTX_set_verify_depth(ssl_ctx, VERIFY_DEPTH);
|
||||
}
|
||||
} else {
|
||||
SSL_CTX_set_verify(ssl_ctx, SSL_VERIFY_NONE, nullptr);
|
||||
}
|
||||
|
||||
// TODO(now): cipher list
|
||||
string cipher_list;
|
||||
if (SSL_CTX_set_cipher_list(ssl_ctx, cipher_list.empty() ? "DEFAULT" : cipher_list.c_str()) == 0) {
|
||||
return create_openssl_error(-9, PSLICE() << "Failed to set cipher list \"" << cipher_list << '"');
|
||||
}
|
||||
|
||||
auto ssl_handle = SSL_new(ssl_ctx);
|
||||
if (ssl_handle == nullptr) {
|
||||
return create_openssl_error(-13, "Failed to create an SSL handle");
|
||||
}
|
||||
auto ssl_handle_guard = ScopeExit() + [&]() {
|
||||
do_ssl_shutdown(ssl_handle);
|
||||
SSL_free(ssl_handle);
|
||||
};
|
||||
|
||||
#if OPENSSL_VERSION_NUMBER >= 0x10002000L
|
||||
X509_VERIFY_PARAM* param = SSL_get0_param(ssl_handle);
|
||||
/* Enable automatic hostname checks */
|
||||
// TODO: X509_CHECK_FLAG_NO_PARTIAL_WILDCARDS
|
||||
X509_VERIFY_PARAM_set_hostflags(param, 0);
|
||||
X509_VERIFY_PARAM_set1_host(param, host.c_str(), 0);
|
||||
#else
|
||||
#warning DANGEROUS! HTTPS HOST WILL NOT BE CHECKED. INSTALL OPENSSL >= 1.0.2 OR IMPLEMENT HTTPS HOST CHECK MANUALLY
|
||||
#endif
|
||||
|
||||
auto* bio = BIO_new(BIO_s_sslstream());
|
||||
BIO_set_data(bio, this);
|
||||
SSL_set_bio(ssl_handle, bio, bio);
|
||||
|
||||
#if OPENSSL_VERSION_NUMBER >= 0x0090806fL && !defined(OPENSSL_NO_TLSEXT)
|
||||
auto host_str = host.str();
|
||||
SSL_set_tlsext_host_name(ssl_handle, MutableCSlice(host_str).begin());
|
||||
#endif
|
||||
SSL_set_connect_state(ssl_handle);
|
||||
|
||||
ssl_ctx_guard.dismiss();
|
||||
ssl_handle_guard.dismiss();
|
||||
|
||||
ssl_handle_ = ssl_handle;
|
||||
ssl_ctx_ = ssl_ctx;
|
||||
bio_ = bio;
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
ByteFlowInterface& read_byte_flow() {
|
||||
return read_flow_;
|
||||
}
|
||||
ByteFlowInterface& write_byte_flow() {
|
||||
return write_flow_;
|
||||
}
|
||||
size_t flow_read(MutableSlice slice) {
|
||||
return read_flow_.read(slice);
|
||||
}
|
||||
size_t flow_write(Slice slice) {
|
||||
return write_flow_.write(slice);
|
||||
}
|
||||
|
||||
private:
|
||||
static constexpr int VERIFY_DEPTH = 10;
|
||||
|
||||
SSL* ssl_handle_ = nullptr;
|
||||
SSL_CTX* ssl_ctx_ = nullptr;
|
||||
BIO* bio_ = nullptr;
|
||||
|
||||
friend class SslReadByteFlow;
|
||||
friend class SslWriteByteFlow;
|
||||
|
||||
Result<size_t> write(Slice slice) {
|
||||
openssl_clear_errors("Before SslFd::write");
|
||||
auto size = SSL_write(ssl_handle_, slice.data(), static_cast<int>(slice.size()));
|
||||
if (size <= 0) {
|
||||
return process_ssl_error(size);
|
||||
}
|
||||
return size;
|
||||
}
|
||||
|
||||
Result<size_t> read(MutableSlice slice) {
|
||||
openssl_clear_errors("Before SslFd::read");
|
||||
auto size = SSL_read(ssl_handle_, slice.data(), static_cast<int>(slice.size()));
|
||||
if (size <= 0) {
|
||||
return process_ssl_error(size);
|
||||
}
|
||||
return size;
|
||||
}
|
||||
|
||||
class SslReadByteFlow : public ByteFlowBase {
|
||||
public:
|
||||
SslReadByteFlow(SslStreamImpl* stream) : stream_(stream) {
|
||||
}
|
||||
void loop() override {
|
||||
bool was_append = false;
|
||||
while (true) {
|
||||
auto to_read = output_.prepare_append();
|
||||
auto r_size = stream_->read(to_read);
|
||||
if (r_size.is_error()) {
|
||||
return finish(r_size.move_as_error());
|
||||
}
|
||||
auto size = r_size.move_as_ok();
|
||||
if (size == 0) {
|
||||
break;
|
||||
}
|
||||
output_.confirm_append(size);
|
||||
was_append = true;
|
||||
}
|
||||
if (was_append) {
|
||||
on_output_updated();
|
||||
}
|
||||
}
|
||||
|
||||
size_t read(MutableSlice data) {
|
||||
return input_->advance(std::min(data.size(), input_->size()), data);
|
||||
}
|
||||
|
||||
private:
|
||||
SslStreamImpl* stream_;
|
||||
};
|
||||
|
||||
class SslWriteByteFlow : public ByteFlowBase {
|
||||
public:
|
||||
SslWriteByteFlow(SslStreamImpl* stream) : stream_(stream) {
|
||||
}
|
||||
void loop() override {
|
||||
while (!input_->empty()) {
|
||||
auto to_write = input_->prepare_read();
|
||||
auto r_size = stream_->write(to_write);
|
||||
if (r_size.is_error()) {
|
||||
return finish(r_size.move_as_error());
|
||||
}
|
||||
auto size = r_size.move_as_ok();
|
||||
if (size == 0) {
|
||||
break;
|
||||
}
|
||||
input_->confirm_read(size);
|
||||
}
|
||||
if (output_updated_) {
|
||||
output_updated_ = false;
|
||||
on_output_updated();
|
||||
}
|
||||
}
|
||||
|
||||
size_t write(Slice data) {
|
||||
output_.append(data);
|
||||
output_updated_ = true;
|
||||
return data.size();
|
||||
}
|
||||
|
||||
private:
|
||||
SslStreamImpl* stream_;
|
||||
bool output_updated_{false};
|
||||
};
|
||||
|
||||
SslReadByteFlow read_flow_{this};
|
||||
SslWriteByteFlow write_flow_{this};
|
||||
|
||||
Result<size_t> process_ssl_error(int ret) {
|
||||
auto openssl_errno = errno;
|
||||
int error = SSL_get_error(ssl_handle_, ret);
|
||||
switch (error) {
|
||||
case SSL_ERROR_NONE:
|
||||
LOG(ERROR) << "SSL_get_error returned no error";
|
||||
return 0;
|
||||
case SSL_ERROR_ZERO_RETURN:
|
||||
LOG(DEBUG) << "SSL_ERROR_ZERO_RETURN";
|
||||
return 0;
|
||||
case SSL_ERROR_WANT_READ:
|
||||
LOG(DEBUG) << "SSL_ERROR_WANT_READ";
|
||||
return 0;
|
||||
case SSL_ERROR_WANT_WRITE:
|
||||
LOG(DEBUG) << "SSL_ERROR_WANT_WRITE";
|
||||
return 0;
|
||||
case SSL_ERROR_WANT_CONNECT:
|
||||
case SSL_ERROR_WANT_ACCEPT:
|
||||
case SSL_ERROR_WANT_X509_LOOKUP:
|
||||
LOG(DEBUG) << "SSL_ERROR: CONNECT ACCEPT LOOKUP";
|
||||
return 0;
|
||||
case SSL_ERROR_SYSCALL:
|
||||
LOG(DEBUG) << "SSL_ERROR_SYSCALL";
|
||||
if (ERR_peek_error() == 0) {
|
||||
if (openssl_errno != 0) {
|
||||
CHECK(openssl_errno != EAGAIN);
|
||||
return Status::PosixError(openssl_errno, "SSL_ERROR_SYSCALL");
|
||||
} else {
|
||||
return 0;
|
||||
}
|
||||
}
|
||||
/* fall through */
|
||||
default:
|
||||
LOG(DEBUG) << "SSL_ERROR Default";
|
||||
return create_openssl_error(1, "SSL error ");
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
namespace {
|
||||
int strm_read(BIO* b, char* buf, int len) {
|
||||
auto* stream = reinterpret_cast<SslStreamImpl*>(BIO_get_data(b));
|
||||
CHECK(stream);
|
||||
BIO_clear_retry_flags(b);
|
||||
int res = narrow_cast<int>(stream->flow_read(MutableSlice(buf, len)));
|
||||
if (res == 0) {
|
||||
BIO_set_retry_read(b);
|
||||
return -1;
|
||||
}
|
||||
return res;
|
||||
}
|
||||
int strm_write(BIO* b, const char* buf, int len) {
|
||||
auto* stream = reinterpret_cast<SslStreamImpl*>(BIO_get_data(b));
|
||||
CHECK(stream);
|
||||
BIO_clear_retry_flags(b);
|
||||
return narrow_cast<int>(stream->flow_write(Slice(buf, len)));
|
||||
}
|
||||
} // namespace
|
||||
|
||||
} // namespace detail
|
||||
|
||||
SslStream::SslStream() = default;
|
||||
SslStream::SslStream(SslStream&&) = default;
|
||||
SslStream& SslStream::operator=(SslStream&&) = default;
|
||||
SslStream::~SslStream() = default;
|
||||
|
||||
Result<SslStream> SslStream::create(CSlice host, CSlice cert_file, VerifyPeer verify_peer) {
|
||||
auto impl = std::make_unique<detail::SslStreamImpl>();
|
||||
TRY_STATUS(impl->init(host, cert_file, verify_peer));
|
||||
return SslStream(std::move(impl));
|
||||
}
|
||||
SslStream::SslStream(std::unique_ptr<detail::SslStreamImpl> impl) : impl_(std::move(impl)) {
|
||||
}
|
||||
ByteFlowInterface& SslStream::read_byte_flow() {
|
||||
return impl_->read_byte_flow();
|
||||
}
|
||||
ByteFlowInterface& SslStream::write_byte_flow() {
|
||||
return impl_->write_byte_flow();
|
||||
}
|
||||
size_t SslStream::flow_read(MutableSlice slice) {
|
||||
return impl_->flow_read(slice);
|
||||
}
|
||||
size_t SslStream::flow_write(Slice slice) {
|
||||
return impl_->flow_write(slice);
|
||||
}
|
||||
} // namespace td
|
49
tdnet/td/net/SslStream.h
Normal file
49
tdnet/td/net/SslStream.h
Normal file
@ -0,0 +1,49 @@
|
||||
//
|
||||
// Copyright Aliaksei Levin (levlam@telegram.org), Arseny Smirnov (arseny30@gmail.com) 2014-2018
|
||||
//
|
||||
// Distributed under the Boost Software License, Version 1.0. (See accompanying
|
||||
// file LICENSE_1_0.txt or copy at http://www.boost.org/LICENSE_1_0.txt)
|
||||
//
|
||||
#pragma once
|
||||
|
||||
#include "td/utils/port/Fd.h"
|
||||
#include "td/utils/port/SocketFd.h"
|
||||
#include "td/utils/Slice.h"
|
||||
#include "td/utils/Status.h"
|
||||
#include "td/utils/ByteFlow.h"
|
||||
#include "td/utils/BufferedFd.h"
|
||||
|
||||
namespace td {
|
||||
|
||||
namespace detail {
|
||||
class SslStreamImpl;
|
||||
}
|
||||
|
||||
class SslStream {
|
||||
public:
|
||||
SslStream();
|
||||
SslStream(SslStream &&);
|
||||
SslStream &operator=(SslStream &&);
|
||||
~SslStream();
|
||||
|
||||
enum class VerifyPeer { On, Off };
|
||||
|
||||
static Result<SslStream> create(CSlice host, CSlice cert_file = CSlice(), VerifyPeer verify_peer = VerifyPeer::On);
|
||||
|
||||
ByteFlowInterface &read_byte_flow();
|
||||
ByteFlowInterface &write_byte_flow();
|
||||
|
||||
size_t flow_read(MutableSlice slice);
|
||||
size_t flow_write(Slice slice);
|
||||
|
||||
explicit operator bool() const {
|
||||
return bool(impl_);
|
||||
}
|
||||
|
||||
private:
|
||||
std::unique_ptr<detail::SslStreamImpl> impl_;
|
||||
|
||||
SslStream(std::unique_ptr<detail::SslStreamImpl> impl);
|
||||
};
|
||||
|
||||
} // namespace td
|
@ -8,7 +8,7 @@
|
||||
|
||||
#include "td/net/HttpHeaderCreator.h"
|
||||
#include "td/net/HttpOutboundConnection.h"
|
||||
#include "td/net/SslFd.h"
|
||||
#include "td/net/SslStream.h"
|
||||
|
||||
#include "td/utils/buffer.h"
|
||||
#include "td/utils/HttpUrl.h"
|
||||
@ -23,7 +23,7 @@
|
||||
namespace td {
|
||||
|
||||
Wget::Wget(Promise<HttpQueryPtr> promise, string url, std::vector<std::pair<string, string>> headers, int32 timeout_in,
|
||||
int32 ttl, bool prefer_ipv6, SslFd::VerifyPeer verify_peer)
|
||||
int32 ttl, bool prefer_ipv6, SslStream::VerifyPeer verify_peer)
|
||||
: promise_(std::move(promise))
|
||||
, input_url_(std::move(url))
|
||||
, headers_(std::move(headers))
|
||||
@ -66,14 +66,15 @@ Status Wget::try_init() {
|
||||
|
||||
TRY_RESULT(fd, SocketFd::open(addr));
|
||||
if (url.protocol_ == HttpUrl::Protocol::HTTP) {
|
||||
connection_ =
|
||||
create_actor<HttpOutboundConnection>("Connect", std::move(fd), std::numeric_limits<std::size_t>::max(), 0, 0,
|
||||
connection_ = create_actor<HttpOutboundConnection>("Connect", std::move(fd), SslStream{},
|
||||
std::numeric_limits<std::size_t>::max(), 0, 0,
|
||||
ActorOwn<HttpOutboundConnection::Callback>(actor_id(this)));
|
||||
} else {
|
||||
TRY_RESULT(ssl_fd, SslFd::init(std::move(fd), url.host_, CSlice() /* certificate */, verify_peer_));
|
||||
connection_ =
|
||||
create_actor<HttpOutboundConnection>("Connect", std::move(ssl_fd), std::numeric_limits<std::size_t>::max(), 0,
|
||||
0, ActorOwn<HttpOutboundConnection::Callback>(actor_id(this)));
|
||||
LOG(ERROR) << "HTTPS";
|
||||
TRY_RESULT(ssl_stream, SslStream::create(url.host_, CSlice() /* certificate */, verify_peer_));
|
||||
connection_ = create_actor<HttpOutboundConnection>("Connect", std::move(fd), std::move(ssl_stream),
|
||||
std::numeric_limits<std::size_t>::max(), 0, 0,
|
||||
ActorOwn<HttpOutboundConnection::Callback>(actor_id(this)));
|
||||
}
|
||||
|
||||
send_closure(connection_, &HttpOutboundConnection::write_next, BufferSlice(header));
|
||||
|
@ -8,7 +8,7 @@
|
||||
|
||||
#include "td/net/HttpOutboundConnection.h"
|
||||
#include "td/net/HttpQuery.h"
|
||||
#include "td/net/SslFd.h"
|
||||
#include "td/net/SslStream.h"
|
||||
|
||||
#include "td/actor/actor.h"
|
||||
#include "td/actor/PromiseFuture.h"
|
||||
@ -24,7 +24,7 @@ class Wget : public HttpOutboundConnection::Callback {
|
||||
public:
|
||||
explicit Wget(Promise<HttpQueryPtr> promise, string url, std::vector<std::pair<string, string>> headers = {},
|
||||
int32 timeout_in = 10, int32 ttl = 3, bool prefer_ipv6 = false,
|
||||
SslFd::VerifyPeer verify_peer = SslFd::VerifyPeer::On);
|
||||
SslStream::VerifyPeer verify_peer = SslStream::VerifyPeer::On);
|
||||
|
||||
private:
|
||||
Status try_init();
|
||||
@ -45,7 +45,7 @@ class Wget : public HttpOutboundConnection::Callback {
|
||||
int32 timeout_in_;
|
||||
int32 ttl_;
|
||||
bool prefer_ipv6_ = false;
|
||||
SslFd::VerifyPeer verify_peer_;
|
||||
SslStream::VerifyPeer verify_peer_;
|
||||
};
|
||||
|
||||
} // namespace td
|
||||
|
@ -246,6 +246,10 @@ class ByteFlowSink : public ByteFlowInterface {
|
||||
|
||||
class ByteFlowMoveSink : public ByteFlowInterface {
|
||||
public:
|
||||
ByteFlowMoveSink() = default;
|
||||
ByteFlowMoveSink(ChainBufferWriter *output) {
|
||||
set_output(output);
|
||||
}
|
||||
void set_input(ChainBufferReader *input) final {
|
||||
CHECK(!input_);
|
||||
input_ = input;
|
||||
|
Loading…
Reference in New Issue
Block a user