]> git.djapps.eu Git - pkg/ggml/sources/whisper.cpp/commitdiff
ggml : fix thread-safety of ggml_init and ggml_free
authorGeorgi Gerganov <redacted>
Sat, 29 Oct 2022 08:23:44 +0000 (11:23 +0300)
committerGeorgi Gerganov <redacted>
Sat, 29 Oct 2022 16:37:19 +0000 (19:37 +0300)
ggml.c

diff --git a/ggml.c b/ggml.c
index e8384ed778b22aa495e4f737988b1c38bebcf5a2..c0354f45cab5d3f70ae29612a6b29ac1ced88e74 100644 (file)
--- a/ggml.c
+++ b/ggml.c
@@ -1136,6 +1136,7 @@ struct ggml_state {
 
 // global state
 struct ggml_state g_state;
+atomic_bool g_state_barrier = 0;
 
 ////////////////////////////////////////////////////////////////////////////////
 
@@ -1265,6 +1266,17 @@ int ggml_up64(int n) {
 ////////////////////////////////////////////////////////////////////////////////
 
 struct ggml_context * ggml_init(struct ggml_init_params params) {
+    // make this function thread safe
+    {
+        int processing = atomic_fetch_add(&g_state_barrier, 1);
+        while (processing > 0) {
+            // wait for other threads to finish
+            atomic_fetch_sub(&g_state_barrier, 1);
+            sched_yield();
+            processing = atomic_fetch_add(&g_state_barrier, 1);
+        }
+    }
+
     static bool is_first_call = true;
     if (is_first_call) {
         const uint64_t t_start = ggml_time_us(); UNUSED(t_start);
@@ -1308,6 +1320,9 @@ struct ggml_context * ggml_init(struct ggml_init_params params) {
 
     if (ctx == NULL) {
         GGML_PRINT_DEBUG("%s: no unused context found\n", __func__);
+
+        atomic_fetch_sub(&g_state_barrier, 1);
+
         return NULL;
     }
 
@@ -1322,10 +1337,25 @@ struct ggml_context * ggml_init(struct ggml_init_params params) {
 
     ggml_assert_aligned(ctx->mem_buffer);
 
+    GGML_PRINT_DEBUG("%s: context initialized\n", __func__);
+
+    atomic_fetch_sub(&g_state_barrier, 1);
+
     return ctx;
 }
 
 void ggml_free(struct ggml_context * ctx) {
+    // make this function thread safe
+    {
+        int processing = atomic_fetch_add(&g_state_barrier, 1);
+        while (processing > 0) {
+            // wait for other threads to finish
+            atomic_fetch_sub(&g_state_barrier, 1);
+            sched_yield();
+            processing = atomic_fetch_add(&g_state_barrier, 1);
+        }
+    }
+
     for (int i = 0; i < GGML_MAX_CONTEXTS; i++) {
         if (&g_state.contexts[i].context == ctx) {
             g_state.contexts[i].used = false;
@@ -1337,11 +1367,15 @@ void ggml_free(struct ggml_context * ctx) {
                 free(ctx->mem_buffer);
             }
 
+            atomic_fetch_sub(&g_state_barrier, 1);
+
             return;
         }
     }
 
     GGML_PRINT_DEBUG("%s: context not found\n", __func__);
+
+    atomic_fetch_sub(&g_state_barrier, 1);
 }
 
 size_t ggml_used_mem(const struct ggml_context * ctx) {