]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
sycl: add REPEAT_BACK operation support (#16734)
authorshani-f <redacted>
Mon, 27 Oct 2025 01:19:50 +0000 (03:19 +0200)
committerGitHub <redacted>
Mon, 27 Oct 2025 01:19:50 +0000 (09:19 +0800)
* SYCL repeat_back v1 — add core op + switch case

* Implement repeat_back SYCL operation and minor fixes

* Update ggml/src/ggml-sycl/repeat_back.cpp

Co-authored-by: Sigbjørn Skjæret <redacted>
* Update ggml/src/ggml-sycl/repeat_back.hpp

Co-authored-by: Sigbjørn Skjæret <redacted>
* Update ggml/src/ggml-sycl/ggml-sycl.cpp

Co-authored-by: Sigbjørn Skjæret <redacted>
---------

Co-authored-by: Sigbjørn Skjæret <redacted>
ggml/src/ggml-sycl/ggml-sycl.cpp
ggml/src/ggml-sycl/repeat_back.cpp [new file with mode: 0644]
ggml/src/ggml-sycl/repeat_back.hpp [new file with mode: 0644]

index b695ba051b0257cd10b6f5ecbe7b6d918791cf2c..e6bcc596a4a4441df816d166eb2ac99da70d539f 100644 (file)
@@ -48,6 +48,7 @@
 #include "ggml-sycl/set.hpp"
 #include "ggml-sycl/sycl_hw.hpp"
 #include "ggml-sycl/getrows.hpp"
+#include "ggml-sycl/repeat_back.hpp"
 #include "ggml-sycl/quantize.hpp"
 #include "ggml.h"
 
@@ -2615,6 +2616,10 @@ catch (sycl::exception const &exc) {
   std::exit(1);
 }
 
+static void ggml_sycl_repeat_back(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
+    scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/1);
+    ggml_sycl_op_repeat_back(ctx, dst);
+}
 
 static void ggml_sycl_get_rows(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
     scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/2);
@@ -3679,6 +3684,9 @@ static bool ggml_sycl_compute_forward(ggml_backend_sycl_context & ctx, struct gg
         case GGML_OP_REPEAT:
             ggml_sycl_repeat(ctx, dst);
             break;
+        case GGML_OP_REPEAT_BACK:
+            ggml_sycl_repeat_back(ctx, dst);
+            break;
         case GGML_OP_GET_ROWS:
             ggml_sycl_get_rows(ctx, dst);
             break;
@@ -4516,6 +4524,11 @@ static bool ggml_backend_sycl_device_supports_op(ggml_backend_dev_t dev, const g
                 ggml_type src0_type = op->src[0]->type;
                 return src0_type != GGML_TYPE_I32 && src0_type != GGML_TYPE_I16;
             }
+        case GGML_OP_REPEAT_BACK:
+            {
+                ggml_type src0_type = op->src[0]->type;
+                return src0_type == GGML_TYPE_F32;
+            }
         case GGML_OP_DUP:
         case GGML_OP_ARGMAX:
         case GGML_OP_NONE:
diff --git a/ggml/src/ggml-sycl/repeat_back.cpp b/ggml/src/ggml-sycl/repeat_back.cpp
new file mode 100644 (file)
index 0000000..abcd4ce
--- /dev/null
@@ -0,0 +1,56 @@
+#include "repeat_back.hpp"
+
+#include "common.hpp"
+
+void ggml_sycl_op_repeat_back(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
+
+    GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32);
+    GGML_ASSERT(dst->type == GGML_TYPE_F32);
+
+    const float * src0_dd = (const float *) dst->src[0]->data;
+    float *       dst_dd  = (float *) dst->data;
+
+    const int64_t ne0 = dst->ne[0], ne1 = dst->ne[1], ne2 = dst->ne[2], ne3 = dst->ne[3];
+    const int64_t ne00 = dst->src[0]->ne[0], ne01 = dst->src[0]->ne[1], ne02 = dst->src[0]->ne[2],
+                  ne03 = dst->src[0]->ne[3];
+
+    const int nr0 = (int) (ne00 / ne0);
+    const int nr1 = (int) (ne01 / ne1);
+    const int nr2 = (int) (ne02 / ne2);
+    const int nr3 = (int) (ne03 / ne3);
+
+    const size_t total      = ne0 * ne1 * ne2 * ne3;
+    const int    BLOCK_SIZE = 256;
+    const int    num_blocks = (total + BLOCK_SIZE - 1) / BLOCK_SIZE;
+
+    queue_ptr stream = ctx.stream();
+
+    stream->parallel_for(
+        sycl::nd_range<1>(sycl::range<1>(num_blocks * BLOCK_SIZE), sycl::range<1>(BLOCK_SIZE)),
+        [=](sycl::nd_item<1> item_ct1) {
+            const size_t i = item_ct1.get_global_linear_id();
+            if (i >= total) {
+                return;
+            }
+
+            const int i0 = i % ne0;
+            const int i1 = (i / ne0) % ne1;
+            const int i2 = (i / (ne0 * ne1)) % ne2;
+            const int i3 = i / (ne0 * ne1 * ne2);
+
+            float acc = 0.0f;
+
+            for (int j3 = 0; j3 < nr3; ++j3) {
+                for (int j2 = 0; j2 < nr2; ++j2) {
+                    for (int j1 = 0; j1 < nr1; ++j1) {
+                        for (int j0 = 0; j0 < nr0; ++j0) {
+                            acc += src0_dd[(i0 + j0 * ne0) + (i1 + j1 * ne1) * ne00 + (i2 + j2 * ne2) * ne00 * ne01 +
+                                           (i3 + j3 * ne3) * ne00 * ne01 * ne02];
+                        }
+                    }
+                }
+            }
+
+            dst_dd[i] = acc;
+        });
+}
diff --git a/ggml/src/ggml-sycl/repeat_back.hpp b/ggml/src/ggml-sycl/repeat_back.hpp
new file mode 100644 (file)
index 0000000..17a87f3
--- /dev/null
@@ -0,0 +1,8 @@
+#ifndef GGML_SYCL_REPEAT_BACK_HPP
+#define GGML_SYCL_REPEAT_BACK_HPP
+
+#include "common.hpp"
+
+void ggml_sycl_op_repeat_back(ggml_backend_sycl_context & ctx, ggml_tensor * dst);
+
+#endif  // GGML_SYCL_REPEAT_BACK_HPP