diff --git a/httplib.h b/httplib.h index 8b17b69..1a85adf 100644 --- a/httplib.h +++ b/httplib.h @@ -197,7 +197,7 @@ public: virtual std::string get_remote_addr() const = 0; template - void write_format(const char *fmt, const Args &... args); + int write_format(const char *fmt, const Args &... args); }; class SocketStream : public Stream { @@ -286,7 +286,7 @@ private: bool dispatch_request(Request &req, Response &res, Handlers &handlers); bool parse_request_line(const char *s, Request &req); - void write_response(Stream &strm, bool last_connection, const Request &req, + bool write_response(Stream &strm, bool last_connection, const Request &req, Response &res); virtual bool read_and_close_socket(socket_t sock); @@ -1228,18 +1228,29 @@ bool read_content(Stream &strm, T &x, uint64_t payload_max_length, int &status, return ret; } -template inline void write_headers(Stream &strm, const T &info) { +template inline int write_headers(Stream &strm, const T &info) { + auto write_len = 0; for (const auto &x : info.headers) { - strm.write_format("%s: %s\r\n", x.first.c_str(), x.second.c_str()); + auto len = strm.write_format("%s: %s\r\n", x.first.c_str(), x.second.c_str()); + if (len < 0) { + return len; + } + write_len += len; } - strm.write("\r\n"); + auto len = strm.write("\r\n"); + if (len < 0) { + return len; + } + write_len += len; + return write_len; } template -inline void write_content_chunked(Stream &strm, const T &x) { +inline int write_content_chunked(Stream &strm, const T &x) { auto chunked_response = !x.has_header("Content-Length"); uint64_t offset = 0; auto data_available = true; + auto write_len = 0; while (data_available) { auto chunk = x.content_producer(offset); offset += chunk.size(); @@ -1250,10 +1261,13 @@ inline void write_content_chunked(Stream &strm, const T &x) { chunk = from_i_to_hex(chunk.size()) + "\r\n" + chunk + "\r\n"; } - if (strm.write(chunk.c_str(), chunk.size()) < 0) { - break; // Stop on error + auto len = strm.write(chunk.c_str(), chunk.size()); + if (len < 0) { + return len; } + write_len += len; } + return write_len; } inline std::string encode_url(const std::string &s) { @@ -1560,7 +1574,7 @@ inline void Response::set_content(const std::string &s, // Rstream implementation template -inline void Stream::write_format(const char *fmt, const Args &... args) { +inline int Stream::write_format(const char *fmt, const Args &... args) { const auto bufsiz = 2048; char buf[bufsiz]; @@ -1569,23 +1583,25 @@ inline void Stream::write_format(const char *fmt, const Args &... args) { #else auto n = snprintf(buf, bufsiz - 1, fmt, args...); #endif - if (n > 0) { - if (n >= bufsiz - 1) { - std::vector glowable_buf(bufsiz); + if (n <= 0) { + return n; + } - while (n >= static_cast(glowable_buf.size() - 1)) { - glowable_buf.resize(glowable_buf.size() * 2); + if (n >= bufsiz - 1) { + std::vector glowable_buf(bufsiz); + + while (n >= static_cast(glowable_buf.size() - 1)) { + glowable_buf.resize(glowable_buf.size() * 2); #if defined(_MSC_VER) && _MSC_VER < 1900 - n = _snprintf_s(&glowable_buf[0], glowable_buf.size(), - glowable_buf.size() - 1, fmt, args...); + n = _snprintf_s(&glowable_buf[0], glowable_buf.size(), + glowable_buf.size() - 1, fmt, args...); #else - n = snprintf(&glowable_buf[0], glowable_buf.size() - 1, fmt, args...); + n = snprintf(&glowable_buf[0], glowable_buf.size() - 1, fmt, args...); #endif - } - write(&glowable_buf[0], n); - } else { - write(buf, n); } + return write(&glowable_buf[0], n); + } else { + return write(buf, n); } } @@ -1745,15 +1761,17 @@ inline bool Server::parse_request_line(const char *s, Request &req) { return false; } -inline void Server::write_response(Stream &strm, bool last_connection, +inline bool Server::write_response(Stream &strm, bool last_connection, const Request &req, Response &res) { assert(res.status != -1); if (400 <= res.status && error_handler_) { error_handler_(req, res); } // Response line - strm.write_format("HTTP/1.1 %d %s\r\n", res.status, - detail::status_message(res.status)); + if (!strm.write_format("HTTP/1.1 %d %s\r\n", res.status, + detail::status_message(res.status))) { + return false; + } // Headers if (last_connection || req.get_header_value("Connection") == "close") { @@ -1793,19 +1811,27 @@ inline void Server::write_response(Stream &strm, bool last_connection, res.set_header("Content-Length", length.c_str()); } - detail::write_headers(strm, res); + if (!detail::write_headers(strm, res)) { + return false; + } // Body if (req.method != "HEAD") { if (!res.body.empty()) { - strm.write(res.body.c_str(), res.body.size()); + if (!strm.write(res.body.c_str(), res.body.size())) { + return false; + } } else if (res.content_producer) { - detail::write_content_chunked(strm, res); + if (!detail::write_content_chunked(strm, res)) { + return false; + } } } // Log if (logger_) { logger_(req, res); } + + return true; } inline bool Server::handle_file_request(Request &req, Response &res) { @@ -1978,16 +2004,14 @@ Server::process_request(Stream &strm, bool last_connection, // Check if the request URI doesn't exceed the limit if (reader.size() > CPPHTTPLIB_REQUEST_URI_MAX_LENGTH) { res.status = 414; - write_response(strm, last_connection, req, res); - return true; + return write_response(strm, last_connection, req, res); } // Request line and headers if (!parse_request_line(reader.ptr(), req) || !detail::read_headers(strm, req.headers)) { res.status = 400; - write_response(strm, last_connection, req, res); - return true; + return write_response(strm, last_connection, req, res); } if (req.get_header_value("Connection") == "close") { @@ -2001,8 +2025,7 @@ Server::process_request(Stream &strm, bool last_connection, if (!detail::read_content( strm, req, payload_max_length_, res.status, Progress(), [&](const char *buf, size_t n) { req.body.append(buf, n); })) { - write_response(strm, last_connection, req, res); - return true; + return write_response(strm, last_connection, req, res); } const auto &content_type = req.get_header_value("Content-Type"); @@ -2014,8 +2037,7 @@ Server::process_request(Stream &strm, bool last_connection, if (!detail::parse_multipart_boundary(content_type, boundary) || !detail::parse_multipart_formdata(boundary, req.body, req.files)) { res.status = 400; - write_response(strm, last_connection, req, res); - return true; + return write_response(strm, last_connection, req, res); } } } @@ -2029,8 +2051,7 @@ Server::process_request(Stream &strm, bool last_connection, res.status = 404; } - write_response(strm, last_connection, req, res); - return true; + return write_response(strm, last_connection, req, res); } inline bool Server::is_valid() const { return true; }