]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
ggml : remove ggml_task_type and GGML_PERF (#8017)
authorslaren <redacted>
Mon, 24 Jun 2024 01:07:59 +0000 (03:07 +0200)
committerGitHub <redacted>
Mon, 24 Jun 2024 01:07:59 +0000 (03:07 +0200)
* ggml : remove ggml_task_type and GGML_PERF

* check abort_callback on main thread only

* vulkan : remove usage of ggml_compute_params

* remove LLAMA_PERF

CMakeLists.txt
Makefile
ggml-vulkan.cpp
ggml.c
ggml.h
llama.cpp
sgemm.cpp
sgemm.h

index 9cfe08d7b7d598436eefec25023c96b3e6ed6191..49ba45356a78d86f1a4ada182e9604ab0ac0b49e 100644 (file)
@@ -144,9 +144,6 @@ option(LLAMA_BUILD_SERVER                    "llama: build server example"
 option(LLAMA_LASX                            "llama: enable lasx"                               ON)
 option(LLAMA_LSX                             "llama: enable lsx"                                ON)
 
-# add perf arguments
-option(LLAMA_PERF                            "llama: enable perf"                               OFF)
-
 # Required for relocatable CMake package
 include(${CMAKE_CURRENT_SOURCE_DIR}/scripts/build-info.cmake)
 
@@ -870,10 +867,6 @@ if (LLAMA_CPU_HBM)
     target_link_libraries(ggml PUBLIC memkind)
 endif()
 
-if (LLAMA_PERF)
-    add_compile_definitions(GGML_PERF)
-endif()
-
 function(get_flags CCID CCVER)
     set(C_FLAGS "")
     set(CXX_FLAGS "")
index 4ea59c0b4ef299273b2d74753055770c2fb38b64..3aad77394c5ac86b08fb136e0df5859389f50d64 100644 (file)
--- a/Makefile
+++ b/Makefile
@@ -344,9 +344,6 @@ ifdef LLAMA_GPROF
        MK_CFLAGS   += -pg
        MK_CXXFLAGS += -pg
 endif
-ifdef LLAMA_PERF
-       MK_CPPFLAGS += -DGGML_PERF
-endif
 
 # Architecture specific
 # TODO: probably these flags need to be tweaked on some architectures
index c31877403b0d59af1d899fd3c6f32a801be4d6b3..101781ede4b4fd303c95acd72ed1ebdb9858ec73 100644 (file)
@@ -513,8 +513,8 @@ static size_t vk_skip_checks;
 static size_t vk_output_tensor;
 
 static void ggml_vk_print_tensor(ggml_backend * ctx, const ggml_tensor * tensor, const char * name);
-static void ggml_vk_check_results_0(ggml_backend_vk_context * ctx, ggml_compute_params * params, ggml_tensor * tensor);
-static void ggml_vk_check_results_1(ggml_backend_vk_context * ctx, ggml_compute_params * params, ggml_tensor * tensor);
+static void ggml_vk_check_results_0(ggml_backend_vk_context * ctx, ggml_tensor * tensor);
+static void ggml_vk_check_results_1(ggml_backend_vk_context * ctx, ggml_tensor * tensor);
 #endif
 
 typedef void (*ggml_vk_func_t)(ggml_backend_vk_context * ctx, vk_context * subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst);
@@ -5644,7 +5644,7 @@ static void ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod
     }
 }
 
-static bool ggml_vk_compute_forward(ggml_backend_vk_context * ctx, ggml_compute_params * params, ggml_tensor * tensor){
+static bool ggml_vk_compute_forward(ggml_backend_vk_context * ctx, ggml_tensor * tensor){
     ggml_tensor_extra_gpu * extra = nullptr;
 
     switch (tensor->op) {
@@ -5697,17 +5697,10 @@ static bool ggml_vk_compute_forward(ggml_backend_vk_context * ctx, ggml_compute_
         return false;
     }
 
-    if (params->ith != 0) {
-        return true;
-    }
-    if (params->type == GGML_TASK_TYPE_INIT || params->type == GGML_TASK_TYPE_FINALIZE) {
-        return true;
-    }
-
     VK_LOG_DEBUG("ggml_vk_compute_forward(" << tensor << ", name=" << tensor->name << ", op=" << ggml_op_name(tensor->op) << ", type=" << tensor->type << ", ne0=" << tensor->ne[0] << ", ne1=" << tensor->ne[1] << ", ne2=" << tensor->ne[2] << ", ne3=" << tensor->ne[3] << ", nb0=" << tensor->nb[0] << ", nb1=" << tensor->nb[1] << ", nb2=" << tensor->nb[2] << ", nb3=" << tensor->nb[3] << ", view_src=" << tensor->view_src << ", view_offs=" << tensor->view_offs << ")");
 
 #ifdef GGML_VULKAN_CHECK_RESULTS
-    ggml_vk_check_results_0(ctx, params, tensor);
+    ggml_vk_check_results_0(ctx, tensor);
 #endif
 
     vk_context& subctx = ctx->gc.contexts[extra->ctx_idx];
@@ -6214,9 +6207,6 @@ GGML_CALL static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backen
         ggml_vk_build_graph(ctx,cgraph->nodes[i], i == last_node);
     }
 
-    ggml_compute_params params = {};
-    params.type = GGML_TASK_TYPE_COMPUTE;
-    params.ith = 0;
     for (int i = 0; i < cgraph->n_nodes; i++) {
         ggml_tensor * node = cgraph->nodes[i];
 
@@ -6224,13 +6214,13 @@ GGML_CALL static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backen
             continue;
         }
 
-        bool ok = ggml_vk_compute_forward(ctx, &params, node);
+        bool ok = ggml_vk_compute_forward(ctx, node);
         if (!ok) {
             fprintf(stderr, "%s: error: op not supported %s (%s)\n", __func__, node->name, ggml_op_name(node->op));
         }
 #ifdef GGML_VULKAN_CHECK_RESULTS
         else {
-            ggml_vk_check_results_1(ctx, &params, node);
+            ggml_vk_check_results_1(ctx, node);
         }
 #endif
         GGML_ASSERT(ok);
@@ -6600,11 +6590,8 @@ void * comp_result;
 size_t comp_size;
 size_t comp_nb[GGML_MAX_DIMS];
 size_t check_counter = 0;
-static void ggml_vk_check_results_0(ggml_backend_vk_context * ctx, ggml_compute_params * params, ggml_tensor * tensor) {
-    if (params->ith != 0) {
-        return;
-    }
-    if (params->type == GGML_TASK_TYPE_INIT || params->type == GGML_TASK_TYPE_FINALIZE || tensor->op == GGML_OP_TRANSPOSE) {
+static void ggml_vk_check_results_0(ggml_backend_vk_context * ctx, ggml_tensor * tensor) {
+        if (tensor->op == GGML_OP_TRANSPOSE) {
         return;
     }
 
@@ -6908,11 +6895,8 @@ static void ggml_vk_check_results_0(ggml_backend_vk_context * ctx, ggml_compute_
     ggml_free(ggml_ctx);
 }
 
-static void ggml_vk_check_results_1(ggml_backend_vk_context * ctx, ggml_compute_params * params, ggml_tensor * tensor) {
-    if (params->ith != 0) {
-        return;
-    }
-    if (params->type == GGML_TASK_TYPE_INIT || params->type == GGML_TASK_TYPE_FINALIZE || tensor->op == GGML_OP_TRANSPOSE) {
+static void ggml_vk_check_results_1(ggml_backend_vk_context * ctx, ggml_tensor * tensor) {
+    if (tensor->op == GGML_OP_TRANSPOSE) {
         return;
     }
     if (!(vk_output_tensor > 0 && vk_output_tensor == check_counter) && check_counter <= vk_skip_checks) {
diff --git a/ggml.c b/ggml.c
index 778ca3fdf1f8f8b864208640b0dfd146f20c3953..f5502afbe98b36dc0c75b08b1986105057395b81 100644 (file)
--- a/ggml.c
+++ b/ggml.c
@@ -175,7 +175,6 @@ void ggml_print_backtrace(void) {
 }
 #endif
 
-/*#define GGML_PERF*/
 #define GGML_DEBUG 0
 #define GGML_GELU_FP16
 #define GGML_GELU_QUICK_FP16
@@ -293,7 +292,7 @@ inline static void * ggml_calloc(size_t num, size_t size) {
 #define GGML_FREE(ptr) free(ptr)
 
 #define UNUSED GGML_UNUSED
-#define SWAP(x, y, T) do { T SWAP = x; x = y; y = SWAP; } while (0)
+#define SWAP(x, y, T) do { T SWAP = x; (x) = y; (y) = SWAP; } while (0)
 
 #if defined(GGML_USE_ACCELERATE)
 #include <Accelerate/Accelerate.h>
@@ -474,18 +473,6 @@ int64_t ggml_cycles_per_ms(void) {
     return CLOCKS_PER_SEC/1000;
 }
 
-#ifdef GGML_PERF
-#define ggml_perf_time_ms()       ggml_time_ms()
-#define ggml_perf_time_us()       ggml_time_us()
-#define ggml_perf_cycles()        ggml_cycles()
-#define ggml_perf_cycles_per_ms() ggml_cycles_per_ms()
-#else
-#define ggml_perf_time_ms()       0
-#define ggml_perf_time_us()       0
-#define ggml_perf_cycles()        0
-#define ggml_perf_cycles_per_ms() 0
-#endif
-
 //
 // cross-platform UTF-8 file paths
 //
@@ -1730,8 +1717,8 @@ struct ggml_context {
 
     int    n_objects;
 
-    struct ggml_object* objects_begin;
-    struct ggml_object* objects_end;
+    struct ggml_object * objects_begin;
+    struct ggml_object * objects_end;
 
     struct ggml_scratch scratch;
     struct ggml_scratch scratch_save;
@@ -1744,11 +1731,8 @@ struct ggml_context_container {
 };
 
 struct ggml_compute_state_shared {
-    const struct ggml_cgraph* cgraph;
-    const struct ggml_cplan* cplan;
-
-    int64_t perf_node_start_cycles;
-    int64_t perf_node_start_time_us;
+    const struct ggml_cgraph * cgraph;
+    const struct ggml_cplan * cplan;
 
     int n_threads;
 
@@ -1757,16 +1741,28 @@ struct ggml_compute_state_shared {
     atomic_int n_barrier_passed;
 
     ggml_abort_callback abort_callback; // abort ggml_graph_compute when true
-    void* abort_callback_data;
+    void * abort_callback_data;
 
-    atomic_int current_chunk; // currently processing chunk during Mat_Mul, shared between all the threads.
+    atomic_int current_chunk; // currently processing chunk during mul_mat, shared between all the threads
+
+    enum ggml_status ec;
 };
 
 struct ggml_compute_state {
     ggml_thread_t thrd;
     int ith;
-    struct ggml_compute_state_shared* shared;
-    enum ggml_status ec;
+    struct ggml_compute_state_shared * shared;
+};
+
+struct ggml_compute_params {
+    // ith = thread index, nth = number of threads
+    int ith, nth;
+
+    // work buffer for all threads
+    size_t wsize;
+    void * wdata;
+
+    struct ggml_compute_state_shared * shared;
 };
 
 //
@@ -2814,42 +2810,6 @@ static_assert(GGML_UNARY_OP_COUNT == 13, "GGML_UNARY_OP_COUNT != 13");
 static_assert(sizeof(struct ggml_object)%GGML_MEM_ALIGN == 0, "ggml_object size must be a multiple of GGML_MEM_ALIGN");
 static_assert(sizeof(struct ggml_tensor)%GGML_MEM_ALIGN == 0, "ggml_tensor size must be a multiple of GGML_MEM_ALIGN");
 
-// WARN:
-// Mis-configuration can lead to problem that's hard to reason about:
-// * At best  it crash or talks nosense.
-// * At worst it talks slightly difference but hard to perceive.
-//
-// An op has to enable INIT or FINALIZE when any of it's branch needs that pass.
-// Take care about compile options (e.g., GGML_USE_xxx).
-static bool GGML_OP_HAS_INIT    [GGML_OP_COUNT] = { 0 };
-static bool GGML_OP_HAS_FINALIZE[GGML_OP_COUNT] = { 0 };
-
-static void ggml_setup_op_has_task_pass(void) {
-    {   // INIT
-        bool * p = GGML_OP_HAS_INIT;
-
-        p[GGML_OP_ACC                    ] = true;
-        p[GGML_OP_MUL_MAT                ] = true;
-        p[GGML_OP_MUL_MAT_ID             ] = true;
-        p[GGML_OP_OUT_PROD               ] = true;
-        p[GGML_OP_SET                    ] = true;
-        p[GGML_OP_GET_ROWS_BACK          ] = true;
-        p[GGML_OP_DIAG_MASK_INF          ] = true;
-        p[GGML_OP_DIAG_MASK_ZERO         ] = true;
-        p[GGML_OP_CONV_TRANSPOSE_1D      ] = true;
-        p[GGML_OP_CONV_TRANSPOSE_2D      ] = true;
-        p[GGML_OP_FLASH_ATTN_BACK        ] = true;
-        p[GGML_OP_CROSS_ENTROPY_LOSS     ] = true;
-        p[GGML_OP_ADD_REL_POS            ] = true;
-    }
-
-    {   // FINALIZE
-        bool * p = GGML_OP_HAS_FINALIZE;
-
-        p[GGML_OP_CROSS_ENTROPY_LOSS     ] = true;
-    }
-}
-
 //
 // NUMA support
 //
@@ -2888,7 +2848,7 @@ struct ggml_state {
 static struct ggml_state g_state;
 static atomic_flag g_state_critical = ATOMIC_FLAG_INIT;
 
-// barrier via spin lock
+// critical section via spin lock
 inline static void ggml_critical_section_start(void) {
     while (atomic_flag_test_and_set(&g_state_critical)) {
         // spin
@@ -2896,6 +2856,48 @@ inline static void ggml_critical_section_start(void) {
     }
 }
 
+#ifdef GGML_USE_OPENMP
+static void ggml_barrier(struct ggml_compute_state_shared * shared) {
+    if (shared->n_threads == 1) {
+        return;
+    }
+
+    #pragma omp barrier
+}
+#else
+static void ggml_barrier(struct ggml_compute_state_shared * shared) {
+    if (shared->n_threads == 1) {
+        return;
+    }
+
+    atomic_int * n_barrier = &shared->n_barrier;
+    atomic_int * n_barrier_passed = &shared->n_barrier_passed;
+
+    int n_threads = shared->n_threads;
+    int passed_old = atomic_load(n_barrier_passed);
+
+    if (atomic_fetch_add(n_barrier, 1) == n_threads - 1) {
+        // last thread
+        atomic_store(n_barrier, 0);
+        atomic_fetch_add(n_barrier_passed, 1);
+    } else {
+        // wait for other threads
+        const int n_spin_before_sleep = 100000;
+        while (true) {
+            for (int i = 0; i < n_spin_before_sleep; i++) {
+                if (atomic_load(n_barrier_passed) != passed_old) {
+                    return;
+                }
+            #if defined(__SSE3__)
+                _mm_pause();
+            #endif
+            }
+            sched_yield();
+        }
+    }
+}
+#endif
+
 // TODO: make this somehow automatically executed
 //       some sort of "sentry" mechanism
 inline static void ggml_critical_section_end(void) {
@@ -3000,7 +3002,7 @@ void ggml_numa_init(enum ggml_numa_strategy numa_flag) {
         }
     }
 #else
-    GGML_UNUSED(numa_flag);
+    UNUSED(numa_flag);
     // TODO
 #endif
 }
@@ -3106,9 +3108,7 @@ GGML_CALL const char * ggml_op_desc(const struct ggml_tensor * t) {
         enum ggml_unary_op uop = ggml_get_unary_op(t);
         return ggml_unary_op_name(uop);
     }
-    else {
-        return ggml_op_name(t->op);
-    }
+    return ggml_op_name(t->op);
 }
 
 GGML_CALL size_t ggml_element_size(const struct ggml_tensor * tensor) {
@@ -3375,8 +3375,6 @@ struct ggml_context * ggml_init(struct ggml_init_params params) {
             GGML_PRINT_DEBUG("%s: g_state initialized in %f ms\n", __func__, (t_end - t_start)/1000.0f);
         }
 
-        ggml_setup_op_has_task_pass();
-
         is_first_call = false;
     }
 
@@ -3643,15 +3641,12 @@ static struct ggml_tensor * ggml_new_tensor_impl(
         /*.flags        =*/ 0,
         /*.grad         =*/ NULL,
         /*.src          =*/ { NULL },
-        /*.perf_runs    =*/ 0,
-        /*.perf_cycles  =*/ 0,
-        /*.perf_time_us =*/ 0,
         /*.view_src     =*/ view_src,
         /*.view_offs    =*/ view_offs,
         /*.data         =*/ obj_alloc_size > 0 ? (void *)(result + 1) : data,
         /*.name         =*/ { 0 },
         /*.extra        =*/ NULL,
-        /*.padding      =*/ { 0 },
+        ///*.padding      =*/ { 0 },
     };
 
 #ifdef __clang__
@@ -7829,10 +7824,6 @@ static void ggml_compute_forward_dup_same_cont(
     GGML_ASSERT(ggml_is_contiguous(dst) && ggml_is_contiguous(src0));
     GGML_ASSERT(src0->type == dst->type);
 
-    if (params->type == GGML_TASK_TYPE_INIT || params->type == GGML_TASK_TYPE_FINALIZE) {
-        return;
-    }
-
     const size_t nb00 = src0->nb[0];
     const size_t nb0 = dst->nb[0];
 
@@ -7861,10 +7852,6 @@ static void ggml_compute_forward_dup_f16(
 
     GGML_ASSERT(ggml_nelements(dst) == ggml_nelements(src0));
 
-    if (params->type == GGML_TASK_TYPE_INIT || params->type == GGML_TASK_TYPE_FINALIZE) {
-        return;
-    }
-
     GGML_TENSOR_UNARY_OP_LOCALS
 
     const int ith = params->ith; // thread index
@@ -8134,10 +8121,6 @@ static void ggml_compute_forward_dup_bf16(
 
     GGML_ASSERT(ggml_nelements(dst) == ggml_nelements(src0));
 
-    if (params->type == GGML_TASK_TYPE_INIT || params->type == GGML_TASK_TYPE_FINALIZE) {
-        return;
-    }
-
     GGML_TENSOR_UNARY_OP_LOCALS
 
     const int ith = params->ith; // thread index
@@ -8494,10 +8477,6 @@ static void ggml_compute_forward_dup_f32(
 
     GGML_ASSERT(ggml_nelements(dst) == ggml_nelements(src0));
 
-    if (params->type == GGML_TASK_TYPE_INIT || params->type == GGML_TASK_TYPE_FINALIZE) {
-        return;
-    }
-
     GGML_TENSOR_UNARY_OP_LOCALS
 
     const int ith = params->ith; // thread index
@@ -8817,10 +8796,6 @@ static void ggml_compute_forward_dup_bytes(
     GGML_ASSERT(ggml_nelements(dst) == ggml_nelements(src0));
     GGML_ASSERT(src0->type == dst->type);
 
-    if (params->type == GGML_TASK_TYPE_INIT || params->type == GGML_TASK_TYPE_FINALIZE) {
-        return;
-    }
-
     if (ggml_is_contiguous(src0) && ggml_is_contiguous(dst)) {
         ggml_compute_forward_dup_same_cont(params, dst);
         return;
@@ -9001,10 +8976,6 @@ static void ggml_compute_forward_add_f32(
 
     GGML_ASSERT(ggml_can_repeat(src1, src0) && ggml_are_same_shape(src0, dst));
 
-    if (params->type == GGML_TASK_TYPE_INIT || params->type == GGML_TASK_TYPE_FINALIZE) {
-        return;
-    }
-
     const int ith = params->ith;
     const int nth = params->nth;
 
@@ -9080,10 +9051,6 @@ static void ggml_compute_forward_add_f16_f32(
 
     GGML_ASSERT(ggml_are_same_shape(src0, src1) && ggml_are_same_shape(src0, dst));
 
-    if (params->type == GGML_TASK_TYPE_INIT || params->type == GGML_TASK_TYPE_FINALIZE) {
-        return;
-    }
-
     const int ith = params->ith;
     const int nth = params->nth;
 
@@ -9159,10 +9126,6 @@ static void ggml_compute_forward_add_bf16_f32(
 
     GGML_ASSERT(ggml_are_same_shape(src0, src1) && ggml_are_same_shape(src0, dst));
 
-    if (params->type == GGML_TASK_TYPE_INIT || params->type == GGML_TASK_TYPE_FINALIZE) {
-        return;
-    }
-
     const int ith = params->ith;
     const int nth = params->nth;
 
@@ -9238,10 +9201,6 @@ static void ggml_compute_forward_add_f16_f16(
 
     GGML_ASSERT(ggml_are_same_shape(src0, src1) && ggml_are_same_shape(src0, dst));
 
-    if (params->type == GGML_TASK_TYPE_INIT || params->type == GGML_TASK_TYPE_FINALIZE) {
-        return;
-    }
-
     const int ith = params->ith;
     const int nth = params->nth;
 
@@ -9294,10 +9253,6 @@ static void ggml_compute_forward_add_bf16_bf16(
 
     GGML_ASSERT(ggml_are_same_shape(src0, src1) && ggml_are_same_shape(src0, dst));
 
-    if (params->type == GGML_TASK_TYPE_INIT || params->type == GGML_TASK_TYPE_FINALIZE) {
-        return;
-    }
-
     const int ith = params->ith;
     const int nth = params->nth;
 
@@ -9350,10 +9305,6 @@ static void ggml_compute_forward_add_q_f32(
 
     GGML_ASSERT(ggml_are_same_shape(src0, src1) && ggml_are_same_shape(src0, dst));
 
-    if (params->type == GGML_TASK_TYPE_INIT || params->type == GGML_TASK_TYPE_FINALIZE) {
-        return;
-    }
-
     const int nr  = ggml_nrows(src0);
 
     GGML_TENSOR_BINARY_OP_LOCALS
@@ -9503,10 +9454,6 @@ static void ggml_compute_forward_add1_f32(
     GGML_ASSERT(ggml_are_same_shape(src0, dst));
     GGML_ASSERT(ggml_is_scalar(src1));
 
-    if (params->type == GGML_TASK_TYPE_INIT || params->type == GGML_TASK_TYPE_FINALIZE) {
-        return;
-    }
-
     const int ith = params->ith;
     const int nth = params->nth;
 
@@ -9557,10 +9504,6 @@ static void ggml_compute_forward_add1_f16_f32(
     GGML_ASSERT(ggml_are_same_shape(src0, dst));
     GGML_ASSERT(ggml_is_scalar(src1));
 
-    if (params->type == GGML_TASK_TYPE_INIT || params->type == GGML_TASK_TYPE_FINALIZE) {
-        return;
-    }
-
     // scalar to add
     const float v = *(float *) src1->data;
 
@@ -9609,10 +9552,6 @@ static void ggml_compute_forward_add1_f16_f16(
     GGML_ASSERT(ggml_are_same_shape(src0, dst));
     GGML_ASSERT(ggml_is_scalar(src1));
 
-    if (params->type == GGML_TASK_TYPE_INIT || params->type == GGML_TASK_TYPE_FINALIZE) {
-        return;
-    }
-
     // scalar to add
     const float v = GGML_FP16_TO_FP32(*(ggml_fp16_t *) src1->data);
 
@@ -9661,10 +9600,6 @@ static void ggml_compute_forward_add1_q_f32(
     GGML_ASSERT(ggml_are_same_shape(src0, dst));
     GGML_ASSERT(ggml_is_scalar(src1));
 
-    if (params->type == GGML_TASK_TYPE_INIT || params->type == GGML_TASK_TYPE_FINALIZE) {
-        return;
-    }
-
     // scalar to add
     const float v = *(float *) src1->data;
 
@@ -9730,10 +9665,6 @@ static void ggml_compute_forward_add1_bf16_f32(
     GGML_ASSERT(ggml_are_same_shape(src0, dst));
     GGML_ASSERT(ggml_is_scalar(src1));
 
-    if (params->type == GGML_TASK_TYPE_INIT || params->type == GGML_TASK_TYPE_FINALIZE) {
-        return;
-    }
-
     // scalar to add
     const float v = *(float *) src1->data;
 
@@ -9782,10 +9713,6 @@ static void ggml_compute_forward_add1_bf16_bf16(
     GGML_ASSERT(ggml_are_same_shape(src0, dst));
     GGML_ASSERT(ggml_is_scalar(src1));
 
-    if (params->type == GGML_TASK_TYPE_INIT || params->type == GGML_TASK_TYPE_FINALIZE) {
-        return;
-    }
-
     // scalar to add
     const float v = GGML_BF16_TO_FP32(*(ggml_bf16_t *) src1->data);
 
@@ -9910,20 +9837,16 @@ static void ggml_compute_forward_acc_f32(
     size_t offset  = ((int32_t *) dst->op_params)[3];
     bool   inplace = (bool) ((int32_t *) dst->op_params)[4];
 
-    if (!inplace && (params->type == GGML_TASK_TYPE_INIT)) {
-        if (params->ith != 0) {
-            return;
+    if (!inplace) {
+        if (params->ith == 0) {
+            // memcpy needs to be synchronized across threads to avoid race conditions.
+            // => do it in INIT phase
+            memcpy(
+                ((char *)  dst->data),
+                ((char *) src0->data),
+                ggml_nbytes(dst));
         }
-        // memcpy needs to be synchronized across threads to avoid race conditions.
-        // => do it in INIT phase
-        memcpy(
-            ((char *)  dst->data),
-            ((char *) src0->data),
-            ggml_nbytes(dst));
-    }
-
-    if (params->type == GGML_TASK_TYPE_INIT || params->type == GGML_TASK_TYPE_FINALIZE) {
-        return;
+        ggml_barrier(params->shared);
     }
 
     const int ith = params->ith;
@@ -10025,13 +9948,12 @@ static void ggml_compute_forward_sub_f32(
     const struct ggml_tensor * src0 = dst->src[0];
     const struct ggml_tensor * src1 = dst->src[1];
 
-    assert(params->ith == 0);
-    assert(ggml_are_same_shape(src0, src1) && ggml_are_same_shape(src0, dst));
-
-    if (params->type == GGML_TASK_TYPE_INIT || params->type == GGML_TASK_TYPE_FINALIZE) {
+    if (params->ith != 0) {
         return;
     }
 
+    assert(ggml_are_same_shape(src0, src1) && ggml_are_same_shape(src0, dst));
+
     const int nr  = ggml_nrows(src0);
 
     GGML_TENSOR_BINARY_OP_LOCALS
@@ -10109,9 +10031,6 @@ static void ggml_compute_forward_mul_f32(
 
     GGML_ASSERT(ggml_can_repeat(src1, src0) && ggml_are_same_shape(src0, dst));
 
-    if (params->type == GGML_TASK_TYPE_INIT || params->type == GGML_TASK_TYPE_FINALIZE) {
-        return;
-    }
     const int ith = params->ith;
     const int nth = params->nth;
 
@@ -10206,10 +10125,6 @@ static void ggml_compute_forward_div_f32(
 
     GGML_ASSERT(ggml_can_repeat(src1, src0) && ggml_are_same_shape(src0, dst));
 
-    if (params->type == GGML_TASK_TYPE_INIT || params->type == GGML_TASK_TYPE_FINALIZE) {
-        return;
-    }
-
     const int ith = params->ith;
     const int nth = params->nth;
 
@@ -10298,13 +10213,12 @@ static void ggml_compute_forward_sqr_f32(
 
     const struct ggml_tensor * src0 = dst->src[0];
 
-    assert(params->ith == 0);
-    assert(ggml_are_same_shape(src0, dst));
-
-    if (params->type == GGML_TASK_TYPE_INIT || params->type == GGML_TASK_TYPE_FINALIZE) {
+    if (params->ith != 0) {
         return;
     }
 
+    assert(ggml_are_same_shape(src0, dst));
+
     const int n     = ggml_nrows(src0);
     const int nc    = src0->ne[0];
 
@@ -10344,13 +10258,12 @@ static void ggml_compute_forward_sqrt_f32(
 
     const struct ggml_tensor * src0 = dst->src[0];
 
-    assert(params->ith == 0);
-    assert(ggml_are_same_shape(src0, dst));
-
-    if (params->type == GGML_TASK_TYPE_INIT || params->type == GGML_TASK_TYPE_FINALIZE) {
+    if (params->ith != 0) {
         return;
     }
 
+    assert(ggml_are_same_shape(src0, dst));
+
     const int n  = ggml_nrows(src0);
     const int nc = src0->ne[0];
 
@@ -10390,13 +10303,12 @@ static void ggml_compute_forward_log_f32(
 
     const struct ggml_tensor * src0 = dst->src[0];
 
-    GGML_ASSERT(params->ith == 0);
-    GGML_ASSERT(ggml_are_same_shape(src0, dst));
-
-    if (params->type == GGML_TASK_TYPE_INIT || params->type == GGML_TASK_TYPE_FINALIZE) {
+    if (params->ith != 0) {
         return;
     }
 
+    GGML_ASSERT(ggml_are_same_shape(src0, dst));
+
     const int n  = ggml_nrows(src0);
     const int nc = src0->ne[0];
 
@@ -10436,13 +10348,13 @@ static void ggml_compute_forward_sum_f32(
 
     const struct ggml_tensor * src0 = dst->src[0];
 
-    assert(params->ith == 0);
-    assert(ggml_is_scalar(dst));
-
-    if (params->type == GGML_TASK_TYPE_INIT || params->type == GGML_TASK_TYPE_FINALIZE) {
+    if (params->ith != 0) {
         return;
     }
 
+    assert(ggml_is_scalar(dst));
+
+
     assert(ggml_is_scalar(dst));
     assert(src0->nb[0] == sizeof(float));
 
@@ -10471,13 +10383,12 @@ static void ggml_compute_forward_sum_f16(
 
     const struct ggml_tensor * src0 = dst->src[0];
 
-    assert(params->ith == 0);
-    assert(ggml_is_scalar(dst));
-
-    if (params->type == GGML_TASK_TYPE_INIT || params->type == GGML_TASK_TYPE_FINALIZE) {
+    if (params->ith != 0) {
         return;
     }
 
+    assert(ggml_is_scalar(dst));
+
     assert(src0->nb[0] == sizeof(ggml_fp16_t));
 
     GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne)
@@ -10505,13 +10416,12 @@ static void ggml_compute_forward_sum_bf16(
 
     const struct ggml_tensor * src0 = dst->src[0];
 
-    assert(params->ith == 0);
-    assert(ggml_is_scalar(dst));
-
-    if (params->type == GGML_TASK_TYPE_INIT || params->type == GGML_TASK_TYPE_FINALIZE) {
+    if (params->ith != 0) {
         return;
     }
 
+    assert(ggml_is_scalar(dst));
+
     assert(src0->nb[0] == sizeof(ggml_bf16_t));
 
     GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne)
@@ -10567,9 +10477,7 @@ static void ggml_compute_forward_sum_rows_f32(
 
     const struct ggml_tensor * src0 = dst->src[0];
 
-    GGML_ASSERT(params->ith == 0);
-
-    if (params->type == GGML_TASK_TYPE_INIT || params->type == GGML_TASK_TYPE_FINALIZE) {
+    if (params->ith != 0) {
         return;
     }
 
@@ -10622,9 +10530,7 @@ static void ggml_compute_forward_mean_f32(
 
     const struct ggml_tensor * src0 = dst->src[0];
 
-    assert(params->ith == 0);
-
-    if (params->type == GGML_TASK_TYPE_INIT || params->type == GGML_TASK_TYPE_FINALIZE) {
+    if (params->ith != 0) {
         return;
     }
 
@@ -10681,9 +10587,7 @@ static void ggml_compute_forward_argmax_f32(
 
     const struct ggml_tensor * src0 = dst->src[0];
 
-    assert(params->ith == 0);
-
-    if (params->type == GGML_TASK_TYPE_INIT || params->type == GGML_TASK_TYPE_FINALIZE) {
+    if (params->ith != 0) {
         return;
     }
 
@@ -10731,13 +10635,12 @@ static void ggml_compute_forward_repeat_f32(
 
     const struct ggml_tensor * src0 = dst->src[0];
 
-    GGML_ASSERT(params->ith == 0);
-    GGML_ASSERT(ggml_can_repeat(src0, dst));
-
-    if (params->type == GGML_TASK_TYPE_INIT || params->type == GGML_TASK_TYPE_FINALIZE) {
+    if (params->ith != 0) {
         return;
     }
 
+    GGML_ASSERT(ggml_can_repeat(src0, dst));
+
     GGML_TENSOR_UNARY_OP_LOCALS
 
     // guaranteed to be an integer due to the check in ggml_can_repeat
@@ -10776,13 +10679,12 @@ static void ggml_compute_forward_repeat_f16(
 
     const struct ggml_tensor * src0 = dst->src[0];
 
-    GGML_ASSERT(params->ith == 0);
-    GGML_ASSERT(ggml_can_repeat(src0, dst));
-
-    if (params->type == GGML_TASK_TYPE_INIT || params->type == GGML_TASK_TYPE_FINALIZE) {
+    if (params->ith != 0) {
         return;
     }
 
+    GGML_ASSERT(ggml_can_repeat(src0, dst));
+
     GGML_TENSOR_UNARY_OP_LOCALS
 
     // guaranteed to be an integer due to the check in ggml_can_repeat
@@ -10851,13 +10753,12 @@ static void ggml_compute_forward_repeat_back_f32(
 
     const struct ggml_tensor * src0 = dst->src[0];
 
-    GGML_ASSERT(params->ith == 0);
-    GGML_ASSERT(ggml_can_repeat(dst, src0));
-
-    if (params->type == GGML_TASK_TYPE_INIT || params->type == GGML_TASK_TYPE_FINALIZE) {
+    if (params->ith != 0) {
         return;
     }
 
+    GGML_ASSERT(ggml_can_repeat(dst, src0));
+
     GGML_TENSOR_UNARY_OP_LOCALS
 
     // guaranteed to be an integer due to the check in ggml_can_repeat
@@ -10931,10 +10832,6 @@ static void ggml_compute_forward_concat_f32(
     const struct ggml_tensor * src0 = dst->src[0];
     const struct ggml_tensor * src1 = dst->src[1];
 
-    if (params->type == GGML_TASK_TYPE_INIT || params->type == GGML_TASK_TYPE_FINALIZE) {
-        return;
-    }
-
     GGML_ASSERT(src0->nb[0] == sizeof(float));
 
     const int ith = params->ith;
@@ -11003,15 +10900,14 @@ static void ggml_compute_forward_abs_f32(
 
     const struct ggml_tensor * src0 = dst->src[0];
 
-    assert(params->ith == 0);
+    if (params->ith != 0) {
+        return;
+    }
+
     assert(ggml_is_contiguous_1(src0));
     assert(ggml_is_contiguous_1(dst));
     assert(ggml_are_same_shape(src0, dst));
 
-    if (params->type == GGML_TASK_TYPE_INIT || params->type == GGML_TASK_TYPE_FINALIZE) {
-        return;
-    }
-
     const int n  = ggml_nrows(src0);
     const int nc = src0->ne[0];
 
@@ -11048,15 +10944,14 @@ static void ggml_compute_forward_sgn_f32(
 
     const struct ggml_tensor * src0 = dst->src[0];
 
-    assert(params->ith == 0);
+    if (params->ith != 0) {
+        return;
+    }
+
     assert(ggml_is_contiguous_1(src0));
     assert(ggml_is_contiguous_1(dst));
     assert(ggml_are_same_shape(src0, dst));
 
-    if (params->type == GGML_TASK_TYPE_INIT || params->type == GGML_TASK_TYPE_FINALIZE) {
-        return;
-    }
-
     const int n  = ggml_nrows(src0);
     const int nc = src0->ne[0];
 
@@ -11093,15 +10988,14 @@ static void ggml_compute_forward_neg_f32(
 
     const struct ggml_tensor * src0 = dst->src[0];
 
-    assert(params->ith == 0);
+    if (params->ith != 0) {
+        return;
+    }
+
     assert(ggml_is_contiguous_1(src0));
     assert(ggml_is_contiguous_1(dst));
     assert(ggml_are_same_shape(src0, dst));
 
-    if (params->type == GGML_TASK_TYPE_INIT || params->type == GGML_TASK_TYPE_FINALIZE) {
-        return;
-    }
-
     const int n  = ggml_nrows(src0);
     const int nc = src0->ne[0];
 
@@ -11138,15 +11032,14 @@ static void ggml_compute_forward_step_f32(
 
     const struct ggml_tensor * src0 = dst->src[0];
 
-    assert(params->ith == 0);
+    if (params->ith != 0) {
+        return;
+    }
+
     assert(ggml_is_contiguous_1(src0));
     assert(ggml_is_contiguous_1(dst));
     assert(ggml_are_same_shape(src0, dst));
 
-    if (params->type == GGML_TASK_TYPE_INIT || params->type == GGML_TASK_TYPE_FINALIZE) {
-        return;
-    }
-
     const int n  = ggml_nrows(src0);
     const int nc = src0->ne[0];
 
@@ -11183,15 +11076,14 @@ static void ggml_compute_forward_tanh_f32(
 
     const struct ggml_tensor * src0 = dst->src[0];
 
-    assert(params->ith == 0);
+    if (params->ith != 0) {
+        return;
+    }
+
     assert(ggml_is_contiguous_1(src0));
     assert(ggml_is_contiguous_1(dst));
     assert(ggml_are_same_shape(src0, dst));
 
-    if (params->type == GGML_TASK_TYPE_INIT || params->type == GGML_TASK_TYPE_FINALIZE) {
-        return;
-    }
-
     const int n  = ggml_nrows(src0);
     const int nc = src0->ne[0];
 
@@ -11228,15 +11120,14 @@ static void ggml_compute_forward_elu_f32(
 
     const struct ggml_tensor * src0 = dst->src[0];
 
-    assert(params->ith == 0);
+    if (params->ith != 0) {
+        return;
+    }
+
     assert(ggml_is_contiguous_1(src0));
     assert(ggml_is_contiguous_1(dst));
     assert(ggml_are_same_shape(src0, dst));
 
-    if (params->type == GGML_TASK_TYPE_INIT || params->type == GGML_TASK_TYPE_FINALIZE) {
-        return;
-    }
-
     const int n  = ggml_nrows(src0);
     const int nc = src0->ne[0];
 
@@ -11273,15 +11164,14 @@ static void ggml_compute_forward_relu_f32(
 
     const struct ggml_tensor * src0 = dst->src[0];
 
-    assert(params->ith == 0);
+    if (params->ith != 0) {
+        return;
+    }
+
     assert(ggml_is_contiguous_1(src0));
     assert(ggml_is_contiguous_1(dst));
     assert(ggml_are_same_shape(src0, dst));
 
-    if (params->type == GGML_TASK_TYPE_INIT || params->type == GGML_TASK_TYPE_FINALIZE) {
-        return;
-    }
-
     const int n  = ggml_nrows(src0);
     const int nc = src0->ne[0];
 
@@ -11318,15 +11208,14 @@ static void ggml_compute_forward_sigmoid_f32(
 
     const struct ggml_tensor * src0 = dst->src[0];
 
-    assert(params->ith == 0);
+    if (params->ith != 0) {
+        return;
+    }
+
     assert(ggml_is_contiguous_1(src0));
     assert(ggml_is_contiguous_1(dst));
     assert(ggml_are_same_shape(src0, dst));
 
-    if (params->type == GGML_TASK_TYPE_INIT || params->type == GGML_TASK_TYPE_FINALIZE) {
-        return;
-    }
-
     const int n  = ggml_nrows(src0);
     const int nc = src0->ne[0];
 
@@ -11367,10 +11256,6 @@ static void ggml_compute_forward_gelu_f32(
     assert(ggml_is_contiguous_1(dst));
     assert(ggml_are_same_shape(src0, dst));
 
-    if (params->type == GGML_TASK_TYPE_INIT || params->type == GGML_TASK_TYPE_FINALIZE) {
-        return;
-    }
-
     const int ith = params->ith;
     const int nth = params->nth;
 
@@ -11430,10 +11315,6 @@ static void ggml_compute_forward_gelu_quick_f32(
     assert(ggml_is_contiguous_1(dst));
     assert(ggml_are_same_shape(src0, dst));
 
-    if (params->type == GGML_TASK_TYPE_INIT || params->type == GGML_TASK_TYPE_FINALIZE) {
-        return;
-    }
-
     const int ith = params->ith;
     const int nth = params->nth;
 
@@ -11493,10 +11374,6 @@ static void ggml_compute_forward_silu_f32(
     assert(ggml_is_contiguous_1(dst));
     assert(ggml_are_same_shape(src0, dst));
 
-    if (params->type == GGML_TASK_TYPE_INIT || params->type == GGML_TASK_TYPE_FINALIZE) {
-        return;
-    }
-
     const int ith = params->ith;
     const int nth = params->nth;
 
@@ -11551,15 +11428,14 @@ static void ggml_compute_forward_leaky_relu_f32(
 
     const struct ggml_tensor * src0 = dst->src[0];
 
-    assert(params->ith == 0);
+    if (params->ith != 0) {
+        return;
+    }
+
     assert(ggml_is_contiguous_1(src0));
     assert(ggml_is_contiguous_1(dst));
     assert(ggml_are_same_shape(src0, dst));
 
-    if (params->type == GGML_TASK_TYPE_INIT || params->type == GGML_TASK_TYPE_FINALIZE) {
-        return;
-    }
-
     const int n  = ggml_nrows(src0);
     const int nc = src0->ne[0];
 
@@ -11609,10 +11485,6 @@ static void ggml_compute_forward_silu_back_f32(
     assert(ggml_are_same_shape(src0, dst));
     assert(ggml_are_same_shape(src0, grad));
 
-    if (params->type == GGML_TASK_TYPE_INIT || params->type == GGML_TASK_TYPE_FINALIZE) {
-        return;
-    }
-
     const int ith = params->ith;
     const int nth = params->nth;
 
@@ -11668,15 +11540,14 @@ static void ggml_compute_forward_hardswish_f32(
 
     const struct ggml_tensor * src0 = dst->src[0];
 
-    assert(params->ith == 0);
+    if (params->ith != 0) {
+        return;
+    }
+
     assert(ggml_is_contiguous_1(src0));
     assert(ggml_is_contiguous_1(dst));
     assert(ggml_are_same_shape(src0, dst));
 
-    if (params->type == GGML_TASK_TYPE_INIT || params->type == GGML_TASK_TYPE_FINALIZE) {
-        return;
-    }
-
     const int n  = ggml_nrows(src0);
     const int nc = src0->ne[0];
 
@@ -11710,15 +11581,14 @@ static void ggml_compute_forward_hardsigmoid_f32(
 
     const struct ggml_tensor * src0 = dst->src[0];
 
-    assert(params->ith == 0);
+    if (params->ith != 0) {
+        return;
+    }
+
     assert(ggml_is_contiguous_1(src0));
     assert(ggml_is_contiguous_1(dst));
     assert(ggml_are_same_shape(src0, dst));
 
-    if (params->type == GGML_TASK_TYPE_INIT || params->type == GGML_TASK_TYPE_FINALIZE) {
-        return;
-    }
-
     const int n  = ggml_nrows(src0);
     const int nc = src0->ne[0];
 
@@ -11758,10 +11628,6 @@ static void ggml_compute_forward_norm_f32(
 
     GGML_ASSERT(ggml_are_same_shape(src0, dst));
 
-    if (params->type == GGML_TASK_TYPE_INIT || params->type == GGML_TASK_TYPE_FINALIZE) {
-        return;
-    }
-
     GGML_ASSERT(src0->nb[0] == sizeof(float));
 
     const int ith = params->ith;
@@ -11833,10 +11699,6 @@ static void ggml_compute_forward_rms_norm_f32(
 
     GGML_ASSERT(ggml_are_same_shape(src0, dst));
 
-    if (params->type == GGML_TASK_TYPE_INIT || params->type == GGML_TASK_TYPE_FINALIZE) {
-        return;
-    }
-
     GGML_ASSERT(src0->nb[0] == sizeof(float));
 
     const int ith = params->ith;
@@ -11904,10 +11766,6 @@ static void ggml_compute_forward_rms_norm_back_f32(
 
     GGML_ASSERT(ggml_are_same_shape(src0, dst) && ggml_are_same_shape(src0, src1));
 
-    if (params->type == GGML_TASK_TYPE_INIT || params->type == GGML_TASK_TYPE_FINALIZE) {
-        return;
-    }
-
     GGML_ASSERT(src0->nb[0] == sizeof(float));
 
     const int ith = params->ith;
@@ -12082,10 +11940,6 @@ static void ggml_compute_forward_group_norm_f32(
 
     GGML_ASSERT(ggml_are_same_shape(src0, dst));
 
-    if (params->type == GGML_TASK_TYPE_INIT || params->type == GGML_TASK_TYPE_FINALIZE) {
-        return;
-    }
-
     GGML_ASSERT(src0->nb[0] == sizeof(float));
 
     const int ith = params->ith;
@@ -12190,8 +12044,8 @@ static void ggml_compute_forward_mul_mat_one_chunk(
 
     const bool src1_cont = ggml_is_contiguous(src1);
 
-    ggml_vec_dot_t    const vec_dot = type_traits[type].vec_dot;
-    enum ggml_type    const vec_dot_type = type_traits[type].vec_dot_type;
+    ggml_vec_dot_t const vec_dot      = type_traits[type].vec_dot;
+    enum ggml_type const vec_dot_type = type_traits[type].vec_dot_type;
 
     // broadcast factors
     const int64_t r2 = ne12 / ne02;
@@ -12265,15 +12119,11 @@ static void ggml_compute_forward_mul_mat_one_chunk(
 
 static void ggml_compute_forward_mul_mat(
         const struct ggml_compute_params * params,
-              struct ggml_tensor * dst,
-              struct ggml_compute_state * state) {
+              struct ggml_tensor * dst) {
 
     const struct ggml_tensor * src0 = dst->src[0];
     const struct ggml_tensor * src1 = dst->src[1];
 
-    int64_t t0 = ggml_perf_time_us();
-    UNUSED(t0);
-
     GGML_TENSOR_BINARY_OP_LOCALS
 
     const int ith = params->ith;
@@ -12300,16 +12150,14 @@ static void ggml_compute_forward_mul_mat(
     GGML_ASSERT(nb1 <= nb2);
     GGML_ASSERT(nb2 <= nb3);
 
-    // broadcast factors
-    const int64_t r2 = ne12 / ne02;
-    const int64_t r3 = ne13 / ne03;
-    UNUSED(r2);
-    UNUSED(r3);
-
     // nb01 >= nb00 - src0 is not transposed
     //   compute by src0 rows
 
 #if GGML_USE_LLAMAFILE
+    // broadcast factors
+    const int64_t r2 = ne12 / ne02;
+    const int64_t r3 = ne13 / ne03;
+
     const bool src1_cont = ggml_is_contiguous(src1);
 
     if (src1_cont) {
@@ -12323,7 +12171,6 @@ static void ggml_compute_forward_mul_mat(
                                      (char *)dst->data + i12*nb2 + i13*nb3,
                                      nb1/ggml_type_size(dst->type),
                                      ith, nth,
-                                     params->type,
                                      src0->type,
                                      src1->type,
                                      dst->type))
@@ -12333,36 +12180,34 @@ static void ggml_compute_forward_mul_mat(
 UseGgmlGemm1:;
 #endif
 
-    if (params->type == GGML_TASK_TYPE_INIT) {
-        if (ith != 0) {
-            return;
-        }
-        // Every thread starts at ith, so the first unprocessed chunk is nth.  This save a bit of coordination right at the start.
-        atomic_store(&state->shared->current_chunk, nth);
-        if (src1->type != vec_dot_type) {
-            char * wdata = params->wdata;
-            const size_t row_size = ggml_row_size(vec_dot_type, ne10);
-
-            assert(params->wsize >= ne11*ne12*ne13*row_size);
-            GGML_ASSERT(src1->type == GGML_TYPE_F32);
-
-            for (int64_t i13 = 0; i13 < ne13; ++i13) {
-                for (int64_t i12 = 0; i12 < ne12; ++i12) {
-                    for (int64_t i11 = 0; i11 < ne11; ++i11) {
-                        from_float_to_vec_dot((float *)((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11), (void *) wdata, ne10);
-                        wdata += row_size;
-                    }
+    if (src1->type != vec_dot_type) {
+        char * wdata = params->wdata;
+
+        const size_t nbw1 = ggml_row_size(vec_dot_type, ne10);
+        const size_t nbw2 = nbw1*ne11;
+        const size_t nbw3 = nbw2*ne12;
+
+        assert(params->wsize >= ne13*nbw3);
+        GGML_ASSERT(src1->type == GGML_TYPE_F32);
+
+        for (int64_t i13 = 0; i13 < ne13; ++i13) {
+            for (int64_t i12 = 0; i12 < ne12; ++i12) {
+                for (int64_t i11 = ith; i11 < ne11; i11 += nth) {
+                    from_float_to_vec_dot((float *)((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11),
+                                          (void *)               (wdata + i13*nbw3 + i12*nbw2 + i11*nbw1),
+                                           ne10);
                 }
             }
         }
-
-        return;
     }
 
-    if (params->type == GGML_TASK_TYPE_FINALIZE) {
-        return;
+    if (ith == 0) {
+        // Every thread starts at ith, so the first unprocessed chunk is nth.  This save a bit of coordination right at the start.
+        atomic_store(&params->shared->current_chunk, nth);
     }
 
+    ggml_barrier(params->shared);
+
 #if GGML_USE_LLAMAFILE
     if (src1->type != vec_dot_type) {
         const void* wdata = (src1->type == vec_dot_type) ? src1->data : params->wdata;
@@ -12378,7 +12223,6 @@ UseGgmlGemm1:;
                                      (char *)dst->data + i12*nb2 + i13*nb3,
                                      nb1/ggml_type_size(dst->type),
                                      ith, nth,
-                                     params->type,
                                      src0->type,
                                      vec_dot_type,
                                      dst->type))
@@ -12388,11 +12232,6 @@ UseGgmlGemm1:;
 UseGgmlGemm2:;
 #endif
 
-#ifdef GGML_PERF
-    int chunks_executed = 0;
-    UNUSED(chunks_executed);
-#endif
-
     // This is the size of the first dimension of the result, so we can iterate that way. (see the ASSERT above, these are the same numbers)
     const int64_t nr0 = ne0;
 
@@ -12434,9 +12273,6 @@ UseGgmlGemm2:;
     const int64_t dr0 = (nr0 + nchunk0 - 1) / nchunk0;
     const int64_t dr1 = (nr1 + nchunk1 - 1) / nchunk1;
 
-    //if (ith == 0)
-    //    printf("MUL_MAT = [%d, %d, %d, %d] x [%d, %d, %d, %d] = %d x %d = %d.  Fp Ops/Ch %d\n", ne00, ne01, ne02, ne03, ne10, ne11, ne12, ne13, nchunk0, nchunk1, nchunk0 * nchunk1, ne00 * nr0 * nr1 / nchunk0 / nchunk1);
-
     // The first chunk comes from our thread_id, the rest will get auto-assigned.
     int current_chunk = ith;
 
@@ -12452,23 +12288,12 @@ UseGgmlGemm2:;
 
         ggml_compute_forward_mul_mat_one_chunk(params, dst, num_rows_per_vec_dot, ir0_start, ir0_end, ir1_start, ir1_end);
 
-#ifdef GGML_PERF
-        chunks_executed++;
-#endif
-
         if (nth >= nchunk0 * nchunk1) {
             break;
         }
 
-        current_chunk = atomic_fetch_add(&state->shared->current_chunk, 1);
+        current_chunk = atomic_fetch_add(&params->shared->current_chunk, 1);
     }
-
-#ifdef GGML_PERF
-    // These numbers are useful when trying to measure how well the threading scheduling works.
-    //int64_t workSize = (ne01 * ne11 * ne12 * ne13 * ne00) / nchunk0 / nchunk1;
-    //float time = (ggml_perf_time_us() - t0);
-    //printf("MUL_MAT = %f ms, [%d, %d, %d, %d] x [%d, %d, %d, %d] = %I64u, %f ops/usec in %d chunks.\n", time / 1000.0, ne00, ne01, ne02, ne03, ne10, ne11, ne12, ne13, workSize, (float)workSize/time, chunks_executed);
-#endif
 }
 
 // ggml_compute_forward_mul_mat_id
@@ -12520,32 +12345,33 @@ static void ggml_compute_forward_mul_mat_id(
     int64_t * matrix_row_counts = (int64_t *) (wdata_src1_end); // [n_as]
     struct mmid_row_mapping * matrix_rows = (struct mmid_row_mapping *)(matrix_row_counts + n_as); // [n_as][ne11]
 
-   if (params->type == GGML_TASK_TYPE_INIT) {
-        if (ith != 0) {
-            return;
-        }
+    if (src1->type != vec_dot_type) {
         char * wdata = params->wdata;
-        if (src1->type != vec_dot_type) {
-            const size_t row_size = ggml_row_size(vec_dot_type, ne10);
 
-            assert(params->wsize >= ne11*ne12*ne13*row_size);
-            assert(src1->type == GGML_TYPE_F32);
+        const size_t nbw1 = ggml_row_size(vec_dot_type, ne10);
+        const size_t nbw2 = nbw1*ne11;
+        const size_t nbw3 = nbw2*ne12;
 
-            for (int64_t i13 = 0; i13 < ne13; ++i13) {
-                for (int64_t i12 = 0; i12 < ne12; ++i12) {
-                    for (int64_t i11 = 0; i11 < ne11; ++i11) {
-                        from_float_to_vec_dot((float *)((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11), (void *) wdata, ne10);
-                        wdata += row_size;
-                    }
+        assert(params->wsize >= ne13*nbw3);
+        GGML_ASSERT(src1->type == GGML_TYPE_F32);
+
+        for (int64_t i13 = 0; i13 < ne13; ++i13) {
+            for (int64_t i12 = 0; i12 < ne12; ++i12) {
+                for (int64_t i11 = ith; i11 < ne11; i11 += nth) {
+                    from_float_to_vec_dot((float *)((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11),
+                                          (void *)               (wdata + i13*nbw3 + i12*nbw2 + i11*nbw1),
+                                           ne10);
                 }
             }
         }
+    }
 
+#define MMID_MATRIX_ROW(row_id, i1) matrix_rows[(row_id)*ne12 + (i1)]
+
+    if (ith == 0) {
         // initialize matrix_row_counts
         memset(matrix_row_counts, 0, n_as*sizeof(int64_t));
 
-#define MMID_MATRIX_ROW(row_id, i1) matrix_rows[(row_id)*ne12 + (i1)]
-
         // group rows by src0 matrix
         for (int64_t iid1 = 0; iid1 < ids->ne[1]; ++iid1) {
             for (int id = 0; id < n_ids; ++id) {
@@ -12557,13 +12383,9 @@ static void ggml_compute_forward_mul_mat_id(
                 matrix_row_counts[i02] += 1;
             }
         }
-
-        return;
     }
 
-    if (params->type == GGML_TASK_TYPE_FINALIZE) {
-        return;
-    }
+    ggml_barrier(params->shared);
 
     // compute each matrix multiplication in sequence
     for (int cur_a = 0; cur_a < n_as; ++cur_a) {
@@ -12662,9 +12484,6 @@ static void ggml_compute_forward_out_prod_f32(
     const struct ggml_tensor * src0 = dst->src[0];
     const struct ggml_tensor * src1 = dst->src[1];
 
-    // int64_t t0 = ggml_perf_time_us();
-    // UNUSED(t0);
-
     GGML_TENSOR_BINARY_OP_LOCALS
 
     const int ith = params->ith;
@@ -12689,17 +12508,10 @@ static void ggml_compute_forward_out_prod_f32(
     // nb01 >= nb00 - src0 is not transposed
     //   compute by src0 rows
 
-    if (params->type == GGML_TASK_TYPE_INIT) {
-        if (ith != 0) {
-            return;
-        }
+    if (ith == 0) {
         ggml_vec_set_f32(ne0*ne1*ne2*ne3, dst->data, 0);
-        return;
-    }
-
-    if (params->type == GGML_TASK_TYPE_FINALIZE) {
-        return;
     }
+    ggml_barrier(params->shared);
 
     // dst[:,:,:,:] = 0
     // for i2,i3:
@@ -12775,19 +12587,6 @@ static void ggml_compute_forward_out_prod_f32(
             }
         }
     }
-
-    //int64_t t1 = ggml_perf_time_us();
-    //static int64_t acc = 0;
-    //acc += t1 - t0;
-    //if (t1 - t0 > 10) {
-    //    printf("\n");
-    //    printf("ne00 = %5d, ne01 = %5d, ne02 = %5d, ne03 = %5d\n", ne00, ne01, ne02, ne03);
-    //    printf("nb00 = %5d, nb01 = %5d, nb02 = %5d, nb03 = %5d\n", nb00, nb01, nb02, nb03);
-    //    printf("ne10 = %5d, ne11 = %5d, ne12 = %5d, ne13 = %5d\n", ne10, ne11, ne12, ne13);
-    //    printf("nb10 = %5d, nb11 = %5d, nb12 = %5d, nb13 = %5d\n", nb10, nb11, nb12, nb13);
-
-    //    printf("XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX task %d/%d: %d us, acc = %d\n", ith, nth, (int) (t1 - t0), (int) acc);
-    //}
 }
 
 static void ggml_compute_forward_out_prod_q_f32(
@@ -12797,9 +12596,6 @@ static void ggml_compute_forward_out_prod_q_f32(
     const struct ggml_tensor * src0 = dst->src[0];
     const struct ggml_tensor * src1 = dst->src[1];
 
-    // int64_t t0 = ggml_perf_time_us();
-    // UNUSED(t0);
-
     GGML_TENSOR_BINARY_OP_LOCALS;
 
     const int ith = params->ith;
@@ -12830,17 +12626,10 @@ static void ggml_compute_forward_out_prod_q_f32(
     // nb01 >= nb00 - src0 is not transposed
     //   compute by src0 rows
 
-    if (params->type == GGML_TASK_TYPE_INIT) {
-        if (ith != 0) {
-            return;
-        }
+    if (ith == 0) {
         ggml_vec_set_f32(ne0*ne1*ne2*ne3, dst->data, 0);
-        return;
-    }
-
-    if (params->type == GGML_TASK_TYPE_FINALIZE) {
-        return;
     }
+    ggml_barrier(params->shared);
 
     // parallelize by last three dimensions
 
@@ -12887,19 +12676,6 @@ static void ggml_compute_forward_out_prod_q_f32(
             ggml_vec_mad_f32(ne0, d, wdata, *s1);
         }
     }
-
-    //int64_t t1 = ggml_perf_time_us();
-    //static int64_t acc = 0;
-    //acc += t1 - t0;
-    //if (t1 - t0 > 10) {
-    //    printf("\n");
-    //    printf("ne00 = %5d, ne01 = %5d, ne02 = %5d, ne03 = %5d\n", ne00, ne01, ne02, ne03);
-    //    printf("nb00 = %5d, nb01 = %5d, nb02 = %5d, nb03 = %5d\n", nb00, nb01, nb02, nb03);
-    //    printf("ne10 = %5d, ne11 = %5d, ne12 = %5d, ne13 = %5d\n", ne10, ne11, ne12, ne13);
-    //    printf("nb10 = %5d, nb11 = %5d, nb12 = %5d, nb13 = %5d\n", nb10, nb11, nb12, nb13);
-
-    //    printf("XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX task %d/%d: %d us, acc = %d\n", ith, nth, (int) (t1 - t0), (int) acc);
-    //}
 }
 
 static void ggml_compute_forward_out_prod(
@@ -12959,10 +12735,6 @@ static void ggml_compute_forward_scale_f32(
     GGML_ASSERT(ggml_is_contiguous(dst));
     GGML_ASSERT(ggml_are_same_shape(src0, dst));
 
-    if (params->type == GGML_TASK_TYPE_INIT || params->type == GGML_TASK_TYPE_FINALIZE) {
-        return;
-    }
-
     // scale factor
     float v;
     memcpy(&v, dst->op_params, sizeof(float));
@@ -13031,20 +12803,16 @@ static void ggml_compute_forward_set_f32(
     size_t offset  = ((int32_t *) dst->op_params)[3];
     bool   inplace = (bool) ((int32_t *) dst->op_params)[4];
 
-    if (!inplace && (params->type == GGML_TASK_TYPE_INIT)) {
-        if (params->ith != 0) {
-            return;
+    if (!inplace) {
+        if (params->ith == 0) {
+            // memcpy needs to be synchronized across threads to avoid race conditions.
+            // => do it in INIT phase
+            memcpy(
+                ((char *)  dst->data),
+                ((char *) src0->data),
+                ggml_nbytes(dst));
         }
-        // memcpy needs to be synchronized across threads to avoid race conditions.
-        // => do it in INIT phase
-        memcpy(
-            ((char *)  dst->data),
-            ((char *) src0->data),
-            ggml_nbytes(dst));
-    }
-
-    if (params->type == GGML_TASK_TYPE_INIT || params->type == GGML_TASK_TYPE_FINALIZE) {
-        return;
+        ggml_barrier(params->shared);
     }
 
     const int ith = params->ith;
@@ -13193,10 +12961,6 @@ static void ggml_compute_forward_get_rows_q(
     const struct ggml_tensor * src0 = dst->src[0];
     const struct ggml_tensor * src1 = dst->src[1];
 
-    if (params->type == GGML_TASK_TYPE_INIT || params->type == GGML_TASK_TYPE_FINALIZE) {
-        return;
-    }
-
     GGML_TENSOR_BINARY_OP_LOCALS
 
     const int64_t nc = ne00;
@@ -13241,10 +13005,6 @@ static void ggml_compute_forward_get_rows_f16(
     const struct ggml_tensor * src0 = dst->src[0];
     const struct ggml_tensor * src1 = dst->src[1];
 
-    if (params->type == GGML_TASK_TYPE_INIT || params->type == GGML_TASK_TYPE_FINALIZE) {
-        return;
-    }
-
     GGML_TENSOR_BINARY_OP_LOCALS
 
     const int64_t nc = ne00;
@@ -13286,10 +13046,6 @@ static void ggml_compute_forward_get_rows_bf16(
     const struct ggml_tensor * src0 = dst->src[0];
     const struct ggml_tensor * src1 = dst->src[1];
 
-    if (params->type == GGML_TASK_TYPE_INIT || params->type == GGML_TASK_TYPE_FINALIZE) {
-        return;
-    }
-
     GGML_TENSOR_BINARY_OP_LOCALS
 
     const int64_t nc = ne00;
@@ -13331,10 +13087,6 @@ static void ggml_compute_forward_get_rows_f32(
     const struct ggml_tensor * src0 = dst->src[0];
     const struct ggml_tensor * src1 = dst->src[1];
 
-    if (params->type == GGML_TASK_TYPE_INIT || params->type == GGML_TASK_TYPE_FINALIZE) {
-        return;
-    }
-
     GGML_TENSOR_BINARY_OP_LOCALS
 
     const int64_t nc = ne00;
@@ -13446,21 +13198,15 @@ static void ggml_compute_forward_get_rows_back_f32_f16(
     const struct ggml_tensor * src0 = dst->src[0];
     const struct ggml_tensor * src1 = dst->src[1];
 
-    GGML_ASSERT(params->ith == 0);
+    if (params->ith != 0) {
+        return;
+    }
+
     GGML_ASSERT(ggml_is_contiguous(dst));
 
     // ggml_compute_forward_dup_same_cont(params, opt0, dst);
 
-    if (params->type == GGML_TASK_TYPE_INIT) {
-        if (params->ith != 0) {
-            return;
-        }
-        memset(dst->data, 0, ggml_nbytes(dst));
-    }
-
-    if (params->type == GGML_TASK_TYPE_INIT || params->type == GGML_TASK_TYPE_FINALIZE) {
-        return;
-    }
+    memset(dst->data, 0, ggml_nbytes(dst));
 
     const int nc = src0->ne[0];
     const int nr = ggml_nelements(src1);
@@ -13485,21 +13231,15 @@ static void ggml_compute_forward_get_rows_back_f32(
     const struct ggml_tensor * src0 = dst->src[0];
     const struct ggml_tensor * src1 = dst->src[1];
 
-    GGML_ASSERT(params->ith == 0);
+    if (params->ith != 0) {
+        return;
+    }
+
     GGML_ASSERT(ggml_is_contiguous(dst));
 
     // ggml_compute_forward_dup_same_cont(params, opt0, dst);
 
-    if (params->type == GGML_TASK_TYPE_INIT) {
-        if (params->ith != 0) {
-            return;
-        }
-        memset(dst->data, 0, ggml_nbytes(dst));
-    }
-
-    if (params->type == GGML_TASK_TYPE_INIT || params->type == GGML_TASK_TYPE_FINALIZE) {
-        return;
-    }
+    memset(dst->data, 0, ggml_nbytes(dst));
 
     const int nc = src0->ne[0];
     const int nr = ggml_nelements(src1);
@@ -13565,9 +13305,7 @@ static void ggml_compute_forward_diag_f32(
 
     const struct ggml_tensor * src0 = dst->src[0];
 
-    GGML_ASSERT(params->ith == 0);
-
-    if (params->type == GGML_TASK_TYPE_INIT || params->type == GGML_TASK_TYPE_FINALIZE) {
+    if (params->ith != 0) {
         return;
     }
 
@@ -13636,22 +13374,18 @@ static void ggml_compute_forward_diag_mask_f32(
 
     GGML_ASSERT(n_past >= 0);
 
-    if (!inplace && (params->type == GGML_TASK_TYPE_INIT)) {
-        if (ith != 0) {
-            return;
+    if (!inplace) {
+        if (ith == 0) {
+            // memcpy needs to be synchronized across threads to avoid race conditions.
+            // => do it in INIT phase
+            GGML_ASSERT(ggml_nelements(dst) == ggml_nelements(src0));
+            GGML_ASSERT(ggml_is_contiguous(dst) && ggml_is_contiguous(src0));
+            memcpy(
+                ((char *)  dst->data),
+                ((char *) src0->data),
+                ggml_nbytes(dst));
         }
-        // memcpy needs to be synchronized across threads to avoid race conditions.
-        // => do it in INIT phase
-        GGML_ASSERT(ggml_nelements(dst) == ggml_nelements(src0));
-        GGML_ASSERT(ggml_is_contiguous(dst) && ggml_is_contiguous(src0));
-        memcpy(
-            ((char *)  dst->data),
-            ((char *) src0->data),
-            ggml_nbytes(dst));
-    }
-
-    if (params->type == GGML_TASK_TYPE_INIT || params->type == GGML_TASK_TYPE_FINALIZE) {
-        return;
+        ggml_barrier(params->shared);
     }
 
     // TODO: handle transposed/permuted matrices
@@ -13723,10 +13457,6 @@ static void ggml_compute_forward_soft_max_f32(
     assert(ggml_is_contiguous(dst));
     assert(ggml_are_same_shape(src0, dst));
 
-    if (params->type == GGML_TASK_TYPE_INIT || params->type == GGML_TASK_TYPE_FINALIZE) {
-        return;
-    }
-
     float scale    = 1.0f;
     float max_bias = 0.0f;
 
@@ -13848,10 +13578,6 @@ static void ggml_compute_forward_soft_max_back_f32(
     GGML_ASSERT(ggml_are_same_shape(src0, dst));
     GGML_ASSERT(ggml_are_same_shape(src1, dst));
 
-    if (params->type == GGML_TASK_TYPE_INIT || params->type == GGML_TASK_TYPE_FINALIZE) {
-        return;
-    }
-
     // TODO: handle transposed/permuted matrices
 
     const int ith = params->ith;
@@ -13940,9 +13666,7 @@ static void ggml_compute_forward_clamp_f32(
 
     const struct ggml_tensor * src0 = dst->src[0];
 
-    assert(params->ith == 0);
-
-    if (params->type == GGML_TASK_TYPE_INIT || params->type == GGML_TASK_TYPE_FINALIZE) {
+    if (params->ith != 0) {
         return;
     }
 
@@ -14089,10 +13813,6 @@ static void ggml_compute_forward_rope_f32(
     const struct ggml_tensor * src1 = dst->src[1];
     const struct ggml_tensor * src2 = dst->src[2];
 
-    if (params->type == GGML_TASK_TYPE_INIT || params->type == GGML_TASK_TYPE_FINALIZE) {
-        return;
-    }
-
     float freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow;
 
     //const int n_past     = ((int32_t *) dst->op_params)[0];
@@ -14219,10 +13939,6 @@ static void ggml_compute_forward_rope_f16(
     const struct ggml_tensor * src1 = dst->src[1];
     const struct ggml_tensor * src2 = dst->src[2];
 
-    if (params->type == GGML_TASK_TYPE_INIT || params->type == GGML_TASK_TYPE_FINALIZE) {
-        return;
-    }
-
     float freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow;
 
     //const int n_past     = ((int32_t *) dst->op_params)[0];
@@ -14397,9 +14113,6 @@ static void ggml_compute_forward_conv_transpose_1d_f16_f32(
     GGML_ASSERT(src1->type == GGML_TYPE_F32);
     GGML_ASSERT( dst->type == GGML_TYPE_F32);
 
-    int64_t t0 = ggml_perf_time_us();
-    UNUSED(t0);
-
     GGML_TENSOR_BINARY_OP_LOCALS
 
     const int ith = params->ith;
@@ -14410,10 +14123,7 @@ static void ggml_compute_forward_conv_transpose_1d_f16_f32(
     GGML_ASSERT(nb00 == sizeof(ggml_fp16_t));
     GGML_ASSERT(nb10 == sizeof(float));
 
-    if (params->type == GGML_TASK_TYPE_INIT) {
-        if (ith != 0) {
-            return;
-        }
+    if (ith == 0) {
         memset(params->wdata, 0, params->wsize);
 
         // permute kernel data (src0) from (K x Cout x Cin) to (Cin x K x Cout)
@@ -14446,13 +14156,8 @@ static void ggml_compute_forward_conv_transpose_1d_f16_f32(
 
         // need to zero dst since we are accumulating into it
         memset(dst->data, 0, ggml_nbytes(dst));
-
-        return;
-    }
-
-    if (params->type == GGML_TASK_TYPE_FINALIZE) {
-        return;
     }
+    ggml_barrier(params->shared);
 
     const int32_t s0 = ((const int32_t*)(dst->op_params))[0];
 
@@ -14496,9 +14201,6 @@ static void ggml_compute_forward_conv_transpose_1d_f32(
     GGML_ASSERT(src1->type == GGML_TYPE_F32);
     GGML_ASSERT( dst->type == GGML_TYPE_F32);
 
-    int64_t t0 = ggml_perf_time_us();
-    UNUSED(t0);
-
     GGML_TENSOR_BINARY_OP_LOCALS
 
     const int ith = params->ith;
@@ -14509,10 +14211,7 @@ static void ggml_compute_forward_conv_transpose_1d_f32(
     GGML_ASSERT(nb00 == sizeof(float));
     GGML_ASSERT(nb10 == sizeof(float));
 
-    if (params->type == GGML_TASK_TYPE_INIT) {
-        if (ith != 0) {
-            return;
-        }
+    if (ith == 0) {
         memset(params->wdata, 0, params->wsize);
 
         // prepare kernel data (src0) from (K x Cout x Cin) to (Cin x K x Cout)
@@ -14545,13 +14244,8 @@ static void ggml_compute_forward_conv_transpose_1d_f32(
 
         // need to zero dst since we are accumulating into it
         memset(dst->data, 0, ggml_nbytes(dst));
-
-        return;
-    }
-
-    if (params->type == GGML_TASK_TYPE_FINALIZE) {
-        return;
     }
+    ggml_barrier(params->shared);
 
     const int32_t s0 = ((const int32_t*)(dst->op_params))[0];
 
@@ -14620,9 +14314,6 @@ static void ggml_compute_forward_im2col_f32(
     GGML_ASSERT(src1->type == GGML_TYPE_F32);
     GGML_ASSERT( dst->type == GGML_TYPE_F32);
 
-    int64_t t0 = ggml_perf_time_us();
-    UNUSED(t0);
-
     GGML_TENSOR_BINARY_OP_LOCALS;
 
     const int32_t s0 = ((const int32_t *)(dst->op_params))[0];
@@ -14653,14 +14344,6 @@ static void ggml_compute_forward_im2col_f32(
     GGML_ASSERT(nb00 == sizeof(ggml_fp16_t));
     GGML_ASSERT(nb10 == sizeof(float));
 
-    if (params->type == GGML_TASK_TYPE_INIT) {
-        return;
-    }
-
-    if (params->type == GGML_TASK_TYPE_FINALIZE) {
-        return;
-    }
-
     // im2col: [N, IC, IH, IW] => [N, OH, OW, IC*KH*KW]
     {
         float * const wdata = (float *) dst->data;
@@ -14708,9 +14391,6 @@ static void ggml_compute_forward_im2col_f16(
     GGML_ASSERT(src1->type == GGML_TYPE_F32);
     GGML_ASSERT( dst->type == GGML_TYPE_F16);
 
-    int64_t t0 = ggml_perf_time_us();
-    UNUSED(t0);
-
     GGML_TENSOR_BINARY_OP_LOCALS;
 
     const int32_t s0 = ((const int32_t *)(dst->op_params))[0];
@@ -14741,14 +14421,6 @@ static void ggml_compute_forward_im2col_f16(
     GGML_ASSERT(nb00 == sizeof(ggml_fp16_t));
     GGML_ASSERT(nb10 == sizeof(float));
 
-    if (params->type == GGML_TASK_TYPE_INIT) {
-        return;
-    }
-
-    if (params->type == GGML_TASK_TYPE_FINALIZE) {
-        return;
-    }
-
     // im2col: [N, IC, IH, IW] => [N, OH, OW, IC*KH*KW]
     {
         ggml_fp16_t * const wdata = (ggml_fp16_t *) dst->data;
@@ -14814,9 +14486,6 @@ static void ggml_compute_forward_conv_transpose_2d(
     GGML_ASSERT(src1->type == GGML_TYPE_F32);
     GGML_ASSERT( dst->type == GGML_TYPE_F32);
 
-    int64_t t0 = ggml_perf_time_us();
-    UNUSED(t0);
-
     GGML_TENSOR_BINARY_OP_LOCALS
 
     const int ith = params->ith;
@@ -14827,10 +14496,7 @@ static void ggml_compute_forward_conv_transpose_2d(
     GGML_ASSERT(nb00 == sizeof(ggml_fp16_t));
     GGML_ASSERT(nb10 == sizeof(float));
 
-    if (params->type == GGML_TASK_TYPE_INIT) {
-        if (ith != 0) {
-            return;
-        }
+    if (ith == 0) {
         memset(params->wdata, 0, params->wsize);
 
         // permute kernel data (src0) from (Kw x Kh x Cout x Cin) to (Cin x Kw x Kh x Cout)
@@ -14865,13 +14531,8 @@ static void ggml_compute_forward_conv_transpose_2d(
         }
 
         memset(dst->data, 0, ggml_nbytes(dst));
-
-        return;
-    }
-
-    if (params->type == GGML_TASK_TYPE_FINALIZE) {
-        return;
     }
+    ggml_barrier(params->shared);
 
     const int32_t stride = ggml_get_op_params_i32(dst, 0);
 
@@ -14919,9 +14580,8 @@ static void ggml_compute_forward_pool_1d_sk_p0(
     const struct ggml_tensor * src = dst->src[0];
 
     assert(src->type == GGML_TYPE_F32);
-    assert(params->ith == 0);
 
-    if (params->type == GGML_TASK_TYPE_INIT || params->type == GGML_TASK_TYPE_FINALIZE) {
+    if (params->ith != 0) {
         return;
     }
 
@@ -14988,9 +14648,8 @@ static void ggml_compute_forward_pool_2d(
     const struct ggml_tensor * src = dst->src[0];
 
     GGML_ASSERT(src->type == GGML_TYPE_F32);
-    GGML_ASSERT(params->ith == 0);
 
-    if (params->type == GGML_TASK_TYPE_INIT || params->type == GGML_TASK_TYPE_FINALIZE) {
+    if (params->ith != 0) {
         return;
     }
 
@@ -15063,10 +14722,6 @@ static void ggml_compute_forward_upscale_f32(
 
     const struct ggml_tensor * src0 = dst->src[0];
 
-    if (params->type == GGML_TASK_TYPE_INIT || params->type == GGML_TASK_TYPE_FINALIZE) {
-        return;
-    }
-
     GGML_ASSERT(src0->type == GGML_TYPE_F32);
 
     const int ith = params->ith;
@@ -15127,10 +14782,6 @@ static void ggml_compute_forward_pad_f32(
 
     const struct ggml_tensor * src0 = dst->src[0];
 
-    if (params->type == GGML_TASK_TYPE_INIT || params->type == GGML_TASK_TYPE_FINALIZE) {
-        return;
-    }
-
     GGML_ASSERT(src0->nb[0] == sizeof(float));
     GGML_ASSERT( dst->nb[0] == sizeof(float));
 
@@ -15187,10 +14838,6 @@ static void ggml_compute_forward_arange_f32(
     const struct ggml_compute_params * params,
     struct ggml_tensor * dst) {
 
-    if (params->type == GGML_TASK_TYPE_INIT || params->type == GGML_TASK_TYPE_FINALIZE) {
-        return;
-    }
-
     GGML_ASSERT(dst->nb[0] == sizeof(float));
 
     const int ith = params->ith;
@@ -15229,10 +14876,6 @@ static void ggml_compute_forward_timestep_embedding_f32(
     const struct ggml_compute_params * params,
     struct ggml_tensor * dst) {
 
-    if (params->type == GGML_TASK_TYPE_INIT || params->type == GGML_TASK_TYPE_FINALIZE) {
-        return;
-    }
-
     const struct ggml_tensor * src0 = dst->src[0];
 
     GGML_ASSERT(src0->nb[0] == sizeof(float));
@@ -15288,10 +14931,6 @@ static void ggml_compute_forward_argsort_f32(
 
     const struct ggml_tensor * src0 = dst->src[0];
 
-    if (params->type == GGML_TASK_TYPE_INIT || params->type == GGML_TASK_TYPE_FINALIZE) {
-        return;
-    }
-
     GGML_TENSOR_UNARY_OP_LOCALS
 
     GGML_ASSERT(nb0 == sizeof(float));
@@ -15352,8 +14991,6 @@ static void ggml_compute_forward_flash_attn_ext_f16(
         const struct ggml_tensor * v,
         const struct ggml_tensor * mask,
         struct ggml_tensor * dst) {
-    int64_t t0 = ggml_perf_time_us();
-    UNUSED(t0);
 
     GGML_TENSOR_LOCALS(int64_t, neq, q,   ne)
     GGML_TENSOR_LOCALS(size_t,  nbq, q,   nb)
@@ -15398,14 +15035,6 @@ static void ggml_compute_forward_flash_attn_ext_f16(
     const int64_t rv2 = neq2/nev2;
     const int64_t rv3 = neq3/nev3;
 
-    if (params->type == GGML_TASK_TYPE_INIT) {
-        return;
-    }
-
-    if (params->type == GGML_TASK_TYPE_FINALIZE) {
-        return;
-    }
-
     // parallelize by q rows using ggml_vec_dot_f32
 
     // total rows in q
@@ -15588,9 +15217,6 @@ static void ggml_compute_forward_flash_attn_back_f32(
     const struct ggml_tensor * v = dst->src[2];
     const struct ggml_tensor * d = dst->src[3];
 
-    int64_t t0 = ggml_perf_time_us();
-    UNUSED(t0);
-
     GGML_TENSOR_LOCALS(int64_t, neq, q,   ne)
     GGML_TENSOR_LOCALS(size_t,  nbq, q,   nb)
     GGML_TENSOR_LOCALS(int64_t, nek, k,   ne)
@@ -15637,16 +15263,10 @@ static void ggml_compute_forward_flash_attn_back_f32(
     GGML_ASSERT(nb1 <= nb2);
     GGML_ASSERT(nb2 <= nb3);
 
-    if (params->type == GGML_TASK_TYPE_INIT) {
-        if (ith == 0) {
-            memset(dst->data, 0, nb0*ne0*ne1*ne2*ne3);
-        }
-        return;
-    }
-
-    if (params->type == GGML_TASK_TYPE_FINALIZE) {
-        return;
+    if (ith == 0) {
+        memset(dst->data, 0, nb0*ne0*ne1*ne2*ne3);
     }
+    ggml_barrier(params->shared);
 
     const int64_t elem_q = ggml_nelements(q);
     const int64_t elem_k = ggml_nelements(k);
@@ -15926,10 +15546,6 @@ static void ggml_compute_forward_flash_attn_back(
 static void ggml_compute_forward_ssm_conv_f32(
         const struct ggml_compute_params * params,
         struct ggml_tensor * dst) {
-    if (params->type == GGML_TASK_TYPE_INIT || params->type == GGML_TASK_TYPE_FINALIZE) {
-        return;
-    }
-
     const struct ggml_tensor * src0 = dst->src[0]; // conv_state
     const struct ggml_tensor * src1 = dst->src[1]; // x
     const struct ggml_tensor * src2 = dst->src[2]; // conv1d.weight
@@ -16052,10 +15668,6 @@ static void ggml_compute_forward_ssm_conv(
 static void ggml_compute_forward_ssm_scan_f32(
         const struct ggml_compute_params * params,
         struct ggml_tensor * dst) {
-    if (params->type == GGML_TASK_TYPE_INIT || params->type == GGML_TASK_TYPE_FINALIZE) {
-        return;
-    }
-
     const struct ggml_tensor * src0 = dst->src[0]; // s
     const struct ggml_tensor * src1 = dst->src[1]; // x
     const struct ggml_tensor * src2 = dst->src[2]; // dt
@@ -16177,13 +15789,10 @@ static void ggml_compute_forward_ssm_scan(
 static void ggml_compute_forward_win_part_f32(
         const struct ggml_compute_params * params,
         struct ggml_tensor * dst) {
+    UNUSED(params);
 
     const struct ggml_tensor * src0 = dst->src[0];
 
-    if (params->type == GGML_TASK_TYPE_INIT || params->type == GGML_TASK_TYPE_FINALIZE) {
-        return;
-    }
-
     GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne)
     GGML_TENSOR_LOCALS(int64_t, ne,  dst,  ne)
 
@@ -16243,13 +15852,10 @@ static void ggml_compute_forward_win_part(
 static void ggml_compute_forward_win_unpart_f32(
         const struct ggml_compute_params * params,
         struct ggml_tensor * dst) {
+    UNUSED(params);
 
     const struct ggml_tensor * src0 = dst->src[0];
 
-    if (params->type == GGML_TASK_TYPE_INIT || params->type == GGML_TASK_TYPE_FINALIZE) {
-        return;
-    }
-
     GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne)
     GGML_TENSOR_LOCALS(int64_t, ne,  dst,  ne)
 
@@ -16375,13 +15981,10 @@ static void ggml_compute_forward_unary(
 static void ggml_compute_forward_get_rel_pos_f16(
         const struct ggml_compute_params * params,
         struct ggml_tensor * dst) {
+    UNUSED(params);
 
     const struct ggml_tensor * src0 = dst->src[0];
 
-    if (params->type == GGML_TASK_TYPE_INIT || params->type == GGML_TASK_TYPE_FINALIZE) {
-        return;
-    }
-
     // ref: https://github.com/facebookresearch/segment-anything/blob/main/segment_anything/modeling/image_encoder.py#L292-L322
 
     GGML_TENSOR_UNARY_OP_LOCALS
@@ -16431,20 +16034,12 @@ static void ggml_compute_forward_add_rel_pos_f32(
     const struct ggml_tensor * src2 = dst->src[2];
 
     const bool inplace = (bool) ((int32_t *) dst->op_params)[0];
-    if (!inplace && params->type == GGML_TASK_TYPE_INIT) {
-        if (params->ith != 0) {
-            return;
+    if (!inplace) {
+        if (params->ith == 0) {
+            memcpy((char *) dst->data, (char *) src0->data, ggml_nbytes(dst));
         }
-        memcpy((char *) dst->data, (char *) src0->data, ggml_nbytes(dst));
-        return;
-    }
-    if (params->type == GGML_TASK_TYPE_INIT || params->type == GGML_TASK_TYPE_FINALIZE) {
-        return;
+        ggml_barrier(params->shared);
     }
-
-    int64_t t0 = ggml_perf_time_us();
-    UNUSED(t0);
-
     // ref: https://github.com/facebookresearch/segment-anything/blob/main/segment_anything/modeling/image_encoder.py#L357-L359
 
     float * src1_data = (float *) src1->data;
@@ -16518,15 +16113,14 @@ static void ggml_compute_forward_map_unary_f32(
 
     const struct ggml_tensor * src0 = dst->src[0];
 
-    assert(params->ith == 0);
+    if (params->ith != 0) {
+        return;
+    }
+
     assert(ggml_is_contiguous_1(src0));
     assert(ggml_is_contiguous_1(dst));
     assert(ggml_are_same_shape(src0, dst));
 
-    if (params->type == GGML_TASK_TYPE_INIT || params->type == GGML_TASK_TYPE_FINALIZE) {
-        return;
-    }
-
     const int n  = ggml_nrows(src0);
     const int nc = src0->ne[0];
 
@@ -16566,16 +16160,15 @@ static void ggml_compute_forward_map_binary_f32(
     const struct ggml_tensor * src0 = dst->src[0];
     const struct ggml_tensor * src1 = dst->src[1];
 
-    assert(params->ith == 0);
+    if (params->ith != 0) {
+        return;
+    }
+
     assert(ggml_is_contiguous_1(src0));
     assert(ggml_is_contiguous_1(src1));
     assert(ggml_is_contiguous_1(dst));
     assert(ggml_are_same_shape(src0, src1) && ggml_are_same_shape(src0, dst));
 
-    if (params->type == GGML_TASK_TYPE_INIT || params->type == GGML_TASK_TYPE_FINALIZE) {
-        return;
-    }
-
     const int n  = ggml_nrows(src0);
     const int nc = src0->ne[0];
 
@@ -16615,9 +16208,7 @@ static void ggml_compute_forward_map_custom1_f32(
 
     const struct ggml_tensor * a = dst->src[0];
 
-    assert(params->ith == 0);
-
-    if (params->type == GGML_TASK_TYPE_INIT || params->type == GGML_TASK_TYPE_FINALIZE) {
+    if (params->ith != 0) {
         return;
     }
 
@@ -16634,9 +16225,7 @@ static void ggml_compute_forward_map_custom2_f32(
     const struct ggml_tensor * a = dst->src[0];
     const struct ggml_tensor * b = dst->src[1];
 
-    assert(params->ith == 0);
-
-    if (params->type == GGML_TASK_TYPE_INIT || params->type == GGML_TASK_TYPE_FINALIZE) {
+    if (params->ith != 0) {
         return;
     }
 
@@ -16654,9 +16243,7 @@ static void ggml_compute_forward_map_custom3_f32(
     const struct ggml_tensor * b = dst->src[1];
     const struct ggml_tensor * c = dst->src[1];
 
-    assert(params->ith == 0);
-
-    if (params->type == GGML_TASK_TYPE_INIT || params->type == GGML_TASK_TYPE_FINALIZE) {
+    if (params->ith != 0) {
         return;
     }
 
@@ -16671,10 +16258,6 @@ static void ggml_compute_forward_map_custom1(
 
     const struct ggml_tensor * a = dst->src[0];
 
-    if (params->type == GGML_TASK_TYPE_INIT || params->type == GGML_TASK_TYPE_FINALIZE) {
-        return;
-    }
-
     struct ggml_map_custom1_op_params p;
     memcpy(&p, dst->op_params, sizeof(p));
 
@@ -16690,10 +16273,6 @@ static void ggml_compute_forward_map_custom2(
     const struct ggml_tensor * a = dst->src[0];
     const struct ggml_tensor * b = dst->src[1];
 
-    if (params->type == GGML_TASK_TYPE_INIT || params->type == GGML_TASK_TYPE_FINALIZE) {
-        return;
-    }
-
     struct ggml_map_custom2_op_params p;
     memcpy(&p, dst->op_params, sizeof(p));
 
@@ -16710,10 +16289,6 @@ static void ggml_compute_forward_map_custom3(
     const struct ggml_tensor * b = dst->src[1];
     const struct ggml_tensor * c = dst->src[2];
 
-    if (params->type == GGML_TASK_TYPE_INIT || params->type == GGML_TASK_TYPE_FINALIZE) {
-        return;
-    }
-
     struct ggml_map_custom3_op_params p;
     memcpy(&p, dst->op_params, sizeof(p));
 
@@ -16745,21 +16320,10 @@ static void ggml_compute_forward_cross_entropy_loss_f32(
 
     GGML_ASSERT(params->wsize >= sizeof(float) * (nth + nth * nc));
 
-    if (params->type == GGML_TASK_TYPE_INIT) {
-        if (ith == 0) {
-            memset(sums, 0, sizeof(float) * (nth + nth * nc));
-        }
-        return;
-    }
-
-    if (params->type == GGML_TASK_TYPE_FINALIZE) {
-        if (ith == 0) {
-            float * dp = (float *) dst->data;
-            ggml_vec_sum_f32(nth, dp, sums);
-            dp[0] *= -1.0f / (float) nr;
-        }
-        return;
+    if (ith == 0) {
+        memset(sums, 0, sizeof(float) * (nth + nth * nc));
     }
+    ggml_barrier(params->shared);
 
     const double eps = 1e-9;
 
@@ -16807,7 +16371,13 @@ static void ggml_compute_forward_cross_entropy_loss_f32(
         }
 #endif
     }
+    ggml_barrier(params->shared);
 
+    if (ith == 0) {
+        float * dp = (float *) dst->data;
+        ggml_vec_sum_f32(nth, dp, sums);
+        dp[0] *= -1.0f / (float) nr;
+    }
 }
 
 static void ggml_compute_forward_cross_entropy_loss(
@@ -16847,10 +16417,6 @@ static void ggml_compute_forward_cross_entropy_loss_back_f32(
     const int64_t ith = params->ith;
     const int64_t nth = params->nth;
 
-    if (params->type == GGML_TASK_TYPE_INIT || params->type == GGML_TASK_TYPE_FINALIZE) {
-        return;
-    }
-
     const double eps = 1e-9;
 
     // TODO: handle transposed/permuted matrices
@@ -16921,7 +16487,7 @@ static void ggml_compute_forward_cross_entropy_loss_back(
 
 /////////////////////////////////
 
-static void ggml_compute_forward(struct ggml_compute_params * params, struct ggml_tensor * tensor, struct ggml_compute_state * state) {
+static void ggml_compute_forward(struct ggml_compute_params * params, struct ggml_tensor * tensor) {
     GGML_ASSERT(params);
 
     if (tensor->op == GGML_OP_NONE || ggml_is_empty(tensor)) {
@@ -17019,7 +16585,7 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
             } break;
         case GGML_OP_MUL_MAT:
             {
-                ggml_compute_forward_mul_mat(params, tensor, state);
+                ggml_compute_forward_mul_mat(params, tensor);
             } break;
         case GGML_OP_MUL_MAT_ID:
             {
@@ -18497,9 +18063,6 @@ struct ggml_cgraph * ggml_new_graph_custom(struct ggml_context * ctx, size_t siz
         /*.leafs        =*/ leafs_ptr,
         /*.hash_table   =*/ { hash_size, hash_keys_ptr },
         /*.order        =*/ GGML_CGRAPH_EVAL_ORDER_LEFT_TO_RIGHT,
-        /*.perf_runs    =*/ 0,
-        /*.perf_cycles  =*/ 0,
-        /*.perf_time_us =*/ 0,
     };
 
     return cgraph;
@@ -18519,9 +18082,6 @@ struct ggml_cgraph ggml_graph_view(struct ggml_cgraph * cgraph0, int i0, int i1)
         /*.leafs        =*/ NULL,
         /*.hash_table   =*/ { 0, NULL },
         /*.order        =*/ cgraph0->order,
-        /*.perf_runs    =*/ 0,
-        /*.perf_cycles  =*/ 0,
-        /*.perf_time_us =*/ 0,
     };
 
     return cgraph;
@@ -18715,16 +18275,7 @@ static void set_numa_thread_affinity(int thread_n) { UNUSED(thread_n);  }
 static void clear_numa_thread_affinity(void) {}
 #endif
 
-static void ggml_graph_compute_perf_stats_node(struct ggml_tensor * node, const struct ggml_compute_state_shared * st) {
-    int64_t cycles_cur  = ggml_perf_cycles()  - st->perf_node_start_cycles;
-    int64_t time_us_cur = ggml_perf_time_us() - st->perf_node_start_time_us;
-
-    node->perf_runs++;
-    node->perf_cycles  += cycles_cur;
-    node->perf_time_us += time_us_cur;
-}
-
-static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads, int n_cur_threads) {
+static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) {
     int n_tasks = 0;
 
     if (ggml_is_empty(node)) {
@@ -18767,8 +18318,8 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads, int n_cur_
                 case GGML_UNARY_OP_ELU:
                 case GGML_UNARY_OP_RELU:
                 case GGML_UNARY_OP_SIGMOID:
-                case GGML_UNARY_OP_HARDSWISH: // to opt for multiple threads
-                case GGML_UNARY_OP_HARDSIGMOID: // to opt for multiple threads
+                case GGML_UNARY_OP_HARDSWISH:
+                case GGML_UNARY_OP_HARDSIGMOID:
                     {
                         n_tasks = 1;
                     } break;
@@ -18791,33 +18342,18 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads, int n_cur_
         case GGML_OP_RMS_NORM_BACK:
         case GGML_OP_GROUP_NORM:
         case GGML_OP_CONCAT:
-            {
-                n_tasks = n_threads;
-            } break;
         case GGML_OP_MUL_MAT:
-            {
-                n_tasks = n_threads;
-
-                // TODO: use different scheduling for different matrix sizes
-                //const int nr0 = ggml_nrows(node->src[0]);
-                //const int nr1 = ggml_nrows(node->src[1]);
-
-                //n_tasks = MIN(n_threads, MAX(1, nr0/128));
-                //printf("nr0 = %8d, nr1 = %8d, nr0*nr1 = %8d, n_tasks%d\n", nr0, nr1, nr0*nr1, n_tasks);
-            } break;
         case GGML_OP_MUL_MAT_ID:
-            {
-                n_tasks = n_threads;
-            } break;
         case GGML_OP_OUT_PROD:
             {
                 n_tasks = n_threads;
             } break;
         case GGML_OP_GET_ROWS:
             {
-                // FIXME: the cost of launching additional threads decreases performance with GPU offloading
-                //n_tasks = MIN(n_threads, ggml_nelements(node->src[1]));
-                n_tasks = MIN(n_cur_threads, ggml_nelements(node->src[1]));
+                // FIXME: get_rows can use additional threads, but the cost of launching additional threads
+                // decreases performance with GPU offloading
+                //n_tasks = n_threads;
+                n_tasks = 1;
             } break;
         case GGML_OP_SCALE:
         case GGML_OP_SET:
@@ -18847,14 +18383,8 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads, int n_cur_
             {
                 n_tasks = MIN(n_threads, ggml_nrows(node->src[0]));
             } break;
-        case GGML_OP_CONV_TRANSPOSE_1D:
-            {
-                n_tasks = n_threads;
-            } break;
         case GGML_OP_IM2COL:
-            {
-                n_tasks = n_threads;
-            } break;
+        case GGML_OP_CONV_TRANSPOSE_1D:
         case GGML_OP_CONV_TRANSPOSE_2D:
             {
                 n_tasks = n_threads;
@@ -18865,33 +18395,12 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads, int n_cur_
                 n_tasks = 1;
             } break;
         case GGML_OP_UPSCALE:
-            {
-                n_tasks = n_threads;
-            } break;
         case GGML_OP_PAD:
-            {
-                n_tasks = n_threads;
-            } break;
         case GGML_OP_ARANGE:
-            {
-                n_tasks = n_threads;
-            } break;
         case GGML_OP_TIMESTEP_EMBEDDING:
-            {
-                n_tasks = n_threads;
-            } break;
         case GGML_OP_ARGSORT:
-            {
-                n_tasks = n_threads;
-            } break;
         case GGML_OP_FLASH_ATTN_EXT:
-            {
-                n_tasks = n_threads;
-            } break;
         case GGML_OP_FLASH_ATTN_BACK:
-            {
-                n_tasks = n_threads;
-            } break;
         case GGML_OP_SSM_CONV:
         case GGML_OP_SSM_SCAN:
             {
@@ -18939,9 +18448,6 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads, int n_cur_
                 }
             } break;
         case GGML_OP_CROSS_ENTROPY_LOSS:
-            {
-                n_tasks = n_threads;
-            } break;
         case GGML_OP_CROSS_ENTROPY_LOSS_BACK:
             {
                 n_tasks = n_threads;
@@ -18971,110 +18477,6 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads, int n_cur_
     return n_tasks;
 }
 
-#ifdef GGML_USE_OPENMP
-static void ggml_barrier(struct ggml_compute_state * state) {
-    if (state->shared->n_threads == 1) {
-        return;
-    }
-
-    #pragma omp barrier
-}
-#else
-static void ggml_barrier(struct ggml_compute_state * state) {
-    if (state->shared->n_threads == 1) {
-        return;
-    }
-
-    atomic_int * n_barrier = &state->shared->n_barrier;
-    atomic_int * n_barrier_passed = &state->shared->n_barrier_passed;
-
-    int n_threads = state->shared->n_threads;
-    int passed_old = atomic_load(n_barrier_passed);
-
-    if (atomic_fetch_add(n_barrier, 1) == n_threads - 1) {
-        // last thread
-        atomic_store(n_barrier, 0);
-        atomic_fetch_add(n_barrier_passed, 1);
-    } else {
-        // wait for other threads
-        //while (atomic_load(n_barrier_passed) == passed_old) {
-        //}
-        const int n_spin_before_sleep = 100000;
-        while (true) {
-            for (int i = 0; i < n_spin_before_sleep; i++) {
-                if (atomic_load(n_barrier_passed) != passed_old) {
-                    return;
-                }
-            #if defined(__SSE3__)
-                _mm_pause();
-            #endif
-            }
-            sched_yield();
-        }
-    }
-}
-#endif
-
-static thread_ret_t ggml_graph_compute_thread(void * data) {
-    struct ggml_compute_state * state = (struct ggml_compute_state *) data;
-
-    const struct ggml_cgraph * cgraph = state->shared->cgraph;
-    const struct ggml_cplan  * cplan  = state->shared->cplan;
-
-    const int ith       = state->ith;
-    const int n_threads = state->shared->n_threads;
-
-    set_numa_thread_affinity(ith);
-
-    struct ggml_compute_params params = {
-        /*.type  =*/ GGML_TASK_TYPE_INIT,
-        /*.ith   =*/ ith,
-        /*.nth   =*/ state->shared->n_threads,
-        /*.wsize =*/ cplan->work_size,
-        /*.wdata =*/ cplan->work_data,
-    };
-
-    for (int node_n = 0; node_n < cgraph->n_nodes; node_n++) {
-        if (cplan->abort_callback && cplan->abort_callback(cplan->abort_callback_data)) {
-            state->ec = GGML_STATUS_ABORTED;
-            return 0;
-        }
-
-        struct ggml_tensor * node = cgraph->nodes[node_n];
-        const int n_tasks = ggml_get_n_tasks(node, n_threads, state->shared->n_threads);
-
-        params.nth = n_tasks;
-
-        /* INIT */
-        if (GGML_OP_HAS_INIT[node->op]) {
-            if (ith < n_tasks) {
-                params.type = GGML_TASK_TYPE_INIT;
-                ggml_compute_forward(&params, node, state);
-            }
-            ggml_barrier(state);
-        }
-
-        /* COMPUTE */
-        if (ith < n_tasks) {
-            params.type = GGML_TASK_TYPE_COMPUTE;
-            ggml_compute_forward(&params, node, state);
-        }
-
-        ggml_barrier(state);
-
-        /* FINALIZE */
-        if (GGML_OP_HAS_FINALIZE[node->op]) {
-            if (params.ith == 0) {
-                params.type = GGML_TASK_TYPE_FINALIZE;
-                ggml_compute_forward(&params, node, state);
-            }
-            ggml_barrier(state);
-        }
-    }
-
-    return 0;
-}
-
 struct ggml_cplan ggml_graph_plan(const struct ggml_cgraph * cgraph, int n_threads) {
     if (n_threads <= 0) {
         n_threads = GGML_DEFAULT_N_THREADS;
@@ -19091,7 +18493,7 @@ struct ggml_cplan ggml_graph_plan(const struct ggml_cgraph * cgraph, int n_threa
     for (int i = 0; i < cgraph->n_nodes; i++) {
         struct ggml_tensor * node = cgraph->nodes[i];
 
-        const int n_tasks = ggml_get_n_tasks(node, n_threads, 1);
+        const int n_tasks = ggml_get_n_tasks(node, n_threads);
 
         max_tasks = MAX(max_tasks, n_tasks);
 
@@ -19243,119 +18645,121 @@ struct ggml_cplan ggml_graph_plan(const struct ggml_cgraph * cgraph, int n_threa
     return cplan;
 }
 
-static enum ggml_status ggml_graph_compute_parallel(struct ggml_compute_state * workers, int n_threads) {
-    enum ggml_status compute_status = GGML_STATUS_SUCCESS;
+static thread_ret_t ggml_graph_compute_thread(void * data) {
+    struct ggml_compute_state * state = (struct ggml_compute_state *) data;
 
-#ifdef GGML_USE_OPENMP
-    if (n_threads > 1) {
-        #pragma omp parallel num_threads(n_threads)
-        {
-            #pragma omp single
-            {
-                // update the number of threads from the actual number of threads that we got from OpenMP
-                n_threads = omp_get_num_threads();
-                workers[0].shared->n_threads = n_threads;
-            }
-            ggml_graph_compute_thread(&workers[omp_get_thread_num()]);
-        }
-    } else {
-        ggml_graph_compute_thread(&workers[0]);
-    }
-#else
-    // create thread pool
-    if (n_threads > 1) {
-        for (int j = 1; j < n_threads; ++j) {
-            const int rc = ggml_thread_create(&workers[j].thrd, NULL, ggml_graph_compute_thread, &workers[j]);
-            GGML_ASSERT(rc == 0);
-            UNUSED(rc);
-        }
-    }
+    const struct ggml_cgraph * cgraph = state->shared->cgraph;
+    const struct ggml_cplan  * cplan  = state->shared->cplan;
 
-    // this is a work thread too
-    ggml_graph_compute_thread(&workers[0]);
+    set_numa_thread_affinity(state->ith);
 
-    // join or kill thread pool
-    if (n_threads > 1) {
-        for (int j = 1; j < n_threads; j++) {
-            const int rc = ggml_thread_join(workers[j].thrd, NULL);
-            GGML_ASSERT(rc == 0);
-            UNUSED(rc);
+    struct ggml_compute_params params = {
+        /*.ith   =*/ state->ith,
+        /*.nth   =*/ state->shared->n_threads,
+        /*.wsize =*/ cplan->work_size,
+        /*.wdata =*/ cplan->work_data,
+        /*.shared=*/ state->shared,
+    };
+
+    for (int node_n = 0; node_n < cgraph->n_nodes; node_n++) {
+        struct ggml_tensor * node = cgraph->nodes[node_n];
+
+        ggml_compute_forward(&params, node);
+
+        if (state->ith == 0 && cplan->abort_callback && cplan->abort_callback(cplan->abort_callback_data)) {
+            state->shared->ec = GGML_STATUS_ABORTED;
         }
-    }
-#endif
-    // don't leave affinity set on the main thread
-    clear_numa_thread_affinity();
 
-    for (int j = 0; j < n_threads; j++) {
-        if (workers[j].ec != GGML_STATUS_SUCCESS) {
-            compute_status = workers[j].ec;
+        ggml_barrier(state->shared);
+
+        if (state->shared->ec != GGML_STATUS_SUCCESS) {
             break;
         }
     }
-    return compute_status;
+
+    return 0;
 }
 
 enum ggml_status ggml_graph_compute(struct ggml_cgraph * cgraph, struct ggml_cplan * cplan) {
-    {
-        GGML_ASSERT(cplan);
-        GGML_ASSERT(cplan->n_threads > 0);
-
-        if (cplan->work_size > 0) {
-            GGML_ASSERT(cplan->work_data);
-        }
-    }
+    GGML_ASSERT(cplan);
+    GGML_ASSERT(cplan->n_threads > 0);
+    GGML_ASSERT(cplan->work_size == 0 || cplan->work_data != NULL);
 
     int n_threads = cplan->n_threads;
 
-#if defined(GGML_USE_OPENMP)
-    n_threads = MIN(n_threads, omp_get_max_threads());
-#endif
-
     struct ggml_compute_state_shared state_shared = {
         /*.cgraph                  =*/ cgraph,
         /*.cgraph_plan             =*/ cplan,
-        /*.perf_node_start_cycles  =*/ 0,
-        /*.perf_node_start_time_us =*/ 0,
         /*.n_threads               =*/ n_threads,
         /*.n_barrier               =*/ 0,
         /*.n_barrier_passed        =*/ 0,
         /*.abort_callback          =*/ NULL,
         /*.abort_callback_data     =*/ NULL,
-        /*.current_chunk;          =*/ 0,
+        /*.current_chunk           =*/ 0,
+        /*.ec                      =*/ GGML_STATUS_SUCCESS,
     };
+
+#ifdef GGML_USE_OPENMP
+    if (n_threads > 1) {
+        #pragma omp parallel num_threads(n_threads)
+        {
+            #pragma omp single
+            {
+                // update the number of threads from the actual number of threads that we got from OpenMP
+                n_threads = omp_get_num_threads();
+                state_shared.n_threads = n_threads;
+            }
+
+            struct ggml_compute_state worker = {
+                .thrd   = 0,
+                .ith    = omp_get_thread_num(),
+                .shared = &state_shared,
+            };
+            ggml_graph_compute_thread(&worker);
+        }
+    } else {
+        struct ggml_compute_state worker = {
+            .thrd   = 0,
+            .ith    = 0,
+            .shared = &state_shared,
+        };
+        ggml_graph_compute_thread(&worker);
+    }
+#else
     struct ggml_compute_state * workers = alloca(sizeof(struct ggml_compute_state)*n_threads);
-    const int64_t perf_start_cycles  = ggml_perf_cycles();
-    const int64_t perf_start_time_us = ggml_perf_time_us();
 
     for (int j = 0; j < n_threads; ++j) {
         workers[j] = (struct ggml_compute_state) {
             .thrd   = 0,
             .ith    = j,
             .shared = &state_shared,
-            .ec     = GGML_STATUS_SUCCESS,
         };
     }
 
-    enum ggml_status compute_status = ggml_graph_compute_parallel(workers, n_threads);
-
-    // performance stats (graph)
-    {
-        int64_t perf_cycles_cur  = ggml_perf_cycles()  - perf_start_cycles;
-        int64_t perf_time_us_cur = ggml_perf_time_us() - perf_start_time_us;
+    // create thread pool
+    for (int j = 1; j < n_threads; ++j) {
+        const int rc = ggml_thread_create(&workers[j].thrd, NULL, ggml_graph_compute_thread, &workers[j]);
+        GGML_ASSERT(rc == 0);
+        UNUSED(rc);
+    }
 
-        cgraph->perf_runs++;
-        cgraph->perf_cycles  += perf_cycles_cur;
-        cgraph->perf_time_us += perf_time_us_cur;
+    // this is a work thread too
+    ggml_graph_compute_thread(&workers[0]);
 
-        GGML_PRINT_DEBUG("%s: perf (%d) - cpu = %.3f / %.3f ms, wall = %.3f / %.3f ms\n",
-                __func__, cgraph->perf_runs,
-                (double) perf_cycles_cur      / (double) ggml_cycles_per_ms(),
-                (double) cgraph->perf_cycles  / (double) ggml_cycles_per_ms() / (double) cgraph->perf_runs,
-                (double) perf_time_us_cur     / 1000.0,
-                (double) cgraph->perf_time_us / 1000.0 / cgraph->perf_runs);
+    // join or kill thread pool
+    if (n_threads > 1) {
+        for (int j = 1; j < n_threads; j++) {
+            const int rc = ggml_thread_join(workers[j].thrd, NULL);
+            GGML_ASSERT(rc == 0);
+            UNUSED(rc);
+        }
     }
+#endif
+
+    // don't leave affinity set on the main thread
+    clear_numa_thread_affinity();
 
-    return compute_status;
+    return state_shared.ec;
 }
 
 enum ggml_status ggml_graph_compute_with_ctx(struct ggml_context * ctx, struct ggml_cgraph * cgraph, int n_threads) {
@@ -19854,24 +19258,16 @@ struct ggml_cgraph * ggml_graph_import(const char * fname, struct ggml_context *
 }
 
 void ggml_graph_print(const struct ggml_cgraph * cgraph) {
-    int64_t perf_total_per_op_us[GGML_OP_COUNT] = {0};
-
     GGML_PRINT("=== GRAPH ===\n");
 
     GGML_PRINT("n_nodes = %d\n", cgraph->n_nodes);
     for (int i = 0; i < cgraph->n_nodes; i++) {
         struct ggml_tensor * node = cgraph->nodes[i];
 
-        perf_total_per_op_us[node->op] += MAX(1, node->perf_time_us);
-
-        GGML_PRINT(" - %3d: [ %5" PRId64 ", %5" PRId64 ", %5" PRId64 "] %16s %s (%3d) cpu = %7.3f / %7.3f ms, wall = %7.3f / %7.3f ms\n",
+        GGML_PRINT(" - %3d: [ %5" PRId64 ", %5" PRId64 ", %5" PRId64 "] %16s %s\n",
                 i,
                 node->ne[0], node->ne[1], node->ne[2],
-                ggml_op_name(node->op), (node->flags & GGML_TENSOR_FLAG_PARAM) ? "x" : node->grad ? "g" : " ", node->perf_runs,
-                (double) node->perf_cycles  / (double) ggml_cycles_per_ms(),
-                (double) node->perf_cycles  / (double) ggml_cycles_per_ms() / (double) node->perf_runs,
-                (double) node->perf_time_us / 1000.0,
-                (double) node->perf_time_us / 1000.0 / node->perf_runs);
+                ggml_op_name(node->op), (node->flags & GGML_TENSOR_FLAG_PARAM) ? "x" : node->grad ? "g" : " ");
     }
 
     GGML_PRINT("n_leafs = %d\n", cgraph->n_leafs);
@@ -19885,14 +19281,6 @@ void ggml_graph_print(const struct ggml_cgraph * cgraph) {
                 ggml_get_name(node));
     }
 
-    for (int i = 0; i < GGML_OP_COUNT; i++) {
-        if (perf_total_per_op_us[i] == 0) {
-            continue;
-        }
-
-        GGML_PRINT("perf_total_per_op_us[%16s] = %7.3f ms\n", ggml_op_name(i), (double) perf_total_per_op_us[i] / 1000.0);
-    }
-
     GGML_PRINT("========================================\n");
 }
 
diff --git a/ggml.h b/ggml.h
index 2e8fd0dbc2e315409d71b0320ff5d16d1d288b7d..d895c9acdb59631b985adf3aaf8f3e8d217e8d65 100644 (file)
--- a/ggml.h
+++ b/ggml.h
@@ -591,11 +591,7 @@ extern "C" {
         struct ggml_tensor * grad;
         struct ggml_tensor * src[GGML_MAX_SRC];
 
-        // performance
-        int     perf_runs;
-        int64_t perf_cycles;
-        int64_t perf_time_us;
-
+        // source tensor and offset for views
         struct ggml_tensor * view_src;
         size_t               view_offs;
 
@@ -605,7 +601,7 @@ extern "C" {
 
         void * extra; // extra things e.g. for ggml-cuda.cu
 
-        char padding[8];
+        // char padding[4];
     };
 
     static const size_t GGML_TENSOR_SIZE = sizeof(struct ggml_tensor);
@@ -652,11 +648,6 @@ extern "C" {
         struct ggml_hash_set visited_hash_table;
 
         enum ggml_cgraph_eval_order order;
-
-        // performance
-        int     perf_runs;
-        int64_t perf_cycles;
-        int64_t perf_time_us;
     };
 
     // scratch buffer
@@ -673,28 +664,6 @@ extern "C" {
         bool   no_alloc;   // don't allocate memory for the tensor data
     };
 
-
-    // compute types
-
-    // NOTE: the INIT or FINALIZE pass is not scheduled unless explicitly enabled.
-    // This behavior was changed since https://github.com/ggerganov/llama.cpp/pull/1995.
-    enum ggml_task_type {
-        GGML_TASK_TYPE_INIT = 0,
-        GGML_TASK_TYPE_COMPUTE,
-        GGML_TASK_TYPE_FINALIZE,
-    };
-
-    struct ggml_compute_params {
-        enum ggml_task_type type;
-
-        // ith = thread index, nth = number of threads
-        int ith, nth;
-
-        // work buffer for all threads
-        size_t wsize;
-        void * wdata;
-    };
-
     // numa strategies
     enum ggml_numa_strategy {
         GGML_NUMA_STRATEGY_DISABLED   = 0,
index c710ef82b746e0c749df67de33e8a0816f9c08a8..49bc93c028a2a66ef34498a0c80ee4400a9bd934 100644 (file)
--- a/llama.cpp
+++ b/llama.cpp
@@ -12785,12 +12785,6 @@ static int llama_decode_internal(
             }
         }
 
-#ifdef GGML_PERF
-        // print timing information per ggml operation (for debugging purposes)
-        // requires GGML_PERF to be defined
-        ggml_graph_print(gf);
-#endif
-
         // plot the computation graph in dot format (for debugging purposes)
         //if (n_past%100 == 0) {
         //    ggml_graph_dump_dot(gf, NULL, "llama.dot");
index bbe263ddd2bb4070f2b311c178ca861d2d44d22d..6626ceb26213f230558cfaae9d26656c6e00b5dc 100644 (file)
--- a/sgemm.cpp
+++ b/sgemm.cpp
@@ -249,9 +249,8 @@ class tinyBLAS {
         : A(A), B(B), C(C), k(k), lda(lda), ldb(ldb), ldc(ldc), ith(ith), nth(nth) {
     }
 
-    void matmul(int64_t m, int64_t n, int task) {
-        if (task == GGML_TASK_TYPE_COMPUTE)
-            mnpack(0, m, 0, n);
+    void matmul(int64_t m, int64_t n) {
+        mnpack(0, m, 0, n);
     }
 
   private:
@@ -458,9 +457,8 @@ class tinyBLAS_Q0_ARM {
         : A(A), B(B), C(C), k(k), lda(lda), ldb(ldb), ldc(ldc), ith(ith), nth(nth) {
     }
 
-    void matmul(int64_t m, int64_t n, int task) {
-        if (task == GGML_TASK_TYPE_COMPUTE)
-            mnpack(0, m, 0, n);
+    void matmul(int64_t m, int64_t n) {
+        mnpack(0, m, 0, n);
     }
 
   private:
@@ -596,9 +594,8 @@ class tinyBLAS_Q0_AVX {
         : A(A), B(B), C(C), k(k), lda(lda), ldb(ldb), ldc(ldc), ith(ith), nth(nth) {
     }
 
-    void matmul(int64_t m, int64_t n, int task) {
-        if (task == GGML_TASK_TYPE_COMPUTE)
-            mnpack(0, m, 0, n);
+    void matmul(int64_t m, int64_t n) {
+        mnpack(0, m, 0, n);
     }
 
   private:
@@ -829,7 +826,7 @@ class tinyBLAS_Q0_AVX {
  * For example, for single-threaded single-precision GEMM you can say
  *
  *     llamafile_sgemm(m, n, k, A, lda, B, ldb, C, ldc,
- *                     0, 1, GGML_TASK_TYPE_COMPUTE,
+ *                     0, 1,
  *                     GGML_TYPE_F32, GGML_TYPE_F32, GGML_TYPE_F32);
  *
  * @param m is rows in `A` and `C`
@@ -843,14 +840,13 @@ class tinyBLAS_Q0_AVX {
  * @param ldc is row stride of `C`
  * @param ith is thread id (must be less than `nth`)
  * @param nth is number of threads (must be greater than zero)
- * @param task is GGML task type
  * @param Atype is GGML data type of `A`
  * @param Btype is GGML data type of `B`
  * @param Ctype is GGML data type of `C`
  * @return true if this function was able to service the matmul request
  */
 bool llamafile_sgemm(int64_t m, int64_t n, int64_t k, const void *A, int64_t lda, const void *B, int64_t ldb, void *C,
-                     int64_t ldc, int ith, int nth, int task, int Atype, int Btype, int Ctype) {
+                     int64_t ldc, int ith, int nth, int Atype, int Btype, int Ctype) {
 
     assert(m >= 0);
     assert(n >= 0);
@@ -877,7 +873,7 @@ bool llamafile_sgemm(int64_t m, int64_t n, int64_t k, const void *A, int64_t lda
             (const float *)B, ldb,
             (float *)C, ldc,
             ith, nth};
-        tb.matmul(m, n, task);
+        tb.matmul(m, n);
         return true;
 #elif defined(__AVX__) || defined(__AVX2__)
         if (k % 8)
@@ -887,7 +883,7 @@ bool llamafile_sgemm(int64_t m, int64_t n, int64_t k, const void *A, int64_t lda
             (const float *)B, ldb,
             (float *)C, ldc,
             ith, nth};
-        tb.matmul(m, n, task);
+        tb.matmul(m, n);
         return true;
 #elif defined(__ARM_NEON)
         if (n < 4)
@@ -899,7 +895,7 @@ bool llamafile_sgemm(int64_t m, int64_t n, int64_t k, const void *A, int64_t lda
             (const float *)B, ldb,
             (float *)C, ldc,
             ith, nth};
-        tb.matmul(m, n, task);
+        tb.matmul(m, n);
         return true;
 #else
         return false;
@@ -917,7 +913,7 @@ bool llamafile_sgemm(int64_t m, int64_t n, int64_t k, const void *A, int64_t lda
             (const float *)B, ldb,
             (float *)C, ldc,
             ith, nth};
-        tb.matmul(m, n, task);
+        tb.matmul(m, n);
         return true;
 #elif (defined(__AVX__) || defined(__AVX2__)) && defined(__F16C__)
         if (k % 8)
@@ -929,7 +925,7 @@ bool llamafile_sgemm(int64_t m, int64_t n, int64_t k, const void *A, int64_t lda
             (const float *)B, ldb,
             (float *)C, ldc,
             ith, nth};
-        tb.matmul(m, n, task);
+        tb.matmul(m, n);
         return true;
 #elif defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC) && !defined(_MSC_VER)
         if (n < 8)
@@ -943,7 +939,7 @@ bool llamafile_sgemm(int64_t m, int64_t n, int64_t k, const void *A, int64_t lda
             (const ggml_fp16_t *)B, ldb,
             (float *)C, ldc,
             ith, nth};
-        tb.matmul(m, n, task);
+        tb.matmul(m, n);
         return true;
 #elif defined(__ARM_NEON) && !defined(_MSC_VER)
         if (k % 4)
@@ -955,7 +951,7 @@ bool llamafile_sgemm(int64_t m, int64_t n, int64_t k, const void *A, int64_t lda
             (const float *)B, ldb,
             (float *)C, ldc,
             ith, nth};
-        tb.matmul(m, n, task);
+        tb.matmul(m, n);
         return true;
 #else
         return false;
@@ -971,7 +967,7 @@ bool llamafile_sgemm(int64_t m, int64_t n, int64_t k, const void *A, int64_t lda
             (const block_q8_0 *)B, ldb,
             (float *)C, ldc,
             ith, nth};
-        tb.matmul(m, n, task);
+        tb.matmul(m, n);
         return true;
 #elif defined(__ARM_FEATURE_DOTPROD)
         tinyBLAS_Q0_ARM<block_q8_0> tb{
@@ -979,7 +975,7 @@ bool llamafile_sgemm(int64_t m, int64_t n, int64_t k, const void *A, int64_t lda
             (const block_q8_0 *)B, ldb,
             (float *)C, ldc,
             ith, nth};
-        tb.matmul(m, n, task);
+        tb.matmul(m, n);
         return true;
 #else
         return false;
@@ -995,7 +991,7 @@ bool llamafile_sgemm(int64_t m, int64_t n, int64_t k, const void *A, int64_t lda
             (const block_q8_0 *)B, ldb,
             (float *)C, ldc,
             ith, nth};
-        tb.matmul(m, n, task);
+        tb.matmul(m, n);
         return true;
 #elif defined(__ARM_FEATURE_DOTPROD)
         tinyBLAS_Q0_ARM<block_q4_0> tb{
@@ -1003,7 +999,7 @@ bool llamafile_sgemm(int64_t m, int64_t n, int64_t k, const void *A, int64_t lda
             (const block_q8_0 *)B, ldb,
             (float *)C, ldc,
             ith, nth};
-        tb.matmul(m, n, task);
+        tb.matmul(m, n);
         return true;
 #else
         return false;
@@ -1025,7 +1021,6 @@ bool llamafile_sgemm(int64_t m, int64_t n, int64_t k, const void *A, int64_t lda
     (void)ldc;
     (void)ith;
     (void)nth;
-    (void)task;
     (void)Atype;
     (void)Btype;
     (void)Ctype;
diff --git a/sgemm.h b/sgemm.h
index f29747d0a477af4ef5fb24f3877ac30ffd0ccae8..caf6dd5567b3ad107d414be516a8a11cc7745a6b 100644 (file)
--- a/sgemm.h
+++ b/sgemm.h
@@ -7,7 +7,7 @@ extern "C" {
 
 bool llamafile_sgemm(int64_t, int64_t, int64_t, const void *, int64_t,
                      const void *, int64_t, void *, int64_t, int, int,
-                     int, int, int, int);
+                     int, int, int);
 
 #ifdef __cplusplus
 }