diff --git a/httplib.h b/httplib.h index 7fe65d0..b6f2a68 100644 --- a/httplib.h +++ b/httplib.h @@ -1025,7 +1025,7 @@ protected: private: socket_t create_client_socket(Error &error) const; - bool read_response_line(Stream &strm, Response &res); + bool read_response_line(Stream &strm, const Request &req, Response &res); bool write_request(Stream &strm, const Request &req, bool close_connection, Error &error); bool redirect(const Request &req, Response &res, Error &error); @@ -4947,17 +4947,20 @@ inline void ClientImpl::lock_socket_and_shutdown_and_close() { close_socket(socket_); } -inline bool ClientImpl::read_response_line(Stream &strm, Response &res) { +inline bool ClientImpl::read_response_line(Stream &strm, const Request &req, + Response &res) { std::array buf; detail::stream_line_reader line_reader(strm, buf.data(), buf.size()); if (!line_reader.getline()) { return false; } - const static std::regex re("(HTTP/1\\.[01]) (\\d+) (.*?)\r\n"); + const static std::regex re("(HTTP/1\\.[01]) (\\d{3}) (.*?)\r\n"); std::cmatch m; - if (!std::regex_match(line_reader.ptr(), m, re)) { return true; } + if (!std::regex_match(line_reader.ptr(), m, re)) { + return req.method == "CONNECT"; + } res.version = std::string(m[1]); res.status = std::stoi(std::string(m[2])); res.reason = std::string(m[3]); @@ -5404,7 +5407,7 @@ inline bool ClientImpl::process_request(Stream &strm, const Request &req, if (!write_request(strm, req, close_connection, error)) { return false; } // Receive response and headers - if (!read_response_line(strm, res) || + if (!read_response_line(strm, req, res) || !detail::read_headers(strm, res.headers)) { error = Error::Read; return false; @@ -5448,9 +5451,7 @@ inline bool ClientImpl::process_request(Stream &strm, const Request &req, if (!detail::read_content(strm, res, (std::numeric_limits::max)(), dummy_status, std::move(progress), std::move(out), decompress_)) { - if (error != Error::Canceled) { - error = Error::Read; - } + if (error != Error::Canceled) { error = Error::Read; } return false; } } diff --git a/test/test.cc b/test/test.cc index cf9d86f..c638ba3 100644 --- a/test/test.cc +++ b/test/test.cc @@ -930,6 +930,31 @@ TEST(ErrorHandlerTest, ContentLength) { ASSERT_FALSE(svr.is_running()); } +TEST(InvalidFormatTest, StatusCode) { + Server svr; + + svr.Get("/hi", [](const Request & /*req*/, Response &res) { + res.set_content("Hello World!\n", "text/plain"); + res.status = 9999; // Status should be a three-digit code... + }); + + auto thread = std::thread([&]() { svr.listen(HOST, PORT); }); + + // Give GET time to get a few messages. + std::this_thread::sleep_for(std::chrono::seconds(1)); + + { + Client cli(HOST, PORT); + + auto res = cli.Get("/hi"); + ASSERT_FALSE(res); + } + + svr.stop(); + thread.join(); + ASSERT_FALSE(svr.is_running()); +} + class ServerTest : public ::testing::Test { protected: ServerTest()