#include <string>
#include <vector>
#include <memory>
+#include <mutex>
#include <unordered_map>
#include <unordered_set>
#ifdef _WIN32
sockfd_t fd;
socket_t(sockfd_t fd) : fd(fd) {}
~socket_t() {
+ GGML_PRINT_DEBUG("[%s] closing socket %d\n", __func__, this->fd);
#ifdef _WIN32
closesocket(this->fd);
#else
}
struct ggml_backend_rpc_buffer_type_context {
- std::shared_ptr<socket_t> sock;
+ std::string endpoint;
std::string name;
size_t alignment;
size_t max_size;
struct ggml_backend_rpc_context {
std::string endpoint;
std::string name;
- std::shared_ptr<socket_t> sock;
- ggml_backend_buffer_type_t buft;
};
struct ggml_backend_rpc_buffer_context {
return true;
}
-static bool parse_endpoint(const char * endpoint, std::string & host, int & port) {
- std::string str(endpoint);
- size_t pos = str.find(':');
+static bool parse_endpoint(const std::string & endpoint, std::string & host, int & port) {
+ size_t pos = endpoint.find(':');
if (pos == std::string::npos) {
return false;
}
- host = str.substr(0, pos);
- port = std::stoi(str.substr(pos + 1));
+ host = endpoint.substr(0, pos);
+ port = std::stoi(endpoint.substr(pos + 1));
return true;
}
// RPC client-side implementation
+static std::shared_ptr<socket_t> get_socket(const std::string & endpoint) {
+ static std::mutex mutex;
+ std::lock_guard<std::mutex> lock(mutex);
+ static std::unordered_map<std::string, std::weak_ptr<socket_t>> sockets;
+ static bool initialized = false;
+
+ auto it = sockets.find(endpoint);
+ if (it != sockets.end()) {
+ if (auto sock = it->second.lock()) {
+ return sock;
+ }
+ }
+ std::string host;
+ int port;
+ if (!parse_endpoint(endpoint, host, port)) {
+ return nullptr;
+ }
+#ifdef _WIN32
+ if (!initialized) {
+ WSADATA wsaData;
+ int res = WSAStartup(MAKEWORD(2, 2), &wsaData);
+ if (res != 0) {
+ return nullptr;
+ }
+ initialized = true;
+ }
+#else
+ UNUSED(initialized);
+#endif
+ auto sock = socket_connect(host.c_str(), port);
+ if (sock == nullptr) {
+ return nullptr;
+ }
+ GGML_PRINT_DEBUG("[%s] connected to %s, sockfd=%d\n", __func__, endpoint.c_str(), sock->fd);
+ sockets[endpoint] = sock;
+ return sock;
+}
+
GGML_CALL static const char * ggml_backend_rpc_buffer_get_name(ggml_backend_buffer_t buffer) {
ggml_backend_rpc_buffer_context * ctx = (ggml_backend_rpc_buffer_context *)buffer->context;
return ctx->name.c_str();
std::vector<uint8_t> input(input_size, 0);
memcpy(input.data(), &size, sizeof(size));
std::vector<uint8_t> output;
- bool status = send_rpc_cmd(buft_ctx->sock, ALLOC_BUFFER, input, output);
+ auto sock = get_socket(buft_ctx->endpoint);
+ bool status = send_rpc_cmd(sock, ALLOC_BUFFER, input, output);
GGML_ASSERT(status);
GGML_ASSERT(output.size() == 2*sizeof(uint64_t));
// output serialization format: | remote_ptr (8 bytes) | remote_size (8 bytes) |
if (remote_ptr != 0) {
ggml_backend_buffer_t buffer = ggml_backend_buffer_init(buft,
ggml_backend_rpc_buffer_interface,
- new ggml_backend_rpc_buffer_context{buft_ctx->sock, {}, remote_ptr, "RPC"},
+ new ggml_backend_rpc_buffer_context{sock, {}, remote_ptr, "RPC"},
remote_size);
return buffer;
} else {
}
ggml_backend_rpc_buffer_type_context * buft_ctx = (ggml_backend_rpc_buffer_type_context *)buft->context;
ggml_backend_rpc_context * rpc_ctx = (ggml_backend_rpc_context *)backend->context;
- return buft_ctx->sock == rpc_ctx->sock;
+ return buft_ctx->endpoint == rpc_ctx->endpoint;
}
static ggml_backend_buffer_type_i ggml_backend_rpc_buffer_type_interface = {
/* .is_host = */ NULL,
};
-
GGML_CALL static const char * ggml_backend_rpc_name(ggml_backend_t backend) {
ggml_backend_rpc_context * rpc_ctx = (ggml_backend_rpc_context *)backend->context;
GGML_CALL static void ggml_backend_rpc_free(ggml_backend_t backend) {
ggml_backend_rpc_context * rpc_ctx = (ggml_backend_rpc_context *)backend->context;
- ggml_backend_rpc_buffer_type_context * buft_ctx = (ggml_backend_rpc_buffer_type_context *)rpc_ctx->buft->context;
- delete buft_ctx;
- delete rpc_ctx->buft;
delete rpc_ctx;
delete backend;
}
GGML_CALL static ggml_backend_buffer_type_t ggml_backend_rpc_get_default_buffer_type(ggml_backend_t backend) {
ggml_backend_rpc_context * ctx = (ggml_backend_rpc_context *)backend->context;
- return ctx->buft;
+ return ggml_backend_rpc_buffer_type(ctx->endpoint.c_str());
}
GGML_CALL static void ggml_backend_rpc_synchronize(ggml_backend_t backend) {
std::vector<uint8_t> input;
serialize_graph(cgraph, input);
std::vector<uint8_t> output;
- bool status = send_rpc_cmd(rpc_ctx->sock, GRAPH_COMPUTE, input, output);
+ auto sock = get_socket(rpc_ctx->endpoint);
+ bool status = send_rpc_cmd(sock, GRAPH_COMPUTE, input, output);
GGML_ASSERT(status);
GGML_ASSERT(output.size() == 1);
return (enum ggml_status)output[0];
/* .event_synchronize = */ NULL,
};
-static std::unordered_map<std::string, ggml_backend_t> instances;
-
GGML_API GGML_CALL ggml_backend_buffer_type_t ggml_backend_rpc_buffer_type(const char * endpoint) {
- ggml_backend_t backend = ggml_backend_rpc_init(endpoint);
- return backend != nullptr ? ggml_backend_rpc_get_default_buffer_type(backend) : nullptr;
-}
-
-GGML_CALL ggml_backend_t ggml_backend_rpc_init(const char * endpoint) {
- std::string endpoint_str(endpoint);
- if (instances.find(endpoint_str) != instances.end()) {
- return instances[endpoint_str];
- }
-#ifdef _WIN32
- {
- WSADATA wsaData;
- int res = WSAStartup(MAKEWORD(2, 2), &wsaData);
- if (res != 0) {
- return nullptr;
- }
- }
-#endif
- fprintf(stderr, "Connecting to %s\n", endpoint);
- std::string host;
- int port;
- if (!parse_endpoint(endpoint, host, port)) {
- return nullptr;
- }
- auto sock = socket_connect(host.c_str(), port);
+ static std::mutex mutex;
+ std::lock_guard<std::mutex> lock(mutex);
+ // NOTE: buffer types are allocated and never freed; this is by design
+ static std::unordered_map<std::string, ggml_backend_buffer_type_t> buft_map;
+ auto it = buft_map.find(endpoint);
+ if (it != buft_map.end()) {
+ return it->second;
+ }
+ auto sock = get_socket(endpoint);
if (sock == nullptr) {
return nullptr;
}
size_t alignment = get_alignment(sock);
size_t max_size = get_max_size(sock);
ggml_backend_rpc_buffer_type_context * buft_ctx = new ggml_backend_rpc_buffer_type_context {
- /* .sock = */ sock,
- /* .name = */ "RPC" + std::to_string(sock->fd),
+ /* .endpoint = */ endpoint,
+ /* .name = */ "RPC[" + std::string(endpoint) + "]",
/* .alignment = */ alignment,
- /* .max_size = */ max_size
+ /* .max_size = */ max_size
};
ggml_backend_buffer_type_t buft = new ggml_backend_buffer_type {
/* .iface = */ ggml_backend_rpc_buffer_type_interface,
/* .context = */ buft_ctx
};
+ buft_map[endpoint] = buft;
+ return buft;
+}
+GGML_CALL ggml_backend_t ggml_backend_rpc_init(const char * endpoint) {
ggml_backend_rpc_context * ctx = new ggml_backend_rpc_context {
- /* .endpoint = */ endpoint,
- /* .name = */ "RPC" + std::to_string(sock->fd),
- /* .sock = */ sock,
- /* .buft = */ buft
+ /* .endpoint = */ endpoint,
+ /* .name = */ "RPC",
};
- instances[endpoint] = new ggml_backend {
+ ggml_backend_t backend = new ggml_backend {
/* .guid = */ ggml_backend_rpc_guid(),
/* .interface = */ ggml_backend_rpc_interface,
/* .context = */ ctx
};
-
- return instances[endpoint];
+ return backend;
}
GGML_API GGML_CALL bool ggml_backend_is_rpc(ggml_backend_t backend) {
}
GGML_API GGML_CALL void ggml_backend_rpc_get_device_memory(const char * endpoint, size_t * free, size_t * total) {
- ggml_backend_t backend = ggml_backend_rpc_init(endpoint);
- if (backend == nullptr) {
+ auto sock = get_socket(endpoint);
+ if (sock == nullptr) {
*free = 0;
*total = 0;
return;
}
- ggml_backend_rpc_context * ctx = (ggml_backend_rpc_context *)backend->context;
- get_device_memory(ctx->sock, free, total);
+ get_device_memory(sock, free, total);
}
// RPC server-side implementation