]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
metal : fuse add, mul + add tests (#14596)
authorGeorgi Gerganov <redacted>
Fri, 18 Jul 2025 17:37:26 +0000 (20:37 +0300)
committerGitHub <redacted>
Fri, 18 Jul 2025 17:37:26 +0000 (20:37 +0300)
ggml-ci

ggml/src/ggml-alloc.c
ggml/src/ggml-backend.cpp
ggml/src/ggml-impl.h
ggml/src/ggml-metal/ggml-metal-impl.h
ggml/src/ggml-metal/ggml-metal.m
ggml/src/ggml-metal/ggml-metal.metal
src/llama-graph.cpp
tests/test-backend-ops.cpp

index 5fd379f6a9461ec0b1a6913d1fef5f52b683efc9..fcc552da519b135ab465dfcdd9a0aebe6a365e19 100644 (file)
@@ -22,21 +22,6 @@ static bool ggml_is_view(const struct ggml_tensor * t) {
     return t->view_src != NULL;
 }
 
-static bool ggml_are_same_layout(const struct ggml_tensor * a, const struct ggml_tensor * b) {
-    if (a->type != b->type) {
-        return false;
-    }
-    for (int i = 0; i < GGML_MAX_DIMS; i++) {
-        if (a->ne[i] != b->ne[i]) {
-            return false;
-        }
-        if (a->nb[i] != b->nb[i]) {
-            return false;
-        }
-    }
-    return true;
-}
-
 // ops that return true for this function must not use restrict pointers for their backend implementations
 static bool ggml_op_can_inplace(enum ggml_op op) {
     switch (op) {
index 788861a365fab65e234d469f3730b7bf970d6b7e..b7498b8d40238f44ed6bbbfc1134b5f5663b10ca 100644 (file)
@@ -352,21 +352,6 @@ ggml_backend_dev_t ggml_backend_get_device(ggml_backend_t backend) {
 
 // backend copy
 
-static bool ggml_are_same_layout(const struct ggml_tensor * a, const struct ggml_tensor * b) {
-    if (a->type != b->type) {
-        return false;
-    }
-    for (int i = 0; i < GGML_MAX_DIMS; i++) {
-        if (a->ne[i] != b->ne[i]) {
-            return false;
-        }
-        if (a->nb[i] != b->nb[i]) {
-            return false;
-        }
-    }
-    return true;
-}
-
 void ggml_backend_tensor_copy(struct ggml_tensor * src, struct ggml_tensor * dst) {
     GGML_ASSERT(ggml_are_same_layout(src, dst) && "cannot copy tensors with different layouts");
 
index 4972558c98b81b46865cf2d708469a1243181863..a2e30994c4669fbfd60e1230df983dc43b63f913 100644 (file)
@@ -73,6 +73,22 @@ static inline int ggml_up(int n, int m) {
     return (n + m - 1) & ~(m - 1);
 }
 
+// TODO: move to ggml.h?
+static bool ggml_are_same_layout(const struct ggml_tensor * a, const struct ggml_tensor * b) {
+    if (a->type != b->type) {
+        return false;
+    }
+    for (int i = 0; i < GGML_MAX_DIMS; i++) {
+        if (a->ne[i] != b->ne[i]) {
+            return false;
+        }
+        if (a->nb[i] != b->nb[i]) {
+            return false;
+        }
+    }
+    return true;
+}
+
 //
 // logging
 //
index 752d55c216604a9788275a5d9e20def12be9e15b..b7b3fc49af35db9e9d57ed9d27838d197f3cfb82 100644 (file)
@@ -126,6 +126,7 @@ typedef struct {
     uint64_t nb2;
     uint64_t nb3;
     uint64_t offs;
+    uint64_t o1[8];
 } ggml_metal_kargs_bin;
 
 typedef struct {
@@ -240,7 +241,7 @@ typedef struct {
     float    max_bias;
     float    m0;
     float    m1;
-    uint16_t n_head_log2;
+    int32_t  n_head_log2;
     float    logit_softcap;
 } ggml_metal_kargs_flash_attn_ext;
 
@@ -377,8 +378,16 @@ typedef struct {
 typedef struct {
     int32_t  ne00;
     int32_t  ne00_4;
-    uint64_t nb01;
+    uint64_t nb1;
+    uint64_t nb2;
+    uint64_t nb3;
     float    eps;
+    int32_t  nef1[3];
+    int32_t  nef2[3];
+    int32_t  nef3[3];
+    uint64_t nbf1[3];
+    uint64_t nbf2[3];
+    uint64_t nbf3[3];
 } ggml_metal_kargs_rms_norm;
 
 typedef struct {
@@ -484,7 +493,7 @@ typedef struct {
     float    max_bias;
     float    m0;
     float    m1;
-    uint32_t n_head_log2;
+    int32_t  n_head_log2;
 } ggml_metal_kargs_soft_max;
 
 typedef struct {
index 44ddc69d08f1ce752e2662cdfa351cac72c7c118..dc391a0d4d5499fcc53ac9204ecb55da535bf44c 100644 (file)
@@ -55,6 +55,12 @@ static struct ggml_backend_metal_device_context {
     bool has_residency_sets;
     bool has_bfloat;
     bool use_bfloat;
+    bool use_fusion;
+
+    int debug_fusion;
+
+    // how many times a given op was fused
+    uint64_t fuse_cnt[GGML_OP_COUNT];
 
     size_t max_size;
 
@@ -69,6 +75,9 @@ static struct ggml_backend_metal_device_context {
     /*.has_residency_sets      =*/ false,
     /*.has_bfloat              =*/ false,
     /*.use_bfloat              =*/ false,
+    /*.use_fusion              =*/ true,
+    /*.debug_fusion            =*/ 0,
+    /*.fuse_cnt                =*/ { 0 },
     /*.max_size                =*/ 0,
     /*.name                    =*/ "",
 };
@@ -83,16 +92,14 @@ static id<MTLDevice> ggml_backend_metal_device_acq(struct ggml_backend_metal_dev
 
     if (ctx->mtl_device == nil) {
         ctx->mtl_device = MTLCreateSystemDefaultDevice();
-    }
 
-    if (ctx->mtl_device) {
         ctx->has_simdgroup_reduction  = [ctx->mtl_device supportsFamily:MTLGPUFamilyApple7];
         ctx->has_simdgroup_reduction |= [ctx->mtl_device supportsFamily:MTLGPUFamilyMetal3_GGML];
 
         ctx->has_simdgroup_mm = [ctx->mtl_device supportsFamily:MTLGPUFamilyApple7];
 
 #if defined(GGML_METAL_HAS_RESIDENCY_SETS)
-        ctx->has_residency_sets = getenv("GGML_METAL_NO_RESIDENCY") == NULL;
+        ctx->has_residency_sets = getenv("GGML_METAL_NO_RESIDENCY") == nil;
 #endif
 
         ctx->has_bfloat  = [ctx->mtl_device supportsFamily:MTLGPUFamilyMetal3_GGML];
@@ -103,6 +110,14 @@ static id<MTLDevice> ggml_backend_metal_device_acq(struct ggml_backend_metal_dev
 #else
         ctx->use_bfloat = false;
 #endif
+        ctx->use_fusion = getenv("GGML_METAL_FUSION_DISABLE") == nil;
+
+        {
+            const char * val = getenv("GGML_METAL_FUSION_DEBUG");
+            ctx->debug_fusion = val ? atoi(val) : 0;
+        }
+
+        memset(ctx->fuse_cnt, 0, sizeof(ctx->fuse_cnt));
 
         ctx->max_size = ctx->mtl_device.maxBufferLength;
 
@@ -122,6 +137,18 @@ static void ggml_backend_metal_device_rel(struct ggml_backend_metal_device_conte
     ctx->mtl_device_ref_count--;
 
     if (ctx->mtl_device_ref_count == 0) {
+        if (ctx->debug_fusion > 0) {
+            fprintf(stderr, "%s: fusion stats:\n", __func__);
+            for (int i = 0; i < GGML_OP_COUNT; i++) {
+                if (ctx->fuse_cnt[i] == 0) {
+                    continue;
+                }
+
+                // note: cannot use ggml_log here
+                fprintf(stderr, "%s: - %s: %" PRIu64 "\n", __func__, ggml_op_name((enum ggml_op) i), ctx->fuse_cnt[i]);
+            }
+        }
+
         if (ctx->mtl_lock) {
             [ctx->mtl_lock release];
             ctx->mtl_lock = nil;
@@ -147,13 +174,27 @@ struct ggml_metal_kernel {
 
 enum ggml_metal_kernel_type {
     GGML_METAL_KERNEL_TYPE_ADD,
-    GGML_METAL_KERNEL_TYPE_ADD_ROW,
+    GGML_METAL_KERNEL_TYPE_ADD_FUSE_2,
+    GGML_METAL_KERNEL_TYPE_ADD_FUSE_3,
+    GGML_METAL_KERNEL_TYPE_ADD_FUSE_4,
+    GGML_METAL_KERNEL_TYPE_ADD_FUSE_5,
+    GGML_METAL_KERNEL_TYPE_ADD_FUSE_6,
+    GGML_METAL_KERNEL_TYPE_ADD_FUSE_7,
+    GGML_METAL_KERNEL_TYPE_ADD_FUSE_8,
+    GGML_METAL_KERNEL_TYPE_ADD_ROW_C4,
+    GGML_METAL_KERNEL_TYPE_ADD_ROW_C4_FUSE_2,
+    GGML_METAL_KERNEL_TYPE_ADD_ROW_C4_FUSE_3,
+    GGML_METAL_KERNEL_TYPE_ADD_ROW_C4_FUSE_4,
+    GGML_METAL_KERNEL_TYPE_ADD_ROW_C4_FUSE_5,
+    GGML_METAL_KERNEL_TYPE_ADD_ROW_C4_FUSE_6,
+    GGML_METAL_KERNEL_TYPE_ADD_ROW_C4_FUSE_7,
+    GGML_METAL_KERNEL_TYPE_ADD_ROW_C4_FUSE_8,
     GGML_METAL_KERNEL_TYPE_SUB,
-    GGML_METAL_KERNEL_TYPE_SUB_ROW,
+    GGML_METAL_KERNEL_TYPE_SUB_ROW_C4,
     GGML_METAL_KERNEL_TYPE_MUL,
-    GGML_METAL_KERNEL_TYPE_MUL_ROW,
+    GGML_METAL_KERNEL_TYPE_MUL_ROW_C4,
     GGML_METAL_KERNEL_TYPE_DIV,
-    GGML_METAL_KERNEL_TYPE_DIV_ROW,
+    GGML_METAL_KERNEL_TYPE_DIV_ROW_C4,
     GGML_METAL_KERNEL_TYPE_REPEAT_F32,
     GGML_METAL_KERNEL_TYPE_REPEAT_F16,
     GGML_METAL_KERNEL_TYPE_REPEAT_I32,
@@ -218,6 +259,8 @@ enum ggml_metal_kernel_type {
     GGML_METAL_KERNEL_TYPE_SET_ROWS_Q5_1,
     GGML_METAL_KERNEL_TYPE_SET_ROWS_IQ4_NL,
     GGML_METAL_KERNEL_TYPE_RMS_NORM,
+    GGML_METAL_KERNEL_TYPE_RMS_NORM_MUL,
+    GGML_METAL_KERNEL_TYPE_RMS_NORM_MUL_ADD,
     GGML_METAL_KERNEL_TYPE_L2_NORM,
     GGML_METAL_KERNEL_TYPE_GROUP_NORM,
     GGML_METAL_KERNEL_TYPE_NORM,
@@ -1135,13 +1178,27 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de
         // simd_sum and simd_max requires MTLGPUFamilyApple7
 
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ADD,                             add,                             true);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ADD_ROW,                         add_row,                         true);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ADD_FUSE_2,                      add_fuse_2,                      true);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ADD_FUSE_3,                      add_fuse_3,                      true);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ADD_FUSE_4,                      add_fuse_4,                      true);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ADD_FUSE_5,                      add_fuse_5,                      true);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ADD_FUSE_6,                      add_fuse_6,                      true);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ADD_FUSE_7,                      add_fuse_7,                      true);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ADD_FUSE_8,                      add_fuse_8,                      true);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ADD_ROW_C4,                      add_row_c4,                      true);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ADD_ROW_C4_FUSE_2,               add_row_c4_fuse_2,               true);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ADD_ROW_C4_FUSE_3,               add_row_c4_fuse_3,               true);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ADD_ROW_C4_FUSE_4,               add_row_c4_fuse_4,               true);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ADD_ROW_C4_FUSE_5,               add_row_c4_fuse_5,               true);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ADD_ROW_C4_FUSE_6,               add_row_c4_fuse_6,               true);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ADD_ROW_C4_FUSE_7,               add_row_c4_fuse_7,               true);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ADD_ROW_C4_FUSE_8,               add_row_c4_fuse_8,               true);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SUB,                             sub,                             true);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SUB_ROW,                         sub_row,                         true);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SUB_ROW_C4,                      sub_row_c4,                      true);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL,                             mul,                             true);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_ROW,                         mul_row,                         true);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_ROW_C4,                      mul_row_c4,                      true);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_DIV,                             div,                             true);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_DIV_ROW,                         div_row,                         true);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_DIV_ROW_C4,                      div_row_c4,                      true);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_REPEAT_F32,                      repeat_f32,                      true);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_REPEAT_F16,                      repeat_f16,                      true);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_REPEAT_I32,                      repeat_i32,                      true);
@@ -1206,6 +1263,8 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SET_ROWS_Q5_1,                   set_rows_q5_1,                   true);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SET_ROWS_IQ4_NL,                 set_rows_iq4_nl,                 true);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_RMS_NORM,                        rms_norm,                        has_simdgroup_reduction);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_RMS_NORM_MUL,                    rms_norm_mul,                    has_simdgroup_reduction);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_RMS_NORM_MUL_ADD,                rms_norm_mul_add,                has_simdgroup_reduction);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_L2_NORM,                         l2_norm,                         has_simdgroup_reduction);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GROUP_NORM,                      group_norm,                      has_simdgroup_reduction);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_NORM,                            norm,                            true);
@@ -1893,7 +1952,7 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_device_contex
     }
 }
 
-static bool ggml_metal_encode_node(
+static int ggml_metal_encode_node(
                         ggml_backend_t   backend,
                                    int   idx,
           id<MTLComputeCommandEncoder>   encoder,
@@ -1903,7 +1962,10 @@ static bool ggml_metal_encode_node(
 
     struct ggml_cgraph * gf = ctx->gf;
 
-    struct ggml_tensor * node = ggml_graph_node(gf, idx);
+    enum ggml_op ops[8];
+
+    struct ggml_tensor ** nodes = ggml_graph_nodes(gf) + idx;
+    struct ggml_tensor *  node  = nodes[0];
 
     //GGML_LOG_INFO("%s: encoding node %3d, op = %8s\n", __func__, idx, ggml_op_name(node->op));
 
@@ -1913,7 +1975,7 @@ static bool ggml_metal_encode_node(
     struct ggml_tensor * dst  = node;
 
     if (ggml_is_empty(dst)) {
-        return true;
+        return 1;
     }
 
     switch (dst->op) {
@@ -1924,7 +1986,7 @@ static bool ggml_metal_encode_node(
         case GGML_OP_PERMUTE:
             {
                 // noop -> next node
-            } return true;
+            } return 1;
         default:
             {
             } break;
@@ -1991,6 +2053,8 @@ static bool ggml_metal_encode_node(
     id<MTLBuffer> id_src2 = src2 ? ggml_metal_get_buffer(src2, &offs_src2) : nil;
     id<MTLBuffer> id_dst  = dst  ? ggml_metal_get_buffer(dst,  &offs_dst)  : nil;
 
+    int n_fuse = 1;
+
 #if 0
     GGML_LOG_INFO("%s: op - %s\n", __func__, ggml_op_name(dst->op));
     if (src0) {
@@ -2062,37 +2126,15 @@ static bool ggml_metal_encode_node(
                 GGML_ASSERT(src0t == GGML_TYPE_F32);
                 GGML_ASSERT(src1t == GGML_TYPE_F32);
 
+                GGML_ASSERT(ggml_is_contiguous_rows(src0));
+                GGML_ASSERT(ggml_is_contiguous_rows(src1));
+
                 const size_t offs = 0;
 
                 bool bcast_row = false;
 
                 id<MTLComputePipelineState> pipeline = nil;
 
-                if (ggml_nelements(src1) == ne10 && ggml_is_contiguous(src1) && ne00 % 4 == 0 && ne10 % 4 == 0) {
-                    GGML_ASSERT(ggml_is_contiguous(src0));
-
-                    // src1 is a row
-                    GGML_ASSERT(ne11 == 1);
-
-                    switch (dst->op) {
-                        case GGML_OP_ADD: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ADD_ROW].pipeline; break;
-                        case GGML_OP_SUB: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SUB_ROW].pipeline; break;
-                        case GGML_OP_MUL: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_ROW].pipeline; break;
-                        case GGML_OP_DIV: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_DIV_ROW].pipeline; break;
-                        default: GGML_ABORT("fatal error");
-                    }
-
-                    bcast_row = true;
-                } else {
-                    switch (dst->op) {
-                        case GGML_OP_ADD: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ADD].pipeline; break;
-                        case GGML_OP_SUB: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SUB].pipeline; break;
-                        case GGML_OP_MUL: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL].pipeline; break;
-                        case GGML_OP_DIV: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_DIV].pipeline; break;
-                        default: GGML_ABORT("fatal error");
-                    }
-                }
-
                 ggml_metal_kargs_bin args = {
                     /*.ne00 =*/ ne00,
                     /*.ne01 =*/ ne01,
@@ -2119,12 +2161,117 @@ static bool ggml_metal_encode_node(
                     /*.nb2  =*/ nb2,
                     /*.nb3  =*/ nb3,
                     /*.offs =*/ offs,
+                    /*.o1   =*/ { offs_src1 },
                 };
 
+                // c[0] = add(a,    b[0])
+                // c[1] = add(c[0], b[1])
+                // c[2] = add(c[1], b[2])
+                // ...
+                if (ctx_dev->use_fusion) {
+                    ops[0] = GGML_OP_ADD;
+                    ops[1] = GGML_OP_ADD;
+                    ops[2] = GGML_OP_ADD;
+                    ops[3] = GGML_OP_ADD;
+                    ops[4] = GGML_OP_ADD;
+                    ops[5] = GGML_OP_ADD;
+                    ops[6] = GGML_OP_ADD;
+                    ops[7] = GGML_OP_ADD;
+
+                    size_t offs_fuse;
+                    id<MTLBuffer> id_fuse;
+
+                    for (n_fuse = 0; n_fuse <= 6; ++n_fuse) {
+                        if (!ggml_can_fuse(gf, idx + n_fuse, ops + n_fuse, 2)) {
+                            break;
+                        }
+
+                        if (nodes[n_fuse] != nodes[n_fuse + 1]->src[0]) {
+                            break;
+                        }
+
+                        // b[0] === b[1] === ...
+                        if (!ggml_are_same_layout(nodes[n_fuse]->src[1], nodes[n_fuse + 1]->src[1])) {
+                            break;
+                        }
+
+                        // only fuse nodes if src1 is in the same Metal buffer
+                        id_fuse = ggml_metal_get_buffer(nodes[n_fuse + 1]->src[1], &offs_fuse);
+                        if (id_fuse != id_src1) {
+                            break;
+                        }
+
+                        ctx_dev->fuse_cnt[nodes[n_fuse + 1]->op]++;
+
+                        args.o1[n_fuse + 1] = offs_fuse;
+                    }
+
+                    ++n_fuse;
+
+                    if (ctx_dev->debug_fusion > 1 && n_fuse > 1) {
+                        GGML_LOG_DEBUG("%s: fuse: ADD x %d\n", __func__, n_fuse);
+                    }
+                }
+
+                if (ggml_nelements(src1) == ne10 && ggml_is_contiguous(src1) && ne00 % 4 == 0 && ne10 % 4 == 0) {
+                    GGML_ASSERT(ggml_is_contiguous(src0));
+
+                    // src1 is a row
+                    GGML_ASSERT(ne11 == 1);
+
+                    switch (dst->op) {
+                        case GGML_OP_ADD:
+                            {
+                                switch (n_fuse) {
+                                    case 1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ADD_ROW_C4       ].pipeline; break;
+                                    case 2: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ADD_ROW_C4_FUSE_2].pipeline; break;
+                                    case 3: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ADD_ROW_C4_FUSE_3].pipeline; break;
+                                    case 4: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ADD_ROW_C4_FUSE_4].pipeline; break;
+                                    case 5: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ADD_ROW_C4_FUSE_5].pipeline; break;
+                                    case 6: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ADD_ROW_C4_FUSE_6].pipeline; break;
+                                    case 7: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ADD_ROW_C4_FUSE_7].pipeline; break;
+                                    case 8: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ADD_ROW_C4_FUSE_8].pipeline; break;
+                                    default: GGML_ABORT("fatal error");
+                                }
+                            } break;
+                        case GGML_OP_SUB: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SUB_ROW_C4].pipeline; break;
+                        case GGML_OP_MUL: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_ROW_C4].pipeline; break;
+                        case GGML_OP_DIV: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_DIV_ROW_C4].pipeline; break;
+                        default: GGML_ABORT("fatal error");
+                    }
+
+                    bcast_row = true;
+                } else {
+                    switch (dst->op) {
+                        case GGML_OP_ADD:
+                            {
+                                switch (n_fuse) {
+                                    case 1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ADD       ].pipeline; break;
+                                    case 2: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ADD_FUSE_2].pipeline; break;
+                                    case 3: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ADD_FUSE_3].pipeline; break;
+                                    case 4: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ADD_FUSE_4].pipeline; break;
+                                    case 5: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ADD_FUSE_5].pipeline; break;
+                                    case 6: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ADD_FUSE_6].pipeline; break;
+                                    case 7: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ADD_FUSE_7].pipeline; break;
+                                    case 8: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ADD_FUSE_8].pipeline; break;
+                                    default: GGML_ABORT("fatal error");
+                                }
+                            } break;
+                        case GGML_OP_SUB: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SUB].pipeline; break;
+                        case GGML_OP_MUL: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL].pipeline; break;
+                        case GGML_OP_DIV: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_DIV].pipeline; break;
+                        default: GGML_ABORT("fatal error");
+                    }
+                }
+
+                if (n_fuse > 1) {
+                    id_dst = ggml_metal_get_buffer(nodes[n_fuse - 1], &offs_dst);
+                }
+
                 [encoder setComputePipelineState:pipeline];
                 [encoder setBytes:&args length:sizeof(args) atIndex:0];
                 [encoder setBuffer:id_src0 offset:offs_src0 atIndex:1];
-                [encoder setBuffer:id_src1 offset:offs_src1 atIndex:2];
+                [encoder setBuffer:id_src1 offset:0         atIndex:2];
                 [encoder setBuffer:id_dst  offset:offs_dst  atIndex:3];
 
                 if (bcast_row) {
@@ -2132,7 +2279,11 @@ static bool ggml_metal_encode_node(
 
                     [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
                 } else {
-                    const int nth = MIN((int) pipeline.maxTotalThreadsPerThreadgroup, ne0);
+                    int nth = 32;
+
+                    while (16*nth < ne0 && nth < (int) pipeline.maxTotalThreadsPerThreadgroup) {
+                        nth *= 2;
+                    }
 
                     [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
                 }
@@ -2257,12 +2408,13 @@ static bool ggml_metal_encode_node(
                     /*.nb2  =*/ pnb2,
                     /*.nb3  =*/ pnb3,
                     /*.offs =*/ offs,
+                    /*.o1   =*/ { offs_src1},
                 };
 
                 [encoder setComputePipelineState:pipeline];
                 [encoder setBytes:&args length:sizeof(args) atIndex:0];
                 [encoder setBuffer:id_src0 offset:offs_src0 atIndex:1];
-                [encoder setBuffer:id_src1 offset:offs_src1 atIndex:2];
+                [encoder setBuffer:id_src1 offset:0         atIndex:2];
                 [encoder setBuffer:id_dst  offset:offs_dst  atIndex:3];
 
                 const int nth = MIN((int) pipeline.maxTotalThreadsPerThreadgroup, ne00);
@@ -2764,7 +2916,7 @@ static bool ggml_metal_encode_node(
                 id<MTLBuffer> h_src0 = h_src0 = ggml_metal_mem_pool_alloc(mem_pool, ggml_nbytes(src0));
                 if (!h_src0) {
                     GGML_LOG_ERROR("%s: failed to allocate buffer from memory pool, size = %zu\n", __func__, ggml_nbytes(src0));
-                    return false;
+                    return 0;
                 }
 
                 offs_src0 = 0;
@@ -3640,7 +3792,7 @@ static bool ggml_metal_encode_node(
                     id<MTLBuffer> h_src1 = ggml_metal_mem_pool_alloc(mem_pool, s_src1);
                     if (!h_src1) {
                         GGML_LOG_ERROR("%s: failed to allocate buffer from memory pool, size = %zu\n", __func__, s_src1);
-                        return false;
+                        return 0;
                     }
 
                     const int64_t neh0 = ne0;
@@ -3656,7 +3808,7 @@ static bool ggml_metal_encode_node(
                     id<MTLBuffer> h_dst = ggml_metal_mem_pool_alloc(mem_pool, s_dst);
                     if (!h_dst) {
                         GGML_LOG_ERROR("%s: failed to allocate buffer from memory pool, size = %zu\n", __func__, s_dst);
-                        return false;
+                        return 0;
                     }
 
                     // tokens per expert
@@ -3664,7 +3816,7 @@ static bool ggml_metal_encode_node(
                     id<MTLBuffer> h_tpe = ggml_metal_mem_pool_alloc(mem_pool, s_tpe);
                     if (!h_tpe) {
                         GGML_LOG_ERROR("%s: failed to allocate buffer from memory pool, size = %zu\n", __func__, s_tpe);
-                        return false;
+                        return 0;
                     }
 
                     // id map
@@ -3673,7 +3825,7 @@ static bool ggml_metal_encode_node(
                     id<MTLBuffer> h_ids = ggml_metal_mem_pool_alloc(mem_pool, s_ids);
                     if (!h_ids) {
                         GGML_LOG_ERROR("%s: failed to allocate buffer from memory pool, size = %zu\n", __func__, s_ids);
-                        return false;
+                        return 0;
                     }
 
                     {
@@ -4105,12 +4257,95 @@ static bool ggml_metal_encode_node(
         case GGML_OP_RMS_NORM:
             {
                 GGML_ASSERT(ne00 % 4 == 0);
-                GGML_ASSERT(ggml_is_contiguous_1(src0));
+                GGML_ASSERT(ggml_is_contiguous_rows(src0));
 
                 float eps;
                 memcpy(&eps, dst->op_params, sizeof(float));
 
-                id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_RMS_NORM].pipeline;
+                ggml_metal_kargs_rms_norm args = {
+                    /*.ne00   =*/ ne00,
+                    /*.ne00_4 =*/ ne00/4,
+                    /*.nb1    =*/ nb1,
+                    /*.nb2    =*/ nb2,
+                    /*.nb3    =*/ nb3,
+                    /*.eps    =*/ eps,
+                    /*.nef1   =*/ { ne01 },
+                    /*.nef2   =*/ { ne02 },
+                    /*.nef3   =*/ { ne03 },
+                    /*.nbf1   =*/ { nb01 },
+                    /*.nbf2   =*/ { nb02 },
+                    /*.nbf3   =*/ { nb03 },
+                };
+
+                size_t offs_fuse[2] = { 0, 0 };
+                id<MTLBuffer> id_fuse[2] = { id_src0, id_src0 };
+
+                // d[0] = rms_norm(a)
+                // d[1] = mul(d[0], b)
+                // d[2] = add(d[1], c)
+                if (ctx_dev->use_fusion) {
+                    ops[0] = GGML_OP_RMS_NORM;
+                    ops[1] = GGML_OP_MUL;
+                    ops[2] = GGML_OP_ADD;
+
+                    for (n_fuse = 0; n_fuse <= 1; ++n_fuse) {
+                        if (!ggml_can_fuse(gf, idx + n_fuse, ops + n_fuse, 2)) {
+                            break;
+                        }
+
+                        if (nodes[n_fuse] != nodes[n_fuse + 1]->src[0]) {
+                            break;
+                        }
+
+                        if (nodes[n_fuse + 1]->src[1]->ne[0] != node->ne[0]) {
+                            break;
+                        }
+
+                        if (!ggml_is_contiguous_rows(nodes[n_fuse + 1]->src[1])) {
+                            break;
+                        }
+
+                        if (nodes[n_fuse + 1]->type != GGML_TYPE_F32) {
+                            break;
+                        }
+
+                        ctx_dev->fuse_cnt[nodes[n_fuse + 1]->op]++;
+
+                        id_fuse[n_fuse] = ggml_metal_get_buffer(nodes[n_fuse + 1]->src[1], &offs_fuse[n_fuse]);
+
+                        args.nef1[n_fuse + 1] = nodes[n_fuse + 1]->src[1]->ne[1];
+                        args.nef2[n_fuse + 1] = nodes[n_fuse + 1]->src[1]->ne[2];
+                        args.nef3[n_fuse + 1] = nodes[n_fuse + 1]->src[1]->ne[3];
+
+                        args.nbf1[n_fuse + 1] = nodes[n_fuse + 1]->src[1]->nb[1];
+                        args.nbf2[n_fuse + 1] = nodes[n_fuse + 1]->src[1]->nb[2];
+                        args.nbf3[n_fuse + 1] = nodes[n_fuse + 1]->src[1]->nb[3];
+                    }
+
+                    ++n_fuse;
+
+                    if (ctx_dev->debug_fusion > 1 && n_fuse > 1) {
+                        if (n_fuse == 2) {
+                            GGML_LOG_DEBUG("%s: fuse: RMS_NORM + MUL\n", __func__);
+                        }
+                        if (n_fuse == 3) {
+                            GGML_LOG_DEBUG("%s: fuse: RMS_NORM + MUL + ADD\n", __func__);
+                        }
+                    }
+                }
+
+                if (n_fuse > 1) {
+                    id_dst = ggml_metal_get_buffer(nodes[n_fuse - 1], &offs_dst);
+                }
+
+                id<MTLComputePipelineState> pipeline;
+
+                switch (n_fuse) {
+                    case 1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_RMS_NORM        ].pipeline; break;
+                    case 2: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_RMS_NORM_MUL    ].pipeline; break;
+                    case 3: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_RMS_NORM_MUL_ADD].pipeline; break;
+                    default: GGML_ABORT("unsupported n_fuse = %d\n", n_fuse);
+                }
 
                 int nth = 32; // SIMD width
 
@@ -4121,23 +4356,16 @@ static bool ggml_metal_encode_node(
                 nth = MIN(nth, (int) pipeline.maxTotalThreadsPerThreadgroup);
                 nth = MIN(nth, ne00/4);
 
-                ggml_metal_kargs_rms_norm args = {
-                    /*.ne00   =*/ ne00,
-                    /*.ne00_4 =*/ ne00/4,
-                    /*.nb01   =*/ nb01,
-                    /*.eps    =*/ eps,
-                };
-
                 [encoder setComputePipelineState:pipeline];
-                [encoder setBytes:&args length:sizeof(args) atIndex:0];
-                [encoder setBuffer:id_src0 offset:offs_src0 atIndex:1];
-                [encoder setBuffer:id_dst  offset:offs_dst  atIndex:2];
+                [encoder setBytes:&args length:sizeof(args)       atIndex:0];
+                [encoder setBuffer:id_src0    offset:offs_src0    atIndex:1];
+                [encoder setBuffer:id_fuse[0] offset:offs_fuse[0] atIndex:2];
+                [encoder setBuffer:id_fuse[1] offset:offs_fuse[1] atIndex:3];
+                [encoder setBuffer:id_dst     offset:offs_dst     atIndex:4];
 
                 [encoder setThreadgroupMemoryLength:32*sizeof(float) atIndex:0];
 
-                const int64_t nrows = ggml_nrows(src0);
-
-                [encoder dispatchThreadgroups:MTLSizeMake(nrows, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
+                [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
             } break;
         case GGML_OP_L2_NORM:
             {
@@ -5532,7 +5760,7 @@ static bool ggml_metal_encode_node(
             }
     }
 
-    return true;
+    return n_fuse;
 }
 
 static enum ggml_status ggml_metal_graph_compute(
@@ -6038,20 +6266,22 @@ static void ggml_backend_metal_set_n_cb(ggml_backend_t backend, int n_cb) {
         struct ggml_metal_mem_pool * mem_pool = ctx->cmd_bufs[cb_idx].mem_pool;
         ggml_metal_mem_pool_reset(mem_pool);
 
-        for (int idx = node_start; idx < node_end; ++idx) {
+        for (int idx = node_start; idx < node_end;) {
             if (should_capture) {
                 [encoder pushDebugGroup:[NSString stringWithCString:ggml_op_desc(ggml_graph_node(ctx->gf, idx)) encoding:NSUTF8StringEncoding]];
             }
 
-            const bool res = ggml_metal_encode_node(backend, idx, encoder, mem_pool);
+            const int res = ggml_metal_encode_node(backend, idx, encoder, mem_pool);
 
             if (should_capture) {
                 [encoder popDebugGroup];
             }
 
-            if (!res) {
+            if (res == 0) {
                 break;
             }
+
+            idx += res;
         }
 
         [encoder endEncoding];
index 13235e28852414436929a79611b78933a0e0daaf..f62b9ad548e695ce00a8c8aa0681f59e59efc37f 100644 (file)
@@ -832,7 +832,8 @@ enum ggml_sort_order {
 // general-purpose kernel for addition, subtraction, multiplication and division of two tensors
 // pros: works for non-contiguous tensors, supports broadcast across all dims
 // cons: not very efficient
-kernel void kernel_add(
+template <int F>
+kernel void kernel_add_fuse_impl(
         constant ggml_metal_kargs_bin & args,
         device const char * src0,
         device const char * src1,
@@ -848,16 +849,39 @@ kernel void kernel_add(
     const int i12 = i02%args.ne12;
     const int i11 = i01%args.ne11;
 
-    device const char * src0_ptr = src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01 + args.offs;
-    device const char * src1_ptr = src1 + i13*args.nb13 + i12*args.nb12 + i11*args.nb11;
-    device       char * dst_ptr  = dst  + i03*args.nb3  + i02*args.nb2  + i01*args.nb1  + args.offs;
+    device const float * src0_ptr = (device const float *) (src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01 + args.offs);
+    device       float * dst_ptr  = (device       float *) (dst  + i03*args.nb3  + i02*args.nb2  + i01*args.nb1  + args.offs);
+
+    device const float * src1_ptr[F];
+    for (short j = 0; j < F; ++j) {
+        src1_ptr[j] = (device const float *) (src1 + args.o1[j] + i13*args.nb13 + i12*args.nb12 + i11*args.nb11);
+    }
 
     for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) {
         const int i10 = i0%args.ne10;
-        *((device float *)(dst_ptr + i0*args.nb0)) = *((device float *)(src0_ptr + i0*args.nb00)) + *((device float *)(src1_ptr + i10*args.nb10));
+
+        float res = src0_ptr[i0];
+
+#pragma unroll
+        for (short j = 0; j < F; ++j) {
+            res += src1_ptr[j][i10];
+        }
+
+        dst_ptr[i0] = res;
     }
 }
 
+typedef decltype(kernel_add_fuse_impl<2>) kernel_add_fuse_t;
+
+template [[host_name("kernel_add")]]        kernel kernel_add_fuse_t kernel_add_fuse_impl<1>;
+template [[host_name("kernel_add_fuse_2")]] kernel kernel_add_fuse_t kernel_add_fuse_impl<2>;
+template [[host_name("kernel_add_fuse_3")]] kernel kernel_add_fuse_t kernel_add_fuse_impl<3>;
+template [[host_name("kernel_add_fuse_4")]] kernel kernel_add_fuse_t kernel_add_fuse_impl<4>;
+template [[host_name("kernel_add_fuse_5")]] kernel kernel_add_fuse_t kernel_add_fuse_impl<5>;
+template [[host_name("kernel_add_fuse_6")]] kernel kernel_add_fuse_t kernel_add_fuse_impl<6>;
+template [[host_name("kernel_add_fuse_7")]] kernel kernel_add_fuse_t kernel_add_fuse_impl<7>;
+template [[host_name("kernel_add_fuse_8")]] kernel kernel_add_fuse_t kernel_add_fuse_impl<8>;
+
 kernel void kernel_sub(
         constant ggml_metal_kargs_bin & args,
         device const char * src0,
@@ -875,7 +899,7 @@ kernel void kernel_sub(
     const int i11 = i01%args.ne11;
 
     device const char * src0_ptr = src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01 + args.offs;
-    device const char * src1_ptr = src1 + i13*args.nb13 + i12*args.nb12 + i11*args.nb11;
+    device const char * src1_ptr = src1 + i13*args.nb13 + i12*args.nb12 + i11*args.nb11 + args.o1[0];
     device       char * dst_ptr  = dst  + i03*args.nb3  + i02*args.nb2  + i01*args.nb1  + args.offs;
 
     for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) {
@@ -900,9 +924,9 @@ kernel void kernel_mul(
     const int i12 = i02%args.ne12;
     const int i11 = i01%args.ne11;
 
-    device const char * src0_ptr = src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01;
-    device const char * src1_ptr = src1 + i13*args.nb13 + i12*args.nb12 + i11*args.nb11;
-    device       char * dst_ptr  = dst  + i03*args.nb3  + i02*args.nb2  + i01*args.nb1;
+    device const char * src0_ptr = src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01 + args.offs;
+    device const char * src1_ptr = src1 + i13*args.nb13 + i12*args.nb12 + i11*args.nb11 + args.o1[0];
+    device       char * dst_ptr  = dst  + i03*args.nb3  + i02*args.nb2  + i01*args.nb1  + args.offs;
 
     for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) {
         const int i10 = i0%args.ne10;
@@ -926,9 +950,9 @@ kernel void kernel_div(
     const int i12 = i02%args.ne12;
     const int i11 = i01%args.ne11;
 
-    device const char * src0_ptr = src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01;
-    device const char * src1_ptr = src1 + i13*args.nb13 + i12*args.nb12 + i11*args.nb11;
-    device       char * dst_ptr  = dst  + i03*args.nb3  + i02*args.nb2  + i01*args.nb1;
+    device const char * src0_ptr = src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01 + args.offs;
+    device const char * src1_ptr = src1 + i13*args.nb13 + i12*args.nb12 + i11*args.nb11 + args.o1[0];
+    device       char * dst_ptr  = dst  + i03*args.nb3  + i02*args.nb2  + i01*args.nb1  + args.offs;
 
     for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) {
         const int i10 = i0%args.ne10;
@@ -970,46 +994,145 @@ template [[host_name("kernel_repeat_i16")]] kernel kernel_repeat_t kernel_repeat
 
 // assumption: src1 is a row
 // broadcast src1 into src0
-kernel void kernel_add_row(
+template <short F>
+kernel void kernel_add_row_c4_fuse_impl(
         constant ggml_metal_kargs_bin & args,
-        device const float4 * src0,
-        device const float4 * src1,
-        device       float4 * dst,
+        device const char * src0,
+        device const char * src1,
+        device       char * dst,
         uint tpig[[thread_position_in_grid]]) {
+
     const uint nb = args.ne00/4;
-    dst[tpig] = src0[tpig] + src1[tpig % nb];
+    const uint i  = tpig % nb;
+
+    device const float4 * src0_row = (device const float4 *) (src0);
+    device       float4 *  dst_row = (device       float4 *) (dst);
+
+    device const float4 * src1_row[F];
+    for (short j = 0; j < F; ++j) {
+        src1_row[j] = (device const float4 *) (src1 + args.o1[j]);
+    }
+
+    float4 res = src0_row[tpig];
+
+#pragma unroll(F)
+    for (short j = 0; j < F; ++j) {
+        res += src1_row[j][i];
+    }
+
+    dst_row[tpig] = res;
 }
 
-kernel void kernel_sub_row(
+typedef decltype(kernel_add_row_c4_fuse_impl<1>) kernel_add_row_c4_fuse_t;
+
+template [[host_name("kernel_add_row_c4")]]        kernel kernel_add_row_c4_fuse_t kernel_add_row_c4_fuse_impl<1>;
+template [[host_name("kernel_add_row_c4_fuse_2")]] kernel kernel_add_row_c4_fuse_t kernel_add_row_c4_fuse_impl<2>;
+template [[host_name("kernel_add_row_c4_fuse_3")]] kernel kernel_add_row_c4_fuse_t kernel_add_row_c4_fuse_impl<3>;
+template [[host_name("kernel_add_row_c4_fuse_4")]] kernel kernel_add_row_c4_fuse_t kernel_add_row_c4_fuse_impl<4>;
+template [[host_name("kernel_add_row_c4_fuse_5")]] kernel kernel_add_row_c4_fuse_t kernel_add_row_c4_fuse_impl<5>;
+template [[host_name("kernel_add_row_c4_fuse_6")]] kernel kernel_add_row_c4_fuse_t kernel_add_row_c4_fuse_impl<6>;
+template [[host_name("kernel_add_row_c4_fuse_7")]] kernel kernel_add_row_c4_fuse_t kernel_add_row_c4_fuse_impl<7>;
+template [[host_name("kernel_add_row_c4_fuse_8")]] kernel kernel_add_row_c4_fuse_t kernel_add_row_c4_fuse_impl<8>;
+
+template <short F>
+kernel void kernel_sub_row_c4_fuse_impl(
         constant ggml_metal_kargs_bin & args,
-        device const float4 * src0,
-        device const float4 * src1,
-        device       float4 * dst,
+        device const char * src0,
+        device const char * src1,
+        device       char * dst,
         uint tpig[[thread_position_in_grid]]) {
+
     const uint nb = args.ne00/4;
-    dst[tpig] = src0[tpig] - src1[tpig % nb];
+    const uint i  = tpig % nb;
+
+    device const float4 * src0_row = (device const float4 *) (src0);
+    device       float4 *  dst_row = (device       float4 *) (dst);
+
+    device const float4 * src1_row[F];
+    for (short j = 0; j < F; ++j) {
+        src1_row[j] = (device const float4 *) (src1 + args.o1[j]);
+    }
+
+    float4 res = src0_row[tpig];
+
+#pragma unroll(F)
+    for (short j = 0; j < F; ++j) {
+        res -= src1_row[j][i];
+    }
+
+    dst_row[tpig] = res;
 }
 
-kernel void kernel_mul_row(
+typedef decltype(kernel_sub_row_c4_fuse_impl<1>) kernel_sub_row_c4_fuse_t;
+
+template [[host_name("kernel_sub_row_c4")]] kernel kernel_sub_row_c4_fuse_t kernel_sub_row_c4_fuse_impl<1>;
+
+template <short F>
+kernel void kernel_mul_row_c4_fuse_impl(
         constant ggml_metal_kargs_bin & args,
-        device const float4 * src0,
-        device const float4 * src1,
-        device       float4 * dst,
+        device const char * src0,
+        device const char * src1,
+        device       char * dst,
         uint tpig[[thread_position_in_grid]]) {
+
     const uint nb = args.ne00/4;
-    dst[tpig] = src0[tpig] * src1[tpig % nb];
+    const uint i  = tpig % nb;
+
+    device const float4 * src0_row = (device const float4 *) (src0);
+    device       float4 *  dst_row = (device       float4 *) (dst);
+
+    device const float4 * src1_row[F];
+    for (short j = 0; j < F; ++j) {
+        src1_row[j] = (device const float4 *) (src1 + args.o1[j]);
+    }
+
+    float4 res = src0_row[tpig];
+
+#pragma unroll(F)
+    for (short j = 0; j < F; ++j) {
+        res *= src1_row[j][i];
+    }
+
+    dst_row[tpig] = res;
 }
 
-kernel void kernel_div_row(
+typedef decltype(kernel_mul_row_c4_fuse_impl<1>) kernel_mul_row_c4_fuse_t;
+
+template [[host_name("kernel_mul_row_c4")]] kernel kernel_mul_row_c4_fuse_t kernel_mul_row_c4_fuse_impl<1>;
+
+template <short F>
+kernel void kernel_div_row_c4_fuse_impl(
         constant ggml_metal_kargs_bin & args,
-        device const float4 * src0,
-        device const float4 * src1,
-        device       float4 * dst,
+        device const char * src0,
+        device const char * src1,
+        device       char * dst,
         uint tpig[[thread_position_in_grid]]) {
+
     const uint nb = args.ne00/4;
-    dst[tpig] = src0[tpig] / src1[tpig % nb];
+    const uint i  = tpig % nb;
+
+    device const float4 * src0_row = (device const float4 *) (src0);
+    device       float4 *  dst_row = (device       float4 *) (dst);
+
+    device const float4 * src1_row[F];
+    for (short j = 0; j < F; ++j) {
+        src1_row[j] = (device const float4 *) (src1 + args.o1[j]);
+    }
+
+    float4 res = src0_row[tpig];
+
+#pragma unroll(F)
+    for (short j = 0; j < F; ++j) {
+        res /= src1_row[j][i];
+    }
+
+    dst_row[tpig] = res;
 }
 
+typedef decltype(kernel_div_row_c4_fuse_impl<1>) kernel_div_row_c4_fuse_t;
+
+template [[host_name("kernel_div_row_c4")]] kernel kernel_div_row_c4_fuse_t kernel_div_row_c4_fuse_impl<1>;
+
 kernel void kernel_scale(
         device const float * src0,
         device       float * dst,
@@ -2116,26 +2239,39 @@ kernel void kernel_norm(
     }
 }
 
-kernel void kernel_rms_norm(
+// F == 1 : rms_norm (no fuse)
+// F == 2 : rms_norm + mul
+// F == 3 : rms_norm + mul + add
+template <short F>
+kernel void kernel_rms_norm_fuse_impl(
         constant ggml_metal_kargs_rms_norm & args,
         device const char * src0,
+        device const char * src1_0,
+        device const char * src1_1,
         device       char * dst,
         threadgroup float * shmem_f32 [[threadgroup(0)]],
-        uint   tgpig[[threadgroup_position_in_grid]],
-        ushort tpitg[[thread_position_in_threadgroup]],
-        ushort sgitg[[simdgroup_index_in_threadgroup]],
-        ushort tiisg[[thread_index_in_simdgroup]],
-        ushort   ntg[[threads_per_threadgroup]]) {
+        uint3   tgpig[[threadgroup_position_in_grid]],
+        ushort3 tpitg[[thread_position_in_threadgroup]],
+        ushort  sgitg[[simdgroup_index_in_threadgroup]],
+        ushort  tiisg[[thread_index_in_simdgroup]],
+        ushort3   ntg[[threads_per_threadgroup]]) {
     if (sgitg == 0) {
         shmem_f32[tiisg] = 0.0f;
     }
 
-    device const float4 * x = (device const float4 *) (src0 + tgpig*args.nb01);
+    const int i01 = tgpig.x;
+    const int i02 = tgpig.y;
+    const int i03 = tgpig.z;
+
+    device const float4 * x = (device const float4 *) (src0 + i03*args.nbf3[0] + i02*args.nbf2[0] + i01*args.nbf1[0]);
+
+    device const float4 * f0 = (device const float4 *) (src1_0 + (i03%args.nef3[1])*args.nbf3[1] + (i02%args.nef2[1])*args.nbf2[1] + (i01%args.nef1[1])*args.nbf1[1]);
+    device const float4 * f1 = (device const float4 *) (src1_1 + (i03%args.nef3[2])*args.nbf3[2] + (i02%args.nef2[2])*args.nbf2[2] + (i01%args.nef1[2])*args.nbf1[2]);
 
     float sumf = 0.0f;
 
     // parallel sum
-    for (int i00 = tpitg; i00 < args.ne00_4; i00 += ntg) {
+    for (int i00 = tpitg.x; i00 < args.ne00_4; i00 += ntg.x) {
         sumf += dot(x[i00], x[i00]);
     }
     sumf = simd_sum(sumf);
@@ -2154,12 +2290,26 @@ kernel void kernel_rms_norm(
     const float mean  = sumf/args.ne00;
     const float scale = 1.0f/sqrt(mean + args.eps);
 
-    device float4 * y = (device float4 *) dst + tgpig*args.ne00_4;
-    for (int i00 = tpitg; i00 < args.ne00_4; i00 += ntg) {
-        y[i00] = x[i00] * scale;
+    device float4 * y = (device float4 *) (dst + i03*args.nb3 + i02*args.nb2 + i01*args.nb1);
+    for (int i00 = tpitg.x; i00 < args.ne00_4; i00 += ntg.x) {
+        if (F == 1) {
+            y[i00] = (x[i00]*scale);
+        }
+        if (F == 2) {
+            y[i00] = (x[i00]*scale)*f0[i00];
+        }
+        if (F == 3) {
+            y[i00] = (x[i00]*scale)*f0[i00] + f1[i00];
+        }
     }
 }
 
+typedef decltype(kernel_rms_norm_fuse_impl<1>) kernel_rms_norm_fuse_t;
+
+template [[host_name("kernel_rms_norm")]]         kernel kernel_rms_norm_fuse_t kernel_rms_norm_fuse_impl<1>;
+template [[host_name("kernel_rms_norm_mul")]]     kernel kernel_rms_norm_fuse_t kernel_rms_norm_fuse_impl<2>;
+template [[host_name("kernel_rms_norm_mul_add")]] kernel kernel_rms_norm_fuse_t kernel_rms_norm_fuse_impl<3>;
+
 kernel void kernel_l2_norm(
         constant ggml_metal_kargs_l2_norm & args,
         device const char * src0,
index eff0d563c831b4a60b735e94f8847659ac1039cf..b63a41053b488b23025b73df495b1f2f2e8c43d3 100644 (file)
@@ -907,20 +907,25 @@ ggml_tensor * llm_graph_context::build_moe_ffn(
         cb(cur, "ffn_moe_weighted", il);
     }
 
+    ggml_tensor * cur_experts[LLAMA_MAX_EXPERTS] = { nullptr };
+
+    assert(n_expert_used > 0);
+
+    // order the views before the adds
+    for (uint32_t i = 0; i < hparams.n_expert_used; ++i) {
+        cur_experts[i] = ggml_view_2d(ctx0, experts, n_embd, n_tokens, experts->nb[2], i*experts->nb[1]);
+
+        ggml_build_forward_expand(gf, cur_experts[i]);
+    }
+
     // aggregate experts
     // note: here we explicitly use hparams.n_expert_used instead of n_expert_used
     //       to avoid potentially a large number of add nodes during warmup
     //       ref: https://github.com/ggml-org/llama.cpp/pull/14753
-    ggml_tensor * moe_out = nullptr;
-    for (uint32_t i = 0; i < hparams.n_expert_used; ++i) {
-        ggml_tensor * cur_expert = ggml_view_2d(ctx0, experts, n_embd, n_tokens,
-                experts->nb[2], i*experts->nb[1]);
+    ggml_tensor * moe_out = cur_experts[0];
 
-        if (i == 0) {
-            moe_out = cur_expert;
-        } else {
-            moe_out = ggml_add(ctx0, moe_out, cur_expert);
-        }
+    for (uint32_t i = 1; i < hparams.n_expert_used; ++i) {
+        moe_out = ggml_add(ctx0, moe_out, cur_experts[i]);
     }
 
     if (hparams.n_expert_used == 1) {
index a3d68fba046cf5e207e45e12662f518c61be5a5d..bc732df5bb454c417ac0c79e8e32b28f40fe7ddf 100644 (file)
@@ -2353,9 +2353,12 @@ struct test_bin_bcast : public test_case {
     const ggml_type type;
     const std::array<int64_t, 4> ne;
     const std::array<int, 4> nr;
+    int nf; // number of fused ops, nf == 1 -> single op (no fusion)
+
+    bool run_whole_graph() override { return true; }
 
     std::string vars() override {
-        return VARS_TO_STR3(type, ne, nr);
+        return VARS_TO_STR4(type, ne, nr, nf);
     }
 
     size_t op_size(ggml_tensor * t) override {
@@ -2364,24 +2367,35 @@ struct test_bin_bcast : public test_case {
 
     test_bin_bcast(op_t op, ggml_type type = GGML_TYPE_F32,
             std::array<int64_t, 4> ne = {10, 10, 1, 1},
-            std::array<int, 4> nr = {1, 2, 1, 1})
-        : op(op), type(type), ne(ne), nr(nr) {}
+            std::array<int, 4> nr = {1, 2, 1, 1},
+            int nf = 1)
+        : op(op), type(type), ne(ne), nr(nr), nf(nf) {}
 
     ggml_tensor * build_graph(ggml_context * ctx) override {
+        GGML_ASSERT(nf <= 8);
+
         ggml_tensor * a = ggml_new_tensor_4d(ctx, type, ne[0]*nr[0], ne[1]*nr[1], ne[2]*nr[2], ne[3]*nr[3]);
         ggml_set_name(a, "a");
 
-        ggml_tensor * b = ggml_new_tensor(ctx, type, 4, ne.data());
-        ggml_set_name(b, "b");
+        ggml_tensor * b[8];
+        for (int i = 0; i < nf; ++i) {
+            b[i] = ggml_new_tensor(ctx, type, 4, ne.data());
+            ggml_set_name(b[i], (std::string("b") + std::to_string(i)).c_str());
+        }
 
         // The backward pass supports broadcasting only for GGML_ADD:
-        const bool grad_supported = op == ggml_add || ggml_are_same_shape(a, b);
+        const bool grad_supported = op == ggml_add && ggml_are_same_shape(a, b[0]) && nf == 1;
         if (grad_supported) {
             ggml_set_param(a);
-            ggml_set_param(b);
+            ggml_set_param(b[0]);
+        }
+
+        ggml_tensor * out = a;
+
+        for (int i = 0; i < nf; ++i) {
+            out = op(ctx, out, b[i]);
         }
 
-        ggml_tensor * out = op(ctx, a, b);
         ggml_set_name(out, "out");
 
         return out;
@@ -2622,15 +2636,15 @@ struct test_rms_norm_back : public test_case {
     }
 };
 
-// GGML_OP_RMS_NORM + GGML_OP_MUL
-struct test_rms_norm_mul : public test_case {
+// GGML_OP_RMS_NORM + GGML_OP_MUL + GGML_OP_ADD
+struct test_rms_norm_mul_add : public test_case {
     const ggml_type type;
     const std::array<int64_t, 4> ne;
     const float eps;
 
     std::string op_desc(ggml_tensor * t) override {
         GGML_UNUSED(t);
-        return "RMS_NORM_MUL";
+        return "RMS_NORM_MUL_ADD";
     }
 
     bool run_whole_graph() override { return true; }
@@ -2639,7 +2653,7 @@ struct test_rms_norm_mul : public test_case {
         return VARS_TO_STR3(type, ne, eps);
     }
 
-    test_rms_norm_mul(ggml_type type = GGML_TYPE_F32,
+    test_rms_norm_mul_add(ggml_type type = GGML_TYPE_F32,
             std::array<int64_t, 4> ne = {64, 5, 4, 3},
             float eps = 1e-6f)
         : type(type), ne(ne), eps(eps) {}
@@ -2647,14 +2661,17 @@ struct test_rms_norm_mul : public test_case {
     ggml_tensor * build_graph(ggml_context * ctx) override {
         ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne.data());
         ggml_tensor * b = ggml_new_tensor(ctx, type, 4, ne.data());
+        ggml_tensor * c = ggml_new_tensor(ctx, type, 4, ne.data());
         ggml_set_param(a);
         ggml_set_name(a, "a");
         ggml_set_param(b);
         ggml_set_name(b, "b");
+        ggml_set_param(c);
+        ggml_set_name(c, "c");
 
-        // Use a and b early, so we don't end up with an OP_NONE between rms_norm and mul
-        a = ggml_add(ctx, a, b);
-        ggml_tensor * out = ggml_mul(ctx, ggml_rms_norm(ctx, a, eps), b);
+        // Use a, b and c early, so we don't end up with an OP_NONE between rms_norm and mul
+        a = ggml_add(ctx, ggml_add(ctx, a, b), c);
+        ggml_tensor * out = ggml_add(ctx, ggml_mul(ctx, ggml_rms_norm(ctx, a, eps), b), c);
         ggml_set_name(out, "out");
 
         return out;
@@ -5151,6 +5168,15 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
         //add_test_bin_bcast(type, {3, 3, 2560, 1280}, {2, 1, 1, 1});
     }
 
+    // fusion
+    test_cases.emplace_back(new test_bin_bcast(ggml_add, GGML_TYPE_F32, {10, 5, 4, 3}, {2, 1, 1, 1}, 2));
+    test_cases.emplace_back(new test_bin_bcast(ggml_add, GGML_TYPE_F32, {16, 5, 4, 3}, {1, 2, 1, 1}, 3));
+    test_cases.emplace_back(new test_bin_bcast(ggml_add, GGML_TYPE_F32, {10, 5, 4, 3}, {1, 1, 2, 1}, 4));
+    test_cases.emplace_back(new test_bin_bcast(ggml_add, GGML_TYPE_F32, {16, 5, 4, 3}, {1, 1, 1, 2}, 5));
+    test_cases.emplace_back(new test_bin_bcast(ggml_add, GGML_TYPE_F32, {10, 5, 4, 3}, {1, 1, 2, 2}, 6));
+    test_cases.emplace_back(new test_bin_bcast(ggml_add, GGML_TYPE_F32, {10, 5, 4, 3}, {1, 2, 2, 2}, 7));
+    test_cases.emplace_back(new test_bin_bcast(ggml_add, GGML_TYPE_F32, {16, 5, 4, 3}, {2, 2, 2, 2}, 8));
+
     test_cases.emplace_back(new test_add1());
     test_cases.emplace_back(new test_scale());
     test_cases.emplace_back(new test_scale(GGML_TYPE_F32, {10, 10, 10, 10}, 2.0f, 1.0f));
@@ -5165,7 +5191,7 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
         test_cases.emplace_back(new test_l2_norm      (GGML_TYPE_F32, {64, 5, 4, 3}, eps));
     }
     for (float eps : {0.0f, 1e-6f, 1e-4f, 1e-1f, 1.0f}) {
-        test_cases.emplace_back(new test_rms_norm_mul(GGML_TYPE_F32, {64, 5, 4, 3}, eps));
+        test_cases.emplace_back(new test_rms_norm_mul_add(GGML_TYPE_F32, {64, 5, 4, 3}, eps));
     }
 
     test_cases.emplace_back(new test_l2_norm(GGML_TYPE_F32, {64, 5, 4, 3}, 1e-12f));