]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
common : add -hfd option for the draft model (#11318)
authorGeorgi Gerganov <redacted>
Mon, 20 Jan 2025 20:29:43 +0000 (22:29 +0200)
committerGitHub <redacted>
Mon, 20 Jan 2025 20:29:43 +0000 (22:29 +0200)
* common : add -hfd option for the draft model

* cont : fix env var

* cont : more fixes

common/arg.cpp
common/common.h
examples/server/server.cpp

index dede335fbc3fdb5e28c96dfd2089f6913487ab88..126970950f7e6a56816f71767a89365eeaf10287 100644 (file)
@@ -133,7 +133,8 @@ static void common_params_handle_model_default(
         const std::string & model_url,
         std::string & hf_repo,
         std::string & hf_file,
-        const std::string & hf_token) {
+        const std::string & hf_token,
+        const std::string & model_default) {
     if (!hf_repo.empty()) {
         // short-hand to avoid specifying --hf-file -> default it to --model
         if (hf_file.empty()) {
@@ -163,7 +164,7 @@ static void common_params_handle_model_default(
             model = fs_get_cache_file(string_split<std::string>(f, '/').back());
         }
     } else if (model.empty()) {
-        model = DEFAULT_MODEL_PATH;
+        model = model_default;
     }
 }
 
@@ -299,8 +300,9 @@ static bool common_params_parse_ex(int argc, char ** argv, common_params_context
     }
 
     // TODO: refactor model params in a common struct
-    common_params_handle_model_default(params.model,         params.model_url,         params.hf_repo,         params.hf_file,         params.hf_token);
-    common_params_handle_model_default(params.vocoder.model, params.vocoder.model_url, params.vocoder.hf_repo, params.vocoder.hf_file, params.hf_token);
+    common_params_handle_model_default(params.model,             params.model_url,             params.hf_repo,             params.hf_file,             params.hf_token, DEFAULT_MODEL_PATH);
+    common_params_handle_model_default(params.speculative.model, params.speculative.model_url, params.speculative.hf_repo, params.speculative.hf_file, params.hf_token, "");
+    common_params_handle_model_default(params.vocoder.model,     params.vocoder.model_url,     params.vocoder.hf_repo,     params.vocoder.hf_file,     params.hf_token, "");
 
     if (params.escape) {
         string_process_escapes(params.prompt);
@@ -1629,6 +1631,13 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
             params.hf_repo = value;
         }
     ).set_env("LLAMA_ARG_HF_REPO"));
+    add_opt(common_arg(
+        {"-hfd", "-hfrd", "--hf-repo-draft"}, "<user>/<model>[:quant]",
+        "Same as --hf-repo, but for the draft model (default: unused)",
+        [](common_params & params, const std::string & value) {
+            params.speculative.hf_repo = value;
+        }
+    ).set_env("LLAMA_ARG_HFD_REPO"));
     add_opt(common_arg(
         {"-hff", "--hf-file"}, "FILE",
         "Hugging Face model file. If specified, it will override the quant in --hf-repo (default: unused)",
index 3bcc637cc800ae5fb6029403a158f452e0490e4f..b2709c044b1aed24fce607559c08cf24429f1eb2 100644 (file)
@@ -175,7 +175,11 @@ struct common_params_speculative {
     struct cpu_params cpuparams;
     struct cpu_params cpuparams_batch;
 
-    std::string model = ""; // draft model for speculative decoding                          // NOLINT
+    std::string hf_repo = ""; // HF repo                                                     // NOLINT
+    std::string hf_file = ""; // HF file                                                     // NOLINT
+
+    std::string model = "";     // draft model for speculative decoding                      // NOLINT
+    std::string model_url = ""; // model url to download                                     // NOLINT
 };
 
 struct common_params_vocoder {
@@ -508,12 +512,14 @@ struct llama_model * common_load_model_from_url(
     const std::string & local_path,
     const std::string & hf_token,
     const struct llama_model_params & params);
+
 struct llama_model * common_load_model_from_hf(
     const std::string & repo,
     const std::string & remote_path,
     const std::string & local_path,
     const std::string & hf_token,
     const struct llama_model_params & params);
+
 std::pair<std::string, std::string> common_get_hf_file(
     const std::string & hf_repo_with_tag,
     const std::string & hf_token);
index d1e8ee829105c2e81a5aa39e1db41400588b15d7..f35206d7b2f9a27d67e47dadf5831ebe014e6422 100644 (file)
@@ -1728,13 +1728,16 @@ struct server_context {
         add_bos_token = llama_vocab_get_add_bos(vocab);
         has_eos_token = llama_vocab_eos(vocab) != LLAMA_TOKEN_NULL;
 
-        if (!params_base.speculative.model.empty()) {
+        if (!params_base.speculative.model.empty() || !params_base.speculative.hf_repo.empty()) {
             SRV_INF("loading draft model '%s'\n", params_base.speculative.model.c_str());
 
             auto params_dft = params_base;
 
             params_dft.devices      = params_base.speculative.devices;
+            params_dft.hf_file      = params_base.speculative.hf_file;
+            params_dft.hf_repo      = params_base.speculative.hf_repo;
             params_dft.model        = params_base.speculative.model;
+            params_dft.model_url    = params_base.speculative.model_url;
             params_dft.n_ctx        = params_base.speculative.n_ctx == 0 ? params_base.n_ctx / params_base.n_parallel : params_base.speculative.n_ctx;
             params_dft.n_gpu_layers = params_base.speculative.n_gpu_layers;
             params_dft.n_parallel   = 1;