From 1e8235932942763cbbe6f2bb48d2d8e7106d112b Mon Sep 17 00:00:00 2001 From: yhirose Date: Sat, 31 Aug 2019 09:06:24 -0400 Subject: [PATCH] Keep-alive connection support on client (Fix #36) --- httplib.h | 190 ++++++++++++++++++++++++++++++++++++--------------- test/test.cc | 36 ++++++++++ 2 files changed, 171 insertions(+), 55 deletions(-) diff --git a/httplib.h b/httplib.h index 38f3501..d864c8b 100644 --- a/httplib.h +++ b/httplib.h @@ -171,6 +171,9 @@ struct Request { Ranges ranges; Match matches; + ContentReceiver content_receiver; + Progress progress; + #ifdef CPPHTTPLIB_OPENSSL_SUPPORT const SSL *ssl; #endif @@ -195,10 +198,6 @@ struct Response { Headers headers; std::string body; - ContentReceiver content_receiver; - - Progress progress; - bool has_header(const char *key) const; std::string get_header_value(const char *key, size_t id = 0) const; size_t get_header_value_count(const char *key) const; @@ -456,7 +455,7 @@ private: Response &res, const std::string &boundary, const std::string &content_type); - virtual bool read_and_close_socket(socket_t sock); + virtual bool process_and_close_socket(socket_t sock); std::atomic is_running_; std::atomic svr_sock_; @@ -533,6 +532,10 @@ public: bool send(Request &req, Response &res); + bool send(std::vector &requests, std::vector& responses); + + void set_keep_alive_max_count(size_t count); + protected: bool process_request(Stream &strm, Request &req, Response &res, bool &connection_close); @@ -541,17 +544,48 @@ protected: const int port_; time_t timeout_sec_; const std::string host_and_port_; + size_t keep_alive_max_count_; private: socket_t create_client_socket() const; bool read_response_line(Stream &strm, Response &res); void write_request(Stream &strm, Request &req); - virtual bool read_and_close_socket(socket_t sock, Request &req, - Response &res); + virtual bool process_and_close_socket( + socket_t sock, size_t request_count, + std::function + callback); + virtual bool is_ssl() const; }; +inline void Get(std::vector &requests, const char *path, const Headers &headers) { + Request req; + req.method = "GET"; + req.path = path; + req.headers = headers; + requests.emplace_back(std::move(req)); +} + +inline void Get(std::vector &requests, const char *path) { + Get(requests, path, Headers()); +} + +inline void Post(std::vector &requests, const char *path, const Headers &headers, const std::string &body, const char *content_type) { + Request req; + req.method = "POST"; + req.path = path; + req.headers = headers; + req.headers.emplace("Content-Type", content_type); + req.body = body; + requests.emplace_back(std::move(req)); +} + +inline void Post(std::vector &requests, const char *path, const std::string &body, const char *content_type) { + Post(requests, path, Headers(), body, content_type); +} + #ifdef CPPHTTPLIB_OPENSSL_SUPPORT class SSLSocketStream : public Stream { public: @@ -580,7 +614,7 @@ public: virtual bool is_valid() const; private: - virtual bool read_and_close_socket(socket_t sock); + virtual bool process_and_close_socket(socket_t sock); SSL_CTX *ctx_; std::mutex ctx_mutex_; @@ -603,8 +637,11 @@ public: long get_openssl_verify_result() const; private: - virtual bool read_and_close_socket(socket_t sock, Request &req, - Response &res); + virtual bool process_and_close_socket( + socket_t sock, size_t request_count, + std::function + callback); virtual bool is_ssl() const; bool verify_host(X509 *server_cert) const; @@ -928,15 +965,18 @@ inline bool wait_until_socket_is_ready(socket_t sock, time_t sec, time_t usec) { } template -inline bool read_and_close_socket(socket_t sock, size_t keep_alive_max_count, - T callback) { +inline bool process_and_close_socket(bool is_client_request, socket_t sock, + size_t keep_alive_max_count, T callback) { + assert(keep_alive_max_count > 0); + bool ret = false; - if (keep_alive_max_count > 0) { + if (keep_alive_max_count > 1) { auto count = keep_alive_max_count; while (count > 0 && - detail::select_read(sock, CPPHTTPLIB_KEEPALIVE_TIMEOUT_SECOND, - CPPHTTPLIB_KEEPALIVE_TIMEOUT_USECOND) > 0) { + (is_client_request || + detail::select_read(sock, CPPHTTPLIB_KEEPALIVE_TIMEOUT_SECOND, + CPPHTTPLIB_KEEPALIVE_TIMEOUT_USECOND) > 0)) { SocketStream strm(sock); auto last_connection = count == 1; auto connection_close = false; @@ -2315,9 +2355,7 @@ inline bool Server::handle_file_request(Request &req, Response &res) { auto type = detail::find_content_type(path); if (type) { res.set_header("Content-Type", type); } res.status = 200; - if (file_request_handler_) { - file_request_handler_(req, res); - } + if (file_request_handler_) { file_request_handler_(req, res); } return true; } } @@ -2398,7 +2436,7 @@ inline bool Server::listen_internal() { break; } - task_queue->enqueue([=]() { read_and_close_socket(sock); }); + task_queue->enqueue([=]() { process_and_close_socket(sock); }); } task_queue->shutdown(); @@ -2528,9 +2566,9 @@ Server::process_request(Stream &strm, bool last_connection, inline bool Server::is_valid() const { return true; } -inline bool Server::read_and_close_socket(socket_t sock) { - return detail::read_and_close_socket( - sock, keep_alive_max_count_, +inline bool Server::process_and_close_socket(socket_t sock) { + return detail::process_and_close_socket( + false, sock, keep_alive_max_count_, [this](Stream &strm, bool last_connection, bool &connection_close) { return process_request(strm, last_connection, connection_close, nullptr); @@ -2540,7 +2578,8 @@ inline bool Server::read_and_close_socket(socket_t sock) { // HTTP client implementation inline Client::Client(const char *host, int port, time_t timeout_sec) : host_(host), port_(port), timeout_sec_(timeout_sec), - host_and_port_(host_ + ":" + std::to_string(port_)) {} + host_and_port_(host_ + ":" + std::to_string(port_)), + keep_alive_max_count_(CPPHTTPLIB_KEEPALIVE_MAX_COUNT) {} inline Client::~Client() {} @@ -2590,7 +2629,37 @@ inline bool Client::send(Request &req, Response &res) { auto sock = create_client_socket(); if (sock == INVALID_SOCKET) { return false; } - return read_and_close_socket(sock, req, res); + return process_and_close_socket( + sock, 1, + [&](Stream &strm, bool /*last_connection*/, bool &connection_close) { + return process_request(strm, req, res, connection_close); + }); +} + +inline bool Client::send(std::vector &requests, std::vector& responses) { + size_t i = 0; + while (i < requests.size()) { + auto sock = create_client_socket(); + if (sock == INVALID_SOCKET) { return false; } + + if (!process_and_close_socket( + sock, requests.size() - i, + [&](Stream &strm, bool last_connection, bool &connection_close) { + auto &req = requests[i]; + auto res = Response(); + i++; + + if (req.path.empty()) { return false; } + if (last_connection) { req.set_header("Connection", "close"); } + auto ret = process_request(strm, req, res, connection_close); + if (ret) { responses.emplace_back(std::move(res)); } + return ret; + })) { + return false; + } + } + + return true; } inline void Client::write_request(Stream &strm, Request &req) { @@ -2677,10 +2746,10 @@ inline bool Client::process_request(Stream &strm, Request &req, Response &res, return true; }; - if (res.content_receiver) { + if (req.content_receiver) { auto offset = std::make_shared(); auto length = get_header_value_uint64(res.headers, "Content-Length", 0); - auto receiver = res.content_receiver; + auto receiver = req.content_receiver; out = [offset, length, receiver](const char *buf, size_t n) { auto ret = receiver(buf, n, *offset, length); (*offset) += n; @@ -2690,7 +2759,7 @@ inline bool Client::process_request(Stream &strm, Request &req, Response &res, int dummy_status; if (!detail::read_content(strm, res, std::numeric_limits::max(), - dummy_status, res.progress, out)) { + dummy_status, req.progress, out)) { return false; } } @@ -2698,13 +2767,13 @@ inline bool Client::process_request(Stream &strm, Request &req, Response &res, return true; } -inline bool Client::read_and_close_socket(socket_t sock, Request &req, - Response &res) { - return detail::read_and_close_socket( - sock, 0, - [&](Stream &strm, bool /*last_connection*/, bool &connection_close) { - return process_request(strm, req, res, connection_close); - }); +inline bool Client::process_and_close_socket( + socket_t sock, size_t request_count, + std::function + callback) { + request_count = std::min(request_count, keep_alive_max_count_); + return detail::process_and_close_socket(true, sock, request_count, callback); } inline bool Client::is_ssl() const { return false; } @@ -2720,10 +2789,9 @@ Client::Get(const char *path, const Headers &headers, Progress progress) { req.method = "GET"; req.path = path; req.headers = headers; + req.progress = progress; auto res = std::make_shared(); - res->progress = progress; - return send(req, *res) ? res : nullptr; } @@ -2741,11 +2809,10 @@ inline std::shared_ptr Client::Get(const char *path, req.method = "GET"; req.path = path; req.headers = headers; + req.content_receiver = content_receiver; + req.progress = progress; auto res = std::make_shared(); - res->content_receiver = content_receiver; - res->progress = progress; - return send(req, *res) ? res : nullptr; } @@ -2930,6 +2997,10 @@ inline std::shared_ptr Client::Options(const char *path, return send(req, *res) ? res : nullptr; } +inline void Client::set_keep_alive_max_count(size_t count) { + keep_alive_max_count_ = count; +} + /* * SSL Implementation */ @@ -2937,10 +3008,13 @@ inline std::shared_ptr Client::Options(const char *path, namespace detail { template -inline bool -read_and_close_socket_ssl(socket_t sock, size_t keep_alive_max_count, - SSL_CTX *ctx, std::mutex &ctx_mutex, - U SSL_connect_or_accept, V setup, T callback) { +inline bool process_and_close_socket_ssl(bool is_client_request, socket_t sock, + size_t keep_alive_max_count, + SSL_CTX *ctx, std::mutex &ctx_mutex, + U SSL_connect_or_accept, V setup, + T callback) { + assert(keep_alive_max_count > 0); + SSL *ssl = nullptr; { std::lock_guard guard(ctx_mutex); @@ -2969,11 +3043,12 @@ read_and_close_socket_ssl(socket_t sock, size_t keep_alive_max_count, bool ret = false; if (SSL_connect_or_accept(ssl) == 1) { - if (keep_alive_max_count > 0) { + if (keep_alive_max_count > 1) { auto count = keep_alive_max_count; while (count > 0 && - detail::select_read(sock, CPPHTTPLIB_KEEPALIVE_TIMEOUT_SECOND, - CPPHTTPLIB_KEEPALIVE_TIMEOUT_USECOND) > 0) { + (is_client_request || + detail::select_read(sock, CPPHTTPLIB_KEEPALIVE_TIMEOUT_SECOND, + CPPHTTPLIB_KEEPALIVE_TIMEOUT_USECOND) > 0)) { SSLSocketStream strm(sock, ssl); auto last_connection = count == 1; auto connection_close = false; @@ -3123,9 +3198,9 @@ inline SSLServer::~SSLServer() { inline bool SSLServer::is_valid() const { return ctx_; } -inline bool SSLServer::read_and_close_socket(socket_t sock) { - return detail::read_and_close_socket_ssl( - sock, keep_alive_max_count_, ctx_, ctx_mutex_, SSL_accept, +inline bool SSLServer::process_and_close_socket(socket_t sock) { + return detail::process_and_close_socket_ssl( + false, sock, keep_alive_max_count_, ctx_, ctx_mutex_, SSL_accept, [](SSL * /*ssl*/) { return true; }, [this](SSL *ssl, Stream &strm, bool last_connection, bool &connection_close) { @@ -3176,12 +3251,17 @@ inline long SSLClient::get_openssl_verify_result() const { return verify_result_; } -inline bool SSLClient::read_and_close_socket(socket_t sock, Request &req, - Response &res) { +inline bool SSLClient::process_and_close_socket( + socket_t sock, size_t request_count, + std::function + callback) { + + request_count = std::min(request_count, keep_alive_max_count_); return is_valid() && - detail::read_and_close_socket_ssl( - sock, 0, ctx_, ctx_mutex_, + detail::process_and_close_socket_ssl( + true, sock, request_count, ctx_, ctx_mutex_, [&](SSL *ssl) { if (ca_cert_file_path_.empty()) { SSL_CTX_set_verify(ctx_, SSL_VERIFY_NONE, nullptr); @@ -3217,9 +3297,9 @@ inline bool SSLClient::read_and_close_socket(socket_t sock, Request &req, SSL_set_tlsext_host_name(ssl, host_.c_str()); return true; }, - [&](SSL * /*ssl*/, Stream &strm, bool /*last_connection*/, + [&](SSL * /*ssl*/, Stream &strm, bool last_connection, bool &connection_close) { - return process_request(strm, req, res, connection_close); + return callback(strm, last_connection, connection_close); }); } diff --git a/test/test.cc b/test/test.cc index 7413b22..12558e7 100644 --- a/test/test.cc +++ b/test/test.cc @@ -1280,6 +1280,42 @@ TEST_F(ServerTest, NoMultipleHeaders) { EXPECT_EQ(200, res->status); } +TEST_F(ServerTest, KeepAlive) { + cli_.set_keep_alive_max_count(4); + + std::vector requests; + Get(requests, "/hi"); + Get(requests, "/hi"); + Get(requests, "/hi"); + Get(requests, "/not-exist"); + Post(requests, "/empty", "", "text/plain"); + + std::vector responses; + auto ret = cli_.send(requests, responses); + + ASSERT_TRUE(ret == true); + ASSERT_TRUE(requests.size() == responses.size()); + + for (int i = 0; i < 3; i++) { + auto& res = responses[i]; + EXPECT_EQ(200, res.status); + EXPECT_EQ("text/plain", res.get_header_value("Content-Type")); + EXPECT_EQ("Hello World!", res.body); + } + + { + auto& res = responses[3]; + EXPECT_EQ(404, res.status); + } + + { + auto& res = responses[4]; + EXPECT_EQ(200, res.status); + EXPECT_EQ("text/plain", res.get_header_value("Content-Type")); + EXPECT_EQ("empty", res.body); + } +} + #ifdef CPPHTTPLIB_ZLIB_SUPPORT TEST_F(ServerTest, Gzip) { Headers headers;