// CLI argument parsing
//
+void gpt_params_handle_hf_token(gpt_params & params) {
+ if (params.hf_token.empty() && std::getenv("HF_TOKEN")) {
+ params.hf_token = std::getenv("HF_TOKEN");
+ }
+}
+
void gpt_params_handle_model_default(gpt_params & params) {
if (!params.hf_repo.empty()) {
// short-hand to avoid specifying --hf-file -> default it to --model
gpt_params_handle_model_default(params);
+ gpt_params_handle_hf_token(params);
+
if (params.escape) {
string_process_escapes(params.prompt);
string_process_escapes(params.input_prefix);
params.model_url = argv[i];
return true;
}
+ if (arg == "-hft" || arg == "--hf-token") {
+ if (++i >= argc) {
+ invalid_param = true;
+ return true;
+ }
+ params.hf_token = argv[i];
+ return true;
+ }
if (arg == "-hfr" || arg == "--hf-repo") {
CHECK_ARG
params.hf_repo = argv[i];
options.push_back({ "*", "-mu, --model-url MODEL_URL", "model download url (default: unused)" });
options.push_back({ "*", "-hfr, --hf-repo REPO", "Hugging Face model repository (default: unused)" });
options.push_back({ "*", "-hff, --hf-file FILE", "Hugging Face model file (default: unused)" });
+ options.push_back({ "*", "-hft, --hf-token TOKEN", "Hugging Face access token (default: value from HF_TOKEN environment variable)" });
options.push_back({ "retrieval" });
options.push_back({ "retrieval", " --context-file FNAME", "file to load context from (repeat to specify multiple files)" });
llama_model * model = nullptr;
if (!params.hf_repo.empty() && !params.hf_file.empty()) {
- model = llama_load_model_from_hf(params.hf_repo.c_str(), params.hf_file.c_str(), params.model.c_str(), mparams);
+ model = llama_load_model_from_hf(params.hf_repo.c_str(), params.hf_file.c_str(), params.model.c_str(), params.hf_token.c_str(), mparams);
} else if (!params.model_url.empty()) {
- model = llama_load_model_from_url(params.model_url.c_str(), params.model.c_str(), mparams);
+ model = llama_load_model_from_url(params.model_url.c_str(), params.model.c_str(), params.hf_token.c_str(), mparams);
} else {
model = llama_load_model_from_file(params.model.c_str(), mparams);
}
return str.rfind(prefix, 0) == 0;
}
-static bool llama_download_file(const std::string & url, const std::string & path) {
+static bool llama_download_file(const std::string & url, const std::string & path, const std::string & hf_token) {
// Initialize libcurl
std::unique_ptr<CURL, decltype(&curl_easy_cleanup)> curl(curl_easy_init(), &curl_easy_cleanup);
curl_easy_setopt(curl.get(), CURLOPT_URL, url.c_str());
curl_easy_setopt(curl.get(), CURLOPT_FOLLOWLOCATION, 1L);
+ // Check if hf-token or bearer-token was specified
+ if (!hf_token.empty()) {
+ std::string auth_header = "Authorization: Bearer ";
+ auth_header += hf_token.c_str();
+ struct curl_slist *http_headers = NULL;
+ http_headers = curl_slist_append(http_headers, auth_header.c_str());
+ curl_easy_setopt(curl.get(), CURLOPT_HTTPHEADER, http_headers);
+ }
+
#if defined(_WIN32)
// CURLSSLOPT_NATIVE_CA tells libcurl to use standard certificate store of
// operating system. Currently implemented under MS-Windows.
struct llama_model * llama_load_model_from_url(
const char * model_url,
const char * path_model,
+ const char * hf_token,
const struct llama_model_params & params) {
// Basic validation of the model_url
if (!model_url || strlen(model_url) == 0) {
return NULL;
}
- if (!llama_download_file(model_url, path_model)) {
+ if (!llama_download_file(model_url, path_model, hf_token)) {
return NULL;
}
// Prepare download in parallel
std::vector<std::future<bool>> futures_download;
for (int idx = 1; idx < n_split; idx++) {
- futures_download.push_back(std::async(std::launch::async, [&split_prefix, &split_url_prefix, &n_split](int download_idx) -> bool {
+ futures_download.push_back(std::async(std::launch::async, [&split_prefix, &split_url_prefix, &n_split, hf_token](int download_idx) -> bool {
char split_path[PATH_MAX] = {0};
llama_split_path(split_path, sizeof(split_path), split_prefix, download_idx, n_split);
char split_url[LLAMA_CURL_MAX_URL_LENGTH] = {0};
llama_split_path(split_url, sizeof(split_url), split_url_prefix, download_idx, n_split);
- return llama_download_file(split_url, split_path);
+ return llama_download_file(split_url, split_path, hf_token);
}, idx));
}
const char * repo,
const char * model,
const char * path_model,
+ const char * hf_token,
const struct llama_model_params & params) {
// construct hugging face model url:
//
model_url += "/resolve/main/";
model_url += model;
- return llama_load_model_from_url(model_url.c_str(), path_model, params);
+ return llama_load_model_from_url(model_url.c_str(), path_model, hf_token, params);
}
#else
struct llama_model * llama_load_model_from_url(
const char * /*model_url*/,
const char * /*path_model*/,
+ const char * /*hf_token*/,
const struct llama_model_params & /*params*/) {
fprintf(stderr, "%s: llama.cpp built without libcurl, downloading from an url not supported.\n", __func__);
return nullptr;
const char * /*repo*/,
const char * /*model*/,
const char * /*path_model*/,
+ const char * /*hf_token*/,
const struct llama_model_params & /*params*/) {
fprintf(stderr, "%s: llama.cpp built without libcurl, downloading from Hugging Face not supported.\n", __func__);
return nullptr;
std::string model_draft = ""; // draft model for speculative decoding
std::string model_alias = "unknown"; // model alias
std::string model_url = ""; // model url to download
+ std::string hf_token = ""; // HF token
std::string hf_repo = ""; // HF repo
std::string hf_file = ""; // HF file
std::string prompt = "";
bool spm_infill = false; // suffix/prefix/middle pattern for infill
};
+void gpt_params_handle_hf_token(gpt_params & params);
void gpt_params_handle_model_default(gpt_params & params);
bool gpt_params_parse_ex (int argc, char ** argv, gpt_params & params);
struct llama_model_params llama_model_params_from_gpt_params (const gpt_params & params);
struct llama_context_params llama_context_params_from_gpt_params(const gpt_params & params);
-struct llama_model * llama_load_model_from_url(const char * model_url, const char * path_model, const struct llama_model_params & params);
-struct llama_model * llama_load_model_from_hf(const char * repo, const char * file, const char * path_model, const struct llama_model_params & params);
+struct llama_model * llama_load_model_from_url(const char * model_url, const char * path_model, const char * hf_token, const struct llama_model_params & params);
+struct llama_model * llama_load_model_from_hf(const char * repo, const char * file, const char * path_model, const char * hf_token, const struct llama_model_params & params);
// Batch utils