]> git.djapps.eu Git - pkg/ggml/sources/ggml/commitdiff
ggml : add callback to abort ggml_graph_compute() (#328)
authorArjun <redacted>
Tue, 11 Jul 2023 19:11:45 +0000 (00:41 +0530)
committerGitHub <redacted>
Tue, 11 Jul 2023 19:11:45 +0000 (22:11 +0300)
* mechanism to abort ggml_graph_compute

* use pthread_cancel

* forgot to commit ggml.h

* static always_false()

Co-authored-by: Georgi Gerganov <redacted>
* accept callback data

* proper function prototype

* return exit status

* remove pthread_cancel and join every thread

* put abort_callback onto cplan

* cplan abort_callback in ggml.c

* make sure all threads abort

---------

Co-authored-by: Georgi Gerganov <redacted>
include/ggml/ggml.h
src/ggml.c

index d7c9e0f0e06815058ed61ac0128c61fecde5eb8a..8fe05d3a595b7b9368711cb453d1888ef65faa20 100644 (file)
 #define GGML_MAX_NAME          48
 #define GGML_DEFAULT_N_THREADS 4
 
+
+#define GGML_EXIT_SUCCESS 0
+#define GGML_EXIT_ABORTED 1
+
 #define GGML_UNUSED(x) (void)(x)
 
+
 #define GGML_ASSERT(x) \
     do { \
         if (!(x)) { \
@@ -442,6 +447,10 @@ extern "C" {
 
         // the `n_tasks` of nodes, 1:1 mapping to cgraph nodes
         int n_tasks[GGML_MAX_NODES];
+
+        // abort ggml_graph_compute when true
+        bool (*abort_callback)(void * data);
+        void * abort_callback_data;
     };
 
     // computation graph
@@ -1303,7 +1312,7 @@ extern "C" {
     // ggml_graph_plan() has to be called before ggml_graph_compute()
     // when plan.work_size > 0, caller must allocate memory for plan.work_data
     GGML_API struct ggml_cplan ggml_graph_plan   (struct ggml_cgraph * cgraph, int n_threads /*= GGML_DEFAULT_N_THREADS*/);
-    GGML_API              void ggml_graph_compute(struct ggml_cgraph * cgraph, struct ggml_cplan * cplan);
+    GGML_API               int ggml_graph_compute(struct ggml_cgraph * cgraph, struct ggml_cplan * cplan);
     GGML_API              void ggml_graph_reset  (struct ggml_cgraph * cgraph);
 
     // same as ggml_graph_compute() but the work data is allocated as a part of the context
index d84afbb6d423160fa45dfe0a74db2108e48b58dc..28f26762a8931232a096d2168e033be9fedf3e14 100644 (file)
@@ -25,6 +25,7 @@
 #include <float.h>
 #include <limits.h>
 #include <stdarg.h>
+#include <signal.h>
 
 #ifdef GGML_USE_METAL
 #include <unistd.h>
@@ -15955,6 +15956,9 @@ struct ggml_compute_state_shared {
     // synchronization primitives
     atomic_int n_active; // num active threads
     atomic_int node_n;   // active graph node
+
+    bool (*abort_callback)(void * data); // abort ggml_graph_compute when true
+    void * abort_callback_data;
 };
 
 struct ggml_compute_state {
@@ -15986,6 +15990,10 @@ static thread_ret_t ggml_graph_compute_thread(void * data) {
     int node_n = -1;
 
     while (true) {
+        if (cplan->abort_callback && cplan->abort_callback(cplan->abort_callback_data)) {
+            state->shared->node_n += 1;
+            return GGML_EXIT_ABORTED;
+        }
         if (atomic_fetch_sub(&state->shared->n_active, 1) == 1) {
             // all other threads are finished and spinning
             // do finalize and init here so we don't have synchronize again
@@ -16039,6 +16047,10 @@ static thread_ret_t ggml_graph_compute_thread(void * data) {
                 } else {
                     break;
                 }
+
+                if (cplan->abort_callback && cplan->abort_callback(cplan->abort_callback_data)) {
+                    break;
+                }
             }
 
             atomic_store(&state->shared->n_active, n_threads);
@@ -16072,9 +16084,10 @@ static thread_ret_t ggml_graph_compute_thread(void * data) {
         }
     }
 
-    return 0;
+    return GGML_EXIT_SUCCESS;
 }
 
+static bool always_false(void * data) { return false; }
 struct ggml_cplan ggml_graph_plan(struct ggml_cgraph * cgraph, int n_threads) {
     if (n_threads <= 0) {
         n_threads = GGML_DEFAULT_N_THREADS;
@@ -16412,7 +16425,7 @@ struct ggml_cplan ggml_graph_plan(struct ggml_cgraph * cgraph, int n_threads) {
     return cplan;
 }
 
-void ggml_graph_compute(struct ggml_cgraph * cgraph, struct ggml_cplan * cplan) {
+int ggml_graph_compute(struct ggml_cgraph * cgraph, struct ggml_cplan * cplan) {
     {
         GGML_ASSERT(cplan);
         GGML_ASSERT(cplan->n_threads > 0);
@@ -16461,12 +16474,12 @@ void ggml_graph_compute(struct ggml_cgraph * cgraph, struct ggml_cplan * cplan)
     const int64_t perf_start_time_us = ggml_perf_time_us();
 
     // this is a work thread too
-    ggml_graph_compute_thread(&workers[0]);
+    int compute_status = ggml_graph_compute_thread(&workers[0]);
 
     // don't leave affinity set on the main thread
     clear_numa_thread_affinity();
 
-    // join thread pool
+    // join or kill thread pool
     if (n_threads > 1) {
         for (int j = 1; j < n_threads; j++) {
             const int rc = ggml_thread_join(workers[j].thrd, NULL);
@@ -16490,6 +16503,8 @@ void ggml_graph_compute(struct ggml_cgraph * cgraph, struct ggml_cplan * cplan)
                 (double) perf_time_us_cur     / 1000.0,
                 (double) cgraph->perf_time_us / 1000.0 / cgraph->perf_runs);
     }
+
+    return compute_status;
 }
 
 void ggml_graph_reset(struct ggml_cgraph * cgraph) {