]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
tests: allow exporting graph ops from HF file without downloading weights (#21182)
authorRuben Ortlam <redacted>
Thu, 2 Apr 2026 16:19:20 +0000 (18:19 +0200)
committerGitHub <redacted>
Thu, 2 Apr 2026 16:19:20 +0000 (18:19 +0200)
* tests: allow exporting graph ops from HF file without downloading weights

* use unique_ptr for llama_context in HF metadata case

* fix missing non-required tensors falling back to type f32

* use unique pointers where possible

* use no_alloc instead of fixing f32 fallback

* fix missing space

common/arg.cpp
common/common.cpp
common/common.h
tests/CMakeLists.txt
tests/export-graph-ops.cpp
tests/gguf-model-data.cpp
tests/gguf-model-data.h

index 538d2a4b0a4695f99f3ce2a008fe46d6c83d75c4..649216b7f01995e6fc32fb6444b51153ebce7ad4 100644 (file)
@@ -537,9 +537,11 @@ static bool common_params_parse_ex(int argc, char ** argv, common_params_context
     } catch (const std::exception & e) {
         LOG_WRN("HF cache migration failed: %s\n", e.what());
     }
+    // export_graph_ops loads only metadata
+    const bool skip_model_download = ctx_arg.ex == LLAMA_EXAMPLE_EXPORT_GRAPH_OPS;
 
     // maybe handle remote preset
-    if (!params.model.hf_repo.empty()) {
+    if (!params.model.hf_repo.empty() && !skip_model_download) {
         std::string cli_hf_repo = params.model.hf_repo;
         bool has_preset = common_params_handle_remote_preset(params, ctx_arg.ex);
 
@@ -570,7 +572,7 @@ static bool common_params_parse_ex(int argc, char ** argv, common_params_context
     }
 
     // handle model and download
-    {
+    if (!skip_model_download) {
         auto res = common_params_handle_model(params.model, params.hf_token, params.offline);
         if (params.no_mmproj) {
             params.mmproj = {};
@@ -591,7 +593,7 @@ static bool common_params_parse_ex(int argc, char ** argv, common_params_context
 
     // model is required (except for server)
     // TODO @ngxson : maybe show a list of available models in CLI in this case
-    if (params.model.path.empty() && ctx_arg.ex != LLAMA_EXAMPLE_SERVER && !params.usage && !params.completion) {
+    if (params.model.path.empty() && ctx_arg.ex != LLAMA_EXAMPLE_SERVER && !skip_model_download && !params.usage && !params.completion) {
         throw std::invalid_argument("error: --model is required\n");
     }
 
index 60396af1f838495d4b2e6fa18b90db75b10b7909..16f78debd0252192f79e1923d1a50acea4bc716e 100644 (file)
@@ -1442,6 +1442,7 @@ struct llama_model_params common_model_params_to_llama(common_params & params) {
 
     mparams.progress_callback           = params.load_progress_callback;
     mparams.progress_callback_user_data = params.load_progress_callback_user_data;
+    mparams.no_alloc                    = params.no_alloc;
 
     return mparams;
 }
index 17dc3fb23261a69a1eab925fbb67d534c87d24ce..31a337daa6e853875919caf34089ed9bd9ace6a6 100644 (file)
@@ -679,6 +679,7 @@ struct common_params {
     // return false from callback to abort model loading or true to continue
     llama_progress_callback load_progress_callback = NULL;
     void *                  load_progress_callback_user_data = NULL;
+    bool no_alloc = false; // Don't allocate model buffers
 };
 
 // call once at the start of a program if it uses libcommon
index 9582164b580de8ac27a15eee70a1fa6235030952..8355c0807068500043cae2709abcb2a187638f5d 100644 (file)
@@ -287,3 +287,7 @@ target_include_directories(test-alloc PRIVATE ${PROJECT_SOURCE_DIR}/ggml/src)
 
 llama_build(export-graph-ops.cpp)
 target_include_directories(export-graph-ops PRIVATE ${PROJECT_SOURCE_DIR}/ggml/src)
+if (TARGET gguf-model-data)
+    target_link_libraries(export-graph-ops PRIVATE gguf-model-data)
+    target_compile_definitions(export-graph-ops PRIVATE LLAMA_HF_FETCH)
+endif()
index cac3ff628e7c05ee02289fc791845a6cd654c485..64cf6dcea3533ec02a8410e75421db22051dce5a 100644 (file)
@@ -1,15 +1,26 @@
 #include "arg.h"
 #include "common.h"
 #include "log.h"
-#include "llama.h"
+#include "llama-cpp.h"
 #include "../src/llama-ext.h"
 #include "ggml.h"
+#include "gguf-model-data.h"
+#include "gguf.h"
+#include "ggml-backend.h"
+#include "download.h"
 
 #include <array>
 #include <vector>
 #include <set>
 #include <fstream>
 #include <iostream>
+#include <random>
+
+// Noop because weights are not needed
+static void set_tensor_data(struct ggml_tensor * tensor, void * userdata) {
+    GGML_UNUSED(tensor);
+    GGML_UNUSED(userdata);
+}
 
 struct input_tensor {
     ggml_type type;
@@ -132,9 +143,52 @@ int main(int argc, char ** argv) {
 
     params.warmup = false;
 
-    auto init_result = common_init_from_params(params);
+    llama_context * ctx;
+    common_init_result_ptr init_result;
+    llama_context_ptr ctx2;
+    llama_model_ptr model;
+
+    if (params.model.hf_repo.empty()) {
+        init_result = common_init_from_params(params);
+
+        ctx = init_result->context();
+    } else {
+#ifdef LLAMA_HF_FETCH
+        auto [hf_repo, hf_quant] = common_download_split_repo_tag(params.model.hf_repo);
+        if (hf_quant.empty() || hf_quant == "latest") {
+            hf_quant = "Q4_K_M";
+        }
+
+        gguf_context_ptr gguf_ctx = gguf_fetch_gguf_ctx(hf_repo, hf_quant);
+        if (!gguf_ctx) {
+            LOG_ERR("failed to fetch GGUF metadata from %s\n", hf_repo.c_str());
+            return 1;
+        }
+
+        llama_model_params model_params = llama_model_default_params();
+        model_params.devices = params.devices.data();
+        model_params.no_alloc = true;
+
+        model.reset(llama_model_init_from_user(gguf_ctx.get(), set_tensor_data, nullptr, model_params));
 
-    llama_context * ctx = init_result->context();
+        if (!model) {
+            LOG_ERR("failed to create llama_model from %s\n", hf_repo.c_str());
+            return 1;
+        }
+
+        llama_context_params ctx_params = llama_context_default_params();
+        ctx2.reset(llama_init_from_model(model.get(), ctx_params));
+        ctx = ctx2.get();
+
+        if (!ctx) {
+            LOG_ERR("failed to create llama_context\n");
+            return 1;
+        }
+#else
+        LOG_ERR("export-graph-ops compiled without HF fetch support\n");
+        return 1;
+#endif
+    }
 
     const uint32_t n_seqs  = llama_n_seq_max(ctx);
     const uint32_t n_tokens = std::min(llama_n_ctx(ctx), llama_n_ubatch(ctx));
@@ -143,13 +197,15 @@ int main(int argc, char ** argv) {
 
     auto * gf_pp = llama_graph_reserve(ctx, n_tokens, n_seqs, n_tokens);
     if (!gf_pp) {
-        throw std::runtime_error("failed to reserve prompt processing graph");
+        LOG_ERR("failed to reserve prompt processing graph\n");
+        return 1;
     }
     extract_graph_ops(gf_pp, "pp", tests);
 
     auto * gf_tg = llama_graph_reserve(ctx, n_seqs, n_seqs, n_seqs);
     if (!gf_tg) {
-        throw std::runtime_error("failed to reserve token generation graph");
+        LOG_ERR("failed to reserve token generation graph\n");
+        return 1;
     }
     extract_graph_ops(gf_tg, "tg", tests);
 
@@ -158,7 +214,8 @@ int main(int argc, char ** argv) {
     std::ofstream f(params.out_file);
 
     if (!f.is_open()) {
-        throw std::runtime_error("Unable to open output file");
+        LOG_ERR("unable to open output file: %s\n", params.out_file.c_str());
+        return 1;
     }
 
     for (const auto& test : tests) {
index 3bc82c88dac8c34ddc1ed24ea99097f3745821b0..adfd6bec68f65f22771805d5123a197cc4f78636 100644 (file)
@@ -4,6 +4,7 @@
 #include "gguf-model-data.h"
 
 #include "common.h"
+#include "ggml-cpp.h"
 #include "gguf.h"
 
 #include <algorithm>
@@ -531,14 +532,18 @@ static std::optional<gguf_remote_model> fetch_and_parse(
     return std::nullopt;
 }
 
+static std::string get_cache_file_path(const std::string& cdir, const std::string& repo_part, const std::string& filename) {
+    std::string fname_part = sanitize_for_path(filename);
+    return cdir + "/" + repo_part + "--" + fname_part + ".partial";
+}
+
 // Try cache first, then fetch and parse a single GGUF shard.
 static std::optional<gguf_remote_model> fetch_or_cached(
         const std::string & repo,
         const std::string & filename,
         const std::string & cdir,
         const std::string & repo_part) {
-    std::string fname_part = sanitize_for_path(filename);
-    std::string cache_path = cdir + "/" + repo_part + "--" + fname_part + ".partial";
+    std::string cache_path = get_cache_file_path(cdir, repo_part, filename);
 
     {
         std::vector<char> cached;
@@ -611,3 +616,84 @@ std::optional<gguf_remote_model> gguf_fetch_model_meta(
 
     return model_opt;
 }
+
+gguf_context_ptr gguf_fetch_gguf_ctx(
+        const std::string & repo,
+        const std::string & quant,
+        const std::string & cache_dir) {
+    std::string cdir = cache_dir.empty() ? get_default_cache_dir() : cache_dir;
+    std::string repo_part = sanitize_for_path(repo);
+
+    std::string split_prefix;
+    std::string filename = detect_gguf_filename(repo, quant, split_prefix);
+
+    if (filename.empty()) {
+        return nullptr;
+    }
+
+    auto model_opt = fetch_or_cached(repo, filename, cdir, repo_part);
+    if (!model_opt.has_value()) {
+        fprintf(stderr, "gguf_fetch: failed to fetch %s\n", filename.c_str());
+        return nullptr;
+    }
+
+    auto & model = model_opt.value();
+
+    const std::string cache_path = get_cache_file_path(cdir, repo_part, filename);
+
+    ggml_context_ptr ggml_ctx_ptr;
+    ggml_context * ggml_ctx{};
+    gguf_init_params params{true, &ggml_ctx};
+    gguf_context_ptr ctx{gguf_init_from_file(cache_path.c_str(), params)};
+    ggml_ctx_ptr.reset(ggml_ctx);
+
+    if (ctx == nullptr) {
+        fprintf(stderr, "gguf_fetch: gguf_init_from_file failed\n");
+        return nullptr;
+    }
+
+    // If the model is split across multiple files we need to fetch the remaining shards metadata
+    if (model.n_split > 1) {
+        if (split_prefix.empty()) {
+            fprintf(stderr, "gguf_fetch: model reports %u splits but filename has no split pattern\n", model.n_split);
+            return nullptr;
+        }
+
+        fprintf(stderr, "gguf_fetch: split model with %u shards, fetching remaining %u...\n",
+                model.n_split, model.n_split - 1);
+
+        for (int i = 2; i <= model.n_split; i++) {
+            char num_buf[6], total_buf[6];
+            snprintf(num_buf,   sizeof(num_buf),   "%05d", i);
+            snprintf(total_buf, sizeof(total_buf), "%05d", (int)model.n_split);
+            std::string shard_name = split_prefix + "-" + num_buf + "-of-" + total_buf + ".gguf";
+
+            auto shard = fetch_or_cached(repo, shard_name, cdir, repo_part);
+            if (!shard.has_value()) {
+                fprintf(stderr, "gguf_fetch: failed to fetch shard %d: %s\n", i, shard_name.c_str());
+                return nullptr;
+            }
+
+            // Load tensors from shard and add to main gguf_context
+            const std::string shard_path = get_cache_file_path(cdir, repo_part, shard_name);
+            ggml_context_ptr shard_ggml_ctx_ptr;
+            ggml_context * shard_ggml_ctx{};
+            gguf_init_params shard_params{true, &shard_ggml_ctx};
+            gguf_context_ptr shard_ctx{gguf_init_from_file(shard_path.c_str(), shard_params)};
+            shard_ggml_ctx_ptr.reset(shard_ggml_ctx);
+
+            if (shard_ctx == nullptr) {
+                fprintf(stderr, "gguf_fetch: shard gguf_init_from_file failed\n");
+                return nullptr;
+            }
+
+            for (ggml_tensor * t = ggml_get_first_tensor(shard_ggml_ctx); t; t = ggml_get_next_tensor(shard_ggml_ctx, t)) {
+                gguf_add_tensor(ctx.get(), t);
+            }
+        }
+
+        gguf_set_val_u16(ctx.get(), "split.count", 1);
+    }
+
+    return ctx;
+}
index ed433791ad7a8a08b3a9c913d8a742d24f37c6d5..61ce24bb051ae28aff060252f447648d36e0ee76 100644 (file)
@@ -1,6 +1,7 @@
 #pragma once
 
-#include "ggml.h"
+#include "ggml-cpp.h"
+#include "gguf.h"
 
 #include <cstdint>
 #include <optional>
@@ -40,3 +41,8 @@ std::optional<gguf_remote_model> gguf_fetch_model_meta(
     const std::string & repo,
     const std::string & quant = "Q8_0",
     const std::string & cache_dir = "");  // empty = default
+
+gguf_context_ptr gguf_fetch_gguf_ctx(
+    const std::string & repo,
+    const std::string & quant = "Q8_0",
+    const std::string & cache_dir = "");