From 1981e0ccad8a70524e2cbaa9ed7c2098c0207c27 Mon Sep 17 00:00:00 2001 From: yhirose Date: Thu, 20 Jun 2019 18:52:28 -0400 Subject: [PATCH] Add SSL object on Request --- httplib.h | 28 ++++++++++++++++++++-------- test/test.cc | 26 +++++++++++++++++++++----- 2 files changed, 41 insertions(+), 13 deletions(-) diff --git a/httplib.h b/httplib.h index 224cf65..dfa4c2b 100644 --- a/httplib.h +++ b/httplib.h @@ -145,6 +145,10 @@ struct Request { Progress progress; +#ifdef CPPHTTPLIB_OPENSSL_SUPPORT + const SSL *ssl; +#endif + bool has_header(const char *key) const; std::string get_header_value(const char *key, size_t id = 0) const; size_t get_header_value_count(const char *key) const; @@ -256,7 +260,8 @@ public: protected: bool process_request(Stream &strm, bool last_connection, - bool &connection_close); + bool &connection_close, + std::function setup_request = nullptr); size_t keep_alive_max_count_; size_t payload_max_length_; @@ -1828,8 +1833,10 @@ inline bool Server::dispatch_request(Request &req, Response &res, return false; } -inline bool Server::process_request(Stream &strm, bool last_connection, - bool &connection_close) { +inline bool +Server::process_request(Stream &strm, bool last_connection, + bool &connection_close, + std::function setup_request) { const auto bufsiz = 2048; char buf[bufsiz]; @@ -1899,6 +1906,9 @@ inline bool Server::process_request(Stream &strm, bool last_connection, } } + // TODO: Add additional request info + if (setup_request) { setup_request(req); } + if (routing(req, res)) { if (res.status == -1) { res.status = 200; } } else { @@ -2293,7 +2303,7 @@ read_and_close_socket_ssl(socket_t sock, size_t keep_alive_max_count, auto last_connection = count == 1; auto connection_close = false; - ret = callback(strm, last_connection, connection_close); + ret = callback(ssl, strm, last_connection, connection_close); if (!ret || connection_close) { break; } count--; @@ -2301,7 +2311,7 @@ read_and_close_socket_ssl(socket_t sock, size_t keep_alive_max_count, } else { SSLSocketStream strm(sock, ssl); auto dummy_connection_close = false; - ret = callback(strm, true, dummy_connection_close); + ret = callback(ssl, strm, true, dummy_connection_close); } } @@ -2406,8 +2416,10 @@ inline bool SSLServer::read_and_close_socket(socket_t sock) { return detail::read_and_close_socket_ssl( sock, keep_alive_max_count_, ctx_, ctx_mutex_, SSL_accept, [](SSL * /*ssl*/) { return true; }, - [this](Stream &strm, bool last_connection, bool &connection_close) { - return process_request(strm, last_connection, connection_close); + [this](SSL *ssl, Stream &strm, bool last_connection, + bool &connection_close) { + return process_request(strm, last_connection, connection_close, + [&](Request &req) { req.ssl = ssl; }); }); } @@ -2494,7 +2506,7 @@ inline bool SSLClient::read_and_close_socket(socket_t sock, Request &req, SSL_set_tlsext_host_name(ssl, host_.c_str()); return true; }, - [&](Stream &strm, bool /*last_connection*/, + [&](SSL * /*ssl*/, Stream &strm, bool /*last_connection*/, bool &connection_close) { return process_request(strm, req, res, connection_close); }); diff --git a/test/test.cc b/test/test.cc index f6f5eb4..3077ce7 100644 --- a/test/test.cc +++ b/test/test.cc @@ -1384,9 +1384,28 @@ TEST(SSLClientServerTest, ClientCertPresent) { CLIENT_CA_CERT_DIR); ASSERT_TRUE(svr.is_valid()); - svr.Get("/test", [&](const Request &, Response &res) { + svr.Get("/test", [&](const Request &req, Response &res) { res.set_content("test", "text/plain"); svr.stop(); + ASSERT_TRUE(true); + + auto peer_cert = SSL_get_peer_certificate(req.ssl); + ASSERT_TRUE(peer_cert != nullptr); + + auto subject_name = X509_get_subject_name(peer_cert); + ASSERT_TRUE(subject_name != nullptr); + + std::string common_name; + { + char name[BUFSIZ]; + auto name_len = X509_NAME_get_text_by_NID(subject_name, NID_commonName, + name, sizeof(name)); + common_name.assign(name, name_len); + } + + EXPECT_EQ("Common Name", common_name); + + X509_free(peer_cert); }); thread t = thread([&]() { ASSERT_TRUE(svr.listen(HOST, PORT)); }); @@ -1405,10 +1424,7 @@ TEST(SSLClientServerTest, ClientCertMissing) { CLIENT_CA_CERT_DIR); ASSERT_TRUE(svr.is_valid()); - svr.Get("/test", [&](const Request &, Response &res) { - res.set_content("test", "text/plain"); - svr.stop(); - }); + svr.Get("/test", [&](const Request &, Response &) { ASSERT_TRUE(false); }); thread t = thread([&]() { ASSERT_TRUE(svr.listen(HOST, PORT)); });