]> git.djapps.eu Git - pkg/ggml/sources/whisper.cpp/commitdiff
ggml : add abort_callback for cpu backend (ggml/725)
authorMichael Podvitskiy <redacted>
Fri, 9 Feb 2024 09:42:27 +0000 (10:42 +0100)
committerGeorgi Gerganov <redacted>
Sat, 10 Feb 2024 07:55:46 +0000 (09:55 +0200)
* a way to use abort_callback with the cpu backend

* whisper update

ggml-backend.c
ggml-backend.h
ggml.c
ggml.h
whisper.cpp
whisper.h

index 0764dfebca673647babc92e9f58abae085b28be9..532da8edadced7a087ebd3835973db2b3283cec3 100644 (file)
@@ -653,6 +653,9 @@ struct ggml_backend_cpu_context {
     int n_threads;
     void * work_data;
     size_t work_size;
+
+    ggml_abort_callback abort_callback;
+    void *              abort_callback_data;
 };
 
 GGML_CALL static const char * ggml_backend_cpu_name(ggml_backend_t backend) {
@@ -691,6 +694,9 @@ GGML_CALL static ggml_backend_graph_plan_t ggml_backend_cpu_graph_plan_create(gg
         cpu_plan->cplan.work_data = malloc(cpu_plan->cplan.work_size);
     }
 
+    cpu_plan->cplan.abort_callback      = cpu_ctx->abort_callback;
+    cpu_plan->cplan.abort_callback_data = cpu_ctx->abort_callback_data;
+
     return cpu_plan;
 }
 
@@ -721,9 +727,11 @@ GGML_CALL static bool ggml_backend_cpu_graph_compute(ggml_backend_t backend, str
         cpu_ctx->work_data = realloc(cpu_ctx->work_data, cplan.work_size);
         cpu_ctx->work_size = cplan.work_size;
     }
-
     cplan.work_data = cpu_ctx->work_data;
 
+    cplan.abort_callback      = cpu_ctx->abort_callback;
+    cplan.abort_callback_data = cpu_ctx->abort_callback_data;
+
     ggml_graph_compute(cgraph, &cplan);
     return true;
 }
@@ -759,9 +767,11 @@ static struct ggml_backend_i cpu_backend_i = {
 ggml_backend_t ggml_backend_cpu_init(void) {
     struct ggml_backend_cpu_context * ctx = malloc(sizeof(struct ggml_backend_cpu_context));
 
-    ctx->n_threads = GGML_DEFAULT_N_THREADS;
-    ctx->work_data = NULL;
-    ctx->work_size = 0;
+    ctx->n_threads           = GGML_DEFAULT_N_THREADS;
+    ctx->work_data           = NULL;
+    ctx->work_size           = 0;
+    ctx->abort_callback      = NULL;
+    ctx->abort_callback_data = NULL;
 
     ggml_backend_t cpu_backend = malloc(sizeof(struct ggml_backend));
 
@@ -783,6 +793,14 @@ void ggml_backend_cpu_set_n_threads(ggml_backend_t backend_cpu, int n_threads) {
     ctx->n_threads = n_threads;
 }
 
+void ggml_backend_cpu_set_abort_callback(ggml_backend_t backend_cpu, ggml_abort_callback abort_callback, void * abort_callback_data) {
+    GGML_ASSERT(ggml_backend_is_cpu(backend_cpu));
+
+    struct ggml_backend_cpu_context * ctx = (struct ggml_backend_cpu_context *)backend_cpu->context;
+    ctx->abort_callback = abort_callback;
+    ctx->abort_callback_data = abort_callback_data;
+}
+
 GGML_CALL ggml_backend_buffer_t ggml_backend_cpu_buffer_from_ptr(void * ptr, size_t size) {
     return ggml_backend_buffer_init(ggml_backend_cpu_buffer_type(), cpu_backend_buffer_i_from_ptr, ptr, size);
 }
index 8b8160fcf66586831d92b6bc7f054e0a29eef091..282b3a9b79bad9ec56f3aa221f4535cc0fd2ae45 100644 (file)
@@ -83,8 +83,9 @@ extern "C" {
 
     GGML_API ggml_backend_t ggml_backend_cpu_init(void);
 
-    GGML_API GGML_CALL bool ggml_backend_is_cpu           (ggml_backend_t backend);
-    GGML_API           void ggml_backend_cpu_set_n_threads(ggml_backend_t backend_cpu, int n_threads);
+    GGML_API GGML_CALL bool ggml_backend_is_cpu                (ggml_backend_t backend);
+    GGML_API           void ggml_backend_cpu_set_n_threads     (ggml_backend_t backend_cpu, int n_threads);
+    GGML_API           void ggml_backend_cpu_set_abort_callback(ggml_backend_t backend_cpu, ggml_abort_callback abort_callback, void * abort_callback_data);
 
     // Create a backend buffer from an existing pointer
     GGML_API GGML_CALL ggml_backend_buffer_t ggml_backend_cpu_buffer_from_ptr(void * ptr, size_t size);
diff --git a/ggml.c b/ggml.c
index a7a9ea319c5f09dd711d783838200d66fb2f10d7..3499b737dd06f114ecd4bcab2327fb4552ea307f 100644 (file)
--- a/ggml.c
+++ b/ggml.c
@@ -16560,7 +16560,7 @@ struct ggml_compute_state_shared {
     atomic_int node_n;    // active graph node
     atomic_int node_task; // active graph node task phase
 
-    bool (*abort_callback)(void * data); // abort ggml_graph_compute when true
+    ggml_abort_callback abort_callback; // abort ggml_graph_compute when true
     void * abort_callback_data;
 };
 
diff --git a/ggml.h b/ggml.h
index bf782e6ad12793931e36f828bf5e01345774d6f2..e20b14faa08c81c6823d558e97801fe2a11ee15d 100644 (file)
--- a/ggml.h
+++ b/ggml.h
@@ -567,6 +567,11 @@ extern "C" {
 
     static const size_t GGML_TENSOR_SIZE = sizeof(struct ggml_tensor);
 
+    // Abort callback
+    // If not NULL, called before ggml computation
+    // If it returns true, the computation is aborted
+    typedef bool (*ggml_abort_callback)(void * data);
+
     // the compute plan that needs to be prepared for ggml_graph_compute()
     // since https://github.com/ggerganov/ggml/issues/287
     struct ggml_cplan {
@@ -576,8 +581,8 @@ extern "C" {
         int n_threads;
 
         // abort ggml_graph_compute when true
-        bool (*abort_callback)(void * data);
-        void * abort_callback_data;
+        ggml_abort_callback abort_callback;
+        void *              abort_callback_data;
     };
 
     enum ggml_cgraph_eval_order {
index 59d5cff1df51393bbd78f1e750989b599e0cb28f..28e3804f68fc270cd8c1bce7c2a874869649dcc4 100644 (file)
@@ -156,11 +156,11 @@ static bool ggml_graph_compute_helper(
           struct ggml_cgraph * graph,
         std::vector<uint8_t> & buf,
                          int   n_threads,
-      whisper_abort_callback   abort_callback,
+         ggml_abort_callback   abort_callback,
                         void * abort_callback_data) {
     struct ggml_cplan plan = ggml_graph_plan(graph, n_threads);
 
-    plan.abort_callback = abort_callback;
+    plan.abort_callback      = abort_callback;
     plan.abort_callback_data = abort_callback_data;
 
     if (plan.work_size > 0) {
@@ -2130,7 +2130,7 @@ static bool whisper_encode_internal(
           whisper_state & wstate,
               const int   mel_offset,
               const int   n_threads,
whisper_abort_callback   abort_callback,
   ggml_abort_callback   abort_callback,
                    void * abort_callback_data) {
     const int64_t t_start_us = ggml_time_us();
 
@@ -2561,7 +2561,7 @@ static bool whisper_decode_internal(
           whisper_state & wstate,
     const whisper_batch & batch,
               const int   n_threads,
whisper_abort_callback   abort_callback,
   ggml_abort_callback   abort_callback,
                    void * abort_callback_data) {
     const int64_t t_start_us = ggml_time_us();
 
index d571a125db3c48d208c6e1c33f29f6b8367cfbab..a5371eb3b9331a49f7479f8376df5a7d82c42c25 100644 (file)
--- a/whisper.h
+++ b/whisper.h
@@ -412,11 +412,6 @@ extern "C" {
     // If it returns false, the computation is aborted
     typedef bool (*whisper_encoder_begin_callback)(struct whisper_context * ctx, struct whisper_state * state, void * user_data);
 
-    // Abort callback
-    // If not NULL, called before ggml computation
-    // If it returns true, the computation is aborted
-    typedef bool (*whisper_abort_callback)(void * user_data);
-
     // Logits filter callback
     // Can be used to modify the logits before sampling
     // If not NULL, called after applying temperature to logits
@@ -513,7 +508,7 @@ extern "C" {
         void * encoder_begin_callback_user_data;
 
         // called each time before ggml computation starts
-        whisper_abort_callback abort_callback;
+        ggml_abort_callback abort_callback;
         void * abort_callback_user_data;
 
         // called by each decoder to filter obtained logits