diff --git a/httplib.h b/httplib.h index 33addc5..3f850d0 100644 --- a/httplib.h +++ b/httplib.h @@ -194,6 +194,7 @@ using socket_t = int; #include #include #include +#include #include #include #include @@ -834,7 +835,8 @@ protected: bool process_request(Stream &strm, const Request &req, Response &res, bool last_connection, bool &connection_close); - std::atomic sock_; + std::set cli_socks_; + std::mutex cli_socks_mutex_; const std::string host_; const int port_; @@ -911,6 +913,7 @@ protected: private: socket_t create_client_socket() const; + bool create_and_connect_socket(socket_t &sock); bool read_response_line(Stream &strm, Response &res); bool write_request(Stream &strm, const Request &req, bool last_connection); bool redirect(const Request &req, Response &res); @@ -1397,7 +1400,9 @@ public: #endif private: +#ifdef CPPHTTPLIB_OPENSSL_SUPPORT bool is_ssl_ = false; +#endif std::shared_ptr cli_; }; @@ -4309,7 +4314,7 @@ inline Client::Client(const std::string &host, int port) inline Client::Client(const std::string &host, int port, const std::string &client_cert_path, const std::string &client_key_path) - : sock_(INVALID_SOCKET), host_(host), port_(port), + : /*cli_sock_(INVALID_SOCKET),*/ host_(host), port_(port), host_and_port_(host_ + ":" + std::to_string(port_)), client_cert_path_(client_cert_path), client_key_path_(client_key_path) {} @@ -4328,6 +4333,20 @@ inline socket_t Client::create_client_socket() const { connection_timeout_usec_, interface_); } +inline bool Client::create_and_connect_socket(socket_t &sock) { + sock = create_client_socket(); + if (sock == INVALID_SOCKET) { return false; } + +#ifdef CPPHTTPLIB_OPENSSL_SUPPORT + if (is_ssl() && !proxy_host_.empty()) { + Response res; + bool error; + if (!connect(sock, res, error)) { return error; } + } +#endif + return true; +} + inline bool Client::read_response_line(Stream &strm, Response &res) { std::array buf; @@ -4347,54 +4366,58 @@ inline bool Client::read_response_line(Stream &strm, Response &res) { } inline bool Client::send(const Request &req, Response &res) { - sock_ = create_client_socket(); - if (sock_ == INVALID_SOCKET) { return false; } + socket_t sock = INVALID_SOCKET; + if (!create_and_connect_socket(sock)) { return false; } -#ifdef CPPHTTPLIB_OPENSSL_SUPPORT - if (is_ssl() && !proxy_host_.empty()) { - bool error; - if (!connect(sock_, res, error)) { return error; } + { + std::lock_guard guard(cli_socks_mutex_); + cli_socks_.insert(sock); } -#endif - return process_and_close_socket( - sock_, 1, - [&](Stream &strm, bool last_connection, bool &connection_close) { + auto ret = process_and_close_socket( + sock, 1, [&](Stream &strm, bool last_connection, bool &connection_close) { return handle_request(strm, req, res, last_connection, connection_close); }); + + { + std::lock_guard guard(cli_socks_mutex_); + cli_socks_.erase(sock); + } + + return ret; } inline bool Client::send(const std::vector &requests, std::vector &responses) { size_t i = 0; while (i < requests.size()) { - sock_ = create_client_socket(); - if (sock_ == INVALID_SOCKET) { return false; } + socket_t sock = INVALID_SOCKET; + if (!create_and_connect_socket(sock)) { return false; } -#ifdef CPPHTTPLIB_OPENSSL_SUPPORT - if (is_ssl() && !proxy_host_.empty()) { - Response res; - bool error; - if (!connect(sock_, res, error)) { return false; } + { + std::lock_guard guard(cli_socks_mutex_); + cli_socks_.insert(sock); } -#endif - if (!process_and_close_socket(sock_, requests.size() - i, - [&](Stream &strm, bool last_connection, - bool &connection_close) -> bool { - auto &req = requests[i++]; - auto res = Response(); - auto ret = handle_request(strm, req, res, - last_connection, - connection_close); - if (ret) { - responses.emplace_back(std::move(res)); - } - return ret; - })) { - return false; + auto ret = process_and_close_socket( + sock, requests.size() - i, + [&](Stream &strm, bool last_connection, + bool &connection_close) -> bool { + auto &req = requests[i++]; + auto res = Response(); + auto ret = + handle_request(strm, req, res, last_connection, connection_close); + if (ret) { responses.emplace_back(std::move(res)); } + return ret; + }); + + { + std::lock_guard guard(cli_socks_mutex_); + cli_socks_.erase(sock); } + + if (!ret) { return false; } } return true; @@ -5062,11 +5085,12 @@ inline std::shared_ptr Client::Options(const char *path, } inline void Client::stop() { - if (sock_ != INVALID_SOCKET) { - std::atomic sock(sock_.exchange(INVALID_SOCKET)); + std::lock_guard guard(cli_socks_mutex_); + for (auto &sock : cli_socks_) { detail::shutdown_socket(sock); detail::close_socket(sock); } + cli_socks_.clear(); } inline void Client::set_timeout_sec(time_t timeout_sec) { diff --git a/test/test.cc b/test/test.cc index 22a540a..506ef06 100644 --- a/test/test.cc +++ b/test/test.cc @@ -1766,14 +1766,19 @@ TEST_F(ServerTest, GetStreamedEndless) { } TEST_F(ServerTest, ClientStop) { - thread t = thread([&]() { - auto res = cli_.Get("/streamed-cancel", - [&](const char *, uint64_t) { return true; }); - ASSERT_TRUE(res == nullptr); - }); + std::vector threads; + for (auto i = 0; i < 10; i++) { + threads.emplace_back(thread([&]() { + auto res = cli_.Get("/streamed-cancel", + [&](const char *, uint64_t) { return true; }); + ASSERT_TRUE(res == nullptr); + })); + } std::this_thread::sleep_for(std::chrono::seconds(1)); cli_.stop(); - t.join(); + for (auto& t: threads) { + t.join(); + } } TEST_F(ServerTest, GetWithRange1) {