]> git.djapps.eu Git - pkg/ggml/sources/ggml/commitdiff
rpc : add backend registry / device interfaces (llama/9812)
authorDiego Devesa <redacted>
Thu, 10 Oct 2024 18:14:55 +0000 (20:14 +0200)
committerGeorgi Gerganov <redacted>
Wed, 16 Oct 2024 08:28:39 +0000 (11:28 +0300)
* rpc : add backend registry / device interfaces

* llama : add llama_supports_rpc API

* ggml_backend_rpc_start_rpc_server -> ggml_backend_rpc_start_server

include/ggml-rpc.h
src/ggml-backend.cpp
src/ggml-rpc.cpp

index 64cde7f13d391795051dea9612101417399448f0..d5796736821f4badd3e573e414b89935afbde06f 100644 (file)
@@ -17,7 +17,11 @@ GGML_API ggml_backend_buffer_type_t ggml_backend_rpc_buffer_type(const char * en
 
 GGML_API void ggml_backend_rpc_get_device_memory(const char * endpoint, size_t * free, size_t * total);
 
-GGML_API void start_rpc_server(ggml_backend_t backend, const char * endpoint, size_t free_mem, size_t total_mem);
+GGML_API void ggml_backend_rpc_start_server(ggml_backend_t backend, const char * endpoint, size_t free_mem, size_t total_mem);
+
+GGML_API ggml_backend_reg_t ggml_backend_rpc_reg(void);
+
+GGML_API ggml_backend_dev_t ggml_backend_rpc_add_device(const char * endpoint);
 
 #ifdef  __cplusplus
 }
index 627b4dbc7873213b2c2a923d6572ded397afdc6a..fb1d3ead3be69a27905552493ed80a279d17dae3 100644 (file)
@@ -542,6 +542,10 @@ void * ggml_backend_reg_get_proc_address(ggml_backend_reg_t reg, const char * na
 #include "ggml-blas.h"
 #endif
 
+#ifdef GGML_USE_RPC
+#include "ggml-rpc.h"
+#endif
+
 struct ggml_backend_registry {
     std::vector<ggml_backend_reg_t> backends;
     std::vector<ggml_backend_dev_t> devices;
@@ -556,6 +560,9 @@ struct ggml_backend_registry {
 #ifdef GGML_USE_BLAS
         register_backend(ggml_backend_blas_reg());
 #endif
+#ifdef GGML_USE_RPC
+        register_backend(ggml_backend_rpc_reg());
+#endif
 
         // TODO: sycl, vulkan, kompute, cann
 
index ab7298cbae0e691eb9c6ff01f6eac2eec4c17048..13c7dd4364c33130251e98040644fd7b0c28d960 100644 (file)
@@ -25,7 +25,7 @@
 #  include <netdb.h>
 #  include <unistd.h>
 #endif
-#include <string.h>
+#include <cstring>
 
 #define UNUSED GGML_UNUSED
 
@@ -630,22 +630,6 @@ static enum ggml_status ggml_backend_rpc_graph_compute(ggml_backend_t backend, g
     return (enum ggml_status)output[0];
 }
 
-static bool ggml_backend_rpc_supports_op(ggml_backend_t backend, const ggml_tensor * op) {
-    UNUSED(backend);
-    UNUSED(op);
-    //TODO: call the remote backend and cache the results
-    return true;
-}
-
-static bool ggml_backend_rpc_supports_buft(ggml_backend_t backend, ggml_backend_buffer_type_t buft) {
-    if (!buft || buft->iface.get_name != ggml_backend_rpc_buffer_type_name) {
-        return false;
-    }
-    ggml_backend_rpc_buffer_type_context * buft_ctx = (ggml_backend_rpc_buffer_type_context *)buft->context;
-    ggml_backend_rpc_context * rpc_ctx = (ggml_backend_rpc_context *)backend->context;
-    return buft_ctx->endpoint == rpc_ctx->endpoint;
-}
-
 static ggml_backend_i ggml_backend_rpc_interface = {
     /* .get_name                = */ ggml_backend_rpc_name,
     /* .free                    = */ ggml_backend_rpc_free,
@@ -659,8 +643,8 @@ static ggml_backend_i ggml_backend_rpc_interface = {
     /* .graph_plan_update       = */ NULL,
     /* .graph_plan_compute      = */ NULL,
     /* .graph_compute           = */ ggml_backend_rpc_graph_compute,
-    /* .supports_op             = */ ggml_backend_rpc_supports_op,
-    /* .supports_buft           = */ ggml_backend_rpc_supports_buft,
+    /* .supports_op             = */ NULL,
+    /* .supports_buft           = */ NULL,
     /* .offload_op              = */ NULL,
     /* .event_record            = */ NULL,
     /* .event_wait              = */ NULL,
@@ -691,7 +675,7 @@ GGML_API ggml_backend_buffer_type_t ggml_backend_rpc_buffer_type(const char * en
 
     ggml_backend_buffer_type_t buft = new ggml_backend_buffer_type {
         /* .iface   = */ ggml_backend_rpc_buffer_type_interface,
-        /* .device  = */ nullptr,
+        /* .device  = */ ggml_backend_rpc_add_device(endpoint),
         /* .context = */ buft_ctx
     };
     buft_map[endpoint] = buft;
@@ -707,7 +691,7 @@ ggml_backend_t ggml_backend_rpc_init(const char * endpoint) {
     ggml_backend_t backend = new ggml_backend {
         /* .guid      = */ ggml_backend_rpc_guid(),
         /* .interface = */ ggml_backend_rpc_interface,
-        /* .device    = */ nullptr,
+        /* .device    = */ ggml_backend_rpc_add_device(endpoint),
         /* .context   = */ ctx
     };
     return backend;
@@ -1189,7 +1173,7 @@ static void rpc_serve_client(ggml_backend_t backend, sockfd_t sockfd, size_t fre
     }
 }
 
-void start_rpc_server(ggml_backend_t backend, const char * endpoint, size_t free_mem, size_t total_mem) {
+void ggml_backend_rpc_start_server(ggml_backend_t backend, const char * endpoint, size_t free_mem, size_t total_mem) {
     std::string host;
     int port;
     if (!parse_endpoint(endpoint, host, port)) {
@@ -1226,3 +1210,179 @@ void start_rpc_server(ggml_backend_t backend, const char * endpoint, size_t free
     WSACleanup();
 #endif
 }
+
+// device interface
+
+struct ggml_backend_rpc_device_context {
+    std::string endpoint;
+    std::string name;
+};
+
+static const char * ggml_backend_rpc_device_get_name(ggml_backend_dev_t dev) {
+    ggml_backend_rpc_device_context * ctx = (ggml_backend_rpc_device_context *)dev->context;
+
+    return ctx->name.c_str();
+}
+
+static const char * ggml_backend_rpc_device_get_description(ggml_backend_dev_t dev) {
+    ggml_backend_rpc_device_context * ctx = (ggml_backend_rpc_device_context *)dev->context;
+
+    return ctx->name.c_str();
+}
+
+static void ggml_backend_rpc_device_get_memory(ggml_backend_dev_t dev, size_t * free, size_t * total) {
+    ggml_backend_rpc_device_context * ctx = (ggml_backend_rpc_device_context *)dev->context;
+
+    ggml_backend_rpc_get_device_memory(ctx->endpoint.c_str(), free, total);
+
+    UNUSED(dev);
+}
+
+static enum ggml_backend_dev_type ggml_backend_rpc_device_get_type(ggml_backend_dev_t dev) {
+    // TODO: obtain value from the server
+    return GGML_BACKEND_DEVICE_TYPE_GPU_FULL;
+
+    UNUSED(dev);
+}
+
+static void ggml_backend_rpc_device_get_props(ggml_backend_dev_t dev, struct ggml_backend_dev_props * props) {
+    props->name        = ggml_backend_rpc_device_get_name(dev);
+    props->description = ggml_backend_rpc_device_get_description(dev);
+    props->type        = ggml_backend_rpc_device_get_type(dev);
+    ggml_backend_rpc_device_get_memory(dev, &props->memory_free, &props->memory_total);
+    props->caps = {
+        /* .async                 = */ false,
+        /* .host_buffer           = */ false,
+        /* .buffer_from_host_ptr  = */ false,
+        /* .events                = */ false,
+    };
+}
+
+static ggml_backend_t ggml_backend_rpc_device_init(ggml_backend_dev_t dev, const char * params) {
+    ggml_backend_rpc_device_context * ctx = (ggml_backend_rpc_device_context *)dev->context;
+
+    return ggml_backend_rpc_init(ctx->endpoint.c_str());
+
+    UNUSED(params);
+}
+
+static ggml_backend_buffer_type_t ggml_backend_rpc_device_get_buffer_type(ggml_backend_dev_t dev) {
+    ggml_backend_rpc_device_context * ctx = (ggml_backend_rpc_device_context *)dev->context;
+
+    return ggml_backend_rpc_buffer_type(ctx->endpoint.c_str());
+
+    UNUSED(dev);
+}
+
+static ggml_backend_buffer_t ggml_backend_rpc_device_buffer_from_ptr(ggml_backend_dev_t dev, void * ptr, size_t size, size_t max_tensor_size) {
+    return ggml_backend_cpu_buffer_from_ptr(ptr, size);
+
+    UNUSED(dev);
+    UNUSED(max_tensor_size);
+}
+
+static bool ggml_backend_rpc_device_supports_op(ggml_backend_dev_t dev, const struct ggml_tensor * op) {
+    UNUSED(dev);
+    UNUSED(op);
+    //TODO: call the remote backend and cache the results
+    return true;
+}
+
+static bool ggml_backend_rpc_device_supports_buft(ggml_backend_dev_t dev, ggml_backend_buffer_type_t buft) {
+    if (!buft || buft->iface.get_name != ggml_backend_rpc_buffer_type_name) {
+        return false;
+    }
+    ggml_backend_rpc_buffer_type_context * buft_ctx = (ggml_backend_rpc_buffer_type_context *)buft->context;
+    ggml_backend_rpc_device_context * dev_ctx = (ggml_backend_rpc_device_context *)dev->context;
+    return buft_ctx->endpoint == dev_ctx->endpoint;
+}
+
+static const struct ggml_backend_device_i ggml_backend_rpc_device_i = {
+    /* .get_name             = */ ggml_backend_rpc_device_get_name,
+    /* .get_description      = */ ggml_backend_rpc_device_get_description,
+    /* .get_memory           = */ ggml_backend_rpc_device_get_memory,
+    /* .get_type             = */ ggml_backend_rpc_device_get_type,
+    /* .get_props            = */ ggml_backend_rpc_device_get_props,
+    /* .init_backend         = */ ggml_backend_rpc_device_init,
+    /* .get_buffer_type      = */ ggml_backend_rpc_device_get_buffer_type,
+    /* .get_host_buffer_type = */ NULL,
+    /* .buffer_from_host_ptr = */ ggml_backend_rpc_device_buffer_from_ptr,
+    /* .supports_op          = */ ggml_backend_rpc_device_supports_op,
+    /* .supports_buft        = */ ggml_backend_rpc_device_supports_buft,
+    /* .offload_op           = */ NULL,
+    /* .event_new            = */ NULL,
+    /* .event_free           = */ NULL,
+    /* .event_synchronize    = */ NULL,
+};
+
+// backend reg interface
+
+static const char * ggml_backend_rpc_reg_get_name(ggml_backend_reg_t reg) {
+    return "RPC";
+
+    UNUSED(reg);
+}
+
+static size_t ggml_backend_rpc_reg_get_device_count(ggml_backend_reg_t reg) {
+    return 0;
+
+    UNUSED(reg);
+}
+
+static ggml_backend_dev_t ggml_backend_rpc_reg_get_device(ggml_backend_reg_t reg, size_t index) {
+    GGML_ABORT("The RPC backend does not have enumerated devices - use ggml_backend_add_device instead");
+
+    UNUSED(reg);
+    UNUSED(index);
+}
+
+static void * ggml_backend_rpc_get_proc_address(ggml_backend_reg_t reg, const char * name) {
+    if (std::strcmp(name, "ggml_backend_rpc_add_device") == 0) {
+        return (void *)ggml_backend_rpc_add_device;
+    }
+    return NULL;
+
+    UNUSED(reg);
+}
+
+static const struct ggml_backend_reg_i ggml_backend_rpc_reg_i = {
+    /* .get_name         = */ ggml_backend_rpc_reg_get_name,
+    /* .get_device_count = */ ggml_backend_rpc_reg_get_device_count,
+    /* .get_device       = */ ggml_backend_rpc_reg_get_device,
+    /* .get_proc_address = */ ggml_backend_rpc_get_proc_address,
+};
+
+ggml_backend_reg_t ggml_backend_rpc_reg(void) {
+    static struct ggml_backend_reg ggml_backend_rpc_reg = {
+        /* .iface   = */ ggml_backend_rpc_reg_i,
+        /* .context = */ NULL,
+    };
+
+    return &ggml_backend_rpc_reg;
+}
+
+ggml_backend_dev_t ggml_backend_rpc_add_device(const char * endpoint) {
+    static std::unordered_map<std::string, ggml_backend_dev_t> dev_map;
+
+    static std::mutex mutex;
+    std::lock_guard<std::mutex> lock(mutex);
+
+    if (dev_map.find(endpoint) != dev_map.end()) {
+        return dev_map[endpoint];
+    }
+
+    ggml_backend_rpc_device_context * ctx = new ggml_backend_rpc_device_context {
+        /* .endpoint = */ endpoint,
+        /* .name     = */ "RPC[" + std::string(endpoint) + "]",
+    };
+
+    ggml_backend_dev_t dev = new ggml_backend_device {
+        /* .iface   = */ ggml_backend_rpc_device_i,
+        /* .reg     = */ ggml_backend_rpc_reg(),
+        /* .context = */ ctx,
+    };
+
+    dev_map[endpoint] = dev;
+
+    return dev;
+}