]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
Add new hf protocol for ollama (#11449)
authorEric Curtin <redacted>
Mon, 27 Jan 2025 18:36:10 +0000 (19:36 +0100)
committerGitHub <redacted>
Mon, 27 Jan 2025 18:36:10 +0000 (19:36 +0100)
https://huggingface.co/docs/hub/en/ollama

Signed-off-by: Eric Curtin <redacted>
examples/run/run.cpp

index 92a49eb744fda39969bf9a5c01e496b5c7eed793..8a0db74b62d05e046cbd195a74896047aca98cf0 100644 (file)
@@ -319,6 +319,10 @@ class HttpClient {
   public:
     int init(const std::string & url, const std::vector<std::string> & headers, const std::string & output_file,
              const bool progress, std::string * response_str = nullptr) {
+        if (std::filesystem::exists(output_file)) {
+            return 0;
+        }
+
         std::string output_file_partial;
         curl = curl_easy_init();
         if (!curl) {
@@ -558,13 +562,14 @@ class LlamaData {
         }
 
         sampler = initialize_sampler(opt);
+
         return 0;
     }
 
   private:
 #ifdef LLAMA_USE_CURL
-    int download(const std::string & url, const std::vector<std::string> & headers, const std::string & output_file,
-                 const bool progress, std::string * response_str = nullptr) {
+    int download(const std::string & url, const std::string & output_file, const bool progress,
+                 const std::vector<std::string> & headers = {}, std::string * response_str = nullptr) {
         HttpClient http;
         if (http.init(url, headers, output_file, progress, response_str)) {
             return 1;
@@ -573,48 +578,85 @@ class LlamaData {
         return 0;
     }
 #else
-    int download(const std::string &, const std::vector<std::string> &, const std::string &, const bool,
+    int download(const std::string &, const std::string &, const bool, const std::vector<std::string> & = {},
                  std::string * = nullptr) {
         printe("%s: llama.cpp built without libcurl, downloading from an url not supported.\n", __func__);
+
         return 1;
     }
 #endif
 
-    int huggingface_dl(const std::string & model, const std::vector<std::string> headers, const std::string & bn) {
+    // Helper function to handle model tag extraction and URL construction
+    std::pair<std::string, std::string> extract_model_and_tag(std::string & model, const std::string & base_url) {
+        std::string  model_tag = "latest";
+        const size_t colon_pos = model.find(':');
+        if (colon_pos != std::string::npos) {
+            model_tag = model.substr(colon_pos + 1);
+            model     = model.substr(0, colon_pos);
+        }
+
+        std::string url = base_url + model + "/manifests/" + model_tag;
+
+        return { model, url };
+    }
+
+    // Helper function to download and parse the manifest
+    int download_and_parse_manifest(const std::string & url, const std::vector<std::string> & headers,
+                                    nlohmann::json & manifest) {
+        std::string manifest_str;
+        int         ret = download(url, "", false, headers, &manifest_str);
+        if (ret) {
+            return ret;
+        }
+
+        manifest = nlohmann::json::parse(manifest_str);
+
+        return 0;
+    }
+
+    int huggingface_dl(std::string & model, const std::string & bn) {
         // Find the second occurrence of '/' after protocol string
         size_t pos = model.find('/');
         pos        = model.find('/', pos + 1);
+        std::string              hfr, hff;
+        std::vector<std::string> headers = { "User-Agent: llama-cpp", "Accept: application/json" };
+        std::string              url;
+
         if (pos == std::string::npos) {
-            return 1;
+            auto [model_name, manifest_url] = extract_model_and_tag(model, "https://huggingface.co/v2/");
+            hfr                             = model_name;
+
+            nlohmann::json manifest;
+            int            ret = download_and_parse_manifest(manifest_url, headers, manifest);
+            if (ret) {
+                return ret;
+            }
+
+            hff = manifest["ggufFile"]["rfilename"];
+        } else {
+            hfr = model.substr(0, pos);
+            hff = model.substr(pos + 1);
         }
 
-        const std::string hfr = model.substr(0, pos);
-        const std::string hff = model.substr(pos + 1);
-        const std::string url = "https://huggingface.co/" + hfr + "/resolve/main/" + hff;
-        return download(url, headers, bn, true);
+        url = "https://huggingface.co/" + hfr + "/resolve/main/" + hff;
+
+        return download(url, bn, true, headers);
     }
 
-    int ollama_dl(std::string & model, const std::vector<std::string> headers, const std::string & bn) {
+    int ollama_dl(std::string & model, const std::string & bn) {
+        const std::vector<std::string> headers = { "Accept: application/vnd.docker.distribution.manifest.v2+json" };
         if (model.find('/') == std::string::npos) {
             model = "library/" + model;
         }
 
-        std::string model_tag = "latest";
-        size_t      colon_pos = model.find(':');
-        if (colon_pos != std::string::npos) {
-            model_tag = model.substr(colon_pos + 1);
-            model     = model.substr(0, colon_pos);
-        }
-
-        std::string manifest_url = "https://registry.ollama.ai/v2/" + model + "/manifests/" + model_tag;
-        std::string manifest_str;
-        const int   ret = download(manifest_url, headers, "", false, &manifest_str);
+        auto [model_name, manifest_url] = extract_model_and_tag(model, "https://registry.ollama.ai/v2/");
+        nlohmann::json manifest;
+        int            ret = download_and_parse_manifest(manifest_url, {}, manifest);
         if (ret) {
             return ret;
         }
 
-        nlohmann::json manifest = nlohmann::json::parse(manifest_str);
-        std::string    layer;
+        std::string layer;
         for (const auto & l : manifest["layers"]) {
             if (l["mediaType"] == "application/vnd.ollama.image.model") {
                 layer = l["digest"];
@@ -622,8 +664,9 @@ class LlamaData {
             }
         }
 
-        std::string blob_url = "https://registry.ollama.ai/v2/" + model + "/blobs/" + layer;
-        return download(blob_url, headers, bn, true);
+        std::string blob_url = "https://registry.ollama.ai/v2/" + model_name + "/blobs/" + layer;
+
+        return download(blob_url, bn, true, headers);
     }
 
     std::string basename(const std::string & path) {
@@ -653,22 +696,18 @@ class LlamaData {
             return ret;
         }
 
-        const std::string              bn      = basename(model_);
-        const std::vector<std::string> headers = { "--header",
-                                                   "Accept: application/vnd.docker.distribution.manifest.v2+json" };
+        const std::string bn = basename(model_);
         if (string_starts_with(model_, "hf://") || string_starts_with(model_, "huggingface://")) {
             rm_until_substring(model_, "://");
-            ret = huggingface_dl(model_, headers, bn);
+            ret = huggingface_dl(model_, bn);
         } else if (string_starts_with(model_, "hf.co/")) {
             rm_until_substring(model_, "hf.co/");
-            ret = huggingface_dl(model_, headers, bn);
-        } else if (string_starts_with(model_, "ollama://")) {
-            rm_until_substring(model_, "://");
-            ret = ollama_dl(model_, headers, bn);
+            ret = huggingface_dl(model_, bn);
         } else if (string_starts_with(model_, "https://")) {
-            ret = download(model_, headers, bn, true);
-        } else {
-            ret = ollama_dl(model_, headers, bn);
+            ret = download(model_, bn, true);
+        } else {  // ollama:// or nothing
+            rm_until_substring(model_, "://");
+            ret = ollama_dl(model_, bn);
         }
 
         model_ = bn;