diff --git a/httplib.h b/httplib.h index 723c562..303d63f 100644 --- a/httplib.h +++ b/httplib.h @@ -695,6 +695,8 @@ public: bool send(const std::vector &requests, std::vector &responses); + void stop(); + void set_timeout_sec(time_t timeout_sec); void set_read_timeout(time_t sec, time_t usec); @@ -727,6 +729,8 @@ protected: bool process_request(Stream &strm, const Request &req, Response &res, bool last_connection, bool &connection_close); + std::atomic sock_; + const std::string host_; const int port_; const std::string host_and_port_; @@ -3714,7 +3718,7 @@ inline bool Server::process_and_close_socket(socket_t sock) { inline Client::Client(const std::string &host, int port, const std::string &client_cert_path, const std::string &client_key_path) - : host_(host), port_(port), + : sock_(INVALID_SOCKET), host_(host), port_(port), host_and_port_(host_ + ":" + std::to_string(port_)), client_cert_path_(client_cert_path), client_key_path_(client_key_path) {} @@ -3750,18 +3754,18 @@ inline bool Client::read_response_line(Stream &strm, Response &res) { } inline bool Client::send(const Request &req, Response &res) { - auto sock = create_client_socket(); - if (sock == INVALID_SOCKET) { return false; } + sock_ = create_client_socket(); + if (sock_ == INVALID_SOCKET) { return false; } #ifdef CPPHTTPLIB_OPENSSL_SUPPORT if (is_ssl() && !proxy_host_.empty()) { bool error; - if (!connect(sock, res, error)) { return error; } + if (!connect(sock_, res, error)) { return error; } } #endif return process_and_close_socket( - sock, 1, [&](Stream &strm, bool last_connection, bool &connection_close) { + sock_, 1, [&](Stream &strm, bool last_connection, bool &connection_close) { return handle_request(strm, req, res, last_connection, connection_close); }); @@ -3771,18 +3775,18 @@ inline bool Client::send(const std::vector &requests, std::vector &responses) { size_t i = 0; while (i < requests.size()) { - auto sock = create_client_socket(); - if (sock == INVALID_SOCKET) { return false; } + sock_ = create_client_socket(); + if (sock_ == INVALID_SOCKET) { return false; } #ifdef CPPHTTPLIB_OPENSSL_SUPPORT if (is_ssl() && !proxy_host_.empty()) { Response res; bool error; - if (!connect(sock, res, error)) { return false; } + if (!connect(sock_, res, error)) { return false; } } #endif - if (!process_and_close_socket(sock, requests.size() - i, + if (!process_and_close_socket(sock_, requests.size() - i, [&](Stream &strm, bool last_connection, bool &connection_close) -> bool { auto &req = requests[i++]; @@ -4446,6 +4450,14 @@ inline std::shared_ptr Client::Options(const char *path, return send(req, *res) ? res : nullptr; } +inline void Client::stop() { + if (sock_ != INVALID_SOCKET) { + std::atomic sock(sock_.exchange(INVALID_SOCKET)); + detail::shutdown_socket(sock); + detail::close_socket(sock); + } +} + inline void Client::set_timeout_sec(time_t timeout_sec) { timeout_sec_ = timeout_sec; } diff --git a/test/test.cc b/test/test.cc index b5f5cf5..9945153 100644 --- a/test/test.cc +++ b/test/test.cc @@ -803,7 +803,8 @@ protected: auto remote_addr = req.headers.find("REMOTE_ADDR")->second; EXPECT_TRUE(req.has_header("REMOTE_PORT")); EXPECT_EQ(req.remote_addr, req.get_header_value("REMOTE_ADDR")); - EXPECT_EQ(req.remote_port, std::stoi(req.get_header_value("REMOTE_PORT"))); + EXPECT_EQ(req.remote_port, + std::stoi(req.get_header_value("REMOTE_PORT"))); res.set_content(remote_addr.c_str(), "text/plain"); }) .Get("/endwith%", @@ -979,12 +980,12 @@ protected: res.set_content("empty-no-content-type", "text/plain"); }) .Put("/empty-no-content-type", - [&](const Request &req, Response &res) { - EXPECT_EQ(req.body, ""); - EXPECT_FALSE(req.has_header("Content-Type")); - EXPECT_EQ("0", req.get_header_value("Content-Length")); - res.set_content("empty-no-content-type", "text/plain"); - }) + [&](const Request &req, Response &res) { + EXPECT_EQ(req.body, ""); + EXPECT_FALSE(req.has_header("Content-Type")); + EXPECT_EQ("0", req.get_header_value("Content-Length")); + res.set_content("empty-no-content-type", "text/plain"); + }) .Put("/put", [&](const Request &req, Response &res) { EXPECT_EQ(req.body, "PUT"); @@ -1746,6 +1747,18 @@ TEST_F(ServerTest, GetStreamedEndless) { ASSERT_TRUE(res == nullptr); } +TEST_F(ServerTest, ClientStop) { + thread t = thread([&]() { + auto res = + cli_.Get("/streamed-cancel", + [&](const char *, uint64_t) { return true; }); + ASSERT_TRUE(res == nullptr); + }); + std::this_thread::sleep_for(std::chrono::seconds(1)); + cli_.stop(); + t.join(); +} + TEST_F(ServerTest, GetWithRange1) { auto res = cli_.Get("/with-range", {{make_range_header({{3, 5}})}}); ASSERT_TRUE(res != nullptr); @@ -2323,40 +2336,40 @@ TEST(ServerRequestParsingTest, ReadHeadersRegexComplexity2) { TEST(ServerRequestParsingTest, InvalidFirstChunkLengthInRequest) { std::string out; - test_raw_request( - "PUT /put_hi HTTP/1.1\r\n" - "Content-Type: text/plain\r\n" - "Transfer-Encoding: chunked\r\n" - "\r\n" - "nothex\r\n", &out); + test_raw_request("PUT /put_hi HTTP/1.1\r\n" + "Content-Type: text/plain\r\n" + "Transfer-Encoding: chunked\r\n" + "\r\n" + "nothex\r\n", + &out); EXPECT_EQ("HTTP/1.1 400 Bad Request", out.substr(0, 24)); } TEST(ServerRequestParsingTest, InvalidSecondChunkLengthInRequest) { std::string out; - test_raw_request( - "PUT /put_hi HTTP/1.1\r\n" - "Content-Type: text/plain\r\n" - "Transfer-Encoding: chunked\r\n" - "\r\n" - "3\r\n" - "xyz\r\n" - "NaN\r\n", &out); + test_raw_request("PUT /put_hi HTTP/1.1\r\n" + "Content-Type: text/plain\r\n" + "Transfer-Encoding: chunked\r\n" + "\r\n" + "3\r\n" + "xyz\r\n" + "NaN\r\n", + &out); EXPECT_EQ("HTTP/1.1 400 Bad Request", out.substr(0, 24)); } TEST(ServerRequestParsingTest, ChunkLengthTooHighInRequest) { std::string out; - test_raw_request( - "PUT /put_hi HTTP/1.1\r\n" - "Content-Type: text/plain\r\n" - "Transfer-Encoding: chunked\r\n" - "\r\n" - // Length is too large for 64 bits. - "1ffffffffffffffff\r\n" - "xyz\r\n", &out); + test_raw_request("PUT /put_hi HTTP/1.1\r\n" + "Content-Type: text/plain\r\n" + "Transfer-Encoding: chunked\r\n" + "\r\n" + // Length is too large for 64 bits. + "1ffffffffffffffff\r\n" + "xyz\r\n", + &out); EXPECT_EQ("HTTP/1.1 400 Bad Request", out.substr(0, 24)); }