return res;
}
+ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_solve_tri(ggml_metal_library_t lib, const ggml_tensor * op) {
+ char base[256];
+ char name[256];
+
+ const int nsg = 8;
+ const int n = op->src[1]->ne[1];
+ const int k = op->src[1]->ne[0];
+
+ snprintf(base, 256, "kernel_solve_tri_%s", ggml_type_name(op->src[0]->type));
+ snprintf(name, 256, "%s_nsg=%d_n=%d_k=%d", base, nsg, n, k);
+
+ ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);
+ if (!res.pipeline) {
+ ggml_metal_cv_t cv = ggml_metal_cv_init();
+
+ ggml_metal_cv_set_int16(cv, nsg, FC_SOLVE_TRI + 0);
+ ggml_metal_cv_set_int16(cv, n, FC_SOLVE_TRI + 1);
+ ggml_metal_cv_set_int16(cv, k, FC_SOLVE_TRI + 2);
+
+ res = ggml_metal_library_compile_pipeline(lib, base, name, cv);
+
+ ggml_metal_cv_free(cv);
+ }
+
+ res.nsg = nsg;
+ res.smem = GGML_PAD(GGML_PAD(n, 32)*nsg*sizeof(float), 16);
+
+ return res;
+}
+
ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_mul_mv_ext(ggml_metal_library_t lib, ggml_type tsrc0, ggml_type tsrc1, int nsg, int nxpsg, int r1ptg) {
char base[256];
char name[256];
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_ssm_conv_batched (ggml_metal_library_t lib, const struct ggml_tensor * op, int ssm_conv_bs);
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_ssm_scan (ggml_metal_library_t lib, const struct ggml_tensor * op);
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_rwkv (ggml_metal_library_t lib, const struct ggml_tensor * op);
+struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_solve_tri (ggml_metal_library_t lib, const struct ggml_tensor * op);
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_mul_mv_ext (ggml_metal_library_t lib, enum ggml_type tsrc0, enum ggml_type tsrc1, int nsg, int nxpsg, int r1ptg);
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_mul_mm (ggml_metal_library_t lib, const struct ggml_tensor * op);
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_mul_mv (ggml_metal_library_t lib, const struct ggml_tensor * op);
return has_simdgroup_reduction;
case GGML_OP_RWKV_WKV6:
case GGML_OP_RWKV_WKV7:
+ case GGML_OP_SOLVE_TRI:
return true;
case GGML_OP_MUL_MAT:
case GGML_OP_MUL_MAT_ID:
#define FC_MUL_MM 700
#define FC_ROPE 800
#define FC_SSM_CONV 900
-#define FC_COUNT_EQUAL 1000
+#define FC_SOLVE_TRI 1000
+#define FC_COUNT_EQUAL 1100
// op-specific constants
#define OP_FLASH_ATTN_EXT_NQPSG 8
uint64_t nb0;
} ggml_metal_kargs_ssm_scan;
+typedef struct {
+ int32_t ne00;
+ int32_t ne01;
+ int32_t ne02;
+ int32_t ne03;
+ uint64_t nb00;
+ uint64_t nb01;
+ uint64_t nb02;
+ uint64_t nb03;
+ int32_t ne10;
+ int32_t ne11;
+ int32_t ne12;
+ int32_t ne13;
+ uint64_t nb10;
+ uint64_t nb11;
+ uint64_t nb12;
+ uint64_t nb13;
+ int32_t ne0;
+ int32_t ne1;
+ int32_t ne2;
+ int32_t ne3;
+ uint64_t nb0;
+ uint64_t nb1;
+ uint64_t nb2;
+ uint64_t nb3;
+} ggml_metal_kargs_solve_tri;
+
typedef struct {
int32_t ne00t;
int32_t ne00;
{
n_fuse = ggml_metal_op_rwkv(ctx, idx);
} break;
+ case GGML_OP_SOLVE_TRI:
+ {
+ n_fuse = ggml_metal_op_solve_tri(ctx, idx);
+ } break;
case GGML_OP_MUL_MAT:
{
n_fuse = ggml_metal_op_mul_mat(ctx, idx);
return 1;
}
+int ggml_metal_op_solve_tri(ggml_metal_op_t ctx, int idx) {
+ ggml_tensor * op = ctx->node(idx);
+
+ ggml_metal_library_t lib = ctx->lib;
+ ggml_metal_encoder_t enc = ctx->enc;
+
+ GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
+ GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
+ GGML_TENSOR_LOCALS( int32_t, ne1, op->src[1], ne);
+ GGML_TENSOR_LOCALS(uint64_t, nb1, op->src[1], nb);
+ GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
+ GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
+
+ ggml_metal_kargs_solve_tri args = {
+ /*.ne00 =*/ ne00,
+ /*.ne01 =*/ ne01,
+ /*.ne02 =*/ ne02,
+ /*.ne03 =*/ ne03,
+ /*.nb00 =*/ nb00,
+ /*.nb01 =*/ nb01,
+ /*.nb02 =*/ nb02,
+ /*.nb03 =*/ nb03,
+ /*.ne10 =*/ ne10,
+ /*.ne11 =*/ ne11,
+ /*.ne12 =*/ ne12,
+ /*.ne13 =*/ ne13,
+ /*.nb10 =*/ nb10,
+ /*.nb11 =*/ nb11,
+ /*.nb12 =*/ nb12,
+ /*.nb13 =*/ nb13,
+ /*.ne0 =*/ ne0,
+ /*.ne1 =*/ ne1,
+ /*.ne2 =*/ ne2,
+ /*.ne3 =*/ ne3,
+ /*.nb0 =*/ nb0,
+ /*.nb1 =*/ nb1,
+ /*.nb2 =*/ nb2,
+ /*.nb3 =*/ nb3,
+ };
+
+ auto pipeline = ggml_metal_library_get_pipeline_solve_tri(lib, op);
+
+ ggml_metal_encoder_set_pipeline(enc, pipeline);
+ ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
+ ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), 1);
+ ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[1]), 2);
+ ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 3);
+
+ const int nsg = pipeline.nsg;
+
+ ggml_metal_encoder_set_threadgroup_memory_size(enc, pipeline.smem, 0);
+
+ ggml_metal_encoder_dispatch_threadgroups(enc, (ne10 + nsg - 1)/nsg, ne02, ne03, 32, nsg, 1);
+
+ return 1;
+}
+
int ggml_metal_op_cpy(ggml_metal_op_t ctx, int idx) {
ggml_tensor * op = ctx->node(idx);
int ggml_metal_op_ssm_conv (ggml_metal_op_t ctx, int idx);
int ggml_metal_op_ssm_scan (ggml_metal_op_t ctx, int idx);
int ggml_metal_op_rwkv (ggml_metal_op_t ctx, int idx);
+int ggml_metal_op_solve_tri (ggml_metal_op_t ctx, int idx);
int ggml_metal_op_cpy (ggml_metal_op_t ctx, int idx);
int ggml_metal_op_pool_1d (ggml_metal_op_t ctx, int idx);
int ggml_metal_op_pool_2d (ggml_metal_op_t ctx, int idx);
}
}
+constant short FC_solve_tri_nsg [[function_constant(FC_SOLVE_TRI + 0)]];
+constant short FC_solve_tri_n [[function_constant(FC_SOLVE_TRI + 1)]];
+constant short FC_solve_tri_k [[function_constant(FC_SOLVE_TRI + 2)]];
+
+kernel void kernel_solve_tri_f32(
+ constant ggml_metal_kargs_solve_tri & args,
+ device const char * src0,
+ device const char * src1,
+ device char * dst,
+ threadgroup char * shmem [[threadgroup(0)]],
+ ushort3 tgpig[[threadgroup_position_in_grid]],
+ ushort sgitg[[simdgroup_index_in_threadgroup]],
+ ushort tiisg[[thread_index_in_simdgroup]],
+ ushort3 ntg[[threads_per_threadgroup]]) {
+ constexpr short NW = N_SIMDWIDTH;
+
+ const short NSG = FC_solve_tri_nsg;
+ const short N = FC_solve_tri_n;
+ const short K = FC_solve_tri_k;
+ const short NP = PAD2(N, NW);
+
+ const int32_t ne02 = args.ne02;
+ const int32_t ne03 = args.ne03;
+
+ const int32_t i03 = tgpig.z;
+ const int32_t i02 = tgpig.y;
+ const int32_t i01 = tgpig.x*NSG + sgitg;
+
+ threadgroup float * sh0 = (threadgroup float *) shmem;
+
+ device const float * src0_ptr = (device const float *)(src0 + i02 * args.nb02 + i03 * args.nb03) + sgitg*N;
+ device const float * src1_ptr = (device const float *)(src1 + i02 * args.nb12 + i03 * args.nb13) + i01;
+ device float * dst_ptr = (device float *)(dst + i02 * args.nb2 + i03 * args.nb3) + i01;
+
+ for (short rr = 0; rr < N; rr += NSG) {
+ threadgroup_barrier(mem_flags::mem_threadgroup);
+
+ {
+ threadgroup float * sh0_cur = sh0 + sgitg*NP;
+
+ for (short t = 0; t*NW < N; ++t) {
+ const short idx = t*NW + tiisg;
+ sh0_cur[idx] = src0_ptr[idx];
+ }
+
+ src0_ptr += NSG*N;
+ }
+
+ threadgroup_barrier(mem_flags::mem_threadgroup);
+
+ if (i01 >= args.ne10) {
+ continue;
+ }
+
+ for (short ir = 0; ir < NSG && rr + ir < N; ++ir) {
+ const short r = rr + ir;
+
+ threadgroup float * sh0_cur = sh0 + ir*NP;
+
+ float sum = 0.0f;
+
+ for (short t = 0; t*NW < r; ++t) {
+ const short idx = t*NW + tiisg;
+ sum += sh0_cur[idx] * dst_ptr[idx*K] * (idx < r);
+ }
+
+ sum = simd_sum(sum);
+
+ if (tiisg == 0) {
+ const float diag = sh0_cur[r];
+
+ dst_ptr[r*K] = (src1_ptr[r*K] - sum) / diag;
+ }
+ }
+ }
+}
+
kernel void kernel_argmax_f32(
constant ggml_metal_kargs_argmax & args,
device const char * src0,