]> git.djapps.eu Git - pkg/ggml/sources/ggml/commitdiff
metal : add diag (llama/19330)
authorGeorgi Gerganov <redacted>
Thu, 5 Feb 2026 08:08:45 +0000 (10:08 +0200)
committerGeorgi Gerganov <redacted>
Sat, 7 Feb 2026 08:37:38 +0000 (10:37 +0200)
src/ggml-metal/ggml-metal-device.cpp
src/ggml-metal/ggml-metal-device.h
src/ggml-metal/ggml-metal-device.m
src/ggml-metal/ggml-metal-impl.h
src/ggml-metal/ggml-metal-ops.cpp
src/ggml-metal/ggml-metal-ops.h
src/ggml-metal/ggml-metal.metal

index 4cd3d93d81380d6b3d919b9b4e26b34cd487488c..6af0dd88d55ef1571a3bdb406c552e626751fecf 100644 (file)
@@ -176,6 +176,26 @@ ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_set_rows(ggml_me
     return res;
 }
 
+ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_diag(ggml_metal_library_t lib, const ggml_tensor * op) {
+    char base[256];
+    char name[256];
+
+    const int n = op->src[0]->ne[0];
+
+    snprintf(base, 256, "kernel_diag_%s", ggml_type_name(op->src[0]->type));
+    snprintf(name, 256, "%s_n=%d", base, n);
+
+    ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);
+    if (!res.pipeline) {
+        res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
+    }
+
+    res.nsg  = 1;
+    res.smem = 0;
+
+    return res;
+}
+
 ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_repeat(ggml_metal_library_t lib, ggml_type tsrc) {
     char base[256];
     char name[256];
index d8984327124458c8cbc6fb7f412114400fa4f7f9..84dcec308302e6ae9bccd16c9ab9f27f7110f892 100644 (file)
@@ -108,6 +108,7 @@ struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_pool_1d
 struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_pool_2d           (ggml_metal_library_t lib, const struct ggml_tensor * op, enum ggml_op_pool op_pool);
 struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_get_rows          (ggml_metal_library_t lib, enum ggml_type tsrc);
 struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_set_rows          (ggml_metal_library_t lib, enum ggml_type tidx, enum ggml_type tdst);
+struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_diag              (ggml_metal_library_t lib, const struct ggml_tensor * op);
 struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_repeat            (ggml_metal_library_t lib, enum ggml_type tsrc);
 struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_unary             (ggml_metal_library_t lib, const struct ggml_tensor * op);
 struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_glu               (ggml_metal_library_t lib, const struct ggml_tensor * op);
index 8a0b85c6e4d36056ff2bdb574219dd2ce9b6173f..c8e737d418714b17814c60fc786e2e95f54778c3 100644 (file)
@@ -1152,8 +1152,8 @@ bool ggml_metal_device_supports_op(ggml_metal_device_t dev, const struct ggml_te
             return has_simdgroup_reduction;
         case GGML_OP_RWKV_WKV6:
         case GGML_OP_RWKV_WKV7:
-        case GGML_OP_SOLVE_TRI:
             return true;
+        case GGML_OP_SOLVE_TRI:
         case GGML_OP_MUL_MAT:
         case GGML_OP_MUL_MAT_ID:
             return has_simdgroup_reduction;
@@ -1235,6 +1235,8 @@ bool ggml_metal_device_supports_op(ggml_metal_device_t dev, const struct ggml_te
                         return false;
                 };
             }
+        case GGML_OP_DIAG:
+            return true;
         case GGML_OP_OPT_STEP_ADAMW:
         case GGML_OP_OPT_STEP_SGD:
             return has_simdgroup_reduction;
index 640ade8f880e1f6a0fcbecec3094596e9cc7c66e..7f73cb97bbb48df3f6e761e47710e560d2bc9136 100644 (file)
@@ -792,6 +792,25 @@ typedef struct {
     uint64_t nb3;
 } ggml_metal_kargs_set_rows;
 
+typedef struct {
+    int32_t  ne00;
+    int32_t  ne01;
+    int32_t  ne02;
+    int32_t  ne03;
+    uint64_t nb00;
+    uint64_t nb01;
+    uint64_t nb02;
+    uint64_t nb03;
+    int32_t  ne0;
+    int32_t  ne1;
+    int32_t  ne2;
+    int32_t  ne3;
+    uint64_t nb0;
+    uint64_t nb1;
+    uint64_t nb2;
+    uint64_t nb3;
+} ggml_metal_kargs_diag;
+
 typedef struct {
     int64_t  ne00;
     int64_t  ne01;
index 753fcec317538b1b6979c1fa4ac38d752995e43e..e0ed6c7805cea280cc6c39b37d36bac1a0f81fb6 100644 (file)
@@ -361,6 +361,10 @@ static int ggml_metal_op_encode_impl(ggml_metal_op_t ctx, int idx) {
             {
                 n_fuse = ggml_metal_op_set_rows(ctx, idx);
             } break;
+        case GGML_OP_DIAG:
+            {
+                n_fuse = ggml_metal_op_diag(ctx, idx);
+            } break;
         case GGML_OP_L2_NORM:
             {
                 n_fuse = ggml_metal_op_l2_norm(ctx, idx);
@@ -1259,6 +1263,48 @@ int ggml_metal_op_set_rows(ggml_metal_op_t ctx, int idx) {
     return 1;
 }
 
+int ggml_metal_op_diag(ggml_metal_op_t ctx, int idx) {
+    ggml_tensor * op = ctx->node(idx);
+
+    ggml_metal_library_t lib = ctx->lib;
+    ggml_metal_encoder_t enc = ctx->enc;
+
+    GGML_TENSOR_LOCALS(int32_t,  ne0, op->src[0], ne);
+    GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
+    GGML_TENSOR_LOCALS(int32_t,  ne, op, ne);
+    GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
+
+    ggml_metal_kargs_diag args = {
+        /*.ne00 =*/ne00,
+        /*.ne01 =*/ne01,
+        /*.ne02 =*/ne02,
+        /*.ne03 =*/ne03,
+        /*.nb00 =*/nb00,
+        /*.nb01 =*/nb01,
+        /*.nb02 =*/nb02,
+        /*.nb03 =*/nb03,
+        /*.ne0  =*/ne0,
+        /*.ne1  =*/ne1,
+        /*.ne2  =*/ne2,
+        /*.ne3  =*/ne3,
+        /*.nb0  =*/nb0,
+        /*.nb1  =*/nb1,
+        /*.nb2  =*/nb2,
+        /*.nb3  =*/nb3,
+    };
+
+    auto pipeline = ggml_metal_library_get_pipeline_diag(lib, op);
+
+    ggml_metal_encoder_set_pipeline(enc, pipeline);
+    ggml_metal_encoder_set_bytes(enc, &args, sizeof(args), 0);
+    ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op->src[0]), 1);
+    ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op),         2);
+
+    ggml_metal_encoder_dispatch_threadgroups(enc, ne1, ne2, ne3, 32, 1, 1);
+
+    return 1;
+}
+
 int ggml_metal_op_soft_max(ggml_metal_op_t ctx, int idx) {
     ggml_tensor * op = ctx->node(idx);
 
index 2e4c7d3fa117d83c5e2484feb300369282b22f68..3c64e4f6007990b97bcab17368f3104cfe0c911d 100644 (file)
@@ -56,6 +56,7 @@ int ggml_metal_op_sum_rows          (ggml_metal_op_t ctx, int idx);
 int ggml_metal_op_cumsum            (ggml_metal_op_t ctx, int idx);
 int ggml_metal_op_get_rows          (ggml_metal_op_t ctx, int idx);
 int ggml_metal_op_set_rows          (ggml_metal_op_t ctx, int idx);
+int ggml_metal_op_diag              (ggml_metal_op_t ctx, int idx);
 int ggml_metal_op_soft_max          (ggml_metal_op_t ctx, int idx);
 int ggml_metal_op_ssm_conv          (ggml_metal_op_t ctx, int idx);
 int ggml_metal_op_ssm_scan          (ggml_metal_op_t ctx, int idx);
index c09a54e66142c9357138f45a9d9a1fedf260ef19..e54cdab39ddcc4476b20d3c6bbf0ba7dd00191ae 100644 (file)
@@ -8815,6 +8815,26 @@ kernel void kernel_set_rows_f(
     }
 }
 
+kernel void kernel_diag_f32(
+        constant ggml_metal_kargs_diag & args,
+        device   const char * src0,
+        device         char * dst,
+        uint3  tgpig[[threadgroup_position_in_grid]],
+        ushort tiitg[[thread_index_in_threadgroup]]) {
+    constexpr short NW = N_SIMDWIDTH;
+
+    const int32_t i3 = tgpig.z;
+    const int32_t i2 = tgpig.y;
+    const int32_t i1 = tgpig.x;
+
+    device const float * src0_ptr = (device const float *)(src0 +                i2*args.nb02 + i3*args.nb03);
+    device       float * dst_ptr  = (device       float *)(dst  + i1*args.nb01 + i2*args.nb2  + i3*args.nb3);
+
+    for (int i0 = tiitg; i0 < args.ne0; i0 += NW) {
+        dst_ptr[i0] = i0 == i1 ? src0_ptr[i0] : 0.0f;
+    }
+}
+
 constant bool FC_mul_mm_bc_inp [[function_constant(FC_MUL_MM + 0)]];
 constant bool FC_mul_mm_bc_out [[function_constant(FC_MUL_MM + 1)]];