]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
llama : accept a list of devices to use to offload a model (#10497)
authorDiego Devesa <redacted>
Mon, 25 Nov 2024 18:30:06 +0000 (19:30 +0100)
committerGitHub <redacted>
Mon, 25 Nov 2024 18:30:06 +0000 (19:30 +0100)
* llama : accept a list of devices to use to offload a model

* accept `--dev none` to completely disable offloading

* fix dev list with dl backends

* rename env parameter to LLAMA_ARG_DEVICE for consistency

common/arg.cpp
common/common.cpp
common/common.h
examples/server/server.cpp
examples/speculative-simple/speculative-simple.cpp
examples/speculative/speculative.cpp
ggml/src/ggml-backend-reg.cpp
include/llama.h
src/llama.cpp

index 32240f21f2469cd6f09132aa07edece1e8d5d8b8..272492e50df1562f92748a2ca1c167ba5ae4f81d 100644 (file)
@@ -298,6 +298,27 @@ static void common_params_print_usage(common_params_context & ctx_arg) {
     print_options(specific_options);
 }
 
+static std::vector<ggml_backend_dev_t> parse_device_list(const std::string & value) {
+    std::vector<ggml_backend_dev_t> devices;
+    auto dev_names = string_split<std::string>(value, ',');
+    if (dev_names.empty()) {
+        throw std::invalid_argument("no devices specified");
+    }
+    if (dev_names.size() == 1 && dev_names[0] == "none") {
+        devices.push_back(nullptr);
+    } else {
+        for (const auto & device : dev_names) {
+            auto * dev = ggml_backend_dev_by_name(device.c_str());
+            if (!dev || ggml_backend_dev_type(dev) != GGML_BACKEND_DEVICE_TYPE_GPU) {
+                throw std::invalid_argument(string_format("invalid device: %s", device.c_str()));
+            }
+            devices.push_back(dev);
+        }
+        devices.push_back(nullptr);
+    }
+    return devices;
+}
+
 bool common_params_parse(int argc, char ** argv, common_params & params, llama_example ex, void(*print_usage)(int, char **)) {
     auto ctx_arg = common_params_parser_init(params, ex, print_usage);
     const common_params params_org = ctx_arg.params; // the example can modify the default params
@@ -324,6 +345,9 @@ bool common_params_parse(int argc, char ** argv, common_params & params, llama_e
 }
 
 common_params_context common_params_parser_init(common_params & params, llama_example ex, void(*print_usage)(int, char **)) {
+    // load dynamic backends
+    ggml_backend_load_all();
+
     common_params_context ctx_arg(params);
     ctx_arg.print_usage = print_usage;
     ctx_arg.ex          = ex;
@@ -1312,6 +1336,30 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
             else { throw std::invalid_argument("invalid value"); }
         }
     ).set_env("LLAMA_ARG_NUMA"));
+    add_opt(common_arg(
+        {"-dev", "--device"}, "<dev1,dev2,..>",
+        "comma-separated list of devices to use for offloading (none = don't offload)\n"
+        "use --list-devices to see a list of available devices",
+        [](common_params & params, const std::string & value) {
+            params.devices = parse_device_list(value);
+        }
+    ).set_env("LLAMA_ARG_DEVICE"));
+    add_opt(common_arg(
+        {"--list-devices"},
+        "print list of available devices and exit",
+        [](common_params &) {
+            printf("Available devices:\n");
+            for (size_t i = 0; i < ggml_backend_dev_count(); ++i) {
+                auto * dev = ggml_backend_dev_get(i);
+                if (ggml_backend_dev_type(dev) == GGML_BACKEND_DEVICE_TYPE_GPU) {
+                    size_t free, total;
+                    ggml_backend_dev_memory(dev, &free, &total);
+                    printf("  %s: %s (%zu MiB, %zu MiB free)\n", ggml_backend_dev_name(dev), ggml_backend_dev_description(dev), total / 1024 / 1024, free / 1024 / 1024);
+                }
+            }
+            exit(0);
+        }
+    ));
     add_opt(common_arg(
         {"-ngl", "--gpu-layers", "--n-gpu-layers"}, "N",
         "number of layers to store in VRAM",
@@ -1336,10 +1384,6 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
             } else if (arg_next == "layer") {
                 params.split_mode = LLAMA_SPLIT_MODE_LAYER;
             } else if (arg_next == "row") {
-#ifdef GGML_USE_SYCL
-                fprintf(stderr, "warning: The split mode value:[row] is not supported by llama.cpp with SYCL. It's developing.\nExit!\n");
-                exit(1);
-#endif // GGML_USE_SYCL
                 params.split_mode = LLAMA_SPLIT_MODE_ROW;
             } else {
                 throw std::invalid_argument("invalid value");
@@ -2042,6 +2086,14 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
             params.speculative.n_ctx = value;
         }
     ).set_examples({LLAMA_EXAMPLE_SPECULATIVE, LLAMA_EXAMPLE_SERVER}));
+    add_opt(common_arg(
+        {"-devd", "--device-draft"}, "<dev1,dev2,..>",
+        "comma-separated list of devices to use for offloading the draft model (none = don't offload)\n"
+        "use --list-devices to see a list of available devices",
+        [](common_params & params, const std::string & value) {
+            params.speculative.devices = parse_device_list(value);
+        }
+    ).set_examples({LLAMA_EXAMPLE_SPECULATIVE, LLAMA_EXAMPLE_SERVER}));
     add_opt(common_arg(
         {"-ngld", "--gpu-layers-draft", "--n-gpu-layers-draft"}, "N",
         "number of layers to store in VRAM for the draft model",
index 98524f7467ab4e42c152c90e2747b2e20c93a409..09ec9f2388afb61436b454f6d989341f99a25ab0 100644 (file)
@@ -377,9 +377,6 @@ void common_init() {
 #endif
 
     LOG_INF("build: %d (%s) with %s for %s%s\n", LLAMA_BUILD_NUMBER, LLAMA_COMMIT, LLAMA_COMPILER, LLAMA_BUILD_TARGET, build_type);
-
-    // load dynamic backends
-    ggml_backend_load_all();
 }
 
 std::string common_params_get_system_info(const common_params & params) {
@@ -982,9 +979,12 @@ void common_lora_adapters_apply(struct llama_context * ctx, std::vector<common_l
     }
 }
 
-struct llama_model_params common_model_params_to_llama(const common_params & params) {
+struct llama_model_params common_model_params_to_llama(common_params & params) {
     auto mparams = llama_model_default_params();
 
+    if (!params.devices.empty()) {
+        mparams.devices = params.devices.data();
+    }
     if (params.n_gpu_layers != -1) {
         mparams.n_gpu_layers = params.n_gpu_layers;
     }
index 5c579b5abfe03f638d321f6645e2ffe77dbfbb50..286642db241587a8af8d737ff62ba04be2e1a944 100644 (file)
@@ -156,6 +156,7 @@ struct common_params_sampling {
 };
 
 struct common_params_speculative {
+    std::vector<ggml_backend_dev_t> devices; // devices to use for offloading
     int32_t n_ctx        =     0; // draft context size
     int32_t n_max        =    16; // maximum number of tokens to draft during speculative decoding
     int32_t n_min        =     5; // minimum number of draft tokens to use for speculative decoding
@@ -178,9 +179,6 @@ struct common_params {
     int32_t n_chunks              =    -1; // max number of chunks to process (-1 = unlimited)
     int32_t n_parallel            =     1; // number of parallel sequences to decode
     int32_t n_sequences           =     1; // number of sequences to decode
-    int32_t n_gpu_layers          =    -1; // number of layers to store in VRAM (-1 - use default)
-    int32_t main_gpu              =     0; // the GPU that is used for scratch and small tensors
-    float   tensor_split[128]     =   {0}; // how split tensors should be distributed across GPUs
     int32_t grp_attn_n            =     1; // group-attention factor
     int32_t grp_attn_w            =   512; // group-attention width
     int32_t n_print               =    -1; // print token count every n tokens (-1 = disabled)
@@ -193,6 +191,13 @@ struct common_params {
     int32_t yarn_orig_ctx         =     0; // YaRN original context length
     float   defrag_thold          =  0.1f; // KV cache defragmentation threshold
 
+    // offload params
+    std::vector<ggml_backend_dev_t> devices;         // devices to use for offloading
+    int32_t n_gpu_layers                    =    -1; // number of layers to store in VRAM (-1 - use default)
+    int32_t main_gpu                        =     0; // the GPU that is used for scratch and small tensors
+    float   tensor_split[128]               =   {0}; // how split tensors should be distributed across GPUs
+    enum llama_split_mode        split_mode = LLAMA_SPLIT_MODE_LAYER; // how to split the model across GPUs
+
     struct cpu_params cpuparams;
     struct cpu_params cpuparams_batch;
 
@@ -201,7 +206,6 @@ struct common_params {
 
     ggml_numa_strategy numa = GGML_NUMA_STRATEGY_DISABLED;
 
-    enum llama_split_mode        split_mode        = LLAMA_SPLIT_MODE_LAYER; // how to split the model across GPUs
     enum llama_rope_scaling_type rope_scaling_type = LLAMA_ROPE_SCALING_TYPE_UNSPECIFIED;
     enum llama_pooling_type      pooling_type      = LLAMA_POOLING_TYPE_UNSPECIFIED; // pooling type for embeddings
     enum llama_attention_type    attention_type    = LLAMA_ATTENTION_TYPE_UNSPECIFIED; // attention type for embeddings
@@ -462,7 +466,7 @@ struct common_init_result {
 
 struct common_init_result     common_init_from_params(common_params & params);
 
-struct llama_model_params     common_model_params_to_llama  (const common_params & params);
+struct llama_model_params     common_model_params_to_llama  (      common_params & params);
 struct llama_context_params   common_context_params_to_llama(const common_params & params);
 struct ggml_threadpool_params ggml_threadpool_params_from_cpu_params(const cpu_params & params);
 
index f9d20fee5a6f669466949b72d1954ce67170efbe..8684771e2d0bbd2f08aaa53179eef6e01102f8ff 100644 (file)
@@ -692,6 +692,7 @@ struct server_context {
 
             auto params_dft = params_base;
 
+            params_dft.devices      = params_base.speculative.devices;
             params_dft.model        = params_base.speculative.model;
             params_dft.n_ctx        = params_base.speculative.n_ctx;
             params_dft.n_gpu_layers = params_base.speculative.n_gpu_layers;
index 1bc7f428cf3b5729f6a8c92ee21bcc2fb74e4a6a..7bf9056bf6db10fa9139ec8c4baa12d11081fe5b 100644 (file)
@@ -46,6 +46,7 @@ int main(int argc, char ** argv) {
     ctx_tgt   = llama_init_tgt.context;
 
     // load the draft model
+    params.devices      = params.speculative.devices;
     params.model        = params.speculative.model;
     params.n_ctx        = params.speculative.n_ctx;
     params.n_batch      = params.speculative.n_ctx > 0 ? params.speculative.n_ctx : params.n_batch;
index eb8bb2de54a268e1465814e720ae98edb1bde1dc..d4ad9751e813a89a4037e3e654359f2db656f568 100644 (file)
@@ -76,6 +76,7 @@ int main(int argc, char ** argv) {
     ctx_tgt = llama_init_tgt.context;
 
     // load the draft model
+    params.devices = params.speculative.devices;
     params.model = params.speculative.model;
     params.n_gpu_layers = params.speculative.n_gpu_layers;
     if (params.speculative.cpuparams.n_threads > 0) {
index 43d03d7fa73856ba1ec928953210fcaeef716251..a0e0e2c5852f76b465ea466c85695e8e2d92dc84 100644 (file)
@@ -253,6 +253,15 @@ void ggml_backend_device_register(ggml_backend_dev_t device) {
 }
 
 // Backend (reg) enumeration
+static bool striequals(const char * a, const char * b) {
+    for (; *a && *b; a++, b++) {
+        if (std::tolower(*a) != std::tolower(*b)) {
+            return false;
+        }
+    }
+    return *a == *b;
+}
+
 size_t ggml_backend_reg_count() {
     return get_reg().backends.size();
 }
@@ -265,7 +274,7 @@ ggml_backend_reg_t ggml_backend_reg_get(size_t index) {
 ggml_backend_reg_t ggml_backend_reg_by_name(const char * name) {
     for (size_t i = 0; i < ggml_backend_reg_count(); i++) {
         ggml_backend_reg_t reg = ggml_backend_reg_get(i);
-        if (std::strcmp(ggml_backend_reg_name(reg), name) == 0) {
+        if (striequals(ggml_backend_reg_name(reg), name)) {
             return reg;
         }
     }
@@ -285,7 +294,7 @@ ggml_backend_dev_t ggml_backend_dev_get(size_t index) {
 ggml_backend_dev_t ggml_backend_dev_by_name(const char * name) {
     for (size_t i = 0; i < ggml_backend_dev_count(); i++) {
         ggml_backend_dev_t dev = ggml_backend_dev_get(i);
-        if (strcmp(ggml_backend_dev_name(dev), name) == 0) {
+        if (striequals(ggml_backend_dev_name(dev), name)) {
             return dev;
         }
     }
index 90791d5f5ea126a53840030d9e95348739ba4360..ab5e376e6c7f21f025a4bdce61690a3bc826aa36 100644 (file)
@@ -272,6 +272,9 @@ extern "C" {
     };
 
     struct llama_model_params {
+        // NULL-terminated list of devices to use for offloading (if NULL, all available devices are used)
+        ggml_backend_dev_t * devices;
+
         int32_t n_gpu_layers; // number of layers to store in VRAM
         enum llama_split_mode split_mode; // how to split the model across multiple GPUs
 
index 83bbc10a57a432e9e97f029cb67ca2de53d853cf..571cb68e2885495865707ac9c30cee5099ba3399 100644 (file)
@@ -19364,6 +19364,7 @@ void llama_lora_adapter_free(struct llama_lora_adapter * adapter) {
 //
 struct llama_model_params llama_model_default_params() {
     struct llama_model_params result = {
+        /*.devices                     =*/ nullptr,
         /*.n_gpu_layers                =*/ 0,
         /*.split_mode                  =*/ LLAMA_SPLIT_MODE_LAYER,
         /*.main_gpu                    =*/ 0,
@@ -19576,19 +19577,24 @@ struct llama_model * llama_load_model_from_file(
     }
 
     // create list of devices to use with this model
-    // currently, we use all available devices
-    // TODO: rework API to give user more control over device selection
-    for (size_t i = 0; i < ggml_backend_dev_count(); ++i) {
-        ggml_backend_dev_t dev = ggml_backend_dev_get(i);
-        switch (ggml_backend_dev_type(dev)) {
-            case GGML_BACKEND_DEVICE_TYPE_CPU:
-            case GGML_BACKEND_DEVICE_TYPE_ACCEL:
-                // skip CPU backends since they are handled separately
-                break;
+    if (params.devices) {
+        for (ggml_backend_dev_t * dev = params.devices; *dev; ++dev) {
+            model->devices.push_back(*dev);
+        }
+    } else {
+        // use all available devices
+        for (size_t i = 0; i < ggml_backend_dev_count(); ++i) {
+            ggml_backend_dev_t dev = ggml_backend_dev_get(i);
+            switch (ggml_backend_dev_type(dev)) {
+                case GGML_BACKEND_DEVICE_TYPE_CPU:
+                case GGML_BACKEND_DEVICE_TYPE_ACCEL:
+                    // skip CPU backends since they are handled separately
+                    break;
 
-            case GGML_BACKEND_DEVICE_TYPE_GPU:
-                model->devices.push_back(dev);
-                break;
+                case GGML_BACKEND_DEVICE_TYPE_GPU:
+                    model->devices.push_back(dev);
+                    break;
+            }
         }
     }