From 6f663028e9293f2fd2f30b0b3c6c152ddecf0c76 Mon Sep 17 00:00:00 2001 From: yhirose Date: Wed, 17 Jul 2019 21:33:47 -0400 Subject: [PATCH] Fix #139. Content receiver support --- README.md | 9 +++ httplib.h | 162 +++++++++++++++++++++++++++++++++++++++++---------- test/test.cc | 65 ++++++++++++++++++++- 3 files changed, 204 insertions(+), 32 deletions(-) diff --git a/README.md b/README.md index 4a213a2..3e47ab1 100644 --- a/README.md +++ b/README.md @@ -114,6 +114,15 @@ int main(void) } ``` +### GET with Content Receiver + +```c++ + std::string body; + auto res = cli.Get("/large-data", [&](const char *data, size_t len) { + body.append(data, len); + }); +``` + ### POST ```c++ diff --git a/httplib.h b/httplib.h index 5361184..11b1018 100644 --- a/httplib.h +++ b/httplib.h @@ -124,6 +124,9 @@ std::pair make_range_header(uint64_t value, typedef std::multimap Params; typedef std::smatch Match; + +typedef std::function ContentProducer; +typedef std::function ContentReceiver; typedef std::function Progress; struct MultipartFile { @@ -145,8 +148,6 @@ struct Request { MultipartFiles files; Match matches; - Progress progress; - #ifdef CPPHTTPLIB_OPENSSL_SUPPORT const SSL *ssl; #endif @@ -169,7 +170,10 @@ struct Response { int status; Headers headers; std::string body; - std::function streamcb; + + ContentProducer content_producer; + ContentReceiver content_receiver; + Progress progress; bool has_header(const char *key) const; std::string get_header_value(const char *key, size_t id = 0) const; @@ -315,6 +319,13 @@ public: std::shared_ptr Get(const char *path, const Headers &headers, Progress progress = nullptr); + std::shared_ptr Get(const char *path, + ContentReceiver content_receiver, + Progress progress = nullptr); + std::shared_ptr Get(const char *path, const Headers &headers, + ContentReceiver content_receiver, + Progress progress = nullptr); + std::shared_ptr Head(const char *path); std::shared_ptr Head(const char *path, const Headers &headers); @@ -942,6 +953,63 @@ inline bool compress(std::string &content) { return true; } +class decompressor { +public: + decompressor() { + strm.zalloc = Z_NULL; + strm.zfree = Z_NULL; + strm.opaque = Z_NULL; + + // 15 is the value of wbits, which should be at the maximum possible value + // to ensure that any gzip stream can be decoded. The offset of 16 specifies + // that the stream to decompress will be formatted with a gzip wrapper. + is_valid_ = inflateInit2(&strm, 16 + 15) == Z_OK; + } + + ~decompressor() { inflateEnd(&strm); } + + bool is_valid() const { return is_valid_; } + + template + bool decompress(const char *data, size_t data_len, T callback) { + int ret = Z_OK; + std::string decompressed; + + // strm.avail_in = content.size(); + // strm.next_in = (Bytef *)content.data(); + strm.avail_in = data_len; + strm.next_in = (Bytef *)data; + + const auto bufsiz = 16384; + char buff[bufsiz]; + do { + strm.avail_out = bufsiz; + strm.next_out = (Bytef *)buff; + + ret = inflate(&strm, Z_NO_FLUSH); + assert(ret != Z_STREAM_ERROR); + switch (ret) { + case Z_NEED_DICT: + case Z_DATA_ERROR: + case Z_MEM_ERROR: inflateEnd(&strm); return false; + } + + decompressed.append(buff, bufsiz - strm.avail_out); + } while (strm.avail_out == 0); + + if (ret == Z_STREAM_END) { + callback(decompressed.data(), decompressed.size()); + return true; + } + + return false; + } + +private: + bool is_valid_; + z_stream strm; +}; + inline bool decompress(std::string &content) { z_stream strm; strm.zalloc = Z_NULL; @@ -1112,26 +1180,40 @@ inline bool is_chunked_transfer_encoding(const Headers &headers) { "chunked"); } -template +template bool read_content(Stream &strm, T &x, uint64_t payload_max_length, int &status, - Progress progress) { + Progress progress, U callback) { -#ifndef CPPHTTPLIB_ZLIB_SUPPORT + ContentReceiver out = [&](const char *buf, size_t n) { callback(buf, n); }; + +#ifdef CPPHTTPLIB_ZLIB_SUPPORT + detail::decompressor decompressor; + + if (!decompressor.is_valid()) { + status = 500; + return false; + } + + if (x.get_header_value("Content-Encoding") == "gzip") { + out = [&](const char *buf, size_t n) { + decompressor.decompress( + buf, n, [&](const char *buf, size_t n) { callback(buf, n); }); + }; + } +#else if (x.get_header_value("Content-Encoding") == "gzip") { status = 415; return false; } #endif - auto callback = [&](const char *buf, size_t n) { x.body.append(buf, n); }; - auto ret = true; auto exceed_payload_max_length = false; if (is_chunked_transfer_encoding(x.headers)) { - ret = read_content_chunked(strm, callback); + ret = read_content_chunked(strm, out); } else if (!has_header(x.headers, "Content-Length")) { - ret = read_content_without_length(strm, callback); + ret = read_content_without_length(strm, out); } else { auto len = get_header_value_uint64(x.headers, "Content-Length", 0); if (len > 0) { @@ -1143,23 +1225,12 @@ bool read_content(Stream &strm, T &x, uint64_t payload_max_length, int &status, skip_content_with_length(strm, len); ret = false; } else { - // NOTE: We can remove it if it doesn't give us enough better - // performance. - x.body.reserve(len); - ret = read_content_with_length(strm, len, progress, callback); + ret = read_content_with_length(strm, len, progress, out); } } } - if (ret) { -#ifdef CPPHTTPLIB_ZLIB_SUPPORT - if (x.get_header_value("Content-Encoding") == "gzip") { - ret = detail::decompress(x.body); - } -#endif - } else { - status = exceed_payload_max_length ? 413 : 400; - } + if (!ret) { status = exceed_payload_max_length ? 413 : 400; } return ret; } @@ -1177,7 +1248,7 @@ inline void write_content_chunked(Stream &strm, const T &x) { uint64_t offset = 0; auto data_available = true; while (data_available) { - auto chunk = x.streamcb(offset); + auto chunk = x.content_producer(offset); offset += chunk.size(); data_available = !chunk.empty(); @@ -1696,7 +1767,7 @@ inline void Server::write_response(Stream &strm, bool last_connection, if (res.body.empty()) { if (!res.has_header("Content-Length")) { - if (res.streamcb) { + if (res.content_producer) { // Streamed response res.set_header("Transfer-Encoding", "chunked"); } else { @@ -1729,7 +1800,7 @@ inline void Server::write_response(Stream &strm, bool last_connection, if (req.method != "HEAD") { if (!res.body.empty()) { strm.write(res.body.c_str(), res.body.size()); - } else if (res.streamcb) { + } else if (res.content_producer) { detail::write_content_chunked(strm, res); } } @@ -1928,8 +1999,9 @@ Server::process_request(Stream &strm, bool last_connection, // Body if (req.method == "POST" || req.method == "PUT" || req.method == "PATCH") { - if (!detail::read_content(strm, req, payload_max_length_, res.status, - Progress())) { + if (!detail::read_content( + strm, req, payload_max_length_, res.status, Progress(), + [&](const char *buf, size_t n) { req.body.append(buf, n); })) { write_response(strm, last_connection, req, res); return true; } @@ -2107,9 +2179,17 @@ inline bool Client::process_request(Stream &strm, Request &req, Response &res, // Body if (req.method != "HEAD") { + ContentReceiver out = [&](const char *buf, size_t n) { + res.body.append(buf, n); + }; + + if (res.content_receiver) { + out = [&](const char *buf, size_t n) { res.content_receiver(buf, n); }; + } + int dummy_status; if (!detail::read_content(strm, res, std::numeric_limits::max(), - dummy_status, req.progress)) { + dummy_status, res.progress, out)) { return false; } } @@ -2139,9 +2219,31 @@ Client::Get(const char *path, const Headers &headers, Progress progress) { req.method = "GET"; req.path = path; req.headers = headers; - req.progress = progress; auto res = std::make_shared(); + res->progress = progress; + + return send(req, *res) ? res : nullptr; +} + +inline std::shared_ptr Client::Get(const char *path, + ContentReceiver content_receiver, + Progress progress) { + return Get(path, Headers(), content_receiver, progress); +} + +inline std::shared_ptr Client::Get(const char *path, + const Headers &headers, + ContentReceiver content_receiver, + Progress progress) { + Request req; + req.method = "GET"; + req.path = path; + req.headers = headers; + + auto res = std::make_shared(); + res->content_receiver = content_receiver; + res->progress = progress; return send(req, *res) ? res : nullptr; } diff --git a/test/test.cc b/test/test.cc index 3077ce7..b278bca 100644 --- a/test/test.cc +++ b/test/test.cc @@ -142,6 +142,31 @@ TEST(ChunkedEncodingTest, FromHTTPWatch) { EXPECT_EQ(out, res->body); } +TEST(ChunkedEncodingTest, WithContentReceiver) { + auto host = "www.httpwatch.com"; + auto sec = 2; + +#ifdef CPPHTTPLIB_OPENSSL_SUPPORT + auto port = 443; + httplib::SSLClient cli(host, port, sec); +#else + auto port = 80; + httplib::Client cli(host, port, sec); +#endif + + std::string body; + auto res = + cli.Get("/httpgallery/chunked/chunkedimage.aspx?0.4153841143030137", + [&](const char *data, size_t len) { body.append(data, len); }); + ASSERT_TRUE(res != nullptr); + + std::string out; + httplib::detail::read_file("./image.jpg", out); + + EXPECT_EQ(200, res->status); + EXPECT_EQ(out, body); +} + TEST(RangeTest, FromHTTPBin) { auto host = "httpbin.org"; auto sec = 5; @@ -380,7 +405,7 @@ protected: }) .Get("/streamedchunked", [&](const Request & /*req*/, Response &res) { - res.streamcb = [](uint64_t offset) { + res.content_producer = [](uint64_t offset) { if (offset < 3) return "a"; if (offset < 6) return "b"; return ""; @@ -389,7 +414,7 @@ protected: .Get("/streamed", [&](const Request & /*req*/, Response &res) { res.set_header("Content-Length", "6"); - res.streamcb = [](uint64_t offset) { + res.content_producer = [](uint64_t offset) { if (offset < 3) return "a"; if (offset < 6) return "b"; return ""; @@ -1146,6 +1171,24 @@ TEST_F(ServerTest, Gzip) { EXPECT_EQ(200, res->status); } +TEST_F(ServerTest, GzipWithContentReceiver) { + Headers headers; + headers.emplace("Accept-Encoding", "gzip, deflate"); + std::string body; + auto res = cli_.Get("/gzip", headers, [&](const char *data, size_t len) { + body.append(data, len); + }); + + ASSERT_TRUE(res != nullptr); + EXPECT_EQ("gzip", res->get_header_value("Content-Encoding")); + EXPECT_EQ("text/plain", res->get_header_value("Content-Type")); + EXPECT_EQ("33", res->get_header_value("Content-Length")); + EXPECT_EQ("123456789012345678901234567890123456789012345678901234567890123456" + "7890123456789012345678901234567890", + body); + EXPECT_EQ(200, res->status); +} + TEST_F(ServerTest, NoGzip) { Headers headers; headers.emplace("Accept-Encoding", "gzip, deflate"); @@ -1161,6 +1204,24 @@ TEST_F(ServerTest, NoGzip) { EXPECT_EQ(200, res->status); } +TEST_F(ServerTest, NoGzipWithContentReceiver) { + Headers headers; + headers.emplace("Accept-Encoding", "gzip, deflate"); + std::string body; + auto res = cli_.Get("/nogzip", headers, [&](const char *data, size_t len) { + body.append(data, len); + }); + + ASSERT_TRUE(res != nullptr); + EXPECT_EQ(false, res->has_header("Content-Encoding")); + EXPECT_EQ("application/octet-stream", res->get_header_value("Content-Type")); + EXPECT_EQ("100", res->get_header_value("Content-Length")); + EXPECT_EQ("123456789012345678901234567890123456789012345678901234567890123456" + "7890123456789012345678901234567890", + body); + EXPECT_EQ(200, res->status); +} + TEST_F(ServerTest, MultipartFormDataGzip) { Request req; req.method = "POST";