This commit is contained in:
Leow Yong Zheng 2022-07-17 11:22:43 +08:00 committed by GitHub
commit 838ff597c0
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 496 additions and 34 deletions

50
example/WSdefs.h Normal file
View File

@ -0,0 +1,50 @@
#include <stdint.h>
#include <string>
namespace WSSPec {
struct WSFRAME {
bool FIN;
uint8_t RSV_FLAGS = 0;
uint8_t opcode;
bool masked;
uint64_t payload_len;
uint32_t masking_key = 0;
std::string payload = "";
};
namespace BYTE0_FLAGS {
enum MASKS : uint8_t {
FIN = 0b10000000,
RSV = 0b01110000,
OPCODE = 0b00001111
};
enum OPCODES : uint8_t {
CONTINUE = 0x00,
TEXT = 0x01,
BINARY = 0x02,
CLOSE = 0x08,
PING = 0x09,
PONG = 0x0A
};
}
namespace BYTE1_FLAGS {
enum MASKS : uint8_t {
IS_MASKED = 0b10000000,
PAYLOAD_LEN = 0b01111111
};
}
enum PAYLOAD_LEN_MODE : uint8_t {
NORMAL = 0,
EXT_16_BIT = 126,
EXT_64_BIT = 127
};
enum STATE {
IDLE,
READ_BYTE_0,
READ_BYTE_1,
READ_U16_LEN,
READ_U64_LEN,
READ_MASK,
READ_PAYLOAD,
PRINT_MESSAGE
};
}

View File

@ -0,0 +1,311 @@
#define CPPHTTPLIB_OPENSSL_SUPPORT
#include <httplib.h>
#include "WSdefs.h"
#include <thread>
/**
* NOT PRODUCTION READY!
* Use this code at your own risk! It "works", but may behave in unexpected ways
* if the server drops your connection without warning, you may need to terminate the program manually
* e.g. with ctrl-C
* TODO: handle unexpected disconnect
* TODO: find another websocket testing service
*/
std::mutex stream_send_mutex;
void sendTask(httplib::Stream &strm, const WSSPec::BYTE0_FLAGS::OPCODES frame_type, const std::string& payload = std::string(), bool is_mask = true);
void receiveThread(httplib::Stream &strm);
int main(int argc, char const *argv[])
{
httplib::CustomProtocolHandlers protc_handlers {
{
"websocket",
[](httplib::Stream &strm) {
std::cerr << "entered WebSocket handler" << std::endl;
bool stop = false;
auto terminal = [&strm, &stop]() {
for(std::string command = "p"; command != "q"; std::cin >> command) {
if(command == "p")
sendTask(strm, WSSPec::BYTE0_FLAGS::PING, "pingu");
else
sendTask(strm, WSSPec::BYTE0_FLAGS::TEXT, command);
}
stop = true;
sendTask(strm, WSSPec::BYTE0_FLAGS::CLOSE);
};
auto heartbeat = [&strm, &stop]() {
while(!stop) {
sendTask(strm, WSSPec::BYTE0_FLAGS::PING, "HB");
std::cerr << "sending HB" << std::endl;
std::this_thread::sleep_for(std::chrono::seconds(5));
}
};
std::thread receiver(receiveThread, std::ref(strm));
std::thread sender(terminal);
// std::thread hb(heartbeat);
receiver.join();
sender.join();
// hb.join();
return true;
}
}
};
httplib::Headers headers {
{ "Accept", "*/*" },
{ "Connection", "upgrade" },
{ "Upgrade", "websocket" },
{ "Sec-Fetch-Dest", "websocket" },
{ "Sec-Fetch-Mode", "websocket" },
{ "Sec-Fetch-Site", "same-origin" },
{ "Sec-Websocket-Key", "dGhlIHNhbXBsZSBub25jZQ==" },
{ "Sec-WebSocket-Version", "13" },
// { "Origin", "http://localhost:9090"}
{ "Origin", "https://websocketstest.com"}
// { "Origin", "https://echo.websocket.events" }
};
// httplib::Client c("localhost", 9090);
httplib::SSLClient c("websocketstest.com", 443);
// httplib::Client c("echo.websocket.events", 80);
std::cerr << "websocketstest.com accepts commands 'version,' 'echo,<message>' and 'timer,'" << std::endl;
std::cerr << "type a message, then hit enter to send" << std::endl;
std::cerr << "enter p to send ping and q to negotiate a disconnect" << std::endl;
auto res = c.Get(
// "/",
"/service",
headers,
protc_handlers
);
if(res) {
std::cerr << std::endl;
std::cerr << res->status << std::endl;
std::cerr << res->body << std::endl;
} else {
std::cerr << res.error() << std::endl;
}
return 0;
}
void sendTask(httplib::Stream &strm, const WSSPec::BYTE0_FLAGS::OPCODES frame_type, const std::string& payload, bool is_mask) {
std::lock_guard<std::mutex> stream_send_lock(stream_send_mutex);
uint64_t ext_payload_len = 0;
WSSPec::PAYLOAD_LEN_MODE payload_len_mode = WSSPec::NORMAL;
const uint8_t mask[] = {
0x47, 0x65, 0x33, 0xF3
};
// send frame headers
{
using namespace WSSPec::BYTE0_FLAGS;
uint8_t byte0 = FIN | frame_type;
strm.write(reinterpret_cast<const char*>(&byte0), sizeof(byte0));
}
{
using namespace WSSPec::BYTE1_FLAGS;
uint8_t byte1 = is_mask ? IS_MASKED : 0x0;
size_t real_payload_len = payload.length();
if (real_payload_len <= 125) {
byte1 |= real_payload_len & PAYLOAD_LEN;
} else if (real_payload_len < std::numeric_limits<uint16_t>::max()) {
byte1 |= WSSPec::EXT_16_BIT;
ext_payload_len = real_payload_len;
payload_len_mode = WSSPec::EXT_16_BIT;
} else {
byte1 |= WSSPec::EXT_64_BIT;
ext_payload_len = real_payload_len;
payload_len_mode = WSSPec::EXT_64_BIT;
}
strm.write(reinterpret_cast<const char*>(&byte1), sizeof(byte1));
}
switch(payload_len_mode) {
using namespace WSSPec;
case EXT_16_BIT:
strm.write(reinterpret_cast<const char *>(&ext_payload_len), sizeof(uint16_t));
break;
case EXT_64_BIT:
strm.write(reinterpret_cast<const char *>(&ext_payload_len), sizeof(uint64_t));
break;
case NORMAL:
default:
break;
}
if(is_mask)
strm.write(reinterpret_cast<const char *>(&mask), sizeof(mask));
for(size_t i = 0; i < payload.length(); i++) {
uint8_t byte = payload.at(i);
if(is_mask)
byte ^= mask[i % 4];
strm.write(reinterpret_cast<const char *>(&byte), sizeof(byte));
}
}
void receiveThread(httplib::Stream &strm) {
const size_t bufsize = 32;
char buffer[bufsize];
struct WSSPec::WSFRAME frame {};
do {
static bool waiting = false;
static enum WSSPec::STATE state = WSSPec::IDLE;
static size_t length_bytes_read = 0;
static size_t mask_bytes_read = 0;
static size_t payload_bytes_read = 0;
ssize_t readsize;
if(!httplib::detail::is_socket_alive(strm.socket())) {
std::cerr << "disconnected" << std::endl;
break;
}
if(!strm.is_readable()) {
std::this_thread::yield();
if(!waiting) {
waiting = true;
std::cerr << "websocket idle" << std::endl;
// sendTask(strm, WSSPec::BYTE0_FLAGS::PING, "pingu");
}
continue;
}
waiting = false;
{
std::lock_guard<std::mutex> stream_send_lock(stream_send_mutex);
readsize = strm.read(buffer, bufsize);
}
for (size_t i = 0; i < readsize; i++)
{
switch (state)
{
using namespace WSSPec;
case IDLE:
state = READ_BYTE_0;
case READ_BYTE_0: {
using namespace WSSPec::BYTE0_FLAGS;
frame.FIN = buffer[i] & FIN;
frame.RSV_FLAGS = buffer[i] & RSV;
frame.opcode = buffer[i] & OPCODE;
state = READ_BYTE_1;
break;
}
case READ_BYTE_1: {
using namespace WSSPec::BYTE1_FLAGS;
frame.masked = buffer[i] & IS_MASKED;
frame.payload_len = buffer[i] & PAYLOAD_LEN;
if (frame.payload_len == WSSPec::EXT_16_BIT) {
frame.payload_len = 0;
state = READ_U16_LEN;
}
else if (frame.payload_len == WSSPec::EXT_64_BIT) {
frame.payload_len = 0;
state = READ_U64_LEN;
}
else
state = frame.masked ? READ_MASK : READ_PAYLOAD;
break;
}
case READ_U16_LEN:
frame.payload_len = static_cast<uint8_t>(buffer[i]) + frame.payload_len << 8;
if (++length_bytes_read >= sizeof(uint16_t)) {
length_bytes_read = 0;
state = frame.masked ? READ_MASK : READ_PAYLOAD;
}
break;
case READ_U64_LEN:
frame.payload_len = static_cast<uint8_t>(buffer[i]) + frame.payload_len << 8;
if (++length_bytes_read >= sizeof(uint64_t)) {
length_bytes_read = 0;
state = frame.masked ? READ_MASK : READ_PAYLOAD;
}
break;
case READ_MASK:
if (mask_bytes_read == 0)
frame.masking_key = 0;
frame.masking_key = static_cast<uint8_t>(buffer[i]) + (frame.masking_key << 8);
if (++mask_bytes_read >= sizeof(frame.masking_key)) {
mask_bytes_read = 0;
state = READ_PAYLOAD;
}
break;
case READ_PAYLOAD:
if (payload_bytes_read == 0) {
std::cerr << "interpreted length: "
<< frame.payload_len << std::endl;
size_t capacity = frame.payload_len;
if (!frame.payload.empty())
capacity += frame.payload.capacity();
frame.payload.reserve(capacity);
std::cerr << "Received frame of type ";
switch(frame.opcode) {
using namespace WSSPec::BYTE0_FLAGS;
case PING:
std::cerr << "PING";
break;
case PONG:
std::cerr << "PONG";
break;
case TEXT:
std::cerr << "TEXT";
break;
case BINARY:
std::cerr << "BINARY";
break;
case CONTINUE:
std::cerr << "CONTINUE";
break;
case CLOSE:
std::cerr << "CLOSE";
break;
default:
std::cerr << std::hex << (int) frame.opcode << std::dec;
break;
}
std::cerr << std::endl;
}
if (frame.masked) {
uint8_t *mask = reinterpret_cast<uint8_t *>(&frame.masked);
buffer[i] = buffer[i] ^ mask[i % 4];
}
frame.payload.push_back(buffer[i]);
if (++payload_bytes_read >= frame.payload_len) {
payload_bytes_read = 0;
if (frame.FIN) {
state = PRINT_MESSAGE;
} else {
state = IDLE;
// break;
}
}
else break;
case PRINT_MESSAGE:
switch(frame.opcode) {
using namespace WSSPec::BYTE0_FLAGS;
case PING:
std::cerr << "received ping with payload "
<< frame.payload << ", will echo"
<< std::endl;
sendTask(strm, PONG, frame.payload);
break;
case PONG:
std::cerr << "received pong with payload "
<< frame.payload
<< std::endl;
break;
case TEXT:
std::cout << frame.payload << std::endl;
break;
}
frame.payload.clear();
state = IDLE;
break;
default:
break;
}
}
} while (frame.opcode != WSSPec::BYTE0_FLAGS::CLOSE);
std::cerr << "Quit" << std::endl;
}

169
httplib.h
View File

@ -376,6 +376,11 @@ using ContentReceiver =
using MultipartContentHeader =
std::function<bool(const MultipartFormData &file)>;
// forward declaration required to make typedefs work
class Stream;
using StreamManager = std::function<bool(Stream &strm)>;
using CustomProtocolHandlers =
std::multimap<std::string, StreamManager, detail::ci>;
class ContentReader {
public:
using Reader = std::function<bool(ContentReceiver receiver)>;
@ -423,6 +428,8 @@ struct Request {
ResponseHandler response_handler;
ContentReceiverWithProgress content_receiver;
Progress progress;
CustomProtocolHandlers alt_protocol_handlers;
std::string forced_alt_protocol;
#ifdef CPPHTTPLIB_OPENSSL_SUPPORT
const SSL *ssl = nullptr;
#endif
@ -485,6 +492,9 @@ struct Response {
const char *content_type, ContentProviderWithoutLength provider,
ContentProviderResourceReleaser resource_releaser = nullptr);
void use_custom_protocol(
const char *protocol_name, StreamManager custom_protocol_handler);
Response() = default;
Response(const Response &) = default;
Response &operator=(const Response &) = default;
@ -500,6 +510,7 @@ struct Response {
size_t content_length_ = 0;
ContentProvider content_provider_;
ContentProviderResourceReleaser content_provider_resource_releaser_;
StreamManager custom_protocol_handler_;
bool is_chunked_content_provider_ = false;
bool content_provider_success_ = false;
};
@ -863,6 +874,11 @@ public:
Result Get(const char *path);
Result Get(const char *path, const Headers &headers);
Result Get(const char *path, const Headers &headers,
CustomProtocolHandlers &protocol_handlers, const char *force_protocol = "");
Result Get(const char *path, const Headers &headers,
ResponseHandler response_handler,
CustomProtocolHandlers &protocol_handlers, const char *force_protocol = "");
Result Get(const char *path, Progress progress);
Result Get(const char *path, const Headers &headers, Progress progress);
Result Get(const char *path, ContentReceiver content_receiver);
@ -1172,7 +1188,7 @@ private:
std::string adjust_host_string(const std::string &host) const;
virtual bool process_socket(const Socket &socket,
std::function<bool(Stream &strm)> callback);
StreamManager callback);
virtual bool is_ssl() const;
};
@ -1200,6 +1216,11 @@ public:
Result Get(const char *path);
Result Get(const char *path, const Headers &headers);
Result Get(const char *path, const Headers &headers,
CustomProtocolHandlers &protocol_handlers, const char *force_protocol = "");
Result Get(const char *path, const Headers &headers,
ResponseHandler response_handler,
CustomProtocolHandlers &protocol_handlers, const char *force_protocol = "");
Result Get(const char *path, Progress progress);
Result Get(const char *path, const Headers &headers, Progress progress);
Result Get(const char *path, ContentReceiver content_receiver);
@ -1443,7 +1464,7 @@ private:
void shutdown_ssl_impl(Socket &socket, bool shutdown_socket);
bool process_socket(const Socket &socket,
std::function<bool(Stream &strm)> callback) override;
StreamManager callback) override;
bool is_ssl() const override;
bool connect_with_proxy(Socket &sock, Response &res, bool &success,
@ -4583,6 +4604,24 @@ inline void Response::set_chunked_content_provider(
is_chunked_content_provider_ = true;
}
inline void Response::use_custom_protocol(
const char *protocol_name, StreamManager custom_protocol_handler) {
std::string connection_header = "Upgrade";
if (has_header("Connection")) {
std::string header_value = get_header_value("Connection");
if (header_value.find(connection_header) == std::string::npos) {
connection_header.append(", " + header_value);
} else {
connection_header = std::move(header_value);
}
}
if (status == -1) { status = 101; }
set_header("Connection", std::move(connection_header));
set_header("Upgrade", protocol_name);
custom_protocol_handler_ = custom_protocol_handler;
}
// Result implementation
inline bool Result::has_request_header(const char *key) const {
return request_headers_.find(key) != request_headers_.end();
@ -5100,6 +5139,8 @@ inline bool Server::write_response_core(Stream &strm, bool close_connection,
res.content_provider_success_ = false;
ret = false;
}
} else if (res.custom_protocol_handler_) {
ret = res.custom_protocol_handler_(strm);
}
}
@ -6423,38 +6464,66 @@ inline bool ClientImpl::process_request(Stream &strm, Request &req,
}
}
auto out =
req.content_receiver
? static_cast<ContentReceiverWithProgress>(
[&](const char *buf, size_t n, uint64_t off, uint64_t len) {
if (redirect) { return true; }
auto ret = req.content_receiver(buf, n, off, len);
if (!ret) { error = Error::Canceled; }
return ret;
})
: static_cast<ContentReceiverWithProgress>(
[&](const char *buf, size_t n, uint64_t /*off*/,
uint64_t /*len*/) {
if (res.body.size() + n > res.body.max_size()) {
return false;
}
res.body.append(buf, n);
return true;
});
if(!req.forced_alt_protocol.empty()
&& !redirect
&& req.alt_protocol_handlers.find(req.forced_alt_protocol) != req.alt_protocol_handlers.end()) {
if(!req.alt_protocol_handlers.find(req.forced_alt_protocol)->second(strm)) {
error = Error::Canceled;
return false;
}
}
else if(res.status == 101 && res.has_header("Upgrade")) {
std::stringstream parse_upgrade_header(res.get_header_value("Upgrade"));
bool protocol_negotiated = false;
while(parse_upgrade_header.good()) {
std::string protocol_name;
std::getline(parse_upgrade_header, protocol_name, ',');
if(req.alt_protocol_handlers.find(protocol_name) != req.alt_protocol_handlers.end()) {
protocol_negotiated = true;
if(!req.alt_protocol_handlers.find(protocol_name)->second(strm)) {
error = Error::Canceled;
return false;
}
}
}
if(!protocol_negotiated) {
error = Error::Canceled;
return false;
}
} else {
auto out =
req.content_receiver
? static_cast<ContentReceiverWithProgress>(
[&](const char *buf, size_t n, uint64_t off, uint64_t len) {
if (redirect) { return true; }
auto ret = req.content_receiver(buf, n, off, len);
if (!ret) { error = Error::Canceled; }
return ret;
})
: static_cast<ContentReceiverWithProgress>(
[&](const char *buf, size_t n, uint64_t /*off*/,
uint64_t /*len*/) {
if (res.body.size() + n > res.body.max_size()) {
return false;
}
res.body.append(buf, n);
return true;
});
auto progress = [&](uint64_t current, uint64_t total) {
if (!req.progress || redirect) { return true; }
auto ret = req.progress(current, total);
if (!ret) { error = Error::Canceled; }
return ret;
};
auto progress = [&](uint64_t current, uint64_t total) {
if (!req.progress || redirect) { return true; }
auto ret = req.progress(current, total);
if (!ret) { error = Error::Canceled; }
return ret;
};
int dummy_status;
if (!detail::read_content(strm, res, (std::numeric_limits<size_t>::max)(),
dummy_status, std::move(progress), std::move(out),
decompress_)) {
if (error != Error::Canceled) { error = Error::Read; }
return false;
int dummy_status;
if (!detail::read_content(strm, res, (std::numeric_limits<size_t>::max)(),
dummy_status, std::move(progress), std::move(out),
decompress_)) {
if (error != Error::Canceled) { error = Error::Read; }
return false;
}
}
}
@ -6484,7 +6553,7 @@ inline bool ClientImpl::process_request(Stream &strm, Request &req,
inline bool
ClientImpl::process_socket(const Socket &socket,
std::function<bool(Stream &strm)> callback) {
StreamManager callback) {
return detail::process_client_socket(
socket.sock, read_timeout_sec_, read_timeout_usec_, write_timeout_sec_,
write_timeout_usec_, std::move(callback));
@ -6504,6 +6573,27 @@ inline Result ClientImpl::Get(const char *path, const Headers &headers) {
return Get(path, headers, Progress());
}
inline Result ClientImpl::Get(const char *path, const Headers &headers,
CustomProtocolHandlers &protocol_handlers,
const char *force_protocol) {
return Get(path, headers, nullptr, protocol_handlers, force_protocol);
};
inline Result ClientImpl::Get(const char *path, const Headers &headers,
ResponseHandler response_handler,
CustomProtocolHandlers &protocol_handlers,
const char *force_protocol) {
Request req;
req.method = "GET";
req.path = path;
req.headers = headers;
req.alt_protocol_handlers = protocol_handlers;
req.forced_alt_protocol = force_protocol;
req.response_handler = std::move(response_handler);
return send_(std::move(req));
}
inline Result ClientImpl::Get(const char *path, const Headers &headers,
Progress progress) {
Request req;
@ -7650,7 +7740,7 @@ inline void SSLClient::shutdown_ssl_impl(Socket &socket,
inline bool
SSLClient::process_socket(const Socket &socket,
std::function<bool(Stream &strm)> callback) {
StreamManager callback) {
assert(socket.ssl);
return detail::process_client_socket_ssl(
socket.ssl, socket.sock, read_timeout_sec_, read_timeout_usec_,
@ -7854,6 +7944,17 @@ inline Result Client::Get(const char *path) { return cli_->Get(path); }
inline Result Client::Get(const char *path, const Headers &headers) {
return cli_->Get(path, headers);
}
inline Result Client::Get(const char *path, const Headers &headers,
CustomProtocolHandlers &protocol_handlers,
const char *force_protocol) {
return cli_->Get(path, headers, protocol_handlers, force_protocol);
}
inline Result Client::Get(const char *path, const Headers &headers,
ResponseHandler response_handler,
CustomProtocolHandlers &protocol_handlers,
const char *force_protocol) {
return cli_->Get(path, headers, std::move(response_handler), protocol_handlers, force_protocol);
}
inline Result Client::Get(const char *path, Progress progress) {
return cli_->Get(path, std::move(progress));
}