]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
Adding support for the --numa argument for llama-bench. (#7080)
authorkunnis <redacted>
Sun, 5 May 2024 12:17:47 +0000 (07:17 -0500)
committerGitHub <redacted>
Sun, 5 May 2024 12:17:47 +0000 (14:17 +0200)
examples/llama-bench/llama-bench.cpp

index 95c3095dd04da3ffd310eb40fcbc17586e54807e..40128ec444334797a119b6ea27e35e1ba3302c16 100644 (file)
@@ -178,6 +178,7 @@ struct cmd_params {
     std::vector<std::vector<float>> tensor_split;
     std::vector<bool> use_mmap;
     std::vector<bool> embeddings;
+    ggml_numa_strategy numa;
     int reps;
     bool verbose;
     output_formats output_format;
@@ -200,6 +201,7 @@ static const cmd_params cmd_params_defaults = {
     /* tensor_split  */ {std::vector<float>(llama_max_devices(), 0.0f)},
     /* use_mmap      */ {true},
     /* embeddings    */ {false},
+    /* numa          */ GGML_NUMA_STRATEGY_DISABLED,
     /* reps          */ 5,
     /* verbose       */ false,
     /* output_format */ MARKDOWN
@@ -224,6 +226,7 @@ static void print_usage(int /* argc */, char ** argv) {
     printf("  -nkvo, --no-kv-offload <0|1>        (default: %s)\n", join(cmd_params_defaults.no_kv_offload, ",").c_str());
     printf("  -fa, --flash-attn <0|1>             (default: %s)\n", join(cmd_params_defaults.flash_attn, ",").c_str());
     printf("  -mmp, --mmap <0|1>                  (default: %s)\n", join(cmd_params_defaults.use_mmap, ",").c_str());
+    printf("  --numa <distribute|isolate|numactl> (default: disabled)\n");
     printf("  -embd, --embeddings <0|1>           (default: %s)\n", join(cmd_params_defaults.embeddings, ",").c_str());
     printf("  -ts, --tensor-split <ts0/ts1/..>    (default: 0)\n");
     printf("  -r, --repetitions <n>               (default: %d)\n", cmd_params_defaults.reps);
@@ -396,6 +399,17 @@ static cmd_params parse_cmd_params(int argc, char ** argv) {
             }
             auto p = split<bool>(argv[i], split_delim);
             params.no_kv_offload.insert(params.no_kv_offload.end(), p.begin(), p.end());
+        } else if (arg == "--numa") {
+            if (++i >= argc) {
+                invalid_param = true;
+                break;
+            } else {
+                std::string value(argv[i]);
+                /**/ if (value == "distribute" || value == "" ) { params.numa = GGML_NUMA_STRATEGY_DISTRIBUTE; }
+                else if (value == "isolate")                    { params.numa = GGML_NUMA_STRATEGY_ISOLATE; }
+                else if (value == "numactl")                    { params.numa = GGML_NUMA_STRATEGY_NUMACTL; }
+                else { invalid_param = true; break; }
+            }
         } else if (arg == "-fa" || arg == "--flash-attn") {
             if (++i >= argc) {
                 invalid_param = true;
@@ -1215,6 +1229,7 @@ int main(int argc, char ** argv) {
         llama_log_set(llama_null_log_callback, NULL);
     }
     llama_backend_init();
+    llama_numa_init(params.numa);
 
     // initialize printer
     std::unique_ptr<printer> p;