]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
common : add --override-tensor-draft, --cpu-moe-draft and --n-cpu-moe-draft paramete...
authorCopilot <redacted>
Wed, 13 Aug 2025 10:44:40 +0000 (12:44 +0200)
committerGitHub <redacted>
Wed, 13 Aug 2025 10:44:40 +0000 (12:44 +0200)
* Checkpoint from VS Code for coding agent session

* Initial plan

* Fix typo in --override-tensor-draft flag implementation

* Add null termination for speculative tensor buffer overrides

* Apply suggestions from code review

* Apply suggestions from code review

* Extract tensor override parsing logic to common function (addresses @slaren's feedback)

* Apply suggestions from code review

* Apply suggestions

---------

Co-authored-by: Sigbjørn Skjæret <redacted>
Co-authored-by: Georgi Gerganov <redacted>
Co-authored-by: Diego Devesa <redacted>
common/arg.cpp
common/common.h
examples/speculative-simple/speculative-simple.cpp
examples/speculative/speculative.cpp
tools/server/server.cpp

index 3d18aaa171ce41237455c1e42524d062cc11fcb2..4e4c52b5f8748db54cc5dc244d5fe2d3a8c72ed7 100644 (file)
@@ -749,6 +749,39 @@ std::pair<long, std::vector<char>> common_remote_get_content(const std::string &
 // utils
 //
 
+// Helper function to parse tensor buffer override strings
+static void parse_tensor_buffer_overrides(const std::string & value, std::vector<llama_model_tensor_buft_override> & overrides) {
+    std::map<std::string, ggml_backend_buffer_type_t> buft_list;
+    for (size_t i = 0; i < ggml_backend_dev_count(); ++i) {
+        auto * dev = ggml_backend_dev_get(i);
+        auto * buft = ggml_backend_dev_buffer_type(dev);
+        if (buft) {
+            buft_list[ggml_backend_buft_name(buft)] = buft;
+        }
+    }
+
+    for (const auto & override : string_split<std::string>(value, ',')) {
+        std::string::size_type pos = override.find('=');
+        if (pos == std::string::npos) {
+            throw std::invalid_argument("invalid value");
+        }
+        std::string tensor_name = override.substr(0, pos);
+        std::string buffer_type = override.substr(pos + 1);
+
+        if (buft_list.find(buffer_type) == buft_list.end()) {
+            printf("Available buffer types:\n");
+            for (const auto & it : buft_list) {
+                printf("  %s\n", ggml_backend_buft_name(it.second));
+            }
+            throw std::invalid_argument("unknown buffer type");
+        }
+        // keep strings alive and avoid leaking memory by storing them in a static vector
+        static std::list<std::string> buft_overrides;
+        buft_overrides.push_back(tensor_name);
+        overrides.push_back({buft_overrides.back().c_str(), buft_list.at(buffer_type)});
+    }
+}
+
 struct handle_model_result {
     bool found_mmproj = false;
     common_params_model mmproj;
@@ -993,6 +1026,10 @@ static bool common_params_parse_ex(int argc, char ** argv, common_params_context
         params.tensor_buft_overrides.push_back({nullptr, nullptr});
     }
 
+    if (!params.speculative.tensor_buft_overrides.empty()) {
+        params.speculative.tensor_buft_overrides.push_back({nullptr, nullptr});
+    }
+
     if (!params.chat_template.empty() && !common_chat_verify_template(params.chat_template, params.use_jinja)) {
         throw std::runtime_error(string_format(
             "error: the supplied chat template is not supported: %s%s\n",
@@ -2349,40 +2386,15 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
     add_opt(common_arg(
         {"--override-tensor", "-ot"}, "<tensor name pattern>=<buffer type>,...",
         "override tensor buffer type", [](common_params & params, const std::string & value) {
-            /* static */ std::map<std::string, ggml_backend_buffer_type_t> buft_list;
-            if (buft_list.empty()) {
-                // enumerate all the devices and add their buffer types to the list
-                for (size_t i = 0; i < ggml_backend_dev_count(); ++i) {
-                    auto * dev = ggml_backend_dev_get(i);
-                    auto * buft = ggml_backend_dev_buffer_type(dev);
-                    if (buft) {
-                        buft_list[ggml_backend_buft_name(buft)] = buft;
-                    }
-                }
-            }
-
-            for (const auto & override : string_split<std::string>(value, ',')) {
-                std::string::size_type pos = override.find('=');
-                if (pos == std::string::npos) {
-                    throw std::invalid_argument("invalid value");
-                }
-                std::string tensor_name = override.substr(0, pos);
-                std::string buffer_type = override.substr(pos + 1);
-
-                if (buft_list.find(buffer_type) == buft_list.end()) {
-                    printf("Available buffer types:\n");
-                    for (const auto & it : buft_list) {
-                        printf("  %s\n", ggml_backend_buft_name(it.second));
-                    }
-                    throw std::invalid_argument("unknown buffer type");
-                }
-                // keep strings alive and avoid leaking memory by storing them in a static vector
-                static std::list<std::string> buft_overrides;
-                buft_overrides.push_back(tensor_name);
-                params.tensor_buft_overrides.push_back({buft_overrides.back().c_str(), buft_list.at(buffer_type)});
-            }
+            parse_tensor_buffer_overrides(value, params.tensor_buft_overrides);
         }
     ));
+    add_opt(common_arg(
+        {"--override-tensor-draft", "-otd"}, "<tensor name pattern>=<buffer type>,...",
+        "override tensor buffer type for draft model", [](common_params & params, const std::string & value) {
+            parse_tensor_buffer_overrides(value, params.speculative.tensor_buft_overrides);
+        }
+    ).set_examples({LLAMA_EXAMPLE_SPECULATIVE, LLAMA_EXAMPLE_SERVER}));
     add_opt(common_arg(
         {"--cpu-moe", "-cmoe"},
         "keep all Mixture of Experts (MoE) weights in the CPU",
@@ -2405,6 +2417,27 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
             }
         }
     ).set_env("LLAMA_ARG_N_CPU_MOE"));
+    add_opt(common_arg(
+        {"--cpu-moe-draft", "-cmoed"},
+        "keep all Mixture of Experts (MoE) weights in the CPU for the draft model",
+        [](common_params & params) {
+            params.speculative.tensor_buft_overrides.push_back({"\\.ffn_(up|down|gate)_exps", ggml_backend_cpu_buffer_type()});
+        }
+    ).set_examples({LLAMA_EXAMPLE_SPECULATIVE, LLAMA_EXAMPLE_SERVER}).set_env("LLAMA_ARG_CPU_MOE_DRAFT"));
+    add_opt(common_arg(
+        {"--n-cpu-moe-draft", "-ncmoed"}, "N",
+        "keep the Mixture of Experts (MoE) weights of the first N layers in the CPU for the draft model",
+        [](common_params & params, int value) {
+            if (value < 0) {
+                throw std::invalid_argument("invalid value");
+            }
+            for (int i = 0; i < value; ++i) {
+                static std::list<std::string> buft_overrides_draft;
+                buft_overrides_draft.push_back(string_format("blk\\.%d\\.ffn_(up|down|gate)_exps", i));
+                params.speculative.tensor_buft_overrides.push_back({buft_overrides_draft.back().c_str(), ggml_backend_cpu_buffer_type()});
+            }
+        }
+    ).set_examples({LLAMA_EXAMPLE_SPECULATIVE, LLAMA_EXAMPLE_SERVER}).set_env("LLAMA_ARG_N_CPU_MOE_DRAFT"));
     add_opt(common_arg(
         {"-ngl", "--gpu-layers", "--n-gpu-layers"}, "N",
         "number of layers to store in VRAM",
index 5eab199af559ea68351d25a60c623c3ca4e46763..c09509b669e54d66fc746f97027e220e5037cc52 100644 (file)
@@ -202,6 +202,7 @@ struct common_params_speculative {
     float   p_split      =  0.1f; // speculative decoding split probability
     float   p_min        = 0.75f; // minimum speculative decoding probability (greedy)
     std::vector<std::pair<std::string, std::string>> replacements; // main to speculative model replacements
+    std::vector<llama_model_tensor_buft_override> tensor_buft_overrides;
 
     ggml_type cache_type_k = GGML_TYPE_F16; // KV cache data type for the K
     ggml_type cache_type_v = GGML_TYPE_F16; // KV cache data type for the V
index 722cd7f40f088b1a535493c7c82fe13f4ecd1df5..a8e53f28eb59730c81105490442677b7cbea6129 100644 (file)
@@ -59,6 +59,8 @@ int main(int argc, char ** argv) {
     }
 
     params.cpuparams_batch.n_threads = params.speculative.cpuparams_batch.n_threads;
+    params.tensor_buft_overrides     = params.speculative.tensor_buft_overrides;
+
     common_init_result llama_init_dft = common_init_from_params(params);
 
     //model_dft = llama_init_dft.model.get();
index 0adffdb006bcfb58b1ed14dcb189b4b0fe251283..8449406a6d27a490d9e9ad2b5883736120d8f43c 100644 (file)
@@ -85,6 +85,8 @@ int main(int argc, char ** argv) {
     }
 
     params.cpuparams_batch.n_threads = params.speculative.cpuparams_batch.n_threads;
+    params.tensor_buft_overrides     = params.speculative.tensor_buft_overrides;
+
     common_init_result llama_init_dft = common_init_from_params(params);
 
     model_dft = llama_init_dft.model.get();
index 45f50e93769e1dfaf54550f960506213d78ad450..f549cda476657bbb018388749a5ab66e8662172b 100644 (file)
@@ -2015,6 +2015,8 @@ struct server_context {
             params_dft.cache_type_k = params_base.speculative.cache_type_k;
             params_dft.cache_type_v = params_base.speculative.cache_type_v;
 
+            params_dft.tensor_buft_overrides = params_base.speculative.tensor_buft_overrides;
+
             llama_init_dft = common_init_from_params(params_dft);
 
             model_dft = llama_init_dft.model.get();