From 8f32271e8cd7fc294721e2166ab334d5b2c0c9b3 Mon Sep 17 00:00:00 2001 From: Ingo Bauersachs Date: Tue, 6 Dec 2022 14:23:09 +0100 Subject: [PATCH] Support LOCAL_ADDR and LOCAL_PORT header in client Request (#1450) Having the local address/port is useful if the server is bound to all interfaces, e.g. to serve different content for developers on localhost only. --- httplib.h | 44 ++++++++++++++++++++++++++++++----- test/fuzzing/server_fuzzer.cc | 7 ++++-- test/test.cc | 20 ++++++++++++++++ 3 files changed, 63 insertions(+), 8 deletions(-) diff --git a/httplib.h b/httplib.h index c32e2de..c7bf810 100644 --- a/httplib.h +++ b/httplib.h @@ -413,6 +413,8 @@ struct Request { std::string remote_addr; int remote_port = -1; + std::string local_addr; + int local_port = -1; // for server std::string version; @@ -514,6 +516,7 @@ public: virtual ssize_t read(char *ptr, size_t size) = 0; virtual ssize_t write(const char *ptr, size_t size) = 0; virtual void get_remote_ip_and_port(std::string &ip, int &port) const = 0; + virtual void get_local_ip_and_port(std::string &ip, int &port) const = 0; virtual socket_t socket() const = 0; template @@ -1778,6 +1781,7 @@ public: ssize_t read(char *ptr, size_t size) override; ssize_t write(const char *ptr, size_t size) override; void get_remote_ip_and_port(std::string &ip, int &port) const override; + void get_local_ip_and_port(std::string &ip, int &port) const override; socket_t socket() const override; const std::string &get_buffer() const; @@ -2446,6 +2450,7 @@ public: ssize_t read(char *ptr, size_t size) override; ssize_t write(const char *ptr, size_t size) override; void get_remote_ip_and_port(std::string &ip, int &port) const override; + void get_local_ip_and_port(std::string &ip, int &port) const override; socket_t socket() const override; private: @@ -2475,6 +2480,7 @@ public: ssize_t read(char *ptr, size_t size) override; ssize_t write(const char *ptr, size_t size) override; void get_remote_ip_and_port(std::string &ip, int &port) const override; + void get_local_ip_and_port(std::string &ip, int &port) const override; socket_t socket() const override; private: @@ -2843,9 +2849,9 @@ inline socket_t create_client_socket( return sock; } -inline bool get_remote_ip_and_port(const struct sockaddr_storage &addr, - socklen_t addr_len, std::string &ip, - int &port) { +inline bool get_ip_and_port(const struct sockaddr_storage &addr, + socklen_t addr_len, std::string &ip, + int &port) { if (addr.ss_family == AF_INET) { port = ntohs(reinterpret_cast(&addr)->sin_port); } else if (addr.ss_family == AF_INET6) { @@ -2866,6 +2872,15 @@ inline bool get_remote_ip_and_port(const struct sockaddr_storage &addr, return true; } +inline void get_local_ip_and_port(socket_t sock, std::string &ip, int &port) { + struct sockaddr_storage addr; + socklen_t addr_len = sizeof(addr); + if (!getsockname(sock, reinterpret_cast(&addr), + &addr_len)) { + get_ip_and_port(addr, addr_len, ip, port); + } +} + inline void get_remote_ip_and_port(socket_t sock, std::string &ip, int &port) { struct sockaddr_storage addr; socklen_t addr_len = sizeof(addr); @@ -2890,7 +2905,7 @@ inline void get_remote_ip_and_port(socket_t sock, std::string &ip, int &port) { return; } #endif - get_remote_ip_and_port(addr, addr_len, ip, port); + get_ip_and_port(addr, addr_len, ip, port); } } @@ -4517,8 +4532,8 @@ inline void hosted_at(const std::string &hostname, *reinterpret_cast(rp->ai_addr); std::string ip; int dummy = -1; - if (detail::get_remote_ip_and_port(addr, sizeof(struct sockaddr_storage), - ip, dummy)) { + if (detail::get_ip_and_port(addr, sizeof(struct sockaddr_storage), + ip, dummy)) { addrs.push_back(ip); } } @@ -4808,6 +4823,11 @@ inline void SocketStream::get_remote_ip_and_port(std::string &ip, return detail::get_remote_ip_and_port(sock_, ip, port); } +inline void SocketStream::get_local_ip_and_port(std::string &ip, + int &port) const { + return detail::get_local_ip_and_port(sock_, ip, port); +} + inline socket_t SocketStream::socket() const { return sock_; } // Buffer stream implementation @@ -4833,6 +4853,9 @@ inline ssize_t BufferStream::write(const char *ptr, size_t size) { inline void BufferStream::get_remote_ip_and_port(std::string & /*ip*/, int & /*port*/) const {} +inline void BufferStream::get_local_ip_and_port(std::string & /*ip*/, + int & /*port*/) const {} + inline socket_t BufferStream::socket() const { return 0; } inline const std::string &BufferStream::get_buffer() const { return buffer; } @@ -5812,6 +5835,10 @@ Server::process_request(Stream &strm, bool close_connection, req.set_header("REMOTE_ADDR", req.remote_addr); req.set_header("REMOTE_PORT", std::to_string(req.remote_port)); + strm.get_local_ip_and_port(req.local_addr, req.local_port); + req.set_header("LOCAL_ADDR", req.local_addr); + req.set_header("LOCAL_PORT", std::to_string(req.local_port)); + if (req.has_header("Range")) { const auto &range_header_value = req.get_header_value("Range"); if (!detail::parse_range_header(range_header_value, req.ranges)) { @@ -7409,6 +7436,11 @@ inline void SSLSocketStream::get_remote_ip_and_port(std::string &ip, detail::get_remote_ip_and_port(sock_, ip, port); } +inline void SSLSocketStream::get_local_ip_and_port(std::string &ip, + int &port) const { + detail::get_local_ip_and_port(sock_, ip, port); +} + inline socket_t SSLSocketStream::socket() const { return sock_; } static SSLInit sslinit_; diff --git a/test/fuzzing/server_fuzzer.cc b/test/fuzzing/server_fuzzer.cc index 9fb4d4b..41f710f 100644 --- a/test/fuzzing/server_fuzzer.cc +++ b/test/fuzzing/server_fuzzer.cc @@ -22,8 +22,6 @@ public: ssize_t write(const std::string &s) { return write(s.data(), s.size()); } - std::string get_remote_addr() const { return ""; } - bool is_readable() const override { return true; } bool is_writable() const override { return true; } @@ -33,6 +31,11 @@ public: port = 8080; } + void get_local_ip_and_port(std::string &ip, int &port) const override { + ip = "127.0.0.1"; + port = 8080; + } + socket_t socket() const override { return 0; } private: diff --git a/test/test.cc b/test/test.cc index 9e66561..10bec8d 100644 --- a/test/test.cc +++ b/test/test.cc @@ -1521,6 +1521,17 @@ protected: std::stoi(req.get_header_value("REMOTE_PORT"))); res.set_content(remote_addr.c_str(), "text/plain"); }) + .Get("/local_addr", + [&](const Request &req, Response &res) { + EXPECT_TRUE(req.has_header("LOCAL_PORT")); + EXPECT_TRUE(req.has_header("LOCAL_ADDR")); + auto local_addr = req.get_header_value("LOCAL_ADDR"); + auto local_port = req.get_header_value("LOCAL_PORT"); + EXPECT_EQ(req.local_addr, local_addr); + EXPECT_EQ(req.local_port, std::stoi(local_port)); + res.set_content(local_addr.append(":").append(local_port), + "text/plain"); + }) .Get("/endwith%", [&](const Request & /*req*/, Response &res) { res.set_content("Hello World!", "text/plain"); @@ -2810,6 +2821,15 @@ TEST_F(ServerTest, GetMethodRemoteAddr) { EXPECT_TRUE(res->body == "::1" || res->body == "127.0.0.1"); } +TEST_F(ServerTest, GetMethodLocalAddr) { + auto res = cli_.Get("/local_addr"); + ASSERT_TRUE(res); + EXPECT_EQ(200, res->status); + EXPECT_EQ("text/plain", res->get_header_value("Content-Type")); + EXPECT_TRUE(res->body == std::string("::1:").append(to_string(PORT)) || + res->body == std::string("127.0.0.1:").append(to_string(PORT))); +} + TEST_F(ServerTest, HTTPResponseSplitting) { auto res = cli_.Get("/http_response_splitting"); ASSERT_TRUE(res);