diff --git a/httplib.h b/httplib.h index 0aa1c55..16310e0 100644 --- a/httplib.h +++ b/httplib.h @@ -4722,7 +4722,8 @@ inline bool SocketStream::is_readable() const { } inline bool SocketStream::is_writable() const { - return select_write(sock_, write_timeout_sec_, write_timeout_usec_) > 0; + return select_write(sock_, write_timeout_sec_, write_timeout_usec_) > 0 && + is_socket_alive(sock_); } inline ssize_t SocketStream::read(char *ptr, size_t size) { @@ -7345,8 +7346,8 @@ inline bool SSLSocketStream::is_readable() const { } inline bool SSLSocketStream::is_writable() const { - return detail::select_write(sock_, write_timeout_sec_, write_timeout_usec_) > - 0; + return select_write(sock_, write_timeout_sec_, write_timeout_usec_) > 0 && + is_socket_alive(sock_); } inline ssize_t SSLSocketStream::read(char *ptr, size_t size) { diff --git a/test/test.cc b/test/test.cc index aa7eeca..0f61ed3 100644 --- a/test/test.cc +++ b/test/test.cc @@ -1643,7 +1643,8 @@ protected: res.set_content_provider( size_t(-1), "text/plain", [](size_t /*offset*/, size_t /*length*/, DataSink &sink) { - EXPECT_TRUE(sink.is_writable()); + if (!sink.is_writable()) return false; + sink.os << "data_chunk"; return true; }); @@ -5366,4 +5367,66 @@ TEST_F(UnixSocketTest, abstract) { t.join(); } #endif + +TEST(SocketStream, is_writable_UNIX) { + int fd[2]; + ASSERT_EQ(0, socketpair(AF_UNIX, SOCK_STREAM, 0, fd)); + + const auto asSocketStream = + [&] (socket_t fd, std::function func) { + return detail::process_client_socket(fd, 0, 0, 0, 0, func); + }; + asSocketStream(fd[0], [&] (Stream &s0) { + EXPECT_EQ(s0.socket(), fd[0]); + EXPECT_TRUE(s0.is_writable()); + + EXPECT_EQ(0, close(fd[1])); + EXPECT_FALSE(s0.is_writable()); + + return true; + }); + EXPECT_EQ(0, close(fd[0])); +} + +TEST(SocketStream, is_writable_INET) { + sockaddr_in addr; + memset(&addr, 0, sizeof(addr)); + addr.sin_family = AF_INET; + addr.sin_port = htons(PORT+1); + addr.sin_addr.s_addr = htonl(INADDR_LOOPBACK); + + int disconnected_svr_sock = -1; + std::thread svr {[&] { + const int s = socket(AF_INET, SOCK_STREAM, 0); + ASSERT_LE(0, s); + ASSERT_EQ(0, ::bind(s, reinterpret_cast(&addr), sizeof(addr))); + ASSERT_EQ(0, listen(s, 1)); + ASSERT_LE(0, disconnected_svr_sock = accept(s, nullptr, nullptr)); + ASSERT_EQ(0, close(s)); + }}; + std::this_thread::sleep_for(std::chrono::milliseconds(100)); + + std::thread cli {[&] { + const int s = socket(AF_INET, SOCK_STREAM, 0); + ASSERT_LE(0, s); + ASSERT_EQ(0, connect(s, reinterpret_cast(&addr), sizeof(addr))); + ASSERT_EQ(0, close(s)); + }}; + cli.join(); + svr.join(); + ASSERT_NE(disconnected_svr_sock, -1); + + const auto asSocketStream = + [&] (socket_t fd, std::function func) { + return detail::process_client_socket(fd, 0, 0, 0, 0, func); + }; + asSocketStream(disconnected_svr_sock, [&] (Stream &ss) { + EXPECT_EQ(ss.socket(), disconnected_svr_sock); + EXPECT_FALSE(ss.is_writable()); + + return true; + }); + + ASSERT_EQ(0, close(disconnected_svr_sock)); +} #endif // #ifndef _WIN32