]> git.djapps.eu Git - pkg/ggml/sources/ggml/commitdiff
metal : add solve_tri (llama/19302)
authorGeorgi Gerganov <redacted>
Tue, 3 Feb 2026 21:43:14 +0000 (23:43 +0200)
committerGeorgi Gerganov <redacted>
Sat, 7 Feb 2026 08:37:38 +0000 (10:37 +0200)
src/ggml-metal/ggml-metal-device.cpp
src/ggml-metal/ggml-metal-device.h
src/ggml-metal/ggml-metal-device.m
src/ggml-metal/ggml-metal-impl.h
src/ggml-metal/ggml-metal-ops.cpp
src/ggml-metal/ggml-metal-ops.h
src/ggml-metal/ggml-metal.metal

index 377b0d3eb8f5359c042b5eb8e144088030af585b..4cd3d93d81380d6b3d919b9b4e26b34cd487488c 100644 (file)
@@ -534,6 +534,36 @@ ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_rwkv(ggml_metal_
     return res;
 }
 
+ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_solve_tri(ggml_metal_library_t lib, const ggml_tensor * op) {
+    char base[256];
+    char name[256];
+
+    const int nsg = 8;
+    const int n   = op->src[1]->ne[1];
+    const int k   = op->src[1]->ne[0];
+
+    snprintf(base, 256, "kernel_solve_tri_%s", ggml_type_name(op->src[0]->type));
+    snprintf(name, 256, "%s_nsg=%d_n=%d_k=%d", base, nsg, n, k);
+
+    ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);
+    if (!res.pipeline) {
+        ggml_metal_cv_t cv = ggml_metal_cv_init();
+
+        ggml_metal_cv_set_int16(cv, nsg, FC_SOLVE_TRI + 0);
+        ggml_metal_cv_set_int16(cv, n,   FC_SOLVE_TRI + 1);
+        ggml_metal_cv_set_int16(cv, k,   FC_SOLVE_TRI + 2);
+
+        res = ggml_metal_library_compile_pipeline(lib, base, name, cv);
+
+        ggml_metal_cv_free(cv);
+    }
+
+    res.nsg  = nsg;
+    res.smem = GGML_PAD(GGML_PAD(n, 32)*nsg*sizeof(float), 16);
+
+    return res;
+}
+
 ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_mul_mv_ext(ggml_metal_library_t lib, ggml_type tsrc0, ggml_type tsrc1, int nsg, int nxpsg, int r1ptg) {
     char base[256];
     char name[256];
index afb091e725579326327f01c7da2138a3b911f19e..d8984327124458c8cbc6fb7f412114400fa4f7f9 100644 (file)
@@ -121,6 +121,7 @@ struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_ssm_conv
 struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_ssm_conv_batched  (ggml_metal_library_t lib, const struct ggml_tensor * op, int ssm_conv_bs);
 struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_ssm_scan          (ggml_metal_library_t lib, const struct ggml_tensor * op);
 struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_rwkv              (ggml_metal_library_t lib, const struct ggml_tensor * op);
+struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_solve_tri         (ggml_metal_library_t lib, const struct ggml_tensor * op);
 struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_mul_mv_ext        (ggml_metal_library_t lib, enum ggml_type tsrc0, enum ggml_type tsrc1, int nsg, int nxpsg, int r1ptg);
 struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_mul_mm            (ggml_metal_library_t lib, const struct ggml_tensor * op);
 struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_mul_mv            (ggml_metal_library_t lib, const struct ggml_tensor * op);
index 285dd1630e7c6ec28fb612bed1890f9f3d5085db..8a0b85c6e4d36056ff2bdb574219dd2ce9b6173f 100644 (file)
@@ -1152,6 +1152,7 @@ bool ggml_metal_device_supports_op(ggml_metal_device_t dev, const struct ggml_te
             return has_simdgroup_reduction;
         case GGML_OP_RWKV_WKV6:
         case GGML_OP_RWKV_WKV7:
+        case GGML_OP_SOLVE_TRI:
             return true;
         case GGML_OP_MUL_MAT:
         case GGML_OP_MUL_MAT_ID:
index e074f2ef3db157c75bc9ca5b1c36d3e05db96178..640ade8f880e1f6a0fcbecec3094596e9cc7c66e 100644 (file)
@@ -78,7 +78,8 @@
 #define FC_MUL_MM                      700
 #define FC_ROPE                        800
 #define FC_SSM_CONV                    900
-#define FC_COUNT_EQUAL                 1000
+#define FC_SOLVE_TRI                   1000
+#define FC_COUNT_EQUAL                 1100
 
 // op-specific constants
 #define OP_FLASH_ATTN_EXT_NQPSG 8
@@ -733,6 +734,33 @@ typedef struct {
     uint64_t nb0;
 } ggml_metal_kargs_ssm_scan;
 
+typedef struct {
+    int32_t  ne00;
+    int32_t  ne01;
+    int32_t  ne02;
+    int32_t  ne03;
+    uint64_t nb00;
+    uint64_t nb01;
+    uint64_t nb02;
+    uint64_t nb03;
+    int32_t  ne10;
+    int32_t  ne11;
+    int32_t  ne12;
+    int32_t  ne13;
+    uint64_t nb10;
+    uint64_t nb11;
+    uint64_t nb12;
+    uint64_t nb13;
+    int32_t  ne0;
+    int32_t  ne1;
+    int32_t  ne2;
+    int32_t  ne3;
+    uint64_t nb0;
+    uint64_t nb1;
+    uint64_t nb2;
+    uint64_t nb3;
+} ggml_metal_kargs_solve_tri;
+
 typedef struct {
     int32_t  ne00t;
     int32_t  ne00;
index f97c4435dec9f2d818702bdd01ab9bb27fd010b0..753fcec317538b1b6979c1fa4ac38d752995e43e 100644 (file)
@@ -341,6 +341,10 @@ static int ggml_metal_op_encode_impl(ggml_metal_op_t ctx, int idx) {
             {
                 n_fuse = ggml_metal_op_rwkv(ctx, idx);
             } break;
+        case GGML_OP_SOLVE_TRI:
+            {
+                n_fuse = ggml_metal_op_solve_tri(ctx, idx);
+            } break;
         case GGML_OP_MUL_MAT:
             {
                 n_fuse = ggml_metal_op_mul_mat(ctx, idx);
@@ -1557,6 +1561,63 @@ int ggml_metal_op_rwkv(ggml_metal_op_t ctx, int idx) {
     return 1;
 }
 
+int ggml_metal_op_solve_tri(ggml_metal_op_t ctx, int idx) {
+    ggml_tensor * op = ctx->node(idx);
+
+    ggml_metal_library_t lib = ctx->lib;
+    ggml_metal_encoder_t enc = ctx->enc;
+
+    GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
+    GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
+    GGML_TENSOR_LOCALS( int32_t, ne1, op->src[1], ne);
+    GGML_TENSOR_LOCALS(uint64_t, nb1, op->src[1], nb);
+    GGML_TENSOR_LOCALS( int32_t, ne,  op,         ne);
+    GGML_TENSOR_LOCALS(uint64_t, nb,  op,         nb);
+
+    ggml_metal_kargs_solve_tri args = {
+        /*.ne00 =*/ ne00,
+        /*.ne01 =*/ ne01,
+        /*.ne02 =*/ ne02,
+        /*.ne03 =*/ ne03,
+        /*.nb00 =*/ nb00,
+        /*.nb01 =*/ nb01,
+        /*.nb02 =*/ nb02,
+        /*.nb03 =*/ nb03,
+        /*.ne10 =*/ ne10,
+        /*.ne11 =*/ ne11,
+        /*.ne12 =*/ ne12,
+        /*.ne13 =*/ ne13,
+        /*.nb10 =*/ nb10,
+        /*.nb11 =*/ nb11,
+        /*.nb12 =*/ nb12,
+        /*.nb13 =*/ nb13,
+        /*.ne0  =*/ ne0,
+        /*.ne1  =*/ ne1,
+        /*.ne2  =*/ ne2,
+        /*.ne3  =*/ ne3,
+        /*.nb0  =*/ nb0,
+        /*.nb1  =*/ nb1,
+        /*.nb2  =*/ nb2,
+        /*.nb3  =*/ nb3,
+    };
+
+    auto pipeline = ggml_metal_library_get_pipeline_solve_tri(lib, op);
+
+    ggml_metal_encoder_set_pipeline(enc, pipeline);
+    ggml_metal_encoder_set_bytes   (enc, &args, sizeof(args), 0);
+    ggml_metal_encoder_set_buffer  (enc, ggml_metal_get_buffer_id(op->src[0]), 1);
+    ggml_metal_encoder_set_buffer  (enc, ggml_metal_get_buffer_id(op->src[1]), 2);
+    ggml_metal_encoder_set_buffer  (enc, ggml_metal_get_buffer_id(op),         3);
+
+    const int nsg = pipeline.nsg;
+
+    ggml_metal_encoder_set_threadgroup_memory_size(enc, pipeline.smem, 0);
+
+    ggml_metal_encoder_dispatch_threadgroups(enc, (ne10 + nsg - 1)/nsg, ne02, ne03, 32, nsg, 1);
+
+    return 1;
+}
+
 int ggml_metal_op_cpy(ggml_metal_op_t ctx, int idx) {
     ggml_tensor * op = ctx->node(idx);
 
index 10686a334e0970ab3239011a9a6ab5b7460793e1..2e4c7d3fa117d83c5e2484feb300369282b22f68 100644 (file)
@@ -60,6 +60,7 @@ int ggml_metal_op_soft_max          (ggml_metal_op_t ctx, int idx);
 int ggml_metal_op_ssm_conv          (ggml_metal_op_t ctx, int idx);
 int ggml_metal_op_ssm_scan          (ggml_metal_op_t ctx, int idx);
 int ggml_metal_op_rwkv              (ggml_metal_op_t ctx, int idx);
+int ggml_metal_op_solve_tri         (ggml_metal_op_t ctx, int idx);
 int ggml_metal_op_cpy               (ggml_metal_op_t ctx, int idx);
 int ggml_metal_op_pool_1d           (ggml_metal_op_t ctx, int idx);
 int ggml_metal_op_pool_2d           (ggml_metal_op_t ctx, int idx);
index 3259213fd6160a48ebe66a1ef2265518961e77b8..c09a54e66142c9357138f45a9d9a1fedf260ef19 100644 (file)
@@ -2737,6 +2737,83 @@ kernel void kernel_rwkv_wkv7_f32(
     }
 }
 
+constant short FC_solve_tri_nsg [[function_constant(FC_SOLVE_TRI + 0)]];
+constant short FC_solve_tri_n   [[function_constant(FC_SOLVE_TRI + 1)]];
+constant short FC_solve_tri_k   [[function_constant(FC_SOLVE_TRI + 2)]];
+
+kernel void kernel_solve_tri_f32(
+        constant ggml_metal_kargs_solve_tri & args,
+        device   const char * src0,
+        device   const char * src1,
+        device         char * dst,
+        threadgroup    char * shmem [[threadgroup(0)]],
+        ushort3 tgpig[[threadgroup_position_in_grid]],
+        ushort  sgitg[[simdgroup_index_in_threadgroup]],
+        ushort  tiisg[[thread_index_in_simdgroup]],
+        ushort3   ntg[[threads_per_threadgroup]]) {
+    constexpr short NW = N_SIMDWIDTH;
+
+    const short NSG = FC_solve_tri_nsg;
+    const short N   = FC_solve_tri_n;
+    const short K   = FC_solve_tri_k;
+    const short NP  = PAD2(N, NW);
+
+    const int32_t ne02 = args.ne02;
+    const int32_t ne03 = args.ne03;
+
+    const int32_t i03 = tgpig.z;
+    const int32_t i02 = tgpig.y;
+    const int32_t i01 = tgpig.x*NSG + sgitg;
+
+    threadgroup float * sh0 = (threadgroup float *) shmem;
+
+    device const float * src0_ptr = (device const float *)(src0 + i02 * args.nb02 + i03 * args.nb03) + sgitg*N;
+    device const float * src1_ptr = (device const float *)(src1 + i02 * args.nb12 + i03 * args.nb13) + i01;
+    device       float * dst_ptr  = (device       float *)(dst  + i02 * args.nb2  + i03 * args.nb3)  + i01;
+
+    for (short rr = 0; rr < N; rr += NSG) {
+        threadgroup_barrier(mem_flags::mem_threadgroup);
+
+        {
+            threadgroup float * sh0_cur = sh0 + sgitg*NP;
+
+            for (short t = 0; t*NW < N; ++t) {
+                const short idx = t*NW + tiisg;
+                sh0_cur[idx] = src0_ptr[idx];
+            }
+
+            src0_ptr += NSG*N;
+        }
+
+        threadgroup_barrier(mem_flags::mem_threadgroup);
+
+        if (i01 >= args.ne10) {
+            continue;
+        }
+
+        for (short ir = 0; ir < NSG && rr + ir < N; ++ir) {
+            const short r = rr + ir;
+
+            threadgroup float * sh0_cur = sh0 + ir*NP;
+
+            float sum = 0.0f;
+
+            for (short t = 0; t*NW < r; ++t) {
+                const short idx = t*NW + tiisg;
+                sum += sh0_cur[idx] * dst_ptr[idx*K] * (idx < r);
+            }
+
+            sum = simd_sum(sum);
+
+            if (tiisg == 0) {
+                const float diag = sh0_cur[r];
+
+                dst_ptr[r*K] = (src1_ptr[r*K] - sum) / diag;
+            }
+        }
+    }
+}
+
 kernel void kernel_argmax_f32(
         constant ggml_metal_kargs_argmax & args,
         device   const char * src0,