]> git.djapps.eu Git - pkg/ggml/sources/ggml/commitdiff
ggml : fix multi-threaded ggml_compute_forward_diag_mask_f32()
authorGeorgi Gerganov <redacted>
Sun, 14 May 2023 11:45:13 +0000 (14:45 +0300)
committerGeorgi Gerganov <redacted>
Sun, 14 May 2023 12:18:34 +0000 (15:18 +0300)
src/ggml.c

index 63da6799b8ff2461fe37ab8e1df244808ec7effe..58727bb35c80f692b67e48768ab81d5a2bb30162 100644 (file)
@@ -10372,22 +10372,34 @@ static void ggml_compute_forward_diag_mask_f32(
     assert(src1->type == GGML_TYPE_I32);
     assert(ggml_nelements(src1) == 2);
 
-    if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
+    const int  n_past  =       ((int32_t *) src1->data)[0];
+    const bool inplace = (bool)((int32_t *) src1->data)[1];
+
+    if (params->type == GGML_TASK_INIT) {
+        // TODO: this hack is not good, need a better way to handle this
+        if (!inplace) {
+            // use the init task to copy src -> dst
+            struct ggml_compute_params params_cpy = *params;
+
+            params_cpy.ith  = 0;
+            params_cpy.nth  = 1;
+            params_cpy.type = GGML_TASK_COMPUTE;
+
+            ggml_compute_forward_dup_same_cont(&params_cpy, src0, dst);
+        }
+
+        return;
+    }
+
+    if (params->type == GGML_TASK_FINALIZE) {
         return;
     }
 
     const int ith = params->ith;
     const int nth = params->nth;
 
-    const int  n_past  =       ((int32_t *) src1->data)[0];
-    const bool inplace = (bool)((int32_t *) src1->data)[1];
-
     assert(n_past >= 0);
 
-    if (!inplace) {
-        ggml_compute_forward_dup_same_cont(params, src0, dst);
-    }
-
     // TODO: handle transposed/permuted matrices
 
     const int n  = ggml_nrows(src0);
@@ -10474,7 +10486,7 @@ static void ggml_compute_forward_soft_max_f32(
 
     for (int i1 = ir0; i1 < ir1; i1++) {
         float *sp = (float *)((char *) src0->data + i1*src0->nb[1]);
-        float *dp = (float *)((char *) dst->data + i1*dst->nb[1]);
+        float *dp = (float *)((char *)  dst->data +  i1*dst->nb[1]);
 
 #ifndef NDEBUG
         for (int i = 0; i < nc; ++i) {