]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
metal : add mean kernel (#14267)
authorGeorgi Gerganov <redacted>
Thu, 19 Jun 2025 05:05:21 +0000 (08:05 +0300)
committerGitHub <redacted>
Thu, 19 Jun 2025 05:05:21 +0000 (08:05 +0300)
* metal : add mean kernel

ggml-ci

* cont : dedup implementation

ggml-ci

ggml/src/ggml-metal/ggml-metal.m
ggml/src/ggml-metal/ggml-metal.metal

index bc93bc633a49b224d38d3485053c205ba2bbdec3..4e7f373cb435a6cf6cc07b137c89c67691bfd502 100644 (file)
@@ -498,6 +498,7 @@ enum ggml_metal_kernel_type {
     GGML_METAL_KERNEL_TYPE_COS,
     GGML_METAL_KERNEL_TYPE_NEG,
     GGML_METAL_KERNEL_TYPE_SUM_ROWS,
+    GGML_METAL_KERNEL_TYPE_MEAN,
     GGML_METAL_KERNEL_TYPE_POOL_2D_AVG_F32,
     GGML_METAL_KERNEL_TYPE_POOL_2D_MAX_F32,
     GGML_METAL_KERNEL_TYPE_ARGMAX,
@@ -1454,6 +1455,7 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_COS,                             cos,                             true);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_NEG,                             neg,                             true);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SUM_ROWS,                        sum_rows,                        true);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MEAN,                            mean,                            true);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ARGMAX,                          argmax,                          true);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_POOL_2D_AVG_F32,                 pool_2d_avg_f32,                 true);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_POOL_2D_MAX_F32,                 pool_2d_max_f32,                 true);
@@ -1653,6 +1655,7 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_device_contex
         case GGML_OP_LOG:
             return false; // TODO: implement
         case GGML_OP_SUM_ROWS:
+        case GGML_OP_MEAN:
         case GGML_OP_SOFT_MAX:
         case GGML_OP_GROUP_NORM:
             return has_simdgroup_reduction && ggml_is_contiguous(op->src[0]);
@@ -2400,11 +2403,30 @@ static bool ggml_metal_encode_node(
                 [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
             } break;
         case GGML_OP_SUM_ROWS:
+        case GGML_OP_MEAN:
             {
                 GGML_ASSERT(src0->nb[0] == ggml_type_size(src0->type));
 
-                id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SUM_ROWS].pipeline;
+                id<MTLComputePipelineState> pipeline = nil;
+
+                switch (dst->op) {
+                    case GGML_OP_SUM_ROWS:
+                        pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SUM_ROWS].pipeline;
+                        break;
+                    case GGML_OP_MEAN:
+                        pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MEAN].pipeline;
+                        break;
+                    default:
+                        GGML_ABORT("fatal error");
+                }
+
+                int nth = 32; // SIMD width
+
+                while (nth < ne00 && nth < (int) pipeline.maxTotalThreadsPerThreadgroup) {
+                    nth *= 2;
+                }
 
+                nth = MIN(nth, ne00);
 
                 ggml_metal_kargs_sum_rows args = {
                    /*.ne00 =*/ ne00,
@@ -2434,11 +2456,12 @@ static bool ggml_metal_encode_node(
                 };
 
                 [encoder setComputePipelineState:pipeline];
-                [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
-                [encoder setBuffer:id_dst  offset:offs_dst  atIndex:1];
-                [encoder setBytes:&args length:sizeof(args) atIndex:2];
+                [encoder setBytes:&args length:sizeof(args) atIndex:0];
+                [encoder setBuffer:id_src0 offset:offs_src0 atIndex:1];
+                [encoder setBuffer:id_dst  offset:offs_dst  atIndex:2];
+                [encoder setThreadgroupMemoryLength:32*sizeof(float) atIndex:0];
 
-                [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
+                [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
             } break;
         case GGML_OP_SOFT_MAX:
             {
index 5d7760217f82644602fdb635f52943a8b9f89d56..3da19879b4b364b993390cf748c60748e9f14964 100644 (file)
@@ -993,31 +993,61 @@ kernel void kernel_neg(
     dst[tpig] = -src0[tpig];
 }
 
+template <bool norm>
 kernel void kernel_sum_rows(
+        constant ggml_metal_kargs_sum_rows & args,
         device const float * src0,
         device       float * dst,
-        constant ggml_metal_kargs_sum_rows & args,
-        uint3 tpig[[thread_position_in_grid]]) {
-    int64_t i3 = tpig.z;
-    int64_t i2 = tpig.y;
-    int64_t i1 = tpig.x;
+        threadgroup  float * shmem_f32 [[threadgroup(0)]],
+        uint3   tgpig[[threadgroup_position_in_grid]],
+        ushort3 tpitg[[thread_position_in_threadgroup]],
+        ushort  sgitg[[simdgroup_index_in_threadgroup]],
+        ushort  tiisg[[thread_index_in_simdgroup]],
+        ushort3   ntg[[threads_per_threadgroup]]) {
+    int64_t i3 = tgpig.z;
+    int64_t i2 = tgpig.y;
+    int64_t i1 = tgpig.x;
 
     if (i3 >= args.ne03 || i2 >= args.ne02 || i1 >= args.ne01) {
         return;
     }
 
+    if (sgitg == 0) {
+        shmem_f32[tiisg] = 0.0f;
+    }
+
     device const float * src_row = (device const float *) ((device const char *) src0 + i1*args.nb01 + i2*args.nb02 + i3*args.nb03);
     device       float * dst_row = (device       float *) ((device       char *) dst  + i1*args.nb1  + i2*args.nb2  + i3*args.nb3);
 
-    float row_sum = 0;
+    float sumf = 0;
 
-    for (int64_t i0 = 0; i0 < args.ne00; i0++) {
-        row_sum += src_row[i0];
+    for (int64_t i0 = tpitg.x; i0 < args.ne00; i0 += ntg.x) {
+        sumf += src_row[i0];
     }
 
-    dst_row[0] = row_sum;
+    sumf = simd_sum(sumf);
+
+    threadgroup_barrier(mem_flags::mem_threadgroup);
+
+    if (tiisg == 0) {
+        shmem_f32[sgitg] = sumf;
+    }
+
+    threadgroup_barrier(mem_flags::mem_threadgroup);
+
+    sumf = shmem_f32[tiisg];
+    sumf = simd_sum(sumf);
+
+    if (tpitg.x == 0) {
+        dst_row[0] = norm ? sumf / args.ne00 : sumf;
+    }
 }
 
+typedef decltype(kernel_sum_rows<false>) kernel_sum_rows_t;
+
+template [[host_name("kernel_sum_rows")]] kernel kernel_sum_rows_t kernel_sum_rows<false>;
+template [[host_name("kernel_mean")]]     kernel kernel_sum_rows_t kernel_sum_rows<true>;
+
 template<typename T>
 kernel void kernel_soft_max(
         device const  char * src0,