return res;
}
+
+ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_memset(ggml_metal_library_t lib, const ggml_tensor * op) {
+ GGML_ASSERT(op->type == GGML_TYPE_I64);
+
+ char base[256];
+ char name[256];
+
+ snprintf(base, 256, "kernel_memset_%s", ggml_type_name(op->type));
+ snprintf(name, 256, "%s", base);
+
+ ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);
+ if (!res.pipeline) {
+ res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
+ }
+
+ return res;
+}
+
+ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_count_equal(ggml_metal_library_t lib, const ggml_tensor * op) {
+ assert(op->op == GGML_OP_COUNT_EQUAL);
+
+ GGML_TENSOR_LOCALS(int64_t, ne0, op->src[0], ne);
+
+ GGML_ASSERT(op->src[0]->type == op->src[1]->type);
+ GGML_ASSERT(op->src[0]->type == GGML_TYPE_I32);
+ GGML_ASSERT(op->type == GGML_TYPE_I64);
+
+ // note: the kernel only supports i32 output due to metal atomic add only supporting atomic_int
+ GGML_ASSERT(ggml_nelements(op->src[0]) < (1LL << 31));
+
+ char base[256];
+ char name[256];
+
+ int nsg = 1;
+ while (32*nsg < ne00 && nsg < 32) {
+ nsg *= 2;
+ }
+
+ snprintf(base, 256, "kernel_count_equal_%s", ggml_type_name(op->src[0]->type));
+ snprintf(name, 256, "%s_nsg=%d", base, nsg);
+
+ ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);
+ if (!res.pipeline) {
+ ggml_metal_cv_t cv = ggml_metal_cv_init();
+
+ ggml_metal_cv_set_int16(cv, nsg, FC_COUNT_EQUAL + 0);
+
+ res = ggml_metal_library_compile_pipeline(lib, base, name, cv);
+
+ ggml_metal_cv_free(cv);
+ }
+
+ res.smem = 32 * sizeof(int32_t);
+ res.nsg = nsg;
+
+ return res;
+}
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_timestep_embedding(ggml_metal_library_t lib, const struct ggml_tensor * op);
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_opt_step_adamw (ggml_metal_library_t lib, const struct ggml_tensor * op);
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_opt_step_sgd (ggml_metal_library_t lib, const struct ggml_tensor * op);
+struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_memset (ggml_metal_library_t lib, const struct ggml_tensor * op);
+struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_count_equal (ggml_metal_library_t lib, const struct ggml_tensor * op);
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_flash_attn_ext_pad(
ggml_metal_library_t lib,
return has_simdgroup_reduction && ggml_is_contiguous_rows(op->src[0]);
case GGML_OP_L2_NORM:
return has_simdgroup_reduction && (op->ne[0] % 4 == 0 && ggml_is_contiguous_1(op->src[0]));
+ case GGML_OP_COUNT_EQUAL:
+ return has_simdgroup_reduction &&
+ op->src[0]->type == GGML_TYPE_I32 &&
+ op->src[1]->type == GGML_TYPE_I32 &&
+ op->type == GGML_TYPE_I64;
case GGML_OP_ARGMAX:
return has_simdgroup_reduction;
case GGML_OP_NORM:
#define FC_MUL_MM 700
#define FC_ROPE 800
#define FC_SSM_CONV 900
+#define FC_COUNT_EQUAL 1000
// op-specific constants
#define OP_FLASH_ATTN_EXT_NQPTG 8
float step;
} ggml_metal_kargs_arange;
+typedef struct {
+ int64_t val;
+} ggml_metal_kargs_memset;
+
+typedef struct {
+ int32_t ne00;
+ int32_t ne01;
+ int32_t ne02;
+ int32_t ne03;
+ uint64_t nb00;
+ uint64_t nb01;
+ uint64_t nb02;
+ uint64_t nb03;
+ uint64_t nb10;
+ uint64_t nb11;
+ uint64_t nb12;
+ uint64_t nb13;
+} ggml_metal_kargs_count_equal;
+
typedef struct {
int32_t k0;
int32_t k1;
{
n_fuse = ggml_metal_op_opt_step_sgd(ctx, idx);
} break;
- default:
+ case GGML_OP_COUNT_EQUAL:
+ {
+ n_fuse = ggml_metal_op_count_equal(ctx, idx);
+ } break;
+ default:
{
GGML_LOG_ERROR("%s: error: node %3d, op = %8s not implemented\n", __func__, idx, ggml_op_name(node->op));
GGML_ABORT("fatal error");
return 1;
}
+
+int ggml_metal_op_count_equal(ggml_metal_op_t ctx, int idx) {
+ ggml_tensor * op = ctx->node(idx);
+
+ ggml_metal_library_t lib = ctx->lib;
+ ggml_metal_encoder_t enc = ctx->enc;
+
+ GGML_TENSOR_LOCALS(int32_t, ne0, op->src[0], ne);
+ GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
+ GGML_TENSOR_LOCALS(uint64_t, nb1, op->src[1], nb);
+
+ {
+ ggml_metal_kargs_memset args = { /*.val =*/ 0 };
+
+ auto pipeline = ggml_metal_library_get_pipeline_memset(lib, op);
+
+ ggml_metal_encoder_set_pipeline(enc, pipeline);
+ ggml_metal_encoder_set_bytes(enc, &args, sizeof(args), 0);
+ ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op), 1);
+
+ ggml_metal_encoder_dispatch_threadgroups(enc, 1, 1, 1, 1, 1, 1);
+ }
+
+ ggml_metal_op_concurrency_reset(ctx);
+
+ {
+ ggml_metal_kargs_count_equal args = {
+ /*.ne00 =*/ ne00,
+ /*.ne01 =*/ ne01,
+ /*.ne02 =*/ ne02,
+ /*.ne03 =*/ ne03,
+ /*.nb00 =*/ nb00,
+ /*.nb01 =*/ nb01,
+ /*.nb02 =*/ nb02,
+ /*.nb03 =*/ nb03,
+ /*.nb10 =*/ nb10,
+ /*.nb11 =*/ nb11,
+ /*.nb12 =*/ nb12,
+ /*.nb13 =*/ nb13,
+ };
+
+ auto pipeline = ggml_metal_library_get_pipeline_count_equal(lib, op);
+
+ const size_t smem = pipeline.smem;
+
+ const int nth = 32*pipeline.nsg;
+
+ GGML_ASSERT(nth <= ggml_metal_pipeline_max_theads_per_threadgroup(pipeline));
+
+ ggml_metal_encoder_set_pipeline(enc, pipeline);
+ ggml_metal_encoder_set_bytes(enc, &args, sizeof(args), 0);
+ ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op->src[0]), 1);
+ ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op->src[1]), 2);
+ ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op), 3);
+
+ ggml_metal_encoder_set_threadgroup_memory_size(enc, smem, 0);
+ ggml_metal_encoder_dispatch_threadgroups(enc, ne01, ne02, ne03, nth, 1, 1);
+ }
+
+ return 1;
+}
int ggml_metal_op_tri (ggml_metal_op_t ctx, int idx);
int ggml_metal_op_opt_step_adamw (ggml_metal_op_t ctx, int idx);
int ggml_metal_op_opt_step_sgd (ggml_metal_op_t ctx, int idx);
+int ggml_metal_op_count_equal (ggml_metal_op_t ctx, int idx);
#ifdef __cplusplus
}
return;
}
+ // TODO: become function constant
const uint nsg = (ntg.x + 31) / 32;
float sumf = 0;
x[gid] = x[gid] * (1.0f - pars[0] * pars[1]) - pars[0] * g[gid];
}
+
+template<typename T>
+kernel void kernel_memset(
+ constant ggml_metal_kargs_fill & args,
+ device T * dst,
+ uint tpig[[thread_position_in_grid]]) {
+ dst[tpig] = args.val;
+}
+
+typedef decltype(kernel_memset<int64_t>) kernel_memset_t;
+
+template [[host_name("kernel_memset_i64")]] kernel kernel_memset_t kernel_memset<int64_t>;
+
+constant short FC_count_equal_nsg [[function_constant(FC_COUNT_EQUAL + 0)]];
+
+template<typename T>
+kernel void kernel_count_equal(
+ constant ggml_metal_kargs_count_equal & args,
+ device const char * src0,
+ device const char * src1,
+ device atomic_int * dst,
+ threadgroup int32_t * shmem_i32 [[threadgroup(0)]],
+ uint3 tgpig[[threadgroup_position_in_grid]],
+ ushort3 tpitg[[thread_position_in_threadgroup]],
+ ushort sgitg[[simdgroup_index_in_threadgroup]],
+ ushort tiisg[[thread_index_in_simdgroup]],
+ ushort3 ntg[[threads_per_threadgroup]]) {
+ const short NSG = FC_count_equal_nsg;
+
+ const int i3 = tgpig.z;
+ const int i2 = tgpig.y;
+ const int i1 = tgpig.x;
+
+ if (i3 >= args.ne03 || i2 >= args.ne02 || i1 >= args.ne01) {
+ return;
+ }
+
+ int sum = 0;
+
+ device const char * base0 = src0 + i1*args.nb01 + i2*args.nb02 + i3*args.nb03;
+ device const char * base1 = src1 + i1*args.nb11 + i2*args.nb12 + i3*args.nb13;
+
+ for (int64_t i0 = tpitg.x; i0 < args.ne00; i0 += ntg.x) {
+ const T v0 = *(device const T *)(base0 + i0*args.nb00);
+ const T v1 = *(device const T *)(base1 + i0*args.nb10);
+ sum += (v0 == v1);
+ }
+
+ sum = simd_sum(sum);
+
+ if (tiisg == 0) {
+ shmem_i32[sgitg] = sum;
+ }
+
+ threadgroup_barrier(mem_flags::mem_threadgroup);
+
+ if (sgitg == 0) {
+ float v = 0.0f;
+ if (tpitg.x < NSG) {
+ v = shmem_i32[tpitg.x];
+ }
+
+ float total = simd_sum(v);
+ if (tpitg.x == 0) {
+ atomic_fetch_add_explicit(dst, (int32_t) total, memory_order_relaxed);
+ }
+ }
+}
+
+typedef decltype(kernel_count_equal<int32_t>) kernel_count_equal_t;
+
+template [[host_name("kernel_count_equal_i32")]] kernel kernel_count_equal_t kernel_count_equal<int32_t>;