const char * op_str = "undefined";
switch (op->op) {
case GGML_OP_SCALE: op_str = "scale"; break;
+ case GGML_OP_FILL: op_str = "fill"; break;
case GGML_OP_CLAMP: op_str = "clamp"; break;
case GGML_OP_SQR: op_str = "sqr"; break;
case GGML_OP_SQRT: op_str = "sqrt"; break;
case GGML_UNARY_OP_HARDSWISH: op_str = "hardswish"; break;
case GGML_UNARY_OP_HARDSIGMOID: op_str = "hardsigmoid"; break;
case GGML_UNARY_OP_EXP: op_str = "exp"; break;
+ case GGML_UNARY_OP_SOFTPLUS: op_str = "softplus"; break;
+ case GGML_UNARY_OP_EXPM1: op_str = "expm1"; break;
default: GGML_ABORT("fatal error");
} break;
default: GGML_ABORT("fatal error");
return res;
}
+ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_tri(ggml_metal_library_t lib, const ggml_tensor * op) {
+ GGML_ASSERT(op->op == GGML_OP_TRI);
+ GGML_ASSERT(op->src[0]->nb[0] == ggml_type_size(op->src[0]->type));
+
+ char base[256];
+ char name[256];
+
+ const char * op_str = "tri";
+ const int ttype = op->op_params[0];
+
+ snprintf(base, 256, "kernel_%s_%s_%d", op_str, ggml_type_name(op->src[0]->type), ttype);
+
+ snprintf(name, 256, "%s", base);
+
+ ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);
+ if (!res.pipeline) {
+ res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
+ }
+
+ return res;
+}
+
ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_soft_max(ggml_metal_library_t lib, const ggml_tensor * op) {
GGML_ASSERT(!op->src[1] || op->src[1]->type == GGML_TYPE_F16 || op->src[1]->type == GGML_TYPE_F32);
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_sum_rows (ggml_metal_library_t lib, const struct ggml_tensor * op);
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_cumsum_blk (ggml_metal_library_t lib, const struct ggml_tensor * op);
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_cumsum_add (ggml_metal_library_t lib, const struct ggml_tensor * op);
+struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_tri (ggml_metal_library_t lib, const struct ggml_tensor * op);
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_soft_max (ggml_metal_library_t lib, const struct ggml_tensor * op);
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_ssm_conv (ggml_metal_library_t lib, const struct ggml_tensor * op);
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_ssm_scan (ggml_metal_library_t lib, const struct ggml_tensor * op);
case GGML_UNARY_OP_HARDSWISH:
case GGML_UNARY_OP_HARDSIGMOID:
case GGML_UNARY_OP_EXP:
+ case GGML_UNARY_OP_SOFTPLUS:
+ case GGML_UNARY_OP_EXPM1:
return ggml_is_contiguous(op->src[0]) && op->src[0]->type == GGML_TYPE_F32;
default:
return false;
case GGML_OP_ACC:
case GGML_OP_REPEAT:
case GGML_OP_SCALE:
+ case GGML_OP_FILL:
case GGML_OP_CONV_TRANSPOSE_1D:
return true;
case GGML_OP_CONV_TRANSPOSE_2D:
return ggml_is_contiguous(op->src[0]) && op->src[0]->type == GGML_TYPE_F32;
case GGML_OP_SUM:
return has_simdgroup_reduction && ggml_is_contiguous(op->src[0]);
+ case GGML_OP_TRI:
+ return ggml_is_contiguous_rows(op->src[0]);
case GGML_OP_SUM_ROWS:
case GGML_OP_CUMSUM:
case GGML_OP_MEAN:
float bias;
} ggml_metal_kargs_scale;
+typedef struct {
+ float val;
+} ggml_metal_kargs_fill;
+
typedef struct {
float min;
float max;
float slope;
} ggml_metal_kargs_leaky_relu;
+typedef struct {
+ int32_t ne00;
+ int32_t ne01;
+ int32_t ne02;
+ int32_t ne03;
+ uint64_t nb00;
+ uint64_t nb01;
+ uint64_t nb02;
+ uint64_t nb03;
+ int32_t ne0;
+ int32_t ne1;
+ int32_t ne2;
+ int32_t ne3;
+ uint64_t nb0;
+ uint64_t nb1;
+ uint64_t nb2;
+ uint64_t nb3;
+} ggml_metal_kargs_tri;
+
typedef struct {
int32_t ne00;
int32_t ne01;
{
n_fuse = ggml_metal_op_scale(ctx, idx);
} break;
+ case GGML_OP_FILL:
+ {
+ n_fuse = ggml_metal_op_fill(ctx, idx);
+ } break;
case GGML_OP_CLAMP:
{
n_fuse = ggml_metal_op_clamp(ctx, idx);
{
n_fuse = ggml_metal_op_leaky_relu(ctx, idx);
} break;
+ case GGML_OP_TRI:
+ {
+ n_fuse = ggml_metal_op_tri(ctx, idx);
+ } break;
case GGML_OP_FLASH_ATTN_EXT:
{
n_fuse = ggml_metal_op_flash_attn_ext(ctx, idx);
return 1;
}
+int ggml_metal_op_fill(ggml_metal_op_t ctx, int idx) {
+ ggml_tensor * op = ctx->node(idx);
+
+ ggml_metal_library_t lib = ctx->lib;
+ ggml_metal_encoder_t enc = ctx->enc;
+
+ GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
+ GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
+ GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
+ GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
+
+ const float val = ggml_get_op_params_f32(op, 0);
+
+ ggml_metal_kargs_fill args = {
+ /*.val =*/ val
+ };
+
+ int64_t n = ggml_nelements(op);
+
+ if (n % 4 == 0) {
+ n /= 4;
+ }
+
+ auto pipeline = ggml_metal_library_get_pipeline_unary(lib, op);
+
+ ggml_metal_encoder_set_pipeline(enc, pipeline);
+ ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
+ ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), 1);
+ ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 2);
+
+ ggml_metal_encoder_dispatch_threadgroups(enc, n, 1, 1, 1, 1, 1);
+
+ return 1;
+}
+
int ggml_metal_op_clamp(ggml_metal_op_t ctx, int idx) {
ggml_tensor * op = ctx->node(idx);
return 1;
}
+int ggml_metal_op_tri(ggml_metal_op_t ctx, int idx) {
+ ggml_tensor * op = ctx->node(idx);
+
+ ggml_metal_library_t lib = ctx->lib;
+ ggml_metal_encoder_t enc = ctx->enc;
+
+ GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
+ GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
+ GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
+ GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
+
+ ggml_metal_kargs_tri args = {
+ /*.ne00 =*/ ne00,
+ /*.ne01 =*/ ne01,
+ /*.ne02 =*/ ne02,
+ /*.ne03 =*/ ne03,
+ /*.nb00 =*/ nb00,
+ /*.nb01 =*/ nb01,
+ /*.nb02 =*/ nb02,
+ /*.nb03 =*/ nb03,
+ /*.ne0 =*/ ne0,
+ /*.ne1 =*/ ne1,
+ /*.ne2 =*/ ne2,
+ /*.ne3 =*/ ne3,
+ /*.nb0 =*/ nb0,
+ /*.nb1 =*/ nb1,
+ /*.nb2 =*/ nb2,
+ /*.nb3 =*/ nb3,
+ };
+
+ auto pipeline = ggml_metal_library_get_pipeline_tri(lib, op);
+
+ int nth = 32; // SIMD width
+
+ while (nth < ne00 && nth < ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)) {
+ nth *= 2;
+ }
+
+ nth = std::min(nth, ggml_metal_pipeline_max_theads_per_threadgroup(pipeline));
+ nth = std::min(nth, ne00);
+
+ ggml_metal_encoder_set_pipeline(enc, pipeline);
+ ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
+ ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), 1);
+ ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 2);
+
+ ggml_metal_encoder_dispatch_threadgroups(enc, ne01, ne02, ne03, nth, 1, 1);
+
+ return 1;
+}
+
int ggml_metal_op_opt_step_adamw(ggml_metal_op_t ctx, int idx) {
ggml_tensor * op = ctx->node(idx);
int ggml_metal_op_repeat (ggml_metal_op_t ctx, int idx);
int ggml_metal_op_acc (ggml_metal_op_t ctx, int idx);
int ggml_metal_op_scale (ggml_metal_op_t ctx, int idx);
+int ggml_metal_op_fill (ggml_metal_op_t ctx, int idx);
int ggml_metal_op_clamp (ggml_metal_op_t ctx, int idx);
int ggml_metal_op_unary (ggml_metal_op_t ctx, int idx);
int ggml_metal_op_glu (ggml_metal_op_t ctx, int idx);
int ggml_metal_op_argsort (ggml_metal_op_t ctx, int idx);
int ggml_metal_op_top_k (ggml_metal_op_t ctx, int idx);
int ggml_metal_op_leaky_relu (ggml_metal_op_t ctx, int idx);
+int ggml_metal_op_tri (ggml_metal_op_t ctx, int idx);
int ggml_metal_op_opt_step_adamw (ggml_metal_op_t ctx, int idx);
int ggml_metal_op_opt_step_sgd (ggml_metal_op_t ctx, int idx);
dst[tpig] = src0[tpig] * args.scale + args.bias;
}
+kernel void kernel_fill_f32(
+ constant ggml_metal_kargs_fill & args,
+ device const float * src0,
+ device float * dst,
+ uint tpig[[thread_position_in_grid]]) {
+ dst[tpig] = args.val;
+}
+
+kernel void kernel_fill_f32_4(
+ constant ggml_metal_kargs_fill & args,
+ device const float4 * src0,
+ device float4 * dst,
+ uint tpig[[thread_position_in_grid]]) {
+ dst[tpig] = args.val;
+}
+
kernel void kernel_clamp_f32(
constant ggml_metal_kargs_clamp & args,
device const float * src0,
dst[tpig] = exp(src0[tpig]);
}
+kernel void kernel_softplus_f32(
+ device const float * src0,
+ device float * dst,
+ uint tpig[[thread_position_in_grid]]) {
+ device const float & x = src0[tpig];
+ dst[tpig] = select(log(1.0f + exp(x)), x, x > 20.0f);
+}
+
+kernel void kernel_softplus_f32_4(
+ device const float4 * src0,
+ device float4 * dst,
+ uint tpig[[thread_position_in_grid]]) {
+ device const float4 & x = src0[tpig];
+ dst[tpig] = select(log(1.0f + exp(x)), x, x > 20.0f);
+}
+
+kernel void kernel_expm1_f32(
+ device const float * src0,
+ device float * dst,
+ uint tpig[[thread_position_in_grid]]) {
+ dst[tpig] = exp(src0[tpig]) - 1.0f;
+}
+
+kernel void kernel_expm1_f32_4(
+ device const float4 * src0,
+ device float4 * dst,
+ uint tpig[[thread_position_in_grid]]) {
+ dst[tpig] = exp(src0[tpig]) - 1.0f;
+}
+
kernel void kernel_reglu_f32(
constant ggml_metal_kargs_glu & args,
device const char * src0,
template [[host_name("kernel_cumsum_add_f32")]] kernel kernel_cumsum_add_t kernel_cumsum_add<float>;
+
+template<uint32_t ttype>
+bool _ggml_vec_tri_cmp(const int i, const int r);
+
+template<>
+bool _ggml_vec_tri_cmp</* GGML_TRI_TYPE_LOWER */ 3>(const int i, const int r) {
+ return i < r;
+}
+
+template<>
+bool _ggml_vec_tri_cmp</* GGML_TRI_TYPE_LOWER_DIAG */ 2>(const int i, const int r) {
+ return i <= r;
+}
+
+template<>
+bool _ggml_vec_tri_cmp</* GGML_TRI_TYPE_UPPER */ 1>(const int i, const int r) {
+ return i > r;
+}
+
+template<>
+bool _ggml_vec_tri_cmp</* GGML_TRI_TYPE_UPPER_DIAG */ 0>(const int i, const int r) {
+ return i >= r;
+}
+
+template<typename T, int ttype>
+kernel void kernel_tri(
+ constant ggml_metal_kargs_tri & args,
+ device const char * src0,
+ device const char * dst,
+ uint3 tgpig[[threadgroup_position_in_grid]],
+ ushort3 tpitg[[thread_position_in_threadgroup]],
+ ushort3 ntg[[threads_per_threadgroup]]) {
+ const int i3 = tgpig.z;
+ const int i2 = tgpig.y;
+ const int i1 = tgpig.x;
+
+ if (i3 >= args.ne03 || i2 >= args.ne02 || i1 >= args.ne01) {
+ return;
+ }
+
+ device const T * src_row = (device const T *) ((device const char *) src0 + i1*args.nb01 + i2*args.nb02 + i3*args.nb03);
+ device T * dst_row = (device T *) ((device char *) dst + i1*args.nb1 + i2*args.nb2 + i3*args.nb3);
+
+ // Each thread is a single element of the row if ne00 < max threads per
+ // threadgroup, so this will loop once for each index that this thread is
+ // responsible for
+ for (int64_t i0 = tpitg.x; i0 < args.ne00; i0 += ntg.x) {
+ // Use the comparison as a mask for branchless
+ dst_row[i0] = static_cast<T>(_ggml_vec_tri_cmp<ttype>(i0, i1)) * src_row[i0];
+ }
+}
+
+typedef decltype(kernel_tri<float, 0>) kernel_tri_t;
+
+template [[host_name("kernel_tri_f32_0")]] kernel kernel_tri_t kernel_tri<float, 0>;
+template [[host_name("kernel_tri_f32_1")]] kernel kernel_tri_t kernel_tri<float, 1>;
+template [[host_name("kernel_tri_f32_2")]] kernel kernel_tri_t kernel_tri<float, 2>;
+template [[host_name("kernel_tri_f32_3")]] kernel kernel_tri_t kernel_tri<float, 3>;
+template [[host_name("kernel_tri_f16_0")]] kernel kernel_tri_t kernel_tri<half, 0>;
+template [[host_name("kernel_tri_f16_1")]] kernel kernel_tri_t kernel_tri<half, 1>;
+template [[host_name("kernel_tri_f16_2")]] kernel kernel_tri_t kernel_tri<half, 2>;
+template [[host_name("kernel_tri_f16_3")]] kernel kernel_tri_t kernel_tri<half, 3>;
+#if defined(GGML_METAL_HAS_BF16)
+template [[host_name("kernel_tri_bf16_0")]] kernel kernel_tri_t kernel_tri<bfloat, 0>;
+template [[host_name("kernel_tri_bf16_1")]] kernel kernel_tri_t kernel_tri<bfloat, 1>;
+template [[host_name("kernel_tri_bf16_2")]] kernel kernel_tri_t kernel_tri<bfloat, 2>;
+template [[host_name("kernel_tri_bf16_3")]] kernel kernel_tri_t kernel_tri<bfloat, 3>;
+#endif
+
template<typename T>
kernel void kernel_soft_max(
constant ggml_metal_kargs_soft_max & args,