]> git.djapps.eu Git - pkg/ggml/sources/ggml/commitdiff
metal : support GGML_OP_SET (llama/19548)
authorGeorgi Gerganov <redacted>
Fri, 13 Feb 2026 05:34:52 +0000 (07:34 +0200)
committerGeorgi Gerganov <redacted>
Sat, 14 Feb 2026 22:20:18 +0000 (00:20 +0200)
src/ggml-metal/ggml-metal-device.m
src/ggml-metal/ggml-metal-ops.cpp
src/ggml-metal/ggml-metal-ops.h
tests/test-backend-ops.cpp

index 4ea0bfb94da53420be6774a06ec6b94e71d1826f..b4ca9c5dd6fb6a062cfe28239cedbc92ee9883ea 100644 (file)
@@ -1159,6 +1159,7 @@ bool ggml_metal_device_supports_op(ggml_metal_device_t dev, const struct ggml_te
         case GGML_OP_MUL_MAT:
         case GGML_OP_MUL_MAT_ID:
             return has_simdgroup_reduction;
+        case GGML_OP_SET:
         case GGML_OP_CPY:
         case GGML_OP_DUP:
         case GGML_OP_CONT:
index 20880d9551e2043e31e7340ffb93635dc42e09ef..c04e9fc7ffa61d2e38ea4280315eee052da691f0 100644 (file)
@@ -426,6 +426,10 @@ static int ggml_metal_op_encode_impl(ggml_metal_op_t ctx, int idx) {
             {
                 n_fuse = ggml_metal_op_flash_attn_ext(ctx, idx);
             } break;
+        case GGML_OP_SET:
+            {
+                n_fuse = ggml_metal_op_set(ctx, idx);
+            } break;
         case GGML_OP_DUP:
         case GGML_OP_CPY:
         case GGML_OP_CONT:
@@ -1609,6 +1613,134 @@ int ggml_metal_op_solve_tri(ggml_metal_op_t ctx, int idx) {
     return 1;
 }
 
+int ggml_metal_op_set(ggml_metal_op_t ctx, int idx) {
+    ggml_tensor * op = ctx->node(idx);
+
+    ggml_metal_library_t lib = ctx->lib;
+    ggml_metal_encoder_t enc = ctx->enc;
+
+    GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
+    GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
+    GGML_TENSOR_LOCALS( int32_t, ne1, op->src[1], ne);
+    GGML_TENSOR_LOCALS(uint64_t, nb1, op->src[1], nb);
+    GGML_TENSOR_LOCALS( int32_t, ne,  op,         ne);
+    GGML_TENSOR_LOCALS(uint64_t, nb,  op,         nb);
+
+    ggml_metal_buffer_id bid_src0 = ggml_metal_get_buffer_id(op->src[0]);
+    ggml_metal_buffer_id bid_src1 = ggml_metal_get_buffer_id(op->src[1]);
+    ggml_metal_buffer_id bid_dst  = ggml_metal_get_buffer_id(op);
+
+    const size_t pnb1 = ((const int32_t *) op->op_params)[0];
+    const size_t pnb2 = ((const int32_t *) op->op_params)[1];
+    const size_t pnb3 = ((const int32_t *) op->op_params)[2];
+    const size_t offs = ((const int32_t *) op->op_params)[3];
+
+    const bool inplace = (bool) ((const int32_t *) op->op_params)[4];
+
+    if (!inplace) {
+        // run a separete kernel to cpy src->dst
+        // not sure how to avoid this
+        // TODO: make a simpler cpy_bytes kernel
+
+        //const id<MTLComputePipelineState> pipeline = ctx->pipelines[GGML_METAL_PIPELINE_TYPE_CPY_F32_F32].obj;
+        auto pipeline = ggml_metal_library_get_pipeline_cpy(lib, op->src[0]->type, op->type);
+
+        ggml_metal_kargs_cpy args = {
+            /*.nk0  =*/ ne00,
+            /*.ne00 =*/ ne00,
+            /*.ne01 =*/ ne01,
+            /*.ne02 =*/ ne02,
+            /*.ne03 =*/ ne03,
+            /*.nb00 =*/ nb00,
+            /*.nb01 =*/ nb01,
+            /*.nb02 =*/ nb02,
+            /*.nb03 =*/ nb03,
+            /*.ne0  =*/ ne0,
+            /*.ne1  =*/ ne1,
+            /*.ne2  =*/ ne2,
+            /*.ne3  =*/ ne3,
+            /*.nb0  =*/ nb0,
+            /*.nb1  =*/ nb1,
+            /*.nb2  =*/ nb2,
+            /*.nb3  =*/ nb3,
+        };
+
+        ggml_metal_encoder_set_pipeline(enc, pipeline);
+        ggml_metal_encoder_set_bytes   (enc, &args, sizeof(args), 0);
+        ggml_metal_encoder_set_buffer  (enc, bid_src0, 1);
+        ggml_metal_encoder_set_buffer  (enc, bid_dst,  2);
+
+        const int nth = std::min(ggml_metal_pipeline_max_theads_per_threadgroup(pipeline), ne00);
+
+        ggml_metal_encoder_dispatch_threadgroups(enc, ne01, ne02, ne03, nth, 1, 1);
+
+        ggml_metal_op_concurrency_reset(ctx);
+    }
+
+    auto pipeline = ggml_metal_library_get_pipeline_cpy(lib, op->src[1]->type, op->type);
+
+    GGML_ASSERT(ne10 % ggml_blck_size(op->src[1]->type) == 0);
+
+    int64_t nk0 = ne10;
+    if (ggml_is_quantized(op->src[1]->type)) {
+        nk0 = ne10/16;
+    } else if (ggml_is_quantized(op->type)) {
+        nk0 = ne10/ggml_blck_size(op->type);
+    }
+
+    int nth = std::min<int>(nk0, ggml_metal_pipeline_max_theads_per_threadgroup(pipeline));
+
+    // when rows are small, we can batch them together in a single threadgroup
+    int nrptg = 1;
+
+    // TODO: relax this constraint in the future
+    if (ggml_blck_size(op->src[1]->type) == 1 && ggml_blck_size(op->type) == 1) {
+        if (nth > nk0) {
+            nrptg = (nth + nk0 - 1)/nk0;
+            nth   = nk0;
+
+            if (nrptg*nth > ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)) {
+                nrptg--;
+            }
+        }
+    }
+
+    nth = std::min<int>(nth, nk0);
+
+    ggml_metal_kargs_cpy args = {
+        /*.nk0  =*/ nk0,
+        /*.ne00 =*/ ne10,
+        /*.ne01 =*/ ne11,
+        /*.ne02 =*/ ne12,
+        /*.ne03 =*/ ne13,
+        /*.nb00 =*/ nb10,
+        /*.nb01 =*/ nb11,
+        /*.nb02 =*/ nb12,
+        /*.nb03 =*/ nb13,
+        /*.ne0  =*/ ne10,
+        /*.ne1  =*/ ne11,
+        /*.ne2  =*/ ne12,
+        /*.ne3  =*/ ne13,
+        /*.nb0  =*/ ggml_element_size(op),
+        /*.nb1  =*/ pnb1,
+        /*.nb2  =*/ pnb2,
+        /*.nb3  =*/ pnb3,
+    };
+
+    const int nw0 = nrptg == 1 ? (nk0 + nth - 1)/nth : 1;
+
+    bid_dst.offs += offs;
+
+    ggml_metal_encoder_set_pipeline(enc, pipeline);
+    ggml_metal_encoder_set_bytes   (enc, &args, sizeof(args), 0);
+    ggml_metal_encoder_set_buffer  (enc, bid_src1, 1);
+    ggml_metal_encoder_set_buffer  (enc, bid_dst,  2);
+
+    ggml_metal_encoder_dispatch_threadgroups(enc, nw0*(ne11 + nrptg - 1)/nrptg, ne12, ne13, nth, nrptg, 1);
+
+    return 1;
+}
+
 int ggml_metal_op_cpy(ggml_metal_op_t ctx, int idx) {
     ggml_tensor * op = ctx->node(idx);
 
index 29456d70d5ef16205c104066258b920811ae0ae4..f3e38c7aa9de287316a7c429bb17a718e212e11c 100644 (file)
@@ -59,6 +59,7 @@ int ggml_metal_op_ssm_conv          (ggml_metal_op_t ctx, int idx);
 int ggml_metal_op_ssm_scan          (ggml_metal_op_t ctx, int idx);
 int ggml_metal_op_rwkv              (ggml_metal_op_t ctx, int idx);
 int ggml_metal_op_solve_tri         (ggml_metal_op_t ctx, int idx);
+int ggml_metal_op_set               (ggml_metal_op_t ctx, int idx);
 int ggml_metal_op_cpy               (ggml_metal_op_t ctx, int idx);
 int ggml_metal_op_pool_1d           (ggml_metal_op_t ctx, int idx);
 int ggml_metal_op_pool_2d           (ggml_metal_op_t ctx, int idx);
index 222b93584170b7818c9174f4dbbd3bd8564f5bfb..8816f6963fd4fcad0f720aca01f9cef17284a8b9 100644 (file)
@@ -2786,9 +2786,10 @@ struct test_set : public test_case {
     const ggml_type type_dst;
     const std::array<int64_t, 4> ne;
     const int dim;
+    const bool inplace;
 
     std::string vars() override {
-        return VARS_TO_STR4(type_src, type_dst, ne, dim);
+        return VARS_TO_STR5(type_src, type_dst, ne, dim, inplace);
     }
 
     size_t op_size(ggml_tensor * t) override {
@@ -2796,8 +2797,8 @@ struct test_set : public test_case {
     }
 
     test_set(ggml_type type_src = GGML_TYPE_F32, ggml_type type_dst = GGML_TYPE_F32,
-            std::array<int64_t, 4> ne = {6, 5, 4, 3}, int dim = 1)
-        : type_src(type_src), type_dst(type_dst), ne(ne), dim(dim) {}
+            std::array<int64_t, 4> ne = {6, 5, 4, 3}, int dim = 1, bool inplace = false)
+        : type_src(type_src), type_dst(type_dst), ne(ne), dim(dim), inplace(inplace) {}
 
     ggml_tensor * build_graph(ggml_context * ctx) override {
         ggml_tensor * src = ggml_new_tensor(ctx, type_src, 4, ne.data());
@@ -2808,7 +2809,7 @@ struct test_set : public test_case {
         for (int i = 0; i < dim; ++i) {
             ne_dst[i] *= 2;
         }
-        ggml_tensor* dst = ggml_new_tensor(ctx, type_dst, 4, ne_dst.data());
+        ggml_tensor * dst = ggml_new_tensor(ctx, type_dst, 4, ne_dst.data());
         ggml_set_param(dst);
         ggml_set_name(dst, "dst");
 
@@ -2816,9 +2817,16 @@ struct test_set : public test_case {
         for (int i = 0; i < dim; ++i) {
             offset += ((ne_dst[i] - ne[i])/2)*dst->nb[i];
         }
-        ggml_tensor * out = ggml_set(ctx, dst, src,
-            // The backward pass requires setting a contiguous region:
-            src->nb[1], src->nb[2], src->nb[3], offset);
+        ggml_tensor * out;
+        if (inplace) {
+            out = ggml_set_inplace(ctx, dst, src,
+                    // The backward pass requires setting a contiguous region:
+                    src->nb[1], src->nb[2], src->nb[3], offset);
+        } else {
+            out = ggml_set(ctx, dst, src,
+                    // The backward pass requires setting a contiguous region:
+                    src->nb[1], src->nb[2], src->nb[3], offset);
+        }
         ggml_set_name(out, "out");
 
         return out;
@@ -7428,11 +7436,13 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
     test_cases.emplace_back(new test_dup(GGML_TYPE_I16, {10,  8, 3, 1}, {1, 2, 0, 3}));
 
     for (int dim = 1; dim < GGML_MAX_DIMS; ++dim) {
-        test_cases.emplace_back(new test_set(GGML_TYPE_F32, GGML_TYPE_F32, {6, 5, 4, 3}, dim));
+        test_cases.emplace_back(new test_set(GGML_TYPE_F32, GGML_TYPE_F32, {6, 5, 4, 3}, dim, false));
+        test_cases.emplace_back(new test_set(GGML_TYPE_F32, GGML_TYPE_F32, {6, 5, 4, 3}, dim, true));
     }
 
     for (int dim = 1; dim < GGML_MAX_DIMS; ++dim) {
-        test_cases.emplace_back(new test_set(GGML_TYPE_I32, GGML_TYPE_I32, {6, 5, 4, 3}, dim));
+        test_cases.emplace_back(new test_set(GGML_TYPE_I32, GGML_TYPE_I32, {6, 5, 4, 3}, dim, false));
+        test_cases.emplace_back(new test_set(GGML_TYPE_I32, GGML_TYPE_I32, {6, 5, 4, 3}, dim, true));
     }
 
     // same-type copy