]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
ggml : alternative fix for race condition bug in non-inplace ggml_compute_forward_dia...
authorxaedes <redacted>
Sun, 14 May 2023 15:55:02 +0000 (17:55 +0200)
committerGitHub <redacted>
Sun, 14 May 2023 15:55:02 +0000 (18:55 +0300)
* fix race condition bug in non-inplace ggml_compute_forward_diag_mask_f32

memcpy needs to be synchronized across threads to avoid race conditions.
=> do it in INIT phase

* remove trailing whitespace

* Update ggml.c

---------

Co-authored-by: Georgi Gerganov <redacted>
ggml.c

diff --git a/ggml.c b/ggml.c
index da3d914e4ef47837fb5455676e61e95a3aef2579..4311ce7cf9dbe67cb4a8f54c276b231ad7354e6b 100644 (file)
--- a/ggml.c
+++ b/ggml.c
@@ -10501,34 +10501,28 @@ static void ggml_compute_forward_diag_mask_f32(
     assert(src1->type == GGML_TYPE_I32);
     assert(ggml_nelements(src1) == 2);
 
+    const int ith = params->ith;
+    const int nth = params->nth;
+
     const int  n_past  =       ((int32_t *) src1->data)[0];
     const bool inplace = (bool)((int32_t *) src1->data)[1];
+    assert(n_past >= 0);
 
-    if (params->type == GGML_TASK_INIT) {
-        // TODO: this hack is not good, need a better way to handle this
-        if (!inplace) {
-            // use the init task to copy src -> dst
-            struct ggml_compute_params params_cpy = *params;
-
-            params_cpy.ith  = 0;
-            params_cpy.nth  = 1;
-            params_cpy.type = GGML_TASK_COMPUTE;
-
-            ggml_compute_forward_dup_same_cont(&params_cpy, src0, dst);
-        }
-
-        return;
+    if (!inplace && (params->type == GGML_TASK_INIT)) {
+        // memcpy needs to be synchronized across threads to avoid race conditions.
+        // => do it in INIT phase
+        GGML_ASSERT(ggml_nelements(dst) == ggml_nelements(src0));
+        GGML_ASSERT(ggml_is_contiguous(dst) && ggml_is_contiguous(src0));
+        memcpy(
+            ((char *)  dst->data),
+            ((char *) src0->data),
+            ggml_nbytes(dst));
     }
 
-    if (params->type == GGML_TASK_FINALIZE) {
+    if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
         return;
     }
 
-    const int ith = params->ith;
-    const int nth = params->nth;
-
-    assert(n_past >= 0);
-
     // TODO: handle transposed/permuted matrices
 
     const int n  = ggml_nrows(src0);