]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
llama : add option to override model tensor buffers (#11397) upstream/0.0.5028
authorDiego Devesa <redacted>
Wed, 2 Apr 2025 12:52:01 +0000 (14:52 +0200)
committerGitHub <redacted>
Wed, 2 Apr 2025 12:52:01 +0000 (14:52 +0200)
* llama : add option to override tensor buffers

* ggml : fix possible underflow in ggml_nbytes

12 files changed:
common/arg.cpp
common/common.cpp
common/common.h
ggml/src/ggml.c
include/llama.h
src/llama-context.cpp
src/llama-model-loader.cpp
src/llama-model-loader.h
src/llama-model.cpp
src/llama-model.h
src/llama-quant.cpp
src/llama.cpp

index 47c26955ea37459454e773d4b25ef7317d59743a..fa22e86cd14e663d44dad4d1afff6ccca191046b 100644 (file)
@@ -1,6 +1,7 @@
 #include "gguf.h" // for reading GGUF splits
 #include "arg.h"
 
+#include "common.h"
 #include "log.h"
 #include "sampling.h"
 #include "chat.h"
@@ -848,6 +849,10 @@ static bool common_params_parse_ex(int argc, char ** argv, common_params_context
         params.kv_overrides.back().key[0] = 0;
     }
 
+    if (!params.tensor_buft_overrides.empty()) {
+        params.tensor_buft_overrides.push_back({nullptr, nullptr});
+    }
+
     if (params.reranking && params.embedding) {
         throw std::invalid_argument("error: either --embedding or --reranking can be specified, but not both");
     }
@@ -2180,6 +2185,41 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
             exit(0);
         }
     ));
+    add_opt(common_arg(
+        {"--override-tensor", "-ot"}, "<tensor name pattern>=<buffer type>,...",
+        "override tensor buffer type", [](common_params & params, const std::string & value) {
+            /* static */ std::map<std::string, ggml_backend_buffer_type_t> buft_list;
+            if (buft_list.empty()) {
+                // enumerate all the devices and add their buffer types to the list
+                for (size_t i = 0; i < ggml_backend_dev_count(); ++i) {
+                    auto * dev = ggml_backend_dev_get(i);
+                    auto * buft = ggml_backend_dev_buffer_type(dev);
+                    if (buft) {
+                        buft_list[ggml_backend_buft_name(buft)] = buft;
+                    }
+                }
+            }
+
+            for (const auto & override : string_split<std::string>(value, ',')) {
+                std::string::size_type pos = override.find('=');
+                if (pos == std::string::npos) {
+                    throw std::invalid_argument("invalid value");
+                }
+                std::string tensor_name = override.substr(0, pos);
+                std::string buffer_type = override.substr(pos + 1);
+
+                if (buft_list.find(buffer_type) == buft_list.end()) {
+                    printf("Available buffer types:\n");
+                    for (const auto & it : buft_list) {
+                        printf("  %s\n", ggml_backend_buft_name(it.second));
+                    }
+                    throw std::invalid_argument("unknown buffer type");
+                }
+                // FIXME: this leaks memory
+                params.tensor_buft_overrides.push_back({strdup(tensor_name.c_str()), buft_list.at(buffer_type)});
+            }
+        }
+    ));
     add_opt(common_arg(
         {"-ngl", "--gpu-layers", "--n-gpu-layers"}, "N",
         "number of layers to store in VRAM",
index e7269ead4f94e7355a15e544f2fb70103f3ece8d..d4882c5123cce4c6afc0b0bda68f150c3d027d07 100644 (file)
@@ -1042,15 +1042,18 @@ struct llama_model_params common_model_params_to_llama(common_params & params) {
     if (!params.devices.empty()) {
         mparams.devices = params.devices.data();
     }
+
     if (params.n_gpu_layers != -1) {
         mparams.n_gpu_layers = params.n_gpu_layers;
     }
+
     mparams.main_gpu        = params.main_gpu;
     mparams.split_mode      = params.split_mode;
     mparams.tensor_split    = params.tensor_split;
     mparams.use_mmap        = params.use_mmap;
     mparams.use_mlock       = params.use_mlock;
     mparams.check_tensors   = params.check_tensors;
+
     if (params.kv_overrides.empty()) {
         mparams.kv_overrides = NULL;
     } else {
@@ -1058,6 +1061,13 @@ struct llama_model_params common_model_params_to_llama(common_params & params) {
         mparams.kv_overrides = params.kv_overrides.data();
     }
 
+    if (params.tensor_buft_overrides.empty()) {
+        mparams.tensor_buft_overrides = NULL;
+    } else {
+        GGML_ASSERT(params.tensor_buft_overrides.back().pattern == nullptr && "Tensor buffer overrides not terminated with empty pattern");
+        mparams.tensor_buft_overrides = params.tensor_buft_overrides.data();
+    }
+
     return mparams;
 }
 
index ea7aef99d918acec20e145f6867918aef622882f..725b5123d24f989ed7f3ca76af431ed69b1711db 100644 (file)
@@ -279,6 +279,7 @@ struct common_params {
     std::vector<std::string> in_files;   // all input files
     std::vector<std::string> antiprompt; // strings upon which more user input is prompted (a.k.a. reverse prompts)
     std::vector<llama_model_kv_override> kv_overrides;
+    std::vector<llama_model_tensor_buft_override> tensor_buft_overrides;
 
     bool lora_init_without_apply = false; // only load lora to memory, but do not apply it to ctx (user can manually apply lora later using llama_adapter_lora_apply)
     std::vector<common_adapter_lora_info> lora_adapters; // lora adapter path with user defined scale
index 161dd3fa94547c9a7b2b83372e65920b5019d65d..3e274d6ae39614b106270de27a4b158c95db072a 100644 (file)
@@ -1159,6 +1159,12 @@ int64_t ggml_nrows(const struct ggml_tensor * tensor) {
 }
 
 size_t ggml_nbytes(const struct ggml_tensor * tensor) {
+    for (int i = 0; i < GGML_MAX_DIMS; ++i) {
+        if (tensor->ne[i] <= 0) {
+            return 0;
+        }
+    }
+
     size_t nbytes;
     const size_t blck_size = ggml_blck_size(tensor->type);
     if (blck_size == 1) {
index 468ab1fa485dab3d33a439f659ab634fee341051..fca2b034ba27031fd3e063985966e8c8e2837912 100644 (file)
@@ -280,10 +280,18 @@ extern "C" {
         };
     };
 
+    struct llama_model_tensor_buft_override {
+        const char * pattern;
+        ggml_backend_buffer_type_t buft;
+    };
+
     struct llama_model_params {
         // NULL-terminated list of devices to use for offloading (if NULL, all available devices are used)
         ggml_backend_dev_t * devices;
 
+        // NULL-terminated list of buffer types to use for tensors that match a pattern
+        const struct llama_model_tensor_buft_override * tensor_buft_overrides;
+
         int32_t n_gpu_layers; // number of layers to store in VRAM
         enum llama_split_mode split_mode; // how to split the model across multiple GPUs
 
index 7d067afbe73995425fd47a636e5d250bbfa15153..3927079432d941a4bf7c601bb8bc396c09d7688a 100644 (file)
@@ -255,7 +255,8 @@ llama_context::llama_context(
             model.n_devices() > 1 &&
             model.params.n_gpu_layers > (int) model.hparams.n_layer &&
             model.params.split_mode == LLAMA_SPLIT_MODE_LAYER &&
-            cparams.offload_kqv;
+            cparams.offload_kqv &&
+            !model.has_tensor_overrides();
 
         // pipeline parallelism requires support for async compute and events in all devices
         if (pipeline_parallel) {
index 1be0f2d6d6c20a852369e9df5779dfd85d228f06..ec1d78e3144eb1c78ca790b481018b904cd506d2 100644 (file)
@@ -445,7 +445,8 @@ llama_model_loader::llama_model_loader(
         std::vector<std::string> & splits,
         bool use_mmap,
         bool check_tensors,
-        const struct llama_model_kv_override * param_overrides_p) {
+        const llama_model_kv_override * param_overrides_p,
+        const llama_model_tensor_buft_override * param_tensor_buft_overrides_p) {
     int trace = 0;
     if (getenv("LLAMA_TRACE")) {
         trace = atoi(getenv("LLAMA_TRACE"));
@@ -457,6 +458,8 @@ llama_model_loader::llama_model_loader(
         }
     }
 
+    tensor_buft_overrides = param_tensor_buft_overrides_p;
+
     // Load the main GGUF
     struct ggml_context * ctx = NULL;
     struct gguf_init_params params = {
index fe35404b26889ff76ce54bcd9cd65c5057a18fe2..0f52b011b698624e560c767d4ad5d6cd3140343c 100644 (file)
@@ -77,8 +77,9 @@ struct llama_model_loader {
 
     llama_mmaps mappings;
 
-    std::map<std::string, struct llama_tensor_weight, weight_name_comparer> weights_map;
-    std::unordered_map<std::string, struct llama_model_kv_override> kv_overrides;
+    std::map<std::string, llama_tensor_weight, weight_name_comparer> weights_map;
+    std::unordered_map<std::string, llama_model_kv_override> kv_overrides;
+    const llama_model_tensor_buft_override * tensor_buft_overrides;
 
     gguf_context_ptr meta;
     std::vector<ggml_context_ptr> contexts;
@@ -95,7 +96,8 @@ struct llama_model_loader {
         std::vector<std::string> & splits, // optional, only need if the split does not follow naming scheme
         bool use_mmap,
         bool check_tensors,
-        const struct llama_model_kv_override * param_overrides_p);
+        const llama_model_kv_override * param_overrides_p,
+        const llama_model_tensor_buft_override * param_tensor_buft_overrides_p);
 
     template<typename T>
     typename std::enable_if<std::is_integral<T>::value, bool>::type
index 8d525e1bec4e06f4a92322e06c2b0d1d90544896..ca6e3ab2caeb162c7759789570dae3c2b8da3c6e 100644 (file)
@@ -17,6 +17,7 @@
 #include <cmath>
 #include <functional>
 #include <map>
+#include <regex>
 #include <sstream>
 #include <stdexcept>
 
@@ -378,9 +379,12 @@ struct llama_model::impl {
     layer_dev dev_input = {};
     layer_dev dev_output = {};
     std::vector<layer_dev> dev_layer;
+
+    bool has_tensor_overrides;
 };
 
 llama_model::llama_model(const llama_model_params & params) : params(params), pimpl(std::make_unique<impl>()) {
+    pimpl->has_tensor_overrides = params.tensor_buft_overrides && params.tensor_buft_overrides[0].pattern;
 }
 
 llama_model::~llama_model() {}
@@ -1571,9 +1575,26 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
                     GGML_ABORT("invalid layer %d for tensor %s", info.layer, tn.str().c_str());
             }
 
-            ggml_backend_buffer_type_t buft = select_weight_buft(hparams, t_meta, op, *buft_list);
+            ggml_backend_buffer_type_t buft = nullptr;
+
+            // check overrides
+            if (ml.tensor_buft_overrides) {
+                std::string tensor_name = tn.str();
+                for (const auto * overrides = ml.tensor_buft_overrides; overrides->pattern != nullptr; ++overrides) {
+                    std::regex pattern(overrides->pattern);
+                    if (std::regex_search(tensor_name, pattern)) {
+                        LLAMA_LOG_DEBUG("tensor %s buffer type overriden to %s\n", tensor_name.c_str(), ggml_backend_buft_name(overrides->buft));
+                        buft = overrides->buft;
+                        break;
+                    }
+                }
+            }
+
             if (!buft) {
-                throw std::runtime_error(format("failed to find a compatible buffer type for tensor %s", tn.str().c_str()));
+                buft = select_weight_buft(hparams, t_meta, op, *buft_list);
+                if (!buft) {
+                    throw std::runtime_error(format("failed to find a compatible buffer type for tensor %s", tn.str().c_str()));
+                }
             }
 
             // avoid using a host buffer when using mmap
@@ -4151,6 +4172,10 @@ ggml_backend_buffer_type_t llama_model::select_buft(int il) const {
             });
 }
 
+bool llama_model::has_tensor_overrides() const {
+    return pimpl->has_tensor_overrides;
+}
+
 const ggml_tensor * llama_model::get_tensor(const char * name) const {
     auto it = std::find_if(tensors_by_name.begin(), tensors_by_name.end(),
             [name](const std::pair<std::string, ggml_tensor *> & it) {
@@ -12319,6 +12344,7 @@ llm_graph_result_ptr llama_model::build_graph(
 llama_model_params llama_model_default_params() {
     llama_model_params result = {
         /*.devices                     =*/ nullptr,
+        /*.tensor_buft_overrides       =*/ nullptr,
         /*.n_gpu_layers                =*/ 0,
         /*.split_mode                  =*/ LLAMA_SPLIT_MODE_LAYER,
         /*.main_gpu                    =*/ 0,
index f1bf0df3a4ef63bcd352c9e88f63744ea3b1bf38..91e6e8725acd28efeb8d92a100c00bfbeb865e9b 100644 (file)
@@ -382,6 +382,8 @@ struct llama_model {
 
     ggml_backend_buffer_type_t select_buft(int il) const;
 
+    bool has_tensor_overrides() const;
+
     const struct ggml_tensor * get_tensor(const char * name) const;
 
     // TODO: move this to new llm_arch_model_i interface
index 09eb570779ce538ffaaf4c4385d388ce63fa87bf..e3e10fa6cf77f17b768e0c7514158df6a9a5d161 100644 (file)
@@ -527,7 +527,7 @@ static void llama_model_quantize_impl(const std::string & fname_inp, const std::
     }
 
     std::vector<std::string> splits = {};
-    llama_model_loader ml(fname_inp, splits, use_mmap, /*check_tensors*/ true, kv_overrides);
+    llama_model_loader ml(fname_inp, splits, use_mmap, /*check_tensors*/ true, kv_overrides, nullptr);
     ml.init_mappings(false); // no prefetching
 
     llama_model model(llama_model_default_params());
index 81e1dd1d0873a1ecc9617ccdc60b25082711f41e..d5164720b21962869b73483f21e49d4e3264272f 100644 (file)
@@ -92,7 +92,7 @@ static int llama_model_load(const std::string & fname, std::vector<std::string>
     model.t_start_us = tm.t_start_us;
 
     try {
-        llama_model_loader ml(fname, splits, params.use_mmap, params.check_tensors, params.kv_overrides);
+        llama_model_loader ml(fname, splits, params.use_mmap, params.check_tensors, params.kv_overrides, params.tensor_buft_overrides);
 
         ml.print_info();