diff --git a/httplib.h b/httplib.h index 50d64d4..953b482 100644 --- a/httplib.h +++ b/httplib.h @@ -32,6 +32,14 @@ #define CPPHTTPLIB_READ_TIMEOUT_USECOND 0 #endif +#ifndef CPPHTTPLIB_WRITE_TIMEOUT_SECOND +#define CPPHTTPLIB_WRITE_TIMEOUT_SECOND 5 +#endif + +#ifndef CPPHTTPLIB_WRITE_TIMEOUT_USECOND +#define CPPHTTPLIB_WRITE_TIMEOUT_USECOND 0 +#endif + #ifndef CPPHTTPLIB_REQUEST_URI_MAX_LENGTH #define CPPHTTPLIB_REQUEST_URI_MAX_LENGTH 8192 #endif @@ -509,6 +517,7 @@ public: void set_keep_alive_max_count(size_t count); void set_read_timeout(time_t sec, time_t usec); + void set_write_timeout(time_t sec, time_t usec); void set_payload_max_length(size_t length); bool bind_to_port(const char *host, int port, int socket_flags = 0); @@ -530,6 +539,8 @@ protected: size_t keep_alive_max_count_; time_t read_timeout_sec_; time_t read_timeout_usec_; + time_t write_timeout_sec_; + time_t write_timeout_usec_; size_t payload_max_length_; private: @@ -731,6 +742,8 @@ public: void set_read_timeout(time_t sec, time_t usec); + void set_write_timeout(time_t sec, time_t usec); + void set_keep_alive_max_count(size_t count); void set_basic_auth(const char *username, const char *password); @@ -772,6 +785,8 @@ protected: time_t timeout_sec_ = 300; time_t read_timeout_sec_ = CPPHTTPLIB_READ_TIMEOUT_SECOND; time_t read_timeout_usec_ = CPPHTTPLIB_READ_TIMEOUT_USECOND; + time_t write_timeout_sec_ = CPPHTTPLIB_WRITE_TIMEOUT_SECOND; + time_t write_timeout_usec_ = CPPHTTPLIB_WRITE_TIMEOUT_USECOND; size_t keep_alive_max_count_ = CPPHTTPLIB_KEEPALIVE_MAX_COUNT; @@ -806,6 +821,8 @@ protected: timeout_sec_ = rhs.timeout_sec_; read_timeout_sec_ = rhs.read_timeout_sec_; read_timeout_usec_ = rhs.read_timeout_usec_; + write_timeout_sec_ = rhs.write_timeout_sec_; + write_timeout_usec_ = rhs.write_timeout_usec_; keep_alive_max_count_ = rhs.keep_alive_max_count_; basic_auth_username_ = rhs.basic_auth_username_; basic_auth_password_ = rhs.basic_auth_password_; @@ -1347,8 +1364,8 @@ inline bool wait_until_socket_is_ready(socket_t sock, time_t sec, time_t usec) { class SocketStream : public Stream { public: - SocketStream(socket_t sock, time_t read_timeout_sec, - time_t read_timeout_usec); + SocketStream(socket_t sock, time_t read_timeout_sec, time_t read_timeout_usec, + time_t write_timeout_sec, time_t write_timeout_usec); ~SocketStream() override; bool is_readable() const override; @@ -1361,13 +1378,16 @@ private: socket_t sock_; time_t read_timeout_sec_; time_t read_timeout_usec_; + time_t write_timeout_sec_; + time_t write_timeout_usec_; }; #ifdef CPPHTTPLIB_OPENSSL_SUPPORT class SSLSocketStream : public Stream { public: SSLSocketStream(socket_t sock, SSL *ssl, time_t read_timeout_sec, - time_t read_timeout_usec); + time_t read_timeout_usec, time_t write_timeout_sec, + time_t write_timeout_usec); ~SSLSocketStream() override; bool is_readable() const override; @@ -1381,6 +1401,8 @@ private: SSL *ssl_; time_t read_timeout_sec_; time_t read_timeout_usec_; + time_t write_timeout_sec_; + time_t write_timeout_usec_; }; #endif @@ -1405,7 +1427,8 @@ private: template inline bool process_socket(bool is_client_request, socket_t sock, size_t keep_alive_max_count, time_t read_timeout_sec, - time_t read_timeout_usec, T callback) { + time_t read_timeout_usec, time_t write_timeout_sec, + time_t write_timeout_usec, T callback) { assert(keep_alive_max_count > 0); auto ret = false; @@ -1416,7 +1439,8 @@ inline bool process_socket(bool is_client_request, socket_t sock, (is_client_request || select_read(sock, CPPHTTPLIB_KEEPALIVE_TIMEOUT_SECOND, CPPHTTPLIB_KEEPALIVE_TIMEOUT_USECOND) > 0)) { - SocketStream strm(sock, read_timeout_sec, read_timeout_usec); + SocketStream strm(sock, read_timeout_sec, read_timeout_usec, + write_timeout_sec, write_timeout_usec); auto last_connection = count == 1; auto connection_close = false; @@ -1426,7 +1450,8 @@ inline bool process_socket(bool is_client_request, socket_t sock, count--; } } else { // keep_alive_max_count is 0 or 1 - SocketStream strm(sock, read_timeout_sec, read_timeout_usec); + SocketStream strm(sock, read_timeout_sec, read_timeout_usec, + write_timeout_sec, write_timeout_usec); auto dummy_connection_close = false; ret = callback(strm, true, dummy_connection_close); } @@ -1435,12 +1460,14 @@ inline bool process_socket(bool is_client_request, socket_t sock, } template -inline bool process_and_close_socket(bool is_client_request, socket_t sock, - size_t keep_alive_max_count, - time_t read_timeout_sec, - time_t read_timeout_usec, T callback) { +inline bool +process_and_close_socket(bool is_client_request, socket_t sock, + size_t keep_alive_max_count, time_t read_timeout_sec, + time_t read_timeout_usec, time_t write_timeout_sec, + time_t write_timeout_usec, T callback) { auto ret = process_socket(is_client_request, sock, keep_alive_max_count, - read_timeout_sec, read_timeout_usec, callback); + read_timeout_sec, read_timeout_usec, + write_timeout_sec, write_timeout_usec, callback); close_socket(sock); return ret; } @@ -3024,9 +3051,13 @@ namespace detail { // Socket stream implementation inline SocketStream::SocketStream(socket_t sock, time_t read_timeout_sec, - time_t read_timeout_usec) + time_t read_timeout_usec, + time_t write_timeout_sec, + time_t write_timeout_usec) : sock_(sock), read_timeout_sec_(read_timeout_sec), - read_timeout_usec_(read_timeout_usec) {} + read_timeout_usec_(read_timeout_usec), + write_timeout_sec_(write_timeout_sec), + write_timeout_usec_(write_timeout_usec) {} inline SocketStream::~SocketStream() {} @@ -3035,7 +3066,7 @@ inline bool SocketStream::is_readable() const { } inline bool SocketStream::is_writable() const { - return select_write(sock_, 0, 0) > 0; + return select_write(sock_, write_timeout_sec_, write_timeout_usec_) > 0; } inline ssize_t SocketStream::read(char *ptr, size_t size) { @@ -3101,6 +3132,8 @@ inline Server::Server() : keep_alive_max_count_(CPPHTTPLIB_KEEPALIVE_MAX_COUNT), read_timeout_sec_(CPPHTTPLIB_READ_TIMEOUT_SECOND), read_timeout_usec_(CPPHTTPLIB_READ_TIMEOUT_USECOND), + write_timeout_sec_(CPPHTTPLIB_WRITE_TIMEOUT_SECOND), + write_timeout_usec_(CPPHTTPLIB_WRITE_TIMEOUT_USECOND), payload_max_length_(CPPHTTPLIB_PAYLOAD_MAX_LENGTH), is_running_(false), svr_sock_(INVALID_SOCKET) { #ifndef _WIN32 @@ -3223,6 +3256,11 @@ inline void Server::set_read_timeout(time_t sec, time_t usec) { read_timeout_usec_ = usec; } +inline void Server::set_write_timeout(time_t sec, time_t usec) { + write_timeout_sec_ = sec; + write_timeout_usec_ = usec; +} + inline void Server::set_payload_max_length(size_t length) { payload_max_length_ = length; } @@ -3828,6 +3866,7 @@ inline bool Server::is_valid() const { return true; } inline bool Server::process_and_close_socket(socket_t sock) { return detail::process_and_close_socket( false, sock, keep_alive_max_count_, read_timeout_sec_, read_timeout_usec_, + write_timeout_sec_, write_timeout_usec_, [this](Stream &strm, bool last_connection, bool &connection_close) { return process_request(strm, last_connection, connection_close, nullptr); @@ -3993,6 +4032,7 @@ inline bool Client::connect(socket_t sock, Response &res, bool &error) { if (!detail::process_socket( true, sock, 1, read_timeout_sec_, read_timeout_usec_, + write_timeout_sec_, write_timeout_usec_, [&](Stream &strm, bool /*last_connection*/, bool &connection_close) { Request req2; req2.method = "CONNECT"; @@ -4012,6 +4052,7 @@ inline bool Client::connect(socket_t sock, Response &res, bool &error) { Response res3; if (!detail::process_socket( true, sock, 1, read_timeout_sec_, read_timeout_usec_, + write_timeout_sec_, write_timeout_usec_, [&](Stream &strm, bool /*last_connection*/, bool &connection_close) { Request req3; @@ -4299,9 +4340,9 @@ inline bool Client::process_and_close_socket( bool &connection_close)> callback) { request_count = (std::min)(request_count, keep_alive_max_count_); - return detail::process_and_close_socket(true, sock, request_count, - read_timeout_sec_, read_timeout_usec_, - callback); + return detail::process_and_close_socket( + true, sock, request_count, read_timeout_sec_, read_timeout_usec_, + write_timeout_sec_, write_timeout_usec_, callback); } inline bool Client::is_ssl() const { return false; } @@ -4613,6 +4654,11 @@ inline void Client::set_read_timeout(time_t sec, time_t usec) { read_timeout_usec_ = usec; } +inline void Client::set_write_timeout(time_t sec, time_t usec) { + write_timeout_sec_ = sec; + write_timeout_usec_ = usec; +} + inline void Client::set_keep_alive_max_count(size_t count) { keep_alive_max_count_ = count; } @@ -4666,8 +4712,9 @@ namespace detail { template inline bool process_and_close_socket_ssl( bool is_client_request, socket_t sock, size_t keep_alive_max_count, - time_t read_timeout_sec, time_t read_timeout_usec, SSL_CTX *ctx, - std::mutex &ctx_mutex, U SSL_connect_or_accept, V setup, T callback) { + time_t read_timeout_sec, time_t read_timeout_usec, time_t write_timeout_sec, + time_t write_timeout_usec, SSL_CTX *ctx, std::mutex &ctx_mutex, + U SSL_connect_or_accept, V setup, T callback) { assert(keep_alive_max_count > 0); SSL *ssl = nullptr; @@ -4704,7 +4751,8 @@ inline bool process_and_close_socket_ssl( (is_client_request || select_read(sock, CPPHTTPLIB_KEEPALIVE_TIMEOUT_SECOND, CPPHTTPLIB_KEEPALIVE_TIMEOUT_USECOND) > 0)) { - SSLSocketStream strm(sock, ssl, read_timeout_sec, read_timeout_usec); + SSLSocketStream strm(sock, ssl, read_timeout_sec, read_timeout_usec, + write_timeout_sec, write_timeout_usec); auto last_connection = count == 1; auto connection_close = false; @@ -4714,7 +4762,8 @@ inline bool process_and_close_socket_ssl( count--; } } else { - SSLSocketStream strm(sock, ssl, read_timeout_sec, read_timeout_usec); + SSLSocketStream strm(sock, ssl, read_timeout_sec, read_timeout_usec, + write_timeout_sec, write_timeout_usec); auto dummy_connection_close = false; ret = callback(ssl, strm, true, dummy_connection_close); } @@ -4787,9 +4836,13 @@ private: // SSL socket stream implementation inline SSLSocketStream::SSLSocketStream(socket_t sock, SSL *ssl, time_t read_timeout_sec, - time_t read_timeout_usec) + time_t read_timeout_usec, + time_t write_timeout_sec, + time_t write_timeout_usec) : sock_(sock), ssl_(ssl), read_timeout_sec_(read_timeout_sec), - read_timeout_usec_(read_timeout_usec) {} + read_timeout_usec_(read_timeout_usec), + write_timeout_sec_(write_timeout_sec), + write_timeout_usec_(write_timeout_usec) {} inline SSLSocketStream::~SSLSocketStream() {} @@ -4798,7 +4851,8 @@ inline bool SSLSocketStream::is_readable() const { } inline bool SSLSocketStream::is_writable() const { - return detail::select_write(sock_, 0, 0) > 0; + return detail::select_write(sock_, write_timeout_sec_, write_timeout_usec_) > + 0; } inline ssize_t SSLSocketStream::read(char *ptr, size_t size) { @@ -4898,7 +4952,8 @@ inline bool SSLServer::is_valid() const { return ctx_; } inline bool SSLServer::process_and_close_socket(socket_t sock) { return detail::process_and_close_socket_ssl( false, sock, keep_alive_max_count_, read_timeout_sec_, read_timeout_usec_, - ctx_, ctx_mutex_, SSL_accept, [](SSL * /*ssl*/) { return true; }, + write_timeout_sec_, write_timeout_usec_, ctx_, ctx_mutex_, SSL_accept, + [](SSL * /*ssl*/) { return true; }, [this](SSL *ssl, Stream &strm, bool last_connection, bool &connection_close) { return process_request(strm, last_connection, connection_close, @@ -4989,7 +5044,7 @@ inline bool SSLClient::process_and_close_socket( return is_valid() && detail::process_and_close_socket_ssl( true, sock, request_count, read_timeout_sec_, read_timeout_usec_, - ctx_, ctx_mutex_, + write_timeout_sec_, write_timeout_usec_, ctx_, ctx_mutex_, [&](SSL *ssl) { if (ca_cert_file_path_.empty() && ca_cert_store_ == nullptr) { SSL_CTX_set_verify(ctx_, SSL_VERIFY_NONE, nullptr); diff --git a/test/test.cc b/test/test.cc index e0464a3..a323f80 100644 --- a/test/test.cc +++ b/test/test.cc @@ -807,6 +807,11 @@ protected: std::this_thread::sleep_for(std::chrono::seconds(2)); res.set_content("slow", "text/plain"); }) + .Post("/slowpost", + [&](const Request & /*req*/, Response &res) { + std::this_thread::sleep_for(std::chrono::seconds(2)); + res.set_content("slow", "text/plain"); + }) .Get("/remote_addr", [&](const Request &req, Response &res) { auto remote_addr = req.headers.find("REMOTE_ADDR")->second; @@ -1885,6 +1890,28 @@ TEST_F(ServerTest, SlowRequest) { std::this_thread::sleep_for(std::chrono::milliseconds(100)); } +TEST_F(ServerTest, SlowPost) { + char buffer[64 * 1024]; + memset(buffer, 0x42, sizeof(buffer)); + auto res = cli_.Post( + "/slowpost", 64*1024*1024, + [&] (size_t /*offset*/, size_t /*length*/, DataSink & sink) { + sink.write(buffer, sizeof(buffer)); return true; + }, + "text/plain"); + ASSERT_TRUE(res != nullptr); + EXPECT_EQ(200, res->status); + cli_.set_write_timeout(0, 0); + res = cli_.Post( + "/slowpost", 64 * 1024 * 1024, + [&](size_t /*offset*/, size_t /*length*/, DataSink &sink) { + sink.write(buffer, sizeof(buffer)); + return true; + }, + "text/plain"); + ASSERT_FALSE(res != nullptr); +} + TEST_F(ServerTest, Put) { auto res = cli_.Put("/put", "PUT", "text/plain"); ASSERT_TRUE(res != nullptr); @@ -2253,7 +2280,7 @@ static bool send_request(time_t read_timeout_sec, const std::string &req, if (client_sock == INVALID_SOCKET) { return false; } return detail::process_and_close_socket( - true, client_sock, 1, read_timeout_sec, 0, + true, client_sock, 1, read_timeout_sec, 0, 0, 0, [&](Stream &strm, bool /*last_connection*/, bool & /*connection_close*/) -> bool { if (req.size() !=