]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
Enhance user input handling for llama-run (#11138)
authorEric Curtin <redacted>
Wed, 8 Jan 2025 18:47:05 +0000 (18:47 +0000)
committerGitHub <redacted>
Wed, 8 Jan 2025 18:47:05 +0000 (18:47 +0000)
The main motivation for this change is it was not handing
ctrl-c/ctrl-d correctly. Modify `read_user_input` to handle EOF,
"/bye" command, and empty input cases. Introduce `get_user_input`
function to manage user input loop and handle different return
cases.

Signed-off-by: Eric Curtin <redacted>
examples/run/run.cpp

index 2888fcfed1e15576d2f9f0557900a6e82e819884..61420e441e0f9f6297d50b1e444cc6089d297fb2 100644 (file)
@@ -11,6 +11,8 @@
 #    include <curl/curl.h>
 #endif
 
+#include <signal.h>
+
 #include <climits>
 #include <cstdarg>
 #include <cstdio>
 #include "json.hpp"
 #include "llama-cpp.h"
 
+#if defined(__unix__) || (defined(__APPLE__) && defined(__MACH__)) || defined(_WIN32)
+[[noreturn]] static void sigint_handler(int) {
+    printf("\n");
+    exit(0);  // not ideal, but it's the only way to guarantee exit in all cases
+}
+#endif
+
 GGML_ATTRIBUTE_FORMAT(1, 2)
 static std::string fmt(const char * fmt, ...) {
     va_list ap;
@@ -801,7 +810,20 @@ static int generate(LlamaData & llama_data, const std::string & prompt, std::str
 
 static int read_user_input(std::string & user) {
     std::getline(std::cin, user);
-    return user.empty();  // Should have data in happy path
+    if (std::cin.eof()) {
+        printf("\n");
+        return 1;
+    }
+
+    if (user == "/bye") {
+        return 1;
+    }
+
+    if (user.empty()) {
+        return 2;
+    }
+
+    return 0;  // Should have data in happy path
 }
 
 // Function to generate a response based on the prompt
@@ -868,7 +890,25 @@ static bool is_stdout_a_terminal() {
 #endif
 }
 
-// Function to tokenize the prompt
+// Function to handle user input
+static int get_user_input(std::string & user_input, const std::string & user) {
+    while (true) {
+        const int ret = handle_user_input(user_input, user);
+        if (ret == 1) {
+            return 1;
+        }
+
+        if (ret == 2) {
+            continue;
+        }
+
+        break;
+    }
+
+    return 0;
+}
+
+// Main chat loop function
 static int chat_loop(LlamaData & llama_data, const std::string & user) {
     int prev_len = 0;
     llama_data.fmtted.resize(llama_n_ctx(llama_data.context.get()));
@@ -876,7 +916,8 @@ static int chat_loop(LlamaData & llama_data, const std::string & user) {
     while (true) {
         // Get user input
         std::string user_input;
-        while (handle_user_input(user_input, user)) {
+        if (get_user_input(user_input, user) == 1) {
+            return 0;
         }
 
         add_message("user", user.empty() ? user_input : user, llama_data);
@@ -917,7 +958,23 @@ static std::string read_pipe_data() {
     return result.str();
 }
 
+static void ctrl_c_handling() {
+#if defined(__unix__) || (defined(__APPLE__) && defined(__MACH__))
+    struct sigaction sigint_action;
+    sigint_action.sa_handler = sigint_handler;
+    sigemptyset(&sigint_action.sa_mask);
+    sigint_action.sa_flags = 0;
+    sigaction(SIGINT, &sigint_action, NULL);
+#elif defined(_WIN32)
+    auto console_ctrl_handler = +[](DWORD ctrl_type) -> BOOL {
+        return (ctrl_type == CTRL_C_EVENT) ? (sigint_handler(SIGINT), true) : false;
+    };
+    SetConsoleCtrlHandler(reinterpret_cast<PHANDLER_ROUTINE>(console_ctrl_handler), true);
+#endif
+}
+
 int main(int argc, const char ** argv) {
+    ctrl_c_handling();
     Opt       opt;
     const int ret = opt.init(argc, argv);
     if (ret == 2) {