{
n_fuse = ggml_metal_op_flash_attn_ext(ctx, idx);
} break;
+ case GGML_OP_SET:
+ {
+ n_fuse = ggml_metal_op_set(ctx, idx);
+ } break;
case GGML_OP_DUP:
case GGML_OP_CPY:
case GGML_OP_CONT:
return 1;
}
+int ggml_metal_op_set(ggml_metal_op_t ctx, int idx) {
+ ggml_tensor * op = ctx->node(idx);
+
+ ggml_metal_library_t lib = ctx->lib;
+ ggml_metal_encoder_t enc = ctx->enc;
+
+ GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
+ GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
+ GGML_TENSOR_LOCALS( int32_t, ne1, op->src[1], ne);
+ GGML_TENSOR_LOCALS(uint64_t, nb1, op->src[1], nb);
+ GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
+ GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
+
+ ggml_metal_buffer_id bid_src0 = ggml_metal_get_buffer_id(op->src[0]);
+ ggml_metal_buffer_id bid_src1 = ggml_metal_get_buffer_id(op->src[1]);
+ ggml_metal_buffer_id bid_dst = ggml_metal_get_buffer_id(op);
+
+ const size_t pnb1 = ((const int32_t *) op->op_params)[0];
+ const size_t pnb2 = ((const int32_t *) op->op_params)[1];
+ const size_t pnb3 = ((const int32_t *) op->op_params)[2];
+ const size_t offs = ((const int32_t *) op->op_params)[3];
+
+ const bool inplace = (bool) ((const int32_t *) op->op_params)[4];
+
+ if (!inplace) {
+ // run a separete kernel to cpy src->dst
+ // not sure how to avoid this
+ // TODO: make a simpler cpy_bytes kernel
+
+ //const id<MTLComputePipelineState> pipeline = ctx->pipelines[GGML_METAL_PIPELINE_TYPE_CPY_F32_F32].obj;
+ auto pipeline = ggml_metal_library_get_pipeline_cpy(lib, op->src[0]->type, op->type);
+
+ ggml_metal_kargs_cpy args = {
+ /*.nk0 =*/ ne00,
+ /*.ne00 =*/ ne00,
+ /*.ne01 =*/ ne01,
+ /*.ne02 =*/ ne02,
+ /*.ne03 =*/ ne03,
+ /*.nb00 =*/ nb00,
+ /*.nb01 =*/ nb01,
+ /*.nb02 =*/ nb02,
+ /*.nb03 =*/ nb03,
+ /*.ne0 =*/ ne0,
+ /*.ne1 =*/ ne1,
+ /*.ne2 =*/ ne2,
+ /*.ne3 =*/ ne3,
+ /*.nb0 =*/ nb0,
+ /*.nb1 =*/ nb1,
+ /*.nb2 =*/ nb2,
+ /*.nb3 =*/ nb3,
+ };
+
+ ggml_metal_encoder_set_pipeline(enc, pipeline);
+ ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
+ ggml_metal_encoder_set_buffer (enc, bid_src0, 1);
+ ggml_metal_encoder_set_buffer (enc, bid_dst, 2);
+
+ const int nth = std::min(ggml_metal_pipeline_max_theads_per_threadgroup(pipeline), ne00);
+
+ ggml_metal_encoder_dispatch_threadgroups(enc, ne01, ne02, ne03, nth, 1, 1);
+
+ ggml_metal_op_concurrency_reset(ctx);
+ }
+
+ auto pipeline = ggml_metal_library_get_pipeline_cpy(lib, op->src[1]->type, op->type);
+
+ GGML_ASSERT(ne10 % ggml_blck_size(op->src[1]->type) == 0);
+
+ int64_t nk0 = ne10;
+ if (ggml_is_quantized(op->src[1]->type)) {
+ nk0 = ne10/16;
+ } else if (ggml_is_quantized(op->type)) {
+ nk0 = ne10/ggml_blck_size(op->type);
+ }
+
+ int nth = std::min<int>(nk0, ggml_metal_pipeline_max_theads_per_threadgroup(pipeline));
+
+ // when rows are small, we can batch them together in a single threadgroup
+ int nrptg = 1;
+
+ // TODO: relax this constraint in the future
+ if (ggml_blck_size(op->src[1]->type) == 1 && ggml_blck_size(op->type) == 1) {
+ if (nth > nk0) {
+ nrptg = (nth + nk0 - 1)/nk0;
+ nth = nk0;
+
+ if (nrptg*nth > ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)) {
+ nrptg--;
+ }
+ }
+ }
+
+ nth = std::min<int>(nth, nk0);
+
+ ggml_metal_kargs_cpy args = {
+ /*.nk0 =*/ nk0,
+ /*.ne00 =*/ ne10,
+ /*.ne01 =*/ ne11,
+ /*.ne02 =*/ ne12,
+ /*.ne03 =*/ ne13,
+ /*.nb00 =*/ nb10,
+ /*.nb01 =*/ nb11,
+ /*.nb02 =*/ nb12,
+ /*.nb03 =*/ nb13,
+ /*.ne0 =*/ ne10,
+ /*.ne1 =*/ ne11,
+ /*.ne2 =*/ ne12,
+ /*.ne3 =*/ ne13,
+ /*.nb0 =*/ ggml_element_size(op),
+ /*.nb1 =*/ pnb1,
+ /*.nb2 =*/ pnb2,
+ /*.nb3 =*/ pnb3,
+ };
+
+ const int nw0 = nrptg == 1 ? (nk0 + nth - 1)/nth : 1;
+
+ bid_dst.offs += offs;
+
+ ggml_metal_encoder_set_pipeline(enc, pipeline);
+ ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
+ ggml_metal_encoder_set_buffer (enc, bid_src1, 1);
+ ggml_metal_encoder_set_buffer (enc, bid_dst, 2);
+
+ ggml_metal_encoder_dispatch_threadgroups(enc, nw0*(ne11 + nrptg - 1)/nrptg, ne12, ne13, nth, nrptg, 1);
+
+ return 1;
+}
+
int ggml_metal_op_cpy(ggml_metal_op_t ctx, int idx) {
ggml_tensor * op = ctx->node(idx);