]> git.djapps.eu Git - pkg/ggml/sources/ggml/commitdiff
ggml : sync with latest whisper.cpp
authorGeorgi Gerganov <redacted>
Sat, 7 Jan 2023 17:53:05 +0000 (19:53 +0200)
committerGeorgi Gerganov <redacted>
Sat, 7 Jan 2023 17:53:05 +0000 (19:53 +0200)
src/ggml.c

index c88eb22741f3f23b8e6d6005acb89891b14c5c43..f4c96eb42229d494905fb538fd9d0e3e3618a3f9 100644 (file)
@@ -79,8 +79,11 @@ typedef void* thread_ret_t;
 #define static_assert(cond, msg) _Static_assert(cond, msg)
 #endif
 
+/*#define GGML_PERF*/
 #define GGML_DEBUG 0
 #define GGML_GELU_FP16
+#define GGML_SOFT_MAX_UNROLL 4
+#define GGML_VEC_DOT_UNROLL 4
 
 #if UINTPTR_MAX == 0xFFFFFFFF
     #define GGML_MEM_ALIGN 4
@@ -908,6 +911,61 @@ inline static void ggml_vec_dot_f16(const int n, float * restrict s, ggml_fp16_t
     *s = sumf;
 }
 
+// compute GGML_VEC_DOT_UNROLL dot products at once
+// xs - x row stride in bytes
+inline static void ggml_vec_dot_f16_unroll(const int n, const int xs, float * restrict s, void * restrict xv, ggml_fp16_t * restrict y) {
+    ggml_float sumf[GGML_VEC_DOT_UNROLL] = { 0.0 };
+
+    const ggml_fp16_t * restrict x[GGML_VEC_DOT_UNROLL] = { xv };
+
+    for (int i = 1; i < GGML_VEC_DOT_UNROLL; ++i) {
+        x[i] = (ggml_fp16_t *) ((char *) xv + i*xs);
+    }
+
+#if defined(GGML_SIMD)
+    const int np = (n & ~(GGML_F16_STEP - 1));
+
+    GGML_F16_VEC sum[GGML_VEC_DOT_UNROLL][GGML_F16_ARR] = { { GGML_F16_VEC_ZERO } };
+
+    GGML_F16_VEC ax[GGML_F16_ARR];
+    GGML_F16_VEC ay[GGML_F16_ARR];
+
+    for (int i = 0; i < np; i += GGML_F16_STEP) {
+        for (int j = 0; j < GGML_F16_ARR; j++) {
+            ay[j] = GGML_F16_VEC_LOAD(y + i + j*GGML_F16_EPR, j);
+
+            for (int k = 0; k < GGML_VEC_DOT_UNROLL; ++k) {
+                ax[j] = GGML_F16_VEC_LOAD(x[k] + i + j*GGML_F16_EPR, j);
+
+                sum[k][j] = GGML_F16_VEC_FMA(sum[k][j], ax[j], ay[j]);
+            }
+        }
+    }
+
+    // reduce sum0..sum3 to sum0
+    for (int k = 0; k < GGML_VEC_DOT_UNROLL; ++k) {
+        GGML_F16_VEC_REDUCE(sumf[k], sum[k]);
+    }
+
+    // leftovers
+    for (int i = np; i < n; ++i) {
+        for (int j = 0; j < GGML_VEC_DOT_UNROLL; ++j) {
+            sumf[j] += GGML_FP16_TO_FP32(x[j][i])*GGML_FP16_TO_FP32(y[i]);
+        }
+    }
+#else
+    for (int i = 0; i < n; ++i) {
+        for (int j = 0; j < GGML_VEC_DOT_UNROLL; ++j) {
+            sumf[j] += GGML_FP16_TO_FP32(x[j][i])*GGML_FP16_TO_FP32(y[i]);
+        }
+    }
+#endif
+
+    for (int i = 0; i < GGML_VEC_DOT_UNROLL; ++i) {
+        s[i] = sumf[i];
+    }
+}
+
 inline static void ggml_vec_mad_f32(const int n, float * restrict y, const float * restrict x, const float v) {
 #if defined(GGML_SIMD)
     const int np = (n & ~(GGML_F32_STEP - 1));
@@ -1039,7 +1097,30 @@ inline static void ggml_vec_gelu_f32(const int n, float * y, const float * x) {
 }
 #endif
 
-inline static void ggml_vec_sum_f32     (const int n, float * s, const float * x) { ggml_float sum = 0.0; for (int i = 0; i < n; ++i) sum += x[i]; *s += sum; }
+inline static void ggml_vec_sum_f32(const int n, float * s, const float * x) {
+#ifndef GGML_USE_ACCELERATE
+    ggml_float sum = 0.0;
+    for (int i = 0; i < n; ++i) {
+        sum += x[i];
+        *s += sum;
+    }
+#else
+    vDSP_sve(x, 1, s, n);
+#endif
+}
+
+inline static void ggml_vec_max_f32(const int n, float * s, const float * x) {
+#ifndef GGML_USE_ACCELERATE
+    ggml_float max = -INFINITY;
+    for (int i = 0; i < n; ++i) {
+        max = MAX(max, x[i]);
+    }
+    *s = max;
+#else
+    vDSP_maxv(x, 1, s, n);
+#endif
+}
+
 inline static void ggml_vec_norm_inv_f32(const int n, float * s, const float * x) { ggml_vec_norm_f32(n, s, x); *s = 1./(*s); }
 
 //
@@ -1293,25 +1374,25 @@ size_t ggml_element_size(const struct ggml_tensor * tensor) {
     return GGML_TYPE_SIZE[tensor->type];
 }
 
-bool ggml_is_scalar(const struct ggml_tensor * tensor) {
+static inline bool ggml_is_scalar(const struct ggml_tensor * tensor) {
     static_assert(GGML_MAX_DIMS == 4, "GGML_MAX_DIMS is not 4 - update this function");
 
     return tensor->ne[0] == 1 && tensor->ne[1] == 1 && tensor->ne[2] == 1 && tensor->ne[3] == 1;
 }
 
-bool ggml_is_vector(const struct ggml_tensor * tensor) {
+static inline bool ggml_is_vector(const struct ggml_tensor * tensor) {
     static_assert(GGML_MAX_DIMS == 4, "GGML_MAX_DIMS is not 4 - update this function");
 
     return tensor->ne[1] == 1 && tensor->ne[2] == 1 && tensor->ne[3] == 1;
 }
 
-bool ggml_is_matrix(const struct ggml_tensor * tensor) {
+static inline bool ggml_is_matrix(const struct ggml_tensor * tensor) {
     static_assert(GGML_MAX_DIMS == 4, "GGML_MAX_DIMS is not 4 - update this function");
 
     return tensor->ne[2] == 1 && tensor->ne[3] == 1;
 }
 
-bool ggml_can_mul_mat(const struct ggml_tensor * t0, const struct ggml_tensor * t1) {
+static inline bool ggml_can_mul_mat(const struct ggml_tensor * t0, const struct ggml_tensor * t1) {
     static_assert(GGML_MAX_DIMS == 4, "GGML_MAX_DIMS is not 4 - update this function");
 
     return
@@ -1320,7 +1401,7 @@ bool ggml_can_mul_mat(const struct ggml_tensor * t0, const struct ggml_tensor *
         (t0->ne[3]  == t1->ne[3]);
 }
 
-bool ggml_is_contiguous(const struct ggml_tensor * tensor) {
+static inline bool ggml_is_contiguous(const struct ggml_tensor * tensor) {
     static_assert(GGML_MAX_DIMS == 4, "GGML_MAX_DIMS is not 4 - update this function");
 
     return
@@ -1330,7 +1411,7 @@ bool ggml_is_contiguous(const struct ggml_tensor * tensor) {
         tensor->nb[3] == tensor->nb[2]*tensor->ne[2];
 }
 
-bool ggml_is_padded_1d(const struct ggml_tensor * tensor) {
+static inline bool ggml_is_padded_1d(const struct ggml_tensor * tensor) {
     static_assert(GGML_MAX_DIMS == 4, "GGML_MAX_DIMS is not 4 - update this function");
 
     return
@@ -1339,7 +1420,7 @@ bool ggml_is_padded_1d(const struct ggml_tensor * tensor) {
         tensor->nb[3] == tensor->nb[2]*tensor->ne[2];
 }
 
-bool ggml_are_same_shape(const struct ggml_tensor * t0, const struct ggml_tensor * t1) {
+static inline bool ggml_are_same_shape(const struct ggml_tensor * t0, const struct ggml_tensor * t1) {
     static_assert(GGML_MAX_DIMS == 4, "GGML_MAX_DIMS is not 4 - update this function");
 
     return
@@ -1350,7 +1431,7 @@ bool ggml_are_same_shape(const struct ggml_tensor * t0, const struct ggml_tensor
 }
 
 // check if t1 can be represented as a repeatition of t0
-bool ggml_can_repeat(const struct ggml_tensor * t0, const struct ggml_tensor * t1) {
+static inline bool ggml_can_repeat(const struct ggml_tensor * t0, const struct ggml_tensor * t1) {
     static_assert(GGML_MAX_DIMS == 4, "GGML_MAX_DIMS is not 4 - update this function");
 
     return
@@ -1360,14 +1441,20 @@ bool ggml_can_repeat(const struct ggml_tensor * t0, const struct ggml_tensor * t
         (t1->ne[3]%t0->ne[3] == 0);
 }
 
-int ggml_up32(int n) {
+static inline int ggml_up32(int n) {
     return (n + 31) & ~31;
 }
 
-int ggml_up64(int n) {
+static inline int ggml_up64(int n) {
     return (n + 63) & ~63;
 }
 
+static inline int ggml_up(int n, int m) {
+    // assert m is a power of 2
+    GGML_ASSERT((m & (m - 1)) == 0);
+    return (n + m - 1) & ~(m - 1);
+}
+
 // assert that pointer is aligned to GGML_MEM_ALIGN
 #define ggml_assert_aligned(ptr) \
     assert(((uintptr_t) (ptr))%GGML_MEM_ALIGN == 0)
@@ -4658,7 +4745,7 @@ static void ggml_compute_forward_mul_mat_f16_f32(
         // TODO: do not support transposed src1
         assert(nb10/2 == sizeof(ggml_fp16_t));
 
-        // parallelize by src0 rows using ggml_vec_dot_f32
+        // parallelize by src0 rows using ggml_vec_dot_f16
 
         // total rows in src0
         const int nr = ne01*ne02*ne03;
@@ -4686,13 +4773,13 @@ static void ggml_compute_forward_mul_mat_f16_f32(
             const int i3 = i03;
 
             ggml_fp16_t * src0_row = (ggml_fp16_t *) ((char *) src0->data + (i01*nb01 + i02*nb02 + i03*nb03));
-            ggml_fp16_t * src1_col = wdata + (i13*ne12*ne11 + i12*ne11 + 0)*ne00;
+            ggml_fp16_t * src1_col =                                wdata + (       0 + i12*ne11 + i13*ne12*ne11)*ne00;
 
             float * dst_col = (float *) ((char *) dst->data + (i0*nb0 + 0*nb1 + i2*nb2 + i3*nb3));
 
-            for (int ic = 0; ic < ne11; ++ic) {
-                assert(ne00 % 32 == 0);
+            assert(ne00 % 32 == 0);
 
+            for (int ic = 0; ic < ne11; ++ic) {
                 ggml_vec_dot_f16(ne00, &dst_col[ic*ne0], src0_row, src1_col + ic*ne00);
             }
         }
@@ -5071,21 +5158,19 @@ static void ggml_compute_forward_soft_max_f32(
 #endif
 
         float max = -INFINITY;
-        for (int i = 0; i < nc; i++) {
-            max = MAX(max, p[i]);
-        }
+        ggml_vec_max_f32(nc, &max, p);
 
         ggml_float sum = 0.0;
 
-        uint16_t ss;
+        uint16_t scvt;
         for (int i = 0; i < nc; i++) {
             if (p[i] == -INFINITY) {
-                p[i] = 0.0;
+                p[i] = 0.0f;
             } else {
                 //const float val = (p[i] == -INFINITY) ? 0.0 : exp(p[i] - max);
                 ggml_fp16_t s = GGML_FP32_TO_FP16(p[i] - max);
-                memcpy(&ss, &s, sizeof(ss));
-                const float val = GGML_FP16_TO_FP32(table_exp_f16[ss]);
+                memcpy(&scvt, &s, sizeof(scvt));
+                const float val = GGML_FP16_TO_FP32(table_exp_f16[scvt]);
                 sum += val;
                 p[i] = val;
             }
@@ -5797,6 +5882,8 @@ static void ggml_compute_forward_flash_attn_f32(
     const int P = nek1 - N;
     const int M = P + N;
 
+    const int Mup = ggml_up(M, GGML_SOFT_MAX_UNROLL);
+
     GGML_ASSERT(ne0 == D);
     GGML_ASSERT(ne1 == N);
     GGML_ASSERT(P >= 0);
@@ -5849,7 +5936,11 @@ static void ggml_compute_forward_flash_attn_f32(
         const int iq2 = (ir - iq3*neq2*neq1)/neq1;
         const int iq1 = (ir - iq3*neq2*neq1 - iq2*neq1);
 
-        float * S = (float *) params->wdata + ith*(M + CACHE_LINE_SIZE_F32);
+        float * S = (float *) params->wdata + ith*(Mup + CACHE_LINE_SIZE_F32);
+
+        for (int i = M; i < Mup; ++i) {
+            S[i] = -INFINITY;
+        }
 
         for (int ic = 0; ic < nek1; ++ic) {
             // k indices
@@ -5880,30 +5971,50 @@ static void ggml_compute_forward_flash_attn_f32(
         // softmax
         {
             float max = -INFINITY;
-            for (int i = 0; i < M; i++) {
-                max = MAX(max, S[i]);
-            }
+            ggml_vec_max_f32(M, &max, S);
 
-            ggml_float sum = 0.0;
+            float sum = 0.0f;
+            {
+#ifndef GGML_USE_ACCELERATE
+                uint16_t   scvt[GGML_SOFT_MAX_UNROLL];
+                ggml_float sump[GGML_SOFT_MAX_UNROLL] = { 0.0 };
 
-            uint16_t ss;
-            for (int i = 0; i < M; i++) {
-                if (S[i] == -INFINITY) {
-                    S[i] = 0.0;
-                } else {
-                    //const float val = (S[i] == -INFINITY) ? 0.0 : exp(S[i] - max);
-                    ggml_fp16_t s = GGML_FP32_TO_FP16(S[i] - max);
-                    memcpy(&ss, &s, sizeof(ss));
-                    const float val = GGML_FP16_TO_FP32(table_exp_f16[ss]);
-                    sum += val;
-                    S[i] = val;
+                for (int i = 0; i < Mup; i += GGML_SOFT_MAX_UNROLL) {
+                    float * SS = S + i;
+
+                    for (int j = 0; j < GGML_SOFT_MAX_UNROLL; ++j) {
+                        if (SS[j] == -INFINITY) {
+                            SS[j] = 0.0f;
+                        } else {
+                            ggml_fp16_t s = GGML_FP32_TO_FP16(SS[j] - max);
+                            memcpy(&scvt[j], &s, sizeof(uint16_t));
+                            const float val = GGML_FP16_TO_FP32(table_exp_f16[scvt[j]]);
+                            sump[j] += val;
+                            SS[j] = val;
+                        }
+                    }
                 }
+
+                for (int i = 0; i < GGML_SOFT_MAX_UNROLL; i++) {
+                    sum += sump[i];
+                }
+#else
+                vvexpf(S, S, &Mup);
+                ggml_vec_sum_f32(Mup, &sum, S);
+#endif
             }
 
             assert(sum > 0.0f);
 
             sum = 1.0/sum;
             ggml_vec_scale_f32(M, S, sum);
+
+#ifndef NDEBUG
+            for (int i = 0; i < M; ++i) {
+                assert(!isnan(S[i]));
+                assert(!isinf(S[i]));
+            }
+#endif
         }
 
         for (int ic = 0; ic < nev1; ++ic) {
@@ -5978,6 +6089,8 @@ static void ggml_compute_forward_flash_attn_f16(
     const int P = nek1 - N;
     const int M = P + N;
 
+    const int Mup = ggml_up(M, GGML_SOFT_MAX_UNROLL);
+
     GGML_ASSERT(ne0 == D);
     GGML_ASSERT(ne1 == N);
     GGML_ASSERT(P >= 0);
@@ -6030,8 +6143,14 @@ static void ggml_compute_forward_flash_attn_f16(
         const int iq2 = (ir - iq3*neq2*neq1)/neq1;
         const int iq1 = (ir - iq3*neq2*neq1 - iq2*neq1);
 
-        float * S = (float *) params->wdata + ith*(2*M + CACHE_LINE_SIZE_F32);
+        float * S = (float *) params->wdata + ith*(2*Mup + CACHE_LINE_SIZE_F32);
+
+        for (int i = M; i < Mup; ++i) {
+            S[i] = -INFINITY;
+        }
 
+        // looks like unrolling here does not help
+#if 1
         for (int ic = 0; ic < nek1; ++ic) {
             // k indices
             const int ik3 = iq3;
@@ -6046,6 +6165,24 @@ static void ggml_compute_forward_flash_attn_f16(
                     (ggml_fp16_t *) ((char *) k->data + (ik1*nbk1 + ik2*nbk2 + ik3*nbk3)),
                     (ggml_fp16_t *) ((char *) q->data + (iq1*nbq1 + iq2*nbq2 + iq3*nbq3)));
         }
+#else
+        GGML_ASSERT(nek1 % GGML_VEC_DOT_UNROLL == 0);
+
+        for (int ic = 0; ic < nek1; ic += GGML_VEC_DOT_UNROLL) {
+            // k indices
+            const int ik3 = iq3;
+            const int ik2 = iq2;
+            const int ik1 = ic;
+
+            // S indices
+            const int i1 = ik1;
+
+            ggml_vec_dot_f16_unroll(neq0, nbk1,
+                    S + i1,
+                                    ((char *) k->data + (ik1*nbk1 + ik2*nbk2 + ik3*nbk3)),
+                    (ggml_fp16_t *) ((char *) q->data + (iq1*nbq1 + iq2*nbq2 + iq3*nbq3)));
+        }
+#endif
 
         // scale
         ggml_vec_scale_f32(nek1, S, scale);
@@ -6061,30 +6198,50 @@ static void ggml_compute_forward_flash_attn_f16(
         // softmax
         {
             float max = -INFINITY;
-            for (int i = 0; i < M; i++) {
-                max = MAX(max, S[i]);
-            }
+            ggml_vec_max_f32(M, &max, S);
+
+            float sum = 0.0f;
+            {
+#ifndef GGML_USE_ACCELERATE
+                uint16_t   scvt[GGML_SOFT_MAX_UNROLL];
+                ggml_float sump[GGML_SOFT_MAX_UNROLL] = { 0.0 };
 
-            ggml_float sum = 0.0;
+                for (int i = 0; i < Mup; i += GGML_SOFT_MAX_UNROLL) {
+                    float * SS = S + i;
 
-            uint16_t ss;
-            for (int i = 0; i < M; i++) {
-                if (S[i] == -INFINITY) {
-                    S[i] = 0.0;
-                } else {
-                    //const float val = (S[i] == -INFINITY) ? 0.0 : exp(S[i] - max);
-                    ggml_fp16_t s = GGML_FP32_TO_FP16(S[i] - max);
-                    memcpy(&ss, &s, sizeof(ss));
-                    const float val = GGML_FP16_TO_FP32(table_exp_f16[ss]);
-                    sum += val;
-                    S[i] = val;
+                    for (int j = 0; j < GGML_SOFT_MAX_UNROLL; ++j) {
+                        if (SS[j] == -INFINITY) {
+                            SS[j] = 0.0f;
+                        } else {
+                            ggml_fp16_t s = GGML_FP32_TO_FP16(SS[j] - max);
+                            memcpy(&scvt[j], &s, sizeof(uint16_t));
+                            const float val = GGML_FP16_TO_FP32(table_exp_f16[scvt[j]]);
+                            sump[j] += val;
+                            SS[j] = val;
+                        }
+                    }
                 }
+
+                for (int i = 0; i < GGML_SOFT_MAX_UNROLL; i++) {
+                    sum += sump[i];
+                }
+#else
+                vvexpf(S, S, &Mup);
+                ggml_vec_sum_f32(Mup, &sum, S);
+#endif
             }
 
             assert(sum > 0.0f);
 
             sum = 1.0/sum;
             ggml_vec_scale_f32(M, S, sum);
+
+#ifndef NDEBUG
+            for (int i = 0; i < M; ++i) {
+                assert(!isnan(S[i]));
+                assert(!isinf(S[i]));
+            }
+#endif
         }
 
         ggml_fp16_t * S16 = (ggml_fp16_t *) ((float *) params->wdata + ith*(2*M + CACHE_LINE_SIZE_F32) + M);
@@ -6093,15 +6250,17 @@ static void ggml_compute_forward_flash_attn_f16(
             S16[i] = GGML_FP32_TO_FP16(S[i]);
         }
 
-        for (int ic = 0; ic < nev1; ++ic) {
+        GGML_ASSERT(nev1 % GGML_VEC_DOT_UNROLL == 0);
+
+        for (int ic = 0; ic < nev1; ic += GGML_VEC_DOT_UNROLL) {
             // dst indices
             const int i1 = iq1;
             const int i2 = iq2;
             const int i3 = iq3;
 
-            ggml_vec_dot_f16(nek1,
-                    (float *)       ((char *) dst->data + (ic*nb0 + i1*nb1  + i2*nb2  + i3*nb3)),
-                    (ggml_fp16_t *) ((char *) v->data   + (         ic*nbv1 + i2*nbv2 + i3*nbv3)),
+            ggml_vec_dot_f16_unroll(nek1, nbv1,
+                    (float *) ((char *) dst->data + (ic*nb0 + i1*nb1  + i2*nb2  + i3*nb3)),
+                              ((char *) v->data   + (         ic*nbv1 + i2*nbv2 + i3*nbv3)),
                     S16);
         }
     }
@@ -6983,7 +7142,9 @@ static thread_ret_t ggml_graph_compute_thread(void * data) {
         }
 
         if (state->node) {
-            ggml_compute_forward(&state->params, state->node);
+            if (state->params.ith < state->params.nth) {
+                ggml_compute_forward(&state->params, state->node);
+            }
             state->node = NULL;
         } else {
             break;
@@ -7077,9 +7238,15 @@ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph)
                     } break;
                 case GGML_OP_MUL_MAT:
                     {
-                        // TODO: use different scheduling for different matrix sizes
                         node->n_tasks = n_threads;
 
+                        // TODO: use different scheduling for different matrix sizes
+                        //const int nr0 = ggml_nrows(node->src0);
+                        //const int nr1 = ggml_nrows(node->src1);
+
+                        //node->n_tasks = MIN(n_threads, MAX(1, nr0/128));
+                        //printf("nr0 = %8d, nr1 = %8d, nr0*nr1 = %8d, n_tasks = %d\n", nr0, nr1, nr0*nr1, node->n_tasks);
+
                         size_t cur = 0;
 
                         // TODO: better way to determine if the matrix is transposed
@@ -7090,6 +7257,7 @@ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph)
                                 node->src1->type == GGML_TYPE_F32) {
 #if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS)
                                 if (ggml_compute_forward_mul_mat_use_blas(node->src0, node->src1, node)) {
+                                    node->n_tasks = 1;
                                     cur = sizeof(float)*(node->src0->ne[0]*node->src0->ne[1]);
                                 } else {
                                     cur = sizeof(ggml_fp16_t)*ggml_nelements(node->src1);
@@ -7165,14 +7333,16 @@ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph)
 
                         size_t cur = 0;
 
+                        const int ne11 = ggml_up(node->src1->ne[1], GGML_SOFT_MAX_UNROLL);
+
                         if (node->src1->type == GGML_TYPE_F32) {
-                            cur  = sizeof(float)*node->src1->ne[1]*node->n_tasks; // TODO: this can become (n_tasks-1)
-                            cur += sizeof(float)*node->src1->ne[1]*node->n_tasks; // this is overestimated by x2
+                            cur  = sizeof(float)*ne11*node->n_tasks; // TODO: this can become (n_tasks-1)
+                            cur += sizeof(float)*ne11*node->n_tasks; // this is overestimated by x2
                         }
 
                         if (node->src1->type == GGML_TYPE_F16) {
-                            cur  = sizeof(float)*node->src1->ne[1]*node->n_tasks; // TODO: this can become (n_tasks-1)
-                            cur += sizeof(float)*node->src1->ne[1]*node->n_tasks; // this is overestimated by x2
+                            cur  = sizeof(float)*ne11*node->n_tasks; // TODO: this can become (n_tasks-1)
+                            cur += sizeof(float)*ne11*node->n_tasks; // this is overestimated by x2
                         }
 
                         work_size = MAX(work_size, cur);
@@ -7261,7 +7431,7 @@ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph)
                 workers[j].params = (struct ggml_compute_params) {
                     .type  = GGML_TASK_COMPUTE,
                     .ith   = j + 1,
-                    .nth   = n_threads,
+                    .nth   = node->n_tasks,
                     .wsize = cgraph->work ? ggml_nbytes(cgraph->work) : 0,
                     .wdata = cgraph->work ? cgraph->work->data : NULL,
                 };
@@ -7316,7 +7486,7 @@ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph)
                 workers[j].params = (struct ggml_compute_params) {
                     .type  = GGML_TASK_FINALIZE,
                     .ith   = j + 1,
-                    .nth   = n_threads,
+                    .nth   = node->n_tasks,
                     .wsize = cgraph->work ? ggml_nbytes(cgraph->work) : 0,
                     .wdata = cgraph->work ? cgraph->work->data : NULL,
                 };