366 lines
9.8 KiB
C
366 lines
9.8 KiB
C
|
// Copyright (c) 2009- Facebook
|
||
|
// Distributed under the Thrift Software License
|
||
|
//
|
||
|
// See accompanying file LICENSE or visit the Thrift site at:
|
||
|
// http://developers.facebook.com/thrift/
|
||
|
|
||
|
#ifndef THRIFT_TRANSPORT_TSSLSOCKET_H_
|
||
|
#define THRIFT_TRANSPORT_TSSLSOCKET_H_ 1
|
||
|
|
||
|
#include <string>
|
||
|
#include <boost/shared_ptr.hpp>
|
||
|
#include <openssl/ssl.h>
|
||
|
#include "thrift/lib/cpp/concurrency/Mutex.h"
|
||
|
#include "thrift/lib/cpp/transport/TSocket.h"
|
||
|
|
||
|
namespace apache { namespace thrift { namespace transport {
|
||
|
|
||
|
class PasswordCollector;
|
||
|
class SSLContext;
|
||
|
class TSocketAddress;
|
||
|
|
||
|
/**
|
||
|
* OpenSSL implementation for SSL socket interface.
|
||
|
*/
|
||
|
class TSSLSocket: public TVirtualTransport<TSSLSocket, TSocket> {
|
||
|
public:
|
||
|
/**
|
||
|
* Constructor.
|
||
|
*/
|
||
|
explicit TSSLSocket(const boost::shared_ptr<SSLContext>& ctx);
|
||
|
/**
|
||
|
* Constructor, create an instance of TSSLSocket given an existing socket.
|
||
|
*
|
||
|
* @param socket An existing socket
|
||
|
*/
|
||
|
TSSLSocket(const boost::shared_ptr<SSLContext>& ctx, int socket);
|
||
|
/**
|
||
|
* Constructor.
|
||
|
*
|
||
|
* @param host Remote host name
|
||
|
* @param port Remote port number
|
||
|
*/
|
||
|
TSSLSocket(const boost::shared_ptr<SSLContext>& ctx,
|
||
|
const std::string& host,
|
||
|
int port);
|
||
|
/**
|
||
|
* Constructor.
|
||
|
*/
|
||
|
TSSLSocket(const boost::shared_ptr<SSLContext>& ctx,
|
||
|
const TSocketAddress& address);
|
||
|
/**
|
||
|
* Destructor.
|
||
|
*/
|
||
|
~TSSLSocket();
|
||
|
|
||
|
/**
|
||
|
* TTransport interface.
|
||
|
*/
|
||
|
bool isOpen();
|
||
|
bool peek();
|
||
|
void open();
|
||
|
void close();
|
||
|
uint32_t read(uint8_t* buf, uint32_t len);
|
||
|
void write(const uint8_t* buf, uint32_t len);
|
||
|
void flush();
|
||
|
|
||
|
/**
|
||
|
* Set whether to use client or server side SSL handshake protocol.
|
||
|
*
|
||
|
* @param flag Use server side handshake protocol if true.
|
||
|
*/
|
||
|
void server(bool flag) { server_ = flag; }
|
||
|
/**
|
||
|
* Determine whether the SSL socket is server or client mode.
|
||
|
*/
|
||
|
bool server() const { return server_; }
|
||
|
|
||
|
protected:
|
||
|
/**
|
||
|
* Verify peer certificate after SSL handshake completes.
|
||
|
*/
|
||
|
virtual void verifyCertificate();
|
||
|
|
||
|
/**
|
||
|
* Initiate SSL handshake if not already initiated.
|
||
|
*/
|
||
|
void checkHandshake();
|
||
|
|
||
|
bool server_;
|
||
|
SSL* ssl_;
|
||
|
boost::shared_ptr<SSLContext> ctx_;
|
||
|
};
|
||
|
|
||
|
/**
|
||
|
* SSL socket factory. SSL sockets should be created via SSL factory.
|
||
|
*/
|
||
|
class TSSLSocketFactory {
|
||
|
public:
|
||
|
/**
|
||
|
* Constructor/Destructor
|
||
|
*/
|
||
|
explicit TSSLSocketFactory(const boost::shared_ptr<SSLContext>& context);
|
||
|
virtual ~TSSLSocketFactory();
|
||
|
|
||
|
/**
|
||
|
* Create an instance of TSSLSocket with a fresh new socket.
|
||
|
*/
|
||
|
virtual boost::shared_ptr<TSSLSocket> createSocket();
|
||
|
/**
|
||
|
* Create an instance of TSSLSocket with the given socket.
|
||
|
*
|
||
|
* @param socket An existing socket.
|
||
|
*/
|
||
|
virtual boost::shared_ptr<TSSLSocket> createSocket(int socket);
|
||
|
/**
|
||
|
* Create an instance of TSSLSocket.
|
||
|
*
|
||
|
* @param host Remote host to be connected to
|
||
|
* @param port Remote port to be connected to
|
||
|
*/
|
||
|
virtual boost::shared_ptr<TSSLSocket> createSocket(const std::string& host,
|
||
|
int port);
|
||
|
/**
|
||
|
* Set/Unset server mode.
|
||
|
*
|
||
|
* @param flag Server mode if true
|
||
|
*/
|
||
|
virtual void server(bool flag) { server_ = flag; }
|
||
|
/**
|
||
|
* Determine whether the socket is in server or client mode.
|
||
|
*
|
||
|
* @return true, if server mode, or, false, if client mode
|
||
|
*/
|
||
|
virtual bool server() const { return server_; }
|
||
|
|
||
|
private:
|
||
|
boost::shared_ptr<SSLContext> ctx_;
|
||
|
bool server_;
|
||
|
};
|
||
|
|
||
|
/**
|
||
|
* SSL exception.
|
||
|
*/
|
||
|
class TSSLException: public TTransportException {
|
||
|
public:
|
||
|
explicit TSSLException(const std::string& message):
|
||
|
TTransportException(TTransportException::INTERNAL_ERROR, message) {}
|
||
|
|
||
|
virtual const char* what() const throw() {
|
||
|
if (message_.empty()) {
|
||
|
return "TSSLException";
|
||
|
} else {
|
||
|
return message_.c_str();
|
||
|
}
|
||
|
}
|
||
|
};
|
||
|
|
||
|
/**
|
||
|
* Wrap OpenSSL SSL_CTX into a class.
|
||
|
*/
|
||
|
class SSLContext {
|
||
|
public:
|
||
|
|
||
|
enum SSLVersion {
|
||
|
SSLv2,
|
||
|
SSLv3,
|
||
|
TLSv1
|
||
|
};
|
||
|
|
||
|
/**
|
||
|
* Constructor.
|
||
|
*
|
||
|
* @param version The lowest or oldest SSL version to support.
|
||
|
*/
|
||
|
explicit SSLContext(SSLVersion version = TLSv1);
|
||
|
virtual ~SSLContext();
|
||
|
|
||
|
/**
|
||
|
* Set ciphers to be used in SSL handshake process.
|
||
|
*
|
||
|
* @param ciphers A list of ciphers
|
||
|
*/
|
||
|
virtual void ciphers(const std::string& enable);
|
||
|
/**
|
||
|
* Enable/Disable authentication. Peer name validation can only be done
|
||
|
* if checkPeerCert is true.
|
||
|
*
|
||
|
* @param checkPeerCert If true, require peer to present valid certificate
|
||
|
* @param checkPeerName If true, validate that the certificate common name
|
||
|
* or alternate name(s) of peer matches the hostname
|
||
|
* used to connect.
|
||
|
* @param peerName If non-empty, validate that the certificate common
|
||
|
* name of peer matches the given string (altername
|
||
|
* name(s) are not used in this case).
|
||
|
*/
|
||
|
virtual void authenticate(bool checkPeerCert, bool checkPeerName,
|
||
|
const std::string& peerName = std::string());
|
||
|
/**
|
||
|
* Load server certificate.
|
||
|
*
|
||
|
* @param path Path to the certificate file
|
||
|
* @param format Certificate file format
|
||
|
*/
|
||
|
virtual void loadCertificate(const char* path, const char* format = "PEM");
|
||
|
/**
|
||
|
* Load private key.
|
||
|
*
|
||
|
* @param path Path to the private key file
|
||
|
* @param format Private key file format
|
||
|
*/
|
||
|
virtual void loadPrivateKey(const char* path, const char* format = "PEM");
|
||
|
/**
|
||
|
* Load trusted certificates from specified file.
|
||
|
*
|
||
|
* @param path Path to trusted certificate file
|
||
|
*/
|
||
|
virtual void loadTrustedCertificates(const char* path);
|
||
|
/**
|
||
|
* Load trusted certificates from specified X509 certificate store.
|
||
|
*
|
||
|
* @param store X509 certificate store.
|
||
|
*/
|
||
|
virtual void loadTrustedCertificates(X509_STORE* store);
|
||
|
/**
|
||
|
* Default randomize method.
|
||
|
*/
|
||
|
virtual void randomize();
|
||
|
/**
|
||
|
* Override default OpenSSL password collector.
|
||
|
*
|
||
|
* @param collector Instance of user defined password collector
|
||
|
*/
|
||
|
virtual void passwordCollector(boost::shared_ptr<PasswordCollector> collector);
|
||
|
/**
|
||
|
* Obtain password collector.
|
||
|
*
|
||
|
* @return User defined password collector
|
||
|
*/
|
||
|
virtual boost::shared_ptr<PasswordCollector> passwordCollector() {
|
||
|
return collector_;
|
||
|
}
|
||
|
|
||
|
/**
|
||
|
* Create an SSL object from this context.
|
||
|
*/
|
||
|
SSL* createSSL() const;
|
||
|
|
||
|
/**
|
||
|
* Possibly validate the peer's certificate name, depending on how this
|
||
|
* SSLContext was configured by authenticate().
|
||
|
*
|
||
|
* @return True if the peer's name is acceptable, false otherwise
|
||
|
*/
|
||
|
bool validatePeerName(TSSLSocket* sock, SSL* ssl) const;
|
||
|
|
||
|
/**
|
||
|
* Set the options on the SSL_CTX object.
|
||
|
*/
|
||
|
void setOptions(long options);
|
||
|
|
||
|
#ifdef OPENSSL_NPN_NEGOTIATED
|
||
|
/**
|
||
|
* Set the list of protocols that a TLS server should advertise for
|
||
|
* Next Protocol Negotiation (NPN).
|
||
|
*
|
||
|
* @param protocols List of protocol names, or NULL to disable NPN.
|
||
|
* Note: if non-null, this method makes a copy, so
|
||
|
* the caller needn't keep the list in scope after
|
||
|
* the call completes.
|
||
|
*/
|
||
|
void setAdvertisedNextProtocols(const std::list<std::string>* protocols);
|
||
|
#endif // OPENSSL_NPN_NEGOTIATED
|
||
|
|
||
|
/**
|
||
|
* Gets the underlying SSL_CTX for advanced usage
|
||
|
*/
|
||
|
SSL_CTX *getSSLCtx() const {
|
||
|
return ctx_;
|
||
|
}
|
||
|
|
||
|
enum SSLLockType {
|
||
|
LOCK_MUTEX,
|
||
|
LOCK_SPINLOCK,
|
||
|
LOCK_NONE
|
||
|
};
|
||
|
|
||
|
/**
|
||
|
* Set preferences for how to treat locks in OpenSSL. This must be
|
||
|
* called before the instantiation of any SSLContext objects, otherwise
|
||
|
* the defaults will be used.
|
||
|
*
|
||
|
* OpenSSL has a lock for each module rather than for each object or
|
||
|
* data that needs locking. Some locks protect only refcounts, and
|
||
|
* might be better as spinlocks rather than mutexes. Other locks
|
||
|
* may be totally unnecessary if the objects being protected are not
|
||
|
* shared between threads in the application.
|
||
|
*
|
||
|
* By default, all locks are initialized as mutexes. OpenSSL's lock usage
|
||
|
* may change from version to version and you should know what you are doing
|
||
|
* before disabling any locks entirely.
|
||
|
*
|
||
|
* Example: if you don't share SSL sessions between threads in your
|
||
|
* application, you may be able to do this
|
||
|
*
|
||
|
* setSSLLockTypes({{CRYPTO_LOCK_SSL_SESSION, SSLContext::LOCK_NONE}})
|
||
|
*/
|
||
|
static void setSSLLockTypes(std::map<int, SSLLockType> lockTypes);
|
||
|
|
||
|
protected:
|
||
|
SSL_CTX* ctx_;
|
||
|
|
||
|
private:
|
||
|
bool checkPeerName_;
|
||
|
std::string peerFixedName_;
|
||
|
boost::shared_ptr<PasswordCollector> collector_;
|
||
|
|
||
|
static concurrency::Mutex mutex_;
|
||
|
static uint64_t count_;
|
||
|
|
||
|
#ifdef OPENSSL_NPN_NEGOTIATED
|
||
|
/**
|
||
|
* Wire-format list of advertised protocols for use in NPN.
|
||
|
*/
|
||
|
unsigned char* advertisedNextProtocols_;
|
||
|
unsigned advertisedNextProtocolsLength_;
|
||
|
|
||
|
static int advertisedNextProtocolCallback(SSL* ssl,
|
||
|
const unsigned char** out, unsigned int* outlen, void* data);
|
||
|
#endif // OPENSSL_NPN_NEGOTIATED
|
||
|
|
||
|
static int passwordCallback(char* password, int size, int, void* data);
|
||
|
|
||
|
static void initializeOpenSSL();
|
||
|
static void cleanupOpenSSL();
|
||
|
|
||
|
/**
|
||
|
* Helper to match a hostname versus a pattern.
|
||
|
*/
|
||
|
static bool matchName(const char* host, const char* pattern, int size);
|
||
|
};
|
||
|
|
||
|
typedef boost::shared_ptr<SSLContext> SSLContextPtr;
|
||
|
|
||
|
/**
|
||
|
* Override the default password collector.
|
||
|
*/
|
||
|
class PasswordCollector {
|
||
|
public:
|
||
|
virtual ~PasswordCollector() {}
|
||
|
/**
|
||
|
* Interface for customizing how to collect private key password.
|
||
|
*
|
||
|
* By default, OpenSSL prints a prompt on screen and request for password
|
||
|
* while loading private key. To implement a custom password collector,
|
||
|
* implement this interface and register it with TSSLSocketFactory.
|
||
|
*
|
||
|
* @param password Pass collected password back to OpenSSL
|
||
|
* @param size Maximum length of password including NULL character
|
||
|
*/
|
||
|
virtual void getPassword(std::string& password, int size) = 0;
|
||
|
};
|
||
|
|
||
|
}}}
|
||
|
|
||
|
#endif
|