]> git.djapps.eu Git - pkg/ggml/sources/whisper.cpp/commitdiff
metal : fuse NORM + MUL + ADD, support non-multiples of 4 (llama/16220)
authorGeorgi Gerganov <redacted>
Thu, 25 Sep 2025 08:30:16 +0000 (11:30 +0300)
committerGeorgi Gerganov <redacted>
Mon, 29 Sep 2025 12:18:10 +0000 (15:18 +0300)
* metal : fuse NORM + MUL + ADD

* metal : support norms of non-multiple of 4

* cont : fix comment [no ci]

ggml/src/ggml-metal/ggml-metal-common.cpp
ggml/src/ggml-metal/ggml-metal-device.cpp
ggml/src/ggml-metal/ggml-metal-device.h
ggml/src/ggml-metal/ggml-metal-device.m
ggml/src/ggml-metal/ggml-metal-impl.h
ggml/src/ggml-metal/ggml-metal-ops.cpp
ggml/src/ggml-metal/ggml-metal-ops.h
ggml/src/ggml-metal/ggml-metal.metal

index 3ce7d75aad6c977ad24dee99955ef6f75581b03e..dc7d241c3ae4812167916f3c39daf6a04db89b3a 100644 (file)
@@ -383,6 +383,7 @@ void ggml_graph_optimize(ggml_cgraph * gf) {
         // fuse only ops that start with these operations
         // can be expanded when needed
         if (node.op() == GGML_OP_ADD ||
+            node.op() == GGML_OP_NORM ||
             node.op() == GGML_OP_RMS_NORM) {
             ops[0] = node.op();
 
@@ -392,6 +393,7 @@ void ggml_graph_optimize(ggml_cgraph * gf) {
                 // can be expanded when needed
                 if (gf->nodes[f]->op != GGML_OP_ADD &&
                     gf->nodes[f]->op != GGML_OP_MUL &&
+                    gf->nodes[f]->op != GGML_OP_NORM &&
                     gf->nodes[f]->op != GGML_OP_RMS_NORM) {
                     break;
                 }
index 8aeefd2c6830f14eea7d42f51815dd052dc6348d..03be2c01a717d888fdf5b4efc1ed7c62435a1f77 100644 (file)
@@ -1090,36 +1090,6 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_bin(
     return res;
 }
 
-ggml_metal_pipeline_t ggml_metal_library_get_pipeline_rms_norm(ggml_metal_library_t lib, const ggml_tensor * op, int32_t n_fuse) {
-    assert(op->op == GGML_OP_RMS_NORM);
-
-    GGML_ASSERT(op->src[0]->ne[0] % 4 == 0);
-    GGML_ASSERT(ggml_is_contiguous_rows(op->src[0]));
-
-    char base[256];
-    char name[256];
-
-    switch (n_fuse) {
-        case 1: snprintf(base, 256, "kernel_rms_norm_f32");         break;
-        case 2: snprintf(base, 256, "kernel_rms_norm_mul_f32");     break;
-        case 3: snprintf(base, 256, "kernel_rms_norm_mul_add_f32"); break;
-        default: GGML_ABORT("fatal error");
-    }
-
-    snprintf(name, 256, "%s", base);
-
-    ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name);
-    if (res) {
-        return res;
-    }
-
-    res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
-
-    ggml_metal_pipeline_set_smem(res, 32*sizeof(float));
-
-    return res;
-}
-
 ggml_metal_pipeline_t ggml_metal_library_get_pipeline_l2_norm(ggml_metal_library_t lib, const ggml_tensor * op) {
     assert(op->op == GGML_OP_L2_NORM);
 
@@ -1167,16 +1137,37 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_group_norm(ggml_metal_libr
     return res;
 }
 
-ggml_metal_pipeline_t ggml_metal_library_get_pipeline_norm(ggml_metal_library_t lib, const ggml_tensor * op) {
-    assert(op->op == GGML_OP_NORM);
+ggml_metal_pipeline_t ggml_metal_library_get_pipeline_norm(ggml_metal_library_t lib, const ggml_tensor * op, int n_fuse) {
+    assert(op->op == GGML_OP_NORM || op->op == GGML_OP_RMS_NORM);
 
-    GGML_ASSERT(op->src[0]->ne[0] % 4 == 0);
-    GGML_ASSERT(ggml_is_contiguous_1(op->src[0]));
+    GGML_ASSERT(ggml_is_contiguous_rows(op->src[0]));
 
     char base[256];
     char name[256];
 
-    snprintf(base, 256, "kernel_norm_f32");
+    const char * suffix = "";
+    if (op->ne[0] % 4 == 0) {
+        suffix = "_4";
+    }
+
+    switch (op->op) {
+        case GGML_OP_NORM:
+            switch (n_fuse) {
+                case 1: snprintf(base, 256, "kernel_norm_f32%s", suffix);         break;
+                case 2: snprintf(base, 256, "kernel_norm_mul_f32%s", suffix);     break;
+                case 3: snprintf(base, 256, "kernel_norm_mul_add_f32%s", suffix); break;
+                default: GGML_ABORT("fatal error");
+            } break;
+        case GGML_OP_RMS_NORM:
+            switch (n_fuse) {
+                case 1: snprintf(base, 256, "kernel_rms_norm_f32%s", suffix);         break;
+                case 2: snprintf(base, 256, "kernel_rms_norm_mul_f32%s", suffix);     break;
+                case 3: snprintf(base, 256, "kernel_rms_norm_mul_add_f32%s", suffix); break;
+                default: GGML_ABORT("fatal error");
+            } break;
+        default: GGML_ABORT("fatal error");
+    }
+
     snprintf(name, 256, "%s", base);
 
     ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name);
index da67bfab758e894e58e8bab6fd7f92930b5f84e6..dda7eca85a7e8df6a39948f99d8a3ce98cfe2b3b 100644 (file)
@@ -123,10 +123,9 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_mul_mv_id         (ggml_me
 ggml_metal_pipeline_t ggml_metal_library_get_pipeline_argmax            (ggml_metal_library_t lib, const struct ggml_tensor * op);
 ggml_metal_pipeline_t ggml_metal_library_get_pipeline_argsort           (ggml_metal_library_t lib, const struct ggml_tensor * op);
 ggml_metal_pipeline_t ggml_metal_library_get_pipeline_bin               (ggml_metal_library_t lib, enum ggml_op op, int32_t n_fuse, bool row);
-ggml_metal_pipeline_t ggml_metal_library_get_pipeline_rms_norm          (ggml_metal_library_t lib, const struct ggml_tensor * op, int32_t n_fuse);
 ggml_metal_pipeline_t ggml_metal_library_get_pipeline_l2_norm           (ggml_metal_library_t lib, const struct ggml_tensor * op);
 ggml_metal_pipeline_t ggml_metal_library_get_pipeline_group_norm        (ggml_metal_library_t lib, const struct ggml_tensor * op);
-ggml_metal_pipeline_t ggml_metal_library_get_pipeline_norm              (ggml_metal_library_t lib, const struct ggml_tensor * op);
+ggml_metal_pipeline_t ggml_metal_library_get_pipeline_norm              (ggml_metal_library_t lib, const struct ggml_tensor * op, int32_t n_fuse);
 ggml_metal_pipeline_t ggml_metal_library_get_pipeline_rope              (ggml_metal_library_t lib, const struct ggml_tensor * op);
 ggml_metal_pipeline_t ggml_metal_library_get_pipeline_im2col            (ggml_metal_library_t lib, const struct ggml_tensor * op);
 ggml_metal_pipeline_t ggml_metal_library_get_pipeline_conv_transpose_1d (ggml_metal_library_t lib, const struct ggml_tensor * op);
index 67f71ace2b44499c591f77ac8a6f7c247fe76693..5f744d1a0bd960719eeb359c5aa46d0923f89915 100644 (file)
@@ -661,13 +661,13 @@ bool ggml_metal_device_supports_op(ggml_metal_device_t dev, const struct ggml_te
         case GGML_OP_SOFT_MAX:
         case GGML_OP_GROUP_NORM:
             return has_simdgroup_reduction && ggml_is_contiguous_rows(op->src[0]);
-        case GGML_OP_RMS_NORM:
         case GGML_OP_L2_NORM:
             return has_simdgroup_reduction && (op->ne[0] % 4 == 0 && ggml_is_contiguous_1(op->src[0]));
         case GGML_OP_ARGMAX:
             return has_simdgroup_reduction;
         case GGML_OP_NORM:
-            return has_simdgroup_reduction && (op->ne[0] % 4 == 0 && ggml_is_contiguous_1(op->src[0]));
+        case GGML_OP_RMS_NORM:
+            return has_simdgroup_reduction && (ggml_is_contiguous_rows(op->src[0]));
         case GGML_OP_ROPE:
             return true;
         case GGML_OP_IM2COL:
index 3a7a4f317c638aa54ddf6e013742650842f4716d..ab51b76c8cd36fc1fe5a158befad0ed8ce58ef94 100644 (file)
@@ -428,16 +428,11 @@ typedef struct {
     uint64_t nb1;
 } ggml_metal_kargs_mul_mv_id;
 
+// NORM
+// RMS_NORM
 typedef struct {
     int32_t  ne00;
-    int32_t  ne00_4;
-    uint64_t nb01;
-    float    eps;
-} ggml_metal_kargs_norm;
-
-typedef struct {
-    int32_t  ne00;
-    int32_t  ne00_4;
+    int32_t  ne00_t;
     uint64_t nb1;
     uint64_t nb2;
     uint64_t nb3;
@@ -448,7 +443,7 @@ typedef struct {
     uint64_t nbf1[3];
     uint64_t nbf2[3];
     uint64_t nbf3[3];
-} ggml_metal_kargs_rms_norm;
+} ggml_metal_kargs_norm;
 
 typedef struct {
     int32_t  ne00;
index 15cea725135365496cd082cd3715b3452bb2d66b..7b11f36adccddde2d1ee927540817824a0228f20 100644 (file)
@@ -266,10 +266,6 @@ static int ggml_metal_op_encode_impl(ggml_metal_op_t ctx, int idx) {
             {
                 n_fuse = ggml_metal_op_set_rows(ctx, idx);
             } break;
-        case GGML_OP_RMS_NORM:
-            {
-                n_fuse = ggml_metal_op_rms_norm(ctx, idx);
-            } break;
         case GGML_OP_L2_NORM:
             {
                 n_fuse = ggml_metal_op_l2_norm(ctx, idx);
@@ -279,6 +275,7 @@ static int ggml_metal_op_encode_impl(ggml_metal_op_t ctx, int idx) {
                 n_fuse = ggml_metal_op_group_norm(ctx, idx);
             } break;
         case GGML_OP_NORM:
+        case GGML_OP_RMS_NORM:
             {
                 n_fuse = ggml_metal_op_norm(ctx, idx);
             } break;
@@ -2346,146 +2343,6 @@ int ggml_metal_op_bin(ggml_metal_op_t ctx, int idx) {
     return n_fuse;
 }
 
-int ggml_metal_op_rms_norm(ggml_metal_op_t ctx, int idx) {
-    ggml_cgraph * gf = ctx->gf;
-    ggml_tensor * op = ggml_graph_node(gf, idx);
-
-    ggml_metal_library_t lib = ctx->lib;
-    ggml_metal_encoder_t enc = ctx->enc;
-
-    const int idx_end = ctx->idx_end;
-
-    const bool use_fusion = ctx->use_fusion;
-
-    const int debug_fusion = ctx->debug_fusion;
-
-    ggml_tensor ** ops = ggml_graph_nodes(gf) + idx;
-
-    GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
-    GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
-    GGML_TENSOR_LOCALS( int32_t, ne,  op,         ne);
-    GGML_TENSOR_LOCALS(uint32_t, nb,  op,         nb);
-
-    float eps;
-    memcpy(&eps, op->op_params, sizeof(float));
-
-    ggml_metal_buffer_id bid_src0 = ggml_metal_get_buffer_id(op->src[0]);
-    ggml_metal_buffer_id bid_dst  = ggml_metal_get_buffer_id(op);
-
-    ggml_metal_kargs_rms_norm args = {
-        /*.ne00   =*/ ne00,
-        /*.ne00_4 =*/ ne00/4,
-        /*.nb1    =*/ nb1,
-        /*.nb2    =*/ nb2,
-        /*.nb3    =*/ nb3,
-        /*.eps    =*/ eps,
-        /*.nef1   =*/ { ne01 },
-        /*.nef2   =*/ { ne02 },
-        /*.nef3   =*/ { ne03 },
-        /*.nbf1   =*/ { nb01 },
-        /*.nbf2   =*/ { nb02 },
-        /*.nbf3   =*/ { nb03 },
-    };
-
-    ggml_op fops[8];
-
-    int n_fuse = 1;
-
-    ggml_metal_buffer_id bid_fuse[2] = { bid_src0, bid_src0 };
-
-    // d[0] = rms_norm(a)
-    // d[1] = mul(d[0], b)
-    // d[2] = add(d[1], c)
-    if (use_fusion) {
-        fops[0] = GGML_OP_RMS_NORM;
-        fops[1] = GGML_OP_MUL;
-        fops[2] = GGML_OP_ADD;
-
-        for (n_fuse = 0; n_fuse <= 1 && idx + n_fuse + 1 < idx_end; ++n_fuse) {
-            if (!ggml_can_fuse(gf, idx + n_fuse, fops + n_fuse, 2)) {
-                break;
-            }
-
-            if (ops[n_fuse] != ops[n_fuse + 1]->src[0]) {
-                break;
-            }
-
-            if (ops[n_fuse + 1]->src[1]->ne[0] != op->ne[0]) {
-                break;
-            }
-
-            if (!ggml_is_contiguous_rows(ops[n_fuse + 1]->src[1])) {
-                break;
-            }
-
-            if (ops[n_fuse + 1]->type != GGML_TYPE_F32) {
-                break;
-            }
-
-            //ctx->fuse_cnt[ops[n_fuse + 1]->op]++;
-
-            bid_fuse[n_fuse] = ggml_metal_get_buffer_id(ops[n_fuse + 1]->src[1]);
-
-            args.nef1[n_fuse + 1] = ops[n_fuse + 1]->src[1]->ne[1];
-            args.nef2[n_fuse + 1] = ops[n_fuse + 1]->src[1]->ne[2];
-            args.nef3[n_fuse + 1] = ops[n_fuse + 1]->src[1]->ne[3];
-
-            args.nbf1[n_fuse + 1] = ops[n_fuse + 1]->src[1]->nb[1];
-            args.nbf2[n_fuse + 1] = ops[n_fuse + 1]->src[1]->nb[2];
-            args.nbf3[n_fuse + 1] = ops[n_fuse + 1]->src[1]->nb[3];
-        }
-
-        ++n_fuse;
-
-        if (debug_fusion > 1 && n_fuse > 1) {
-            if (n_fuse == 2) {
-                GGML_LOG_DEBUG("%s: fuse: RMS_NORM + MUL\n", __func__);
-            }
-            if (n_fuse == 3) {
-                GGML_LOG_DEBUG("%s: fuse: RMS_NORM + MUL + ADD\n", __func__);
-            }
-        }
-    }
-
-    if (n_fuse > 1) {
-        bid_dst = ggml_metal_get_buffer_id(ops[n_fuse - 1]);
-
-        for (int i = 1; i < n_fuse; ++i) {
-            if (!ggml_metal_op_concurrency_check(ctx, ops[i])) {
-                ggml_metal_op_concurrency_reset(ctx);
-
-                break;
-            }
-        }
-    }
-
-    ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_rms_norm(lib, op, n_fuse);
-
-    int nth = 32; // SIMD width
-
-    while (nth < ne00/4 && nth < ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)) {
-        nth *= 2;
-    }
-
-    nth = std::min(nth, ggml_metal_pipeline_max_theads_per_threadgroup(pipeline));
-    nth = std::min(nth, ne00/4);
-
-    const size_t smem = ggml_metal_pipeline_get_smem(pipeline);
-
-    ggml_metal_encoder_set_pipeline(enc, pipeline);
-    ggml_metal_encoder_set_bytes   (enc, &args, sizeof(args), 0);
-    ggml_metal_encoder_set_buffer  (enc, bid_src0, 1);
-    ggml_metal_encoder_set_buffer  (enc, bid_fuse[0], 2);
-    ggml_metal_encoder_set_buffer  (enc, bid_fuse[1], 3);
-    ggml_metal_encoder_set_buffer  (enc, bid_dst, 4);
-
-    ggml_metal_encoder_set_threadgroup_memory_size(enc, smem, 0);
-
-    ggml_metal_encoder_dispatch_threadgroups(enc, ne01, ne02, ne03, nth, 1, 1);
-
-    return n_fuse;
-}
-
 int ggml_metal_op_l2_norm(ggml_metal_op_t ctx, int idx) {
     ggml_cgraph * gf = ctx->gf;
     ggml_tensor * op = ggml_graph_node(gf, idx);
@@ -2594,6 +2451,14 @@ int ggml_metal_op_norm(ggml_metal_op_t ctx, int idx) {
     ggml_metal_library_t lib = ctx->lib;
     ggml_metal_encoder_t enc = ctx->enc;
 
+    const int idx_end = ctx->idx_end;
+
+    const bool use_fusion = ctx->use_fusion;
+
+    const int debug_fusion = ctx->debug_fusion;
+
+    ggml_tensor ** ops = ggml_graph_nodes(gf) + idx;
+
     GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
     GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
     GGML_TENSOR_LOCALS( int32_t, ne,  op,         ne);
@@ -2602,37 +2467,121 @@ int ggml_metal_op_norm(ggml_metal_op_t ctx, int idx) {
     float eps;
     memcpy(&eps, op->op_params, sizeof(float));
 
+    ggml_metal_buffer_id bid_src0 = ggml_metal_get_buffer_id(op->src[0]);
+    ggml_metal_buffer_id bid_dst  = ggml_metal_get_buffer_id(op);
+
     ggml_metal_kargs_norm args = {
         /*.ne00   =*/ ne00,
-        /*.ne00_4 =*/ ne00/4,
-        /*.nb01   =*/ nb01,
+        /*.ne00_t =*/ ne00 % 4 == 0 ? ne00/4 : ne00,
+        /*.nb1    =*/ nb1,
+        /*.nb2    =*/ nb2,
+        /*.nb3    =*/ nb3,
         /*.eps    =*/ eps,
+        /*.nef1   =*/ { ne01 },
+        /*.nef2   =*/ { ne02 },
+        /*.nef3   =*/ { ne03 },
+        /*.nbf1   =*/ { nb01 },
+        /*.nbf2   =*/ { nb02 },
+        /*.nbf3   =*/ { nb03 },
     };
 
-    ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_norm(lib, op);
+    ggml_op fops[8];
+
+    int n_fuse = 1;
+
+    ggml_metal_buffer_id bid_fuse[2] = { bid_src0, bid_src0 };
+
+    // d[0] = norm(a)
+    // d[1] = mul(d[0], b)
+    // d[2] = add(d[1], c)
+    if (use_fusion) {
+        fops[0] = op->op;
+        fops[1] = GGML_OP_MUL;
+        fops[2] = GGML_OP_ADD;
+
+        for (n_fuse = 0; n_fuse <= 1 && idx + n_fuse + 1 < idx_end; ++n_fuse) {
+            if (!ggml_can_fuse(gf, idx + n_fuse, fops + n_fuse, 2)) {
+                break;
+            }
+
+            if (ops[n_fuse] != ops[n_fuse + 1]->src[0]) {
+                break;
+            }
+
+            if (ops[n_fuse + 1]->src[1]->ne[0] != op->ne[0]) {
+                break;
+            }
+
+            if (!ggml_is_contiguous_rows(ops[n_fuse + 1]->src[1])) {
+                break;
+            }
+
+            if (ops[n_fuse + 1]->type != GGML_TYPE_F32) {
+                break;
+            }
+
+            //ctx->fuse_cnt[ops[n_fuse + 1]->op]++;
+
+            bid_fuse[n_fuse] = ggml_metal_get_buffer_id(ops[n_fuse + 1]->src[1]);
+
+            args.nef1[n_fuse + 1] = ops[n_fuse + 1]->src[1]->ne[1];
+            args.nef2[n_fuse + 1] = ops[n_fuse + 1]->src[1]->ne[2];
+            args.nef3[n_fuse + 1] = ops[n_fuse + 1]->src[1]->ne[3];
+
+            args.nbf1[n_fuse + 1] = ops[n_fuse + 1]->src[1]->nb[1];
+            args.nbf2[n_fuse + 1] = ops[n_fuse + 1]->src[1]->nb[2];
+            args.nbf3[n_fuse + 1] = ops[n_fuse + 1]->src[1]->nb[3];
+        }
+
+        ++n_fuse;
+
+        if (debug_fusion > 1 && n_fuse > 1) {
+            if (n_fuse == 2) {
+                GGML_LOG_DEBUG("%s: fuse: %s + MUL\n", __func__, ggml_op_name(op->op));
+            }
+            if (n_fuse == 3) {
+                GGML_LOG_DEBUG("%s: fuse: %s + MUL + ADD\n", __func__, ggml_op_name(op->op));
+            }
+        }
+    }
+
+    if (n_fuse > 1) {
+        bid_dst = ggml_metal_get_buffer_id(ops[n_fuse - 1]);
+
+        for (int i = 1; i < n_fuse; ++i) {
+            if (!ggml_metal_op_concurrency_check(ctx, ops[i])) {
+                ggml_metal_op_concurrency_reset(ctx);
+
+                break;
+            }
+        }
+    }
+
+    ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_norm(lib, op, n_fuse);
 
     int nth = 32; // SIMD width
-    while (nth < ne00/4 && nth < ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)) {
+
+    while (nth < args.ne00_t && nth < ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)) {
         nth *= 2;
     }
 
     nth = std::min(nth, ggml_metal_pipeline_max_theads_per_threadgroup(pipeline));
-    nth = std::min(nth, ne00/4);
+    nth = std::min(nth, args.ne00_t);
 
     const size_t smem = ggml_metal_pipeline_get_smem(pipeline);
 
-    const int64_t nrows = ggml_nrows(op->src[0]);
-
     ggml_metal_encoder_set_pipeline(enc, pipeline);
     ggml_metal_encoder_set_bytes   (enc, &args, sizeof(args), 0);
-    ggml_metal_encoder_set_buffer  (enc, ggml_metal_get_buffer_id(op->src[0]), 1);
-    ggml_metal_encoder_set_buffer  (enc, ggml_metal_get_buffer_id(op),         2);
+    ggml_metal_encoder_set_buffer  (enc, bid_src0,    1);
+    ggml_metal_encoder_set_buffer  (enc, bid_fuse[0], 2);
+    ggml_metal_encoder_set_buffer  (enc, bid_fuse[1], 3);
+    ggml_metal_encoder_set_buffer  (enc, bid_dst,     4);
 
     ggml_metal_encoder_set_threadgroup_memory_size(enc, smem, 0);
 
-    ggml_metal_encoder_dispatch_threadgroups(enc, nrows, 1, 1, nth, 1, 1);
+    ggml_metal_encoder_dispatch_threadgroups(enc, ne01, ne02, ne03, nth, 1, 1);
 
-    return 1;
+    return n_fuse;
 }
 
 int ggml_metal_op_rope(ggml_metal_op_t ctx, int idx) {
index b620de164d755f5fdb5606af62a5e6aa96c99601..a1151f881b372392bba21aaac076614d639e0593 100644 (file)
@@ -60,7 +60,6 @@ int ggml_metal_op_mul_mat_id        (ggml_metal_op_t ctx, int idx);
 int ggml_metal_op_add_id            (ggml_metal_op_t ctx, int idx);
 int ggml_metal_op_flash_attn_ext    (ggml_metal_op_t ctx, int idx);
 int ggml_metal_op_bin               (ggml_metal_op_t ctx, int idx);
-int ggml_metal_op_rms_norm          (ggml_metal_op_t ctx, int idx);
 int ggml_metal_op_l2_norm           (ggml_metal_op_t ctx, int idx);
 int ggml_metal_op_group_norm        (ggml_metal_op_t ctx, int idx);
 int ggml_metal_op_norm              (ggml_metal_op_t ctx, int idx);
index 339cbf91fbe5797ac140fa304fd4f13a6da2c41e..48856c79b7bd2a6f98f24ce2539b8c98b6788a8b 100644 (file)
@@ -66,6 +66,10 @@ static inline float e8m0_to_fp32(uint8_t x) {
     return as_type<float>(bits);
 }
 
+static inline float dot(float x, float y) {
+    return x*y;
+}
+
 // NOTE: this is not dequantizing - we are simply fitting the template
 template <typename type4x4>
 void dequantize_f32(device const float4x4 * src, short il, thread type4x4 & reg) {
@@ -2493,30 +2497,43 @@ kernel void kernel_argmax_f32(
     dst_i32[tgpig] = arg_val;
 }
 
-kernel void kernel_norm_f32(
+// F == 1 : norm (no fuse)
+// F == 2 : norm + mul
+// F == 3 : norm + mul + add
+template <typename T, short F>
+kernel void kernel_norm_fuse_impl(
         constant ggml_metal_kargs_norm & args,
         device const char * src0,
+        device const char * src1_0,
+        device const char * src1_1,
         device       char * dst,
         threadgroup float * shmem_f32 [[threadgroup(0)]],
-        uint   tgpig[[threadgroup_position_in_grid]],
-        ushort tpitg[[thread_position_in_threadgroup]],
-        ushort sgitg[[simdgroup_index_in_threadgroup]],
-        ushort tiisg[[thread_index_in_simdgroup]],
-        ushort   ntg[[threads_per_threadgroup]]) {
+        uint3   tgpig[[threadgroup_position_in_grid]],
+        ushort3 tpitg[[thread_position_in_threadgroup]],
+        ushort  sgitg[[simdgroup_index_in_threadgroup]],
+        ushort  tiisg[[thread_index_in_simdgroup]],
+        ushort3   ntg[[threads_per_threadgroup]]) {
     if (sgitg == 0) {
         shmem_f32[tiisg] = 0.0f;
     }
 
-    device const float4 * x = (device const float4 *) (src0 + tgpig*args.nb01);
+    const int i01 = tgpig.x;
+    const int i02 = tgpig.y;
+    const int i03 = tgpig.z;
 
-    float4 sumf4(0.0f);
+    device const T * x = (device const T *) (src0 + i03*args.nbf3[0] + i02*args.nbf2[0] + i01*args.nbf1[0]);
+
+    device const T * f0 = (device const T *) (src1_0 + (i03%args.nef3[1])*args.nbf3[1] + (i02%args.nef2[1])*args.nbf2[1] + (i01%args.nef1[1])*args.nbf1[1]);
+    device const T * f1 = (device const T *) (src1_1 + (i03%args.nef3[2])*args.nbf3[2] + (i02%args.nef2[2])*args.nbf2[2] + (i01%args.nef1[2])*args.nbf1[2]);
+
+    T sumft(0.0f);
 
     float sumf = 0.0f;
 
-    for (int i00 = tpitg; i00 < args.ne00_4; i00 += ntg) {
-        sumf4 += x[i00];
+    for (int i00 = tpitg.x; i00 < args.ne00_t; i00 += ntg.x) {
+        sumft += x[i00];
     }
-    sumf = sumf4[0] + sumf4[1] + sumf4[2] + sumf4[3];
+    sumf = dot(sumft, T(1.0f));
     sumf = simd_sum(sumf);
 
     threadgroup_barrier(mem_flags::mem_threadgroup);
@@ -2532,10 +2549,10 @@ kernel void kernel_norm_f32(
 
     const float mean = sumf/args.ne00;
 
-    device float4 * y = (device float4 *) dst + tgpig*args.ne00_4;
+    device T * y = (device T *) (dst + i03*args.nb3 + i02*args.nb2 + i01*args.nb1);
 
     sumf = 0.0f;
-    for (int i00 = tpitg; i00 < args.ne00_4; i00 += ntg) {
+    for (int i00 = tpitg.x; i00 < args.ne00_t; i00 += ntg.x) {
         y[i00] = x[i00] - mean;
         sumf += dot(y[i00], y[i00]);
     }
@@ -2555,17 +2572,35 @@ kernel void kernel_norm_f32(
     const float variance = sumf/args.ne00;
 
     const float scale = 1.0f/sqrt(variance + args.eps);
-    for (int i00 = tpitg; i00 < args.ne00_4; i00 += ntg) {
-        y[i00] = y[i00] * scale;
+    for (int i00 = tpitg.x; i00 < args.ne00_t; i00 += ntg.x) {
+        if (F == 1) {
+            y[i00] = (y[i00]*scale);
+        }
+        if (F == 2) {
+            y[i00] = (y[i00]*scale)*f0[i00];
+        }
+        if (F == 3) {
+            y[i00] = (y[i00]*scale)*f0[i00] + f1[i00];
+        }
     }
 }
 
+typedef decltype(kernel_norm_fuse_impl<float4, 1>) kernel_norm_fuse_t;
+
+template [[host_name("kernel_norm_f32")]]         kernel kernel_norm_fuse_t kernel_norm_fuse_impl<float, 1>;
+template [[host_name("kernel_norm_mul_f32")]]     kernel kernel_norm_fuse_t kernel_norm_fuse_impl<float, 2>;
+template [[host_name("kernel_norm_mul_add_f32")]] kernel kernel_norm_fuse_t kernel_norm_fuse_impl<float, 3>;
+
+template [[host_name("kernel_norm_f32_4")]]         kernel kernel_norm_fuse_t kernel_norm_fuse_impl<float4, 1>;
+template [[host_name("kernel_norm_mul_f32_4")]]     kernel kernel_norm_fuse_t kernel_norm_fuse_impl<float4, 2>;
+template [[host_name("kernel_norm_mul_add_f32_4")]] kernel kernel_norm_fuse_t kernel_norm_fuse_impl<float4, 3>;
+
 // F == 1 : rms_norm (no fuse)
 // F == 2 : rms_norm + mul
 // F == 3 : rms_norm + mul + add
-template <short F>
+template <typename T, short F>
 kernel void kernel_rms_norm_fuse_impl(
-        constant ggml_metal_kargs_rms_norm & args,
+        constant ggml_metal_kargs_norm & args,
         device const char * src0,
         device const char * src1_0,
         device const char * src1_1,
@@ -2584,15 +2619,15 @@ kernel void kernel_rms_norm_fuse_impl(
     const int i02 = tgpig.y;
     const int i03 = tgpig.z;
 
-    device const float4 * x = (device const float4 *) (src0 + i03*args.nbf3[0] + i02*args.nbf2[0] + i01*args.nbf1[0]);
+    device const T * x = (device const T *) (src0 + i03*args.nbf3[0] + i02*args.nbf2[0] + i01*args.nbf1[0]);
 
-    device const float4 * f0 = (device const float4 *) (src1_0 + (i03%args.nef3[1])*args.nbf3[1] + (i02%args.nef2[1])*args.nbf2[1] + (i01%args.nef1[1])*args.nbf1[1]);
-    device const float4 * f1 = (device const float4 *) (src1_1 + (i03%args.nef3[2])*args.nbf3[2] + (i02%args.nef2[2])*args.nbf2[2] + (i01%args.nef1[2])*args.nbf1[2]);
+    device const T * f0 = (device const T *) (src1_0 + (i03%args.nef3[1])*args.nbf3[1] + (i02%args.nef2[1])*args.nbf2[1] + (i01%args.nef1[1])*args.nbf1[1]);
+    device const T * f1 = (device const T *) (src1_1 + (i03%args.nef3[2])*args.nbf3[2] + (i02%args.nef2[2])*args.nbf2[2] + (i01%args.nef1[2])*args.nbf1[2]);
 
     float sumf = 0.0f;
 
     // parallel sum
-    for (int i00 = tpitg.x; i00 < args.ne00_4; i00 += ntg.x) {
+    for (int i00 = tpitg.x; i00 < args.ne00_t; i00 += ntg.x) {
         sumf += dot(x[i00], x[i00]);
     }
     sumf = simd_sum(sumf);
@@ -2611,8 +2646,8 @@ kernel void kernel_rms_norm_fuse_impl(
     const float mean  = sumf/args.ne00;
     const float scale = 1.0f/sqrt(mean + args.eps);
 
-    device float4 * y = (device float4 *) (dst + i03*args.nb3 + i02*args.nb2 + i01*args.nb1);
-    for (int i00 = tpitg.x; i00 < args.ne00_4; i00 += ntg.x) {
+    device T * y = (device T *) (dst + i03*args.nb3 + i02*args.nb2 + i01*args.nb1);
+    for (int i00 = tpitg.x; i00 < args.ne00_t; i00 += ntg.x) {
         if (F == 1) {
             y[i00] = (x[i00]*scale);
         }
@@ -2625,11 +2660,15 @@ kernel void kernel_rms_norm_fuse_impl(
     }
 }
 
-typedef decltype(kernel_rms_norm_fuse_impl<1>) kernel_rms_norm_fuse_t;
+typedef decltype(kernel_rms_norm_fuse_impl<float4, 1>) kernel_rms_norm_fuse_t;
+
+template [[host_name("kernel_rms_norm_f32")]]         kernel kernel_rms_norm_fuse_t kernel_rms_norm_fuse_impl<float, 1>;
+template [[host_name("kernel_rms_norm_mul_f32")]]     kernel kernel_rms_norm_fuse_t kernel_rms_norm_fuse_impl<float, 2>;
+template [[host_name("kernel_rms_norm_mul_add_f32")]] kernel kernel_rms_norm_fuse_t kernel_rms_norm_fuse_impl<float, 3>;
 
-template [[host_name("kernel_rms_norm_f32")]]         kernel kernel_rms_norm_fuse_t kernel_rms_norm_fuse_impl<1>;
-template [[host_name("kernel_rms_norm_mul_f32")]]     kernel kernel_rms_norm_fuse_t kernel_rms_norm_fuse_impl<2>;
-template [[host_name("kernel_rms_norm_mul_add_f32")]] kernel kernel_rms_norm_fuse_t kernel_rms_norm_fuse_impl<3>;
+template [[host_name("kernel_rms_norm_f32_4")]]         kernel kernel_rms_norm_fuse_t kernel_rms_norm_fuse_impl<float4, 1>;
+template [[host_name("kernel_rms_norm_mul_f32_4")]]     kernel kernel_rms_norm_fuse_t kernel_rms_norm_fuse_impl<float4, 2>;
+template [[host_name("kernel_rms_norm_mul_add_f32_4")]] kernel kernel_rms_norm_fuse_t kernel_rms_norm_fuse_impl<float4, 3>;
 
 kernel void kernel_l2_norm_f32(
         constant ggml_metal_kargs_l2_norm & args,