]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
llama : add abort_callback to interrupt computation (#5409)
authorMichael Podvitskiy <redacted>
Sat, 2 Mar 2024 19:52:25 +0000 (20:52 +0100)
committerGitHub <redacted>
Sat, 2 Mar 2024 19:52:25 +0000 (21:52 +0200)
* using abort_callback from ggml to stop llama computation

* format fix

* a brief explaining comment

---------

Co-authored-by: Georgi Gerganov <redacted>
llama.cpp
llama.h

index 697e85e89e19c50b14c620c9897792ab2dd1bda0..d4c7a965bf377fd1b5b96135333cf8ed7688c602 100644 (file)
--- a/llama.cpp
+++ b/llama.cpp
@@ -1987,6 +1987,9 @@ struct llama_context {
     std::vector<uint8_t> buf_compute_meta;
     ggml_backend_sched_t sched = nullptr;
 
+    ggml_abort_callback abort_callback      = nullptr;
+    void *              abort_callback_data = nullptr;
+
     // input tensors
     ggml_backend_buffer_t buf_input = nullptr;
     ggml_context * ctx_input = nullptr;
@@ -8071,6 +8074,7 @@ static void llama_graph_compute(
 
     if (lctx.backend_cpu != nullptr) {
         ggml_backend_cpu_set_n_threads(lctx.backend_cpu, n_threads);
+        ggml_backend_cpu_set_abort_callback(lctx.backend_cpu, lctx.abort_callback, lctx.abort_callback_data);
     }
 
     ggml_backend_sched_graph_compute(lctx.sched, gf);
@@ -11856,6 +11860,8 @@ struct llama_context_params llama_context_default_params() {
         /*.embedding                   =*/ false,
         /*.offload_kqv                 =*/ true,
         /*.do_pooling                  =*/ true,
+        /*.abort_callback              =*/ nullptr,
+        /*.abort_callback_data         =*/ nullptr,
     };
 
     return result;
@@ -12038,8 +12044,11 @@ struct llama_context * llama_new_context_with_model(
     LLAMA_LOG_INFO("%s: freq_base  = %.1f\n",   __func__, cparams.rope_freq_base);
     LLAMA_LOG_INFO("%s: freq_scale = %g\n",     __func__, cparams.rope_freq_scale);
 
-    ctx->rng = std::mt19937(params.seed);
-    ctx->logits_all = params.logits_all;
+    ctx->abort_callback      = params.abort_callback;
+    ctx->abort_callback_data = params.abort_callback_data;
+
+    ctx->rng                 = std::mt19937(params.seed);
+    ctx->logits_all          = params.logits_all;
 
     const ggml_type type_k = params.type_k;
     const ggml_type type_v = params.type_v;
@@ -12989,6 +12998,11 @@ void llama_set_n_threads(struct llama_context * ctx, uint32_t n_threads, uint32_
     ctx->cparams.n_threads_batch = n_threads_batch;
 }
 
+void llama_set_abort_callback(struct llama_context * ctx, bool (*abort_callback)(void * data), void * abort_callback_data) {
+    ctx->abort_callback      = abort_callback;
+    ctx->abort_callback_data = abort_callback_data;
+}
+
 struct llama_batch llama_batch_get_one(
              llama_token * tokens,
                  int32_t   n_tokens,
diff --git a/llama.h b/llama.h
index ed51f478a7b2142fe001a71d4a9cce067cb17ff6..6406b52705e7d9d63861eaaca08c9a4b6ba70505 100644 (file)
--- a/llama.h
+++ b/llama.h
@@ -255,10 +255,16 @@ extern "C" {
         enum ggml_type type_v; // data type for V cache
 
         // Keep the booleans together to avoid misalignment during copy-by-value.
-        bool logits_all;  // the llama_eval() call computes all logits, not just the last one (DEPRECATED - set llama_batch.logits instead)
+        bool logits_all;  // the llama_decode() call computes all logits, not just the last one (DEPRECATED - set llama_batch.logits instead)
         bool embedding;   // embedding mode only
         bool offload_kqv; // whether to offload the KQV ops (including the KV cache) to GPU
         bool do_pooling;  // whether to pool (sum) embedding results by sequence id (ignored if no pooling layer)
+
+        // Abort callback
+        // if it returns true, execution of llama_decode() will be aborted
+        // currently works only with CPU execution
+        ggml_abort_callback abort_callback;
+        void *              abort_callback_data;
     };
 
     // model quantization parameters
@@ -632,7 +638,10 @@ extern "C" {
     // n_threads_batch is the number of threads used for prompt and batch processing (multiple tokens)
     LLAMA_API void llama_set_n_threads(struct llama_context * ctx, uint32_t n_threads, uint32_t n_threads_batch);
 
-    // Token logits obtained from the last call to llama_eval()
+    // Set abort callback
+    LLAMA_API void llama_set_abort_callback(struct llama_context * ctx, ggml_abort_callback abort_callback, void * abort_callback_data);
+
+    // Token logits obtained from the last call to llama_decode()
     // The logits for the last token are stored in the last row
     // Logits for which llama_batch.logits[i] == 0 are undefined
     // Rows: n_tokens provided with llama_batch