return res;
}
+ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_diag(ggml_metal_library_t lib, const ggml_tensor * op) {
+ char base[256];
+ char name[256];
+
+ const int n = op->src[0]->ne[0];
+
+ snprintf(base, 256, "kernel_diag_%s", ggml_type_name(op->src[0]->type));
+ snprintf(name, 256, "%s_n=%d", base, n);
+
+ ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);
+ if (!res.pipeline) {
+ res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
+ }
+
+ res.nsg = 1;
+ res.smem = 0;
+
+ return res;
+}
+
ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_repeat(ggml_metal_library_t lib, ggml_type tsrc) {
char base[256];
char name[256];
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_pool_2d (ggml_metal_library_t lib, const struct ggml_tensor * op, enum ggml_op_pool op_pool);
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_get_rows (ggml_metal_library_t lib, enum ggml_type tsrc);
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_set_rows (ggml_metal_library_t lib, enum ggml_type tidx, enum ggml_type tdst);
+struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_diag (ggml_metal_library_t lib, const struct ggml_tensor * op);
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_repeat (ggml_metal_library_t lib, enum ggml_type tsrc);
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_unary (ggml_metal_library_t lib, const struct ggml_tensor * op);
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_glu (ggml_metal_library_t lib, const struct ggml_tensor * op);
return has_simdgroup_reduction;
case GGML_OP_RWKV_WKV6:
case GGML_OP_RWKV_WKV7:
- case GGML_OP_SOLVE_TRI:
return true;
+ case GGML_OP_SOLVE_TRI:
case GGML_OP_MUL_MAT:
case GGML_OP_MUL_MAT_ID:
return has_simdgroup_reduction;
return false;
};
}
+ case GGML_OP_DIAG:
+ return true;
case GGML_OP_OPT_STEP_ADAMW:
case GGML_OP_OPT_STEP_SGD:
return has_simdgroup_reduction;
uint64_t nb3;
} ggml_metal_kargs_set_rows;
+typedef struct {
+ int32_t ne00;
+ int32_t ne01;
+ int32_t ne02;
+ int32_t ne03;
+ uint64_t nb00;
+ uint64_t nb01;
+ uint64_t nb02;
+ uint64_t nb03;
+ int32_t ne0;
+ int32_t ne1;
+ int32_t ne2;
+ int32_t ne3;
+ uint64_t nb0;
+ uint64_t nb1;
+ uint64_t nb2;
+ uint64_t nb3;
+} ggml_metal_kargs_diag;
+
typedef struct {
int64_t ne00;
int64_t ne01;
{
n_fuse = ggml_metal_op_set_rows(ctx, idx);
} break;
+ case GGML_OP_DIAG:
+ {
+ n_fuse = ggml_metal_op_diag(ctx, idx);
+ } break;
case GGML_OP_L2_NORM:
{
n_fuse = ggml_metal_op_l2_norm(ctx, idx);
return 1;
}
+int ggml_metal_op_diag(ggml_metal_op_t ctx, int idx) {
+ ggml_tensor * op = ctx->node(idx);
+
+ ggml_metal_library_t lib = ctx->lib;
+ ggml_metal_encoder_t enc = ctx->enc;
+
+ GGML_TENSOR_LOCALS(int32_t, ne0, op->src[0], ne);
+ GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
+ GGML_TENSOR_LOCALS(int32_t, ne, op, ne);
+ GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
+
+ ggml_metal_kargs_diag args = {
+ /*.ne00 =*/ne00,
+ /*.ne01 =*/ne01,
+ /*.ne02 =*/ne02,
+ /*.ne03 =*/ne03,
+ /*.nb00 =*/nb00,
+ /*.nb01 =*/nb01,
+ /*.nb02 =*/nb02,
+ /*.nb03 =*/nb03,
+ /*.ne0 =*/ne0,
+ /*.ne1 =*/ne1,
+ /*.ne2 =*/ne2,
+ /*.ne3 =*/ne3,
+ /*.nb0 =*/nb0,
+ /*.nb1 =*/nb1,
+ /*.nb2 =*/nb2,
+ /*.nb3 =*/nb3,
+ };
+
+ auto pipeline = ggml_metal_library_get_pipeline_diag(lib, op);
+
+ ggml_metal_encoder_set_pipeline(enc, pipeline);
+ ggml_metal_encoder_set_bytes(enc, &args, sizeof(args), 0);
+ ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op->src[0]), 1);
+ ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op), 2);
+
+ ggml_metal_encoder_dispatch_threadgroups(enc, ne1, ne2, ne3, 32, 1, 1);
+
+ return 1;
+}
+
int ggml_metal_op_soft_max(ggml_metal_op_t ctx, int idx) {
ggml_tensor * op = ctx->node(idx);
int ggml_metal_op_cumsum (ggml_metal_op_t ctx, int idx);
int ggml_metal_op_get_rows (ggml_metal_op_t ctx, int idx);
int ggml_metal_op_set_rows (ggml_metal_op_t ctx, int idx);
+int ggml_metal_op_diag (ggml_metal_op_t ctx, int idx);
int ggml_metal_op_soft_max (ggml_metal_op_t ctx, int idx);
int ggml_metal_op_ssm_conv (ggml_metal_op_t ctx, int idx);
int ggml_metal_op_ssm_scan (ggml_metal_op_t ctx, int idx);
}
}
+kernel void kernel_diag_f32(
+ constant ggml_metal_kargs_diag & args,
+ device const char * src0,
+ device char * dst,
+ uint3 tgpig[[threadgroup_position_in_grid]],
+ ushort tiitg[[thread_index_in_threadgroup]]) {
+ constexpr short NW = N_SIMDWIDTH;
+
+ const int32_t i3 = tgpig.z;
+ const int32_t i2 = tgpig.y;
+ const int32_t i1 = tgpig.x;
+
+ device const float * src0_ptr = (device const float *)(src0 + i2*args.nb02 + i3*args.nb03);
+ device float * dst_ptr = (device float *)(dst + i1*args.nb01 + i2*args.nb2 + i3*args.nb3);
+
+ for (int i0 = tiitg; i0 < args.ne0; i0 += NW) {
+ dst_ptr[i0] = i0 == i1 ? src0_ptr[i0] : 0.0f;
+ }
+}
+
constant bool FC_mul_mm_bc_inp [[function_constant(FC_MUL_MM + 0)]];
constant bool FC_mul_mm_bc_out [[function_constant(FC_MUL_MM + 1)]];