]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
common : fix duplicated file name with hf_repo and hf_file (#10550)
authorXuan Son Nguyen <redacted>
Wed, 27 Nov 2024 21:30:52 +0000 (22:30 +0100)
committerGitHub <redacted>
Wed, 27 Nov 2024 21:30:52 +0000 (22:30 +0100)
common/arg.cpp
common/common.cpp
common/common.h
examples/server/tests/utils.py

index 272492e50df1562f92748a2ca1c167ba5ae4f81d..a6b7a1394f73557777d1e5625188699e82c0b7f9 100644 (file)
@@ -128,7 +128,11 @@ static void common_params_handle_model_default(common_params & params) {
             }
             params.hf_file = params.model;
         } else if (params.model.empty()) {
-            params.model = fs_get_cache_file(string_split<std::string>(params.hf_file, '/').back());
+            // this is to avoid different repo having same file name, or same file name in different subdirs
+            std::string filename = params.hf_repo + "_" + params.hf_file;
+            // to make sure we don't have any slashes in the filename
+            string_replace_all(filename, "/", "_");
+            params.model = fs_get_cache_file(filename);
         }
     } else if (!params.model_url.empty()) {
         if (params.model.empty()) {
index 09ec9f2388afb61436b454f6d989341f99a25ab0..2b2f0009897f37fc114aab39f451de6f0390e678 100644 (file)
@@ -829,9 +829,9 @@ struct common_init_result common_init_from_params(common_params & params) {
     llama_model * model = nullptr;
 
     if (!params.hf_repo.empty() && !params.hf_file.empty()) {
-        model = common_load_model_from_hf(params.hf_repo.c_str(), params.hf_file.c_str(), params.model.c_str(), params.hf_token.c_str(), mparams);
+        model = common_load_model_from_hf(params.hf_repo, params.hf_file, params.model, params.hf_token, mparams);
     } else if (!params.model_url.empty()) {
-        model = common_load_model_from_url(params.model_url.c_str(), params.model.c_str(), params.hf_token.c_str(), mparams);
+        model = common_load_model_from_url(params.model_url, params.model, params.hf_token, mparams);
     } else {
         model = llama_load_model_from_file(params.model.c_str(), mparams);
     }
@@ -1342,17 +1342,17 @@ static bool common_download_file(const std::string & url, const std::string & pa
 }
 
 struct llama_model * common_load_model_from_url(
-        const char * model_url,
-        const char * path_model,
-        const char * hf_token,
+        const std::string & model_url,
+        const std::string & local_path,
+        const std::string & hf_token,
         const struct llama_model_params & params) {
     // Basic validation of the model_url
-    if (!model_url || strlen(model_url) == 0) {
+    if (model_url.empty()) {
         LOG_ERR("%s: invalid model_url\n", __func__);
         return NULL;
     }
 
-    if (!common_download_file(model_url, path_model, hf_token)) {
+    if (!common_download_file(model_url, local_path, hf_token)) {
         return NULL;
     }
 
@@ -1363,9 +1363,9 @@ struct llama_model * common_load_model_from_url(
             /*.no_alloc = */ true,
             /*.ctx      = */ NULL,
         };
-        auto * ctx_gguf = gguf_init_from_file(path_model, gguf_params);
+        auto * ctx_gguf = gguf_init_from_file(local_path.c_str(), gguf_params);
         if (!ctx_gguf) {
-            LOG_ERR("\n%s:  failed to load input GGUF from %s\n", __func__, path_model);
+            LOG_ERR("\n%s:  failed to load input GGUF from %s\n", __func__, local_path.c_str());
             return NULL;
         }
 
@@ -1384,13 +1384,13 @@ struct llama_model * common_load_model_from_url(
         // Verify the first split file format
         // and extract split URL and PATH prefixes
         {
-            if (!llama_split_prefix(split_prefix, sizeof(split_prefix), path_model, 0, n_split)) {
-                LOG_ERR("\n%s: unexpected model file name: %s n_split=%d\n", __func__, path_model, n_split);
+            if (!llama_split_prefix(split_prefix, sizeof(split_prefix), local_path.c_str(), 0, n_split)) {
+                LOG_ERR("\n%s: unexpected model file name: %s n_split=%d\n", __func__, local_path.c_str(), n_split);
                 return NULL;
             }
 
-            if (!llama_split_prefix(split_url_prefix, sizeof(split_url_prefix), model_url, 0, n_split)) {
-                LOG_ERR("\n%s: unexpected model url: %s n_split=%d\n", __func__, model_url, n_split);
+            if (!llama_split_prefix(split_url_prefix, sizeof(split_url_prefix), model_url.c_str(), 0, n_split)) {
+                LOG_ERR("\n%s: unexpected model url: %s n_split=%d\n", __func__, model_url.c_str(), n_split);
                 return NULL;
             }
         }
@@ -1417,14 +1417,14 @@ struct llama_model * common_load_model_from_url(
         }
     }
 
-    return llama_load_model_from_file(path_model, params);
+    return llama_load_model_from_file(local_path.c_str(), params);
 }
 
 struct llama_model * common_load_model_from_hf(
-        const char * repo,
-        const char * model,
-        const char * path_model,
-        const char * hf_token,
+        const std::string & repo,
+        const std::string & remote_path,
+        const std::string & local_path,
+        const std::string & hf_token,
         const struct llama_model_params & params) {
     // construct hugging face model url:
     //
@@ -1438,27 +1438,27 @@ struct llama_model * common_load_model_from_hf(
     std::string model_url = "https://huggingface.co/";
     model_url += repo;
     model_url += "/resolve/main/";
-    model_url += model;
+    model_url += remote_path;
 
-    return common_load_model_from_url(model_url.c_str(), path_model, hf_token, params);
+    return common_load_model_from_url(model_url, local_path, hf_token, params);
 }
 
 #else
 
 struct llama_model * common_load_model_from_url(
-        const char * /*model_url*/,
-        const char * /*path_model*/,
-        const char * /*hf_token*/,
+        const std::string & /*model_url*/,
+        const std::string & /*local_path*/,
+        const std::string & /*hf_token*/,
         const struct llama_model_params & /*params*/) {
     LOG_WRN("%s: llama.cpp built without libcurl, downloading from an url not supported.\n", __func__);
     return nullptr;
 }
 
 struct llama_model * common_load_model_from_hf(
-        const char * /*repo*/,
-        const char * /*model*/,
-        const char * /*path_model*/,
-        const char * /*hf_token*/,
+        const std::string & /*repo*/,
+        const std::string & /*remote_path*/,
+        const std::string & /*local_path*/,
+        const std::string & /*hf_token*/,
         const struct llama_model_params & /*params*/) {
     LOG_WRN("%s: llama.cpp built without libcurl, downloading from Hugging Face not supported.\n", __func__);
     return nullptr;
index 286642db241587a8af8d737ff62ba04be2e1a944..9b1508a15fb43840bca968bc2374dc2e674871e6 100644 (file)
@@ -470,8 +470,17 @@ struct llama_model_params     common_model_params_to_llama  (      common_params
 struct llama_context_params   common_context_params_to_llama(const common_params & params);
 struct ggml_threadpool_params ggml_threadpool_params_from_cpu_params(const cpu_params & params);
 
-struct llama_model * common_load_model_from_url(const char * model_url, const char * path_model, const char * hf_token, const struct llama_model_params & params);
-struct llama_model * common_load_model_from_hf(const char * repo, const char * file, const char * path_model, const char * hf_token, const struct llama_model_params & params);
+struct llama_model * common_load_model_from_url(
+    const std::string & model_url,
+    const std::string & local_path,
+    const std::string & hf_token,
+    const struct llama_model_params & params);
+struct llama_model * common_load_model_from_hf(
+    const std::string & repo,
+    const std::string & remote_path,
+    const std::string & local_path,
+    const std::string & hf_token,
+    const struct llama_model_params & params);
 
 // clear LoRA adapters from context, then apply new list of adapters
 void common_lora_adapters_apply(struct llama_context * ctx, std::vector<common_lora_adapter_container> & lora_adapters);
index bc590bcb31547dbf441ac3d4ace97dab2a7120a9..e31743c505d8e02d6f92bbe470fde52332a3e7d3 100644 (file)
@@ -319,7 +319,6 @@ class ServerPreset:
         server.model_hf_repo = "ggml-org/models"
         server.model_hf_file = "jina-reranker-v1-tiny-en/ggml-model-f16.gguf"
         server.model_alias = "jina-reranker"
-        server.model_file = "./tmp/jina-reranker-v1-tiny-en.gguf"
         server.n_ctx = 512
         server.n_batch = 512
         server.n_slots = 1