]> git.djapps.eu Git - pkg/ggml/sources/ggml/commitdiff
metal : extend l2_norm support for non-cont src0 (llama/19502)
authorGeorgi Gerganov <redacted>
Wed, 11 Feb 2026 12:53:19 +0000 (14:53 +0200)
committerGeorgi Gerganov <redacted>
Sat, 14 Feb 2026 22:20:18 +0000 (00:20 +0200)
src/ggml-metal/ggml-metal-device.cpp
src/ggml-metal/ggml-metal-device.m
src/ggml-metal/ggml-metal-impl.h
src/ggml-metal/ggml-metal-ops.cpp
src/ggml-metal/ggml-metal.metal

index 949e344cc8c3f0e420df7c3f31c9894b1ac53573..517559d12a667dc720f8a693dcac8709c0eb7d3c 100644 (file)
@@ -1480,13 +1480,15 @@ ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_bin_one(ggml_met
 ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_l2_norm(ggml_metal_library_t lib, const ggml_tensor * op) {
     assert(op->op == GGML_OP_L2_NORM);
 
-    GGML_ASSERT(op->src[0]->ne[0] % 4 == 0);
-    GGML_ASSERT(ggml_is_contiguous_1(op->src[0]));
-
     char base[256];
     char name[256];
 
-    snprintf(base, 256, "kernel_l2_norm_f32");
+    const bool is_c4 = op->src[0]->ne[0] % 4 == 0;
+
+    const char * t0_str = ggml_type_name(op->src[0]->type);
+    const char * t_str  = ggml_type_name(op->type);
+
+    snprintf(base, 256, "kernel_l2_norm_%s_%s%s", t0_str, t_str, is_c4 ? "_4" : "");
     snprintf(name, 256, "%s", base);
 
     ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);
@@ -1494,6 +1496,7 @@ ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_l2_norm(ggml_met
         res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
     }
 
+    res.c4   = is_c4;
     res.smem = 32*sizeof(float);
 
     return res;
index 50a2a3e7f72afa7f514d34c4d3088043b07274a6..c714ef3add9fa12221e162c0328a7af9394a534c 100644 (file)
@@ -1086,9 +1086,8 @@ bool ggml_metal_device_supports_op(ggml_metal_device_t dev, const struct ggml_te
         case GGML_OP_MEAN:
         case GGML_OP_SOFT_MAX:
         case GGML_OP_GROUP_NORM:
-            return has_simdgroup_reduction && ggml_is_contiguous_rows(op->src[0]);
         case GGML_OP_L2_NORM:
-            return has_simdgroup_reduction && (op->ne[0] % 4 == 0 && ggml_is_contiguous_1(op->src[0]));
+            return has_simdgroup_reduction && ggml_is_contiguous_rows(op->src[0]);
         case GGML_OP_COUNT_EQUAL:
             return has_simdgroup_reduction &&
                 op->src[0]->type == GGML_TYPE_I32 &&
index 44141f8e3d952bc030d1e30840f5e62e4541cfa7..952e1be076e01e05f42fa6b349f667976a925faf 100644 (file)
@@ -539,8 +539,21 @@ typedef struct {
 
 typedef struct {
     int32_t  ne00;
-    int32_t  ne00_4;
+    int32_t  ne01;
+    int32_t  ne02;
+    int32_t  ne03;
+    uint64_t nb00;
     uint64_t nb01;
+    uint64_t nb02;
+    uint64_t nb03;
+    int32_t  ne0;
+    int32_t  ne1;
+    int32_t  ne2;
+    int32_t  ne3;
+    uint64_t nb0;
+    uint64_t nb1;
+    uint64_t nb2;
+    uint64_t nb3;
     float    eps;
 } ggml_metal_kargs_l2_norm;
 
index b159a8e7fd017f12e457ddb90c193ea45ea14acf..7db95d1c84d3a08e1df706b407e67d1aae7bb0e7 100644 (file)
@@ -2979,39 +2979,59 @@ int ggml_metal_op_l2_norm(ggml_metal_op_t ctx, int idx) {
     GGML_TENSOR_LOCALS( int32_t, ne,  op,         ne);
     GGML_TENSOR_LOCALS(uint64_t, nb,  op,         nb);
 
+    GGML_ASSERT(ggml_is_contiguous_rows(op->src[0]));
+
+    ggml_metal_buffer_id bid_src0 = ggml_metal_get_buffer_id(op->src[0]);
+    ggml_metal_buffer_id bid_dst  = ggml_metal_get_buffer_id(op);
+
     float eps;
     memcpy(&eps, op->op_params, sizeof(float));
 
-    int nth = 32; // SIMD width
-
     ggml_metal_kargs_l2_norm args = {
-        /*.ne00   =*/ ne00,
-        /*.ne00_4 =*/ ne00/4,
-        /*.nb01   =*/ nb01,
-        /*.eps    =*/ eps,
+        /*.ne00  =*/ ne00,
+        /*.ne01  =*/ ne01,
+        /*.ne02  =*/ ne02,
+        /*.ne03  =*/ ne03,
+        /*.nb00  =*/ nb00,
+        /*.nb01  =*/ nb01,
+        /*.nb02  =*/ nb02,
+        /*.nb03  =*/ nb03,
+        /*.ne0   =*/ ne0,
+        /*.ne1   =*/ ne1,
+        /*.ne2   =*/ ne2,
+        /*.ne3   =*/ ne3,
+        /*.nb0   =*/ nb0,
+        /*.nb1   =*/ nb1,
+        /*.nb2   =*/ nb2,
+        /*.nb3   =*/ nb3,
+        /*.eps   =*/ eps,
     };
 
     auto pipeline = ggml_metal_library_get_pipeline_l2_norm(lib, op);
 
-    while (nth < ne00/4 && nth < ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)) {
+    if (pipeline.c4) {
+        args.ne00 = ne00/4;
+        args.ne0  = ne0/4;
+    }
+
+    int nth = 32; // SIMD width
+
+    while (nth < ne00 && nth < ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)) {
         nth *= 2;
     }
 
     nth = std::min(nth, ggml_metal_pipeline_max_theads_per_threadgroup(pipeline));
-    nth = std::min(nth, ne00/4);
 
     const size_t smem = pipeline.smem;
 
-    const int64_t nrows = ggml_nrows(op->src[0]);
-
     ggml_metal_encoder_set_pipeline(enc, pipeline);
     ggml_metal_encoder_set_bytes   (enc, &args, sizeof(args), 0);
-    ggml_metal_encoder_set_buffer  (enc, ggml_metal_get_buffer_id(op->src[0]), 1);
-    ggml_metal_encoder_set_buffer  (enc, ggml_metal_get_buffer_id(op),         2);
+    ggml_metal_encoder_set_buffer  (enc, bid_src0, 1);
+    ggml_metal_encoder_set_buffer  (enc, bid_dst,  2);
 
     ggml_metal_encoder_set_threadgroup_memory_size(enc, smem, 0);
 
-    ggml_metal_encoder_dispatch_threadgroups(enc, nrows, 1, 1, nth, 1, 1);
+    ggml_metal_encoder_dispatch_threadgroups(enc, ne01, ne02, ne03, nth, 1, 1);
 
     return 1;
 }
index 7d841341a18638bdf4377d1725f09f15821c44b0..a385a50b9424848e0be41cbc898ad3ada611ef99 100644 (file)
@@ -2706,26 +2706,32 @@ template [[host_name("kernel_rms_norm_f32_4")]]         kernel kernel_rms_norm_f
 template [[host_name("kernel_rms_norm_mul_f32_4")]]     kernel kernel_rms_norm_fuse_t kernel_rms_norm_fuse_impl<float4, 2>;
 template [[host_name("kernel_rms_norm_mul_add_f32_4")]] kernel kernel_rms_norm_fuse_t kernel_rms_norm_fuse_impl<float4, 3>;
 
-kernel void kernel_l2_norm_f32(
+template <typename T0, typename T>
+kernel void kernel_l2_norm_impl(
         constant ggml_metal_kargs_l2_norm & args,
         device const char * src0,
         device       char * dst,
         threadgroup float * shmem_f32 [[threadgroup(0)]],
-        uint   tgpig[[threadgroup_position_in_grid]],
-        ushort tpitg[[thread_position_in_threadgroup]],
-        ushort sgitg[[simdgroup_index_in_threadgroup]],
-        ushort tiisg[[thread_index_in_simdgroup]],
-        ushort   ntg[[threads_per_threadgroup]]) {
+        uint3   tgpig[[threadgroup_position_in_grid]],
+        ushort3 tpitg[[thread_position_in_threadgroup]],
+        ushort  sgitg[[simdgroup_index_in_threadgroup]],
+        ushort  tiisg[[thread_index_in_simdgroup]],
+        ushort3   ntg[[threads_per_threadgroup]]) {
+    const int i03 = tgpig.z;
+    const int i02 = tgpig.y;
+    const int i01 = tgpig.x;
+
     if (sgitg == 0) {
         shmem_f32[tiisg] = 0.0f;
     }
 
-    device const float4 * x = (device const float4 *) (src0 + tgpig*args.nb01);
+    device const T0 * x = (device const T0 *) (src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01);
+    device       T  * y = (device       T  *) (dst  + i03*args.nb3  + i02*args.nb2  + i01*args.nb1);
 
     float sumf = 0.0f;
 
     // parallel sum
-    for (int i00 = tpitg; i00 < args.ne00_4; i00 += ntg) {
+    for (int i00 = tpitg.x; i00 < args.ne00; i00 += ntg.x) {
         sumf += dot(x[i00], x[i00]);
     }
     sumf = simd_sum(sumf);
@@ -2743,12 +2749,16 @@ kernel void kernel_l2_norm_f32(
 
     const float scale = 1.0f/sqrt(max(sumf, args.eps));
 
-    device float4 * y = (device float4 *) dst + tgpig*args.ne00_4;
-    for (int i00 = tpitg; i00 < args.ne00_4; i00 += ntg) {
+    for (int i00 = tpitg.x; i00 < args.ne00; i00 += ntg.x) {
         y[i00] = x[i00] * scale;
     }
 }
 
+typedef decltype(kernel_l2_norm_impl<float, float>) kernel_l2_norm_t;
+
+template [[host_name("kernel_l2_norm_f32_f32")]]   kernel kernel_l2_norm_t kernel_l2_norm_impl<float,  float>;
+template [[host_name("kernel_l2_norm_f32_f32_4")]] kernel kernel_l2_norm_t kernel_l2_norm_impl<float4, float4>;
+
 kernel void kernel_group_norm_f32(
         constant ggml_metal_kargs_group_norm & args,
         device const float * src0,