]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
metal : consolidate bin kernels (#19390)
authorGeorgi Gerganov <redacted>
Sat, 7 Feb 2026 08:35:56 +0000 (10:35 +0200)
committerGitHub <redacted>
Sat, 7 Feb 2026 08:35:56 +0000 (10:35 +0200)
* metal : refactor bin kernels

* cont

* cont : fix cv

ggml/src/ggml-metal/ggml-metal-device.cpp
ggml/src/ggml-metal/ggml-metal-device.h
ggml/src/ggml-metal/ggml-metal-device.m
ggml/src/ggml-metal/ggml-metal-impl.h
ggml/src/ggml-metal/ggml-metal-ops.cpp
ggml/src/ggml-metal/ggml-metal.metal

index 6af0dd88d55ef1571a3bdb406c552e626751fecf..4c4c3ce36c4605bf8319264112b0484a8a192cd4 100644 (file)
@@ -1392,34 +1392,78 @@ ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_flash_attn_ext_v
     GGML_UNUSED(op);
 }
 
-ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_bin(
-        ggml_metal_library_t lib,
-        ggml_op op,
-        int32_t n_fuse,
-        bool row) {
+ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_bin(ggml_metal_library_t lib, const ggml_tensor * op, int32_t n_fuse) {
     char base[256];
     char name[256];
 
-    const char * op_str = "undefined";
-    switch (op) {
-        case GGML_OP_ADD:   op_str = "add";   break;
-        case GGML_OP_SUB:   op_str = "sub";   break;
-        case GGML_OP_MUL:   op_str = "mul";   break;
-        case GGML_OP_DIV:   op_str = "div";   break;
+    int op_num = -1;
+
+    switch (op->op) {
+        case GGML_OP_ADD: op_num = 0; break;
+        case GGML_OP_SUB: op_num = 1; break;
+        case GGML_OP_MUL: op_num = 2; break;
+        case GGML_OP_DIV: op_num = 3; break;
         default: GGML_ABORT("fatal error");
     };
 
-    if (row) {
-        snprintf(base, 256, "kernel_%s_row_c4_fuse_%d", op_str, n_fuse);
-    } else {
-        snprintf(base, 256, "kernel_%s_fuse_%d", op_str, n_fuse);
+    const char * t0_str = ggml_type_name(op->src[0]->type);
+    const char * t1_str = ggml_type_name(op->src[1]->type);
+    const char * t_str  = ggml_type_name(op->type);
+
+    const bool is_c4 = (op->src[0]->ne[0] % 4 == 0) && (op->src[1]->ne[0] % 4 == 0);
+
+    const bool is_rb = ggml_is_contiguous(op->src[0]) && ggml_is_contiguous(op->src[1]) && (ggml_nrows(op->src[1]) == 1) && ggml_nelements(op) < 65536;
+
+    snprintf(base, 256, "kernel_bin_fuse_%s_%s_%s%s", t0_str, t1_str, t_str, is_c4 ? "_4" : "");
+    snprintf(name, 256, "%s_op=%d_nf=%d_rb=%d", base, op_num, n_fuse, is_rb);
+
+    ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);
+    if (!res.pipeline) {
+        ggml_metal_cv_t cv = ggml_metal_cv_init();
+
+        ggml_metal_cv_set_int16(cv, op_num, FC_BIN + 0);
+        ggml_metal_cv_set_int16(cv, n_fuse, FC_BIN + 1);
+        ggml_metal_cv_set_bool (cv, is_rb,  FC_BIN + 2);
+
+        res = ggml_metal_library_compile_pipeline(lib, base, name, cv);
+
+        ggml_metal_cv_free(cv);
     }
 
-    snprintf(name, 256, "%s", base);
+    res.c4  = is_c4;
+    res.cnt = is_rb;
+
+    return res;
+}
+
+ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_bin_one(ggml_metal_library_t lib, ggml_op op) {
+    char base[256];
+    char name[256];
+
+    int op_num = -1;
+
+    switch (op) {
+        case GGML_OP_ADD: op_num = 0; break;
+        case GGML_OP_SUB: op_num = 1; break;
+        case GGML_OP_MUL: op_num = 2; break;
+        case GGML_OP_DIV: op_num = 3; break;
+        default: GGML_ABORT("fatal error");
+    };
+
+    snprintf(base, 256, "kernel_bin_fuse_%s_%s_%s", "f32", "f32", "f32");
+    snprintf(name, 256, "%s_op=%d_nf=%d", base, op_num, 1);
 
     ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);
     if (!res.pipeline) {
-        res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
+        ggml_metal_cv_t cv = ggml_metal_cv_init();
+
+        ggml_metal_cv_set_int16(cv, op_num, FC_BIN + 0);
+        ggml_metal_cv_set_int16(cv, 1,      FC_BIN + 1);
+        ggml_metal_cv_set_bool (cv, false,  FC_BIN + 2);
+
+        res = ggml_metal_library_compile_pipeline(lib, base, name, cv);
+
+        ggml_metal_cv_free(cv);
     }
 
     return res;
index 84dcec308302e6ae9bccd16c9ab9f27f7110f892..93d7f6a216f976bf7f251ffde0add5dbf338e2a2 100644 (file)
@@ -53,6 +53,9 @@ struct ggml_metal_pipeline_with_params {
     int nr1;
 
     size_t smem;
+
+    bool c4;
+    bool cnt;
 };
 
 int ggml_metal_pipeline_max_theads_per_threadgroup(struct ggml_metal_pipeline_with_params pipeline);
@@ -134,7 +137,8 @@ struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_argsort
 struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_argsort_merge     (ggml_metal_library_t lib, const struct ggml_tensor * op);
 struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_top_k             (ggml_metal_library_t lib, const struct ggml_tensor * op);
 struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_top_k_merge       (ggml_metal_library_t lib, const struct ggml_tensor * op);
-struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_bin               (ggml_metal_library_t lib, enum ggml_op op, int32_t n_fuse, bool row);
+struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_bin               (ggml_metal_library_t lib, const struct ggml_tensor * op, int32_t n_fuse );
+struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_bin_one           (ggml_metal_library_t lib, enum ggml_op op);
 struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_l2_norm           (ggml_metal_library_t lib, const struct ggml_tensor * op);
 struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_group_norm        (ggml_metal_library_t lib, const struct ggml_tensor * op);
 struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_norm              (ggml_metal_library_t lib, const struct ggml_tensor * op, int32_t n_fuse);
index c8e737d418714b17814c60fc786e2e95f54778c3..891d70c85a452e0092dde44518ec4b0d48b01b4c 100644 (file)
@@ -346,10 +346,12 @@ struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline(ggml_meta
 
     struct ggml_metal_pipeline_with_params res = {
         /*.pipeline =*/ nil,
+        /*.nsg      =*/ 0,
         /*.nr0      =*/ 0,
         /*.nr1      =*/ 0,
-        /*.nsg      =*/ 0,
         /*.smem     =*/ 0,
+        /*.c4       =*/ false,
+        /*.cnt      =*/ false,
     };
 
     res.pipeline = ggml_metal_pipelines_get(lib->pipelines, name);
@@ -362,10 +364,12 @@ struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline(ggml_meta
 struct ggml_metal_pipeline_with_params ggml_metal_library_compile_pipeline(ggml_metal_library_t lib, const char * base, const char * name, ggml_metal_cv_t cv) {
     struct ggml_metal_pipeline_with_params res = {
         /*.pipeline =*/ nil,
+        /*.nsg      =*/ 0,
         /*.nr0      =*/ 0,
         /*.nr1      =*/ 0,
-        /*.nsg      =*/ 0,
         /*.smem     =*/ 0,
+        /*.c4       =*/ false,
+        /*.cnt      =*/ false,
     };
 
     [lib->lock lock];
@@ -1054,7 +1058,7 @@ bool ggml_metal_device_supports_op(ggml_metal_device_t dev, const struct ggml_te
         case GGML_OP_MUL:
         case GGML_OP_DIV:
         case GGML_OP_ADD_ID:
-            return op->src[0]->type == GGML_TYPE_F32;
+            return ggml_is_contiguous_rows(op->src[0]) && ggml_is_contiguous_rows(op->src[1]) && op->src[0]->type == GGML_TYPE_F32;
         case GGML_OP_ACC:
         case GGML_OP_REPEAT:
         case GGML_OP_SCALE:
index 7f73cb97bbb48df3f6e761e47710e560d2bc9136..77bb403c15d2fce0579265b06820be613272cab7 100644 (file)
@@ -80,6 +80,7 @@
 #define FC_SSM_CONV                    900
 #define FC_SOLVE_TRI                   1000
 #define FC_COUNT_EQUAL                 1100
+#define FC_BIN                         1200
 
 // op-specific constants
 #define OP_FLASH_ATTN_EXT_NQPSG 8
index e0ed6c7805cea280cc6c39b37d36bac1a0f81fb6..dbf25433c259922d4530426b6097f8e335628eb3 100644 (file)
@@ -707,7 +707,7 @@ int ggml_metal_op_acc(ggml_metal_op_t ctx, int idx) {
         /*.o1   =*/ { 0 },
     };
 
-    auto pipeline = ggml_metal_library_get_pipeline_bin(lib, GGML_OP_ADD, 1, false);
+    auto pipeline = ggml_metal_library_get_pipeline_bin_one(lib, GGML_OP_ADD);
 
     ggml_metal_encoder_set_pipeline(enc, pipeline);
     ggml_metal_encoder_set_bytes   (enc, &args, sizeof(args), 0);
@@ -2895,8 +2895,6 @@ int ggml_metal_op_bin(ggml_metal_op_t ctx, int idx) {
     GGML_ASSERT(ggml_is_contiguous_rows(op->src[0]));
     GGML_ASSERT(ggml_is_contiguous_rows(op->src[1]));
 
-    bool bcast_row = false;
-
     ggml_metal_buffer_id bid_src0 = ggml_metal_get_buffer_id(op->src[0]);
     ggml_metal_buffer_id bid_src1 = ggml_metal_get_buffer_id(op->src[1]);
     ggml_metal_buffer_id bid_dst  = ggml_metal_get_buffer_id(op);
@@ -2990,18 +2988,7 @@ int ggml_metal_op_bin(ggml_metal_op_t ctx, int idx) {
 
     struct ggml_metal_pipeline_with_params pipeline;
 
-    if (ggml_nelements(op->src[1]) == ne10 && ggml_is_contiguous(op->src[1]) && ne00 % 4 == 0 && ne10 % 4 == 0) {
-        GGML_ASSERT(ggml_is_contiguous(op->src[0]));
-
-        // src1 is a row
-        GGML_ASSERT(ne11 == 1);
-
-        pipeline = ggml_metal_library_get_pipeline_bin(lib, op->op, n_fuse, true);
-
-        bcast_row = true;
-    } else {
-        pipeline = ggml_metal_library_get_pipeline_bin(lib, op->op, n_fuse, false);
-    }
+    pipeline = ggml_metal_library_get_pipeline_bin(lib, op, n_fuse);
 
     if (n_fuse > 1) {
         bid_dst = ggml_metal_get_buffer_id(ctx->node(idx + n_fuse - 1));
@@ -3015,20 +3002,28 @@ int ggml_metal_op_bin(ggml_metal_op_t ctx, int idx) {
         }
     }
 
+    if (pipeline.c4) {
+        args.ne00 = ne00/4;
+        args.ne10 = ne10/4;
+        args.ne0  = ne0/4;
+    }
+
     ggml_metal_encoder_set_pipeline(enc, pipeline);
     ggml_metal_encoder_set_bytes   (enc, &args, sizeof(args), 0);
     ggml_metal_encoder_set_buffer  (enc, bid_src0, 1);
     ggml_metal_encoder_set_buffer  (enc, bid_src1, 2);
     ggml_metal_encoder_set_buffer  (enc, bid_dst,  3);
 
-    if (bcast_row) {
-        const int64_t n = ggml_nelements(op)/4;
+    if (pipeline.cnt) {
+        const int n = pipeline.c4 ? ggml_nelements(op)/4 : ggml_nelements(op);
 
         ggml_metal_encoder_dispatch_threadgroups(enc, n, 1, 1, 1, 1, 1);
     } else {
-        int nth = 32;
+        const int nth_max = MIN(256, ggml_metal_pipeline_max_theads_per_threadgroup(pipeline));
+
+        int nth = 1;
 
-        while (16*nth < ne0 && nth < ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)) {
+        while (2*nth < args.ne0 && nth < nth_max) {
             nth *= 2;
         }
 
index 612a42a1ea86b69c495d8693209b9e32a596c29c..35cc3bbdfdf1dfabff61affed44d2424301d55c9 100644 (file)
@@ -895,11 +895,13 @@ enum ggml_sort_order {
     GGML_SORT_ORDER_DESC,
 };
 
-// general-purpose kernel for addition, subtraction, multiplication and division of two tensors
-// pros: works for non-contiguous tensors, supports broadcast across all dims
-// cons: not very efficient
-template <int F>
-kernel void kernel_add_fuse_impl(
+// OP: 0 - add, 1 - sub, 2 - mul, 3 - div
+constant short FC_bin_op [[function_constant(FC_BIN + 0)]];
+constant short FC_bin_f  [[function_constant(FC_BIN + 1)]];
+constant bool  FC_bin_rb [[function_constant(FC_BIN + 2)]];
+
+template <typename T0, typename T1, typename T>
+kernel void kernel_bin_fuse_impl(
         constant ggml_metal_kargs_bin & args,
         device const char * src0,
         device const char * src1,
@@ -907,139 +909,153 @@ kernel void kernel_add_fuse_impl(
         uint3   tgpig[[threadgroup_position_in_grid]],
         ushort3 tpitg[[thread_position_in_threadgroup]],
         ushort3   ntg[[threads_per_threadgroup]]) {
-    const int i03 = tgpig.z;
-    const int i02 = tgpig.y;
-    const int i01 = tgpig.x;
+#define FC_OP FC_bin_op
+#define FC_F  FC_bin_f
+#define FC_RB FC_bin_rb
 
-    const int i13 = i03%args.ne13;
-    const int i12 = i02%args.ne12;
-    const int i11 = i01%args.ne11;
+    if (FC_RB) {
+        // row broadcast
+        const uint i0 = tgpig.x;
+        const uint i1 = i0%args.ne10;
 
-    device const float * src0_ptr = (device const float *) (src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01 + args.offs);
-    device       float * dst_ptr  = (device       float *) (dst  + i03*args.nb3  + i02*args.nb2  + i01*args.nb1  + args.offs);
+        device const T0 * src0_row = (device const T0 *) (src0);
+        device       T  * dst_row  = (device       T  *) (dst);
 
-    device const float * src1_ptr[F];
-    for (short j = 0; j < F; ++j) {
-        src1_ptr[j] = (device const float *) (src1 + args.o1[j] + i13*args.nb13 + i12*args.nb12 + i11*args.nb11);
-    }
+        if (FC_F == 1) {
+            device const T1 * src1_row = (device const T1 *) (src1 + args.o1[0]);
 
-    for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) {
-        const int i10 = i0%args.ne10;
+            if (FC_OP == 0) {
+                dst_row[i0] = src0_row[i0] + src1_row[i1];
+            }
+
+            if (FC_OP == 1) {
+                dst_row[i0] = src0_row[i0] - src1_row[i1];
+            }
+
+            if (FC_OP == 2) {
+                dst_row[i0] = src0_row[i0] * src1_row[i1];
+            }
+
+            if (FC_OP == 3) {
+                dst_row[i0] = src0_row[i0] / src1_row[i1];
+            }
+        } else {
+            T0 res = src0_row[i0];
+
+            if (FC_OP == 0) {
+                FOR_UNROLL (short j = 0; j < FC_F; ++j) {
+                    res += ((device const T1 *) (src1 + args.o1[j]))[i1];
+                }
+            }
+
+            if (FC_OP == 1) {
+                FOR_UNROLL (short j = 0; j < FC_F; ++j) {
+                    res -= ((device const T1 *) (src1 + args.o1[j]))[i1];
+                }
+            }
 
-        float res = src0_ptr[i0];
+            if (FC_OP == 2) {
+                FOR_UNROLL (short j = 0; j < FC_F; ++j) {
+                    res *= ((device const T1 *) (src1 + args.o1[j]))[i1];
+                }
+            }
+
+            if (FC_OP == 3) {
+                FOR_UNROLL (short j = 0; j < FC_F; ++j) {
+                    res /= ((device const T1 *) (src1 + args.o1[j]))[i1];
+                }
+            }
 
-#pragma unroll
-        for (short j = 0; j < F; ++j) {
-            res += src1_ptr[j][i10];
+            dst_row[i0] = res;
         }
+    } else {
+        const int i03 = tgpig.z;
+        const int i02 = tgpig.y;
+        const int i01 = tgpig.x;
 
-        dst_ptr[i0] = res;
-    }
-}
+        if (i01 >= args.ne01) {
+            return;
+        }
 
-typedef decltype(kernel_add_fuse_impl<2>) kernel_add_fuse_t;
+        const int i13 = i03%args.ne13;
+        const int i12 = i02%args.ne12;
+        const int i11 = i01%args.ne11;
 
-template [[host_name("kernel_add_fuse_1")]] kernel kernel_add_fuse_t kernel_add_fuse_impl<1>;
-template [[host_name("kernel_add_fuse_2")]] kernel kernel_add_fuse_t kernel_add_fuse_impl<2>;
-template [[host_name("kernel_add_fuse_3")]] kernel kernel_add_fuse_t kernel_add_fuse_impl<3>;
-template [[host_name("kernel_add_fuse_4")]] kernel kernel_add_fuse_t kernel_add_fuse_impl<4>;
-template [[host_name("kernel_add_fuse_5")]] kernel kernel_add_fuse_t kernel_add_fuse_impl<5>;
-template [[host_name("kernel_add_fuse_6")]] kernel kernel_add_fuse_t kernel_add_fuse_impl<6>;
-template [[host_name("kernel_add_fuse_7")]] kernel kernel_add_fuse_t kernel_add_fuse_impl<7>;
-template [[host_name("kernel_add_fuse_8")]] kernel kernel_add_fuse_t kernel_add_fuse_impl<8>;
+        device const T0 * src0_ptr = (device const T0 *) (src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01 + args.offs);
+        device       T  * dst_ptr  = (device       T  *) (dst  + i03*args.nb3  + i02*args.nb2  + i01*args.nb1  + args.offs);
 
-kernel void kernel_sub_fuse_1(
-        constant ggml_metal_kargs_bin & args,
-        device const char * src0,
-        device const char * src1,
-        device       char * dst,
-        uint3   tgpig[[threadgroup_position_in_grid]],
-        ushort3 tpitg[[thread_position_in_threadgroup]],
-        ushort3   ntg[[threads_per_threadgroup]]) {
-    const int i03 = tgpig.z;
-    const int i02 = tgpig.y;
-    const int i01 = tgpig.x;
+        if (FC_F == 1) {
+            device const T1 * src1_ptr = (device const T1 *) (src1 + args.o1[0] + i13*args.nb13 + i12*args.nb12 + i11*args.nb11);
 
-    const int i13 = i03%args.ne13;
-    const int i12 = i02%args.ne12;
-    const int i11 = i01%args.ne11;
+            for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) {
+                const int i10 = i0%args.ne10;
 
-    device const char * src0_ptr = src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01 + args.offs;
-    device const char * src1_ptr = src1 + i13*args.nb13 + i12*args.nb12 + i11*args.nb11 + args.o1[0];
-    device       char * dst_ptr  = dst  + i03*args.nb3  + i02*args.nb2  + i01*args.nb1  + args.offs;
+                if (FC_OP == 0) {
+                    dst_ptr[i0] = src0_ptr[i0] + src1_ptr[i10];
+                }
 
-    for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) {
-        const int i10 = i0%args.ne10;
-        *((device float *)(dst_ptr + i0*args.nb0)) = *((device float *)(src0_ptr + i0*args.nb00)) - *((device float *)(src1_ptr + i10*args.nb10));
-    }
-}
+                if (FC_OP == 1) {
+                    dst_ptr[i0] = src0_ptr[i0] - src1_ptr[i10];
+                }
 
-kernel void kernel_mul_fuse_1(
-        constant ggml_metal_kargs_bin & args,
-        device const char * src0,
-        device const char * src1,
-        device       char * dst,
-        uint3   tgpig[[threadgroup_position_in_grid]],
-        ushort3 tpitg[[thread_position_in_threadgroup]],
-        ushort3   ntg[[threads_per_threadgroup]]) {
-    const int i03 = tgpig.z;
-    const int i02 = tgpig.y;
-    const int i01 = tgpig.x;
+                if (FC_OP == 2) {
+                    dst_ptr[i0] = src0_ptr[i0] * src1_ptr[i10];
+                }
 
-    const int i13 = i03%args.ne13;
-    const int i12 = i02%args.ne12;
-    const int i11 = i01%args.ne11;
+                if (FC_OP == 3) {
+                    dst_ptr[i0] = src0_ptr[i0] / src1_ptr[i10];
+                }
+            }
+        } else {
+            device const T1 * src1_ptr[8];
+            FOR_UNROLL (short j = 0; j < FC_F; ++j) {
+                src1_ptr[j] = (device const T1 *) (src1 + args.o1[j] + i13*args.nb13 + i12*args.nb12 + i11*args.nb11);
+            }
 
-    device const char * src0_ptr = src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01 + args.offs;
-    device const char * src1_ptr = src1 + i13*args.nb13 + i12*args.nb12 + i11*args.nb11 + args.o1[0];
-    device       char * dst_ptr  = dst  + i03*args.nb3  + i02*args.nb2  + i01*args.nb1  + args.offs;
+            for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) {
+                const int i10 = i0%args.ne10;
 
-    if (args.ne10 == 1) {
-        const float x = *((device float *)(src1_ptr));
-        for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) {
-            *((device float *)(dst_ptr + i0*args.nb0)) = *((device float *)(src0_ptr + i0*args.nb00)) * x;
-        }
-    } else {
-        for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) {
-            const int i10 = i0%args.ne10;
-            *((device float *)(dst_ptr + i0*args.nb0)) = *((device float *)(src0_ptr + i0*args.nb00)) * *((device float *)(src1_ptr + i10*args.nb10));
-        }
-    }
-}
+                T res = src0_ptr[i0];
 
-kernel void kernel_div_fuse_1(
-        constant ggml_metal_kargs_bin & args,
-        device const char * src0,
-        device const char * src1,
-        device       char * dst,
-        uint3   tgpig[[threadgroup_position_in_grid]],
-        ushort3 tpitg[[thread_position_in_threadgroup]],
-        ushort3   ntg[[threads_per_threadgroup]]) {
-    const int i03 = tgpig.z;
-    const int i02 = tgpig.y;
-    const int i01 = tgpig.x;
+                if (FC_OP == 0) {
+                    FOR_UNROLL (short j = 0; j < FC_F; ++j) {
+                        res += src1_ptr[j][i10];
+                    }
+                }
+
+                if (FC_OP == 1) {
+                    FOR_UNROLL (short j = 0; j < FC_F; ++j) {
+                        res -= src1_ptr[j][i10];
+                    }
+                }
 
-    const int i13 = i03%args.ne13;
-    const int i12 = i02%args.ne12;
-    const int i11 = i01%args.ne11;
+                if (FC_OP == 2) {
+                    FOR_UNROLL (short j = 0; j < FC_F; ++j) {
+                        res *= src1_ptr[j][i10];
+                    }
+                }
 
-    device const char * src0_ptr = src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01 + args.offs;
-    device const char * src1_ptr = src1 + i13*args.nb13 + i12*args.nb12 + i11*args.nb11 + args.o1[0];
-    device       char * dst_ptr  = dst  + i03*args.nb3  + i02*args.nb2  + i01*args.nb1  + args.offs;
+                if (FC_OP == 3) {
+                    FOR_UNROLL (short j = 0; j < FC_F; ++j) {
+                        res /= src1_ptr[j][i10];
+                    }
+                }
 
-    if (args.ne10 == 1) {
-        const float x = 1.0f / *((device float *)(src1_ptr));
-        for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) {
-            *((device float *)(dst_ptr + i0*args.nb0)) = *((device float *)(src0_ptr + i0*args.nb00)) * x;
-        }
-    } else {
-        for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) {
-            const int i10 = i0%args.ne10;
-            *((device float *)(dst_ptr + i0*args.nb0)) = *((device float *)(src0_ptr + i0*args.nb00)) / *((device float *)(src1_ptr + i10*args.nb10));
+                dst_ptr[i0] = res;
+            }
         }
     }
+
+#undef FC_OP
+#undef FC_F
+#undef FC_RB
 }
 
+typedef decltype(kernel_bin_fuse_impl<float, float, float>) kernel_bin_fuse_t;
+
+template [[host_name("kernel_bin_fuse_f32_f32_f32")]]   kernel kernel_bin_fuse_t kernel_bin_fuse_impl<float,  float,  float>;
+template [[host_name("kernel_bin_fuse_f32_f32_f32_4")]] kernel kernel_bin_fuse_t kernel_bin_fuse_impl<float4, float4, float4>;
+
 kernel void kernel_add_id(
         constant ggml_metal_kargs_add_id & args,
         device const char * src0,
@@ -1057,7 +1073,7 @@ kernel void kernel_add_id(
     const size_t nb1 = args.ne0 * sizeof(float);
     const size_t nb2 = args.ne1 * nb1;
 
-    device       float * dst_row  = (device       float *)((device char *)dst + i1*nb1 + i2*nb2);
+    device       float * dst_row  = (device       float *)((device char *)dst  +  i1*nb1       + i2*nb2);
     device const float * src0_row = (device const float *)((device char *)src0 +  i1*args.nb01 + i2*args.nb02);
     device const float * src1_row = (device const float *)((device char *)src1 + i11*args.nb11);
 
@@ -1098,141 +1114,6 @@ template [[host_name("kernel_repeat_f16")]] kernel kernel_repeat_t kernel_repeat
 template [[host_name("kernel_repeat_i32")]] kernel kernel_repeat_t kernel_repeat<int>;
 template [[host_name("kernel_repeat_i16")]] kernel kernel_repeat_t kernel_repeat<short>;
 
-// assumption: src1 is a row
-// broadcast src1 into src0
-template <short F>
-kernel void kernel_add_row_c4_fuse_impl(
-        constant ggml_metal_kargs_bin & args,
-        device const char * src0,
-        device const char * src1,
-        device       char * dst,
-        uint tpig[[thread_position_in_grid]]) {
-    const uint nb = args.ne00/4;
-    const uint i  = tpig % nb;
-
-    device const float4 * src0_row = (device const float4 *) (src0);
-    device       float4 *  dst_row = (device       float4 *) (dst);
-
-    float4 res = src0_row[tpig];
-
-#pragma unroll(F)
-    for (short j = 0; j < F; ++j) {
-        res += ((device const float4 *) (src1 + args.o1[j]))[i];
-    }
-
-    dst_row[tpig] = res;
-}
-
-typedef decltype(kernel_add_row_c4_fuse_impl<1>) kernel_add_row_c4_fuse_t;
-
-template [[host_name("kernel_add_row_c4_fuse_1")]] kernel kernel_add_row_c4_fuse_t kernel_add_row_c4_fuse_impl<1>;
-template [[host_name("kernel_add_row_c4_fuse_2")]] kernel kernel_add_row_c4_fuse_t kernel_add_row_c4_fuse_impl<2>;
-template [[host_name("kernel_add_row_c4_fuse_3")]] kernel kernel_add_row_c4_fuse_t kernel_add_row_c4_fuse_impl<3>;
-template [[host_name("kernel_add_row_c4_fuse_4")]] kernel kernel_add_row_c4_fuse_t kernel_add_row_c4_fuse_impl<4>;
-template [[host_name("kernel_add_row_c4_fuse_5")]] kernel kernel_add_row_c4_fuse_t kernel_add_row_c4_fuse_impl<5>;
-template [[host_name("kernel_add_row_c4_fuse_6")]] kernel kernel_add_row_c4_fuse_t kernel_add_row_c4_fuse_impl<6>;
-template [[host_name("kernel_add_row_c4_fuse_7")]] kernel kernel_add_row_c4_fuse_t kernel_add_row_c4_fuse_impl<7>;
-template [[host_name("kernel_add_row_c4_fuse_8")]] kernel kernel_add_row_c4_fuse_t kernel_add_row_c4_fuse_impl<8>;
-
-template <short F>
-kernel void kernel_sub_row_c4_fuse_impl(
-        constant ggml_metal_kargs_bin & args,
-        device const char * src0,
-        device const char * src1,
-        device       char * dst,
-        uint tpig[[thread_position_in_grid]]) {
-
-    const uint nb = args.ne00/4;
-    const uint i  = tpig % nb;
-
-    device const float4 * src0_row = (device const float4 *) (src0);
-    device       float4 *  dst_row = (device       float4 *) (dst);
-
-    device const float4 * src1_row[F];
-    for (short j = 0; j < F; ++j) {
-        src1_row[j] = (device const float4 *) (src1 + args.o1[j]);
-    }
-
-    float4 res = src0_row[tpig];
-
-#pragma unroll(F)
-    for (short j = 0; j < F; ++j) {
-        res -= src1_row[j][i];
-    }
-
-    dst_row[tpig] = res;
-}
-
-typedef decltype(kernel_sub_row_c4_fuse_impl<1>) kernel_sub_row_c4_fuse_t;
-
-template [[host_name("kernel_sub_row_c4_fuse_1")]] kernel kernel_sub_row_c4_fuse_t kernel_sub_row_c4_fuse_impl<1>;
-
-template <short F>
-kernel void kernel_mul_row_c4_fuse_impl(
-        constant ggml_metal_kargs_bin & args,
-        device const char * src0,
-        device const char * src1,
-        device       char * dst,
-        uint tpig[[thread_position_in_grid]]) {
-
-    const uint nb = args.ne00/4;
-    const uint i  = tpig % nb;
-
-    device const float4 * src0_row = (device const float4 *) (src0);
-    device       float4 *  dst_row = (device       float4 *) (dst);
-
-    device const float4 * src1_row[F];
-    for (short j = 0; j < F; ++j) {
-        src1_row[j] = (device const float4 *) (src1 + args.o1[j]);
-    }
-
-    float4 res = src0_row[tpig];
-
-#pragma unroll(F)
-    for (short j = 0; j < F; ++j) {
-        res *= src1_row[j][i];
-    }
-
-    dst_row[tpig] = res;
-}
-
-typedef decltype(kernel_mul_row_c4_fuse_impl<1>) kernel_mul_row_c4_fuse_t;
-
-template [[host_name("kernel_mul_row_c4_fuse_1")]] kernel kernel_mul_row_c4_fuse_t kernel_mul_row_c4_fuse_impl<1>;
-
-template <short F>
-kernel void kernel_div_row_c4_fuse_impl(
-        constant ggml_metal_kargs_bin & args,
-        device const char * src0,
-        device const char * src1,
-        device       char * dst,
-        uint tpig[[thread_position_in_grid]]) {
-
-    const uint nb = args.ne00/4;
-    const uint i  = tpig % nb;
-
-    device const float4 * src0_row = (device const float4 *) (src0);
-    device       float4 *  dst_row = (device       float4 *) (dst);
-
-    device const float4 * src1_row[F];
-    for (short j = 0; j < F; ++j) {
-        src1_row[j] = (device const float4 *) (src1 + args.o1[j]);
-    }
-
-    float4 res = src0_row[tpig];
-
-#pragma unroll(F)
-    for (short j = 0; j < F; ++j) {
-        res /= src1_row[j][i];
-    }
-
-    dst_row[tpig] = res;
-}
-
-typedef decltype(kernel_div_row_c4_fuse_impl<1>) kernel_div_row_c4_fuse_t;
-
-template [[host_name("kernel_div_row_c4_fuse_1")]] kernel kernel_div_row_c4_fuse_t kernel_div_row_c4_fuse_impl<1>;
-
 kernel void kernel_scale_f32(
         constant ggml_metal_kargs_scale & args,
         device const float * src0,