]> git.djapps.eu Git - pkg/ggml/sources/ggml/commitdiff
rpc : resource management rework (llama/7562)
authorRadoslav Gerganov <redacted>
Tue, 28 May 2024 15:13:36 +0000 (18:13 +0300)
committerGeorgi Gerganov <redacted>
Wed, 29 May 2024 10:16:38 +0000 (13:16 +0300)
* rpc : resource management rework

* address review comments

src/ggml-rpc.cpp

index cc1d3ace1ddac012b87d2ac31aac81111f4d15b6..49a20df4bd85e9511d7176256a7b7a641c07676a 100644 (file)
@@ -6,6 +6,7 @@
 #include <string>
 #include <vector>
 #include <memory>
+#include <mutex>
 #include <unordered_map>
 #include <unordered_set>
 #ifdef _WIN32
@@ -47,6 +48,7 @@ struct socket_t {
     sockfd_t fd;
     socket_t(sockfd_t fd) : fd(fd) {}
     ~socket_t() {
+        GGML_PRINT_DEBUG("[%s] closing socket %d\n", __func__, this->fd);
 #ifdef _WIN32
         closesocket(this->fd);
 #else
@@ -97,7 +99,7 @@ static ggml_guid_t ggml_backend_rpc_guid() {
 }
 
 struct ggml_backend_rpc_buffer_type_context {
-    std::shared_ptr<socket_t> sock;
+    std::string endpoint;
     std::string name;
     size_t alignment;
     size_t max_size;
@@ -106,8 +108,6 @@ struct ggml_backend_rpc_buffer_type_context {
 struct ggml_backend_rpc_context {
     std::string endpoint;
     std::string name;
-    std::shared_ptr<socket_t> sock;
-    ggml_backend_buffer_type_t buft;
 };
 
 struct ggml_backend_rpc_buffer_context {
@@ -231,14 +231,13 @@ static bool recv_data(sockfd_t sockfd, void * data, size_t size) {
     return true;
 }
 
-static bool parse_endpoint(const char * endpoint, std::string & host, int & port) {
-    std::string str(endpoint);
-    size_t pos = str.find(':');
+static bool parse_endpoint(const std::string & endpoint, std::string & host, int & port) {
+    size_t pos = endpoint.find(':');
     if (pos == std::string::npos) {
         return false;
     }
-    host = str.substr(0, pos);
-    port = std::stoi(str.substr(pos + 1));
+    host = endpoint.substr(0, pos);
+    port = std::stoi(endpoint.substr(pos + 1));
     return true;
 }
 
@@ -273,6 +272,44 @@ static bool send_rpc_cmd(const std::shared_ptr<socket_t> & sock, enum rpc_cmd cm
 
 // RPC client-side implementation
 
+static std::shared_ptr<socket_t> get_socket(const std::string & endpoint) {
+    static std::mutex mutex;
+    std::lock_guard<std::mutex> lock(mutex);
+    static std::unordered_map<std::string, std::weak_ptr<socket_t>> sockets;
+    static bool initialized = false;
+
+    auto it = sockets.find(endpoint);
+    if (it != sockets.end()) {
+        if (auto sock = it->second.lock()) {
+            return sock;
+        }
+    }
+    std::string host;
+    int port;
+    if (!parse_endpoint(endpoint, host, port)) {
+        return nullptr;
+    }
+#ifdef _WIN32
+    if (!initialized) {
+        WSADATA wsaData;
+        int res = WSAStartup(MAKEWORD(2, 2), &wsaData);
+        if (res != 0) {
+            return nullptr;
+        }
+        initialized = true;
+    }
+#else
+    UNUSED(initialized);
+#endif
+    auto sock = socket_connect(host.c_str(), port);
+    if (sock == nullptr) {
+        return nullptr;
+    }
+    GGML_PRINT_DEBUG("[%s] connected to %s, sockfd=%d\n", __func__, endpoint.c_str(), sock->fd);
+    sockets[endpoint] = sock;
+    return sock;
+}
+
 GGML_CALL static const char * ggml_backend_rpc_buffer_get_name(ggml_backend_buffer_t buffer) {
     ggml_backend_rpc_buffer_context * ctx = (ggml_backend_rpc_buffer_context *)buffer->context;
     return ctx->name.c_str();
@@ -442,7 +479,8 @@ GGML_CALL static ggml_backend_buffer_t ggml_backend_rpc_buffer_type_alloc_buffer
     std::vector<uint8_t> input(input_size, 0);
     memcpy(input.data(), &size, sizeof(size));
     std::vector<uint8_t> output;
-    bool status = send_rpc_cmd(buft_ctx->sock, ALLOC_BUFFER, input, output);
+    auto sock = get_socket(buft_ctx->endpoint);
+    bool status = send_rpc_cmd(sock, ALLOC_BUFFER, input, output);
     GGML_ASSERT(status);
     GGML_ASSERT(output.size() == 2*sizeof(uint64_t));
     // output serialization format: | remote_ptr (8 bytes) | remote_size (8 bytes) |
@@ -453,7 +491,7 @@ GGML_CALL static ggml_backend_buffer_t ggml_backend_rpc_buffer_type_alloc_buffer
     if (remote_ptr != 0) {
         ggml_backend_buffer_t buffer = ggml_backend_buffer_init(buft,
             ggml_backend_rpc_buffer_interface,
-            new ggml_backend_rpc_buffer_context{buft_ctx->sock, {}, remote_ptr, "RPC"},
+            new ggml_backend_rpc_buffer_context{sock, {}, remote_ptr, "RPC"},
             remote_size);
         return buffer;
     } else {
@@ -508,7 +546,7 @@ GGML_CALL static bool ggml_backend_rpc_buffer_type_supports_backend(ggml_backend
     }
     ggml_backend_rpc_buffer_type_context * buft_ctx = (ggml_backend_rpc_buffer_type_context *)buft->context;
     ggml_backend_rpc_context * rpc_ctx = (ggml_backend_rpc_context *)backend->context;
-    return buft_ctx->sock == rpc_ctx->sock;
+    return buft_ctx->endpoint == rpc_ctx->endpoint;
 }
 
 static ggml_backend_buffer_type_i ggml_backend_rpc_buffer_type_interface = {
@@ -521,7 +559,6 @@ static ggml_backend_buffer_type_i ggml_backend_rpc_buffer_type_interface = {
     /* .is_host          = */ NULL,
 };
 
-
 GGML_CALL static const char * ggml_backend_rpc_name(ggml_backend_t backend) {
     ggml_backend_rpc_context * rpc_ctx = (ggml_backend_rpc_context *)backend->context;
 
@@ -530,16 +567,13 @@ GGML_CALL static const char * ggml_backend_rpc_name(ggml_backend_t backend) {
 
 GGML_CALL static void ggml_backend_rpc_free(ggml_backend_t backend) {
     ggml_backend_rpc_context * rpc_ctx = (ggml_backend_rpc_context *)backend->context;
-    ggml_backend_rpc_buffer_type_context * buft_ctx = (ggml_backend_rpc_buffer_type_context *)rpc_ctx->buft->context;
-    delete buft_ctx;
-    delete rpc_ctx->buft;
     delete rpc_ctx;
     delete backend;
 }
 
 GGML_CALL static ggml_backend_buffer_type_t ggml_backend_rpc_get_default_buffer_type(ggml_backend_t backend) {
     ggml_backend_rpc_context * ctx = (ggml_backend_rpc_context *)backend->context;
-    return ctx->buft;
+    return ggml_backend_rpc_buffer_type(ctx->endpoint.c_str());
 }
 
 GGML_CALL static void ggml_backend_rpc_synchronize(ggml_backend_t backend) {
@@ -590,7 +624,8 @@ GGML_CALL static enum ggml_status ggml_backend_rpc_graph_compute(ggml_backend_t
     std::vector<uint8_t> input;
     serialize_graph(cgraph, input);
     std::vector<uint8_t> output;
-    bool status = send_rpc_cmd(rpc_ctx->sock, GRAPH_COMPUTE, input, output);
+    auto sock = get_socket(rpc_ctx->endpoint);
+    bool status = send_rpc_cmd(sock, GRAPH_COMPUTE, input, output);
     GGML_ASSERT(status);
     GGML_ASSERT(output.size() == 1);
     return (enum ggml_status)output[0];
@@ -624,65 +659,48 @@ static ggml_backend_i ggml_backend_rpc_interface = {
     /* .event_synchronize       = */ NULL,
 };
 
-static std::unordered_map<std::string, ggml_backend_t> instances;
-
 GGML_API GGML_CALL ggml_backend_buffer_type_t ggml_backend_rpc_buffer_type(const char * endpoint) {
-    ggml_backend_t backend = ggml_backend_rpc_init(endpoint);
-    return backend != nullptr ? ggml_backend_rpc_get_default_buffer_type(backend) : nullptr;
-}
-
-GGML_CALL ggml_backend_t ggml_backend_rpc_init(const char * endpoint) {
-    std::string endpoint_str(endpoint);
-    if (instances.find(endpoint_str) != instances.end()) {
-        return instances[endpoint_str];
-    }
-#ifdef _WIN32
-    {
-        WSADATA wsaData;
-        int res = WSAStartup(MAKEWORD(2, 2), &wsaData);
-        if (res != 0) {
-            return nullptr;
-        }
-    }
-#endif
-    fprintf(stderr, "Connecting to %s\n", endpoint);
-    std::string host;
-    int port;
-    if (!parse_endpoint(endpoint, host, port)) {
-        return nullptr;
-    }
-    auto sock = socket_connect(host.c_str(), port);
+    static std::mutex mutex;
+    std::lock_guard<std::mutex> lock(mutex);
+    // NOTE: buffer types are allocated and never freed; this is by design
+    static std::unordered_map<std::string, ggml_backend_buffer_type_t> buft_map;
+    auto it = buft_map.find(endpoint);
+    if (it != buft_map.end()) {
+        return it->second;
+    }
+    auto sock = get_socket(endpoint);
     if (sock == nullptr) {
         return nullptr;
     }
     size_t alignment = get_alignment(sock);
     size_t max_size = get_max_size(sock);
     ggml_backend_rpc_buffer_type_context * buft_ctx = new ggml_backend_rpc_buffer_type_context {
-        /* .sock   = */ sock,
-        /* .name   = */ "RPC" + std::to_string(sock->fd),
+        /* .endpoint  = */ endpoint,
+        /* .name      = */ "RPC[" + std::string(endpoint) + "]",
         /* .alignment = */ alignment,
-        /* .max_size = */ max_size
+        /* .max_size  = */ max_size
     };
 
     ggml_backend_buffer_type_t buft = new ggml_backend_buffer_type {
         /* .iface   = */ ggml_backend_rpc_buffer_type_interface,
         /* .context = */ buft_ctx
     };
+    buft_map[endpoint] = buft;
+    return buft;
+}
 
+GGML_CALL ggml_backend_t ggml_backend_rpc_init(const char * endpoint) {
     ggml_backend_rpc_context * ctx = new ggml_backend_rpc_context {
-        /* .endpoint = */ endpoint,
-        /* .name     = */ "RPC" + std::to_string(sock->fd),
-        /* .sock     = */ sock,
-        /* .buft     = */ buft
+        /* .endpoint  = */ endpoint,
+        /* .name      = */ "RPC",
     };
 
-    instances[endpoint] = new ggml_backend {
+    ggml_backend_t backend = new ggml_backend {
         /* .guid      = */ ggml_backend_rpc_guid(),
         /* .interface = */ ggml_backend_rpc_interface,
         /* .context   = */ ctx
     };
-
-    return instances[endpoint];
+    return backend;
 }
 
 GGML_API GGML_CALL bool ggml_backend_is_rpc(ggml_backend_t backend) {
@@ -706,14 +724,13 @@ static void get_device_memory(const std::shared_ptr<socket_t> & sock, size_t * f
 }
 
 GGML_API GGML_CALL void ggml_backend_rpc_get_device_memory(const char * endpoint, size_t * free, size_t * total) {
-    ggml_backend_t backend = ggml_backend_rpc_init(endpoint);
-    if (backend == nullptr) {
+    auto sock = get_socket(endpoint);
+    if (sock == nullptr) {
         *free = 0;
         *total = 0;
         return;
     }
-    ggml_backend_rpc_context * ctx = (ggml_backend_rpc_context *)backend->context;
-    get_device_memory(ctx->sock, free, total);
+    get_device_memory(sock, free, total);
 }
 
 // RPC server-side implementation