]> git.djapps.eu Git - pkg/ggml/sources/whisper.cpp/commitdiff
metal : support GGML_OP_SET (llama/19548)
authorGeorgi Gerganov <redacted>
Fri, 13 Feb 2026 05:34:52 +0000 (07:34 +0200)
committerGeorgi Gerganov <redacted>
Sun, 15 Feb 2026 19:44:37 +0000 (21:44 +0200)
ggml/src/ggml-metal/ggml-metal-device.m
ggml/src/ggml-metal/ggml-metal-ops.cpp
ggml/src/ggml-metal/ggml-metal-ops.h

index 4ea0bfb94da53420be6774a06ec6b94e71d1826f..b4ca9c5dd6fb6a062cfe28239cedbc92ee9883ea 100644 (file)
@@ -1159,6 +1159,7 @@ bool ggml_metal_device_supports_op(ggml_metal_device_t dev, const struct ggml_te
         case GGML_OP_MUL_MAT:
         case GGML_OP_MUL_MAT_ID:
             return has_simdgroup_reduction;
+        case GGML_OP_SET:
         case GGML_OP_CPY:
         case GGML_OP_DUP:
         case GGML_OP_CONT:
index 20880d9551e2043e31e7340ffb93635dc42e09ef..c04e9fc7ffa61d2e38ea4280315eee052da691f0 100644 (file)
@@ -426,6 +426,10 @@ static int ggml_metal_op_encode_impl(ggml_metal_op_t ctx, int idx) {
             {
                 n_fuse = ggml_metal_op_flash_attn_ext(ctx, idx);
             } break;
+        case GGML_OP_SET:
+            {
+                n_fuse = ggml_metal_op_set(ctx, idx);
+            } break;
         case GGML_OP_DUP:
         case GGML_OP_CPY:
         case GGML_OP_CONT:
@@ -1609,6 +1613,134 @@ int ggml_metal_op_solve_tri(ggml_metal_op_t ctx, int idx) {
     return 1;
 }
 
+int ggml_metal_op_set(ggml_metal_op_t ctx, int idx) {
+    ggml_tensor * op = ctx->node(idx);
+
+    ggml_metal_library_t lib = ctx->lib;
+    ggml_metal_encoder_t enc = ctx->enc;
+
+    GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
+    GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
+    GGML_TENSOR_LOCALS( int32_t, ne1, op->src[1], ne);
+    GGML_TENSOR_LOCALS(uint64_t, nb1, op->src[1], nb);
+    GGML_TENSOR_LOCALS( int32_t, ne,  op,         ne);
+    GGML_TENSOR_LOCALS(uint64_t, nb,  op,         nb);
+
+    ggml_metal_buffer_id bid_src0 = ggml_metal_get_buffer_id(op->src[0]);
+    ggml_metal_buffer_id bid_src1 = ggml_metal_get_buffer_id(op->src[1]);
+    ggml_metal_buffer_id bid_dst  = ggml_metal_get_buffer_id(op);
+
+    const size_t pnb1 = ((const int32_t *) op->op_params)[0];
+    const size_t pnb2 = ((const int32_t *) op->op_params)[1];
+    const size_t pnb3 = ((const int32_t *) op->op_params)[2];
+    const size_t offs = ((const int32_t *) op->op_params)[3];
+
+    const bool inplace = (bool) ((const int32_t *) op->op_params)[4];
+
+    if (!inplace) {
+        // run a separete kernel to cpy src->dst
+        // not sure how to avoid this
+        // TODO: make a simpler cpy_bytes kernel
+
+        //const id<MTLComputePipelineState> pipeline = ctx->pipelines[GGML_METAL_PIPELINE_TYPE_CPY_F32_F32].obj;
+        auto pipeline = ggml_metal_library_get_pipeline_cpy(lib, op->src[0]->type, op->type);
+
+        ggml_metal_kargs_cpy args = {
+            /*.nk0  =*/ ne00,
+            /*.ne00 =*/ ne00,
+            /*.ne01 =*/ ne01,
+            /*.ne02 =*/ ne02,
+            /*.ne03 =*/ ne03,
+            /*.nb00 =*/ nb00,
+            /*.nb01 =*/ nb01,
+            /*.nb02 =*/ nb02,
+            /*.nb03 =*/ nb03,
+            /*.ne0  =*/ ne0,
+            /*.ne1  =*/ ne1,
+            /*.ne2  =*/ ne2,
+            /*.ne3  =*/ ne3,
+            /*.nb0  =*/ nb0,
+            /*.nb1  =*/ nb1,
+            /*.nb2  =*/ nb2,
+            /*.nb3  =*/ nb3,
+        };
+
+        ggml_metal_encoder_set_pipeline(enc, pipeline);
+        ggml_metal_encoder_set_bytes   (enc, &args, sizeof(args), 0);
+        ggml_metal_encoder_set_buffer  (enc, bid_src0, 1);
+        ggml_metal_encoder_set_buffer  (enc, bid_dst,  2);
+
+        const int nth = std::min(ggml_metal_pipeline_max_theads_per_threadgroup(pipeline), ne00);
+
+        ggml_metal_encoder_dispatch_threadgroups(enc, ne01, ne02, ne03, nth, 1, 1);
+
+        ggml_metal_op_concurrency_reset(ctx);
+    }
+
+    auto pipeline = ggml_metal_library_get_pipeline_cpy(lib, op->src[1]->type, op->type);
+
+    GGML_ASSERT(ne10 % ggml_blck_size(op->src[1]->type) == 0);
+
+    int64_t nk0 = ne10;
+    if (ggml_is_quantized(op->src[1]->type)) {
+        nk0 = ne10/16;
+    } else if (ggml_is_quantized(op->type)) {
+        nk0 = ne10/ggml_blck_size(op->type);
+    }
+
+    int nth = std::min<int>(nk0, ggml_metal_pipeline_max_theads_per_threadgroup(pipeline));
+
+    // when rows are small, we can batch them together in a single threadgroup
+    int nrptg = 1;
+
+    // TODO: relax this constraint in the future
+    if (ggml_blck_size(op->src[1]->type) == 1 && ggml_blck_size(op->type) == 1) {
+        if (nth > nk0) {
+            nrptg = (nth + nk0 - 1)/nk0;
+            nth   = nk0;
+
+            if (nrptg*nth > ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)) {
+                nrptg--;
+            }
+        }
+    }
+
+    nth = std::min<int>(nth, nk0);
+
+    ggml_metal_kargs_cpy args = {
+        /*.nk0  =*/ nk0,
+        /*.ne00 =*/ ne10,
+        /*.ne01 =*/ ne11,
+        /*.ne02 =*/ ne12,
+        /*.ne03 =*/ ne13,
+        /*.nb00 =*/ nb10,
+        /*.nb01 =*/ nb11,
+        /*.nb02 =*/ nb12,
+        /*.nb03 =*/ nb13,
+        /*.ne0  =*/ ne10,
+        /*.ne1  =*/ ne11,
+        /*.ne2  =*/ ne12,
+        /*.ne3  =*/ ne13,
+        /*.nb0  =*/ ggml_element_size(op),
+        /*.nb1  =*/ pnb1,
+        /*.nb2  =*/ pnb2,
+        /*.nb3  =*/ pnb3,
+    };
+
+    const int nw0 = nrptg == 1 ? (nk0 + nth - 1)/nth : 1;
+
+    bid_dst.offs += offs;
+
+    ggml_metal_encoder_set_pipeline(enc, pipeline);
+    ggml_metal_encoder_set_bytes   (enc, &args, sizeof(args), 0);
+    ggml_metal_encoder_set_buffer  (enc, bid_src1, 1);
+    ggml_metal_encoder_set_buffer  (enc, bid_dst,  2);
+
+    ggml_metal_encoder_dispatch_threadgroups(enc, nw0*(ne11 + nrptg - 1)/nrptg, ne12, ne13, nth, nrptg, 1);
+
+    return 1;
+}
+
 int ggml_metal_op_cpy(ggml_metal_op_t ctx, int idx) {
     ggml_tensor * op = ctx->node(idx);
 
index 29456d70d5ef16205c104066258b920811ae0ae4..f3e38c7aa9de287316a7c429bb17a718e212e11c 100644 (file)
@@ -59,6 +59,7 @@ int ggml_metal_op_ssm_conv          (ggml_metal_op_t ctx, int idx);
 int ggml_metal_op_ssm_scan          (ggml_metal_op_t ctx, int idx);
 int ggml_metal_op_rwkv              (ggml_metal_op_t ctx, int idx);
 int ggml_metal_op_solve_tri         (ggml_metal_op_t ctx, int idx);
+int ggml_metal_op_set               (ggml_metal_op_t ctx, int idx);
 int ggml_metal_op_cpy               (ggml_metal_op_t ctx, int idx);
 int ggml_metal_op_pool_1d           (ggml_metal_op_t ctx, int idx);
 int ggml_metal_op_pool_2d           (ggml_metal_op_t ctx, int idx);