diff --git a/httplib.h b/httplib.h index b3ad3ad..43fd2bb 100644 --- a/httplib.h +++ b/httplib.h @@ -376,7 +376,7 @@ private: class SSLServer : public Server { public: - SSLServer(const char *cert_path, const char *private_key_path); + SSLServer(const char *cert_path, const char *private_key_path, const char *client_CA_cert_path, const char *trusted_cert_path); virtual ~SSLServer(); @@ -387,11 +387,14 @@ private: SSL_CTX *ctx_; std::mutex ctx_mutex_; + const char *client_CA_cert_path_; + const char *trusted_cert_path_; }; class SSLClient : public Client { public: - SSLClient(const char *host, int port = 443, time_t timeout_sec = 300); + SSLClient(const char *host, int port = 443, time_t timeout_sec = 300, + const char *client_cert_path = nullptr, const char *client_key_path = nullptr); virtual ~SSLClient(); @@ -2234,7 +2237,9 @@ read_and_close_socket_ssl(socket_t sock, size_t keep_alive_max_count, // TODO: OpenSSL 1.0.2 occasionally crashes... // The upcoming 1.1.0 is going to be thread safe. SSL_CTX *ctx, std::mutex &ctx_mutex, - U SSL_connect_or_accept, V setup, T callback) { + U SSL_connect_or_accept, V setup, T callback, + const char* client_CA_cert_path = nullptr, + const char* trusted_cert_path = nullptr) { SSL *ssl = nullptr; { std::lock_guard guard(ctx_mutex); @@ -2260,9 +2265,24 @@ read_and_close_socket_ssl(socket_t sock, size_t keep_alive_max_count, return false; } + if(client_CA_cert_path){ + STACK_OF(X509_NAME)* list; + //list of client CAs to request from client + list = SSL_load_client_CA_file(client_CA_cert_path); + SSL_set_client_CA_list(ssl, list); + //certificate chain to verify received client certificate against + //please run c_rehash in the cert folder first + SSL_CTX_load_verify_locations(ctx,client_CA_cert_path,trusted_cert_path); + } + bool ret = false; if (SSL_connect_or_accept(ssl) == 1) { + /* + auto client_cert = SSL_get_peer_certificate(ssl); + if(client_cert) + printf("Connected client: %s\n", client_cert->name); + */ if (keep_alive_max_count > 0) { auto count = keep_alive_max_count; while (count > 0 && @@ -2338,7 +2358,11 @@ inline std::string SSLSocketStream::get_remote_addr() const { // SSL HTTP server implementation inline SSLServer::SSLServer(const char *cert_path, - const char *private_key_path) { + const char *private_key_path, + const char *client_CA_cert_path = nullptr, + const char *trusted_cert_path = nullptr) + : client_CA_cert_path_(client_CA_cert_path), + trusted_cert_path_(trusted_cert_path){ ctx_ = SSL_CTX_new(SSLv23_server_method()); if (ctx_) { @@ -2356,6 +2380,11 @@ inline SSLServer::SSLServer(const char *cert_path, 1) { SSL_CTX_free(ctx_); ctx_ = nullptr; + } else if(client_CA_cert_path_) { + SSL_CTX_set_verify(ctx_, + SSL_VERIFY_PEER | SSL_VERIFY_FAIL_IF_NO_PEER_CERT, //SSL_VERIFY_CLIENT_ONCE, + nullptr + ); } } } @@ -2372,11 +2401,14 @@ inline bool SSLServer::read_and_close_socket(socket_t sock) { [](SSL * /*ssl*/) { return true; }, [this](Stream &strm, bool last_connection, bool &connection_close) { return process_request(strm, last_connection, connection_close); - }); + }, + client_CA_cert_path_, + trusted_cert_path_); } // SSL HTTP client implementation -inline SSLClient::SSLClient(const char *host, int port, time_t timeout_sec) +inline SSLClient::SSLClient(const char *host, int port, time_t timeout_sec, + const char *client_cert_path, const char *client_key_path) : Client(host, port, timeout_sec) { ctx_ = SSL_CTX_new(SSLv23_client_method()); @@ -2384,6 +2416,13 @@ inline SSLClient::SSLClient(const char *host, int port, time_t timeout_sec) [&](const char *b, const char *e) { host_components_.emplace_back(std::string(b, e)); }); + if(client_cert_path && client_key_path) { + if (SSL_CTX_use_certificate_file(ctx_, client_cert_path, SSL_FILETYPE_PEM) != 1 + ||SSL_CTX_use_PrivateKey_file(ctx_, client_key_path, SSL_FILETYPE_PEM) != 1) { + SSL_CTX_free(ctx_); + ctx_ = nullptr; + } + } } inline SSLClient::~SSLClient() { diff --git a/test/Makefile b/test/Makefile index 0a72252..8a1b572 100644 --- a/test/Makefile +++ b/test/Makefile @@ -15,6 +15,11 @@ test : test.cc ../httplib.h Makefile cert.pem cert.pem: openssl genrsa 2048 > key.pem openssl req -new -batch -config test.conf -key key.pem | openssl x509 -days 3650 -req -signkey key.pem > cert.pem + openssl genrsa 2048 > rootCA.key.pem + openssl req -x509 -new -batch -config test.rootCA.conf -key rootCA.key.pem -days 1024 > rootCA.cert.pem + openssl genrsa 2048 > client.key.pem + openssl req -new -batch -config test.conf -key client.key.pem | openssl x509 -days 370 -req -CA rootCA.cert.pem -CAkey rootCA.key.pem -CAcreateserial > client.cert.pem + #c_rehash . clean: - rm -f test *.pem + rm -f test *.pem *.0 *.1 *.srl diff --git a/test/test.cc b/test/test.cc index afa1363..3637d69 100644 --- a/test/test.cc +++ b/test/test.cc @@ -5,6 +5,10 @@ #define SERVER_CERT_FILE "./cert.pem" #define SERVER_PRIVATE_KEY_FILE "./key.pem" #define CA_CERT_FILE "./ca-bundle.crt" +#define CLIENT_CA_CERT_FILE "./rootCA.cert.pem" +#define CLIENT_CERT_FILE "./client.cert.pem" +#define CLIENT_PRIVATE_KEY_FILE "./client.key.pem" +#define TRUST_CERT_DIR "." #ifdef _WIN32 #include @@ -1374,6 +1378,70 @@ TEST(SSLClientTest, WildcardHostNameMatch) { ASSERT_TRUE(res != nullptr); ASSERT_EQ(200, res->status); } + +TEST(SSLClientServerTest, ClientCertPresent) { + SSLServer svr(SERVER_CERT_FILE, SERVER_PRIVATE_KEY_FILE, CLIENT_CA_CERT_FILE, TRUST_CERT_DIR); + ASSERT_TRUE(svr.is_valid()); + + svr.Get("/test", [&](const Request &, Response &res){ + res.set_content("test", "text/plain"); + svr.stop(); + }); + + thread t = thread([&]() { ASSERT_TRUE(svr.listen(HOST, PORT)); }); + + httplib::SSLClient cli(HOST, PORT, 30, CLIENT_CERT_FILE, CLIENT_PRIVATE_KEY_FILE); + auto res = cli.Get("/test"); + ASSERT_TRUE(res != nullptr); + ASSERT_EQ(200, res->status); + + t.join(); +} + +TEST(SSLClientServerTest, ClientCertMissing) { + SSLServer svr(SERVER_CERT_FILE, SERVER_PRIVATE_KEY_FILE, CLIENT_CA_CERT_FILE, TRUST_CERT_DIR); + ASSERT_TRUE(svr.is_valid()); + + svr.Get("/test", [&](const Request &, Response &res){ + res.set_content("test", "text/plain"); + svr.stop(); + }); + + thread t = thread([&]() { ASSERT_TRUE(svr.listen(HOST, PORT)); }); + + httplib::SSLClient cli(HOST, PORT, 30); + auto res = cli.Get("/test"); + ASSERT_TRUE(res == nullptr); + + svr.stop(); + + t.join(); +} + +TEST(SSLClientServerTest, TrustDirOptional) { + SSLServer svr(SERVER_CERT_FILE, SERVER_PRIVATE_KEY_FILE, CLIENT_CA_CERT_FILE); + ASSERT_TRUE(svr.is_valid()); + + svr.Get("/test", [&](const Request &, Response &res){ + res.set_content("test", "text/plain"); + svr.stop(); + }); + + thread t = thread([&]() { ASSERT_TRUE(svr.listen(HOST, PORT)); }); + + httplib::SSLClient cli(HOST, PORT, 30, CLIENT_CERT_FILE, CLIENT_PRIVATE_KEY_FILE); + auto res = cli.Get("/test"); + ASSERT_TRUE(res != nullptr); + ASSERT_EQ(200, res->status); + + t.join(); +} + +/* Cannot test this case as there is no external access to SSL object to check SSL_get_peer_certificate() == NULL +TEST(SSLClientServerTest, ClientCAPathRequired) { + SSLServer svr(SERVER_CERT_FILE, SERVER_PRIVATE_KEY_FILE, nullptr, TRUST_CERT_DIR); +} +*/ #endif #ifdef _WIN32 diff --git a/test/test.rootCA.conf b/test/test.rootCA.conf new file mode 100644 index 0000000..9d7037d --- /dev/null +++ b/test/test.rootCA.conf @@ -0,0 +1,18 @@ +[req] +default_bits = 2048 +distinguished_name = req_distinguished_name +attributes = req_attributes +prompt = no +output_password = mypass + +[req_distinguished_name] +C = US +ST = Test State or Province +L = Test Locality +O = Organization Name +OU = Organizational Unit Name +CN = Root CA Name +emailAddress = test@email.address + +[req_attributes] +challengePassword = 1234