]> git.djapps.eu Git - pkg/ggml/sources/ggml/commitdiff
ggml : use dynamic thread scheduling for matrix multiplication (llama/6915)
authorkunnis <redacted>
Wed, 15 May 2024 17:59:12 +0000 (12:59 -0500)
committerGeorgi Gerganov <redacted>
Tue, 28 May 2024 11:41:08 +0000 (14:41 +0300)
* Just reordering some structs.

* Adding in the calls to mm_pause

* Passing around the state

* Renaming and moving a bunch of variables around.

* Extracting the logic to it's own function.

* Moving some variable definitions into the chunk function.

* Moving some variables around

* moving src1_cont inside

* Moving row_size

* adding the current_chunk

* Reorg the code.

* Formatting to match the orig patch

* starting to setup the chunking variables

* Starting the buildup of the loop

* The yield shouldn't be necessary.

* adding the looping structure based on the chunk configuration.

* Add in the re-chunking code.

* Making it much more likely to rechunk.

* disable resizing if numa is enabled.

* Updating comments with what we've learned.

* Fix formatting

* Couple more formatting fixes.

* More style fixes.

* Fix Warnings

* Going with unused because there's conditional logic that needs it.

* Update ggml.c

* Update ggml.c

---------

src/ggml.c

index 67e17a2102b166ef6b6a984c815f60922def59a6..684d84235f8d37421cf92960c569114f49df371b 100644 (file)
@@ -112,6 +112,8 @@ typedef void * thread_ret_t;
 
 #endif
 
+typedef pthread_t ggml_thread_t;
+
 #ifdef GGML_USE_CPU_HBM
 #include <hbwmalloc.h>
 #endif
@@ -1539,6 +1541,59 @@ static inline void __sse_f16x4_store(ggml_fp16_t *x, __m128 y) {
 #define GGML_F16_ARR (GGML_F16_STEP/GGML_F16_EPR)
 #endif
 
+//
+// ggml context
+//
+
+struct ggml_context {
+    size_t mem_size;
+    void* mem_buffer;
+    bool   mem_buffer_owned;
+    bool   no_alloc;
+    bool   no_alloc_save; // this is used to save the no_alloc state when using scratch buffers
+
+    int    n_objects;
+
+    struct ggml_object* objects_begin;
+    struct ggml_object* objects_end;
+
+    struct ggml_scratch scratch;
+    struct ggml_scratch scratch_save;
+};
+
+struct ggml_context_container {
+    bool used;
+
+    struct ggml_context context;
+};
+
+struct ggml_compute_state_shared {
+    const struct ggml_cgraph* cgraph;
+    const struct ggml_cplan* cplan;
+
+    int64_t perf_node_start_cycles;
+    int64_t perf_node_start_time_us;
+
+    const int n_threads;
+
+    // synchronization primitives
+    atomic_int n_active;  // num active threads
+    atomic_int node_n;    // active graph node
+    atomic_int node_task; // active graph node task phase
+
+    ggml_abort_callback abort_callback; // abort ggml_graph_compute when true
+    void* abort_callback_data;
+
+    atomic_int current_chunk; // currently processing chunk during Mat_Mul, shared between all the threads.
+};
+
+struct ggml_compute_state {
+    ggml_thread_t thrd;
+    int ith;
+    struct ggml_compute_state_shared* shared;
+    enum ggml_status ec;
+};
+
 //
 // fundamental operations
 //
@@ -2385,32 +2440,6 @@ static void ggml_setup_op_has_task_pass(void) {
     }
 }
 
-//
-// ggml context
-//
-
-struct ggml_context {
-    size_t mem_size;
-    void * mem_buffer;
-    bool   mem_buffer_owned;
-    bool   no_alloc;
-    bool   no_alloc_save; // this is used to save the no_alloc state when using scratch buffers
-
-    int    n_objects;
-
-    struct ggml_object * objects_begin;
-    struct ggml_object * objects_end;
-
-    struct ggml_scratch scratch;
-    struct ggml_scratch scratch_save;
-};
-
-struct ggml_context_container {
-    bool used;
-
-    struct ggml_context context;
-};
-
 //
 // NUMA support
 //
@@ -11815,9 +11844,101 @@ static bool ggml_compute_forward_mul_mat_use_blas(struct ggml_tensor * dst) {
 }
 #endif
 
+static void ggml_compute_forward_mul_mat_one_chunk(
+    const struct ggml_compute_params * params,
+    struct ggml_tensor * dst,
+    const int64_t num_rows_per_vec_dot,
+    const int64_t ir0_start,
+    const int64_t ir0_end,
+    const int64_t ir1_start,
+    const int64_t ir1_end) {
+
+    const struct ggml_tensor * src0 = dst->src[0];
+    const struct ggml_tensor * src1 = dst->src[1];
+
+    GGML_TENSOR_BINARY_OP_LOCALS
+
+    const enum ggml_type type = src0->type;
+
+    const bool src1_cont = ggml_is_contiguous(src1);
+
+    ggml_vec_dot_t    const vec_dot = type_traits[type].vec_dot;
+    enum ggml_type    const vec_dot_type = type_traits[type].vec_dot_type;
+
+    // broadcast factors
+    const int64_t r2 = ne12 / ne02;
+    const int64_t r3 = ne13 / ne03;
+
+    //printf("ir0_start = %6lld, ir0_end = %6lld, ir1_start = %6lld, ir1_end = %6lld\n", ir0_start, ir0_end, ir1_start, ir1_end);
+
+    // threads with no work simply yield (not sure if it helps)
+    if (ir0_start >= ir0_end || ir1_start >= ir1_end) {
+        return;
+    }
+
+    const void * wdata = (src1->type == vec_dot_type) ? src1->data : params->wdata;
+    const size_t row_size = ggml_row_size(vec_dot_type, ne10);
+
+    assert(ne12 % ne02 == 0);
+    assert(ne13 % ne03 == 0);
+
+    // block-tiling attempt
+    const int64_t blck_0 = 16;
+    const int64_t blck_1 = 16;
+
+    const size_t src1_col_stride = src1_cont || src1->type != vec_dot_type ? row_size : nb11;
+
+    // attempt to reduce false-sharing (does not seem to make a difference)
+    // 16 * 2, accounting for mmla kernels
+    float tmp[32];
+
+    for (int64_t iir1 = ir1_start; iir1 < ir1_end; iir1 += blck_1) {
+        for (int64_t iir0 = ir0_start; iir0 < ir0_end; iir0 += blck_0) {
+            for (int64_t ir1 = iir1; ir1 < iir1 + blck_1 && ir1 < ir1_end; ir1 += num_rows_per_vec_dot) {
+                const int64_t i13 = (ir1 / (ne12 * ne1));
+                const int64_t i12 = (ir1 - i13 * ne12 * ne1) / ne1;
+                const int64_t i11 = (ir1 - i13 * ne12 * ne1 - i12 * ne1);
+
+                // broadcast src0 into src1
+                const int64_t i03 = i13 / r3;
+                const int64_t i02 = i12 / r2;
+
+                const int64_t i1 = i11;
+                const int64_t i2 = i12;
+                const int64_t i3 = i13;
+
+                const char * src0_row = (const char*)src0->data + (0 + i02 * nb02 + i03 * nb03);
+
+                // desc: when src1 is not a contiguous memory block we have to calculate the offset using the strides
+                //       if it is, then we have either copied the data to params->wdata and made it contiguous or we are using
+                //       the original src1 data pointer, so we should index using the indices directly
+                // TODO: this is a bit of a hack, we should probably have a better way to handle this
+                const char * src1_col = (const char*)wdata +
+                    (src1_cont || src1->type != vec_dot_type
+                        ? (i11 + i12 * ne11 + i13 * ne12 * ne11) * row_size
+                        : (i11 * nb11 + i12 * nb12 + i13 * nb13));
+                float * dst_col = (float*)((char*)dst->data + (i1 * nb1 + i2 * nb2 + i3 * nb3));
+
+                //for (int64_t ir0 = iir0; ir0 < iir0 + blck_0 && ir0 < ir0_end; ++ir0) {
+                //    vec_dot(ne00, &dst_col[ir0], src0_row + ir0*nb01, src1_col);
+                //}
+
+                for (int64_t ir0 = iir0; ir0 < iir0 + blck_0 && ir0 < ir0_end; ir0 += num_rows_per_vec_dot) {
+                    vec_dot(ne00, &tmp[ir0 - iir0], (num_rows_per_vec_dot > 1 ? 16 : 0), src0_row + ir0 * nb01, (num_rows_per_vec_dot > 1 ? nb01 : 0), src1_col, (num_rows_per_vec_dot > 1 ? src1_col_stride : 0), num_rows_per_vec_dot);
+                }
+
+                for (int cn = 0; cn < num_rows_per_vec_dot; ++cn) {
+                    memcpy(&dst_col[iir0 + cn * nb1 / nb0], tmp + (cn * 16), (MIN(iir0 + blck_0, ir0_end) - iir0) * sizeof(float));
+                }
+            }
+        }
+    }
+}
+
 static void ggml_compute_forward_mul_mat(
         const struct ggml_compute_params * params,
-              struct ggml_tensor * dst) {
+              struct ggml_tensor * dst,
+              struct ggml_compute_state * state) {
 
     const struct ggml_tensor * src0 = dst->src[0];
     const struct ggml_tensor * src1 = dst->src[1];
@@ -11832,9 +11953,6 @@ static void ggml_compute_forward_mul_mat(
 
     const enum ggml_type type = src0->type;
 
-    const bool src1_cont = ggml_is_contiguous(src1);
-
-    ggml_vec_dot_t    const vec_dot               = type_traits[type].vec_dot;
     enum ggml_type    const vec_dot_type          = type_traits[type].vec_dot_type;
     ggml_from_float_t const from_float_to_vec_dot = type_traits[vec_dot_type].from_float;
     int64_t           const vec_dot_num_rows      = type_traits[type].nrows;
@@ -11855,8 +11973,10 @@ static void ggml_compute_forward_mul_mat(
     GGML_ASSERT(nb2 <= nb3);
 
     // broadcast factors
-    const int64_t r2 = ne12/ne02;
-    const int64_t r3 = ne13/ne03;
+    const int64_t r2 = ne12 / ne02;
+    const int64_t r3 = ne13 / ne03;
+    UNUSED(r2);
+    UNUSED(r3);
 
     // nb01 >= nb00 - src0 is not transposed
     //   compute by src0 rows
@@ -11938,6 +12058,8 @@ static void ggml_compute_forward_mul_mat(
 #endif
 
 #if GGML_USE_LLAMAFILE
+    const bool src1_cont = ggml_is_contiguous(src1);
+
     if (src1_cont) {
         for (int64_t i13 = 0; i13 < ne13; i13++)
             for (int64_t i12 = 0; i12 < ne12; i12++)
@@ -11963,6 +12085,8 @@ UseGgmlGemm1:;
         if (ith != 0) {
             return;
         }
+        // Every thread starts at ith, so the first unprocessed chunk is nth.  This save a bit of coordination right at the start.
+        atomic_store(&state->shared->current_chunk, nth);
         if (src1->type != vec_dot_type) {
             char * wdata = params->wdata;
             const size_t row_size = ggml_row_size(vec_dot_type, ne10);
@@ -11987,11 +12111,11 @@ UseGgmlGemm1:;
         return;
     }
 
-    const void * wdata    = (src1->type == vec_dot_type) ? src1->data : params->wdata;
-    const size_t row_size = ggml_row_size(vec_dot_type, ne10);
-
 #if GGML_USE_LLAMAFILE
     if (src1->type != vec_dot_type) {
+        const void* wdata = (src1->type == vec_dot_type) ? src1->data : params->wdata;
+        const size_t row_size = ggml_row_size(vec_dot_type, ne10);
+
         for (int64_t i13 = 0; i13 < ne13; i13++)
             for (int64_t i12 = 0; i12 < ne12; i12++)
                 if (!llamafile_sgemm(ne01, ne11, ne00/ggml_blck_size(src0->type),
@@ -12012,98 +12136,87 @@ UseGgmlGemm1:;
 UseGgmlGemm2:;
 #endif
 
-    const int64_t nr0 = ne01;          // src0 rows
-    const int64_t nr1 = ne1*ne12*ne13; // src1 rows
-
-    //printf("nr0 = %lld, nr1 = %lld\n", nr0, nr1);
-
-    // distribute the thread work across the inner or outer loop based on which one is larger
-
-    const int64_t nth0 = nr0 > nr1 ? nth : 1; // parallelize by src0 rows
-    const int64_t nth1 = nr0 > nr1 ? 1 : nth; // parallelize by src1 rows
-
-    const int64_t ith0 = ith % nth0;
-    const int64_t ith1 = ith / nth0;
-
-    const int64_t dr0 = (nr0 + nth0 - 1)/nth0;
-    const int64_t dr1 = (nr1 + nth1 - 1)/nth1;
-
-    const int64_t ir010 = dr0*ith0;
-    const int64_t ir011 = MIN(ir010 + dr0, nr0);
-
-    const int64_t ir110 = dr1*ith1;
-    const int64_t ir111 = MIN(ir110 + dr1, nr1);
-
-    //printf("ir010 = %6lld, ir011 = %6lld, ir110 = %6lld, ir111 = %6lld\n", ir010, ir011, ir110, ir111);
-
-    // threads with no work simply yield (not sure if it helps)
-    if (ir010 >= ir011 || ir110 >= ir111) {
-        sched_yield();
-        return;
-    }
+#ifdef GGML_PERF
+    int chunks_executed = 0;
+    UNUSED(chunks_executed);
+#endif
 
-    assert(ne12 % ne02 == 0);
-    assert(ne13 % ne03 == 0);
+    // This is the size of the first dimension of the result, so we can iterate that way. (see the ASSERT above, these are the same numbers)
+    const int64_t nr0 = ne0;
 
-    // block-tiling attempt
-    const int64_t blck_0 = 16;
-    const int64_t blck_1 = 16;
+    // This is the size of the rest of the dimensions of the result
+    const int64_t nr1 = ne1 * ne2 * ne3;
 
     // dot kernels can handle 1 row and col at a time, but mmla kernels can process 2 rows and cols
-    int64_t nrc = vec_dot_num_rows;
+    int64_t num_rows_per_vec_dot = vec_dot_num_rows;
     // TODO: currently the mmla kernels support only even numbered rows/cols.
     // this check can be removed once they are extended to support odd numbered rows/cols too
     if ((nr0 % 2 != 0) || (ne11 % 2 != 0)) {
-        nrc = 1;
+        num_rows_per_vec_dot = 1;
     }
 
-    const size_t src1_col_stride = src1_cont || src1->type != vec_dot_type ? row_size : nb11;
+    // Now select a reasonable chunk size.
+    int chunk_size = 16;
 
-    // attempt to reduce false-sharing (does not seem to make a difference)
-    // 16 * 2, accounting for mmla kernels
-    float tmp[32];
+    // We need to step up the size if it's small
+    if (nr0 == 1 || nr1 == 1) {
+        chunk_size = 64;
+    }
 
-    for (int64_t iir1 = ir110; iir1 < ir111; iir1 += blck_1) {
-        for (int64_t iir0 = ir010; iir0 < ir011; iir0 += blck_0) {
-            for (int64_t ir1 = iir1; ir1 < iir1 + blck_1 && ir1 < ir111; ir1 += nrc) {
-                const int64_t i13 = (ir1/(ne12*ne1));
-                const int64_t i12 = (ir1 - i13*ne12*ne1)/ne1;
-                const int64_t i11 = (ir1 - i13*ne12*ne1 - i12*ne1);
+    // distribute the work across the inner or outer loop based on which one is larger
+    // The number of chunks in the 0/1 dim.
+    // CEIL(nr0/chunk_size)
+    int64_t nchunk0 = (nr0 + chunk_size - 1) / chunk_size;
+    int64_t nchunk1 = (nr1 + chunk_size - 1) / chunk_size;
 
-                // broadcast src0 into src1
-                const int64_t i03 = i13/r3;
-                const int64_t i02 = i12/r2;
+    // If the chunking is poor for the number of threads on this setup, scrap the whole plan.  Re-chunk it by thread.
+    //   Also, chunking by thread was measured to have perform better on NUMA systems.  See https://github.com/ggerganov/llama.cpp/pull/6915
+    //   In theory, chunking should be just as useful on NUMA and non NUMA systems, but testing disagreed with that.
+    if (nchunk0 * nchunk1 < nth * 4 || ggml_is_numa()) {
+        // distribute the thread work across the inner or outer loop based on which one is larger
+        nchunk0 = nr0 > nr1 ? nth : 1; // parallelize by src0 rows
+        nchunk1 = nr0 > nr1 ? 1 : nth; // parallelize by src1 rows
+    }
 
-                const int64_t i1 = i11;
-                const int64_t i2 = i12;
-                const int64_t i3 = i13;
+    // The number of elements in each chunk
+    const int64_t dr0 = (nr0 + nchunk0 - 1) / nchunk0;
+    const int64_t dr1 = (nr1 + nchunk1 - 1) / nchunk1;
 
-                const char * src0_row = (const char *) src0->data + (0 + i02*nb02 + i03*nb03);
+    //if (ith == 0)
+    //    printf("MUL_MAT = [%d, %d, %d, %d] x [%d, %d, %d, %d] = %d x %d = %d.  Fp Ops/Ch %d\n", ne00, ne01, ne02, ne03, ne10, ne11, ne12, ne13, nchunk0, nchunk1, nchunk0 * nchunk1, ne00 * nr0 * nr1 / nchunk0 / nchunk1);
 
-                // desc: when src1 is not a contiguous memory block we have to calculate the offset using the strides
-                //       if it is, then we have either copied the data to params->wdata and made it contiguous or we are using
-                //       the original src1 data pointer, so we should index using the indices directly
-                // TODO: this is a bit of a hack, we should probably have a better way to handle this
-                const char * src1_col = (const char *) wdata +
-                    (src1_cont || src1->type != vec_dot_type
-                     ? (i11      + i12*ne11 + i13*ne12*ne11)*row_size
-                     : (i11*nb11 + i12*nb12 + i13*nb13));
-                float * dst_col = (float *) ((char *) dst->data + (i1*nb1 + i2*nb2 + i3*nb3));
+    // The first chunk comes from our thread_id, the rest will get auto-assigned.
+    int current_chunk = ith;
 
-                //for (int64_t ir0 = iir0; ir0 < iir0 + blck_0 && ir0 < ir011; ++ir0) {
-                //    vec_dot(ne00, &dst_col[ir0], src0_row + ir0*nb01, src1_col);
-                //}
+    while (current_chunk < nchunk0 * nchunk1) {
+        const int64_t ith0 = current_chunk % nchunk0;
+        const int64_t ith1 = current_chunk / nchunk0;
 
-                for (int64_t ir0 = iir0; ir0 < iir0 + blck_0 && ir0 < ir011; ir0 += nrc) {
-                    vec_dot(ne00, &tmp[ir0 - iir0], (nrc>1 ? 16 : 0), src0_row + ir0*nb01, (nrc>1 ? nb01 : 0), src1_col, (nrc>1 ? src1_col_stride : 0), nrc);
-                }
+        const int64_t ir0_start = dr0 * ith0;
+        const int64_t ir0_end = MIN(ir0_start + dr0, nr0);
 
-                for (int cn = 0; cn < nrc; ++cn) {
-                    memcpy(&dst_col[iir0 + cn*nb1/nb0], tmp + (cn*16), (MIN(iir0 + blck_0, ir011) - iir0)*sizeof(float));
-                }
-            }
+        const int64_t ir1_start = dr1 * ith1;
+        const int64_t ir1_end = MIN(ir1_start + dr1, nr1);
+
+        ggml_compute_forward_mul_mat_one_chunk(params, dst, num_rows_per_vec_dot, ir0_start, ir0_end, ir1_start, ir1_end);
+
+#ifdef GGML_PERF
+        chunks_executed++;
+#endif
+
+        if (nth >= nchunk0 * nchunk1) {
+            break;
         }
+
+        current_chunk = atomic_fetch_add(&state->shared->current_chunk, 1);
     }
+
+#ifdef GGML_PERF
+    // These numbers are useful when trying to measure how well the threading scheduling works.
+    //int64_t workSize = (ne01 * ne11 * ne12 * ne13 * ne00) / nchunk0 / nchunk1;
+    //float time = (ggml_perf_time_us() - t0);
+    //printf("MUL_MAT = %f ms, [%d, %d, %d, %d] x [%d, %d, %d, %d] = %I64u, %f ops/usec in %d chunks.\n", time / 1000.0, ne00, ne01, ne02, ne03, ne10, ne11, ne12, ne13, workSize, (float)workSize/time, chunks_executed);
+#endif
 }
 
 // ggml_compute_forward_mul_mat_id
@@ -17358,7 +17471,7 @@ static void ggml_compute_forward_cross_entropy_loss_back(
 
 /////////////////////////////////
 
-static void ggml_compute_forward(struct ggml_compute_params * params, struct ggml_tensor * tensor) {
+static void ggml_compute_forward(struct ggml_compute_params * params, struct ggml_tensor * tensor, struct ggml_compute_state * state) {
     GGML_ASSERT(params);
 
     if (tensor->op == GGML_OP_NONE || ggml_is_empty(tensor)) {
@@ -17456,7 +17569,7 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
             } break;
         case GGML_OP_MUL_MAT:
             {
-                ggml_compute_forward_mul_mat(params, tensor);
+                ggml_compute_forward_mul_mat(params, tensor, state);
             } break;
         case GGML_OP_MUL_MAT_ID:
             {
@@ -19072,8 +19185,6 @@ typedef int ggml_lock_t;
 
 #define GGML_LOCK_INITIALIZER 0
 
-typedef pthread_t ggml_thread_t;
-
 #define ggml_thread_create pthread_create
 #define ggml_thread_join   pthread_join
 
@@ -19099,8 +19210,6 @@ typedef int ggml_lock_t;
 
 #define GGML_LOCK_INITIALIZER 0
 
-typedef pthread_t ggml_thread_t;
-
 #define ggml_thread_create pthread_create
 #define ggml_thread_join   pthread_join
 
@@ -19180,31 +19289,6 @@ static void set_numa_thread_affinity(int thread_n) { UNUSED(thread_n);  }
 static void clear_numa_thread_affinity(void) {}
 #endif
 
-struct ggml_compute_state_shared {
-    const struct ggml_cgraph * cgraph;
-    const struct ggml_cplan  * cplan;
-
-    int64_t perf_node_start_cycles;
-    int64_t perf_node_start_time_us;
-
-    const int n_threads;
-
-    // synchronization primitives
-    atomic_int n_active;  // num active threads
-    atomic_int node_n;    // active graph node
-    atomic_int node_task; // active graph node task phase
-
-    ggml_abort_callback abort_callback; // abort ggml_graph_compute when true
-    void * abort_callback_data;
-};
-
-struct ggml_compute_state {
-    ggml_thread_t thrd;
-    int ith;
-    struct ggml_compute_state_shared * shared;
-    enum ggml_status ec;
-};
-
 static void ggml_graph_compute_perf_stats_node(struct ggml_tensor * node, const struct ggml_compute_state_shared * st) {
     int64_t cycles_cur  = ggml_perf_cycles()  - st->perf_node_start_cycles;
     int64_t time_us_cur = ggml_perf_time_us() - st->perf_node_start_time_us;
@@ -19477,6 +19561,10 @@ static void ggml_graph_compute_thread_sync_node(int * node_n, struct ggml_comput
 
         * node_n = atomic_load(&state->shared->node_n);
         if (* node_n != last_node_n) break;
+#if defined(__SSE3__)
+        // Tell the processor we're spinning.  It's a processor hint for spinlocks.
+        _mm_pause();
+#endif
     }
 }
 
@@ -19491,6 +19579,10 @@ static void ggml_graph_compute_thread_sync_task(int * task_phase, struct ggml_co
 
         * task_phase = atomic_load(&state->shared->node_task);
         if (* task_phase != last_task_phase) break;
+#if defined(__SSE3__)
+        // Tell the processor we're spinning.  It's a processor hint for spinlocks.
+        _mm_pause();
+#endif
     }
 }
 
@@ -19530,7 +19622,7 @@ static thread_ret_t ggml_graph_compute_thread(void * data) {
                 struct ggml_tensor * node = cgraph->nodes[node_n];
                 if (GGML_OP_HAS_FINALIZE[node->op]) {
                     params.nth = ggml_get_n_tasks(node, n_threads, state->shared->n_threads);
-                    ggml_compute_forward(&params, node);
+                    ggml_compute_forward(&params, node, state);
                 }
                 ggml_graph_compute_perf_stats_node(node, state->shared);
             }
@@ -19550,17 +19642,17 @@ static thread_ret_t ggml_graph_compute_thread(void * data) {
                     /* INIT */
                     if (GGML_OP_HAS_INIT[node->op]) {
                         params.type = GGML_TASK_TYPE_INIT;
-                        ggml_compute_forward(&params, node);
+                        ggml_compute_forward(&params, node, state);
                     }
 
                     // TODO: maybe push node_n to the atomic but if other threads see n_tasks is 1,
                     // they do something more efficient than spinning (?)
                     params.type = GGML_TASK_TYPE_COMPUTE;
-                    ggml_compute_forward(&params, node);
+                    ggml_compute_forward(&params, node, state);
 
                     if (GGML_OP_HAS_FINALIZE[node->op]) {
                         params.type = GGML_TASK_TYPE_FINALIZE;
-                        ggml_compute_forward(&params, node);
+                        ggml_compute_forward(&params, node, state);
                     }
 
                     ggml_graph_compute_perf_stats_node(node, state->shared);
@@ -19599,7 +19691,7 @@ static thread_ret_t ggml_graph_compute_thread(void * data) {
 
         if (state->ith < n_tasks) {
             if (GGML_OP_HAS_INIT[node->op]) {
-                ggml_compute_forward(&params, node);
+                ggml_compute_forward(&params, node, state);
             }
         }
 
@@ -19620,7 +19712,7 @@ static thread_ret_t ggml_graph_compute_thread(void * data) {
 
         if (state->ith < n_tasks) {
             params.type = GGML_TASK_TYPE_COMPUTE;
-            ggml_compute_forward(&params, node);
+            ggml_compute_forward(&params, node, state);
         }
 
         if (atomic_fetch_sub(&state->shared->n_active, 1) == 1) {
@@ -19871,6 +19963,7 @@ enum ggml_status ggml_graph_compute(struct ggml_cgraph * cgraph, struct ggml_cpl
         /*.node_task               =*/ GGML_TASK_TYPE_FINALIZE,
         /*.abort_callback          =*/ NULL,
         /*.abort_callback_data     =*/ NULL,
+        /*.current_chunk;          =*/ 0,
     };
     struct ggml_compute_state * workers = alloca(sizeof(struct ggml_compute_state)*n_threads);