]> git.djapps.eu Git - pkg/ggml/sources/whisper.cpp/commitdiff
metal : add count_equal op (llama/18314)
authorgatbontonpc <redacted>
Wed, 31 Dec 2025 08:39:48 +0000 (00:39 -0800)
committerGeorgi Gerganov <redacted>
Wed, 31 Dec 2025 15:52:09 +0000 (17:52 +0200)
* add count equal for metal

* remove trailing whitespace

* updated doc ops table

* changed shmem to i32

* added multi tg and templating

* removed BLAS support from Metal docs

* Apply suggestions from code review

Co-authored-by: Georgi Gerganov <redacted>
* add memset to set dst to 0

* metal : cleanup

---------

Co-authored-by: Georgi Gerganov <redacted>
ggml/src/ggml-metal/ggml-metal-device.cpp
ggml/src/ggml-metal/ggml-metal-device.h
ggml/src/ggml-metal/ggml-metal-device.m
ggml/src/ggml-metal/ggml-metal-impl.h
ggml/src/ggml-metal/ggml-metal-ops.cpp
ggml/src/ggml-metal/ggml-metal-ops.h
ggml/src/ggml-metal/ggml-metal.metal

index 680904d132d996f6d9802203ca80d710db8d0db9..b0734797f197bf8bed223eb10d94deae6060f9f4 100644 (file)
@@ -1684,3 +1684,60 @@ ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_opt_step_sgd(ggm
 
     return res;
 }
+
+ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_memset(ggml_metal_library_t lib, const ggml_tensor *  op) {
+    GGML_ASSERT(op->type == GGML_TYPE_I64);
+
+    char base[256];
+    char name[256];
+
+    snprintf(base, 256, "kernel_memset_%s", ggml_type_name(op->type));
+    snprintf(name, 256, "%s", base);
+
+    ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);
+    if (!res.pipeline) {
+        res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
+    }
+
+    return res;
+}
+
+ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_count_equal(ggml_metal_library_t lib, const ggml_tensor *  op) {
+    assert(op->op == GGML_OP_COUNT_EQUAL);
+
+    GGML_TENSOR_LOCALS(int64_t, ne0, op->src[0], ne);
+
+    GGML_ASSERT(op->src[0]->type == op->src[1]->type);
+    GGML_ASSERT(op->src[0]->type == GGML_TYPE_I32);
+    GGML_ASSERT(op->type == GGML_TYPE_I64);
+
+    // note: the kernel only supports i32 output due to metal atomic add only supporting atomic_int
+    GGML_ASSERT(ggml_nelements(op->src[0]) < (1LL << 31));
+
+    char base[256];
+    char name[256];
+
+    int nsg = 1;
+    while (32*nsg < ne00 && nsg < 32) {
+        nsg *= 2;
+    }
+
+    snprintf(base, 256, "kernel_count_equal_%s", ggml_type_name(op->src[0]->type));
+    snprintf(name, 256, "%s_nsg=%d", base, nsg);
+
+    ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);
+    if (!res.pipeline) {
+        ggml_metal_cv_t cv = ggml_metal_cv_init();
+
+        ggml_metal_cv_set_int16(cv, nsg, FC_COUNT_EQUAL + 0);
+
+        res = ggml_metal_library_compile_pipeline(lib, base, name, cv);
+
+        ggml_metal_cv_free(cv);
+    }
+
+    res.smem = 32 * sizeof(int32_t);
+    res.nsg  = nsg;
+
+    return res;
+}
index 0a8b9211a769f21890cedab4f81f285efc1c5116..d983b666ca2ec52786e8b98fd9b9cd8d1d9a3e50 100644 (file)
@@ -147,6 +147,8 @@ struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_arange
 struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_timestep_embedding(ggml_metal_library_t lib, const struct ggml_tensor * op);
 struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_opt_step_adamw    (ggml_metal_library_t lib, const struct ggml_tensor * op);
 struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_opt_step_sgd      (ggml_metal_library_t lib, const struct ggml_tensor * op);
+struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_memset            (ggml_metal_library_t lib, const struct ggml_tensor * op);
+struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_count_equal       (ggml_metal_library_t lib, const struct ggml_tensor * op);
 
 struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_flash_attn_ext_pad(
         ggml_metal_library_t lib,
index f24270bb1c583262c3e1b11b09b212d18f0b15fd..59badd00431a0dbef3187155caebfd1977605163 100644 (file)
@@ -1023,6 +1023,11 @@ bool ggml_metal_device_supports_op(ggml_metal_device_t dev, const struct ggml_te
             return has_simdgroup_reduction && ggml_is_contiguous_rows(op->src[0]);
         case GGML_OP_L2_NORM:
             return has_simdgroup_reduction && (op->ne[0] % 4 == 0 && ggml_is_contiguous_1(op->src[0]));
+        case GGML_OP_COUNT_EQUAL:
+            return has_simdgroup_reduction &&
+                op->src[0]->type == GGML_TYPE_I32 &&
+                op->src[1]->type == GGML_TYPE_I32 &&
+                op->type == GGML_TYPE_I64;
         case GGML_OP_ARGMAX:
             return has_simdgroup_reduction;
         case GGML_OP_NORM:
index 8944b07e9076b075e0d71d54e06a7ea1732256a5..d3b0e732ec46338d6ad5cbed64fe45e25cb7c170 100644 (file)
@@ -78,6 +78,7 @@
 #define FC_MUL_MM                      700
 #define FC_ROPE                        800
 #define FC_SSM_CONV                    900
+#define FC_COUNT_EQUAL                 1000
 
 // op-specific constants
 #define OP_FLASH_ATTN_EXT_NQPTG 8
@@ -894,6 +895,25 @@ typedef struct {
     float    step;
 } ggml_metal_kargs_arange;
 
+typedef struct {
+    int64_t val;
+} ggml_metal_kargs_memset;
+
+typedef struct {
+    int32_t  ne00;
+    int32_t  ne01;
+    int32_t  ne02;
+    int32_t  ne03;
+    uint64_t nb00;
+    uint64_t nb01;
+    uint64_t nb02;
+    uint64_t nb03;
+    uint64_t nb10;
+    uint64_t nb11;
+    uint64_t nb12;
+    uint64_t nb13;
+} ggml_metal_kargs_count_equal;
+
 typedef struct {
     int32_t  k0;
     int32_t  k1;
index e99c1763f631950305904a310407e5e85c1b95c1..acf2aa918479152e2884a8caa3ee41ac2f9ac1bd 100644 (file)
@@ -448,7 +448,11 @@ static int ggml_metal_op_encode_impl(ggml_metal_op_t ctx, int idx) {
             {
                 n_fuse = ggml_metal_op_opt_step_sgd(ctx, idx);
             } break;
-       default:
+        case GGML_OP_COUNT_EQUAL:
+            {
+                n_fuse = ggml_metal_op_count_equal(ctx, idx);
+            } break;
+        default:
             {
                 GGML_LOG_ERROR("%s: error: node %3d, op = %8s not implemented\n", __func__, idx, ggml_op_name(node->op));
                 GGML_ABORT("fatal error");
@@ -4090,3 +4094,64 @@ int ggml_metal_op_opt_step_sgd(ggml_metal_op_t ctx, int idx) {
 
     return 1;
 }
+
+int ggml_metal_op_count_equal(ggml_metal_op_t ctx, int idx) {
+    ggml_tensor * op = ctx->node(idx);
+
+    ggml_metal_library_t lib = ctx->lib;
+    ggml_metal_encoder_t enc = ctx->enc;
+
+    GGML_TENSOR_LOCALS(int32_t,  ne0, op->src[0], ne);
+    GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
+    GGML_TENSOR_LOCALS(uint64_t, nb1, op->src[1], nb);
+
+    {
+        ggml_metal_kargs_memset args = { /*.val =*/ 0 };
+
+        auto pipeline = ggml_metal_library_get_pipeline_memset(lib, op);
+
+        ggml_metal_encoder_set_pipeline(enc, pipeline);
+        ggml_metal_encoder_set_bytes(enc, &args, sizeof(args), 0);
+        ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op), 1);
+
+        ggml_metal_encoder_dispatch_threadgroups(enc, 1, 1, 1, 1, 1, 1);
+    }
+
+    ggml_metal_op_concurrency_reset(ctx);
+
+    {
+        ggml_metal_kargs_count_equal args = {
+            /*.ne00 =*/ ne00,
+            /*.ne01 =*/ ne01,
+            /*.ne02 =*/ ne02,
+            /*.ne03 =*/ ne03,
+            /*.nb00 =*/ nb00,
+            /*.nb01 =*/ nb01,
+            /*.nb02 =*/ nb02,
+            /*.nb03 =*/ nb03,
+            /*.nb10 =*/ nb10,
+            /*.nb11 =*/ nb11,
+            /*.nb12 =*/ nb12,
+            /*.nb13 =*/ nb13,
+        };
+
+        auto pipeline = ggml_metal_library_get_pipeline_count_equal(lib, op);
+
+        const size_t smem = pipeline.smem;
+
+        const int nth = 32*pipeline.nsg;
+
+        GGML_ASSERT(nth <= ggml_metal_pipeline_max_theads_per_threadgroup(pipeline));
+
+        ggml_metal_encoder_set_pipeline(enc, pipeline);
+        ggml_metal_encoder_set_bytes(enc, &args, sizeof(args), 0);
+        ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op->src[0]), 1);
+        ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op->src[1]), 2);
+        ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op), 3);
+
+        ggml_metal_encoder_set_threadgroup_memory_size(enc, smem, 0);
+        ggml_metal_encoder_dispatch_threadgroups(enc, ne01, ne02, ne03, nth, 1, 1);
+    }
+
+    return 1;
+}
index 902b5445232a64d3739d4b92509c495c4fa84e0b..c1025d35677f95db62fc27cff10bce17ce0fccb7 100644 (file)
@@ -87,6 +87,7 @@ int ggml_metal_op_leaky_relu        (ggml_metal_op_t ctx, int idx);
 int ggml_metal_op_tri               (ggml_metal_op_t ctx, int idx);
 int ggml_metal_op_opt_step_adamw    (ggml_metal_op_t ctx, int idx);
 int ggml_metal_op_opt_step_sgd      (ggml_metal_op_t ctx, int idx);
+int ggml_metal_op_count_equal       (ggml_metal_op_t ctx, int idx);
 
 #ifdef __cplusplus
 }
index 3154beff9bb3e9a9b771b81197b3ec4f4b26fb90..67b30e0d93c819a9ed535ae5be1b2cfd211f9509 100644 (file)
@@ -1790,6 +1790,7 @@ kernel void kernel_op_sum_f32(
         return;
     }
 
+    // TODO: become function constant
     const uint nsg = (ntg.x + 31) / 32;
 
     float sumf = 0;
@@ -9914,3 +9915,75 @@ kernel void kernel_opt_step_sgd_f32(
 
     x[gid] = x[gid] * (1.0f - pars[0] * pars[1]) - pars[0] * g[gid];
 }
+
+template<typename T>
+kernel void kernel_memset(
+        constant ggml_metal_kargs_fill & args,
+        device T * dst,
+        uint tpig[[thread_position_in_grid]]) {
+    dst[tpig] = args.val;
+}
+
+typedef decltype(kernel_memset<int64_t>) kernel_memset_t;
+
+template [[host_name("kernel_memset_i64")]] kernel kernel_memset_t kernel_memset<int64_t>;
+
+constant short FC_count_equal_nsg [[function_constant(FC_COUNT_EQUAL + 0)]];
+
+template<typename T>
+kernel void kernel_count_equal(
+        constant ggml_metal_kargs_count_equal & args,
+        device   const char * src0,
+        device   const char * src1,
+        device   atomic_int * dst,
+        threadgroup int32_t * shmem_i32 [[threadgroup(0)]],
+        uint3   tgpig[[threadgroup_position_in_grid]],
+        ushort3 tpitg[[thread_position_in_threadgroup]],
+        ushort  sgitg[[simdgroup_index_in_threadgroup]],
+        ushort  tiisg[[thread_index_in_simdgroup]],
+        ushort3   ntg[[threads_per_threadgroup]]) {
+    const short NSG = FC_count_equal_nsg;
+
+    const int i3 = tgpig.z;
+    const int i2 = tgpig.y;
+    const int i1 = tgpig.x;
+
+    if (i3 >= args.ne03 || i2 >= args.ne02 || i1 >= args.ne01) {
+        return;
+    }
+
+    int sum = 0;
+
+    device const char * base0 = src0 + i1*args.nb01 + i2*args.nb02 + i3*args.nb03;
+    device const char * base1 = src1 + i1*args.nb11 + i2*args.nb12 + i3*args.nb13;
+
+    for (int64_t i0 = tpitg.x; i0 < args.ne00; i0 += ntg.x) {
+        const T v0 = *(device const T *)(base0 + i0*args.nb00);
+        const T v1 = *(device const T *)(base1 + i0*args.nb10);
+        sum += (v0 == v1);
+    }
+
+    sum = simd_sum(sum);
+
+    if (tiisg == 0) {
+        shmem_i32[sgitg] = sum;
+    }
+
+    threadgroup_barrier(mem_flags::mem_threadgroup);
+
+    if (sgitg == 0) {
+        float v = 0.0f;
+        if (tpitg.x < NSG) {
+            v = shmem_i32[tpitg.x];
+        }
+
+        float total = simd_sum(v);
+        if (tpitg.x == 0) {
+            atomic_fetch_add_explicit(dst, (int32_t) total, memory_order_relaxed);
+        }
+    }
+}
+
+typedef decltype(kernel_count_equal<int32_t>) kernel_count_equal_t;
+
+template [[host_name("kernel_count_equal_i32")]] kernel kernel_count_equal_t kernel_count_equal<int32_t>;