diff --git a/httplib.h b/httplib.h index d496d06..d060c92 100644 --- a/httplib.h +++ b/httplib.h @@ -1897,24 +1897,20 @@ inline ssize_t write_content(Stream &strm, ContentProvider content_provider, }; data_sink.done = [&](void) { written_length = -1; }; - content_provider(offset, end_offset - offset, - // [&](const char *d, size_t l) { - // offset += l; - // written_length = strm.write(d, l); - // }, - // [&](void) { written_length = -1; } - data_sink); + content_provider(offset, end_offset - offset, data_sink); if (written_length < 0) { return written_length; } } return static_cast(offset - begin_offset); } +template inline ssize_t write_content_chunked(Stream &strm, - ContentProvider content_provider) { + ContentProvider content_provider, + T is_shutting_down) { size_t offset = 0; auto data_available = true; ssize_t total_written_length = 0; - while (data_available) { + while (data_available && !is_shutting_down()) { ssize_t written_length = 0; DataSink data_sink; @@ -1931,21 +1927,7 @@ inline ssize_t write_content_chunked(Stream &strm, written_length = strm.write("0\r\n\r\n"); }; - content_provider( - offset, 0, - // [&](const char *d, size_t l) { - // data_available = l > 0; - // offset += l; - // - // // Emit chunked response header and footer for each chunk - // auto chunk = from_i_to_hex(l) + "\r\n" + std::string(d, l) + - // "\r\n"; written_length = strm.write(chunk); - // }, - // [&](void) { - // data_available = false; - // written_length = strm.write("0\r\n\r\n"); - // } - data_sink); + content_provider(offset, 0, data_sink); if (written_length < 0) { return written_length; } total_written_length += written_length; @@ -3088,7 +3070,8 @@ Server::write_content_with_provider(Stream &strm, const Request &req, } } } else { - if (detail::write_content_chunked(strm, res.content_provider) < 0) { + auto is_shutting_down = [this]() { return this->svr_sock_ == INVALID_SOCKET; }; + if (detail::write_content_chunked(strm, res.content_provider, is_shutting_down) < 0) { return false; } } diff --git a/test/test.cc b/test/test.cc index a82c31d..94dc6e2 100644 --- a/test/test.cc +++ b/test/test.cc @@ -1967,6 +1967,46 @@ TEST(ServerRequestParsingTest, ReadHeadersRegexComplexity) { EXPECT_TRUE(listen_thread_ok); } +TEST(ServerStopTest, StopServerWithChunkedTransmission) { + Server svr; + + svr.Get("/events", [](const Request &req, Response &res) { + res.set_header("Content-Type", "text/event-stream"); + res.set_header("Cache-Control", "no-cache"); + res.set_chunked_content_provider([](size_t offset, const DataSink &sink) { + char buffer[27]; + int size = sprintf(buffer, "data:%ld\n\n", offset); + sink.write(buffer, size); + std::this_thread::sleep_for(std::chrono::seconds(1)); + }); + }); + + auto listen_thread = std::thread([&svr]() { svr.listen("localhost", PORT); }); + while (!svr.is_running()) { + std::this_thread::sleep_for(std::chrono::milliseconds(1)); + } + + Client client(HOST, PORT); + const Headers headers = {{"Accept", "text/event-stream"}, + {"Connection", "Keep-Alive"}}; + + auto get_thread = std::thread([&client, &headers]() { + std::shared_ptr res = + client.Get("/events", headers, + [](const char *data, size_t len) -> bool { return true; }); + }); + + // Give GET time to get a few messages. + std::this_thread::sleep_for(std::chrono::seconds(2)); + + svr.stop(); + + listen_thread.join(); + get_thread.join(); + + ASSERT_FALSE(svr.is_running()); +} + class ServerTestWithAI_PASSIVE : public ::testing::Test { protected: ServerTestWithAI_PASSIVE()