// Copyright (c) 2009- Facebook // Distributed under the Thrift Software License // // See accompanying file LICENSE or visit the Thrift site at: // http://developers.facebook.com/thrift/ #ifndef THRIFT_TRANSPORT_TSSLSOCKET_H_ #define THRIFT_TRANSPORT_TSSLSOCKET_H_ 1 #include #include #include #include "thrift/lib/cpp/concurrency/Mutex.h" #include "thrift/lib/cpp/transport/TSocket.h" namespace apache { namespace thrift { namespace transport { class PasswordCollector; class SSLContext; class TSocketAddress; /** * OpenSSL implementation for SSL socket interface. */ class TSSLSocket: public TVirtualTransport { public: /** * Constructor. */ explicit TSSLSocket(const boost::shared_ptr& ctx); /** * Constructor, create an instance of TSSLSocket given an existing socket. * * @param socket An existing socket */ TSSLSocket(const boost::shared_ptr& ctx, int socket); /** * Constructor. * * @param host Remote host name * @param port Remote port number */ TSSLSocket(const boost::shared_ptr& ctx, const std::string& host, int port); /** * Constructor. */ TSSLSocket(const boost::shared_ptr& ctx, const TSocketAddress& address); /** * Destructor. */ ~TSSLSocket(); /** * TTransport interface. */ bool isOpen(); bool peek(); void open(); void close(); uint32_t read(uint8_t* buf, uint32_t len); void write(const uint8_t* buf, uint32_t len); void flush(); /** * Set whether to use client or server side SSL handshake protocol. * * @param flag Use server side handshake protocol if true. */ void server(bool flag) { server_ = flag; } /** * Determine whether the SSL socket is server or client mode. */ bool server() const { return server_; } protected: /** * Verify peer certificate after SSL handshake completes. */ virtual void verifyCertificate(); /** * Initiate SSL handshake if not already initiated. */ void checkHandshake(); bool server_; SSL* ssl_; boost::shared_ptr ctx_; }; /** * SSL socket factory. SSL sockets should be created via SSL factory. */ class TSSLSocketFactory { public: /** * Constructor/Destructor */ explicit TSSLSocketFactory(const boost::shared_ptr& context); virtual ~TSSLSocketFactory(); /** * Create an instance of TSSLSocket with a fresh new socket. */ virtual boost::shared_ptr createSocket(); /** * Create an instance of TSSLSocket with the given socket. * * @param socket An existing socket. */ virtual boost::shared_ptr createSocket(int socket); /** * Create an instance of TSSLSocket. * * @param host Remote host to be connected to * @param port Remote port to be connected to */ virtual boost::shared_ptr createSocket(const std::string& host, int port); /** * Set/Unset server mode. * * @param flag Server mode if true */ virtual void server(bool flag) { server_ = flag; } /** * Determine whether the socket is in server or client mode. * * @return true, if server mode, or, false, if client mode */ virtual bool server() const { return server_; } private: boost::shared_ptr ctx_; bool server_; }; /** * SSL exception. */ class TSSLException: public TTransportException { public: explicit TSSLException(const std::string& message): TTransportException(TTransportException::INTERNAL_ERROR, message) {} virtual const char* what() const throw() { if (message_.empty()) { return "TSSLException"; } else { return message_.c_str(); } } }; /** * Wrap OpenSSL SSL_CTX into a class. */ class SSLContext { public: enum SSLVersion { SSLv2, SSLv3, TLSv1 }; /** * Constructor. * * @param version The lowest or oldest SSL version to support. */ explicit SSLContext(SSLVersion version = TLSv1); virtual ~SSLContext(); /** * Set ciphers to be used in SSL handshake process. * * @param ciphers A list of ciphers */ virtual void ciphers(const std::string& enable); /** * Enable/Disable authentication. Peer name validation can only be done * if checkPeerCert is true. * * @param checkPeerCert If true, require peer to present valid certificate * @param checkPeerName If true, validate that the certificate common name * or alternate name(s) of peer matches the hostname * used to connect. * @param peerName If non-empty, validate that the certificate common * name of peer matches the given string (altername * name(s) are not used in this case). */ virtual void authenticate(bool checkPeerCert, bool checkPeerName, const std::string& peerName = std::string()); /** * Load server certificate. * * @param path Path to the certificate file * @param format Certificate file format */ virtual void loadCertificate(const char* path, const char* format = "PEM"); /** * Load private key. * * @param path Path to the private key file * @param format Private key file format */ virtual void loadPrivateKey(const char* path, const char* format = "PEM"); /** * Load trusted certificates from specified file. * * @param path Path to trusted certificate file */ virtual void loadTrustedCertificates(const char* path); /** * Load trusted certificates from specified X509 certificate store. * * @param store X509 certificate store. */ virtual void loadTrustedCertificates(X509_STORE* store); /** * Default randomize method. */ virtual void randomize(); /** * Override default OpenSSL password collector. * * @param collector Instance of user defined password collector */ virtual void passwordCollector(boost::shared_ptr collector); /** * Obtain password collector. * * @return User defined password collector */ virtual boost::shared_ptr passwordCollector() { return collector_; } /** * Create an SSL object from this context. */ SSL* createSSL() const; /** * Possibly validate the peer's certificate name, depending on how this * SSLContext was configured by authenticate(). * * @return True if the peer's name is acceptable, false otherwise */ bool validatePeerName(TSSLSocket* sock, SSL* ssl) const; /** * Set the options on the SSL_CTX object. */ void setOptions(long options); #ifdef OPENSSL_NPN_NEGOTIATED /** * Set the list of protocols that a TLS server should advertise for * Next Protocol Negotiation (NPN). * * @param protocols List of protocol names, or NULL to disable NPN. * Note: if non-null, this method makes a copy, so * the caller needn't keep the list in scope after * the call completes. */ void setAdvertisedNextProtocols(const std::list* protocols); #endif // OPENSSL_NPN_NEGOTIATED /** * Gets the underlying SSL_CTX for advanced usage */ SSL_CTX *getSSLCtx() const { return ctx_; } enum SSLLockType { LOCK_MUTEX, LOCK_SPINLOCK, LOCK_NONE }; /** * Set preferences for how to treat locks in OpenSSL. This must be * called before the instantiation of any SSLContext objects, otherwise * the defaults will be used. * * OpenSSL has a lock for each module rather than for each object or * data that needs locking. Some locks protect only refcounts, and * might be better as spinlocks rather than mutexes. Other locks * may be totally unnecessary if the objects being protected are not * shared between threads in the application. * * By default, all locks are initialized as mutexes. OpenSSL's lock usage * may change from version to version and you should know what you are doing * before disabling any locks entirely. * * Example: if you don't share SSL sessions between threads in your * application, you may be able to do this * * setSSLLockTypes({{CRYPTO_LOCK_SSL_SESSION, SSLContext::LOCK_NONE}}) */ static void setSSLLockTypes(std::map lockTypes); protected: SSL_CTX* ctx_; private: bool checkPeerName_; std::string peerFixedName_; boost::shared_ptr collector_; static concurrency::Mutex mutex_; static uint64_t count_; #ifdef OPENSSL_NPN_NEGOTIATED /** * Wire-format list of advertised protocols for use in NPN. */ unsigned char* advertisedNextProtocols_; unsigned advertisedNextProtocolsLength_; static int advertisedNextProtocolCallback(SSL* ssl, const unsigned char** out, unsigned int* outlen, void* data); #endif // OPENSSL_NPN_NEGOTIATED static int passwordCallback(char* password, int size, int, void* data); static void initializeOpenSSL(); static void cleanupOpenSSL(); /** * Helper to match a hostname versus a pattern. */ static bool matchName(const char* host, const char* pattern, int size); }; typedef boost::shared_ptr SSLContextPtr; /** * Override the default password collector. */ class PasswordCollector { public: virtual ~PasswordCollector() {} /** * Interface for customizing how to collect private key password. * * By default, OpenSSL prints a prompt on screen and request for password * while loading private key. To implement a custom password collector, * implement this interface and register it with TSSLSocketFactory. * * @param password Pass collected password back to OpenSSL * @param size Maximum length of password including NULL character */ virtual void getPassword(std::string& password, int size) = 0; }; }}} #endif