]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
added support for Authorization Bearer tokens when downloading model (#8307)
authorDerrick T. Woolworth <redacted>
Sat, 6 Jul 2024 20:32:04 +0000 (15:32 -0500)
committerGitHub <redacted>
Sat, 6 Jul 2024 20:32:04 +0000 (22:32 +0200)
* added support for Authorization Bearer tokens

* removed auth_token, removed set_ function, other small fixes

* Update common/common.cpp

---------

Co-authored-by: Xuan Son Nguyen <redacted>
common/common.cpp
common/common.h

index c548bcb2857a8ad1e127c279059834a69ae5b3ce..fc0f3b350ec6e7d9731ccdf6cf4e6ada89a5cfff 100644 (file)
@@ -190,6 +190,12 @@ int32_t cpu_get_num_math() {
 // CLI argument parsing
 //
 
+void gpt_params_handle_hf_token(gpt_params & params) {
+    if (params.hf_token.empty() && std::getenv("HF_TOKEN")) {
+        params.hf_token = std::getenv("HF_TOKEN");
+    }
+}
+
 void gpt_params_handle_model_default(gpt_params & params) {
     if (!params.hf_repo.empty()) {
         // short-hand to avoid specifying --hf-file -> default it to --model
@@ -237,6 +243,8 @@ bool gpt_params_parse_ex(int argc, char ** argv, gpt_params & params) {
 
     gpt_params_handle_model_default(params);
 
+    gpt_params_handle_hf_token(params);
+
     if (params.escape) {
         string_process_escapes(params.prompt);
         string_process_escapes(params.input_prefix);
@@ -652,6 +660,14 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa
         params.model_url = argv[i];
         return true;
     }
+    if (arg == "-hft" || arg == "--hf-token") {
+        if (++i >= argc) {
+          invalid_param = true;
+          return true;
+        }
+        params.hf_token = argv[i];
+        return true;
+    }
     if (arg == "-hfr" || arg == "--hf-repo") {
         CHECK_ARG
         params.hf_repo = argv[i];
@@ -1576,6 +1592,7 @@ void gpt_params_print_usage(int /*argc*/, char ** argv, const gpt_params & param
     options.push_back({ "*",           "-mu,   --model-url MODEL_URL",  "model download url (default: unused)" });
     options.push_back({ "*",           "-hfr,  --hf-repo REPO",         "Hugging Face model repository (default: unused)" });
     options.push_back({ "*",           "-hff,  --hf-file FILE",         "Hugging Face model file (default: unused)" });
+    options.push_back({ "*",           "-hft,  --hf-token TOKEN",       "Hugging Face access token (default: value from HF_TOKEN environment variable)" });
 
     options.push_back({ "retrieval" });
     options.push_back({ "retrieval",   "       --context-file FNAME",   "file to load context from (repeat to specify multiple files)" });
@@ -2015,9 +2032,9 @@ std::tuple<struct llama_model *, struct llama_context *> llama_init_from_gpt_par
     llama_model * model = nullptr;
 
     if (!params.hf_repo.empty() && !params.hf_file.empty()) {
-        model = llama_load_model_from_hf(params.hf_repo.c_str(), params.hf_file.c_str(), params.model.c_str(), mparams);
+        model = llama_load_model_from_hf(params.hf_repo.c_str(), params.hf_file.c_str(), params.model.c_str(), params.hf_token.c_str(), mparams);
     } else if (!params.model_url.empty()) {
-        model = llama_load_model_from_url(params.model_url.c_str(), params.model.c_str(), mparams);
+        model = llama_load_model_from_url(params.model_url.c_str(), params.model.c_str(), params.hf_token.c_str(), mparams);
     } else {
         model = llama_load_model_from_file(params.model.c_str(), mparams);
     }
@@ -2205,7 +2222,7 @@ static bool starts_with(const std::string & str, const std::string & prefix) {
     return str.rfind(prefix, 0) == 0;
 }
 
-static bool llama_download_file(const std::string & url, const std::string & path) {
+static bool llama_download_file(const std::string & url, const std::string & path, const std::string & hf_token) {
 
     // Initialize libcurl
     std::unique_ptr<CURL, decltype(&curl_easy_cleanup)> curl(curl_easy_init(), &curl_easy_cleanup);
@@ -2220,6 +2237,15 @@ static bool llama_download_file(const std::string & url, const std::string & pat
     curl_easy_setopt(curl.get(), CURLOPT_URL, url.c_str());
     curl_easy_setopt(curl.get(), CURLOPT_FOLLOWLOCATION, 1L);
 
+    // Check if hf-token or bearer-token was specified
+    if (!hf_token.empty()) {
+      std::string auth_header = "Authorization: Bearer ";
+      auth_header += hf_token.c_str();
+      struct curl_slist *http_headers = NULL;
+      http_headers = curl_slist_append(http_headers, auth_header.c_str());
+      curl_easy_setopt(curl.get(), CURLOPT_HTTPHEADER, http_headers);
+    }
+
 #if defined(_WIN32)
     // CURLSSLOPT_NATIVE_CA tells libcurl to use standard certificate store of
     //   operating system. Currently implemented under MS-Windows.
@@ -2415,6 +2441,7 @@ static bool llama_download_file(const std::string & url, const std::string & pat
 struct llama_model * llama_load_model_from_url(
         const char * model_url,
         const char * path_model,
+        const char * hf_token,
         const struct llama_model_params & params) {
     // Basic validation of the model_url
     if (!model_url || strlen(model_url) == 0) {
@@ -2422,7 +2449,7 @@ struct llama_model * llama_load_model_from_url(
         return NULL;
     }
 
-    if (!llama_download_file(model_url, path_model)) {
+    if (!llama_download_file(model_url, path_model, hf_token)) {
         return NULL;
     }
 
@@ -2470,14 +2497,14 @@ struct llama_model * llama_load_model_from_url(
         // Prepare download in parallel
         std::vector<std::future<bool>> futures_download;
         for (int idx = 1; idx < n_split; idx++) {
-            futures_download.push_back(std::async(std::launch::async, [&split_prefix, &split_url_prefix, &n_split](int download_idx) -> bool {
+            futures_download.push_back(std::async(std::launch::async, [&split_prefix, &split_url_prefix, &n_split, hf_token](int download_idx) -> bool {
                 char split_path[PATH_MAX] = {0};
                 llama_split_path(split_path, sizeof(split_path), split_prefix, download_idx, n_split);
 
                 char split_url[LLAMA_CURL_MAX_URL_LENGTH] = {0};
                 llama_split_path(split_url, sizeof(split_url), split_url_prefix, download_idx, n_split);
 
-                return llama_download_file(split_url, split_path);
+                return llama_download_file(split_url, split_path, hf_token);
             }, idx));
         }
 
@@ -2496,6 +2523,7 @@ struct llama_model * llama_load_model_from_hf(
         const char * repo,
         const char * model,
         const char * path_model,
+        const char * hf_token,
         const struct llama_model_params & params) {
     // construct hugging face model url:
     //
@@ -2511,7 +2539,7 @@ struct llama_model * llama_load_model_from_hf(
     model_url += "/resolve/main/";
     model_url += model;
 
-    return llama_load_model_from_url(model_url.c_str(), path_model, params);
+    return llama_load_model_from_url(model_url.c_str(), path_model, hf_token, params);
 }
 
 #else
@@ -2519,6 +2547,7 @@ struct llama_model * llama_load_model_from_hf(
 struct llama_model * llama_load_model_from_url(
         const char * /*model_url*/,
         const char * /*path_model*/,
+        const char * /*hf_token*/,
         const struct llama_model_params & /*params*/) {
     fprintf(stderr, "%s: llama.cpp built without libcurl, downloading from an url not supported.\n", __func__);
     return nullptr;
@@ -2528,6 +2557,7 @@ struct llama_model * llama_load_model_from_hf(
         const char * /*repo*/,
         const char * /*model*/,
         const char * /*path_model*/,
+        const char * /*hf_token*/,
         const struct llama_model_params & /*params*/) {
     fprintf(stderr, "%s: llama.cpp built without libcurl, downloading from Hugging Face not supported.\n", __func__);
     return nullptr;
index dabf02b598b3b0519ad55165d548e38c896cb10b..184a53dc090646206ff980ee4e74b6b8bf16042b 100644 (file)
@@ -108,6 +108,7 @@ struct gpt_params {
     std::string model_draft          = ""; // draft model for speculative decoding
     std::string model_alias          = "unknown"; // model alias
     std::string model_url            = ""; // model url to download
+    std::string hf_token             = ""; // HF token
     std::string hf_repo              = ""; // HF repo
     std::string hf_file              = ""; // HF file
     std::string prompt               = "";
@@ -256,6 +257,7 @@ struct gpt_params {
     bool spm_infill = false; // suffix/prefix/middle pattern for infill
 };
 
+void gpt_params_handle_hf_token(gpt_params & params);
 void gpt_params_handle_model_default(gpt_params & params);
 
 bool gpt_params_parse_ex   (int argc, char ** argv, gpt_params & params);
@@ -311,8 +313,8 @@ std::tuple<struct llama_model *, struct llama_context *> llama_init_from_gpt_par
 struct llama_model_params   llama_model_params_from_gpt_params  (const gpt_params & params);
 struct llama_context_params llama_context_params_from_gpt_params(const gpt_params & params);
 
-struct llama_model * llama_load_model_from_url(const char * model_url, const char * path_model, const struct llama_model_params & params);
-struct llama_model * llama_load_model_from_hf(const char * repo, const char * file, const char * path_model, const struct llama_model_params & params);
+struct llama_model * llama_load_model_from_url(const char * model_url, const char * path_model, const char * hf_token, const struct llama_model_params & params);
+struct llama_model * llama_load_model_from_hf(const char * repo, const char * file, const char * path_model, const char * hf_token, const struct llama_model_params & params);
 
 // Batch utils