]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
llama : add `llama_model_load_from_splits` (#11255)
authorXuan Son Nguyen <redacted>
Thu, 16 Jan 2025 12:54:08 +0000 (13:54 +0100)
committerGitHub <redacted>
Thu, 16 Jan 2025 12:54:08 +0000 (13:54 +0100)
* llama : add `llama_model_load_from_splits`

* update

include/llama.h
src/llama-model-loader.cpp
src/llama-model-loader.h
src/llama-quant.cpp
src/llama.cpp

index a184884c77a51d67b3c824e922329b66130b3a93..352c3417ecb5bb0d8cc931db781dcb51f74ce40b 100644 (file)
@@ -418,10 +418,20 @@ extern "C" {
               struct llama_model_params   params),
             "use llama_model_load_from_file instead");
 
+    // Load the model from a file
+    // If the file is split into multiple parts, the file name must follow this pattern: <name>-%05d-of-%05d.gguf
+    // If the split file name does not follow this pattern, use llama_model_load_from_splits
     LLAMA_API struct llama_model * llama_model_load_from_file(
                              const char * path_model,
               struct llama_model_params   params);
 
+    // Load the model from multiple splits (support custom naming scheme)
+    // The paths must be in the correct order
+    LLAMA_API struct llama_model * llama_model_load_from_splits(
+                             const char ** paths,
+                                 size_t    n_paths,
+              struct llama_model_params    params);
+
     DEPRECATED(LLAMA_API void llama_free_model(struct llama_model * model),
             "use llama_model_free instead");
 
index 53175f0e069a6d2041ad5136b9f8fa66d42e4fbf..75073bf610ac3323359c78d898415e47f9f708c7 100644 (file)
@@ -64,6 +64,33 @@ static std::string llama_model_ftype_name(llama_ftype ftype) {
     }
 }
 
+// return a list of splits for a given path
+// for example, given "<name>-00002-of-00004.gguf", returns list of all 4 splits
+static std::vector<std::string> llama_get_list_splits(const std::string & path, const int idx, const int n_split) {
+    std::vector<std::string> paths;
+    std::string split_prefix;
+    std::vector<char> buf(llama_path_max(), 0);
+
+    {
+        int ret = llama_split_prefix(buf.data(), buf.size(), path.c_str(), idx, n_split);
+        if (!ret) {
+            throw std::runtime_error(format("invalid split file name: %s", path.c_str()));
+        }
+        split_prefix = std::string(buf.data(), ret);
+    }
+
+    if (split_prefix.empty()) {
+        throw std::runtime_error(format("invalid split file: %s", path.c_str()));
+    }
+
+    for (int idx = 0; idx < n_split; ++idx) {
+        int ret = llama_split_path(buf.data(), buf.size(), split_prefix.c_str(), idx, n_split);
+        paths.push_back(std::string(buf.data(), ret));
+    }
+
+    return paths;
+}
+
 namespace GGUFMeta {
     template <typename T, gguf_type gt_, T (*gfun)(const gguf_context *, const int64_t)>
     struct GKV_Base_Type {
@@ -413,7 +440,12 @@ namespace GGUFMeta {
     template bool llama_model_loader::get_key_or_arr<std::array<int, 4>>(enum llm_kv kid, std::array<int, 4> & result, uint32_t n, bool required);
     template bool llama_model_loader::get_key_or_arr<std::array<uint32_t, 512>>(enum llm_kv kid, std::array<uint32_t, 512> & result, uint32_t n, bool required);
 
-llama_model_loader::llama_model_loader(const std::string & fname, bool use_mmap, bool check_tensors, const struct llama_model_kv_override * param_overrides_p) {
+llama_model_loader::llama_model_loader(
+        const std::string & fname,
+        std::vector<std::string> & splits,
+        bool use_mmap,
+        bool check_tensors,
+        const struct llama_model_kv_override * param_overrides_p) {
     int trace = 0;
     if (getenv("LLAMA_TRACE")) {
         trace = atoi(getenv("LLAMA_TRACE"));
@@ -425,6 +457,7 @@ llama_model_loader::llama_model_loader(const std::string & fname, bool use_mmap,
         }
     }
 
+    // Load the main GGUF
     struct ggml_context * ctx = NULL;
     struct gguf_init_params params = {
         /*.no_alloc = */ true,
@@ -460,35 +493,54 @@ llama_model_loader::llama_model_loader(const std::string & fname, bool use_mmap,
 
     // Load additional GGML contexts
     if (n_split > 1) {
+        // make sure the main file is loaded first
         uint16_t idx = 0;
-        get_key(llm_kv(LLM_KV_SPLIT_NO), idx);
+        const std::string kv_split_no = llm_kv(LLM_KV_SPLIT_NO);
+        get_key(kv_split_no, idx);
         if (idx != 0) {
-            throw std::runtime_error(format("illegal split file: %d, model must be loaded with the first split", idx));
+            throw std::runtime_error(format("illegal split file idx: %d (file: %s), model must be loaded with the first split", idx, fname.c_str()));
+        }
+
+        // generate list of splits if needed
+        if (splits.empty()) {
+            splits = llama_get_list_splits(fname, idx, n_split);
         }
 
-        std::vector<char> split_prefix(llama_path_max(), 0);
-        if (!llama_split_prefix(split_prefix.data(), split_prefix.size(), fname.c_str(), idx, n_split)) {
-            throw std::runtime_error(format("invalid split file: %s", fname.c_str()));
+        // in case user give a custom list of splits, check if it matches the expected number
+        if (n_split != (uint16_t)splits.size()) {
+            throw std::runtime_error(format("invalid split count, given: %zu splits, but expected %d", splits.size(), n_split));
         }
 
         if (trace > 0) {
             LLAMA_LOG_INFO("%s: loading additional %d GGUFs\n", __func__, n_split);
         }
 
-        std::vector<char> split_path(llama_path_max(), 0);
+        // load other splits
         for (idx = 1; idx < n_split; idx++) {
-            llama_split_path(split_path.data(), split_path.size(), split_prefix.data(), idx, n_split);
+            const char * fname_split = splits[idx].c_str();
 
             struct gguf_init_params split_params = {
                 /*.no_alloc = */ true,
                 /*.ctx      = */ &ctx,
             };
-            gguf_context_ptr ctx_gguf { gguf_init_from_file(split_path.data(), split_params) };
+            gguf_context_ptr ctx_gguf { gguf_init_from_file(fname_split, split_params) };
             if (!ctx_gguf) {
-                throw std::runtime_error(format("%s: failed to load GGUF split from %s\n", __func__, split_path.data()));
+                throw std::runtime_error(format("%s: failed to load GGUF split from %s\n", __func__, fname_split));
+            }
+
+            // check idx
+            {
+                const int kid = gguf_find_key(ctx_gguf.get(), kv_split_no.c_str());
+                if (kid < 0) {
+                    throw std::runtime_error(format("missing key %s in GGUF split %s", kv_split_no.c_str(), fname_split));
+                }
+                int idx_gguf = gguf_get_val_u16(ctx_gguf.get(), kid);
+                if (idx_gguf != idx) {
+                    throw std::runtime_error(format("invalid split file idx: %d (file: %s), expected %d", idx_gguf, fname_split, idx));
+                }
             }
 
-            files.emplace_back(new llama_file(split_path.data(), "rb"));
+            files.emplace_back(new llama_file(fname_split, "rb"));
             contexts.emplace_back(ctx);
 
             // Save tensors data offset info of the shard.
index b63d158d982dd192cc22586f40d6a878ee30d7fd..fe35404b26889ff76ce54bcd9cd65c5057a18fe2 100644 (file)
@@ -90,7 +90,12 @@ struct llama_model_loader {
     size_t size_data = 0;
     std::vector<std::pair<size_t, size_t>> mmaps_used;
 
-    llama_model_loader(const std::string & fname, bool use_mmap, bool check_tensors, const struct llama_model_kv_override * param_overrides_p);
+    llama_model_loader(
+        const std::string & fname,
+        std::vector<std::string> & splits, // optional, only need if the split does not follow naming scheme
+        bool use_mmap,
+        bool check_tensors,
+        const struct llama_model_kv_override * param_overrides_p);
 
     template<typename T>
     typename std::enable_if<std::is_integral<T>::value, bool>::type
index d4947a780c12f52e4c562ae79053f5ac94ed17d5..fb7982655a373f3cc9f4e65b72a5e819534a8478 100644 (file)
@@ -526,7 +526,8 @@ static void llama_model_quantize_impl(const std::string & fname_inp, const std::
         kv_overrides = v->data();
     }
 
-    llama_model_loader ml(fname_inp, use_mmap, /*check_tensors*/ true, kv_overrides);
+    std::vector<std::string> splits = {};
+    llama_model_loader ml(fname_inp, splits, use_mmap, /*check_tensors*/ true, kv_overrides);
     ml.init_mappings(false); // no prefetching
 
     llama_model model(llama_model_default_params());
index 2e391b3b60d3db2d7daad1b89fdb4c18e4fbf245..fede23d196b5f34433196027db17993eb37b1295 100644 (file)
@@ -31,7 +31,7 @@
 #endif
 
 // Returns 0 on success, -1 on error, and -2 on cancellation via llama_progress_callback
-static int llama_model_load(const std::string & fname, llama_model & model, llama_model_params & params) {
+static int llama_model_load(const std::string & fname, std::vector<std::string> & splits, llama_model & model, llama_model_params & params) {
     // loading time will be recalculated after the first eval, so
     // we take page faults deferred by mmap() into consideration
     model.t_load_us = 0;
@@ -40,7 +40,7 @@ static int llama_model_load(const std::string & fname, llama_model & model, llam
     model.t_start_us = tm.t_start_us;
 
     try {
-        llama_model_loader ml(fname, params.use_mmap, params.check_tensors, params.kv_overrides);
+        llama_model_loader ml(fname, splits, params.use_mmap, params.check_tensors, params.kv_overrides);
 
         ml.print_info();
 
@@ -9374,14 +9374,9 @@ int64_t llama_time_us(void) {
     return ggml_time_us();
 }
 
-struct llama_model * llama_load_model_from_file(
-        const char * path_model,
-        struct llama_model_params params) {
-    return llama_model_load_from_file(path_model, params);
-}
-
-struct llama_model * llama_model_load_from_file(
-        const char * path_model,
+static struct llama_model * llama_model_load_from_file_impl(
+        const std::string & path_model,
+        std::vector<std::string> & splits,
         struct llama_model_params params) {
     ggml_time_init();
 
@@ -9485,7 +9480,7 @@ struct llama_model * llama_model_load_from_file(
         LLAMA_LOG_INFO("%s: using device %s (%s) - %zu MiB free\n", __func__, ggml_backend_dev_name(dev), ggml_backend_dev_description(dev), free/1024/1024);
     }
 
-    const int status = llama_model_load(path_model, *model, params);
+    const int status = llama_model_load(path_model, splits, *model, params);
     GGML_ASSERT(status <= 0);
     if (status < 0) {
         if (status == -1) {
@@ -9501,6 +9496,35 @@ struct llama_model * llama_model_load_from_file(
     return model;
 }
 
+// deprecated
+struct llama_model * llama_load_model_from_file(
+        const char * path_model,
+        struct llama_model_params params) {
+    return llama_model_load_from_file(path_model, params);
+}
+
+struct llama_model * llama_model_load_from_file(
+        const char * path_model,
+        struct llama_model_params params) {
+    std::vector<std::string> splits = {};
+    return llama_model_load_from_file_impl(path_model, splits, params);
+}
+
+struct llama_model * llama_model_load_from_splits(
+        const char ** paths,
+        size_t n_paths,
+        struct llama_model_params params) {
+    std::vector<std::string> splits;
+    if (n_paths == 0) {
+        LLAMA_LOG_ERROR("%s: list of splits is empty\n", __func__);
+        return nullptr;
+    }
+    for (size_t i = 0; i < n_paths; ++i) {
+        splits.push_back(paths[i]);
+    }
+    return llama_model_load_from_file_impl(splits.front(), splits, params);
+}
+
 struct llama_context * llama_init_from_model(
                  struct llama_model * model,
         struct llama_context_params   params) {