From 24bdb736f06589bfa61baa4eb834364559e80c2f Mon Sep 17 00:00:00 2001 From: yhirose Date: Tue, 9 Jun 2020 19:58:01 -0400 Subject: [PATCH] Fix #506 --- httplib.h | 59 ++++++++++++++++++++++++++++++++++++++-------------- test/test.cc | 5 +++-- 2 files changed, 46 insertions(+), 18 deletions(-) diff --git a/httplib.h b/httplib.h index 9d561f8..33addc5 100644 --- a/httplib.h +++ b/httplib.h @@ -515,6 +515,26 @@ private: using Logger = std::function; +using SocketOptions = std::function; + +inline void default_socket_options(socket_t sock) { + int yes = 1; +#ifdef _WIN32 + setsockopt(sock, SOL_SOCKET, SO_REUSEADDR, reinterpret_cast(&yes), + sizeof(yes)); + setsockopt(sock, SOL_SOCKET, SO_EXCLUSIVEADDRUSE, + reinterpret_cast(&yes), sizeof(yes)); +#else +#ifdef SO_REUSEPORT + setsockopt(sock, SOL_SOCKET, SO_REUSEPORT, reinterpret_cast(&yes), + sizeof(yes)); +#else + setsockopt(sock, SOL_SOCKET, SO_REUSEADDR, reinterpret_cast(&yes), + sizeof(yes)); +#endif +#endif +} + class Server { public: using Handler = std::function; @@ -549,9 +569,10 @@ public: void set_file_request_handler(Handler handler); void set_error_handler(Handler handler); + void set_expect_100_continue_handler(Expect100ContinueHandler handler); void set_logger(Logger logger); - void set_expect_100_continue_handler(Expect100ContinueHandler handler); + void set_socket_options(SocketOptions socket_options); void set_keep_alive_max_count(size_t count); void set_read_timeout(time_t sec, time_t usec = 0); @@ -590,8 +611,8 @@ private: using HandlersForContentReader = std::vector>; - socket_t create_server_socket(const char *host, int port, - int socket_flags) const; + socket_t create_server_socket(const char *host, int port, int socket_flags, + SocketOptions socket_options) const; int bind_internal(const char *host, int port, int socket_flags); bool listen_internal(); @@ -639,6 +660,7 @@ private: Handler error_handler_; Logger logger_; Expect100ContinueHandler expect_100_continue_handler_; + SocketOptions socket_options_ = default_socket_options; }; class Client { @@ -1873,9 +1895,10 @@ inline int shutdown_socket(socket_t sock) { #endif } -template -socket_t create_socket(const char *host, int port, Fn fn, - int socket_flags = 0) { +template +socket_t create_socket(const char *host, int port, int socket_flags, + SocketOptions socket_options, + BindOrConnect bind_or_connect) { // Get address info struct addrinfo hints; struct addrinfo *result; @@ -1923,6 +1946,8 @@ socket_t create_socket(const char *host, int port, Fn fn, if (fcntl(sock, F_SETFD, FD_CLOEXEC) == -1) { continue; } #endif + if (socket_options) { socket_options(sock); } + // Make 'reuse address' option available int yes = 1; setsockopt(sock, SOL_SOCKET, SO_REUSEADDR, reinterpret_cast(&yes), @@ -1940,7 +1965,7 @@ socket_t create_socket(const char *host, int port, Fn fn, } // bind or connect - if (fn(sock, *rp)) { + if (bind_or_connect(sock, *rp)) { freeaddrinfo(result); return sock; } @@ -2017,10 +2042,12 @@ inline std::string if2ip(const std::string &ifn) { #endif inline socket_t create_client_socket(const char *host, int port, + SocketOptions socket_options, time_t timeout_sec, time_t timeout_usec, const std::string &intf) { return create_socket( - host, port, [&](socket_t sock, struct addrinfo &ai) -> bool { + host, port, 0, socket_options, + [&](socket_t sock, struct addrinfo &ai) -> bool { if (!intf.empty()) { #ifndef _WIN32 auto ip = if2ip(intf); @@ -3984,10 +4011,11 @@ inline bool Server::handle_file_request(Request &req, Response &res, return false; } -inline socket_t Server::create_server_socket(const char *host, int port, - int socket_flags) const { +inline socket_t +Server::create_server_socket(const char *host, int port, int socket_flags, + SocketOptions socket_options) const { return detail::create_socket( - host, port, + host, port, socket_flags, socket_options, [](socket_t sock, struct addrinfo &ai) -> bool { if (::bind(sock, ai.ai_addr, static_cast(ai.ai_addrlen))) { return false; @@ -3996,14 +4024,13 @@ inline socket_t Server::create_server_socket(const char *host, int port, return false; } return true; - }, - socket_flags); + }); } inline int Server::bind_internal(const char *host, int port, int socket_flags) { if (!is_valid()) { return -1; } - svr_sock_ = create_server_socket(host, port, socket_flags); + svr_sock_ = create_server_socket(host, port, socket_flags, socket_options_); if (svr_sock_ == INVALID_SOCKET) { return -1; } if (port == 0) { @@ -4293,10 +4320,10 @@ inline bool Client::is_valid() const { return true; } inline socket_t Client::create_client_socket() const { if (!proxy_host_.empty()) { return detail::create_client_socket(proxy_host_.c_str(), proxy_port_, - connection_timeout_sec_, + nullptr, connection_timeout_sec_, connection_timeout_usec_, interface_); } - return detail::create_client_socket(host_.c_str(), port_, + return detail::create_client_socket(host_.c_str(), port_, nullptr, connection_timeout_sec_, connection_timeout_usec_, interface_); } diff --git a/test/test.cc b/test/test.cc index 2a2ce2a..22a540a 100644 --- a/test/test.cc +++ b/test/test.cc @@ -2294,8 +2294,9 @@ TEST_F(ServerTest, MultipartFormDataGzip) { // Sends a raw request to a server listening at HOST:PORT. static bool send_request(time_t read_timeout_sec, const std::string &req, std::string *resp = nullptr) { - auto client_sock = detail::create_client_socket(HOST, PORT, /*timeout_sec=*/5, 0, - std::string()); + auto client_sock = detail::create_client_socket( + HOST, PORT, nullptr, + /*timeout_sec=*/5, 0, std::string()); if (client_sock == INVALID_SOCKET) { return false; }