diff --git a/httplib.h b/httplib.h index 65de818..4607342 100644 --- a/httplib.h +++ b/httplib.h @@ -52,6 +52,7 @@ typedef int socket_t; #include #include #include +#include #include #include #include @@ -204,6 +205,8 @@ public: bool is_running() const; void stop(); + bool is_handling_requests() const; + protected: bool process_request(Stream& strm, bool last_connection); @@ -231,6 +234,10 @@ private: Handlers post_handlers_; Handler error_handler_; Logger logger_; + + // TODO: Use thread pool... + std::mutex running_threads_mutex_; + int running_threads_; }; class Client { @@ -1407,6 +1414,7 @@ inline std::string SocketStream::get_remote_addr() { inline Server::Server(HttpVersion http_version) : http_version_(http_version) , svr_sock_(-1) + , running_threads_(0) { #ifndef _WIN32 signal(SIGPIPE, SIG_IGN); @@ -1476,6 +1484,11 @@ inline void Server::stop() svr_sock_ = -1; } +inline bool Server::is_handling_requests() const +{ + return running_threads_ > 0; +} + inline bool Server::parse_request_line(const char* s, Request& req) { static std::regex re("(GET|HEAD|POST) ([^?]+)(?:\\?(.+?))? (HTTP/1\\.[01])\r\n"); @@ -1632,10 +1645,29 @@ inline bool Server::listen_internal() // TODO: Use thread pool... std::thread([=]() { + { + std::lock_guard guard(running_threads_mutex_); + running_threads_++; + } + read_and_close_socket(sock); + + { + std::lock_guard guard(running_threads_mutex_); + running_threads_--; + } }).detach(); } + // TODO: Use thread pool... + for (;;) { + std::this_thread::sleep_for(std::chrono::milliseconds(10)); + std::lock_guard guard(running_threads_mutex_); + if (!running_threads_) { + break; + } + } + return ret; } diff --git a/test/test.cc b/test/test.cc index de2ebf2..2e91ad7 100644 --- a/test/test.cc +++ b/test/test.cc @@ -262,6 +262,10 @@ protected: svr_.get("/hi", [&](const Request& /*req*/, Response& res) { res.set_content("Hello World!", "text/plain"); }) + .get("/slow", [&](const Request& /*req*/, Response& res) { + msleep(3000); + res.set_content("slow", "text/plain"); + }) .get("/remote_addr", [&](const Request& req, Response& res) { auto remote_addr = req.headers.find("REMOTE_ADDR")->second; res.set_content(remote_addr.c_str(), "text/plain"); @@ -358,6 +362,7 @@ protected: virtual void TearDown() { svr_.stop(); t_.join(); + EXPECT_EQ(false, svr_.is_handling_requests()); } map persons_; @@ -664,6 +669,14 @@ TEST_F(ServerTest, GetMethodRemoteAddr) EXPECT_TRUE(res->body == "::1" || res->body == "127.0.0.1"); } +TEST_F(ServerTest, SlowRequest) +{ + std::thread([=]() { auto res = cli_.get("/slow"); }).detach(); + std::thread([=]() { auto res = cli_.get("/slow"); }).detach(); + std::thread([=]() { auto res = cli_.get("/slow"); }).detach(); + msleep(1000); +} + #ifdef CPPHTTPLIB_ZLIB_SUPPORT TEST_F(ServerTest, Gzip) {