]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
server : support reading arguments from environment variables (#9105)
authorXuan Son Nguyen <redacted>
Wed, 21 Aug 2024 09:04:34 +0000 (11:04 +0200)
committerGitHub <redacted>
Wed, 21 Aug 2024 09:04:34 +0000 (11:04 +0200)
* server : support reading arguments from environment variables

* add -fa and -dt

* readme : specify non-arg env var

common/common.cpp
common/common.h
examples/server/README.md
examples/server/server.cpp

index 382d585a5e6f900a1eb72fbd3abab2809b73d19f..59e8296604c9c8cda9b781f8a7097cbc7ad69c2e 100644 (file)
 
 using json = nlohmann::ordered_json;
 
+//
+// Environment variable utils
+//
+
+template<typename T>
+static typename std::enable_if<std::is_same<T, std::string>::value, void>::type
+get_env(std::string name, T & target) {
+    char * value = std::getenv(name.c_str());
+    target = value ? std::string(value) : target;
+}
+
+template<typename T>
+static typename std::enable_if<!std::is_same<T, bool>::value && std::is_integral<T>::value, void>::type
+get_env(std::string name, T & target) {
+    char * value = std::getenv(name.c_str());
+    target = value ? std::stoi(value) : target;
+}
+
+template<typename T>
+static typename std::enable_if<std::is_floating_point<T>::value, void>::type
+get_env(std::string name, T & target) {
+    char * value = std::getenv(name.c_str());
+    target = value ? std::stof(value) : target;
+}
+
+template<typename T>
+static typename std::enable_if<std::is_same<T, bool>::value, void>::type
+get_env(std::string name, T & target) {
+    char * value = std::getenv(name.c_str());
+    if (value) {
+        std::string val(value);
+        target = val == "1" || val == "true";
+    }
+}
+
 //
 // CPU utils
 //
@@ -220,12 +255,6 @@ int32_t cpu_get_num_math() {
 // CLI argument parsing
 //
 
-void gpt_params_handle_hf_token(gpt_params & params) {
-    if (params.hf_token.empty() && std::getenv("HF_TOKEN")) {
-        params.hf_token = std::getenv("HF_TOKEN");
-    }
-}
-
 void gpt_params_handle_model_default(gpt_params & params) {
     if (!params.hf_repo.empty()) {
         // short-hand to avoid specifying --hf-file -> default it to --model
@@ -273,7 +302,9 @@ bool gpt_params_parse_ex(int argc, char ** argv, gpt_params & params) {
 
     gpt_params_handle_model_default(params);
 
-    gpt_params_handle_hf_token(params);
+    if (params.hf_token.empty()) {
+        get_env("HF_TOKEN", params.hf_token);
+    }
 
     if (params.escape) {
         string_process_escapes(params.prompt);
@@ -293,6 +324,25 @@ bool gpt_params_parse_ex(int argc, char ** argv, gpt_params & params) {
     return true;
 }
 
+void gpt_params_parse_from_env(gpt_params & params) {
+    // we only care about server-related params for now
+    get_env("LLAMA_ARG_MODEL",            params.model);
+    get_env("LLAMA_ARG_THREADS",          params.n_threads);
+    get_env("LLAMA_ARG_CTX_SIZE",         params.n_ctx);
+    get_env("LLAMA_ARG_N_PARALLEL",       params.n_parallel);
+    get_env("LLAMA_ARG_BATCH",            params.n_batch);
+    get_env("LLAMA_ARG_UBATCH",           params.n_ubatch);
+    get_env("LLAMA_ARG_N_GPU_LAYERS",     params.n_gpu_layers);
+    get_env("LLAMA_ARG_THREADS_HTTP",     params.n_threads_http);
+    get_env("LLAMA_ARG_CHAT_TEMPLATE",    params.chat_template);
+    get_env("LLAMA_ARG_N_PREDICT",        params.n_predict);
+    get_env("LLAMA_ARG_ENDPOINT_METRICS", params.endpoint_metrics);
+    get_env("LLAMA_ARG_ENDPOINT_SLOTS",   params.endpoint_slots);
+    get_env("LLAMA_ARG_EMBEDDINGS",       params.embedding);
+    get_env("LLAMA_ARG_FLASH_ATTN",       params.flash_attn);
+    get_env("LLAMA_ARG_DEFRAG_THOLD",     params.defrag_thold);
+}
+
 bool gpt_params_parse(int argc, char ** argv, gpt_params & params) {
     const auto params_org = params; // the example can modify the default params
 
index df23460a50fe063b51b9c6eca03b1ce5f5084685..f603ba2be1d35a89dd6d42a80aefed655f0b642e 100644 (file)
@@ -267,7 +267,7 @@ struct gpt_params {
     std::string lora_outfile = "ggml-lora-merged-f16.gguf";
 };
 
-void gpt_params_handle_hf_token(gpt_params & params);
+void gpt_params_parse_from_env(gpt_params & params);
 void gpt_params_handle_model_default(gpt_params & params);
 
 bool gpt_params_parse_ex   (int argc, char ** argv, gpt_params & params);
index 930ae15f64d8b685cc66aaf5b8160e76f5dd3279..abe245271195b27b0dae7f5e98ebf55ea1b0024d 100644 (file)
@@ -247,6 +247,25 @@ logging:
          --log-append             Don't truncate the old log file.
 ```
 
+Available environment variables (if specified, these variables will override parameters specified in arguments):
+
+- `LLAMA_CACHE` (cache directory, used by `--hf-repo`)
+- `HF_TOKEN` (Hugging Face access token, used when accessing a gated model with `--hf-repo`)
+- `LLAMA_ARG_MODEL`
+- `LLAMA_ARG_THREADS`
+- `LLAMA_ARG_CTX_SIZE`
+- `LLAMA_ARG_N_PARALLEL`
+- `LLAMA_ARG_BATCH`
+- `LLAMA_ARG_UBATCH`
+- `LLAMA_ARG_N_GPU_LAYERS`
+- `LLAMA_ARG_THREADS_HTTP`
+- `LLAMA_ARG_CHAT_TEMPLATE`
+- `LLAMA_ARG_N_PREDICT`
+- `LLAMA_ARG_ENDPOINT_METRICS`
+- `LLAMA_ARG_ENDPOINT_SLOTS`
+- `LLAMA_ARG_EMBEDDINGS`
+- `LLAMA_ARG_FLASH_ATTN`
+- `LLAMA_ARG_DEFRAG_THOLD`
 
 ## Build
 
index ce711eadd29acfb678cac35c96a8a6ff23b962d9..e79e7aa2cb8460a50ee6b4151d6184c3fec38399 100644 (file)
@@ -2507,6 +2507,9 @@ int main(int argc, char ** argv) {
         return 1;
     }
 
+    // parse arguments from environment variables
+    gpt_params_parse_from_env(params);
+
     // TODO: not great to use extern vars
     server_log_json = params.log_json;
     server_verbose = params.verbosity > 0;