]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
common : support tag-based --hf-repo like on ollama (#11195)
authorXuan Son Nguyen <redacted>
Mon, 13 Jan 2025 12:56:23 +0000 (13:56 +0100)
committerGitHub <redacted>
Mon, 13 Jan 2025 12:56:23 +0000 (13:56 +0100)
* common : support tag-based hf_repo like on ollama

* fix build

* various fixes

* small fixes

* fix style

* fix windows build?

* move common_get_hf_file to common.cpp

* fix complain with noreturn

common/arg.cpp
common/common.cpp
common/common.h

index 27886b84e862c4049dd208fd7f3908e447f67847..1457a360faab25f6e2234f1d4722f6ab34002d96 100644 (file)
@@ -130,17 +130,26 @@ std::string common_arg::to_string() {
 
 static void common_params_handle_model_default(
         std::string & model,
-        std::string & model_url,
+        const std::string & model_url,
         std::string & hf_repo,
-        std::string & hf_file) {
+        std::string & hf_file,
+        const std::string & hf_token) {
     if (!hf_repo.empty()) {
         // short-hand to avoid specifying --hf-file -> default it to --model
         if (hf_file.empty()) {
             if (model.empty()) {
-                throw std::invalid_argument("error: --hf-repo requires either --hf-file or --model\n");
+                auto auto_detected = common_get_hf_file(hf_repo, hf_token);
+                if (auto_detected.first.empty() || auto_detected.second.empty()) {
+                    exit(1); // built without CURL, error message already printed
+                }
+                hf_repo = auto_detected.first;
+                hf_file = auto_detected.second;
+            } else {
+                hf_file = model;
             }
-            hf_file = model;
-        } else if (model.empty()) {
+        }
+        // make sure model path is present (for caching purposes)
+        if (model.empty()) {
             // this is to avoid different repo having same file name, or same file name in different subdirs
             std::string filename = hf_repo + "_" + hf_file;
             // to make sure we don't have any slashes in the filename
@@ -290,8 +299,8 @@ static bool common_params_parse_ex(int argc, char ** argv, common_params_context
     }
 
     // TODO: refactor model params in a common struct
-    common_params_handle_model_default(params.model,         params.model_url,         params.hf_repo,         params.hf_file);
-    common_params_handle_model_default(params.vocoder.model, params.vocoder.model_url, params.vocoder.hf_repo, params.vocoder.hf_file);
+    common_params_handle_model_default(params.model,         params.model_url,         params.hf_repo,         params.hf_file,         params.hf_token);
+    common_params_handle_model_default(params.vocoder.model, params.vocoder.model_url, params.vocoder.hf_repo, params.vocoder.hf_file, params.hf_token);
 
     if (params.escape) {
         string_process_escapes(params.prompt);
@@ -1583,21 +1592,23 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
         }
     ).set_env("LLAMA_ARG_MODEL_URL"));
     add_opt(common_arg(
-        {"-hfr", "--hf-repo"}, "REPO",
-        "Hugging Face model repository (default: unused)",
+        {"-hf", "-hfr", "--hf-repo"}, "<user>/<model>[:quant]",
+        "Hugging Face model repository; quant is optional, case-insensitive, default to Q4_K_M, or falls back to the first file in the repo if Q4_K_M doesn't exist.\n"
+        "example: unsloth/phi-4-GGUF:q4_k_m\n"
+        "(default: unused)",
         [](common_params & params, const std::string & value) {
             params.hf_repo = value;
         }
     ).set_env("LLAMA_ARG_HF_REPO"));
     add_opt(common_arg(
         {"-hff", "--hf-file"}, "FILE",
-        "Hugging Face model file (default: unused)",
+        "Hugging Face model file. If specified, it will override the quant in --hf-repo (default: unused)",
         [](common_params & params, const std::string & value) {
             params.hf_file = value;
         }
     ).set_env("LLAMA_ARG_HF_FILE"));
     add_opt(common_arg(
-        {"-hfrv", "--hf-repo-v"}, "REPO",
+        {"-hfv", "-hfrv", "--hf-repo-v"}, "<user>/<model>[:quant]",
         "Hugging Face model repository for the vocoder model (default: unused)",
         [](common_params & params, const std::string & value) {
             params.vocoder.hf_repo = value;
index 1a2e1524799d3e540ecb0fb0e9ee52a5a9904d13..a6f9252b27a9fb37388535d5d9a29de4e0af5699 100644 (file)
 #include <sys/syslimits.h>
 #endif
 #define LLAMA_CURL_MAX_URL_LENGTH 2084 // Maximum URL Length in Chrome: 2083
+
+//
+// CURL utils
+//
+
+using curl_ptr = std::unique_ptr<CURL, decltype(&curl_easy_cleanup)>;
+
+// cannot use unique_ptr for curl_slist, because we cannot update without destroying the old one
+struct curl_slist_ptr {
+    struct curl_slist * ptr = nullptr;
+    ~curl_slist_ptr() {
+        if (ptr) {
+            curl_slist_free_all(ptr);
+        }
+    }
+};
 #endif // LLAMA_USE_CURL
 
 using json = nlohmann::ordered_json;
@@ -1130,7 +1146,8 @@ static bool curl_perform_with_retry(const std::string & url, CURL * curl, int ma
 
 static bool common_download_file(const std::string & url, const std::string & path, const std::string & hf_token) {
     // Initialize libcurl
-    std::unique_ptr<CURL, decltype(&curl_easy_cleanup)> curl(curl_easy_init(), &curl_easy_cleanup);
+    curl_ptr       curl(curl_easy_init(), &curl_easy_cleanup);
+    curl_slist_ptr http_headers;
     if (!curl) {
         LOG_ERR("%s: error initializing libcurl\n", __func__);
         return false;
@@ -1144,11 +1161,9 @@ static bool common_download_file(const std::string & url, const std::string & pa
 
     // Check if hf-token or bearer-token was specified
     if (!hf_token.empty()) {
-      std::string auth_header = "Authorization: Bearer ";
-      auth_header += hf_token.c_str();
-      struct curl_slist *http_headers = NULL;
-      http_headers = curl_slist_append(http_headers, auth_header.c_str());
-      curl_easy_setopt(curl.get(), CURLOPT_HTTPHEADER, http_headers);
+        std::string auth_header = "Authorization: Bearer " + hf_token;
+        http_headers.ptr = curl_slist_append(http_headers.ptr, auth_header.c_str());
+        curl_easy_setopt(curl.get(), CURLOPT_HTTPHEADER, http_headers.ptr);
     }
 
 #if defined(_WIN32)
@@ -1444,6 +1459,80 @@ struct llama_model * common_load_model_from_hf(
     return common_load_model_from_url(model_url, local_path, hf_token, params);
 }
 
+/**
+ * Allow getting the HF file from the HF repo with tag (like ollama), for example:
+ * - bartowski/Llama-3.2-3B-Instruct-GGUF:q4
+ * - bartowski/Llama-3.2-3B-Instruct-GGUF:Q4_K_M
+ * - bartowski/Llama-3.2-3B-Instruct-GGUF:q5_k_s
+ * Tag is optional, default to "latest" (meaning it checks for Q4_K_M first, then Q4, then if not found, return the first GGUF file in repo)
+ *
+ * Return pair of <repo, file> (with "repo" already having tag removed)
+ *
+ * Note: we use the Ollama-compatible HF API, but not using the blobId. Instead, we use the special "ggufFile" field which returns the value for "hf_file". This is done to be backward-compatible with existing cache files.
+ */
+std::pair<std::string, std::string> common_get_hf_file(const std::string & hf_repo_with_tag, const std::string & hf_token) {
+    auto parts = string_split<std::string>(hf_repo_with_tag, ':');
+    std::string tag = parts.size() > 1 ? parts.back() : "latest";
+    std::string hf_repo = parts[0];
+    if (string_split<std::string>(hf_repo, '/').size() != 2) {
+        throw std::invalid_argument("error: invalid HF repo format, expected <user>/<model>[:quant]\n");
+    }
+
+    // fetch model info from Hugging Face Hub API
+    json model_info;
+    curl_ptr       curl(curl_easy_init(), &curl_easy_cleanup);
+    curl_slist_ptr http_headers;
+    std::string res_str;
+    std::string url = "https://huggingface.co/v2/" + hf_repo + "/manifests/" + tag;
+    curl_easy_setopt(curl.get(), CURLOPT_URL, url.c_str());
+    curl_easy_setopt(curl.get(), CURLOPT_NOPROGRESS, 1L);
+    typedef size_t(*CURLOPT_WRITEFUNCTION_PTR)(void * ptr, size_t size, size_t nmemb, void * data);
+    auto write_callback = [](void * ptr, size_t size, size_t nmemb, void * data) -> size_t {
+        static_cast<std::string *>(data)->append((char * ) ptr, size * nmemb);
+        return size * nmemb;
+    };
+    curl_easy_setopt(curl.get(), CURLOPT_WRITEFUNCTION, static_cast<CURLOPT_WRITEFUNCTION_PTR>(write_callback));
+    curl_easy_setopt(curl.get(), CURLOPT_WRITEDATA, &res_str);
+#if defined(_WIN32)
+    curl_easy_setopt(curl.get(), CURLOPT_SSL_OPTIONS, CURLSSLOPT_NATIVE_CA);
+#endif
+    if (!hf_token.empty()) {
+        std::string auth_header = "Authorization: Bearer " + hf_token;
+        http_headers.ptr = curl_slist_append(http_headers.ptr, auth_header.c_str());
+    }
+    // Important: the User-Agent must be "llama-cpp" to get the "ggufFile" field in the response
+    http_headers.ptr = curl_slist_append(http_headers.ptr, "User-Agent: llama-cpp");
+    http_headers.ptr = curl_slist_append(http_headers.ptr, "Accept: application/json");
+    curl_easy_setopt(curl.get(), CURLOPT_HTTPHEADER, http_headers.ptr);
+
+    CURLcode res = curl_easy_perform(curl.get());
+
+    if (res != CURLE_OK) {
+        throw std::runtime_error("error: cannot make GET request to HF API");
+    }
+
+    long res_code;
+    curl_easy_getinfo(curl.get(), CURLINFO_RESPONSE_CODE, &res_code);
+    if (res_code == 200) {
+        model_info = json::parse(res_str);
+    } else if (res_code == 401) {
+        throw std::runtime_error("error: model is private or does not exist; if you are accessing a gated model, please provide a valid HF token");
+    } else {
+        throw std::runtime_error(string_format("error from HF API, response code: %ld, data: %s", res_code, res_str.c_str()));
+    }
+
+    // check response
+    if (!model_info.contains("ggufFile")) {
+        throw std::runtime_error("error: model does not have ggufFile");
+    }
+    json & gguf_file = model_info.at("ggufFile");
+    if (!gguf_file.contains("rfilename")) {
+        throw std::runtime_error("error: ggufFile does not have rfilename");
+    }
+
+    return std::make_pair(hf_repo, gguf_file.at("rfilename"));
+}
+
 #else
 
 struct llama_model * common_load_model_from_url(
@@ -1465,6 +1554,11 @@ struct llama_model * common_load_model_from_hf(
     return nullptr;
 }
 
+std::pair<std::string, std::string> common_get_hf_file(const std::string &, const std::string &) {
+    LOG_WRN("%s: llama.cpp built without libcurl, downloading from Hugging Face not supported.\n", __func__);
+    return std::make_pair("", "");
+}
+
 #endif // LLAMA_USE_CURL
 
 //
index d523948b03e30994546444a252150943b7ce23d0..c86a4ef39212bdab3eb107a258c9b3dc0b7762c3 100644 (file)
@@ -454,6 +454,11 @@ static bool string_starts_with(const std::string & str,
     return str.rfind(prefix, 0) == 0;
 }
 
+static bool string_ends_with(const std::string & str,
+                               const std::string & suffix) {  // While we wait for C++20's std::string::ends_with...
+    return str.size() >= suffix.size() && str.compare(str.size()-suffix.size(), suffix.size(), suffix) == 0;
+}
+
 bool string_parse_kv_override(const char * data, std::vector<llama_model_kv_override> & overrides);
 void string_process_escapes(std::string & input);
 
@@ -501,6 +506,9 @@ struct llama_model * common_load_model_from_hf(
     const std::string & local_path,
     const std::string & hf_token,
     const struct llama_model_params & params);
+std::pair<std::string, std::string> common_get_hf_file(
+    const std::string & hf_repo_with_tag,
+    const std::string & hf_token);
 
 // clear LoRA adapters from context, then apply new list of adapters
 void common_set_adapter_lora(struct llama_context * ctx, std::vector<common_adapter_lora_info> & lora);