]> git.djapps.eu Git - pkg/ggml/sources/ggml/commitdiff
ggml : use atomic_flag for critical section (llama/7598)
authorslaren <redacted>
Wed, 29 May 2024 11:36:39 +0000 (13:36 +0200)
committerGeorgi Gerganov <redacted>
Sat, 15 Jun 2024 19:05:47 +0000 (22:05 +0300)
* ggml : use atomic_flag for critical section

* add windows shims

src/ggml.c

index 5025ec23b37645982a41bc3eaae868b67ddf19ee..d8f74f3ceaf5da026988c807dbf0aaff129d7022 100644 (file)
@@ -60,6 +60,9 @@
 
 typedef volatile LONG atomic_int;
 typedef atomic_int atomic_bool;
+typedef atomic_int atomic_flag;
+
+#define ATOMIC_FLAG_INIT 0
 
 static void atomic_store(atomic_int * ptr, LONG val) {
     InterlockedExchange(ptr, val);
@@ -73,6 +76,12 @@ static LONG atomic_fetch_add(atomic_int * ptr, LONG inc) {
 static LONG atomic_fetch_sub(atomic_int * ptr, LONG dec) {
     return atomic_fetch_add(ptr, -(dec));
 }
+static atomic_bool atomic_flag_test_and_set(atomic_flag * ptr) {
+    return InterlockedExchange(ptr, 1);
+}
+static void atomic_flag_clear(atomic_flag * ptr) {
+    InterlockedExchange(ptr, 0);
+}
 
 typedef HANDLE pthread_t;
 
@@ -2883,24 +2892,20 @@ struct ggml_state {
 
 // global state
 static struct ggml_state g_state;
-static atomic_int g_state_barrier = 0;
+static atomic_flag g_state_critical = ATOMIC_FLAG_INIT;
 
 // barrier via spin lock
 inline static void ggml_critical_section_start(void) {
-    int processing = atomic_fetch_add(&g_state_barrier, 1);
-
-    while (processing > 0) {
-        // wait for other threads to finish
-        atomic_fetch_sub(&g_state_barrier, 1);
-        sched_yield(); // TODO: reconsider this
-        processing = atomic_fetch_add(&g_state_barrier, 1);
+    while (atomic_flag_test_and_set(&g_state_critical)) {
+        // spin
+        sched_yield();
     }
 }
 
 // TODO: make this somehow automatically executed
 //       some sort of "sentry" mechanism
 inline static void ggml_critical_section_end(void) {
-    atomic_fetch_sub(&g_state_barrier, 1);
+    atomic_flag_clear(&g_state_critical);
 }
 
 #if defined(__gnu_linux__)