]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
rpc : early register backend devices (#11262)
authorRadoslav Gerganov <redacted>
Fri, 17 Jan 2025 08:57:09 +0000 (10:57 +0200)
committerGitHub <redacted>
Fri, 17 Jan 2025 08:57:09 +0000 (10:57 +0200)
Early register RPC devices and do not propagate RPC specifics in the
llama model structures.

ref: #10609

common/arg.cpp
common/common.cpp
common/common.h
examples/llama-bench/llama-bench.cpp
ggml/include/ggml-backend.h
ggml/src/ggml-backend-impl.h
include/llama.h
src/llama-model.cpp
src/llama-model.h
src/llama.cpp

index dd10b635259c6ed0fdc6639209d096f51f151640..9069950eb093936f5019613482e851faeba1adbc 100644 (file)
@@ -376,6 +376,30 @@ static std::vector<ggml_backend_dev_t> parse_device_list(const std::string & val
     return devices;
 }
 
+static void add_rpc_devices(std::string servers) {
+    auto rpc_servers = string_split<std::string>(servers, ',');
+    if (rpc_servers.empty()) {
+        throw std::invalid_argument("no RPC servers specified");
+    }
+    ggml_backend_reg_t rpc_reg = ggml_backend_reg_by_name("RPC");
+    if (!rpc_reg) {
+        throw std::invalid_argument("failed to find RPC backend");
+    }
+    typedef ggml_backend_dev_t (*ggml_backend_rpc_add_device_t)(const char * endpoint);
+    ggml_backend_rpc_add_device_t ggml_backend_rpc_add_device_fn = (ggml_backend_rpc_add_device_t) ggml_backend_reg_get_proc_address(rpc_reg, "ggml_backend_rpc_add_device");
+    if (!ggml_backend_rpc_add_device_fn) {
+        throw std::invalid_argument("failed to find RPC device add function");
+    }
+    for (const auto & server : rpc_servers) {
+        ggml_backend_dev_t dev = ggml_backend_rpc_add_device_fn(server.c_str());
+        if (dev) {
+            ggml_backend_device_register(dev);
+        } else {
+            throw std::invalid_argument("failed to register RPC device");
+        }
+    }
+}
+
 bool common_params_parse(int argc, char ** argv, common_params & params, llama_example ex, void(*print_usage)(int, char **)) {
     auto ctx_arg = common_params_parser_init(params, ex, print_usage);
     const common_params params_org = ctx_arg.params; // the example can modify the default params
@@ -1385,7 +1409,8 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
             {"--rpc"}, "SERVERS",
             "comma separated list of RPC servers",
             [](common_params & params, const std::string & value) {
-                params.rpc_servers = value;
+                add_rpc_devices(value);
+                GGML_UNUSED(params);
             }
         ).set_env("LLAMA_ARG_RPC"));
     }
index a6f9252b27a9fb37388535d5d9a29de4e0af5699..451826d5d683e6eb0c50035918b70cf5286dc638 100644 (file)
@@ -1043,7 +1043,6 @@ struct llama_model_params common_model_params_to_llama(common_params & params) {
     if (params.n_gpu_layers != -1) {
         mparams.n_gpu_layers = params.n_gpu_layers;
     }
-    mparams.rpc_servers     = params.rpc_servers.c_str();
     mparams.main_gpu        = params.main_gpu;
     mparams.split_mode      = params.split_mode;
     mparams.tensor_split    = params.tensor_split;
index 4fab1319a7c82e02a391643091689681ff249e30..691141d6b6b2cc31d9b9646e9e6aa20f09f63002 100644 (file)
@@ -246,7 +246,6 @@ struct common_params {
     std::string lookup_cache_static  = ""; // path of static ngram cache file for lookup decoding           // NOLINT
     std::string lookup_cache_dynamic = ""; // path of dynamic ngram cache file for lookup decoding          // NOLINT
     std::string logits_file          = ""; // file for saving *all* logits                                  // NOLINT
-    std::string rpc_servers          = ""; // comma separated list of RPC servers                           // NOLINT
 
     std::vector<std::string> in_files;   // all input files
     std::vector<std::string> antiprompt; // strings upon which more user input is prompted (a.k.a. reverse prompts)
index a3b4c5ac83f802d9b8daf4893180db3c7ab2a4ad..4ac19ca86ec56cbec687b6daeb7292e35179c7d3 100644 (file)
@@ -683,7 +683,7 @@ struct cmd_params_instance {
     bool               cpu_strict;
     int                poll;
     int                n_gpu_layers;
-    std::string        rpc_servers;
+    std::string        rpc_servers_str;
     llama_split_mode   split_mode;
     int                main_gpu;
     bool               no_kv_offload;
@@ -696,8 +696,37 @@ struct cmd_params_instance {
         llama_model_params mparams = llama_model_default_params();
 
         mparams.n_gpu_layers = n_gpu_layers;
-        if (!rpc_servers.empty()) {
-            mparams.rpc_servers = rpc_servers.c_str();
+        if (!rpc_servers_str.empty()) {
+            auto rpc_servers = string_split<std::string>(rpc_servers_str, ',');
+
+            // add RPC devices
+            if (!rpc_servers.empty()) {
+                ggml_backend_reg_t rpc_reg = ggml_backend_reg_by_name("RPC");
+                if (!rpc_reg) {
+                    fprintf(stderr, "%s: failed to find RPC backend\n", __func__);
+                    exit(1);
+                }
+
+                typedef ggml_backend_dev_t (*ggml_backend_rpc_add_device_t)(const char * endpoint);
+                ggml_backend_rpc_add_device_t ggml_backend_rpc_add_device_fn = (ggml_backend_rpc_add_device_t) ggml_backend_reg_get_proc_address(rpc_reg, "ggml_backend_rpc_add_device");
+                if (!ggml_backend_rpc_add_device_fn) {
+                    fprintf(stderr, "%s: failed to find RPC device add function\n", __func__);
+                    exit(1);
+                }
+                static std::vector<ggml_backend_dev_t> devices;
+                devices.clear();
+                for (const std::string & server : rpc_servers) {
+                    ggml_backend_dev_t dev = ggml_backend_rpc_add_device_fn(server.c_str());
+                    if (dev) {
+                        devices.push_back(dev);
+                    } else {
+                        fprintf(stderr, "%s: failed to add RPC device for server '%s'\n", __func__, server.c_str());
+                        exit(1);
+                    }
+                }
+                devices.push_back(nullptr);
+                mparams.devices = devices.data();
+            }
         }
         mparams.split_mode   = split_mode;
         mparams.main_gpu     = main_gpu;
@@ -708,7 +737,7 @@ struct cmd_params_instance {
     }
 
     bool equal_mparams(const cmd_params_instance & other) const {
-        return model == other.model && n_gpu_layers == other.n_gpu_layers && rpc_servers == other.rpc_servers &&
+        return model == other.model && n_gpu_layers == other.n_gpu_layers && rpc_servers_str == other.rpc_servers_str &&
                split_mode == other.split_mode && main_gpu == other.main_gpu && use_mmap == other.use_mmap &&
                tensor_split == other.tensor_split;
     }
index 7221a08309274ad05c441d1a9cc9e0c176583c13..fc9571c82c959f61444552d539ab305d01efbd05 100644 (file)
@@ -203,6 +203,8 @@ extern "C" {
     // Backend registry
     //
 
+    GGML_API void ggml_backend_device_register(ggml_backend_dev_t device);
+
     // Backend (reg) enumeration
     GGML_API size_t             ggml_backend_reg_count(void);
     GGML_API ggml_backend_reg_t ggml_backend_reg_get(size_t index);
index 36d72e95f028c8336edcbba73ee85f7cbf3aba38..d1c2d76d8975eae52bf7671b2c7faf4e8387f5cb 100644 (file)
@@ -208,7 +208,6 @@ extern "C" {
 
     // Internal backend registry API
     GGML_API void ggml_backend_register(ggml_backend_reg_t reg);
-    GGML_API void ggml_backend_device_register(ggml_backend_dev_t device);
 
     // Add backend dynamic loading support to the backend
 
index be6802eefbf373c630becb1c07e61f58e2eb90e7..298b8d1bc0fa2cb791b11766c517605cfa5db5bf 100644 (file)
@@ -288,9 +288,6 @@ extern "C" {
         // proportion of the model (layers or rows) to offload to each GPU, size: llama_max_devices()
         const float * tensor_split;
 
-        // comma separated list of RPC servers to use for offloading
-        const char * rpc_servers;
-
         // Called with a progress value between 0.0 and 1.0. Pass NULL to disable.
         // If the provided progress_callback returns true, model loading continues.
         // If it returns false, model loading is immediately aborted.
index f90f5e746077b2445e39ace31c340e8e900fa7de..c2d23a8d3a195f3e6a71bcdc1df1939a69d9879a 100644 (file)
@@ -3717,7 +3717,6 @@ struct llama_model_params llama_model_default_params() {
         /*.split_mode                  =*/ LLAMA_SPLIT_MODE_LAYER,
         /*.main_gpu                    =*/ 0,
         /*.tensor_split                =*/ nullptr,
-        /*.rpc_servers                 =*/ nullptr,
         /*.progress_callback           =*/ nullptr,
         /*.progress_callback_user_data =*/ nullptr,
         /*.kv_overrides                =*/ nullptr,
index 4cc8abb753a4fce7865bbf6d61e972d8cb44e2a0..a7c30444786fdf5aeeb255c3c35d01d0fcf4cb7a 100644 (file)
@@ -323,8 +323,6 @@ struct llama_model {
     // gguf metadata
     std::unordered_map<std::string, std::string> gguf_kv;
 
-    std::vector<std::string> rpc_servers;
-
     // list of devices used in this model
     std::vector<ggml_backend_dev_t> devices;
 
index fede23d196b5f34433196027db17993eb37b1295..e8cfe5012819ce1bb5d11f5973c2dfb79b48f95c 100644 (file)
@@ -9399,47 +9399,6 @@ static struct llama_model * llama_model_load_from_file_impl(
         };
     }
 
-    if (params.rpc_servers != nullptr && params.rpc_servers[0] != '\0') {
-        // split the servers set them into model->rpc_servers
-        std::string servers(params.rpc_servers);
-        size_t pos = 0;
-        while ((pos = servers.find(',')) != std::string::npos) {
-            std::string server = servers.substr(0, pos);
-            model->rpc_servers.push_back(server);
-            servers.erase(0, pos + 1);
-        }
-        model->rpc_servers.push_back(servers);
-    }
-
-    // add RPC devices
-    if (!model->rpc_servers.empty()) {
-        ggml_backend_reg_t rpc_reg = ggml_backend_reg_by_name("RPC");
-        if (!rpc_reg) {
-            LLAMA_LOG_ERROR("%s: failed to find RPC backend\n", __func__);
-            llama_model_free(model);
-            return nullptr;
-        }
-
-        typedef ggml_backend_dev_t (*ggml_backend_rpc_add_device_t)(const char * endpoint);
-        ggml_backend_rpc_add_device_t ggml_backend_rpc_add_device_fn = (ggml_backend_rpc_add_device_t) ggml_backend_reg_get_proc_address(rpc_reg, "ggml_backend_rpc_add_device");
-        if (!ggml_backend_rpc_add_device_fn) {
-            LLAMA_LOG_ERROR("%s: failed to find RPC device add function\n", __func__);
-            llama_model_free(model);
-            return nullptr;
-        }
-
-        for (const std::string & server : model->rpc_servers) {
-            ggml_backend_dev_t dev = ggml_backend_rpc_add_device_fn(server.c_str());
-            if (dev) {
-                model->devices.push_back(dev);
-            } else {
-                LLAMA_LOG_ERROR("%s: failed to add RPC device for server '%s'\n", __func__, server.c_str());
-                llama_model_free(model);
-                return nullptr;
-            }
-        }
-    }
-
     // create list of devices to use with this model
     if (params.devices) {
         for (ggml_backend_dev_t * dev = params.devices; *dev; ++dev) {