]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
cuda: add SET operation support (#16804)
authorYaelGitAccount <redacted>
Tue, 28 Oct 2025 19:10:28 +0000 (21:10 +0200)
committerGitHub <redacted>
Tue, 28 Oct 2025 19:10:28 +0000 (20:10 +0100)
* feat(cuda): add GGML_OP_SET support

Implement CUDA kernel for SET operation with f32 support.

All tests passing (14598/14598).

* cuda(set): add I32 support; keep F32

* refactor(cuda): use ggml_cuda_cpy to unify SET operator logic and remove code duplication

* Update ggml/src/ggml-cuda/ggml-cuda.cu

Co-authored-by: Sigbjørn Skjæret <redacted>
* Update ggml/src/ggml-cuda/set.cu

Co-authored-by: Sigbjørn Skjæret <redacted>
---------

Co-authored-by: Sigbjørn Skjæret <redacted>
ggml/src/ggml-cuda/ggml-cuda.cu
ggml/src/ggml-cuda/set.cu [new file with mode: 0644]
ggml/src/ggml-cuda/set.cuh [new file with mode: 0644]

index 94ab1ec0f5a908bd5d1ae7469eb2e09f8f600345..be505748af5a4287379404b64ad05ec645957e99 100644 (file)
@@ -50,6 +50,7 @@
 #include "ggml-cuda/upscale.cuh"
 #include "ggml-cuda/wkv.cuh"
 #include "ggml-cuda/gla.cuh"
+#include "ggml-cuda/set.cuh"
 #include "ggml-cuda/set-rows.cuh"
 #include "ggml-cuda/pad_reflect_1d.cuh"
 #include "ggml.h"
@@ -2416,6 +2417,9 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
         case GGML_OP_SET_ROWS:
             ggml_cuda_op_set_rows(ctx, dst);
             break;
+        case GGML_OP_SET:
+            ggml_cuda_op_set(ctx, dst);
+            break;
         case GGML_OP_DUP:
             ggml_cuda_dup(ctx, dst);
             break;
@@ -3842,6 +3846,13 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
                        op->src[0]->type == GGML_TYPE_F32 &&
                        (op->src[1]->type == GGML_TYPE_I64 || op->src[1]->type == GGML_TYPE_I32);
             } break;
+        case GGML_OP_SET:
+            {
+                const ggml_type t = op->type;
+                return (t == GGML_TYPE_F32 || t == GGML_TYPE_I32) &&
+                    t == op->src[0]->type &&
+                    t == op->src[1]->type;
+            } break;
         case GGML_OP_CPY:
             {
                 ggml_type src0_type = op->src[0]->type;
diff --git a/ggml/src/ggml-cuda/set.cu b/ggml/src/ggml-cuda/set.cu
new file mode 100644 (file)
index 0000000..04bfe07
--- /dev/null
@@ -0,0 +1,39 @@
+#include "set.cuh"
+#include "cpy.cuh"
+
+void ggml_cuda_op_set(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
+    const ggml_tensor * src0 = dst->src[0];
+    const ggml_tensor * src1 = dst->src[1];
+
+    GGML_ASSERT((src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_I32));
+    GGML_ASSERT(src1->type == src0->type);
+    GGML_ASSERT(dst ->type == src0->type);
+
+    GGML_ASSERT(ggml_is_contiguous(dst));
+    GGML_ASSERT(ggml_is_contiguous(src0));
+    GGML_ASSERT(ggml_is_contiguous(src1));
+
+    const size_t nb1    = ((int32_t *) dst->op_params)[0];
+    const size_t nb2    = ((int32_t *) dst->op_params)[1];
+    const size_t nb3    = ((int32_t *) dst->op_params)[2];
+    const size_t offset = ((int32_t *) dst->op_params)[3];
+    const bool   inplace= (bool)     ((int32_t *) dst->op_params)[4];
+
+    if (!inplace) {
+        ggml_cuda_cpy(ctx, src0, dst);
+    }
+
+    ggml_tensor dst_view = *dst;
+    dst_view.data  = (void *)((char *)dst->data + offset);
+    dst_view.ne[0] = src1->ne[0];
+    dst_view.ne[1] = src1->ne[1];
+    dst_view.ne[2] = src1->ne[2];
+    dst_view.ne[3] = src1->ne[3];
+
+    dst_view.nb[0] = ggml_element_size(dst);
+    dst_view.nb[1] = nb1;
+    dst_view.nb[2] = nb2;
+    dst_view.nb[3] = nb3;
+
+    ggml_cuda_cpy(ctx, src1, &dst_view);
+}
diff --git a/ggml/src/ggml-cuda/set.cuh b/ggml/src/ggml-cuda/set.cuh
new file mode 100644 (file)
index 0000000..dd09529
--- /dev/null
@@ -0,0 +1,7 @@
+#pragma once
+
+#include "common.cuh"
+
+#define CUDA_SET_BLOCK_SIZE 256
+
+void ggml_cuda_op_set(ggml_backend_cuda_context & ctx, ggml_tensor * dst);