]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
run : add --chat-template-file (#11961)
authorMichael Engel <redacted>
Thu, 20 Feb 2025 08:35:11 +0000 (09:35 +0100)
committerGitHub <redacted>
Thu, 20 Feb 2025 08:35:11 +0000 (10:35 +0200)
Relates to: https://github.com/ggml-org/llama.cpp/issues/11178

Added --chat-template-file CLI option to llama-run. If specified, the file
will be read and the content passed for overwriting the chat template of
the model to common_chat_templates_from_model.

Signed-off-by: Michael Engel <redacted>
examples/run/run.cpp

index ed8644ef78d97fb064e665d283db1b61dcdaeda0..4da1e50251600c0357454ab950525960cc449139 100644 (file)
@@ -113,6 +113,7 @@ class Opt {
     llama_context_params ctx_params;
     llama_model_params   model_params;
     std::string model_;
+    std::string chat_template_file;
     std::string          user;
     bool                 use_jinja   = false;
     int                  context_size = -1, ngl = -1;
@@ -148,6 +149,16 @@ class Opt {
         return 0;
     }
 
+    int handle_option_with_value(int argc, const char ** argv, int & i, std::string & option_value) {
+        if (i + 1 >= argc) {
+            return 1;
+        }
+
+        option_value = argv[++i];
+
+        return 0;
+    }
+
     int parse(int argc, const char ** argv) {
         bool options_parsing   = true;
         for (int i = 1, positional_args_i = 0; i < argc; ++i) {
@@ -169,6 +180,11 @@ class Opt {
                 verbose = true;
             } else if (options_parsing && strcmp(argv[i], "--jinja") == 0) {
                 use_jinja = true;
+            } else if (options_parsing && strcmp(argv[i], "--chat-template-file") == 0){
+                if (handle_option_with_value(argc, argv, i, chat_template_file) == 1) {
+                    return 1;
+                }
+                use_jinja = true;
             } else if (options_parsing && parse_flag(argv, i, "-h", "--help")) {
                 help = true;
                 return 0;
@@ -207,6 +223,11 @@ class Opt {
             "Options:\n"
             "  -c, --context-size <value>\n"
             "      Context size (default: %d)\n"
+            "  --chat-template-file <path>\n"
+            "      Path to the file containing the chat template to use with the model.\n"
+            "      Only supports jinja templates and implicitly sets the --jinja flag.\n"
+            "  --jinja\n"
+            "      Use jinja templating for the chat template of the model\n"
             "  -n, -ngl, --ngl <value>\n"
             "      Number of GPU layers (default: %d)\n"
             "  --temp <value>\n"
@@ -261,13 +282,12 @@ static int get_terminal_width() {
 #endif
 }
 
-#ifdef LLAMA_USE_CURL
 class File {
   public:
     FILE * file = nullptr;
 
     FILE * open(const std::string & filename, const char * mode) {
-        file = fopen(filename.c_str(), mode);
+        file = ggml_fopen(filename.c_str(), mode);
 
         return file;
     }
@@ -303,6 +323,28 @@ class File {
         return 0;
     }
 
+    std::string read_all(const std::string & filename){
+        open(filename, "r");
+        lock();
+        if (!file) {
+            printe("Error opening file '%s': %s", filename.c_str(), strerror(errno));
+            return "";
+        }
+
+        fseek(file, 0, SEEK_END);
+        size_t size = ftell(file);
+        fseek(file, 0, SEEK_SET);
+
+        std::string out;
+        out.resize(size);
+        size_t read_size = fread(&out[0], 1, size, file);
+        if (read_size != size) {
+            printe("Error reading file '%s': %s", filename.c_str(), strerror(errno));
+            return "";
+        }
+        return out;
+    }
+
     ~File() {
         if (fd >= 0) {
 #    ifdef _WIN32
@@ -327,6 +369,7 @@ class File {
 #    endif
 };
 
+#ifdef LLAMA_USE_CURL
 class HttpClient {
   public:
     int init(const std::string & url, const std::vector<std::string> & headers, const std::string & output_file,
@@ -1053,11 +1096,33 @@ static int get_user_input(std::string & user_input, const std::string & user) {
     return 0;
 }
 
+// Reads a chat template file to be used
+static std::string read_chat_template_file(const std::string & chat_template_file) {
+    if(chat_template_file.empty()){
+        return "";
+    }
+
+    File file;
+    std::string chat_template = "";
+    chat_template = file.read_all(chat_template_file);
+    if(chat_template.empty()){
+        printe("Error opening chat template file '%s': %s", chat_template_file.c_str(), strerror(errno));
+        return "";
+    }
+    return chat_template;
+}
+
 // Main chat loop function
-static int chat_loop(LlamaData & llama_data, const std::string & user, bool use_jinja) {
+static int chat_loop(LlamaData & llama_data, const std::string & user, const std::string & chat_template_file, bool use_jinja) {
     int prev_len = 0;
     llama_data.fmtted.resize(llama_n_ctx(llama_data.context.get()));
-    auto chat_templates = common_chat_templates_init(llama_data.model.get(), "");
+
+    std::string chat_template = "";
+    if(!chat_template_file.empty()){
+        chat_template = read_chat_template_file(chat_template_file);
+    }
+    auto chat_templates = common_chat_templates_init(llama_data.model.get(), chat_template.empty() ? nullptr : chat_template);
+
     static const bool stdout_a_terminal = is_stdout_a_terminal();
     while (true) {
         // Get user input
@@ -1143,7 +1208,7 @@ int main(int argc, const char ** argv) {
         return 1;
     }
 
-    if (chat_loop(llama_data, opt.user, opt.use_jinja)) {
+    if (chat_loop(llama_data, opt.user, opt.chat_template_file, opt.use_jinja)) {
         return 1;
     }