]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
Some llama-run cleanups (#11973)
authorEric Curtin <redacted>
Sun, 23 Feb 2025 13:14:32 +0000 (13:14 +0000)
committerGitHub <redacted>
Sun, 23 Feb 2025 13:14:32 +0000 (13:14 +0000)
Use consolidated open function call from File class. Change
read_all to to_string(). Remove exclusive locking, the intent for
that lock is to avoid multiple processes writing to the same file,
it's not an issue for readers, although we may want to consider
adding a shared lock. Remove passing nullptr as reference,
references are never supposed to be null. clang-format the code
for consistent styling.

Signed-off-by: Eric Curtin <redacted>
examples/run/run.cpp

index 4da1e50251600c0357454ab950525960cc449139..de736c7d5a3d98df760f84ecd5f6596dc867b322 100644 (file)
@@ -323,25 +323,17 @@ class File {
         return 0;
     }
 
-    std::string read_all(const std::string & filename){
-        open(filename, "r");
-        lock();
-        if (!file) {
-            printe("Error opening file '%s': %s", filename.c_str(), strerror(errno));
-            return "";
-        }
-
+    std::string to_string() {
         fseek(file, 0, SEEK_END);
-        size_t size = ftell(file);
+        const size_t size = ftell(file);
         fseek(file, 0, SEEK_SET);
-
         std::string out;
         out.resize(size);
-        size_t read_size = fread(&out[0], 1, size, file);
+        const size_t read_size = fread(&out[0], 1, size, file);
         if (read_size != size) {
-            printe("Error reading file '%s': %s", filename.c_str(), strerror(errno));
-            return "";
+            printe("Error reading file: %s", strerror(errno));
         }
+
         return out;
     }
 
@@ -1098,59 +1090,66 @@ static int get_user_input(std::string & user_input, const std::string & user) {
 
 // Reads a chat template file to be used
 static std::string read_chat_template_file(const std::string & chat_template_file) {
-    if(chat_template_file.empty()){
-        return "";
-    }
-
     File file;
-    std::string chat_template = "";
-    chat_template = file.read_all(chat_template_file);
-    if(chat_template.empty()){
+    if (!file.open(chat_template_file, "r")) {
         printe("Error opening chat template file '%s': %s", chat_template_file.c_str(), strerror(errno));
         return "";
     }
-    return chat_template;
+
+    return file.to_string();
+}
+
+static int process_user_message(const Opt & opt, const std::string & user_input, LlamaData & llama_data,
+                                const common_chat_templates_ptr & chat_templates, int & prev_len,
+                                const bool stdout_a_terminal) {
+    add_message("user", opt.user.empty() ? user_input : opt.user, llama_data);
+    int new_len;
+    if (apply_chat_template_with_error_handling(chat_templates.get(), llama_data, true, new_len, opt.use_jinja) < 0) {
+        return 1;
+    }
+
+    std::string prompt(llama_data.fmtted.begin() + prev_len, llama_data.fmtted.begin() + new_len);
+    std::string response;
+    if (generate_response(llama_data, prompt, response, stdout_a_terminal)) {
+        return 1;
+    }
+
+    if (!opt.user.empty()) {
+        return 2;
+    }
+
+    add_message("assistant", response, llama_data);
+    if (apply_chat_template_with_error_handling(chat_templates.get(), llama_data, false, prev_len, opt.use_jinja) < 0) {
+        return 1;
+    }
+
+    return 0;
 }
 
 // Main chat loop function
-static int chat_loop(LlamaData & llama_data, const std::string & user, const std::string & chat_template_file, bool use_jinja) {
+static int chat_loop(LlamaData & llama_data, const Opt & opt) {
     int prev_len = 0;
     llama_data.fmtted.resize(llama_n_ctx(llama_data.context.get()));
-
-    std::string chat_template = "";
-    if(!chat_template_file.empty()){
-        chat_template = read_chat_template_file(chat_template_file);
+    std::string chat_template;
+    if (!opt.chat_template_file.empty()) {
+        chat_template = read_chat_template_file(opt.chat_template_file);
     }
-    auto chat_templates = common_chat_templates_init(llama_data.model.get(), chat_template.empty() ? nullptr : chat_template);
 
+    common_chat_templates_ptr chat_templates    = common_chat_templates_init(llama_data.model.get(), chat_template);
     static const bool stdout_a_terminal = is_stdout_a_terminal();
     while (true) {
         // Get user input
         std::string user_input;
-        if (get_user_input(user_input, user) == 1) {
+        if (get_user_input(user_input, opt.user) == 1) {
             return 0;
         }
 
-        add_message("user", user.empty() ? user_input : user, llama_data);
-        int new_len;
-        if (apply_chat_template_with_error_handling(chat_templates.get(), llama_data, true, new_len, use_jinja) < 0) {
-            return 1;
-        }
-
-        std::string prompt(llama_data.fmtted.begin() + prev_len, llama_data.fmtted.begin() + new_len);
-        std::string response;
-        if (generate_response(llama_data, prompt, response, stdout_a_terminal)) {
+        const int ret = process_user_message(opt, user_input, llama_data, chat_templates, prev_len, stdout_a_terminal);
+        if (ret == 1) {
             return 1;
-        }
-
-        if (!user.empty()) {
+        } else if (ret == 2) {
             break;
         }
-
-        add_message("assistant", response, llama_data);
-        if (apply_chat_template_with_error_handling(chat_templates.get(), llama_data, false, prev_len, use_jinja) < 0) {
-            return 1;
-        }
     }
 
     return 0;
@@ -1208,7 +1207,7 @@ int main(int argc, const char ** argv) {
         return 1;
     }
 
-    if (chat_loop(llama_data, opt.user, opt.chat_template_file, opt.use_jinja)) {
+    if (chat_loop(llama_data, opt)) {
         return 1;
     }