]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
llama-run : include temperature option (#10899)
authorEric Curtin <redacted>
Mon, 23 Dec 2024 00:21:40 +0000 (00:21 +0000)
committerGitHub <redacted>
Mon, 23 Dec 2024 00:21:40 +0000 (01:21 +0100)
This commit updates the `examples/run/README.md` file to include a new
option for setting the temperature and updates the `run.cpp` file to
parse this option.

Signed-off-by: Eric Curtin <redacted>
examples/run/README.md
examples/run/run.cpp

index 874293516f4b6a606e6a44bb0c7e529cfdca3bcd..a0680544120b944f73f8462a12f991af80a3e7f8 100644 (file)
@@ -19,6 +19,8 @@ Options:
       Context size (default: 2048)
   -n, --ngl <value>
       Number of GPU layers (default: 0)
+  --temp <value>
+      Temperature (default: 0.8)
   -v, --verbose, --log-verbose
       Set verbosity level to infinity (i.e. log all messages, useful for debugging)
   -h, --help
index 03da54ca3b2ef60bec898fd319e5252eae182b5e..f89d041c44ac10371e5958c68c7a74a755462a25 100644 (file)
@@ -55,29 +55,51 @@ static int printe(const char * fmt, ...) {
 class Opt {
   public:
     int init(int argc, const char ** argv) {
+        ctx_params           = llama_context_default_params();
+        model_params         = llama_model_default_params();
+        context_size_default = ctx_params.n_batch;
+        ngl_default          = model_params.n_gpu_layers;
+        common_params_sampling sampling;
+        temperature_default = sampling.temp;
+
+        if (argc < 2) {
+            printe("Error: No arguments provided.\n");
+            print_help();
+            return 1;
+        }
+
         // Parse arguments
         if (parse(argc, argv)) {
             printe("Error: Failed to parse arguments.\n");
-            help();
+            print_help();
             return 1;
         }
 
         // If help is requested, show help and exit
-        if (help_) {
-            help();
+        if (help) {
+            print_help();
             return 2;
         }
 
+        ctx_params.n_batch        = context_size >= 0 ? context_size : context_size_default;
+        model_params.n_gpu_layers = ngl >= 0 ? ngl : ngl_default;
+        temperature               = temperature >= 0 ? temperature : temperature_default;
+
         return 0;  // Success
     }
 
+    llama_context_params ctx_params;
+    llama_model_params   model_params;
     std::string model_;
-    std::string user_;
-    int         context_size_ = -1, ngl_ = -1;
-    bool        verbose_ = false;
+    std::string          user;
+    int                  context_size = -1, ngl = -1;
+    float                temperature = -1;
+    bool                 verbose     = false;
 
   private:
-    bool        help_ = false;
+    int   context_size_default = -1, ngl_default = -1;
+    float temperature_default = -1;
+    bool  help                = false;
 
     bool parse_flag(const char ** argv, int i, const char * short_opt, const char * long_opt) {
         return strcmp(argv[i], short_opt) == 0 || strcmp(argv[i], long_opt) == 0;
@@ -89,6 +111,17 @@ class Opt {
         }
 
         option_value = std::atoi(argv[++i]);
+
+        return 0;
+    }
+
+    int handle_option_with_value(int argc, const char ** argv, int & i, float & option_value) {
+        if (i + 1 >= argc) {
+            return 1;
+        }
+
+        option_value = std::atof(argv[++i]);
+
         return 0;
     }
 
@@ -96,18 +129,22 @@ class Opt {
         bool options_parsing   = true;
         for (int i = 1, positional_args_i = 0; i < argc; ++i) {
             if (options_parsing && (strcmp(argv[i], "-c") == 0 || strcmp(argv[i], "--context-size") == 0)) {
-                if (handle_option_with_value(argc, argv, i, context_size_) == 1) {
+                if (handle_option_with_value(argc, argv, i, context_size) == 1) {
                     return 1;
                 }
             } else if (options_parsing && (strcmp(argv[i], "-n") == 0 || strcmp(argv[i], "--ngl") == 0)) {
-                if (handle_option_with_value(argc, argv, i, ngl_) == 1) {
+                if (handle_option_with_value(argc, argv, i, ngl) == 1) {
+                    return 1;
+                }
+            } else if (options_parsing && strcmp(argv[i], "--temp") == 0) {
+                if (handle_option_with_value(argc, argv, i, temperature) == 1) {
                     return 1;
                 }
             } else if (options_parsing &&
                        (parse_flag(argv, i, "-v", "--verbose") || parse_flag(argv, i, "-v", "--log-verbose"))) {
-                verbose_ = true;
+                verbose = true;
             } else if (options_parsing && parse_flag(argv, i, "-h", "--help")) {
-                help_ = true;
+                help = true;
                 return 0;
             } else if (options_parsing && strcmp(argv[i], "--") == 0) {
                 options_parsing = false;
@@ -120,16 +157,16 @@ class Opt {
                 model_ = argv[i];
             } else if (positional_args_i == 1) {
                 ++positional_args_i;
-                user_ = argv[i];
+                user = argv[i];
             } else {
-                user_ += " " + std::string(argv[i]);
+                user += " " + std::string(argv[i]);
             }
         }
 
         return 0;
     }
 
-    void help() const {
+    void print_help() const {
         printf(
             "Description:\n"
             "  Runs a llm\n"
@@ -142,6 +179,8 @@ class Opt {
             "      Context size (default: %d)\n"
             "  -n, --ngl <value>\n"
             "      Number of GPU layers (default: %d)\n"
+            "  --temp <value>\n"
+            "      Temperature (default: %.1f)\n"
             "  -v, --verbose, --log-verbose\n"
             "      Set verbosity level to infinity (i.e. log all messages, useful for debugging)\n"
             "  -h, --help\n"
@@ -170,7 +209,7 @@ class Opt {
             "  llama-run file://some-file3.gguf\n"
             "  llama-run --ngl 999 some-file4.gguf\n"
             "  llama-run --ngl 999 some-file5.gguf Hello World\n",
-            llama_context_default_params().n_batch, llama_model_default_params().n_gpu_layers);
+            context_size_default, ngl_default, temperature_default);
     }
 };
 
@@ -495,12 +534,12 @@ class LlamaData {
             return 1;
         }
 
-        context = initialize_context(model, opt.context_size_);
+        context = initialize_context(model, opt);
         if (!context) {
             return 1;
         }
 
-        sampler = initialize_sampler();
+        sampler = initialize_sampler(opt);
         return 0;
     }
 
@@ -619,14 +658,12 @@ class LlamaData {
     // Initializes the model and returns a unique pointer to it
     llama_model_ptr initialize_model(Opt & opt) {
         ggml_backend_load_all();
-        llama_model_params model_params = llama_model_default_params();
-        model_params.n_gpu_layers       = opt.ngl_ >= 0 ? opt.ngl_ : model_params.n_gpu_layers;
         resolve_model(opt.model_);
         printe(
             "\r%*s"
             "\rLoading model",
             get_terminal_width(), " ");
-        llama_model_ptr model(llama_load_model_from_file(opt.model_.c_str(), model_params));
+        llama_model_ptr model(llama_load_model_from_file(opt.model_.c_str(), opt.model_params));
         if (!model) {
             printe("%s: error: unable to load model from file: %s\n", __func__, opt.model_.c_str());
         }
@@ -636,10 +673,8 @@ class LlamaData {
     }
 
     // Initializes the context with the specified parameters
-    llama_context_ptr initialize_context(const llama_model_ptr & model, const int n_ctx) {
-        llama_context_params ctx_params = llama_context_default_params();
-        ctx_params.n_ctx = ctx_params.n_batch = n_ctx >= 0 ? n_ctx : ctx_params.n_batch;
-        llama_context_ptr context(llama_new_context_with_model(model.get(), ctx_params));
+    llama_context_ptr initialize_context(const llama_model_ptr & model, const Opt & opt) {
+        llama_context_ptr context(llama_new_context_with_model(model.get(), opt.ctx_params));
         if (!context) {
             printe("%s: error: failed to create the llama_context\n", __func__);
         }
@@ -648,10 +683,10 @@ class LlamaData {
     }
 
     // Initializes and configures the sampler
-    llama_sampler_ptr initialize_sampler() {
+    llama_sampler_ptr initialize_sampler(const Opt & opt) {
         llama_sampler_ptr sampler(llama_sampler_chain_init(llama_sampler_chain_default_params()));
         llama_sampler_chain_add(sampler.get(), llama_sampler_init_min_p(0.05f, 1));
-        llama_sampler_chain_add(sampler.get(), llama_sampler_init_temp(0.8f));
+        llama_sampler_chain_add(sampler.get(), llama_sampler_init_temp(opt.temperature));
         llama_sampler_chain_add(sampler.get(), llama_sampler_init_dist(LLAMA_DEFAULT_SEED));
 
         return sampler;
@@ -798,9 +833,9 @@ static int apply_chat_template_with_error_handling(LlamaData & llama_data, const
 }
 
 // Helper function to handle user input
-static int handle_user_input(std::string & user_input, const std::string & user_) {
-    if (!user_.empty()) {
-        user_input = user_;
+static int handle_user_input(std::string & user_input, const std::string & user) {
+    if (!user.empty()) {
+        user_input = user;
         return 0;  // No need for interactive input
     }
 
@@ -832,17 +867,17 @@ static bool is_stdout_a_terminal() {
 }
 
 // Function to tokenize the prompt
-static int chat_loop(LlamaData & llama_data, const std::string & user_) {
+static int chat_loop(LlamaData & llama_data, const std::string & user) {
     int prev_len = 0;
     llama_data.fmtted.resize(llama_n_ctx(llama_data.context.get()));
     static const bool stdout_a_terminal = is_stdout_a_terminal();
     while (true) {
         // Get user input
         std::string user_input;
-        while (handle_user_input(user_input, user_)) {
+        while (handle_user_input(user_input, user)) {
         }
 
-        add_message("user", user_.empty() ? user_input : user_, llama_data);
+        add_message("user", user.empty() ? user_input : user, llama_data);
         int new_len;
         if (apply_chat_template_with_error_handling(llama_data, true, new_len) < 0) {
             return 1;
@@ -854,7 +889,7 @@ static int chat_loop(LlamaData & llama_data, const std::string & user_) {
             return 1;
         }
 
-        if (!user_.empty()) {
+        if (!user.empty()) {
             break;
         }
 
@@ -869,7 +904,7 @@ static int chat_loop(LlamaData & llama_data, const std::string & user_) {
 
 static void log_callback(const enum ggml_log_level level, const char * text, void * p) {
     const Opt * opt = static_cast<Opt *>(p);
-    if (opt->verbose_ || level == GGML_LOG_LEVEL_ERROR) {
+    if (opt->verbose || level == GGML_LOG_LEVEL_ERROR) {
         printe("%s", text);
     }
 }
@@ -890,11 +925,11 @@ int main(int argc, const char ** argv) {
     }
 
     if (!is_stdin_a_terminal()) {
-        if (!opt.user_.empty()) {
-            opt.user_ += "\n\n";
+        if (!opt.user.empty()) {
+            opt.user += "\n\n";
         }
 
-        opt.user_ += read_pipe_data();
+        opt.user += read_pipe_data();
     }
 
     llama_log_set(log_callback, &opt);
@@ -903,7 +938,7 @@ int main(int argc, const char ** argv) {
         return 1;
     }
 
-    if (chat_loop(llama_data, opt.user_)) {
+    if (chat_loop(llama_data, opt.user)) {
         return 1;
     }