#include "httplib.h"
namespace httplib {
+// httplib::any — type-erased value container (C++11 compatible)
+// On C++17+ builds, thin wrappers around std::any are provided.
/*
* Implementation that will be part of the .cc file if split into .h + .cc.
return 0;
}
+} // namespace detail
+
+namespace ws {
+namespace impl {
+
+bool is_valid_utf8(const std::string &s) {
+ size_t i = 0;
+ auto n = s.size();
+ while (i < n) {
+ auto c = static_cast<unsigned char>(s[i]);
+ size_t len;
+ uint32_t cp;
+ if (c < 0x80) {
+ i++;
+ continue;
+ } else if ((c & 0xE0) == 0xC0) {
+ len = 2;
+ cp = c & 0x1F;
+ } else if ((c & 0xF0) == 0xE0) {
+ len = 3;
+ cp = c & 0x0F;
+ } else if ((c & 0xF8) == 0xF0) {
+ len = 4;
+ cp = c & 0x07;
+ } else {
+ return false;
+ }
+ if (i + len > n) { return false; }
+ for (size_t j = 1; j < len; j++) {
+ auto b = static_cast<unsigned char>(s[i + j]);
+ if ((b & 0xC0) != 0x80) { return false; }
+ cp = (cp << 6) | (b & 0x3F);
+ }
+ // Overlong encoding check
+ if (len == 2 && cp < 0x80) { return false; }
+ if (len == 3 && cp < 0x800) { return false; }
+ if (len == 4 && cp < 0x10000) { return false; }
+ // Surrogate halves (U+D800..U+DFFF) and beyond U+10FFFF are invalid
+ if (cp >= 0xD800 && cp <= 0xDFFF) { return false; }
+ if (cp > 0x10FFFF) { return false; }
+ i += len;
+ }
+ return true;
+}
+
+} // namespace impl
+} // namespace ws
+
+namespace detail {
+
// NOTE: This code came up with the following stackoverflow post:
// https://stackoverflow.com/questions/180947/base64-decode-snippet-in-c
std::string base64_encode(const std::string &in) {
return out;
}
+std::string sha1(const std::string &input) {
+ // RFC 3174 SHA-1 implementation
+ auto left_rotate = [](uint32_t x, uint32_t n) -> uint32_t {
+ return (x << n) | (x >> (32 - n));
+ };
+
+ uint32_t h0 = 0x67452301;
+ uint32_t h1 = 0xEFCDAB89;
+ uint32_t h2 = 0x98BADCFE;
+ uint32_t h3 = 0x10325476;
+ uint32_t h4 = 0xC3D2E1F0;
+
+ // Pre-processing: adding padding bits
+ std::string msg = input;
+ uint64_t original_bit_len = static_cast<uint64_t>(msg.size()) * 8;
+ msg.push_back(static_cast<char>(0x80));
+ while (msg.size() % 64 != 56) {
+ msg.push_back(0);
+ }
+
+ // Append original length in bits as 64-bit big-endian
+ for (int i = 56; i >= 0; i -= 8) {
+ msg.push_back(static_cast<char>((original_bit_len >> i) & 0xFF));
+ }
+
+ // Process each 512-bit chunk
+ for (size_t offset = 0; offset < msg.size(); offset += 64) {
+ uint32_t w[80];
+
+ for (size_t i = 0; i < 16; i++) {
+ w[i] =
+ (static_cast<uint32_t>(static_cast<uint8_t>(msg[offset + i * 4]))
+ << 24) |
+ (static_cast<uint32_t>(static_cast<uint8_t>(msg[offset + i * 4 + 1]))
+ << 16) |
+ (static_cast<uint32_t>(static_cast<uint8_t>(msg[offset + i * 4 + 2]))
+ << 8) |
+ (static_cast<uint32_t>(
+ static_cast<uint8_t>(msg[offset + i * 4 + 3])));
+ }
+
+ for (int i = 16; i < 80; i++) {
+ w[i] = left_rotate(w[i - 3] ^ w[i - 8] ^ w[i - 14] ^ w[i - 16], 1);
+ }
+
+ uint32_t a = h0, b = h1, c = h2, d = h3, e = h4;
+
+ for (int i = 0; i < 80; i++) {
+ uint32_t f, k;
+ if (i < 20) {
+ f = (b & c) | ((~b) & d);
+ k = 0x5A827999;
+ } else if (i < 40) {
+ f = b ^ c ^ d;
+ k = 0x6ED9EBA1;
+ } else if (i < 60) {
+ f = (b & c) | (b & d) | (c & d);
+ k = 0x8F1BBCDC;
+ } else {
+ f = b ^ c ^ d;
+ k = 0xCA62C1D6;
+ }
+
+ uint32_t temp = left_rotate(a, 5) + f + e + k + w[i];
+ e = d;
+ d = c;
+ c = left_rotate(b, 30);
+ b = a;
+ a = temp;
+ }
+
+ h0 += a;
+ h1 += b;
+ h2 += c;
+ h3 += d;
+ h4 += e;
+ }
+
+ // Produce the final hash as a 20-byte binary string
+ std::string hash(20, '\0');
+ for (size_t i = 0; i < 4; i++) {
+ hash[i] = static_cast<char>((h0 >> (24 - i * 8)) & 0xFF);
+ hash[4 + i] = static_cast<char>((h1 >> (24 - i * 8)) & 0xFF);
+ hash[8 + i] = static_cast<char>((h2 >> (24 - i * 8)) & 0xFF);
+ hash[12 + i] = static_cast<char>((h3 >> (24 - i * 8)) & 0xFF);
+ hash[16 + i] = static_cast<char>((h4 >> (24 - i * 8)) & 0xFF);
+ }
+ return hash;
+}
+
+std::string websocket_accept_key(const std::string &client_key) {
+ const std::string magic = "258EAFA5-E914-47DA-95CA-C5AB0DC85B11";
+ return base64_encode(sha1(client_key + magic));
+}
+
+bool is_websocket_upgrade(const Request &req) {
+ if (req.method != "GET") { return false; }
+
+ // Check Upgrade: websocket (case-insensitive)
+ auto upgrade_it = req.headers.find("Upgrade");
+ if (upgrade_it == req.headers.end()) { return false; }
+ auto upgrade_val = upgrade_it->second;
+ std::transform(upgrade_val.begin(), upgrade_val.end(), upgrade_val.begin(),
+ ::tolower);
+ if (upgrade_val != "websocket") { return false; }
+
+ // Check Connection header contains "Upgrade"
+ auto connection_it = req.headers.find("Connection");
+ if (connection_it == req.headers.end()) { return false; }
+ auto connection_val = connection_it->second;
+ std::transform(connection_val.begin(), connection_val.end(),
+ connection_val.begin(), ::tolower);
+ if (connection_val.find("upgrade") == std::string::npos) { return false; }
+
+ // Check Sec-WebSocket-Key is a valid base64-encoded 16-byte value (24 chars)
+ // RFC 6455 Section 4.2.1
+ auto ws_key = req.get_header_value("Sec-WebSocket-Key");
+ if (ws_key.size() != 24 || ws_key[22] != '=' || ws_key[23] != '=') {
+ return false;
+ }
+ static const std::string b64chars =
+ "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/";
+ for (size_t i = 0; i < 22; i++) {
+ if (b64chars.find(ws_key[i]) == std::string::npos) { return false; }
+ }
+
+ // Check Sec-WebSocket-Version: 13
+ auto version = req.get_header_value("Sec-WebSocket-Version");
+ if (version != "13") { return false; }
+
+ return true;
+}
+
+bool write_websocket_frame(Stream &strm, ws::Opcode opcode,
+ const char *data, size_t len, bool fin,
+ bool mask) {
+ // First byte: FIN + opcode
+ uint8_t header[2];
+ header[0] = static_cast<uint8_t>((fin ? 0x80 : 0x00) |
+ (static_cast<uint8_t>(opcode) & 0x0F));
+
+ // Second byte: MASK + payload length
+ if (len < 126) {
+ header[1] = static_cast<uint8_t>(len);
+ if (mask) { header[1] |= 0x80; }
+ if (strm.write(reinterpret_cast<char *>(header), 2) < 0) { return false; }
+ } else if (len <= 0xFFFF) {
+ header[1] = 126;
+ if (mask) { header[1] |= 0x80; }
+ if (strm.write(reinterpret_cast<char *>(header), 2) < 0) { return false; }
+ uint8_t ext[2];
+ ext[0] = static_cast<uint8_t>((len >> 8) & 0xFF);
+ ext[1] = static_cast<uint8_t>(len & 0xFF);
+ if (strm.write(reinterpret_cast<char *>(ext), 2) < 0) { return false; }
+ } else {
+ header[1] = 127;
+ if (mask) { header[1] |= 0x80; }
+ if (strm.write(reinterpret_cast<char *>(header), 2) < 0) { return false; }
+ uint8_t ext[8];
+ for (int i = 7; i >= 0; i--) {
+ ext[7 - i] = static_cast<uint8_t>((len >> (i * 8)) & 0xFF);
+ }
+ if (strm.write(reinterpret_cast<char *>(ext), 8) < 0) { return false; }
+ }
+
+ if (mask) {
+ // Generate random mask key
+ thread_local std::mt19937 rng(std::random_device{}());
+ uint8_t mask_key[4];
+ auto r = rng();
+ std::memcpy(mask_key, &r, 4);
+ if (strm.write(reinterpret_cast<char *>(mask_key), 4) < 0) { return false; }
+
+ // Write masked payload in chunks
+ const size_t chunk_size = 4096;
+ std::vector<char> buf((std::min)(len, chunk_size));
+ for (size_t offset = 0; offset < len; offset += chunk_size) {
+ size_t n = (std::min)(chunk_size, len - offset);
+ for (size_t i = 0; i < n; i++) {
+ buf[i] =
+ data[offset + i] ^ static_cast<char>(mask_key[(offset + i) % 4]);
+ }
+ if (strm.write(buf.data(), n) < 0) { return false; }
+ }
+ } else {
+ if (len > 0) {
+ if (strm.write(data, len) < 0) { return false; }
+ }
+ }
+
+ return true;
+}
+
+} // namespace detail
+
+namespace ws {
+namespace impl {
+
+bool read_websocket_frame(Stream &strm, Opcode &opcode,
+ std::string &payload, bool &fin,
+ bool expect_masked, size_t max_len) {
+ // Read first 2 bytes
+ uint8_t header[2];
+ if (strm.read(reinterpret_cast<char *>(header), 2) != 2) { return false; }
+
+ fin = (header[0] & 0x80) != 0;
+
+ // RSV1, RSV2, RSV3 must be 0 when no extension is negotiated
+ if (header[0] & 0x70) { return false; }
+
+ opcode = static_cast<Opcode>(header[0] & 0x0F);
+ bool masked = (header[1] & 0x80) != 0;
+ uint64_t payload_len = header[1] & 0x7F;
+
+ // RFC 6455 Section 5.5: control frames MUST NOT be fragmented and
+ // MUST have a payload length of 125 bytes or less
+ bool is_control = (static_cast<uint8_t>(opcode) & 0x08) != 0;
+ if (is_control) {
+ if (!fin) { return false; }
+ if (payload_len > 125) { return false; }
+ }
+
+ if (masked != expect_masked) { return false; }
+
+ // Extended payload length
+ if (payload_len == 126) {
+ uint8_t ext[2];
+ if (strm.read(reinterpret_cast<char *>(ext), 2) != 2) { return false; }
+ payload_len = (static_cast<uint64_t>(ext[0]) << 8) | ext[1];
+ } else if (payload_len == 127) {
+ uint8_t ext[8];
+ if (strm.read(reinterpret_cast<char *>(ext), 8) != 8) { return false; }
+ // RFC 6455 Section 5.2: the most significant bit MUST be 0
+ if (ext[0] & 0x80) { return false; }
+ payload_len = 0;
+ for (int i = 0; i < 8; i++) {
+ payload_len = (payload_len << 8) | ext[i];
+ }
+ }
+
+ if (payload_len > max_len) { return false; }
+
+ // Read mask key if present
+ uint8_t mask_key[4] = {0};
+ if (masked) {
+ if (strm.read(reinterpret_cast<char *>(mask_key), 4) != 4) { return false; }
+ }
+
+ // Read payload
+ payload.resize(static_cast<size_t>(payload_len));
+ if (payload_len > 0) {
+ size_t total_read = 0;
+ while (total_read < payload_len) {
+ auto n = strm.read(&payload[total_read],
+ static_cast<size_t>(payload_len - total_read));
+ if (n <= 0) { return false; }
+ total_read += static_cast<size_t>(n);
+ }
+ }
+
+ // Unmask if needed
+ if (masked) {
+ for (size_t i = 0; i < payload.size(); i++) {
+ payload[i] ^= static_cast<char>(mask_key[i % 4]);
+ }
+ }
+
+ return true;
+}
+
+} // namespace impl
+} // namespace ws
+
+namespace detail {
+
bool is_valid_path(const std::string &path) {
size_t level = 0;
size_t i = 0;
void get_local_ip_and_port(std::string &ip, int &port) const override;
socket_t socket() const override;
time_t duration() const override;
+ void set_read_timeout(time_t sec, time_t usec = 0) override;
private:
socket_t sock_;
return true;
}
+bool read_websocket_upgrade_response(Stream &strm,
+ const std::string &expected_accept,
+ std::string &selected_subprotocol) {
+ // Read status line
+ const auto bufsiz = 2048;
+ char buf[bufsiz];
+ stream_line_reader line_reader(strm, buf, bufsiz);
+ if (!line_reader.getline()) { return false; }
+
+ // Check for "HTTP/1.1 101"
+ auto line = std::string(line_reader.ptr(), line_reader.size());
+ if (line.find("HTTP/1.1 101") == std::string::npos) { return false; }
+
+ // Parse headers using existing read_headers
+ Headers headers;
+ if (!read_headers(strm, headers)) { return false; }
+
+ // Verify Upgrade: websocket (case-insensitive)
+ auto upgrade_it = headers.find("Upgrade");
+ if (upgrade_it == headers.end()) { return false; }
+ auto upgrade_val = upgrade_it->second;
+ std::transform(upgrade_val.begin(), upgrade_val.end(), upgrade_val.begin(),
+ ::tolower);
+ if (upgrade_val != "websocket") { return false; }
+
+ // Verify Connection header contains "Upgrade" (case-insensitive)
+ auto connection_it = headers.find("Connection");
+ if (connection_it == headers.end()) { return false; }
+ auto connection_val = connection_it->second;
+ std::transform(connection_val.begin(), connection_val.end(),
+ connection_val.begin(), ::tolower);
+ if (connection_val.find("upgrade") == std::string::npos) { return false; }
+
+ // Verify Sec-WebSocket-Accept header value
+ auto it = headers.find("Sec-WebSocket-Accept");
+ if (it == headers.end() || it->second != expected_accept) { return false; }
+
+ // Extract negotiated subprotocol
+ auto proto_it = headers.find("Sec-WebSocket-Protocol");
+ if (proto_it != headers.end()) { selected_subprotocol = proto_it->second; }
+
+ return true;
+}
+
enum class ReadContentResult {
Success, // Successfully read the content
PayloadTooLarge, // The content exceeds the specified payload limit
return body;
}
+size_t get_multipart_content_length(const UploadFormDataItems &items,
+ const std::string &boundary) {
+ size_t total = 0;
+ for (const auto &item : items) {
+ total += serialize_multipart_formdata_item_begin(item, boundary).size();
+ total += item.content.size();
+ total += serialize_multipart_formdata_item_end().size();
+ }
+ total += serialize_multipart_formdata_finish(boundary).size();
+ return total;
+}
+
+struct MultipartSegment {
+ const char *data;
+ size_t size;
+};
+
+// NOTE: items must outlive the returned ContentProvider
+// (safe for synchronous use inside Post/Put/Patch)
+ContentProvider
+make_multipart_content_provider(const UploadFormDataItems &items,
+ const std::string &boundary) {
+ // Own the per-item header strings and the finish string
+ std::vector<std::string> owned;
+ owned.reserve(items.size() + 1);
+ for (const auto &item : items)
+ owned.push_back(serialize_multipart_formdata_item_begin(item, boundary));
+ owned.push_back(serialize_multipart_formdata_finish(boundary));
+
+ // Flat segment list: [header, content, "\r\n"] * N + [finish]
+ std::vector<MultipartSegment> segs;
+ segs.reserve(items.size() * 3 + 1);
+ static const char crlf[] = "\r\n";
+ for (size_t i = 0; i < items.size(); i++) {
+ segs.push_back({owned[i].data(), owned[i].size()});
+ segs.push_back({items[i].content.data(), items[i].content.size()});
+ segs.push_back({crlf, 2});
+ }
+ segs.push_back({owned.back().data(), owned.back().size()});
+
+ struct MultipartState {
+ std::vector<std::string> owned;
+ std::vector<MultipartSegment> segs;
+ };
+ auto state = std::make_shared<MultipartState>();
+ state->owned = std::move(owned);
+ // `segs` holds raw pointers into owned strings; std::string move preserves
+ // the data pointer, so these pointers remain valid after the move above.
+ state->segs = std::move(segs);
+
+ return [state](size_t offset, size_t length, DataSink &sink) -> bool {
+ size_t pos = 0;
+ for (const auto &seg : state->segs) {
+ // Loop invariant: pos <= offset (proven by advancing pos only when
+ // offset - pos >= seg.size, i.e., the segment doesn't contain offset)
+ if (seg.size > 0 && offset - pos < seg.size) {
+ size_t seg_offset = offset - pos;
+ size_t available = seg.size - seg_offset;
+ size_t to_write = (std::min)(available, length);
+ return sink.write(seg.data + seg_offset, to_write);
+ }
+ pos += seg.size;
+ }
+ return true; // past end (shouldn't be reached when content_length is exact)
+ };
+}
+
void coalesce_ranges(Ranges &ranges, size_t content_length) {
if (ranges.size() <= 1) return;
return false;
}
-bool has_crlf(const std::string &s) {
- auto p = s.c_str();
- while (*p) {
- if (*p == '\r' || *p == '\n') { return true; }
- p++;
- }
- return false;
-}
-
#ifdef _WIN32
class WSInit {
public:
bool is_field_value(const std::string &s) { return is_field_content(s); }
} // namespace fields
+
+bool perform_websocket_handshake(Stream &strm, const std::string &host,
+ int port, const std::string &path,
+ const Headers &headers,
+ std::string &selected_subprotocol) {
+ // Validate path and host
+ if (!fields::is_field_value(path) || !fields::is_field_value(host)) {
+ return false;
+ }
+
+ // Validate user-provided headers
+ for (const auto &h : headers) {
+ if (!fields::is_field_name(h.first) || !fields::is_field_value(h.second)) {
+ return false;
+ }
+ }
+
+ // Generate random Sec-WebSocket-Key
+ thread_local std::mt19937 rng(std::random_device{}());
+ std::string key_bytes(16, '\0');
+ for (size_t i = 0; i < 16; i += 4) {
+ auto r = rng();
+ std::memcpy(&key_bytes[i], &r, (std::min)(size_t(4), size_t(16 - i)));
+ }
+ auto client_key = base64_encode(key_bytes);
+
+ // Build upgrade request
+ std::string req_str = "GET " + path + " HTTP/1.1\r\n";
+ req_str += "Host: " + host + ":" + std::to_string(port) + "\r\n";
+ req_str += "Upgrade: websocket\r\n";
+ req_str += "Connection: Upgrade\r\n";
+ req_str += "Sec-WebSocket-Key: " + client_key + "\r\n";
+ req_str += "Sec-WebSocket-Version: 13\r\n";
+ for (const auto &h : headers) {
+ req_str += h.first + ": " + h.second + "\r\n";
+ }
+ req_str += "\r\n";
+
+ if (strm.write(req_str.data(), req_str.size()) < 0) { return false; }
+
+ // Verify 101 response and Sec-WebSocket-Accept header
+ auto expected_accept = websocket_accept_key(client_key);
+ return read_websocket_upgrade_response(strm, expected_accept,
+ selected_subprotocol);
+}
+
} // namespace detail
/*
void get_local_ip_and_port(std::string &ip, int &port) const override;
socket_t socket() const override;
time_t duration() const override;
+ void set_read_timeout(time_t sec, time_t usec = 0) override;
private:
socket_t sock_;
#endif
return hash_to_hex(hash);
}
+#elif defined(CPPHTTPLIB_WOLFSSL_SUPPORT)
+namespace {
+template <size_t N>
+std::string hash_to_hex(const unsigned char (&hash)[N]) {
+ std::stringstream ss;
+ for (size_t i = 0; i < N; ++i) {
+ ss << std::hex << std::setw(2) << std::setfill('0')
+ << static_cast<unsigned int>(hash[i]);
+ }
+ return ss.str();
+}
+} // namespace
+
+std::string MD5(const std::string &s) {
+ unsigned char hash[WC_MD5_DIGEST_SIZE];
+ wc_Md5Hash(reinterpret_cast<const unsigned char *>(s.c_str()),
+ static_cast<word32>(s.size()), hash);
+ return hash_to_hex(hash);
+}
+
+std::string SHA_256(const std::string &s) {
+ unsigned char hash[WC_SHA256_DIGEST_SIZE];
+ wc_Sha256Hash(reinterpret_cast<const unsigned char *>(s.c_str()),
+ static_cast<word32>(s.size()), hash);
+ return hash_to_hex(hash);
+}
+
+std::string SHA_512(const std::string &s) {
+ unsigned char hash[WC_SHA512_DIGEST_SIZE];
+ wc_Sha512Hash(reinterpret_cast<const unsigned char *>(s.c_str()),
+ static_cast<word32>(s.size()), hash);
+ return hash_to_hex(hash);
+}
#endif
bool is_ip_address(const std::string &host) {
}
#endif // _WIN32
+bool setup_client_tls_session(const std::string &host, tls::ctx_t &ctx,
+ tls::session_t &session, socket_t sock,
+ bool server_certificate_verification,
+ const std::string &ca_cert_file_path,
+ tls::ca_store_t ca_cert_store,
+ time_t timeout_sec, time_t timeout_usec) {
+ using namespace tls;
+
+ ctx = create_client_context();
+ if (!ctx) { return false; }
+
+ if (server_certificate_verification) {
+ if (!ca_cert_file_path.empty()) {
+ load_ca_file(ctx, ca_cert_file_path.c_str());
+ }
+ if (ca_cert_store) { set_ca_store(ctx, ca_cert_store); }
+ load_system_certs(ctx);
+ }
+
+ bool is_ip = is_ip_address(host);
+
+#ifdef CPPHTTPLIB_MBEDTLS_SUPPORT
+ if (is_ip && server_certificate_verification) {
+ set_verify_client(ctx, false);
+ } else {
+ set_verify_client(ctx, server_certificate_verification);
+ }
+#endif
+
+ session = create_session(ctx, sock);
+ if (!session) { return false; }
+
+ // RFC 6066: SNI must not be set for IP addresses
+ if (!is_ip) { set_sni(session, host.c_str()); }
+ if (server_certificate_verification) { set_hostname(session, host.c_str()); }
+
+ if (!connect_nonblocking(session, sock, timeout_sec, timeout_usec, nullptr)) {
+ return false;
+ }
+
+ if (server_certificate_verification) {
+ if (get_verify_result(session) != 0) { return false; }
+ }
+
+ return true;
+}
+
} // namespace detail
#endif // CPPHTTPLIB_SSL_ENABLED
}
// ThreadPool implementation
-ThreadPool::ThreadPool(size_t n, size_t mqr)
- : shutdown_(false), max_queued_requests_(mqr) {
- threads_.reserve(n);
- while (n) {
- threads_.emplace_back(worker(*this));
- n--;
+ThreadPool::ThreadPool(size_t n, size_t max_n, size_t mqr)
+ : base_thread_count_(n), max_queued_requests_(mqr), idle_thread_count_(0),
+ shutdown_(false) {
+#ifndef CPPHTTPLIB_NO_EXCEPTIONS
+ if (max_n != 0 && max_n < n) {
+ std::string msg = "max_threads must be >= base_threads";
+ throw std::invalid_argument(msg);
+ }
+#endif
+ max_thread_count_ = max_n == 0 ? n : max_n;
+ threads_.reserve(base_thread_count_);
+ for (size_t i = 0; i < base_thread_count_; i++) {
+ threads_.emplace_back(std::thread([this]() { worker(false); }));
}
}
bool ThreadPool::enqueue(std::function<void()> fn) {
{
std::unique_lock<std::mutex> lock(mutex_);
+ if (shutdown_) { return false; }
if (max_queued_requests_ > 0 && jobs_.size() >= max_queued_requests_) {
return false;
}
jobs_.push_back(std::move(fn));
+
+ // Spawn a dynamic thread if no idle threads and under max
+ if (idle_thread_count_ == 0 &&
+ threads_.size() + dynamic_threads_.size() < max_thread_count_) {
+ cleanup_finished_threads();
+ dynamic_threads_.emplace_back(std::thread([this]() { worker(true); }));
+ }
}
cond_.notify_one();
}
void ThreadPool::shutdown() {
- // Stop all worker threads...
{
std::unique_lock<std::mutex> lock(mutex_);
shutdown_ = true;
cond_.notify_all();
- // Join...
for (auto &t : threads_) {
- t.join();
+ if (t.joinable()) { t.join(); }
+ }
+
+ // Move dynamic_threads_ to a local list under the lock to avoid racing
+ // with worker threads that call move_to_finished() concurrently.
+ std::list<std::thread> remaining_dynamic;
+ {
+ std::unique_lock<std::mutex> lock(mutex_);
+ remaining_dynamic = std::move(dynamic_threads_);
+ }
+ for (auto &t : remaining_dynamic) {
+ if (t.joinable()) { t.join(); }
+ }
+
+ std::unique_lock<std::mutex> lock(mutex_);
+ cleanup_finished_threads();
+}
+
+void ThreadPool::move_to_finished(std::thread::id id) {
+ // Must be called with mutex_ held
+ for (auto it = dynamic_threads_.begin(); it != dynamic_threads_.end(); ++it) {
+ if (it->get_id() == id) {
+ finished_threads_.push_back(std::move(*it));
+ dynamic_threads_.erase(it);
+ return;
+ }
}
}
-ThreadPool::worker::worker(ThreadPool &pool) : pool_(pool) {}
+void ThreadPool::cleanup_finished_threads() {
+ // Must be called with mutex_ held
+ for (auto &t : finished_threads_) {
+ if (t.joinable()) { t.join(); }
+ }
+ finished_threads_.clear();
+}
-void ThreadPool::worker::operator()() {
+void ThreadPool::worker(bool is_dynamic) {
for (;;) {
std::function<void()> fn;
{
- std::unique_lock<std::mutex> lock(pool_.mutex_);
+ std::unique_lock<std::mutex> lock(mutex_);
+ idle_thread_count_++;
+
+ if (is_dynamic) {
+ auto has_work = cond_.wait_for(
+ lock, std::chrono::seconds(CPPHTTPLIB_THREAD_POOL_IDLE_TIMEOUT),
+ [&] { return !jobs_.empty() || shutdown_; });
+ if (!has_work) {
+ // Timed out with no work - exit this dynamic thread
+ idle_thread_count_--;
+ move_to_finished(std::this_thread::get_id());
+ break;
+ }
+ } else {
+ cond_.wait(lock, [&] { return !jobs_.empty() || shutdown_; });
+ }
- pool_.cond_.wait(lock,
- [&] { return !pool_.jobs_.empty() || pool_.shutdown_; });
+ idle_thread_count_--;
- if (pool_.shutdown_ && pool_.jobs_.empty()) { break; }
+ if (shutdown_ && jobs_.empty()) { break; }
- fn = pool_.jobs_.front();
- pool_.jobs_.pop_front();
+ fn = std::move(jobs_.front());
+ jobs_.pop_front();
}
assert(true == static_cast<bool>(fn));
fn();
+
+ // Dynamic thread: exit if queue is empty after task completion
+ if (is_dynamic) {
+ std::unique_lock<std::mutex> lock(mutex_);
+ if (jobs_.empty()) {
+ move_to_finished(std::this_thread::get_id());
+ break;
+ }
+ }
}
#if defined(CPPHTTPLIB_OPENSSL_SUPPORT) && !defined(OPENSSL_IS_BORINGSSL) && \
.count();
}
+void SocketStream::set_read_timeout(time_t sec, time_t usec) {
+ read_timeout_sec_ = sec;
+ read_timeout_usec_ = usec;
+}
+
// Buffer stream implementation
bool BufferStream::is_readable() const { return true; }
.count();
}
+void SSLSocketStream::set_read_timeout(time_t sec, time_t usec) {
+ read_timeout_sec_ = sec;
+ read_timeout_usec_ = usec;
+}
+
} // namespace detail
#endif // CPPHTTPLIB_SSL_ENABLED
// HTTP server implementation
Server::Server()
- : new_task_queue(
- [] { return new ThreadPool(CPPHTTPLIB_THREAD_POOL_COUNT); }) {
+ : new_task_queue([] {
+ return new ThreadPool(CPPHTTPLIB_THREAD_POOL_COUNT,
+ CPPHTTPLIB_THREAD_POOL_MAX_COUNT);
+ }) {
#ifndef _WIN32
signal(SIGPIPE, SIG_IGN);
#endif
return *this;
}
+Server &Server::WebSocket(const std::string &pattern,
+ WebSocketHandler handler) {
+ websocket_handlers_.push_back(
+ {make_matcher(pattern), std::move(handler), nullptr});
+ return *this;
+}
+
+Server &Server::WebSocket(const std::string &pattern,
+ WebSocketHandler handler,
+ SubProtocolSelector sub_protocol_selector) {
+ websocket_handlers_.push_back({make_matcher(pattern), std::move(handler),
+ std::move(sub_protocol_selector)});
+ return *this;
+}
+
bool Server::set_base_dir(const std::string &dir,
const std::string &mount_point) {
return set_mount_point(mount_point, dir);
int remote_port, const std::string &local_addr,
int local_port, bool close_connection,
bool &connection_closed,
- const std::function<void(Request &)> &setup_request) {
+ const std::function<void(Request &)> &setup_request,
+ bool *websocket_upgraded) {
std::array<char, 2048> buf{};
detail::stream_line_reader line_reader(strm, buf.data(), buf.size());
return !detail::is_socket_alive(sock);
};
+ // WebSocket upgrade
+ // Check pre_routing_handler_ before upgrading so that authentication
+ // and other middleware can reject the request with an HTTP response
+ // (e.g., 401) before the protocol switches.
+ if (detail::is_websocket_upgrade(req)) {
+ if (pre_routing_handler_ &&
+ pre_routing_handler_(req, res) == HandlerResponse::Handled) {
+ if (res.status == -1) { res.status = StatusCode::OK_200; }
+ return write_response(strm, close_connection, req, res);
+ }
+ // Find matching WebSocket handler
+ for (const auto &entry : websocket_handlers_) {
+ if (entry.matcher->match(req)) {
+ // Compute accept key
+ auto client_key = req.get_header_value("Sec-WebSocket-Key");
+ auto accept_key = detail::websocket_accept_key(client_key);
+
+ // Negotiate subprotocol
+ std::string selected_subprotocol;
+ if (entry.sub_protocol_selector) {
+ auto protocol_header = req.get_header_value("Sec-WebSocket-Protocol");
+ if (!protocol_header.empty()) {
+ std::vector<std::string> protocols;
+ std::istringstream iss(protocol_header);
+ std::string token;
+ while (std::getline(iss, token, ',')) {
+ // Trim whitespace
+ auto start = token.find_first_not_of(' ');
+ auto end = token.find_last_not_of(' ');
+ if (start != std::string::npos) {
+ protocols.push_back(token.substr(start, end - start + 1));
+ }
+ }
+ selected_subprotocol = entry.sub_protocol_selector(protocols);
+ }
+ }
+
+ // Send 101 Switching Protocols
+ std::string handshake_response = "HTTP/1.1 101 Switching Protocols\r\n"
+ "Upgrade: websocket\r\n"
+ "Connection: Upgrade\r\n"
+ "Sec-WebSocket-Accept: " +
+ accept_key + "\r\n";
+ if (!selected_subprotocol.empty()) {
+ if (!detail::fields::is_field_value(selected_subprotocol)) {
+ return false;
+ }
+ handshake_response +=
+ "Sec-WebSocket-Protocol: " + selected_subprotocol + "\r\n";
+ }
+ handshake_response += "\r\n";
+ if (strm.write(handshake_response.data(), handshake_response.size()) <
+ 0) {
+ return false;
+ }
+
+ connection_closed = true;
+ if (websocket_upgraded) { *websocket_upgraded = true; }
+
+ {
+ // Use WebSocket-specific read timeout instead of HTTP timeout
+ strm.set_read_timeout(CPPHTTPLIB_WEBSOCKET_READ_TIMEOUT_SECOND, 0);
+ ws::WebSocket ws(strm, req, true);
+ entry.handler(req, ws);
+ }
+ return true;
+ }
+ }
+ // No matching handler - fall through to 404
+ }
+
// Routing
auto routed = false;
#ifdef CPPHTTPLIB_NO_EXCEPTIONS
int local_port = 0;
detail::get_local_ip_and_port(sock, local_addr, local_port);
+ bool websocket_upgraded = false;
auto ret = detail::process_server_socket(
svr_sock_, sock, keep_alive_max_count_, keep_alive_timeout_sec_,
read_timeout_sec_, read_timeout_usec_, write_timeout_sec_,
[&](Stream &strm, bool close_connection, bool &connection_closed) {
return process_request(strm, remote_addr, remote_port, local_addr,
local_port, close_connection, connection_closed,
- nullptr);
+ nullptr, &websocket_upgraded);
});
detail::shutdown_socket(sock);
const auto &boundary = detail::make_multipart_data_boundary();
const auto &content_type =
detail::serialize_multipart_formdata_get_content_type(boundary);
- const auto &body = detail::serialize_multipart_formdata(items, boundary);
- return Post(path, headers, body, content_type, progress);
+ auto content_length = detail::get_multipart_content_length(items, boundary);
+ return Post(path, headers, content_length,
+ detail::make_multipart_content_provider(items, boundary),
+ content_type, progress);
}
Result ClientImpl::Post(const std::string &path, const Headers &headers,
const auto &content_type =
detail::serialize_multipart_formdata_get_content_type(boundary);
- const auto &body = detail::serialize_multipart_formdata(items, boundary);
- return Post(path, headers, body, content_type, progress);
+ auto content_length = detail::get_multipart_content_length(items, boundary);
+ return Post(path, headers, content_length,
+ detail::make_multipart_content_provider(items, boundary),
+ content_type, progress);
}
Result ClientImpl::Post(const std::string &path, const Headers &headers,
const auto &boundary = detail::make_multipart_data_boundary();
const auto &content_type =
detail::serialize_multipart_formdata_get_content_type(boundary);
- const auto &body = detail::serialize_multipart_formdata(items, boundary);
- return Put(path, headers, body, content_type, progress);
+ auto content_length = detail::get_multipart_content_length(items, boundary);
+ return Put(path, headers, content_length,
+ detail::make_multipart_content_provider(items, boundary),
+ content_type, progress);
}
Result ClientImpl::Put(const std::string &path, const Headers &headers,
const auto &content_type =
detail::serialize_multipart_formdata_get_content_type(boundary);
- const auto &body = detail::serialize_multipart_formdata(items, boundary);
- return Put(path, headers, body, content_type, progress);
+ auto content_length = detail::get_multipart_content_length(items, boundary);
+ return Put(path, headers, content_length,
+ detail::make_multipart_content_provider(items, boundary),
+ content_type, progress);
}
Result ClientImpl::Put(const std::string &path, const Headers &headers,
const auto &boundary = detail::make_multipart_data_boundary();
const auto &content_type =
detail::serialize_multipart_formdata_get_content_type(boundary);
- const auto &body = detail::serialize_multipart_formdata(items, boundary);
- return Patch(path, headers, body, content_type, progress);
+ auto content_length = detail::get_multipart_content_length(items, boundary);
+ return Patch(path, headers, content_length,
+ detail::make_multipart_content_provider(items, boundary),
+ content_type, progress);
}
Result ClientImpl::Patch(const std::string &path, const Headers &headers,
const auto &content_type =
detail::serialize_multipart_formdata_get_content_type(boundary);
- const auto &body = detail::serialize_multipart_formdata(items, boundary);
- return Patch(path, headers, body, content_type, progress);
+ auto content_length = detail::get_multipart_content_length(items, boundary);
+ return Patch(path, headers, content_length,
+ detail::make_multipart_content_provider(items, boundary),
+ content_type, progress);
}
Result ClientImpl::Patch(const std::string &path, const Headers &headers,
// Use scope_exit to ensure cleanup on all paths (including exceptions)
bool handshake_done = false;
bool ret = false;
+ bool websocket_upgraded = false;
auto cleanup = detail::scope_exit([&] {
- // Shutdown gracefully if handshake succeeded and processing was successful
- if (handshake_done) { shutdown(session, ret); }
+ if (handshake_done) { shutdown(session, !websocket_upgraded && ret); }
free_session(session);
detail::shutdown_socket(sock);
detail::close_socket(sock);
read_timeout_sec_, read_timeout_usec_, write_timeout_sec_,
write_timeout_usec_,
[&](Stream &strm, bool close_connection, bool &connection_closed) {
- return process_request(strm, remote_addr, remote_port, local_addr,
- local_port, close_connection, connection_closed,
- [&](Request &req) { req.ssl = session; });
+ return process_request(
+ strm, remote_addr, remote_port, local_addr, local_port,
+ close_connection, connection_closed,
+ [&](Request &req) { req.ssl = session; }, &websocket_upgraded);
});
return ret;
bool is_ip = detail::is_ip_address(host_);
-#ifdef CPPHTTPLIB_MBEDTLS_SUPPORT
- // MbedTLS needs explicit verification mode (OpenSSL uses SSL_VERIFY_NONE
- // by default and performs all verification post-handshake).
+#if defined(CPPHTTPLIB_MBEDTLS_SUPPORT) || defined(CPPHTTPLIB_WOLFSSL_SUPPORT)
+ // MbedTLS/wolfSSL need explicit verification mode (OpenSSL uses
+ // SSL_VERIFY_NONE by default and performs all verification post-handshake).
// For IP addresses with verification enabled, use OPTIONAL mode since
- // MbedTLS requires hostname for VERIFY_REQUIRED.
+ // these backends require hostname for strict verification.
if (is_ip && server_certificate_verification_) {
set_verify_client(ctx_, false);
} else {
return callback;
}
-} // namespace impl
+// Check if a string is an IPv4 address
+bool is_ipv4_address(const std::string &str) {
+ int dots = 0;
+ for (char c : str) {
+ if (c == '.') {
+ dots++;
+ } else if (!isdigit(static_cast<unsigned char>(c))) {
+ return false;
+ }
+ }
+ return dots == 3;
+}
-bool set_client_ca_file(ctx_t ctx, const char *ca_file,
- const char *ca_dir) {
- if (!ctx) { return false; }
+// Parse IPv4 address string to bytes
+bool parse_ipv4(const std::string &str, unsigned char *out) {
+ int parts[4];
+ if (sscanf(str.c_str(), "%d.%d.%d.%d", &parts[0], &parts[1], &parts[2],
+ &parts[3]) != 4) {
+ return false;
+ }
+ for (int i = 0; i < 4; i++) {
+ if (parts[i] < 0 || parts[i] > 255) return false;
+ out[i] = static_cast<unsigned char>(parts[i]);
+ }
+ return true;
+}
+
+#ifdef _WIN32
+// Enumerate Windows system certificates and call callback with DER data
+template <typename Callback>
+bool enumerate_windows_system_certs(Callback cb) {
+ bool loaded = false;
+ static const wchar_t *store_names[] = {L"ROOT", L"CA"};
+ for (auto store_name : store_names) {
+ HCERTSTORE hStore = CertOpenSystemStoreW(0, store_name);
+ if (hStore) {
+ PCCERT_CONTEXT pContext = nullptr;
+ while ((pContext = CertEnumCertificatesInStore(hStore, pContext)) !=
+ nullptr) {
+ if (cb(pContext->pbCertEncoded, pContext->cbCertEncoded)) {
+ loaded = true;
+ }
+ }
+ CertCloseStore(hStore, 0);
+ }
+ }
+ return loaded;
+}
+#endif
+
+#if defined(__APPLE__) && defined(CPPHTTPLIB_USE_CERTS_FROM_MACOSX_KEYCHAIN)
+// Enumerate macOS Keychain certificates and call callback with DER data
+template <typename Callback>
+bool enumerate_macos_keychain_certs(Callback cb) {
+ bool loaded = false;
+ CFArrayRef certs = nullptr;
+ OSStatus status = SecTrustCopyAnchorCertificates(&certs);
+ if (status == errSecSuccess && certs) {
+ CFIndex count = CFArrayGetCount(certs);
+ for (CFIndex i = 0; i < count; i++) {
+ SecCertificateRef cert =
+ (SecCertificateRef)CFArrayGetValueAtIndex(certs, i);
+ CFDataRef data = SecCertificateCopyData(cert);
+ if (data) {
+ if (cb(CFDataGetBytePtr(data),
+ static_cast<size_t>(CFDataGetLength(data)))) {
+ loaded = true;
+ }
+ CFRelease(data);
+ }
+ }
+ CFRelease(certs);
+ }
+ return loaded;
+}
+#endif
+
+#if !defined(_WIN32) && !(defined(__APPLE__) && \
+ defined(CPPHTTPLIB_USE_CERTS_FROM_MACOSX_KEYCHAIN))
+// Common CA certificate file paths on Linux/Unix
+const char **system_ca_paths() {
+ static const char *paths[] = {
+ "/etc/ssl/certs/ca-certificates.crt", // Debian/Ubuntu
+ "/etc/pki/tls/certs/ca-bundle.crt", // RHEL/CentOS
+ "/etc/ssl/ca-bundle.pem", // OpenSUSE
+ "/etc/pki/tls/cacert.pem", // OpenELEC
+ "/etc/ssl/cert.pem", // Alpine, FreeBSD
+ nullptr};
+ return paths;
+}
+
+// Common CA certificate directory paths on Linux/Unix
+const char **system_ca_dirs() {
+ static const char *dirs[] = {"/etc/ssl/certs", // Debian/Ubuntu
+ "/etc/pki/tls/certs", // RHEL/CentOS
+ "/usr/share/ca-certificates", // Other
+ nullptr};
+ return dirs;
+}
+#endif
+
+} // namespace impl
+
+bool set_client_ca_file(ctx_t ctx, const char *ca_file,
+ const char *ca_dir) {
+ if (!ctx) { return false; }
bool success = true;
if (ca_file && *ca_file) {
int mbedtls_verify_callback(void *data, mbedtls_x509_crt *crt,
int cert_depth, uint32_t *flags);
-// Check if a string is an IPv4 address
-bool is_ipv4_address(const std::string &str) {
- int dots = 0;
- for (char c : str) {
- if (c == '.') {
- dots++;
- } else if (!isdigit(static_cast<unsigned char>(c))) {
- return false;
- }
- }
- return dots == 3;
-}
-
-// Parse IPv4 address string to bytes
-bool parse_ipv4(const std::string &str, unsigned char *out) {
- int parts[4];
- if (sscanf(str.c_str(), "%d.%d.%d.%d", &parts[0], &parts[1], &parts[2],
- &parts[3]) != 4) {
- return false;
- }
- for (int i = 0; i < 4; i++) {
- if (parts[i] < 0 || parts[i] > 255) return false;
- out[i] = static_cast<unsigned char>(parts[i]);
- }
- return true;
-}
-
// MbedTLS verify callback wrapper
int mbedtls_verify_callback(void *data, mbedtls_x509_crt *crt,
int cert_depth, uint32_t *flags) {
bool loaded = false;
#ifdef _WIN32
- // Load from Windows certificate store (ROOT and CA)
- static const wchar_t *store_names[] = {L"ROOT", L"CA"};
- for (auto store_name : store_names) {
- HCERTSTORE hStore = CertOpenSystemStoreW(0, store_name);
- if (hStore) {
- PCCERT_CONTEXT pContext = nullptr;
- while ((pContext = CertEnumCertificatesInStore(hStore, pContext)) !=
- nullptr) {
- int ret = mbedtls_x509_crt_parse_der(
- &mctx->ca_chain, pContext->pbCertEncoded, pContext->cbCertEncoded);
- if (ret == 0) { loaded = true; }
- }
- CertCloseStore(hStore, 0);
- }
- }
+ loaded = impl::enumerate_windows_system_certs(
+ [&](const unsigned char *data, size_t len) {
+ return mbedtls_x509_crt_parse_der(&mctx->ca_chain, data, len) == 0;
+ });
#elif defined(__APPLE__) && defined(CPPHTTPLIB_USE_CERTS_FROM_MACOSX_KEYCHAIN)
- // Load from macOS Keychain
- CFArrayRef certs = nullptr;
- OSStatus status = SecTrustCopyAnchorCertificates(&certs);
- if (status == errSecSuccess && certs) {
- CFIndex count = CFArrayGetCount(certs);
- for (CFIndex i = 0; i < count; i++) {
- SecCertificateRef cert =
- (SecCertificateRef)CFArrayGetValueAtIndex(certs, i);
- CFDataRef data = SecCertificateCopyData(cert);
- if (data) {
- int ret = mbedtls_x509_crt_parse_der(
- &mctx->ca_chain, CFDataGetBytePtr(data),
- static_cast<size_t>(CFDataGetLength(data)));
- if (ret == 0) { loaded = true; }
- CFRelease(data);
- }
- }
- CFRelease(certs);
- }
+ loaded = impl::enumerate_macos_keychain_certs(
+ [&](const unsigned char *data, size_t len) {
+ return mbedtls_x509_crt_parse_der(&mctx->ca_chain, data, len) == 0;
+ });
#else
- // Try common CA certificate locations on Linux/Unix
- static const char *ca_paths[] = {
- "/etc/ssl/certs/ca-certificates.crt", // Debian/Ubuntu
- "/etc/pki/tls/certs/ca-bundle.crt", // RHEL/CentOS
- "/etc/ssl/ca-bundle.pem", // OpenSUSE
- "/etc/pki/tls/cacert.pem", // OpenELEC
- "/etc/ssl/cert.pem", // Alpine, FreeBSD
- nullptr};
-
- for (const char **path = ca_paths; *path; ++path) {
- int ret = mbedtls_x509_crt_parse_file(&mctx->ca_chain, *path);
- if (ret >= 0) {
+ for (auto path = impl::system_ca_paths(); *path; ++path) {
+ if (mbedtls_x509_crt_parse_file(&mctx->ca_chain, *path) >= 0) {
loaded = true;
break;
}
}
- // Also try the CA directory
if (!loaded) {
- static const char *ca_dirs[] = {"/etc/ssl/certs", // Debian/Ubuntu
- "/etc/pki/tls/certs", // RHEL/CentOS
- "/usr/share/ca-certificates", nullptr};
-
- for (const char **dir = ca_dirs; *dir; ++dir) {
- int ret = mbedtls_x509_crt_parse_path(&mctx->ca_chain, *dir);
- if (ret >= 0) {
+ for (auto dir = impl::system_ca_dirs(); *dir; ++dir) {
+ if (mbedtls_x509_crt_parse_path(&mctx->ca_chain, *dir) >= 0) {
loaded = true;
break;
}
return false;
}
+ // Verify that the certificate and private key match
+#ifdef CPPHTTPLIB_MBEDTLS_V3
+ ret = mbedtls_pk_check_pair(&mctx->own_cert.pk, &mctx->own_key,
+ mbedtls_ctr_drbg_random, &mctx->ctr_drbg);
+#else
+ ret = mbedtls_pk_check_pair(&mctx->own_cert.pk, &mctx->own_key);
+#endif
+ if (ret != 0) {
+ impl::mbedtls_last_error() = ret;
+ return false;
+ }
+
ret = mbedtls_ssl_conf_own_cert(&mctx->conf, &mctx->own_cert, &mctx->own_key);
if (ret != 0) {
impl::mbedtls_last_error() = ret;
return false;
}
+ // Verify that the certificate and private key match
+#ifdef CPPHTTPLIB_MBEDTLS_V3
+ ret = mbedtls_pk_check_pair(&mctx->own_cert.pk, &mctx->own_key,
+ mbedtls_ctr_drbg_random, &mctx->ctr_drbg);
+#else
+ ret = mbedtls_pk_check_pair(&mctx->own_cert.pk, &mctx->own_key);
+#endif
+ if (ret != 0) {
+ impl::mbedtls_last_error() = ret;
+ return false;
+ }
+
ret = mbedtls_ssl_conf_own_cert(&mctx->conf, &mctx->own_cert, &mctx->own_key);
if (ret != 0) {
impl::mbedtls_last_error() = ret;
#endif // CPPHTTPLIB_MBEDTLS_SUPPORT
+/*
+ * Group 10: TLS abstraction layer - wolfSSL backend
+ */
+
+/*
+ * wolfSSL Backend Implementation
+ */
+
+#ifdef CPPHTTPLIB_WOLFSSL_SUPPORT
+namespace tls {
+
+namespace impl {
+
+// wolfSSL session wrapper
+struct WolfSSLSession {
+ WOLFSSL *ssl = nullptr;
+ socket_t sock = INVALID_SOCKET;
+ std::string hostname; // For client: set via set_sni
+ std::string sni_hostname; // For server: received from client via SNI callback
+
+ WolfSSLSession() = default;
+
+ ~WolfSSLSession() {
+ if (ssl) { wolfSSL_free(ssl); }
+ }
+
+ WolfSSLSession(const WolfSSLSession &) = delete;
+ WolfSSLSession &operator=(const WolfSSLSession &) = delete;
+};
+
+// Thread-local error code accessor for wolfSSL
+uint64_t &wolfssl_last_error() {
+ static thread_local uint64_t err = 0;
+ return err;
+}
+
+// Helper to map wolfSSL error to ErrorCode.
+// ssl_error is the value from wolfSSL_get_error().
+// raw_ret is the raw return value from the wolfSSL call (for low-level error).
+ErrorCode map_wolfssl_error(WOLFSSL *ssl, int ssl_error,
+ int &out_errno) {
+ switch (ssl_error) {
+ case SSL_ERROR_NONE: return ErrorCode::Success;
+ case SSL_ERROR_WANT_READ: return ErrorCode::WantRead;
+ case SSL_ERROR_WANT_WRITE: return ErrorCode::WantWrite;
+ case SSL_ERROR_ZERO_RETURN: return ErrorCode::PeerClosed;
+ case SSL_ERROR_SYSCALL: out_errno = errno; return ErrorCode::SyscallError;
+ default:
+ if (ssl) {
+ // wolfSSL stores the low-level error code as a negative value.
+ // DOMAIN_NAME_MISMATCH (-322) indicates hostname verification failure.
+ int low_err = ssl_error; // wolfSSL_get_error returns the low-level code
+ if (low_err == DOMAIN_NAME_MISMATCH) {
+ return ErrorCode::HostnameMismatch;
+ }
+ // Check verify result to distinguish cert verification from generic SSL
+ // errors.
+ long vr = wolfSSL_get_verify_result(ssl);
+ if (vr != 0) { return ErrorCode::CertVerifyFailed; }
+ }
+ return ErrorCode::Fatal;
+ }
+}
+
+// WolfSSLContext constructor/destructor implementations
+WolfSSLContext::WolfSSLContext() { wolfSSL_Init(); }
+
+WolfSSLContext::~WolfSSLContext() {
+ if (ctx) { wolfSSL_CTX_free(ctx); }
+}
+
+// Thread-local storage for SNI captured during handshake
+std::string &wolfssl_pending_sni() {
+ static thread_local std::string sni;
+ return sni;
+}
+
+// SNI callback for wolfSSL server to capture client's SNI hostname
+int wolfssl_sni_callback(WOLFSSL *ssl, int *ret, void *exArg) {
+ (void)ret;
+ (void)exArg;
+
+ void *name_data = nullptr;
+ unsigned short name_len =
+ wolfSSL_SNI_GetRequest(ssl, WOLFSSL_SNI_HOST_NAME, &name_data);
+
+ if (name_data && name_len > 0) {
+ wolfssl_pending_sni().assign(static_cast<const char *>(name_data),
+ name_len);
+ } else {
+ wolfssl_pending_sni().clear();
+ }
+ return 0; // Continue regardless
+}
+
+// wolfSSL verify callback wrapper
+int wolfssl_verify_callback(int preverify_ok,
+ WOLFSSL_X509_STORE_CTX *x509_ctx) {
+ auto &callback = get_verify_callback();
+ if (!callback) { return preverify_ok; }
+
+ WOLFSSL_X509 *cert = wolfSSL_X509_STORE_CTX_get_current_cert(x509_ctx);
+ int depth = wolfSSL_X509_STORE_CTX_get_error_depth(x509_ctx);
+ int err = wolfSSL_X509_STORE_CTX_get_error(x509_ctx);
+
+ // Get the WOLFSSL object from the X509_STORE_CTX
+ WOLFSSL *ssl = static_cast<WOLFSSL *>(wolfSSL_X509_STORE_CTX_get_ex_data(
+ x509_ctx, wolfSSL_get_ex_data_X509_STORE_CTX_idx()));
+
+ VerifyContext verify_ctx;
+ verify_ctx.session = static_cast<session_t>(ssl);
+ verify_ctx.cert = static_cast<cert_t>(cert);
+ verify_ctx.depth = depth;
+ verify_ctx.preverify_ok = (preverify_ok != 0);
+ verify_ctx.error_code = static_cast<long>(err);
+
+ if (err != 0) {
+ verify_ctx.error_string = wolfSSL_X509_verify_cert_error_string(err);
+ } else {
+ verify_ctx.error_string = nullptr;
+ }
+
+ bool accepted = callback(verify_ctx);
+ return accepted ? 1 : 0;
+}
+
+void set_wolfssl_password_cb(WOLFSSL_CTX *ctx, const char *password) {
+ wolfSSL_CTX_set_default_passwd_cb_userdata(ctx, const_cast<char *>(password));
+ wolfSSL_CTX_set_default_passwd_cb(
+ ctx, [](char *buf, int size, int /*rwflag*/, void *userdata) -> int {
+ auto *pwd = static_cast<const char *>(userdata);
+ if (!pwd) return 0;
+ auto len = static_cast<int>(strlen(pwd));
+ if (len > size) len = size;
+ memcpy(buf, pwd, static_cast<size_t>(len));
+ return len;
+ });
+}
+
+} // namespace impl
+
+ctx_t create_client_context() {
+ auto ctx = new (std::nothrow) impl::WolfSSLContext();
+ if (!ctx) { return nullptr; }
+
+ ctx->is_server = false;
+
+ WOLFSSL_METHOD *method = wolfTLSv1_2_client_method();
+ if (!method) {
+ delete ctx;
+ return nullptr;
+ }
+
+ ctx->ctx = wolfSSL_CTX_new(method);
+ if (!ctx->ctx) {
+ delete ctx;
+ return nullptr;
+ }
+
+ // Default: verify peer certificate
+ wolfSSL_CTX_set_verify(ctx->ctx, SSL_VERIFY_PEER, nullptr);
+
+ return static_cast<ctx_t>(ctx);
+}
+
+ctx_t create_server_context() {
+ auto ctx = new (std::nothrow) impl::WolfSSLContext();
+ if (!ctx) { return nullptr; }
+
+ ctx->is_server = true;
+
+ WOLFSSL_METHOD *method = wolfTLSv1_2_server_method();
+ if (!method) {
+ delete ctx;
+ return nullptr;
+ }
+
+ ctx->ctx = wolfSSL_CTX_new(method);
+ if (!ctx->ctx) {
+ delete ctx;
+ return nullptr;
+ }
+
+ // Default: don't verify client
+ wolfSSL_CTX_set_verify(ctx->ctx, SSL_VERIFY_NONE, nullptr);
+
+ // Enable SNI on server
+ wolfSSL_CTX_SNI_SetOptions(ctx->ctx, WOLFSSL_SNI_HOST_NAME,
+ WOLFSSL_SNI_CONTINUE_ON_MISMATCH);
+ wolfSSL_CTX_set_servername_callback(ctx->ctx, impl::wolfssl_sni_callback);
+
+ return static_cast<ctx_t>(ctx);
+}
+
+void free_context(ctx_t ctx) {
+ if (ctx) { delete static_cast<impl::WolfSSLContext *>(ctx); }
+}
+
+bool set_min_version(ctx_t ctx, Version version) {
+ if (!ctx) { return false; }
+ auto wctx = static_cast<impl::WolfSSLContext *>(ctx);
+
+ int min_ver = WOLFSSL_TLSV1_2;
+ if (version >= Version::TLS1_3) { min_ver = WOLFSSL_TLSV1_3; }
+
+ return wolfSSL_CTX_SetMinVersion(wctx->ctx, min_ver) == WOLFSSL_SUCCESS;
+}
+
+bool load_ca_pem(ctx_t ctx, const char *pem, size_t len) {
+ if (!ctx || !pem) { return false; }
+ auto wctx = static_cast<impl::WolfSSLContext *>(ctx);
+
+ int ret = wolfSSL_CTX_load_verify_buffer(
+ wctx->ctx, reinterpret_cast<const unsigned char *>(pem),
+ static_cast<long>(len), SSL_FILETYPE_PEM);
+ if (ret != SSL_SUCCESS) {
+ impl::wolfssl_last_error() =
+ static_cast<uint64_t>(wolfSSL_ERR_peek_last_error());
+ return false;
+ }
+ wctx->ca_pem_data_.append(pem, len);
+ return true;
+}
+
+bool load_ca_file(ctx_t ctx, const char *file_path) {
+ if (!ctx || !file_path) { return false; }
+ auto wctx = static_cast<impl::WolfSSLContext *>(ctx);
+
+ int ret = wolfSSL_CTX_load_verify_locations(wctx->ctx, file_path, nullptr);
+ if (ret != SSL_SUCCESS) {
+ impl::wolfssl_last_error() =
+ static_cast<uint64_t>(wolfSSL_ERR_peek_last_error());
+ return false;
+ }
+ return true;
+}
+
+bool load_ca_dir(ctx_t ctx, const char *dir_path) {
+ if (!ctx || !dir_path) { return false; }
+ auto wctx = static_cast<impl::WolfSSLContext *>(ctx);
+
+ int ret = wolfSSL_CTX_load_verify_locations(wctx->ctx, nullptr, dir_path);
+ // wolfSSL may fail if the directory doesn't contain properly hashed certs.
+ // Unlike OpenSSL which lazily loads certs from directories, wolfSSL scans
+ // immediately. Return true even on failure since the CA file may have
+ // already been loaded, matching OpenSSL's lenient behavior.
+ (void)ret;
+ return true;
+}
+
+bool load_system_certs(ctx_t ctx) {
+ if (!ctx) { return false; }
+ auto wctx = static_cast<impl::WolfSSLContext *>(ctx);
+ bool loaded = false;
+
+#ifdef _WIN32
+ loaded = impl::enumerate_windows_system_certs(
+ [&](const unsigned char *data, size_t len) {
+ return wolfSSL_CTX_load_verify_buffer(wctx->ctx, data,
+ static_cast<long>(len),
+ SSL_FILETYPE_ASN1) == SSL_SUCCESS;
+ });
+#elif defined(__APPLE__) && defined(CPPHTTPLIB_USE_CERTS_FROM_MACOSX_KEYCHAIN)
+ loaded = impl::enumerate_macos_keychain_certs(
+ [&](const unsigned char *data, size_t len) {
+ return wolfSSL_CTX_load_verify_buffer(wctx->ctx, data,
+ static_cast<long>(len),
+ SSL_FILETYPE_ASN1) == SSL_SUCCESS;
+ });
+#else
+ for (auto path = impl::system_ca_paths(); *path; ++path) {
+ if (wolfSSL_CTX_load_verify_locations(wctx->ctx, *path, nullptr) ==
+ SSL_SUCCESS) {
+ loaded = true;
+ break;
+ }
+ }
+
+ if (!loaded) {
+ for (auto dir = impl::system_ca_dirs(); *dir; ++dir) {
+ if (wolfSSL_CTX_load_verify_locations(wctx->ctx, nullptr, *dir) ==
+ SSL_SUCCESS) {
+ loaded = true;
+ break;
+ }
+ }
+ }
+#endif
+
+ return loaded;
+}
+
+bool set_client_cert_pem(ctx_t ctx, const char *cert, const char *key,
+ const char *password) {
+ if (!ctx || !cert || !key) { return false; }
+ auto wctx = static_cast<impl::WolfSSLContext *>(ctx);
+
+ // Load certificate
+ int ret = wolfSSL_CTX_use_certificate_buffer(
+ wctx->ctx, reinterpret_cast<const unsigned char *>(cert),
+ static_cast<long>(strlen(cert)), SSL_FILETYPE_PEM);
+ if (ret != SSL_SUCCESS) {
+ impl::wolfssl_last_error() =
+ static_cast<uint64_t>(wolfSSL_ERR_peek_last_error());
+ return false;
+ }
+
+ // Set password callback if password is provided
+ if (password) { impl::set_wolfssl_password_cb(wctx->ctx, password); }
+
+ // Load private key
+ ret = wolfSSL_CTX_use_PrivateKey_buffer(
+ wctx->ctx, reinterpret_cast<const unsigned char *>(key),
+ static_cast<long>(strlen(key)), SSL_FILETYPE_PEM);
+ if (ret != SSL_SUCCESS) {
+ impl::wolfssl_last_error() =
+ static_cast<uint64_t>(wolfSSL_ERR_peek_last_error());
+ return false;
+ }
+
+ // Verify that the certificate and private key match
+ return wolfSSL_CTX_check_private_key(wctx->ctx) == SSL_SUCCESS;
+}
+
+bool set_client_cert_file(ctx_t ctx, const char *cert_path,
+ const char *key_path, const char *password) {
+ if (!ctx || !cert_path || !key_path) { return false; }
+ auto wctx = static_cast<impl::WolfSSLContext *>(ctx);
+
+ // Load certificate file
+ int ret =
+ wolfSSL_CTX_use_certificate_file(wctx->ctx, cert_path, SSL_FILETYPE_PEM);
+ if (ret != SSL_SUCCESS) {
+ impl::wolfssl_last_error() =
+ static_cast<uint64_t>(wolfSSL_ERR_peek_last_error());
+ return false;
+ }
+
+ // Set password callback if password is provided
+ if (password) { impl::set_wolfssl_password_cb(wctx->ctx, password); }
+
+ // Load private key file
+ ret = wolfSSL_CTX_use_PrivateKey_file(wctx->ctx, key_path, SSL_FILETYPE_PEM);
+ if (ret != SSL_SUCCESS) {
+ impl::wolfssl_last_error() =
+ static_cast<uint64_t>(wolfSSL_ERR_peek_last_error());
+ return false;
+ }
+
+ // Verify that the certificate and private key match
+ return wolfSSL_CTX_check_private_key(wctx->ctx) == SSL_SUCCESS;
+}
+
+void set_verify_client(ctx_t ctx, bool require) {
+ if (!ctx) { return; }
+ auto wctx = static_cast<impl::WolfSSLContext *>(ctx);
+ wctx->verify_client = require;
+ if (require) {
+ wolfSSL_CTX_set_verify(
+ wctx->ctx, SSL_VERIFY_PEER | SSL_VERIFY_FAIL_IF_NO_PEER_CERT,
+ wctx->has_verify_callback ? impl::wolfssl_verify_callback : nullptr);
+ } else {
+ if (wctx->has_verify_callback) {
+ wolfSSL_CTX_set_verify(wctx->ctx, SSL_VERIFY_PEER,
+ impl::wolfssl_verify_callback);
+ } else {
+ wolfSSL_CTX_set_verify(wctx->ctx, SSL_VERIFY_NONE, nullptr);
+ }
+ }
+}
+
+session_t create_session(ctx_t ctx, socket_t sock) {
+ if (!ctx || sock == INVALID_SOCKET) { return nullptr; }
+ auto wctx = static_cast<impl::WolfSSLContext *>(ctx);
+
+ auto session = new (std::nothrow) impl::WolfSSLSession();
+ if (!session) { return nullptr; }
+
+ session->sock = sock;
+ session->ssl = wolfSSL_new(wctx->ctx);
+ if (!session->ssl) {
+ impl::wolfssl_last_error() =
+ static_cast<uint64_t>(wolfSSL_ERR_peek_last_error());
+ delete session;
+ return nullptr;
+ }
+
+ wolfSSL_set_fd(session->ssl, static_cast<int>(sock));
+
+ return static_cast<session_t>(session);
+}
+
+void free_session(session_t session) {
+ if (session) { delete static_cast<impl::WolfSSLSession *>(session); }
+}
+
+bool set_sni(session_t session, const char *hostname) {
+ if (!session || !hostname) { return false; }
+ auto wsession = static_cast<impl::WolfSSLSession *>(session);
+
+ int ret = wolfSSL_UseSNI(wsession->ssl, WOLFSSL_SNI_HOST_NAME, hostname,
+ static_cast<word16>(strlen(hostname)));
+ if (ret != WOLFSSL_SUCCESS) {
+ impl::wolfssl_last_error() =
+ static_cast<uint64_t>(wolfSSL_ERR_peek_last_error());
+ return false;
+ }
+
+ // Also set hostname for verification
+ wolfSSL_check_domain_name(wsession->ssl, hostname);
+
+ wsession->hostname = hostname;
+ return true;
+}
+
+bool set_hostname(session_t session, const char *hostname) {
+ // In wolfSSL, set_hostname also sets up hostname verification
+ return set_sni(session, hostname);
+}
+
+TlsError connect(session_t session) {
+ TlsError err;
+ if (!session) {
+ err.code = ErrorCode::Fatal;
+ return err;
+ }
+
+ auto wsession = static_cast<impl::WolfSSLSession *>(session);
+ int ret = wolfSSL_connect(wsession->ssl);
+
+ if (ret == SSL_SUCCESS) {
+ err.code = ErrorCode::Success;
+ } else {
+ int ssl_error = wolfSSL_get_error(wsession->ssl, ret);
+ err.code = impl::map_wolfssl_error(wsession->ssl, ssl_error, err.sys_errno);
+ err.backend_code = static_cast<uint64_t>(ssl_error);
+ impl::wolfssl_last_error() = err.backend_code;
+ }
+
+ return err;
+}
+
+TlsError accept(session_t session) {
+ TlsError err;
+ if (!session) {
+ err.code = ErrorCode::Fatal;
+ return err;
+ }
+
+ auto wsession = static_cast<impl::WolfSSLSession *>(session);
+ int ret = wolfSSL_accept(wsession->ssl);
+
+ if (ret == SSL_SUCCESS) {
+ err.code = ErrorCode::Success;
+ // Capture SNI from thread-local storage after successful handshake
+ wsession->sni_hostname = std::move(impl::wolfssl_pending_sni());
+ impl::wolfssl_pending_sni().clear();
+ } else {
+ int ssl_error = wolfSSL_get_error(wsession->ssl, ret);
+ err.code = impl::map_wolfssl_error(wsession->ssl, ssl_error, err.sys_errno);
+ err.backend_code = static_cast<uint64_t>(ssl_error);
+ impl::wolfssl_last_error() = err.backend_code;
+ }
+
+ return err;
+}
+
+bool connect_nonblocking(session_t session, socket_t sock,
+ time_t timeout_sec, time_t timeout_usec,
+ TlsError *err) {
+ if (!session) {
+ if (err) { err->code = ErrorCode::Fatal; }
+ return false;
+ }
+
+ auto wsession = static_cast<impl::WolfSSLSession *>(session);
+
+ // Set socket to non-blocking mode
+ detail::set_nonblocking(sock, true);
+ auto cleanup =
+ detail::scope_exit([&]() { detail::set_nonblocking(sock, false); });
+
+ int ret;
+ while ((ret = wolfSSL_connect(wsession->ssl)) != SSL_SUCCESS) {
+ int ssl_error = wolfSSL_get_error(wsession->ssl, ret);
+ if (ssl_error == SSL_ERROR_WANT_READ) {
+ if (detail::select_read(sock, timeout_sec, timeout_usec) > 0) {
+ continue;
+ }
+ } else if (ssl_error == SSL_ERROR_WANT_WRITE) {
+ if (detail::select_write(sock, timeout_sec, timeout_usec) > 0) {
+ continue;
+ }
+ }
+
+ // Error or timeout
+ if (err) {
+ err->code =
+ impl::map_wolfssl_error(wsession->ssl, ssl_error, err->sys_errno);
+ err->backend_code = static_cast<uint64_t>(ssl_error);
+ }
+ impl::wolfssl_last_error() = static_cast<uint64_t>(ssl_error);
+ return false;
+ }
+
+ if (err) { err->code = ErrorCode::Success; }
+ return true;
+}
+
+bool accept_nonblocking(session_t session, socket_t sock,
+ time_t timeout_sec, time_t timeout_usec,
+ TlsError *err) {
+ if (!session) {
+ if (err) { err->code = ErrorCode::Fatal; }
+ return false;
+ }
+
+ auto wsession = static_cast<impl::WolfSSLSession *>(session);
+
+ // Set socket to non-blocking mode
+ detail::set_nonblocking(sock, true);
+ auto cleanup =
+ detail::scope_exit([&]() { detail::set_nonblocking(sock, false); });
+
+ int ret;
+ while ((ret = wolfSSL_accept(wsession->ssl)) != SSL_SUCCESS) {
+ int ssl_error = wolfSSL_get_error(wsession->ssl, ret);
+ if (ssl_error == SSL_ERROR_WANT_READ) {
+ if (detail::select_read(sock, timeout_sec, timeout_usec) > 0) {
+ continue;
+ }
+ } else if (ssl_error == SSL_ERROR_WANT_WRITE) {
+ if (detail::select_write(sock, timeout_sec, timeout_usec) > 0) {
+ continue;
+ }
+ }
+
+ // Error or timeout
+ if (err) {
+ err->code =
+ impl::map_wolfssl_error(wsession->ssl, ssl_error, err->sys_errno);
+ err->backend_code = static_cast<uint64_t>(ssl_error);
+ }
+ impl::wolfssl_last_error() = static_cast<uint64_t>(ssl_error);
+ return false;
+ }
+
+ if (err) { err->code = ErrorCode::Success; }
+
+ // Capture SNI from thread-local storage after successful handshake
+ wsession->sni_hostname = std::move(impl::wolfssl_pending_sni());
+ impl::wolfssl_pending_sni().clear();
+
+ return true;
+}
+
+ssize_t read(session_t session, void *buf, size_t len, TlsError &err) {
+ if (!session || !buf) {
+ err.code = ErrorCode::Fatal;
+ return -1;
+ }
+
+ auto wsession = static_cast<impl::WolfSSLSession *>(session);
+ int ret = wolfSSL_read(wsession->ssl, buf, static_cast<int>(len));
+
+ if (ret > 0) {
+ err.code = ErrorCode::Success;
+ return static_cast<ssize_t>(ret);
+ }
+
+ if (ret == 0) {
+ err.code = ErrorCode::PeerClosed;
+ return 0;
+ }
+
+ int ssl_error = wolfSSL_get_error(wsession->ssl, ret);
+ err.code = impl::map_wolfssl_error(wsession->ssl, ssl_error, err.sys_errno);
+ err.backend_code = static_cast<uint64_t>(ssl_error);
+ impl::wolfssl_last_error() = err.backend_code;
+ return -1;
+}
+
+ssize_t write(session_t session, const void *buf, size_t len,
+ TlsError &err) {
+ if (!session || !buf) {
+ err.code = ErrorCode::Fatal;
+ return -1;
+ }
+
+ auto wsession = static_cast<impl::WolfSSLSession *>(session);
+ int ret = wolfSSL_write(wsession->ssl, buf, static_cast<int>(len));
+
+ if (ret > 0) {
+ err.code = ErrorCode::Success;
+ return static_cast<ssize_t>(ret);
+ }
+
+ // wolfSSL_write returns 0 when the peer has sent a close_notify.
+ // Treat this as an error (return -1) so callers don't spin in a
+ // write loop adding zero to the offset.
+ if (ret == 0) {
+ err.code = ErrorCode::PeerClosed;
+ return -1;
+ }
+
+ int ssl_error = wolfSSL_get_error(wsession->ssl, ret);
+ err.code = impl::map_wolfssl_error(wsession->ssl, ssl_error, err.sys_errno);
+ err.backend_code = static_cast<uint64_t>(ssl_error);
+ impl::wolfssl_last_error() = err.backend_code;
+ return -1;
+}
+
+int pending(const_session_t session) {
+ if (!session) { return 0; }
+ auto wsession =
+ static_cast<impl::WolfSSLSession *>(const_cast<void *>(session));
+ return wolfSSL_pending(wsession->ssl);
+}
+
+void shutdown(session_t session, bool graceful) {
+ if (!session) { return; }
+ auto wsession = static_cast<impl::WolfSSLSession *>(session);
+
+ if (graceful) {
+ int ret;
+ int attempts = 0;
+ while ((ret = wolfSSL_shutdown(wsession->ssl)) != SSL_SUCCESS &&
+ attempts < 3) {
+ int ssl_error = wolfSSL_get_error(wsession->ssl, ret);
+ if (ssl_error != SSL_ERROR_WANT_READ &&
+ ssl_error != SSL_ERROR_WANT_WRITE) {
+ break;
+ }
+ attempts++;
+ }
+ } else {
+ wolfSSL_shutdown(wsession->ssl);
+ }
+}
+
+bool is_peer_closed(session_t session, socket_t sock) {
+ if (!session || sock == INVALID_SOCKET) { return true; }
+ auto wsession = static_cast<impl::WolfSSLSession *>(session);
+
+ // Check if there's already decrypted data available
+ if (wolfSSL_pending(wsession->ssl) > 0) { return false; }
+
+ // Set socket to non-blocking to avoid blocking on read
+ detail::set_nonblocking(sock, true);
+ auto cleanup =
+ detail::scope_exit([&]() { detail::set_nonblocking(sock, false); });
+
+ // Peek 1 byte to check connection status without consuming data
+ unsigned char buf;
+ int ret = wolfSSL_peek(wsession->ssl, &buf, 1);
+
+ // If we got data or WANT_READ (would block), connection is alive
+ if (ret > 0) { return false; }
+
+ int ssl_error = wolfSSL_get_error(wsession->ssl, ret);
+ if (ssl_error == SSL_ERROR_WANT_READ) { return false; }
+
+ return ssl_error == SSL_ERROR_ZERO_RETURN || ssl_error == SSL_ERROR_SYSCALL ||
+ ret == 0;
+}
+
+cert_t get_peer_cert(const_session_t session) {
+ if (!session) { return nullptr; }
+ auto wsession =
+ static_cast<impl::WolfSSLSession *>(const_cast<void *>(session));
+
+ WOLFSSL_X509 *cert = wolfSSL_get_peer_certificate(wsession->ssl);
+ return static_cast<cert_t>(cert);
+}
+
+void free_cert(cert_t cert) {
+ if (cert) { wolfSSL_X509_free(static_cast<WOLFSSL_X509 *>(cert)); }
+}
+
+bool verify_hostname(cert_t cert, const char *hostname) {
+ if (!cert || !hostname) { return false; }
+ auto x509 = static_cast<WOLFSSL_X509 *>(cert);
+ std::string host_str(hostname);
+
+ // Check if hostname is an IP address
+ bool is_ip = impl::is_ipv4_address(host_str);
+ unsigned char ip_bytes[4];
+ if (is_ip) { impl::parse_ipv4(host_str, ip_bytes); }
+
+ // Check Subject Alternative Names
+ auto *san_names = static_cast<WOLF_STACK_OF(WOLFSSL_GENERAL_NAME) *>(
+ wolfSSL_X509_get_ext_d2i(x509, NID_subject_alt_name, nullptr, nullptr));
+
+ if (san_names) {
+ int san_count = wolfSSL_sk_num(san_names);
+ for (int i = 0; i < san_count; i++) {
+ auto *names =
+ static_cast<WOLFSSL_GENERAL_NAME *>(wolfSSL_sk_value(san_names, i));
+ if (!names) continue;
+
+ if (!is_ip && names->type == WOLFSSL_GEN_DNS) {
+ // DNS name
+ unsigned char *dns_name = nullptr;
+ int dns_len = wolfSSL_ASN1_STRING_to_UTF8(&dns_name, names->d.dNSName);
+ if (dns_name && dns_len > 0) {
+ std::string san_name(reinterpret_cast<char *>(dns_name),
+ static_cast<size_t>(dns_len));
+ XFREE(dns_name, nullptr, DYNAMIC_TYPE_OPENSSL);
+ if (detail::match_hostname(san_name, host_str)) {
+ wolfSSL_sk_free(san_names);
+ return true;
+ }
+ }
+ } else if (is_ip && names->type == WOLFSSL_GEN_IPADD) {
+ // IP address
+ unsigned char *ip_data = wolfSSL_ASN1_STRING_data(names->d.iPAddress);
+ int ip_len = wolfSSL_ASN1_STRING_length(names->d.iPAddress);
+ if (ip_data && ip_len == 4 && memcmp(ip_data, ip_bytes, 4) == 0) {
+ wolfSSL_sk_free(san_names);
+ return true;
+ }
+ }
+ }
+ wolfSSL_sk_free(san_names);
+ }
+
+ // Fallback: Check Common Name (CN) in subject
+ WOLFSSL_X509_NAME *subject = wolfSSL_X509_get_subject_name(x509);
+ if (subject) {
+ char cn[256] = {};
+ int cn_len = wolfSSL_X509_NAME_get_text_by_NID(subject, NID_commonName, cn,
+ sizeof(cn));
+ if (cn_len > 0) {
+ std::string cn_str(cn, static_cast<size_t>(cn_len));
+ if (detail::match_hostname(cn_str, host_str)) { return true; }
+ }
+ }
+
+ return false;
+}
+
+uint64_t hostname_mismatch_code() {
+ return static_cast<uint64_t>(DOMAIN_NAME_MISMATCH);
+}
+
+long get_verify_result(const_session_t session) {
+ if (!session) { return -1; }
+ auto wsession =
+ static_cast<impl::WolfSSLSession *>(const_cast<void *>(session));
+ long result = wolfSSL_get_verify_result(wsession->ssl);
+ return result;
+}
+
+std::string get_cert_subject_cn(cert_t cert) {
+ if (!cert) return "";
+ auto x509 = static_cast<WOLFSSL_X509 *>(cert);
+
+ WOLFSSL_X509_NAME *subject = wolfSSL_X509_get_subject_name(x509);
+ if (!subject) return "";
+
+ char cn[256] = {};
+ int cn_len = wolfSSL_X509_NAME_get_text_by_NID(subject, NID_commonName, cn,
+ sizeof(cn));
+ if (cn_len <= 0) return "";
+ return std::string(cn, static_cast<size_t>(cn_len));
+}
+
+std::string get_cert_issuer_name(cert_t cert) {
+ if (!cert) return "";
+ auto x509 = static_cast<WOLFSSL_X509 *>(cert);
+
+ WOLFSSL_X509_NAME *issuer = wolfSSL_X509_get_issuer_name(x509);
+ if (!issuer) return "";
+
+ char *name_str = wolfSSL_X509_NAME_oneline(issuer, nullptr, 0);
+ if (!name_str) return "";
+
+ std::string result(name_str);
+ XFREE(name_str, nullptr, DYNAMIC_TYPE_OPENSSL);
+ return result;
+}
+
+bool get_cert_sans(cert_t cert, std::vector<SanEntry> &sans) {
+ sans.clear();
+ if (!cert) return false;
+ auto x509 = static_cast<WOLFSSL_X509 *>(cert);
+
+ auto *san_names = static_cast<WOLF_STACK_OF(WOLFSSL_GENERAL_NAME) *>(
+ wolfSSL_X509_get_ext_d2i(x509, NID_subject_alt_name, nullptr, nullptr));
+ if (!san_names) return true; // No SANs is not an error
+
+ int count = wolfSSL_sk_num(san_names);
+ for (int i = 0; i < count; i++) {
+ auto *name =
+ static_cast<WOLFSSL_GENERAL_NAME *>(wolfSSL_sk_value(san_names, i));
+ if (!name) continue;
+
+ SanEntry entry;
+ switch (name->type) {
+ case WOLFSSL_GEN_DNS: {
+ entry.type = SanType::DNS;
+ unsigned char *dns_name = nullptr;
+ int dns_len = wolfSSL_ASN1_STRING_to_UTF8(&dns_name, name->d.dNSName);
+ if (dns_name && dns_len > 0) {
+ entry.value = std::string(reinterpret_cast<char *>(dns_name),
+ static_cast<size_t>(dns_len));
+ XFREE(dns_name, nullptr, DYNAMIC_TYPE_OPENSSL);
+ }
+ break;
+ }
+ case WOLFSSL_GEN_IPADD: {
+ entry.type = SanType::IP;
+ unsigned char *ip_data = wolfSSL_ASN1_STRING_data(name->d.iPAddress);
+ int ip_len = wolfSSL_ASN1_STRING_length(name->d.iPAddress);
+ if (ip_data && ip_len == 4) {
+ char buf[16];
+ snprintf(buf, sizeof(buf), "%d.%d.%d.%d", ip_data[0], ip_data[1],
+ ip_data[2], ip_data[3]);
+ entry.value = buf;
+ } else if (ip_data && ip_len == 16) {
+ char buf[64];
+ snprintf(buf, sizeof(buf),
+ "%02x%02x:%02x%02x:%02x%02x:%02x%02x:"
+ "%02x%02x:%02x%02x:%02x%02x:%02x%02x",
+ ip_data[0], ip_data[1], ip_data[2], ip_data[3], ip_data[4],
+ ip_data[5], ip_data[6], ip_data[7], ip_data[8], ip_data[9],
+ ip_data[10], ip_data[11], ip_data[12], ip_data[13],
+ ip_data[14], ip_data[15]);
+ entry.value = buf;
+ }
+ break;
+ }
+ case WOLFSSL_GEN_EMAIL:
+ entry.type = SanType::EMAIL;
+ {
+ unsigned char *email = nullptr;
+ int email_len = wolfSSL_ASN1_STRING_to_UTF8(&email, name->d.rfc822Name);
+ if (email && email_len > 0) {
+ entry.value = std::string(reinterpret_cast<char *>(email),
+ static_cast<size_t>(email_len));
+ XFREE(email, nullptr, DYNAMIC_TYPE_OPENSSL);
+ }
+ }
+ break;
+ case WOLFSSL_GEN_URI:
+ entry.type = SanType::URI;
+ {
+ unsigned char *uri = nullptr;
+ int uri_len = wolfSSL_ASN1_STRING_to_UTF8(
+ &uri, name->d.uniformResourceIdentifier);
+ if (uri && uri_len > 0) {
+ entry.value = std::string(reinterpret_cast<char *>(uri),
+ static_cast<size_t>(uri_len));
+ XFREE(uri, nullptr, DYNAMIC_TYPE_OPENSSL);
+ }
+ }
+ break;
+ default: entry.type = SanType::OTHER; break;
+ }
+
+ if (!entry.value.empty()) { sans.push_back(std::move(entry)); }
+ }
+ wolfSSL_sk_free(san_names);
+ return true;
+}
+
+bool get_cert_validity(cert_t cert, time_t ¬_before,
+ time_t ¬_after) {
+ if (!cert) return false;
+ auto x509 = static_cast<WOLFSSL_X509 *>(cert);
+
+ const WOLFSSL_ASN1_TIME *nb = wolfSSL_X509_get_notBefore(x509);
+ const WOLFSSL_ASN1_TIME *na = wolfSSL_X509_get_notAfter(x509);
+
+ if (!nb || !na) return false;
+
+ // wolfSSL_ASN1_TIME_to_tm is available
+ struct tm tm_nb = {}, tm_na = {};
+ if (wolfSSL_ASN1_TIME_to_tm(nb, &tm_nb) != WOLFSSL_SUCCESS) return false;
+ if (wolfSSL_ASN1_TIME_to_tm(na, &tm_na) != WOLFSSL_SUCCESS) return false;
+
+#ifdef _WIN32
+ not_before = _mkgmtime(&tm_nb);
+ not_after = _mkgmtime(&tm_na);
+#else
+ not_before = timegm(&tm_nb);
+ not_after = timegm(&tm_na);
+#endif
+ return true;
+}
+
+std::string get_cert_serial(cert_t cert) {
+ if (!cert) return "";
+ auto x509 = static_cast<WOLFSSL_X509 *>(cert);
+
+ WOLFSSL_ASN1_INTEGER *serial_asn1 = wolfSSL_X509_get_serialNumber(x509);
+ if (!serial_asn1) return "";
+
+ // Get the serial number data
+ int len = serial_asn1->length;
+ unsigned char *data = serial_asn1->data;
+ if (!data || len <= 0) return "";
+
+ std::string result;
+ result.reserve(static_cast<size_t>(len) * 2);
+ for (int i = 0; i < len; i++) {
+ char hex[3];
+ snprintf(hex, sizeof(hex), "%02X", data[i]);
+ result += hex;
+ }
+ return result;
+}
+
+bool get_cert_der(cert_t cert, std::vector<unsigned char> &der) {
+ if (!cert) return false;
+ auto x509 = static_cast<WOLFSSL_X509 *>(cert);
+
+ int der_len = 0;
+ const unsigned char *der_data = wolfSSL_X509_get_der(x509, &der_len);
+ if (!der_data || der_len <= 0) return false;
+
+ der.assign(der_data, der_data + der_len);
+ return true;
+}
+
+const char *get_sni(const_session_t session) {
+ if (!session) return nullptr;
+ auto wsession = static_cast<const impl::WolfSSLSession *>(session);
+
+ // For server: return SNI received from client during handshake
+ if (!wsession->sni_hostname.empty()) {
+ return wsession->sni_hostname.c_str();
+ }
+
+ // For client: return the hostname set via set_sni
+ if (!wsession->hostname.empty()) { return wsession->hostname.c_str(); }
+
+ return nullptr;
+}
+
+uint64_t peek_error() {
+ return static_cast<uint64_t>(wolfSSL_ERR_peek_last_error());
+}
+
+uint64_t get_error() {
+ uint64_t err = impl::wolfssl_last_error();
+ impl::wolfssl_last_error() = 0;
+ return err;
+}
+
+std::string error_string(uint64_t code) {
+ char buf[256];
+ wolfSSL_ERR_error_string(static_cast<unsigned long>(code), buf);
+ return std::string(buf);
+}
+
+ca_store_t create_ca_store(const char *pem, size_t len) {
+ if (!pem || len == 0) { return nullptr; }
+ // Validate by attempting to load into a temporary ctx
+ WOLFSSL_CTX *tmp_ctx = wolfSSL_CTX_new(wolfTLSv1_2_client_method());
+ if (!tmp_ctx) { return nullptr; }
+ int ret = wolfSSL_CTX_load_verify_buffer(
+ tmp_ctx, reinterpret_cast<const unsigned char *>(pem),
+ static_cast<long>(len), SSL_FILETYPE_PEM);
+ wolfSSL_CTX_free(tmp_ctx);
+ if (ret != SSL_SUCCESS) { return nullptr; }
+ return static_cast<ca_store_t>(
+ new impl::WolfSSLCAStore{std::string(pem, len)});
+}
+
+void free_ca_store(ca_store_t store) {
+ delete static_cast<impl::WolfSSLCAStore *>(store);
+}
+
+bool set_ca_store(ctx_t ctx, ca_store_t store) {
+ if (!ctx || !store) { return false; }
+ auto *wctx = static_cast<impl::WolfSSLContext *>(ctx);
+ auto *ca = static_cast<impl::WolfSSLCAStore *>(store);
+ int ret = wolfSSL_CTX_load_verify_buffer(
+ wctx->ctx, reinterpret_cast<const unsigned char *>(ca->pem_data.data()),
+ static_cast<long>(ca->pem_data.size()), SSL_FILETYPE_PEM);
+ if (ret == SSL_SUCCESS) { wctx->ca_pem_data_ += ca->pem_data; }
+ return ret == SSL_SUCCESS;
+}
+
+size_t get_ca_certs(ctx_t ctx, std::vector<cert_t> &certs) {
+ certs.clear();
+ if (!ctx) { return 0; }
+ auto *wctx = static_cast<impl::WolfSSLContext *>(ctx);
+ if (wctx->ca_pem_data_.empty()) { return 0; }
+
+ const std::string &pem = wctx->ca_pem_data_;
+ const std::string begin_marker = "-----BEGIN CERTIFICATE-----";
+ const std::string end_marker = "-----END CERTIFICATE-----";
+ size_t pos = 0;
+ while ((pos = pem.find(begin_marker, pos)) != std::string::npos) {
+ size_t end_pos = pem.find(end_marker, pos);
+ if (end_pos == std::string::npos) { break; }
+ end_pos += end_marker.size();
+ std::string cert_pem = pem.substr(pos, end_pos - pos);
+ WOLFSSL_X509 *x509 = wolfSSL_X509_load_certificate_buffer(
+ reinterpret_cast<const unsigned char *>(cert_pem.data()),
+ static_cast<int>(cert_pem.size()), WOLFSSL_FILETYPE_PEM);
+ if (x509) { certs.push_back(static_cast<cert_t>(x509)); }
+ pos = end_pos;
+ }
+ return certs.size();
+}
+
+std::vector<std::string> get_ca_names(ctx_t ctx) {
+ std::vector<std::string> names;
+ if (!ctx) { return names; }
+ auto *wctx = static_cast<impl::WolfSSLContext *>(ctx);
+ if (wctx->ca_pem_data_.empty()) { return names; }
+
+ const std::string &pem = wctx->ca_pem_data_;
+ const std::string begin_marker = "-----BEGIN CERTIFICATE-----";
+ const std::string end_marker = "-----END CERTIFICATE-----";
+ size_t pos = 0;
+ while ((pos = pem.find(begin_marker, pos)) != std::string::npos) {
+ size_t end_pos = pem.find(end_marker, pos);
+ if (end_pos == std::string::npos) { break; }
+ end_pos += end_marker.size();
+ std::string cert_pem = pem.substr(pos, end_pos - pos);
+ WOLFSSL_X509 *x509 = wolfSSL_X509_load_certificate_buffer(
+ reinterpret_cast<const unsigned char *>(cert_pem.data()),
+ static_cast<int>(cert_pem.size()), WOLFSSL_FILETYPE_PEM);
+ if (x509) {
+ WOLFSSL_X509_NAME *subject = wolfSSL_X509_get_subject_name(x509);
+ if (subject) {
+ char *name_str = wolfSSL_X509_NAME_oneline(subject, nullptr, 0);
+ if (name_str) {
+ names.push_back(name_str);
+ XFREE(name_str, nullptr, DYNAMIC_TYPE_OPENSSL);
+ }
+ }
+ wolfSSL_X509_free(x509);
+ }
+ pos = end_pos;
+ }
+ return names;
+}
+
+bool update_server_cert(ctx_t ctx, const char *cert_pem,
+ const char *key_pem, const char *password) {
+ if (!ctx || !cert_pem || !key_pem) { return false; }
+ auto *wctx = static_cast<impl::WolfSSLContext *>(ctx);
+
+ // Load new certificate
+ int ret = wolfSSL_CTX_use_certificate_buffer(
+ wctx->ctx, reinterpret_cast<const unsigned char *>(cert_pem),
+ static_cast<long>(strlen(cert_pem)), SSL_FILETYPE_PEM);
+ if (ret != SSL_SUCCESS) {
+ impl::wolfssl_last_error() =
+ static_cast<uint64_t>(wolfSSL_ERR_peek_last_error());
+ return false;
+ }
+
+ // Set password if provided
+ if (password) { impl::set_wolfssl_password_cb(wctx->ctx, password); }
+
+ // Load new private key
+ ret = wolfSSL_CTX_use_PrivateKey_buffer(
+ wctx->ctx, reinterpret_cast<const unsigned char *>(key_pem),
+ static_cast<long>(strlen(key_pem)), SSL_FILETYPE_PEM);
+ if (ret != SSL_SUCCESS) {
+ impl::wolfssl_last_error() =
+ static_cast<uint64_t>(wolfSSL_ERR_peek_last_error());
+ return false;
+ }
+
+ return true;
+}
+
+bool update_server_client_ca(ctx_t ctx, const char *ca_pem) {
+ if (!ctx || !ca_pem) { return false; }
+ auto *wctx = static_cast<impl::WolfSSLContext *>(ctx);
+
+ int ret = wolfSSL_CTX_load_verify_buffer(
+ wctx->ctx, reinterpret_cast<const unsigned char *>(ca_pem),
+ static_cast<long>(strlen(ca_pem)), SSL_FILETYPE_PEM);
+ if (ret != SSL_SUCCESS) {
+ impl::wolfssl_last_error() =
+ static_cast<uint64_t>(wolfSSL_ERR_peek_last_error());
+ return false;
+ }
+ return true;
+}
+
+bool set_verify_callback(ctx_t ctx, VerifyCallback callback) {
+ if (!ctx) { return false; }
+ auto *wctx = static_cast<impl::WolfSSLContext *>(ctx);
+
+ impl::get_verify_callback() = std::move(callback);
+ wctx->has_verify_callback = static_cast<bool>(impl::get_verify_callback());
+
+ if (wctx->has_verify_callback) {
+ wolfSSL_CTX_set_verify(wctx->ctx, SSL_VERIFY_PEER,
+ impl::wolfssl_verify_callback);
+ } else {
+ wolfSSL_CTX_set_verify(
+ wctx->ctx,
+ wctx->verify_client
+ ? (SSL_VERIFY_PEER | SSL_VERIFY_FAIL_IF_NO_PEER_CERT)
+ : SSL_VERIFY_NONE,
+ nullptr);
+ }
+ return true;
+}
+
+long get_verify_error(const_session_t session) {
+ if (!session) { return -1; }
+ auto *wsession =
+ static_cast<impl::WolfSSLSession *>(const_cast<void *>(session));
+ return wolfSSL_get_verify_result(wsession->ssl);
+}
+
+std::string verify_error_string(long error_code) {
+ if (error_code == 0) { return ""; }
+ const char *str =
+ wolfSSL_X509_verify_cert_error_string(static_cast<int>(error_code));
+ return str ? std::string(str) : std::string();
+}
+
+} // namespace tls
+
+#endif // CPPHTTPLIB_WOLFSSL_SUPPORT
+
+// WebSocket implementation
+namespace ws {
+
+bool WebSocket::send_frame(Opcode op, const char *data, size_t len,
+ bool fin) {
+ std::lock_guard<std::mutex> lock(write_mutex_);
+ if (closed_) { return false; }
+ return detail::write_websocket_frame(strm_, op, data, len, fin, !is_server_);
+}
+
+ReadResult WebSocket::read(std::string &msg) {
+ while (!closed_) {
+ Opcode opcode;
+ std::string payload;
+ bool fin;
+
+ if (!impl::read_websocket_frame(strm_, opcode, payload, fin, is_server_,
+ CPPHTTPLIB_WEBSOCKET_MAX_PAYLOAD_LENGTH)) {
+ closed_ = true;
+ return Fail;
+ }
+
+ switch (opcode) {
+ case Opcode::Ping: {
+ std::lock_guard<std::mutex> lock(write_mutex_);
+ detail::write_websocket_frame(strm_, Opcode::Pong, payload.data(),
+ payload.size(), true, !is_server_);
+ continue;
+ }
+ case Opcode::Pong: continue;
+ case Opcode::Close: {
+ if (!closed_.exchange(true)) {
+ // Echo close frame back
+ std::lock_guard<std::mutex> lock(write_mutex_);
+ detail::write_websocket_frame(strm_, Opcode::Close, payload.data(),
+ payload.size(), true, !is_server_);
+ }
+ return Fail;
+ }
+ case Opcode::Text:
+ case Opcode::Binary: {
+ auto result = opcode == Opcode::Text ? Text : Binary;
+ msg = std::move(payload);
+
+ // Handle fragmentation
+ if (!fin) {
+ while (true) {
+ Opcode cont_opcode;
+ std::string cont_payload;
+ bool cont_fin;
+ if (!impl::read_websocket_frame(
+ strm_, cont_opcode, cont_payload, cont_fin, is_server_,
+ CPPHTTPLIB_WEBSOCKET_MAX_PAYLOAD_LENGTH)) {
+ closed_ = true;
+ return Fail;
+ }
+ if (cont_opcode == Opcode::Ping) {
+ std::lock_guard<std::mutex> lock(write_mutex_);
+ detail::write_websocket_frame(
+ strm_, Opcode::Pong, cont_payload.data(), cont_payload.size(),
+ true, !is_server_);
+ continue;
+ }
+ if (cont_opcode == Opcode::Pong) { continue; }
+ if (cont_opcode == Opcode::Close) {
+ if (!closed_.exchange(true)) {
+ std::lock_guard<std::mutex> lock(write_mutex_);
+ detail::write_websocket_frame(
+ strm_, Opcode::Close, cont_payload.data(),
+ cont_payload.size(), true, !is_server_);
+ }
+ return Fail;
+ }
+ // RFC 6455: continuation frames must use opcode 0x0
+ if (cont_opcode != Opcode::Continuation) {
+ closed_ = true;
+ return Fail;
+ }
+ msg += cont_payload;
+ if (msg.size() > CPPHTTPLIB_WEBSOCKET_MAX_PAYLOAD_LENGTH) {
+ closed_ = true;
+ return Fail;
+ }
+ if (cont_fin) { break; }
+ }
+ }
+ // RFC 6455 Section 5.6: text frames must contain valid UTF-8
+ if (result == Text && !impl::is_valid_utf8(msg)) {
+ close(CloseStatus::InvalidPayload, "invalid UTF-8");
+ return Fail;
+ }
+ return result;
+ }
+ default: closed_ = true; return Fail;
+ }
+ }
+ return Fail;
+}
+
+bool WebSocket::send(const std::string &data) {
+ return send_frame(Opcode::Text, data.data(), data.size());
+}
+
+bool WebSocket::send(const char *data, size_t len) {
+ return send_frame(Opcode::Binary, data, len);
+}
+
+void WebSocket::close(CloseStatus status, const std::string &reason) {
+ if (closed_.exchange(true)) { return; }
+ ping_cv_.notify_all();
+ std::string payload;
+ auto code = static_cast<uint16_t>(status);
+ payload.push_back(static_cast<char>((code >> 8) & 0xFF));
+ payload.push_back(static_cast<char>(code & 0xFF));
+ // RFC 6455 Section 5.5: control frame payload must not exceed 125 bytes
+ // Close frame has 2-byte status code, so reason is limited to 123 bytes
+ payload += reason.substr(0, 123);
+ {
+ std::lock_guard<std::mutex> lock(write_mutex_);
+ detail::write_websocket_frame(strm_, Opcode::Close, payload.data(),
+ payload.size(), true, !is_server_);
+ }
+
+ // RFC 6455 Section 7.1.1: after sending a Close frame, wait for the peer's
+ // Close response before closing the TCP connection. Use a short timeout to
+ // avoid hanging if the peer doesn't respond.
+ strm_.set_read_timeout(CPPHTTPLIB_WEBSOCKET_CLOSE_TIMEOUT_SECOND, 0);
+ Opcode op;
+ std::string resp;
+ bool fin;
+ while (impl::read_websocket_frame(strm_, op, resp, fin, is_server_, 125)) {
+ if (op == Opcode::Close) { break; }
+ }
+}
+
+WebSocket::~WebSocket() {
+ {
+ std::lock_guard<std::mutex> lock(ping_mutex_);
+ closed_ = true;
+ }
+ ping_cv_.notify_all();
+ if (ping_thread_.joinable()) { ping_thread_.join(); }
+}
+
+void WebSocket::start_heartbeat() {
+ ping_thread_ = std::thread([this]() {
+ std::unique_lock<std::mutex> lock(ping_mutex_);
+ while (!closed_) {
+ ping_cv_.wait_for(lock, std::chrono::seconds(
+ CPPHTTPLIB_WEBSOCKET_PING_INTERVAL_SECOND));
+ if (closed_) { break; }
+ lock.unlock();
+ if (!send_frame(Opcode::Ping, nullptr, 0)) {
+ closed_ = true;
+ break;
+ }
+ lock.lock();
+ }
+ });
+}
+
+const Request &WebSocket::request() const { return req_; }
+
+bool WebSocket::is_open() const { return !closed_; }
+
+// WebSocketClient implementation
+WebSocketClient::WebSocketClient(
+ const std::string &scheme_host_port_path, const Headers &headers)
+ : headers_(headers) {
+ const static std::regex re(
+ R"(([a-z]+):\/\/(?:\[([a-fA-F\d:]+)\]|([^:/?#]+))(?::(\d+))?(\/.*))");
+
+ std::smatch m;
+ if (std::regex_match(scheme_host_port_path, m, re)) {
+ auto scheme = m[1].str();
+
+#ifdef CPPHTTPLIB_SSL_ENABLED
+ if (scheme != "ws" && scheme != "wss") {
+#else
+ if (scheme != "ws") {
+#endif
+#ifndef CPPHTTPLIB_NO_EXCEPTIONS
+ std::string msg = "'" + scheme + "' scheme is not supported.";
+ throw std::invalid_argument(msg);
+#endif
+ return;
+ }
+
+ auto is_ssl = scheme == "wss";
+
+ host_ = m[2].str();
+ if (host_.empty()) { host_ = m[3].str(); }
+
+ auto port_str = m[4].str();
+ port_ = !port_str.empty() ? std::stoi(port_str) : (is_ssl ? 443 : 80);
+
+ path_ = m[5].str();
+
+#ifdef CPPHTTPLIB_SSL_ENABLED
+ is_ssl_ = is_ssl;
+#else
+ if (is_ssl) { return; }
+#endif
+
+ is_valid_ = true;
+ }
+}
+
+WebSocketClient::~WebSocketClient() { shutdown_and_close(); }
+
+bool WebSocketClient::is_valid() const { return is_valid_; }
+
+void WebSocketClient::shutdown_and_close() {
+#ifdef CPPHTTPLIB_SSL_ENABLED
+ if (is_ssl_) {
+ if (tls_session_) {
+ tls::shutdown(tls_session_, true);
+ tls::free_session(tls_session_);
+ tls_session_ = nullptr;
+ }
+ if (tls_ctx_) {
+ tls::free_context(tls_ctx_);
+ tls_ctx_ = nullptr;
+ }
+ }
+#endif
+ if (ws_ && ws_->is_open()) { ws_->close(); }
+ ws_.reset();
+ if (sock_ != INVALID_SOCKET) {
+ detail::shutdown_socket(sock_);
+ detail::close_socket(sock_);
+ sock_ = INVALID_SOCKET;
+ }
+}
+
+bool WebSocketClient::create_stream(std::unique_ptr<Stream> &strm) {
+#ifdef CPPHTTPLIB_SSL_ENABLED
+ if (is_ssl_) {
+ if (!detail::setup_client_tls_session(
+ host_, tls_ctx_, tls_session_, sock_,
+ server_certificate_verification_, ca_cert_file_path_,
+ ca_cert_store_, read_timeout_sec_, read_timeout_usec_)) {
+ return false;
+ }
+
+ strm = std::unique_ptr<Stream>(new detail::SSLSocketStream(
+ sock_, tls_session_, read_timeout_sec_, read_timeout_usec_,
+ write_timeout_sec_, write_timeout_usec_));
+ return true;
+ }
+#endif
+ strm = std::unique_ptr<Stream>(
+ new detail::SocketStream(sock_, read_timeout_sec_, read_timeout_usec_,
+ write_timeout_sec_, write_timeout_usec_));
+ return true;
+}
+
+bool WebSocketClient::connect() {
+ if (!is_valid_) { return false; }
+ shutdown_and_close();
+
+ Error error;
+ sock_ = detail::create_client_socket(
+ host_, std::string(), port_, AF_UNSPEC, false, false, nullptr, 5, 0,
+ read_timeout_sec_, read_timeout_usec_, write_timeout_sec_,
+ write_timeout_usec_, std::string(), error);
+
+ if (sock_ == INVALID_SOCKET) { return false; }
+
+ std::unique_ptr<Stream> strm;
+ if (!create_stream(strm)) {
+ shutdown_and_close();
+ return false;
+ }
+
+ std::string selected_subprotocol;
+ if (!detail::perform_websocket_handshake(*strm, host_, port_, path_, headers_,
+ selected_subprotocol)) {
+ shutdown_and_close();
+ return false;
+ }
+ subprotocol_ = std::move(selected_subprotocol);
+
+ Request req;
+ req.method = "GET";
+ req.path = path_;
+ ws_ = std::unique_ptr<WebSocket>(new WebSocket(std::move(strm), req, false));
+ return true;
+}
+
+ReadResult WebSocketClient::read(std::string &msg) {
+ if (!ws_) { return Fail; }
+ return ws_->read(msg);
+}
+
+bool WebSocketClient::send(const std::string &data) {
+ if (!ws_) { return false; }
+ return ws_->send(data);
+}
+
+bool WebSocketClient::send(const char *data, size_t len) {
+ if (!ws_) { return false; }
+ return ws_->send(data, len);
+}
+
+void WebSocketClient::close(CloseStatus status,
+ const std::string &reason) {
+ if (ws_) { ws_->close(status, reason); }
+}
+
+bool WebSocketClient::is_open() const { return ws_ && ws_->is_open(); }
+
+const std::string &WebSocketClient::subprotocol() const {
+ return subprotocol_;
+}
+
+void WebSocketClient::set_read_timeout(time_t sec, time_t usec) {
+ read_timeout_sec_ = sec;
+ read_timeout_usec_ = usec;
+}
+
+void WebSocketClient::set_write_timeout(time_t sec, time_t usec) {
+ write_timeout_sec_ = sec;
+ write_timeout_usec_ = usec;
+}
+
+#ifdef CPPHTTPLIB_SSL_ENABLED
+
+void WebSocketClient::set_ca_cert_path(const std::string &path) {
+ ca_cert_file_path_ = path;
+}
+
+void WebSocketClient::set_ca_cert_store(tls::ca_store_t store) {
+ ca_cert_store_ = store;
+}
+
+void
+WebSocketClient::enable_server_certificate_verification(bool enabled) {
+ server_certificate_verification_ = enabled;
+}
+
+#endif // CPPHTTPLIB_SSL_ENABLED
+
+} // namespace ws
+
} // namespace httplib