]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
ggml : add ggml_repeat_4d (#13824)
authorXuan-Son Nguyen <redacted>
Tue, 27 May 2025 13:53:55 +0000 (15:53 +0200)
committerGitHub <redacted>
Tue, 27 May 2025 13:53:55 +0000 (15:53 +0200)
ggml/include/ggml.h
ggml/src/ggml.c

index bff7dea3a539bc7b2dde44c0eec6a75a16aa6b42..21ee8a1410a4e1c3fd607196f913ee1724ce4913 100644 (file)
@@ -935,6 +935,15 @@ extern "C" {
             struct ggml_tensor  * a,
             struct ggml_tensor  * b);
 
+    // repeat a to the specified shape
+    GGML_API struct ggml_tensor * ggml_repeat_4d(
+            struct ggml_context * ctx,
+            struct ggml_tensor  * a,
+                       int64_t    ne0,
+                       int64_t    ne1,
+                       int64_t    ne2,
+                       int64_t    ne3);
+
     // sums repetitions in a into shape of b
     GGML_API struct ggml_tensor * ggml_repeat_back(
             struct ggml_context * ctx,
index 57d3e39adf7581d4f413d9f54b9ee4b0e06b6a7a..fb0d379dc8d68284c168e0c091395caad686fce9 100644 (file)
@@ -2312,6 +2312,26 @@ struct ggml_tensor * ggml_repeat(
     return result;
 }
 
+struct ggml_tensor * ggml_repeat_4d(
+        struct ggml_context * ctx,
+        struct ggml_tensor * a,
+        int64_t ne0, int64_t ne1, int64_t ne2, int64_t ne3) {
+    const bool can_repeat = ggml_is_empty(a) || (
+        (ne0 % a->ne[0] == 0) &&
+        (ne1 % a->ne[1] == 0) &&
+        (ne2 % a->ne[2] == 0) &&
+        (ne3 % a->ne[3] == 0)
+    );
+    GGML_ASSERT(can_repeat);
+
+    struct ggml_tensor * result = ggml_new_tensor_4d(ctx, a->type, ne0, ne1, ne2, ne3);
+
+    result->op     = GGML_OP_REPEAT;
+    result->src[0] = a;
+
+    return result;
+}
+
 // ggml_repeat_back
 
 struct ggml_tensor * ggml_repeat_back(