]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
llama-mtmd-cli: Sigint rework in mtmd vision example (#13080)
authorpl752 <redacted>
Wed, 23 Apr 2025 21:32:35 +0000 (02:32 +0500)
committerGitHub <redacted>
Wed, 23 Apr 2025 21:32:35 +0000 (23:32 +0200)
* Sigint rework in mtmd vision example

* Applied suggestions on mtmd-cli PR

* Forgot to invert one of the conditions

* Update examples/llava/mtmd-cli.cpp

* Removed redundant exit check

---------

Co-authored-by: pl752 <redacted>
Co-authored-by: Xuan-Son Nguyen <redacted>
examples/llava/mtmd-cli.cpp

index e80845a2c546947157e4208693223a060533829a..89af7331a1658b38fef760f9f1d67ec9d229e22b 100644 (file)
@@ -24,7 +24,9 @@
 #include <signal.h>
 #endif
 
-static bool g_is_generating = false;
+// volatile, because of signal being an interrupt
+static volatile bool g_is_generating = false;
+static volatile bool g_is_interrupted = false;
 
 /**
  * Please note that this is NOT a production-ready stuff.
@@ -50,8 +52,10 @@ static void sigint_handler(int signo) {
             g_is_generating = false;
         } else {
             console::cleanup();
-            LOG("\nInterrupted by user\n");
-            _exit(130);
+            if (g_is_interrupted) {
+                _exit(1);
+            }
+            g_is_interrupted = true;
         }
     }
 }
@@ -167,7 +171,7 @@ struct decode_embd_batch {
 static int generate_response(mtmd_cli_context & ctx, common_sampler * smpl, int n_predict) {
     llama_tokens generated_tokens;
     for (int i = 0; i < n_predict; i++) {
-        if (i > n_predict || !g_is_generating) {
+        if (i > n_predict || !g_is_generating || g_is_interrupted) {
             printf("\n");
             break;
         }
@@ -184,6 +188,11 @@ static int generate_response(mtmd_cli_context & ctx, common_sampler * smpl, int
         printf("%s", common_token_to_piece(ctx.lctx, token_id).c_str());
         fflush(stdout);
 
+        if (g_is_interrupted) {
+            printf("\n");
+            break;
+        }
+
         // eval the token
         common_batch_clear(ctx.batch);
         common_batch_add(ctx.batch, token_id, ctx.n_past++, {0}, true);
@@ -219,6 +228,9 @@ static int eval_message(mtmd_cli_context & ctx, common_chat_msg & msg, std::vect
     text.add_special   = add_bos;
     text.parse_special = true;
     mtmd_input_chunks chunks;
+
+    if (g_is_interrupted) return 0;
+
     int32_t res = mtmd_tokenize(ctx.ctx_vision.get(), chunks, text, bitmaps);
     if (res != 0) {
         LOG_ERR("Unable to tokenize prompt, res = %d\n", res);
@@ -276,6 +288,8 @@ int main(int argc, char ** argv) {
 #endif
     }
 
+    if (g_is_interrupted) return 130;
+
     if (is_single_turn) {
         g_is_generating = true;
         if (params.prompt.find("<__image__>") == std::string::npos) {
@@ -287,7 +301,7 @@ int main(int argc, char ** argv) {
         if (eval_message(ctx, msg, params.image, true)) {
             return 1;
         }
-        if (generate_response(ctx, smpl, n_predict)) {
+        if (!g_is_interrupted && generate_response(ctx, smpl, n_predict)) {
             return 1;
         }
 
@@ -302,12 +316,13 @@ int main(int argc, char ** argv) {
         std::vector<std::string> images_fname;
         std::string content;
 
-        while (true) {
+        while (!g_is_interrupted) {
             g_is_generating = false;
             LOG("\n> ");
             console::set_display(console::user_input);
             std::string line;
             console::readline(line, false);
+            if (g_is_interrupted) break;
             console::set_display(console::reset);
             line = string_strip(line);
             if (line.empty()) {
@@ -335,6 +350,7 @@ int main(int argc, char ** argv) {
             msg.role = "user";
             msg.content = content;
             int ret = eval_message(ctx, msg, images_fname, is_first_msg);
+            if (g_is_interrupted) break;
             if (ret == 2) {
                 // non-fatal error
                 images_fname.clear();
@@ -352,6 +368,7 @@ int main(int argc, char ** argv) {
             is_first_msg = false;
         }
     }
+    if (g_is_interrupted) LOG("\nInterrupted by user\n");
     llama_perf_context_print(ctx.lctx);
-    return 0;
+    return g_is_interrupted ? 130 : 0;
 }