]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
common : add standard Hugging Face cache support (#20775)
authorAdrien Gallouët <redacted>
Tue, 24 Mar 2026 06:30:33 +0000 (07:30 +0100)
committerGitHub <redacted>
Tue, 24 Mar 2026 06:30:33 +0000 (07:30 +0100)
* common : add standard Hugging Face cache support

- Use HF API to find all files
- Migrate all manifests to hugging face cache at startup

Signed-off-by: Adrien Gallouët <redacted>
* Check with the quant tag

Signed-off-by: Adrien Gallouët <redacted>
* Cleanup

Signed-off-by: Adrien Gallouët <redacted>
* Improve error handling and report API errors

Signed-off-by: Adrien Gallouët <redacted>
* Restore common_cached_model_info and align mmproj filtering

Signed-off-by: Adrien Gallouët <redacted>
* Prefer main when getting cached ref

Signed-off-by: Adrien Gallouët <redacted>
* Use cached files when HF API fails

Signed-off-by: Adrien Gallouët <redacted>
* Use final_path..

Signed-off-by: Adrien Gallouët <redacted>
* Check all inputs

Signed-off-by: Adrien Gallouët <redacted>
---------

Signed-off-by: Adrien Gallouët <redacted>
common/CMakeLists.txt
common/arg.cpp
common/download.cpp
common/download.h
common/hf-cache.cpp [new file with mode: 0644]
common/hf-cache.h [new file with mode: 0644]
tools/llama-bench/llama-bench.cpp
tools/server/tests/unit/test_router.py

index 75c6366c7fa65e265a5b37e83cb2747c27d7db95..b313a7320e56ce6e26764055a8edc54a3b851462 100644 (file)
@@ -63,6 +63,8 @@ add_library(${TARGET} STATIC
     debug.h
     download.cpp
     download.h
+    hf-cache.cpp
+    hf-cache.h
     http.h
     json-partial.cpp
     json-partial.h
index c6a2dcbf2d86e78076d6bbe7d897a94c62ab5452..71f63ba63f2ea917808075770656156b91bf8f97 100644 (file)
@@ -3,6 +3,7 @@
 #include "chat.h"
 #include "common.h"
 #include "download.h"
+#include "hf-cache.h"
 #include "json-schema-to-grammar.h"
 #include "log.h"
 #include "sampling.h"
@@ -326,60 +327,48 @@ struct handle_model_result {
     common_params_model mmproj;
 };
 
-static handle_model_result common_params_handle_model(
-        struct common_params_model & model,
-        const std::string & bearer_token,
-        bool offline) {
+static handle_model_result common_params_handle_model(struct common_params_model & model,
+                                                      const std::string          & bearer_token,
+                                                      bool                         offline) {
     handle_model_result result;
-    // handle pre-fill default model path and url based on hf_repo and hf_file
-    {
-        if (!model.docker_repo.empty()) {  // Handle Docker URLs by resolving them to local paths
-            model.path = common_docker_resolve_model(model.docker_repo);
-            model.name = model.docker_repo; // set name for consistency
-        } else if (!model.hf_repo.empty()) {
-            // short-hand to avoid specifying --hf-file -> default it to --model
-            if (model.hf_file.empty()) {
-                if (model.path.empty()) {
-                    auto auto_detected = common_get_hf_file(model.hf_repo, bearer_token, offline);
-                    if (auto_detected.repo.empty() || auto_detected.ggufFile.empty()) {
-                        exit(1); // error message already printed
-                    }
-                    model.name    = model.hf_repo;      // repo name with tag
-                    model.hf_repo = auto_detected.repo; // repo name without tag
-                    model.hf_file = auto_detected.ggufFile;
-                    if (!auto_detected.mmprojFile.empty()) {
-                        result.found_mmproj   = true;
-                        result.mmproj.hf_repo = model.hf_repo;
-                        result.mmproj.hf_file = auto_detected.mmprojFile;
-                    }
-                } else {
-                    model.hf_file = model.path;
-                }
-            }
 
-            std::string model_endpoint = get_model_endpoint();
-            model.url = model_endpoint + model.hf_repo + "/resolve/main/" + model.hf_file;
-            // make sure model path is present (for caching purposes)
-            if (model.path.empty()) {
-                // this is to avoid different repo having same file name, or same file name in different subdirs
-                std::string filename = clean_file_name(model.hf_repo + "_" + model.hf_file);
-                model.path = fs_get_cache_file(filename);
-            }
+    if (!model.docker_repo.empty()) {
+        model.path = common_docker_resolve_model(model.docker_repo);
+        model.name = model.docker_repo;
+    } else if (!model.hf_repo.empty()) {
+        // If -m was used with -hf, treat the model "path" as the hf_file to download
+        if (model.hf_file.empty() && !model.path.empty()) {
+            model.hf_file = model.path;
+            model.path = "";
+        }
+        common_download_model_opts opts;
+        opts.download_mmproj = true;
+        opts.offline = offline;
+        auto download_result = common_download_model(model, bearer_token, opts);
+
+        if (download_result.model_path.empty()) {
+            LOG_ERR("error: failed to download model from Hugging Face\n");
+            exit(1);
+        }
 
-        } else if (!model.url.empty()) {
-            if (model.path.empty()) {
-                auto f = string_split<std::string>(model.url, '#').front();
-                f = string_split<std::string>(f, '?').front();
-                model.path = fs_get_cache_file(string_split<std::string>(f, '/').back());
-            }
+        model.name = model.hf_repo;
+        model.path = download_result.model_path;
 
+        if (!download_result.mmproj_path.empty()) {
+            result.found_mmproj = true;
+            result.mmproj.path  = download_result.mmproj_path;
+        }
+    } else if (!model.url.empty()) {
+        if (model.path.empty()) {
+            auto f = string_split<std::string>(model.url, '#').front();
+            f = string_split<std::string>(f, '?').front();
+            model.path = fs_get_cache_file(string_split<std::string>(f, '/').back());
         }
-    }
 
-    // then, download it if needed
-    if (!model.url.empty()) {
-        bool ok = common_download_model(model, bearer_token, offline);
-        if (!ok) {
+        common_download_model_opts opts;
+        opts.offline = offline;
+        auto download_result = common_download_model(model, bearer_token, opts);
+        if (download_result.model_path.empty()) {
             LOG_ERR("error: failed to download model from %s\n", model.url.c_str());
             exit(1);
         }
@@ -539,6 +528,13 @@ static bool common_params_parse_ex(int argc, char ** argv, common_params_context
     // parse the first time to get -hf option (used for remote preset)
     parse_cli_args();
 
+    // TODO: Remove later
+    try {
+        hf_cache::migrate_old_cache_to_hf_cache(params.hf_token, params.offline);
+    } catch (const std::exception & e) {
+        LOG_WRN("HF cache migration failed: %s\n", e.what());
+    }
+
     // maybe handle remote preset
     if (!params.model.hf_repo.empty()) {
         std::string cli_hf_repo = params.model.hf_repo;
@@ -1061,12 +1057,10 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
         {"-cl", "--cache-list"},
         "show list of models in cache",
         [](common_params &) {
-            printf("model cache directory: %s\n", fs_get_cache_directory().c_str());
             auto models = common_list_cached_models();
             printf("number of models in cache: %zu\n", models.size());
             for (size_t i = 0; i < models.size(); i++) {
-                auto & model = models[i];
-                printf("%4d. %s\n", (int) i + 1, model.to_string().c_str());
+                printf("%4zu. %s\n", i + 1, models[i].to_string().c_str());
             }
             exit(0);
         }
index 5ef60a420864834d9f740471976332fdf554a7eb..151517b19e510a4abdfffa31579d017d86e54c57 100644 (file)
@@ -1,9 +1,9 @@
 #include "arg.h"
 
 #include "common.h"
-#include "gguf.h" // for reading GGUF splits
 #include "log.h"
 #include "download.h"
+#include "hf-cache.h"
 
 #define JSON_ASSERT GGML_ASSERT
 #include <nlohmann/json.hpp>
@@ -15,6 +15,7 @@
 #include <map>
 #include <mutex>
 #include <regex>
+#include <unordered_set>
 #include <string>
 #include <thread>
 #include <vector>
@@ -35,8 +36,6 @@
 #endif
 #endif
 
-#define LLAMA_MAX_URL_LENGTH 2084 // Maximum URL Length in Chrome: 2083
-
 // isatty
 #if defined(_WIN32)
 #include <io.h>
@@ -51,31 +50,6 @@ using json = nlohmann::ordered_json;
 //
 
 // validate repo name format: owner/repo
-static bool validate_repo_name(const std::string & repo) {
-    static const std::regex repo_regex(R"(^[A-Za-z0-9_.\-]+\/[A-Za-z0-9_.\-]+$)");
-    return std::regex_match(repo, repo_regex);
-}
-
-static std::string get_manifest_path(const std::string & repo, const std::string & tag) {
-    // we use "=" to avoid clashing with other component, while still being allowed on windows
-    std::string fname = "manifest=" + repo + "=" + tag + ".json";
-    if (!validate_repo_name(repo)) {
-        throw std::runtime_error("error: repo name must be in the format 'owner/repo'");
-    }
-    string_replace_all(fname, "/", "=");
-    return fs_get_cache_file(fname);
-}
-
-static std::string read_file(const std::string & fname) {
-    std::ifstream file(fname);
-    if (!file) {
-        throw std::runtime_error(string_format("error: failed to open file '%s'\n", fname.c_str()));
-    }
-    std::string content((std::istreambuf_iterator<char>(file)), std::istreambuf_iterator<char>());
-    file.close();
-    return content;
-}
-
 static void write_file(const std::string & fname, const std::string & content) {
     const std::string fname_tmp = fname + ".tmp";
     std::ofstream     file(fname_tmp);
@@ -132,7 +106,7 @@ static bool is_http_status_ok(int status) {
 
 std::pair<std::string, std::string> common_download_split_repo_tag(const std::string & hf_repo_with_tag) {
     auto parts = string_split<std::string>(hf_repo_with_tag, ':');
-    std::string tag = parts.size() > 1 ? parts.back() : "latest";
+    std::string tag = parts.size() > 1 ? parts.back() : "";
     std::string hf_repo = parts[0];
     if (string_split<std::string>(hf_repo, '/').size() != 2) {
         throw std::invalid_argument("error: invalid HF repo format, expected <user>/<model>[:quant]\n");
@@ -290,7 +264,8 @@ static bool common_pull_file(httplib::Client & cli,
 static int common_download_file_single_online(const std::string        & url,
                                               const std::string        & path,
                                               const std::string        & bearer_token,
-                                              const common_header_list & custom_headers) {
+                                              const common_header_list & custom_headers,
+                                              bool                       skip_etag = false) {
     static const int max_attempts        = 3;
     static const int retry_delay_seconds = 2;
 
@@ -310,6 +285,11 @@ static int common_download_file_single_online(const std::string        & url,
 
     const bool file_exists = std::filesystem::exists(path);
 
+    if (file_exists && skip_etag) {
+        LOG_INF("%s: using cached file: %s\n", __func__, path.c_str());
+        return 304; // 304 Not Modified - fake cached response
+    }
+
     std::string last_etag;
     if (file_exists) {
         last_etag = read_etag(path);
@@ -361,6 +341,12 @@ static int common_download_file_single_online(const std::string        & url,
         }
     }
 
+    { // silent
+        std::error_code ec;
+        std::filesystem::path p(path);
+        std::filesystem::create_directories(p.parent_path(), ec);
+    }
+
     const std::string path_temporary = path + ".downloadInProgress";
     int delay = retry_delay_seconds;
 
@@ -391,7 +377,7 @@ static int common_download_file_single_online(const std::string        & url,
                 LOG_ERR("%s: unable to rename file: %s to %s\n", __func__, path_temporary.c_str(), path.c_str());
                 return -1;
             }
-            if (!etag.empty()) {
+            if (!etag.empty() && !skip_etag) {
                 write_etag(path, etag);
             }
             return head->status;
@@ -440,9 +426,10 @@ int common_download_file_single(const std::string & url,
                                 const std::string & path,
                                 const std::string & bearer_token,
                                 bool offline,
-                                const common_header_list & headers) {
+                                const common_header_list & headers,
+                                bool skip_etag) {
     if (!offline) {
-        return common_download_file_single_online(url, path, bearer_token, headers);
+        return common_download_file_single_online(url, path, bearer_token, headers, skip_etag);
     }
 
     if (!std::filesystem::exists(path)) {
@@ -454,193 +441,293 @@ int common_download_file_single(const std::string & url,
     return 304; // Not Modified - fake cached response
 }
 
-// download multiple files from remote URLs to local paths
-// the input is a vector of pairs <url, path>
-static bool common_download_file_multiple(const std::vector<std::pair<std::string, std::string>> & urls,
-                                          const std::string & bearer_token,
-                                          bool offline,
-                                          const common_header_list & headers) {
-    // Prepare download in parallel
-    std::vector<std::future<bool>> futures_download;
-    futures_download.reserve(urls.size());
-
-    for (auto const & item : urls) {
-        futures_download.push_back(
-            std::async(
-                std::launch::async,
-                [&bearer_token, offline, &headers](const std::pair<std::string, std::string> & it) -> bool {
-                    const int http_status = common_download_file_single(it.first, it.second, bearer_token, offline, headers);
-                    return is_http_status_ok(http_status);
-                },
-                item
-            )
-        );
-    }
-
-    // Wait for all downloads to complete
-    for (auto & f : futures_download) {
-        if (!f.get()) {
-            return false;
+struct gguf_split_info {
+    std::string prefix; // tag included
+    std::string tag;
+    int index;
+    int count;
+};
+
+static gguf_split_info get_gguf_split_info(const std::string & path) {
+    static const std::regex re_split("^(.+)-([0-9]{5})-of-([0-9]{5})$", std::regex::icase);
+    static const std::regex re_tag("[-.]([A-Z0-9_]+)$", std::regex::icase);
+    std::smatch m;
+
+    std::string prefix = path;
+    string_remove_suffix(prefix, ".gguf");
+
+    int index = 1;
+    int count = 1;
+
+    if (std::regex_match(prefix, m, re_split)) {
+        prefix = m[1].str();
+        index = std::stoi(m[2].str());
+        count = std::stoi(m[3].str());
+    }
+
+    std::string tag;
+    if (std::regex_search(prefix, m, re_tag)) {
+        tag = m[1].str();
+        for (char & c : tag) {
+            c = std::toupper((unsigned char)c);
         }
     }
 
-    return true;
+    return {std::move(prefix), std::move(tag), index, count};
 }
 
-bool common_download_model(const common_params_model & model,
-                           const std::string & bearer_token,
-                           bool offline,
-                           const common_header_list & headers) {
-    // Basic validation of the model.url
-    if (model.url.empty()) {
-        LOG_ERR("%s: invalid model url\n", __func__);
-        return false;
+// Q4_0 -> 4, F16 -> 16, NVFP4 -> 4, Q8_K_M -> 8, etc
+static int extract_quant_bits(const std::string & filename) {
+    auto split = get_gguf_split_info(filename);
+
+    auto pos = split.tag.find_first_of("0123456789");
+    if (pos == std::string::npos) {
+        return 0;
     }
 
-    const int http_status = common_download_file_single(model.url, model.path, bearer_token, offline, headers);
-    if (!is_http_status_ok(http_status)) {
-        return false;
+    return std::stoi(split.tag.substr(pos));
+}
+
+static hf_cache::hf_files get_split_files(const hf_cache::hf_files & files,
+                                          const hf_cache::hf_file  & file) {
+    auto split = get_gguf_split_info(file.path);
+
+    if (split.count <= 1) {
+        return {file};
     }
+    hf_cache::hf_files result;
 
-    // check for additional GGUFs split to download
-    int n_split = 0;
-    {
-        struct gguf_init_params gguf_params = {
-            /*.no_alloc = */ true,
-            /*.ctx      = */ NULL,
-        };
-        auto * ctx_gguf = gguf_init_from_file(model.path.c_str(), gguf_params);
-        if (!ctx_gguf) {
-            LOG_ERR("\n%s:  failed to load input GGUF from %s\n", __func__, model.path.c_str());
-            return false;
+    for (const auto & f : files) {
+        auto split_f = get_gguf_split_info(f.path);
+        if (split_f.count == split.count && split_f.prefix == split.prefix) {
+            result.push_back(f);
+        }
+    }
+    return result;
+}
+
+static hf_cache::hf_file find_best_mmproj(const hf_cache::hf_files & files,
+                                          const std::string        & model) {
+    hf_cache::hf_file best;
+    size_t best_depth = 0;
+    int best_diff = 0;
+    bool found = false;
+
+    auto model_bits = extract_quant_bits(model);
+    auto model_parts = string_split<std::string>(model, '/');
+    auto model_dir = model_parts.end() - 1;
+
+    for (const auto & f : files) {
+        if (!string_ends_with(f.path, ".gguf") ||
+            f.path.find("mmproj") == std::string::npos) {
+            continue;
         }
 
-        auto key_n_split = gguf_find_key(ctx_gguf, LLM_KV_SPLIT_COUNT);
-        if (key_n_split >= 0) {
-            n_split = gguf_get_val_u16(ctx_gguf, key_n_split);
+        auto mmproj_parts = string_split<std::string>(f.path, '/');
+        auto mmproj_dir = mmproj_parts.end() - 1;
+
+        auto [_, dir] = std::mismatch(model_parts.begin(), model_dir,
+                                      mmproj_parts.begin(), mmproj_dir);
+        if (dir != mmproj_dir) {
+            continue;
         }
 
-        gguf_free(ctx_gguf);
+        size_t depth = dir - mmproj_parts.begin();
+        auto bits = extract_quant_bits(f.path);
+        auto diff = std::abs(bits - model_bits);
+
+        if (!found || depth > best_depth || (depth == best_depth && diff < best_diff)) {
+            best = f;
+            best_depth = depth;
+            best_diff = diff;
+            found = true;
+        }
     }
+    return best;
+}
 
-    if (n_split > 1) {
-        char split_prefix[PATH_MAX] = {0};
-        char split_url_prefix[LLAMA_MAX_URL_LENGTH] = {0};
+static hf_cache::hf_file find_best_model(const hf_cache::hf_files & files,
+                                         const std::string        & tag) {
+    std::vector<std::string> tags;
 
-        // Verify the first split file format
-        // and extract split URL and PATH prefixes
-        {
-            if (!llama_split_prefix(split_prefix, sizeof(split_prefix), model.path.c_str(), 0, n_split)) {
-                LOG_ERR("\n%s: unexpected model file name: %s n_split=%d\n", __func__, model.path.c_str(), n_split);
-                return false;
-            }
+    if (!tag.empty()) {
+        tags.push_back(tag);
+    } else {
+        tags = {"Q4_K_M", "Q4_0"};
+    }
 
-            if (!llama_split_prefix(split_url_prefix, sizeof(split_url_prefix), model.url.c_str(), 0, n_split)) {
-                LOG_ERR("\n%s: unexpected model url: %s n_split=%d\n", __func__, model.url.c_str(), n_split);
-                return false;
+    for (const auto & t : tags) {
+        std::regex pattern(t + "[.-]", std::regex::icase);
+        for (const auto & f : files) {
+            if (string_ends_with(f.path, ".gguf") &&
+                f.path.find("mmproj") == std::string::npos &&
+                std::regex_search(f.path, pattern)) {
+                return f;
             }
         }
+    }
 
-        std::vector<std::pair<std::string, std::string>> urls;
-        for (int idx = 1; idx < n_split; idx++) {
-            char split_path[PATH_MAX] = {0};
-            llama_split_path(split_path, sizeof(split_path), split_prefix, idx, n_split);
+    for (const auto & f : files) {
+        if (string_ends_with(f.path, ".gguf") &&
+            f.path.find("mmproj") == std::string::npos) {
+            return f;
+        }
+    }
 
-            char split_url[LLAMA_MAX_URL_LENGTH] = {0};
-            llama_split_path(split_url, sizeof(split_url), split_url_prefix, idx, n_split);
+    return {};
+}
 
-            if (std::string(split_path) == model.path) {
-                continue; // skip the already downloaded file
-            }
+static void list_available_gguf_files(const hf_cache::hf_files & files) {
+    LOG_INF("Available GGUF files:\n");
+    for (const auto & f : files) {
+        if (string_ends_with(f.path, ".gguf")) {
+            LOG_INF(" - %s\n", f.path.c_str());
+        }
+    }
+}
+
+struct hf_plan {
+    hf_cache::hf_files model_files;
+    hf_cache::hf_file mmproj;
+};
 
-            urls.push_back({split_url, split_path});
+static hf_plan get_hf_plan(const common_params_model        & model,
+                           const std::string                & token,
+                           const common_download_model_opts & opts) {
+    hf_plan plan;
+    hf_cache::hf_files all;
+
+    auto [repo, tag] = common_download_split_repo_tag(model.hf_repo);
+
+    if (!opts.offline) {
+        all = hf_cache::get_repo_files(repo, token);
+    }
+    if (all.empty()) {
+        all = hf_cache::get_cached_files(repo);
+    }
+    if (all.empty()) {
+        return plan;
+    }
+
+    hf_cache::hf_file primary;
+
+    if (!model.hf_file.empty()) {
+        for (const auto & f : all) {
+            if (f.path == model.hf_file) {
+                primary = f;
+                break;
+            }
+        }
+        if (primary.path.empty()) {
+            LOG_ERR("%s: file '%s' not found in repository\n", __func__, model.hf_file.c_str());
+            list_available_gguf_files(all);
+            return plan;
         }
+    } else {
+        primary = find_best_model(all, tag);
+        if (primary.path.empty()) {
+            LOG_ERR("%s: no GGUF files found in repository %s\n", __func__, repo.c_str());
+            list_available_gguf_files(all);
+            return plan;
+        }
+    }
 
-        // Download in parallel
-        common_download_file_multiple(urls, bearer_token, offline, headers);
+    plan.model_files = get_split_files(all, primary);
+
+    if (opts.download_mmproj) {
+        plan.mmproj = find_best_mmproj(all, primary.path);
     }
 
-    return true;
+    return plan;
 }
 
-common_hf_file_res common_get_hf_file(const std::string & hf_repo_with_tag,
-                                      const std::string & bearer_token,
-                                      bool offline,
-                                      const common_header_list & custom_headers) {
-    // the returned hf_repo is without tag
-    auto [hf_repo, tag] = common_download_split_repo_tag(hf_repo_with_tag);
+struct download_task {
+    std::string url;
+    std::string path;
+};
 
-    std::string url = get_model_endpoint() + "v2/" + hf_repo + "/manifests/" + tag;
+static std::vector<download_task> get_url_tasks(const common_params_model & model) {
+    auto split = get_gguf_split_info(model.url);
 
-    // headers
-    common_header_list headers = custom_headers;
-    headers.push_back({"Accept", "application/json"});
-    if (!bearer_token.empty()) {
-        headers.push_back({"Authorization", "Bearer " + bearer_token});
+    if (split.count <= 1) {
+        return {{model.url, model.path}};
     }
-    // Important: the User-Agent must be "llama-cpp" to get the "ggufFile" field in the response
-    // User-Agent header is already set in common_remote_get_content, no need to set it here
 
-    // make the request
-    common_remote_params params;
-    params.headers = headers;
-    long res_code = 0;
-    std::string res_str;
-    bool use_cache = false;
-    std::string cached_response_path = get_manifest_path(hf_repo, tag);
-    if (!offline) {
-        try {
-            auto res = common_remote_get_content(url, params);
-            res_code = res.first;
-            res_str = std::string(res.second.data(), res.second.size());
-        } catch (const std::exception & e) {
-            LOG_WRN("error: failed to get manifest at %s: %s\n", url.c_str(), e.what());
-        }
+    auto filename = split.prefix;
+    if (auto pos = split.prefix.rfind('/'); pos != std::string::npos) {
+        filename = split.prefix.substr(pos + 1);
     }
-    if (res_code == 0) {
-        if (std::filesystem::exists(cached_response_path)) {
-            LOG_WRN("trying to read manifest from cache: %s\n", cached_response_path.c_str());
-            res_str = read_file(cached_response_path);
-            res_code = 200;
-            use_cache = true;
-        } else {
-            throw std::runtime_error(
-                offline ? "error: failed to get manifest (offline mode)"
-                : "error: failed to get manifest (check your internet connection)");
-        }
+
+    auto parent_path = std::filesystem::path(model.path).parent_path();
+    auto prefix_path = (parent_path / filename).string();
+
+    std::vector<download_task> tasks;
+    for (int i = 1; i <= split.count; i++) {
+        auto suffix = string_format("-%05d-of-%05d.gguf", i, split.count);
+        tasks.push_back({split.prefix + suffix, prefix_path + suffix});
     }
-    std::string ggufFile;
-    std::string mmprojFile;
+    return tasks;
+}
 
-    if (res_code == 200 || res_code == 304) {
-        try {
-            auto j = json::parse(res_str);
+common_download_model_result common_download_model(const common_params_model        & model,
+                                                   const std::string                & bearer_token,
+                                                   const common_download_model_opts & opts,
+                                                   const common_header_list         & headers) {
+    common_download_model_result result;
+    std::vector<download_task> tasks;
+    hf_plan hf;
 
-            if (j.contains("ggufFile") && j["ggufFile"].contains("rfilename")) {
-                ggufFile = j["ggufFile"]["rfilename"].get<std::string>();
-            }
-            if (j.contains("mmprojFile") && j["mmprojFile"].contains("rfilename")) {
-                mmprojFile = j["mmprojFile"]["rfilename"].get<std::string>();
-            }
-        } catch (const std::exception & e) {
-            throw std::runtime_error(std::string("error parsing manifest JSON: ") + e.what());
+    bool is_hf = !model.hf_repo.empty();
+
+    if (is_hf) {
+        hf = get_hf_plan(model, bearer_token, opts);
+        for (const auto & f : hf.model_files) {
+            tasks.push_back({f.url, f.local_path});
         }
-        if (!use_cache) {
-            // if not using cached response, update the cache file
-            write_file(cached_response_path, res_str);
+        if (!hf.mmproj.path.empty()) {
+            tasks.push_back({hf.mmproj.url, hf.mmproj.local_path});
         }
-    } else if (res_code == 401) {
-        throw std::runtime_error("error: model is private or does not exist; if you are accessing a gated model, please provide a valid HF token");
+    } else if (!model.url.empty()) {
+        tasks = get_url_tasks(model);
     } else {
-        throw std::runtime_error(string_format("error from HF API (%s), response code: %ld, data: %s", url.c_str(), res_code, res_str.c_str()));
+        result.model_path = model.path;
+        return result;
     }
 
-    // check response
-    if (ggufFile.empty()) {
-        throw std::runtime_error("error: model does not have ggufFile");
+    if (tasks.empty()) {
+        return result;
     }
 
-    return { hf_repo, ggufFile, mmprojFile };
+    std::vector<std::future<bool>> futures;
+    for (const auto & task : tasks) {
+        futures.push_back(std::async(std::launch::async,
+            [&task, &bearer_token, offline = opts.offline, &headers, is_hf]() {
+                int status = common_download_file_single(task.url, task.path, bearer_token, offline, headers, is_hf);
+                return is_http_status_ok(status);
+            }
+        ));
+    }
+
+    for (auto & f : futures) {
+        if (!f.get()) {
+            return {};
+        }
+    }
+
+    if (is_hf) {
+        for (const auto & f : hf.model_files) {
+            hf_cache::finalize_file(f);
+        }
+        result.model_path = hf.model_files[0].final_path;
+
+        if (!hf.mmproj.path.empty()) {
+            result.mmproj_path = hf_cache::finalize_file(hf.mmproj);
+        }
+    } else {
+        result.model_path = model.path;
+    }
+
+    return result;
 }
 
 //
@@ -765,28 +852,21 @@ std::string common_docker_resolve_model(const std::string & docker) {
 }
 
 std::vector<common_cached_model_info> common_list_cached_models() {
-    std::vector<common_cached_model_info> models;
-    const std::string cache_dir = fs_get_cache_directory();
-    const std::vector<common_file_info> files = fs_list(cache_dir, false);
-    for (const auto & file : files) {
-        if (string_starts_with(file.name, "manifest=") && string_ends_with(file.name, ".json")) {
-            common_cached_model_info model_info;
-            model_info.manifest_path = file.path;
-            std::string fname = file.name;
-            string_replace_all(fname, ".json", ""); // remove extension
-            auto parts = string_split<std::string>(fname, '=');
-            if (parts.size() == 4) {
-                // expect format: manifest=<user>=<model>=<tag>=<other>
-                model_info.user  = parts[1];
-                model_info.model = parts[2];
-                model_info.tag   = parts[3];
-            } else {
-                // invalid format
-                continue;
-            }
-            model_info.size = 0; // TODO: get GGUF size, not manifest size
-            models.push_back(model_info);
+    std::unordered_set<std::string> seen;
+    std::vector<common_cached_model_info> result;
+
+    auto files = hf_cache::get_cached_files();
+
+    for (const auto & f : files) {
+        auto split = get_gguf_split_info(f.path);
+        if (split.index != 1 || split.tag.empty() ||
+            split.prefix.find("mmproj") != std::string::npos) {
+            continue;
+        }
+        if (seen.insert(f.repo_id + ":" + split.tag).second) {
+            result.push_back({f.repo_id, split.tag});
         }
     }
-    return models;
+
+    return result;
 }
index 1c1d8e6db595f6458c6e9a635caf21a76191a8fa..0a933521faaf040891e601f6c4e5d3113136325e 100644 (file)
@@ -17,54 +17,60 @@ struct common_remote_params {
 // get remote file content, returns <http_code, raw_response_body>
 std::pair<long, std::vector<char>> common_remote_get_content(const std::string & url, const common_remote_params & params);
 
-// split HF repo with tag into <repo, tag>
-// for example: "user/model:tag" -> <"user/model", "tag">
-// if tag is not present, default to "latest"
-// example: "user/model" -> <"user/model", "latest">
+// split HF repo with tag into <repo, tag>, for example:
+// - "ggml-org/models:F16" -> <"ggml-org/models", "F16">
+// tag is optional and can be empty
 std::pair<std::string, std::string> common_download_split_repo_tag(const std::string & hf_repo_with_tag);
 
+// Result of common_list_cached_models
 struct common_cached_model_info {
-    std::string manifest_path;
-    std::string user;
-    std::string model;
+    std::string repo;
     std::string tag;
-    size_t      size = 0; // GGUF size in bytes
-    // return string representation like "user/model:tag"
-    // if tag is "latest", it will be omitted
     std::string to_string() const {
-        return user + "/" + model + (tag == "latest" ? "" : ":" + tag);
+        return repo + ":" + tag;
     }
 };
 
-struct common_hf_file_res {
-    std::string repo; // repo name with ":tag" removed
-    std::string ggufFile;
-    std::string mmprojFile;
+// Options for common_download_model
+struct common_download_model_opts {
+    bool download_mmproj = false;
+    bool offline         = false;
 };
 
-/**
- * Allow getting the HF file from the HF repo with tag (like ollama), for example:
- * - bartowski/Llama-3.2-3B-Instruct-GGUF:q4
- * - bartowski/Llama-3.2-3B-Instruct-GGUF:Q4_K_M
- * - bartowski/Llama-3.2-3B-Instruct-GGUF:q5_k_s
- * Tag is optional, default to "latest" (meaning it checks for Q4_K_M first, then Q4, then if not found, return the first GGUF file in repo)
- *
- * Return pair of <repo, file> (with "repo" already having tag removed)
- *
- * Note: we use the Ollama-compatible HF API, but not using the blobId. Instead, we use the special "ggufFile" field which returns the value for "hf_file". This is done to be backward-compatible with existing cache files.
- */
-common_hf_file_res common_get_hf_file(
-    const std::string & hf_repo_with_tag,
-    const std::string & bearer_token,
-    bool offline,
-    const common_header_list & headers = {}
-);
+// Result of common_download_model
+struct common_download_model_result {
+    std::string model_path;
+    std::string mmproj_path;
+};
 
-// returns true if download succeeded
-bool common_download_model(
+// Download model from HuggingFace repo or URL
+//
+// input (via model struct):
+// - model.hf_repo: HF repo with optional tag, see common_download_split_repo_tag
+// - model.hf_file: specific file in the repo (requires hf_repo)
+// - model.url: simple download (used if hf_repo is empty)
+// - model.path: local file path
+//
+// tag matching (for HF repos without model.hf_file):
+// - if tag is specified, searches for GGUF matching that quantization
+// - if no tag, searches for Q4_K_M, then Q4_0, then first available GGUF
+//
+// split GGUF: multi-part files like "model-00001-of-00003.gguf" are automatically
+// detected and all parts are downloaded
+//
+// caching:
+// - HF repos: uses HuggingFace cache
+// - URLs: uses ETag-based caching
+//
+// when opts.offline=true, no network requests are made
+// when download_mmproj=true, searches for mmproj in same directory as model or any parent directory
+// then with the closest quantization bits
+//
+// returns result with model_path and mmproj_path (empty on failure)
+common_download_model_result common_download_model(
     const common_params_model & model,
     const std::string & bearer_token,
-    bool offline,
+    const common_download_model_opts & opts = {},
     const common_header_list & headers = {}
 );
 
@@ -73,11 +79,13 @@ std::vector<common_cached_model_info> common_list_cached_models();
 
 // download single file from url to local path
 // returns status code or -1 on error
+// skip_etag: if true, don't read/write .etag files (for HF cache where filename is the hash)
 int common_download_file_single(const std::string & url,
                                 const std::string & path,
                                 const std::string & bearer_token,
                                 bool offline,
-                                const common_header_list & headers = {});
+                                const common_header_list & headers = {},
+                                bool skip_etag = false);
 
 // resolve and download model from Docker registry
 // return local path to downloaded model file
diff --git a/common/hf-cache.cpp b/common/hf-cache.cpp
new file mode 100644 (file)
index 0000000..ad68c55
--- /dev/null
@@ -0,0 +1,629 @@
+#include "hf-cache.h"
+
+#include "common.h"
+#include "log.h"
+#include "http.h"
+
+#define JSON_ASSERT GGML_ASSERT
+#include <nlohmann/json.hpp>
+
+#include <filesystem>
+#include <fstream>
+#include <atomic>
+#include <regex> // migration only
+#include <string>
+#include <string_view>
+#include <stdexcept>
+
+namespace nl = nlohmann;
+
+#if defined(_WIN32)
+#define WIN32_LEAN_AND_MEAN
+#ifndef NOMINMAX
+#define NOMINMAX
+#endif
+#define HOME_DIR "USERPROFILE"
+#include <windows.h>
+#else
+#define HOME_DIR "HOME"
+#endif
+
+namespace hf_cache {
+
+namespace fs = std::filesystem;
+
+static fs::path get_cache_directory() {
+    static const fs::path cache = []() {
+        struct {
+            const char * var;
+            fs::path path;
+        } entries[] = {
+            {"HF_HUB_CACHE",          fs::path()},
+            {"HUGGINGFACE_HUB_CACHE", fs::path()},
+            {"HF_HOME",               fs::path("hub")},
+            {"XDG_CACHE_HOME",        fs::path("huggingface") / "hub"},
+            {HOME_DIR,                fs::path(".cache") / "huggingface" / "hub"}
+        };
+        for (const auto & entry : entries) {
+            if (auto * p = std::getenv(entry.var); p && *p) {
+                fs::path base(p);
+                return entry.path.empty() ? base : base / entry.path;
+            }
+        }
+        throw std::runtime_error("Failed to determine HF cache directory");
+    }();
+
+    return cache;
+}
+
+static std::string folder_name_to_repo(const std::string & folder) {
+    constexpr std::string_view prefix = "models--";
+    if (folder.rfind(prefix, 0)) {
+        return {};
+    }
+    std::string result = folder.substr(prefix.length());
+    string_replace_all(result, "--", "/");
+    return result;
+}
+
+static std::string repo_to_folder_name(const std::string & repo_id) {
+    constexpr std::string_view prefix = "models--";
+    std::string result = std::string(prefix) + repo_id;
+    string_replace_all(result, "/", "--");
+    return result;
+}
+
+static fs::path get_repo_path(const std::string & repo_id) {
+    return get_cache_directory() / repo_to_folder_name(repo_id);
+}
+
+static bool is_hex_char(const char c) {
+    return (c >= 'A' && c <= 'F') ||
+           (c >= 'a' && c <= 'f') ||
+           (c >= '0' && c <= '9');
+}
+
+static bool is_hex_string(const std::string & s, size_t expected_len) {
+    if (s.length() != expected_len) {
+        return false;
+    }
+    for (const char c : s) {
+        if (!is_hex_char(c)) {
+            return false;
+        }
+    }
+    return true;
+}
+
+static bool is_alphanum(const char c) {
+    return (c >= 'A' && c <= 'Z') ||
+           (c >= 'a' && c <= 'z') ||
+           (c >= '0' && c <= '9');
+}
+
+static bool is_special_char(char c) {
+    return c == '/' || c == '.' || c == '-';
+}
+
+// base chars [A-Za-z0-9_] are always valid
+// special chars [/.-] must be surrounded by base chars
+// exactly one '/' required
+static bool is_valid_repo_id(const std::string & repo_id) {
+    if (repo_id.empty() || repo_id.length() > 256) {
+        return false;
+    }
+    int slash = 0;
+    bool special = true;
+
+    for (const char c : repo_id) {
+        if (is_alphanum(c) || c == '_') {
+            special = false;
+        } else if (is_special_char(c)) {
+            if (special) {
+                return false;
+            }
+            slash += (c == '/');
+            special = true;
+        } else {
+            return false;
+        }
+    }
+    return !special && slash == 1;
+}
+
+static bool is_valid_hf_token(const std::string & token) {
+    if (token.length() < 37 || token.length() > 256 ||
+        !string_starts_with(token, "hf_")) {
+        return false;
+    }
+    for (size_t i = 3; i < token.length(); ++i) {
+        if (!is_alphanum(token[i])) {
+            return false;
+        }
+    }
+    return true;
+}
+
+static bool is_valid_commit(const std::string & hash) {
+    return is_hex_string(hash, 40);
+}
+
+static bool is_valid_oid(const std::string & oid) {
+    return is_hex_string(oid, 40) || is_hex_string(oid, 64);
+}
+
+static bool is_valid_subpath(const fs::path & path, const fs::path & subpath) {
+    if (subpath.is_absolute()) {
+        return false; // never do a / b with b absolute
+    }
+    auto b = fs::absolute(path).lexically_normal();
+    auto t = (b / subpath).lexically_normal();
+    auto [b_end, _] = std::mismatch(b.begin(), b.end(), t.begin(), t.end());
+
+    return b_end == b.end();
+}
+
+static void safe_write_file(const fs::path & path, const std::string & data) {
+    fs::path path_tmp = path.string() + ".tmp";
+
+    if (path.has_parent_path()) {
+        fs::create_directories(path.parent_path());
+    }
+
+    std::ofstream file(path_tmp);
+    file << data;
+    file.close();
+
+    std::error_code ec;
+
+    if (!file.fail()) {
+        fs::rename(path_tmp, path, ec);
+    }
+    if (file.fail() || ec) {
+        fs::remove(path_tmp, ec);
+        throw std::runtime_error("failed to write file: " + path.string());
+    }
+}
+
+static nl::json api_get(const std::string & url,
+                        const std::string & token) {
+    auto [cli, parts] = common_http_client(url);
+
+    httplib::Headers headers = {
+        {"User-Agent", "llama-cpp/" + build_info},
+        {"Accept", "application/json"}
+    };
+
+    if (is_valid_hf_token(token)) {
+        headers.emplace("Authorization", "Bearer " + token);
+    } else if (!token.empty()) {
+        LOG_WRN("%s: invalid token, authentication disabled\n", __func__);
+    }
+
+    if (auto res = cli.Get(parts.path, headers)) {
+        auto body = res->body;
+
+        if (res->status == 200) {
+            return nl::json::parse(res->body);
+        }
+        try {
+            body = nl::json::parse(res->body)["error"].get<std::string>();
+        } catch (...) { }
+
+        throw std::runtime_error("GET failed (" + std::to_string(res->status) + "): " + body);
+    } else {
+        throw std::runtime_error("HTTPLIB failed: " + httplib::to_string(res.error()));
+    }
+}
+
+static std::string get_repo_commit(const std::string & repo_id,
+                                   const std::string & token) {
+    try {
+        auto endpoint = get_model_endpoint();
+        auto json = api_get(endpoint + "api/models/" + repo_id + "/refs", token);
+
+        if (!json.is_object() ||
+            !json.contains("branches") || !json["branches"].is_array()) {
+            LOG_WRN("%s: missing 'branches' for '%s'\n", __func__, repo_id.c_str());
+            return {};
+        }
+
+        fs::path refs_path = get_repo_path(repo_id) / "refs";
+        std::string name;
+        std::string commit;
+
+        for (const auto & branch : json["branches"]) {
+            if (!branch.is_object() ||
+                !branch.contains("name") || !branch["name"].is_string() ||
+                !branch.contains("targetCommit") || !branch["targetCommit"].is_string()) {
+                continue;
+            }
+            std::string _name = branch["name"].get<std::string>();
+            std::string _commit = branch["targetCommit"].get<std::string>();
+
+            if (!is_valid_subpath(refs_path, _name)) {
+                LOG_WRN("%s: skip invalid branch: %s\n", __func__, _name.c_str());
+                continue;
+            }
+            if (!is_valid_commit(_commit)) {
+                LOG_WRN("%s: skip invalid commit: %s\n", __func__, _commit.c_str());
+                continue;
+            }
+
+            if (_name == "main") {
+                name = _name;
+                commit = _commit;
+                break;
+            }
+
+            if (name.empty() || commit.empty()) {
+                name = _name;
+                commit = _commit;
+            }
+        }
+
+        if (name.empty() || commit.empty()) {
+            LOG_WRN("%s: no valid branch for '%s'\n", __func__, repo_id.c_str());
+            return {};
+        }
+
+        safe_write_file(refs_path / name, commit);
+        return commit;
+
+    } catch (const nl::json::exception & e) {
+        LOG_ERR("%s: JSON error: %s\n", __func__, e.what());
+    } catch (const std::exception & e) {
+        LOG_ERR("%s: error: %s\n", __func__, e.what());
+    }
+    return {};
+}
+
+hf_files get_repo_files(const std::string & repo_id,
+                        const std::string & token) {
+    if (!is_valid_repo_id(repo_id)) {
+        LOG_WRN("%s: invalid repository: %s\n", __func__, repo_id.c_str());
+        return {};
+    }
+
+    std::string commit = get_repo_commit(repo_id, token);
+    if (commit.empty()) {
+        LOG_WRN("%s: failed to resolve commit for %s\n", __func__, repo_id.c_str());
+        return {};
+    }
+
+    fs::path blobs_path = get_repo_path(repo_id) / "blobs";
+    fs::path commit_path = get_repo_path(repo_id) / "snapshots" / commit;
+
+    hf_files files;
+
+    try {
+        auto endpoint = get_model_endpoint();
+        auto json = api_get(endpoint + "api/models/" + repo_id + "/tree/" + commit + "?recursive=true", token);
+
+        if (!json.is_array()) {
+            LOG_WRN("%s: response is not an array for '%s'\n", __func__, repo_id.c_str());
+            return {};
+        }
+
+        for (const auto & item : json) {
+            if (!item.is_object() ||
+                !item.contains("type") || !item["type"].is_string() || item["type"] != "file" ||
+                !item.contains("path") || !item["path"].is_string()) {
+                continue;
+            }
+
+            hf_file file;
+            file.repo_id = repo_id;
+            file.path = item["path"].get<std::string>();
+
+            if (!is_valid_subpath(commit_path, file.path)) {
+                LOG_WRN("%s: skip invalid path: %s\n", __func__, file.path.c_str());
+                continue;
+            }
+
+            if (item.contains("lfs") && item["lfs"].is_object()) {
+                if (item["lfs"].contains("oid") && item["lfs"]["oid"].is_string()) {
+                    file.oid = item["lfs"]["oid"].get<std::string>();
+                }
+            } else if (item.contains("oid") && item["oid"].is_string()) {
+                file.oid = item["oid"].get<std::string>();
+            }
+
+            if (!file.oid.empty() && !is_valid_oid(file.oid)) {
+                LOG_WRN("%s: skip invalid oid: %s\n", __func__, file.oid.c_str());
+                continue;
+            }
+
+            file.url = endpoint + repo_id + "/resolve/" + commit + "/" + file.path;
+
+            fs::path final_path = commit_path / file.path;
+            file.final_path = final_path.string();
+
+            if (!file.oid.empty() && !fs::exists(final_path)) {
+                fs::path local_path = blobs_path / file.oid;
+                file.local_path = local_path.string();
+            } else {
+                file.local_path = file.final_path;
+            }
+
+            files.push_back(file);
+        }
+    } catch (const nl::json::exception & e) {
+        LOG_ERR("%s: JSON error: %s\n", __func__, e.what());
+    } catch (const std::exception & e) {
+        LOG_ERR("%s: error: %s\n", __func__, e.what());
+    }
+    return files;
+}
+
+static std::string get_cached_ref(const fs::path & repo_path) {
+    fs::path refs_path = repo_path / "refs";
+    if (!fs::is_directory(refs_path)) {
+        return {};
+    }
+    std::string fallback;
+
+    for (const auto & entry : fs::directory_iterator(refs_path)) {
+        if (!entry.is_regular_file()) {
+            continue;
+        }
+        std::ifstream f(entry.path());
+        std::string commit;
+        if (!f || !std::getline(f, commit) || commit.empty()) {
+            continue;
+        }
+        if (!is_valid_commit(commit)) {
+            LOG_WRN("%s: skip invalid commit: %s\n", __func__, commit.c_str());
+            continue;
+        }
+        if (entry.path().filename() == "main") {
+            return commit;
+        }
+        if (fallback.empty()) {
+            fallback = commit;
+        }
+    }
+    return fallback;
+}
+
+hf_files get_cached_files(const std::string & repo_id) {
+    fs::path cache_dir = get_cache_directory();
+    if (!fs::exists(cache_dir)) {
+        return {};
+    }
+
+    if (!repo_id.empty() && !is_valid_repo_id(repo_id)) {
+        LOG_WRN("%s: invalid repository: %s\n", __func__, repo_id.c_str());
+        return {};
+    }
+
+    hf_files files;
+
+    for (const auto & repo : fs::directory_iterator(cache_dir)) {
+        if (!repo.is_directory()) {
+            continue;
+        }
+        fs::path snapshots_path = repo.path() / "snapshots";
+
+        if (!fs::exists(snapshots_path)) {
+            continue;
+        }
+        std::string _repo_id = folder_name_to_repo(repo.path().filename().string());
+
+        if (!is_valid_repo_id(_repo_id)) {
+            continue;
+        }
+        if (!repo_id.empty() && _repo_id != repo_id) {
+            continue;
+        }
+        std::string commit = get_cached_ref(repo.path());
+        fs::path commit_path = snapshots_path / commit;
+
+        if (commit.empty() || !fs::is_directory(commit_path)) {
+            continue;
+        }
+        for (const auto & entry : fs::recursive_directory_iterator(commit_path)) {
+            if (!entry.is_regular_file() && !entry.is_symlink()) {
+                continue;
+            }
+            fs::path path = entry.path().lexically_relative(commit_path);
+
+            if (!path.empty()) {
+                hf_file file;
+                file.repo_id = _repo_id;
+                file.path = path.generic_string();
+                file.local_path = entry.path().string();
+                file.final_path = file.local_path;
+                files.push_back(std::move(file));
+            }
+        }
+    }
+
+    return files;
+}
+
+std::string finalize_file(const hf_file & file) {
+    static std::atomic<bool> symlinks_disabled{false};
+
+    std::error_code ec;
+    fs::path local_path(file.local_path);
+    fs::path final_path(file.final_path);
+
+    if (local_path == final_path || fs::exists(final_path, ec)) {
+        return file.final_path;
+    }
+
+    if (!fs::exists(local_path, ec)) {
+        return file.final_path;
+    }
+
+    fs::create_directories(final_path.parent_path(), ec);
+
+    if (!symlinks_disabled) {
+        fs::path target = fs::relative(local_path, final_path.parent_path(), ec);
+        if (!ec) {
+            fs::create_symlink(target, final_path, ec);
+        }
+        if (!ec) {
+            return file.final_path;
+        }
+    }
+
+    if (!symlinks_disabled.exchange(true)) {
+        LOG_WRN("%s: failed to create symlink: %s\n", __func__, ec.message().c_str());
+        LOG_WRN("%s: switching to degraded mode\n", __func__);
+    }
+
+    fs::rename(local_path, final_path, ec);
+    if (ec) {
+        LOG_WRN("%s: failed to move file to snapshots: %s\n", __func__, ec.message().c_str());
+        fs::copy(local_path, final_path, ec);
+        if (ec) {
+            LOG_ERR("%s: failed to copy file to snapshots: %s\n", __func__, ec.message().c_str());
+        }
+    }
+    return file.final_path;
+}
+
+// delete everything after this line, one day
+
+static std::pair<std::string, std::string> parse_manifest_name(std::string & filename) {
+    static const std::regex re(R"(^manifest=([^=]+)=([^=]+)=.*\.json$)");
+    std::smatch match;
+    if (std::regex_match(filename, match, re)) {
+        return {match[1].str(), match[2].str()};
+    }
+    return {};
+}
+
+static std::string make_old_cache_filename(const std::string & owner,
+                                           const std::string & repo,
+                                           const std::string & filename) {
+    auto result = owner + "_" + repo + "_" + filename;
+    string_replace_all(result, "/", "_");
+    return result;
+}
+
+static bool migrate_single_file(const fs::path    & old_cache,
+                                const std::string & owner,
+                                const std::string & repo,
+                                const nl::json    & node,
+                                const hf_files    & files) {
+
+    if (!node.contains("rfilename") ||
+        !node.contains("lfs")       ||
+        !node["lfs"].contains("sha256")) {
+        return false;
+    }
+
+    std::string path = node["rfilename"];
+    std::string sha256 = node["lfs"]["sha256"];
+
+    const hf_file * file_info = nullptr;
+    for (const auto & f : files) {
+        if (f.path == path) {
+            file_info = &f;
+            break;
+        }
+    }
+
+    std::string old_filename = make_old_cache_filename(owner, repo, path);
+    fs::path old_path = old_cache / old_filename;
+    fs::path etag_path = old_path.string() + ".etag";
+
+    if (!fs::exists(old_path)) {
+        if (fs::exists(etag_path)) {
+            LOG_WRN("%s: %s is orphan, deleting...\n", __func__, etag_path.string().c_str());
+            fs::remove(etag_path);
+        }
+        return false;
+    }
+
+    bool delete_old_path = false;
+
+    if (!file_info) {
+        LOG_WRN("%s: %s not found in current repo, deleting...\n", __func__, old_filename.c_str());
+        delete_old_path = true;
+    } else if (!sha256.empty() && !file_info->oid.empty() && sha256 != file_info->oid) {
+        LOG_WRN("%s: %s is not up to date (sha256 mismatch), deleting...\n", __func__, old_filename.c_str());
+        delete_old_path = true;
+    }
+
+    std::error_code ec;
+
+    if (delete_old_path) {
+        fs::remove(old_path, ec);
+        fs::remove(etag_path, ec);
+        return true;
+    }
+
+    fs::path new_path(file_info->local_path);
+    fs::create_directories(new_path.parent_path(), ec);
+
+    if (!fs::exists(new_path, ec)) {
+        fs::rename(old_path, new_path, ec);
+        if (ec) {
+            fs::copy_file(old_path, new_path, ec);
+            if (ec) {
+                LOG_WRN("%s: failed to move/copy %s: %s\n", __func__, old_path.string().c_str(), ec.message().c_str());
+                return false;
+            }
+        }
+        fs::remove(old_path, ec);
+    }
+    fs::remove(etag_path, ec);
+
+    std::string filename = finalize_file(*file_info);
+    LOG_INF("%s: migrated %s -> %s\n", __func__, old_filename.c_str(), filename.c_str());
+
+    return true;
+}
+
+void migrate_old_cache_to_hf_cache(const std::string & token, bool offline) {
+    fs::path old_cache = fs_get_cache_directory();
+    if (!fs::exists(old_cache)) {
+        return;
+    }
+
+    if (offline) {
+        LOG_WRN("%s: skipping migration in offline mode (will run when online)\n", __func__);
+        return; // -hf is not going to work
+    }
+
+    for (const auto & entry : fs::directory_iterator(old_cache)) {
+        if (!entry.is_regular_file()) {
+            continue;
+        }
+        auto filename = entry.path().filename().string();
+        auto [owner, repo] = parse_manifest_name(filename);
+
+        if (owner.empty() || repo.empty()) {
+            continue;
+        }
+
+        auto repo_id = owner + "/" + repo;
+        auto files = get_repo_files(repo_id, token);
+
+        if (files.empty()) {
+            LOG_WRN("%s: could not get repo files for %s, skipping\n", __func__, repo_id.c_str());
+            continue;
+        }
+
+        try {
+            std::ifstream manifest(entry.path());
+            auto json = nl::json::parse(manifest);
+
+            for (const char * key : {"ggufFile", "mmprojFile"}) {
+                if (json.contains(key)) {
+                    migrate_single_file(old_cache, owner, repo, json[key], files);
+                }
+            }
+        } catch (const std::exception & e) {
+            LOG_WRN("%s: failed to parse manifest %s: %s\n", __func__, filename.c_str(), e.what());
+            continue;
+        }
+        fs::remove(entry.path());
+    }
+}
+
+} // namespace hf_cache
diff --git a/common/hf-cache.h b/common/hf-cache.h
new file mode 100644 (file)
index 0000000..ee2e984
--- /dev/null
@@ -0,0 +1,35 @@
+#pragma once
+
+#include <string>
+#include <vector>
+
+// Ref: https://huggingface.co/docs/hub/local-cache.md
+
+namespace hf_cache {
+
+struct hf_file {
+    std::string path;
+    std::string url;
+    std::string local_path;
+    std::string final_path;
+    std::string oid;
+    std::string repo_id;
+};
+
+using hf_files = std::vector<hf_file>;
+
+// Get files from HF API
+hf_files get_repo_files(
+    const std::string & repo_id,
+    const std::string & token
+);
+
+hf_files get_cached_files(const std::string & repo_id = {});
+
+// Create snapshot path (link or move/copy) and return it
+std::string finalize_file(const hf_file & file);
+
+// TODO: Remove later
+void migrate_old_cache_to_hf_cache(const std::string & token, bool offline = false);
+
+} // namespace hf_cache
index 21173576cc7b01cb618b24e9ab4dd2f87f346a73..25beb369e630f6998a0d43644037793ce0e15aad 100644 (file)
@@ -979,37 +979,20 @@ static cmd_params parse_cmd_params(int argc, char ** argv) {
         for (size_t i = 0; i < params.hf_repo.size(); i++) {
             common_params_model model;
 
-            // step 1: no `-hff` provided, we auto-detect based on the `-hf` flag
             if (params.hf_file.empty() || params.hf_file[i].empty()) {
-                auto auto_detected = common_get_hf_file(params.hf_repo[i], params.hf_token, false);
-                if (auto_detected.repo.empty() || auto_detected.ggufFile.empty()) {
-                    exit(1);
-                }
-
-                model.name    = params.hf_repo[i];
-                model.hf_repo = auto_detected.repo;
-                model.hf_file = auto_detected.ggufFile;
+                model.hf_repo = params.hf_repo[i];
             } else {
+                model.hf_repo = params.hf_repo[i];
                 model.hf_file = params.hf_file[i];
             }
 
-            // step 2: construct the model cache path
-            std::string clean_fname = model.hf_repo + "_" + model.hf_file;
-            string_replace_all(clean_fname, "\\", "_");
-            string_replace_all(clean_fname, "/", "_");
-            model.path = fs_get_cache_file(clean_fname);
-
-            // step 3: download the model if not exists
-            std::string model_endpoint = get_model_endpoint();
-            model.url = model_endpoint + model.hf_repo + "/resolve/main/" + model.hf_file;
-
-            bool ok = common_download_model(model, params.hf_token, false);
-            if (!ok) {
-                fprintf(stderr, "error: failed to download model from %s\n", model.url.c_str());
+            auto download_result = common_download_model(model, params.hf_token);
+            if (download_result.model_path.empty()) {
+                fprintf(stderr, "error: failed to download model from HuggingFace\n");
                 exit(1);
             }
 
-            params.model.push_back(model.path);
+            params.model.push_back(download_result.model_path);
         }
     }
 
index e85f2c33829bd74acda7714bbd68bd3b9984c281..717007a446d2bb01b38c7736f2799707447dfbe9 100644 (file)
@@ -103,8 +103,8 @@ def test_router_models_max_evicts_lru():
 
     candidate_models = [
         "ggml-org/tinygemma3-GGUF:Q8_0",
-        "ggml-org/test-model-stories260K",
-        "ggml-org/test-model-stories260K-infill",
+        "ggml-org/test-model-stories260K:F32",
+        "ggml-org/test-model-stories260K-infill:F32",
     ]
 
     # Load only the first 2 models to fill the cache