]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
ggml-webgpu: Add supports for `GGML_OP_REPEAT` (#20230)
authorMasashi Yoshimura <redacted>
Wed, 11 Mar 2026 21:40:36 +0000 (06:40 +0900)
committerGitHub <redacted>
Wed, 11 Mar 2026 21:40:36 +0000 (14:40 -0700)
* Add GGML_OP_REPEAT to webgpu backend.

* Add i16 support for GGML_OP_REPEAT.

docs/ops.md
docs/ops/WebGPU.csv
ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp
ggml/src/ggml-webgpu/ggml-webgpu.cpp
ggml/src/ggml-webgpu/wgsl-shaders/repeat.wgsl [new file with mode: 0644]

index 37329d56a8d2483be6c2ac57f78cd06bce5dba4b..f914c2b7d20ddfb2621401c2e9270a1ca8b7a330 100644 (file)
@@ -80,7 +80,7 @@ Legend:
 |                          POOL_2D | โŒ | ๐ŸŸก | โœ… | โœ… | โœ… | โŒ | โœ… | โœ… | โŒ | โŒ | โŒ |
 |                            REGLU | โŒ | โœ… | โœ… | โœ… | ๐ŸŸก | โœ… | โœ… | ๐ŸŸก | โœ… | โŒ | โŒ |
 |                             RELU | โŒ | โœ… | โœ… | ๐ŸŸก | ๐ŸŸก | ๐ŸŸก | โœ… | ๐ŸŸก | โœ… | โŒ | โŒ |
-|                           REPEAT | รข\9d\8c | รข\9c\85 | รข\9c\85 | รฐ\9f\9fยก | รข\9c\85 | รฐ\9f\9fยก | รข\9c\85 | รฐ\9f\9fยก | รข\9d\8c | โŒ | โŒ |
+|                           REPEAT | รข\9d\8c | รข\9c\85 | รข\9c\85 | รฐ\9f\9fยก | รข\9c\85 | รฐ\9f\9fยก | รข\9c\85 | รฐ\9f\9fยก | รข\9c\85 | โŒ | โŒ |
 |                      REPEAT_BACK | โŒ | โŒ | โœ… | โœ… | โŒ | โŒ | โœ… | โœ… | โŒ | โŒ | โŒ |
 |                         RMS_NORM | โŒ | โœ… | โœ… | โœ… | โœ… | โœ… | โœ… | โœ… | โœ… | โŒ | โŒ |
 |                    RMS_NORM_BACK | โŒ | โŒ | โœ… | โœ… | โŒ | โŒ | โœ… | โœ… | โŒ | โŒ | โŒ |
index 9e081e7605fc85c0fc66fcb20921ad9a2d8939a1..b7761b9dd3f91ac2f3d01a663d71dee373980838 100644 (file)
 "WebGPU: WebGPU","ARGMAX","type=f32,ne=[1024,12,1,1]","support","1","yes","WebGPU"
 "WebGPU: WebGPU","ARGMAX","type=f32,ne=[2000,10,1,1]","support","1","yes","WebGPU"
 "WebGPU: WebGPU","ARGMAX","type=f32,ne=[5438,3,1,1]","support","1","yes","WebGPU"
-"WebGPU: WebGPU","REPEAT","type=f32,ne=[10,5,4,1],nr=[1,1,1,1]","support","0","no","WebGPU"
-"WebGPU: WebGPU","REPEAT","type=f32,ne=[10,5,4,1],nr=[2,1,1,1]","support","0","no","WebGPU"
-"WebGPU: WebGPU","REPEAT","type=f32,ne=[10,5,4,1],nr=[1,2,1,1]","support","0","no","WebGPU"
-"WebGPU: WebGPU","REPEAT","type=f32,ne=[10,5,4,1],nr=[1,1,2,1]","support","0","no","WebGPU"
-"WebGPU: WebGPU","REPEAT","type=f32,ne=[10,5,4,1],nr=[1,1,1,2]","support","0","no","WebGPU"
-"WebGPU: WebGPU","REPEAT","type=i32,ne=[10,5,4,1],nr=[2,1,1,1]","support","0","no","WebGPU"
-"WebGPU: WebGPU","REPEAT","type=i16,ne=[10,5,4,1],nr=[1,1,1,2]","support","0","no","WebGPU"
-"WebGPU: WebGPU","REPEAT","type=f32,ne=[10,5,4,3],nr=[1,1,1,1]","support","0","no","WebGPU"
-"WebGPU: WebGPU","REPEAT","type=f32,ne=[10,5,4,3],nr=[2,1,1,1]","support","0","no","WebGPU"
-"WebGPU: WebGPU","REPEAT","type=f32,ne=[10,5,4,3],nr=[1,2,1,1]","support","0","no","WebGPU"
-"WebGPU: WebGPU","REPEAT","type=f32,ne=[10,5,4,3],nr=[1,1,2,1]","support","0","no","WebGPU"
-"WebGPU: WebGPU","REPEAT","type=f32,ne=[10,5,4,3],nr=[1,1,1,2]","support","0","no","WebGPU"
-"WebGPU: WebGPU","REPEAT","type=i32,ne=[10,5,4,3],nr=[2,1,1,1]","support","0","no","WebGPU"
-"WebGPU: WebGPU","REPEAT","type=i16,ne=[10,5,4,3],nr=[1,1,1,2]","support","0","no","WebGPU"
+"WebGPU: WebGPU","REPEAT","type=f32,ne=[10,5,4,1],nr=[1,1,1,1]","support","1","yes","WebGPU"
+"WebGPU: WebGPU","REPEAT","type=f32,ne=[10,5,4,1],nr=[2,1,1,1]","support","1","yes","WebGPU"
+"WebGPU: WebGPU","REPEAT","type=f32,ne=[10,5,4,1],nr=[1,2,1,1]","support","1","yes","WebGPU"
+"WebGPU: WebGPU","REPEAT","type=f32,ne=[10,5,4,1],nr=[1,1,2,1]","support","1","yes","WebGPU"
+"WebGPU: WebGPU","REPEAT","type=f32,ne=[10,5,4,1],nr=[1,1,1,2]","support","1","yes","WebGPU"
+"WebGPU: WebGPU","REPEAT","type=i32,ne=[10,5,4,1],nr=[2,1,1,1]","support","1","yes","WebGPU"
+"WebGPU: WebGPU","REPEAT","type=i16,ne=[10,5,4,1],nr=[1,1,1,2]","support","1","yes","WebGPU"
+"WebGPU: WebGPU","REPEAT","type=f32,ne=[10,5,4,3],nr=[1,1,1,1]","support","1","yes","WebGPU"
+"WebGPU: WebGPU","REPEAT","type=f32,ne=[10,5,4,3],nr=[2,1,1,1]","support","1","yes","WebGPU"
+"WebGPU: WebGPU","REPEAT","type=f32,ne=[10,5,4,3],nr=[1,2,1,1]","support","1","yes","WebGPU"
+"WebGPU: WebGPU","REPEAT","type=f32,ne=[10,5,4,3],nr=[1,1,2,1]","support","1","yes","WebGPU"
+"WebGPU: WebGPU","REPEAT","type=f32,ne=[10,5,4,3],nr=[1,1,1,2]","support","1","yes","WebGPU"
+"WebGPU: WebGPU","REPEAT","type=i32,ne=[10,5,4,3],nr=[2,1,1,1]","support","1","yes","WebGPU"
+"WebGPU: WebGPU","REPEAT","type=i16,ne=[10,5,4,3],nr=[1,1,1,2]","support","1","yes","WebGPU"
 "WebGPU: WebGPU","REPEAT_BACK","type=f32,ne=[8,6,4,2],nr=[1,1,1,1],v=0","support","0","no","WebGPU"
 "WebGPU: WebGPU","REPEAT_BACK","type=f32,ne=[8,6,4,2],nr=[2,1,1,1],v=0","support","0","no","WebGPU"
 "WebGPU: WebGPU","REPEAT_BACK","type=f32,ne=[8,6,4,2],nr=[1,2,1,1],v=0","support","0","no","WebGPU"
index 3c38b1a230ffad526bc11c014007ca1086e80303..3d7e59fddf322d5ccfee1367edc01f5fb8fb4382 100644 (file)
@@ -198,6 +198,22 @@ struct ggml_webgpu_concat_pipeline_key_hash {
     }
 };
 
+/** Repeat **/
+
+struct ggml_webgpu_repeat_pipeline_key {
+    int type;
+
+    bool operator==(const ggml_webgpu_repeat_pipeline_key & other) const { return type == other.type; }
+};
+
+struct ggml_webgpu_repeat_pipeline_key_hash {
+    size_t operator()(const ggml_webgpu_repeat_pipeline_key & key) const {
+        size_t seed = 0;
+        ggml_webgpu_hash_combine(seed, key.type);
+        return seed;
+    }
+};
+
 /** Binary **/
 
 struct ggml_webgpu_binary_pipeline_key {
@@ -431,6 +447,8 @@ class ggml_webgpu_shader_lib {
         binary_pipelines;                                              // type/op/inplace/overlap
     std::unordered_map<ggml_webgpu_concat_pipeline_key, webgpu_pipeline, ggml_webgpu_concat_pipeline_key_hash>
         concat_pipelines;                                              // type
+    std::unordered_map<ggml_webgpu_repeat_pipeline_key, webgpu_pipeline, ggml_webgpu_repeat_pipeline_key_hash>
+        repeat_pipelines;                                              // type
     std::unordered_map<ggml_webgpu_flash_attn_pipeline_key, webgpu_pipeline, ggml_webgpu_flash_attn_pipeline_key_hash>
         flash_attn_pipelines;
     std::unordered_map<ggml_webgpu_legacy_mul_mat_pipeline_key,
@@ -1147,7 +1165,7 @@ class ggml_webgpu_shader_lib {
         }
 
         std::vector<std::string> defines;
-        std::string variant = "concat";
+        std::string              variant = "concat";
 
         switch (key.type) {
             case GGML_TYPE_F32:
@@ -1164,15 +1182,56 @@ class ggml_webgpu_shader_lib {
 
         defines.push_back(std::string("WG_SIZE=") + std::to_string(context.max_wg_size));
 
-        auto processed = preprocessor.preprocess(wgsl_concat, defines);
-        auto decisions = std::make_shared<ggml_webgpu_generic_shader_decisions>();
-        decisions->wg_size = context.max_wg_size;
+        auto processed           = preprocessor.preprocess(wgsl_concat, defines);
+        auto decisions           = std::make_shared<ggml_webgpu_generic_shader_decisions>();
+        decisions->wg_size       = context.max_wg_size;
         webgpu_pipeline pipeline = ggml_webgpu_create_pipeline(device, processed, variant);
-        pipeline.context = decisions;
-        concat_pipelines[key] = pipeline;
+        pipeline.context         = decisions;
+        concat_pipelines[key]    = pipeline;
         return concat_pipelines[key];
     }
 
+    webgpu_pipeline get_repeat_pipeline(const ggml_webgpu_shader_lib_context & context) {
+        ggml_webgpu_repeat_pipeline_key key = {
+            .type = context.dst->type,
+        };
+
+        auto it = repeat_pipelines.find(key);
+        if (it != repeat_pipelines.end()) {
+            return it->second;
+        }
+
+        std::vector<std::string> defines;
+        std::string              variant = "repeat";
+
+        switch (key.type) {
+            case GGML_TYPE_F32:
+                defines.push_back("TYPE_F32");
+                variant += "_f32";
+                break;
+            case GGML_TYPE_I32:
+                defines.push_back("TYPE_I32");
+                variant += "_i32";
+                break;
+            case GGML_TYPE_I16:
+                defines.push_back("TYPE_I16");
+                variant += "_i16";
+                break;
+            default:
+                GGML_ABORT("Unsupported type for repeat shader");
+        }
+
+        defines.push_back(std::string("WG_SIZE=") + std::to_string(context.max_wg_size));
+
+        auto processed           = preprocessor.preprocess(wgsl_repeat, defines);
+        auto decisions           = std::make_shared<ggml_webgpu_generic_shader_decisions>();
+        decisions->wg_size       = context.max_wg_size;
+        webgpu_pipeline pipeline = ggml_webgpu_create_pipeline(device, processed, variant);
+        pipeline.context         = decisions;
+        repeat_pipelines[key]    = pipeline;
+        return repeat_pipelines[key];
+    }
+
     webgpu_pipeline get_flash_attn_pipeline(const ggml_webgpu_shader_lib_context & context) {
         const bool has_mask  = context.src3 != nullptr;
         const bool has_sinks = context.src4 != nullptr;
index ccc34cb153f70e17fa7c3c8b140022747e36e62f..128b7dc3de8abd5380e32261fc41c5c53dee271b 100644 (file)
@@ -1567,6 +1567,48 @@ static webgpu_command ggml_webgpu_concat(webgpu_context & ctx,
     return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_buf_pool, pipeline, params, entries, wg_x);
 }
 
+static webgpu_command ggml_webgpu_repeat(webgpu_context & ctx, ggml_tensor * src0, ggml_tensor * dst) {
+    uint32_t ne = (uint32_t) ggml_nelements(dst);
+
+    std::vector<uint32_t> params = { ne,
+                                     (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src0) /
+                                                 ggml_type_size(src0->type)),
+                                     (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, dst) / ggml_type_size(dst->type)),
+                                     (uint32_t) (src0->nb[0] / ggml_type_size(src0->type)),
+                                     (uint32_t) (src0->nb[1] / ggml_type_size(src0->type)),
+                                     (uint32_t) (src0->nb[2] / ggml_type_size(src0->type)),
+                                     (uint32_t) (src0->nb[3] / ggml_type_size(src0->type)),
+                                     (uint32_t) (src0->ne[0]),
+                                     (uint32_t) (src0->ne[1]),
+                                     (uint32_t) (src0->ne[2]),
+                                     (uint32_t) (src0->ne[3]),
+                                     (uint32_t) (dst->ne[0]),
+                                     (uint32_t) (dst->ne[1]),
+                                     (uint32_t) (dst->ne[2]) };
+
+    std::vector<wgpu::BindGroupEntry> entries = {
+        { .binding = 0,
+         .buffer  = ggml_webgpu_tensor_buf(src0),
+         .offset  = ggml_webgpu_tensor_align_offset(ctx, src0),
+         .size    = ggml_webgpu_tensor_binding_size(ctx, src0) },
+        { .binding = 1,
+         .buffer  = ggml_webgpu_tensor_buf(dst),
+         .offset  = ggml_webgpu_tensor_align_offset(ctx, dst),
+         .size    = ggml_webgpu_tensor_binding_size(ctx, dst)  }
+    };
+
+    ggml_webgpu_shader_lib_context shader_lib_ctx = {
+        .src0        = src0,
+        .dst         = dst,
+        .max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup,
+    };
+
+    webgpu_pipeline pipeline  = ctx->shader_lib->get_repeat_pipeline(shader_lib_ctx);
+    auto *          decisions = static_cast<ggml_webgpu_generic_shader_decisions *>(pipeline.context.get());
+    uint32_t        wg_x      = CEIL_DIV(ne, decisions->wg_size);
+    return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_buf_pool, pipeline, params, entries, wg_x);
+}
+
 static webgpu_command ggml_webgpu_rms_norm(webgpu_context & ctx, ggml_tensor * src, ggml_tensor * dst) {
     int inplace = ggml_webgpu_tensor_equal(src, dst);
 
@@ -2158,6 +2200,8 @@ static std::optional<webgpu_command> ggml_webgpu_encode_node(webgpu_context ctx,
             return ggml_webgpu_binary_op(ctx, src0, src1, node);
         case GGML_OP_CONCAT:
             return ggml_webgpu_concat(ctx, src0, src1, node);
+        case GGML_OP_REPEAT:
+            return ggml_webgpu_repeat(ctx, src0, node);
         case GGML_OP_RMS_NORM:
             return ggml_webgpu_rms_norm(ctx, src0, node);
         case GGML_OP_ROPE:
@@ -2919,10 +2963,10 @@ static ggml_backend_buffer_type_t ggml_backend_webgpu_device_get_buffer_type(ggm
         /* .iface = */ {
                         /* .get_name         = */ ggml_backend_webgpu_buffer_type_get_name,
                         /* .alloc_buffer     = */
-            ggml_backend_webgpu_buffer_type_alloc_buffer,  /* .get_alignment    = */
-            ggml_backend_webgpu_buffer_type_get_alignment, /* .get_max_size     = */
-            ggml_backend_webgpu_buffer_type_get_max_size,  /* .get_alloc_size   = */
-            ggml_backend_webgpu_buffer_type_get_alloc_size, /* .is_host          = */ NULL,                // defaults to false
+            ggml_backend_webgpu_buffer_type_alloc_buffer,                                    /* .get_alignment    = */
+            ggml_backend_webgpu_buffer_type_get_alignment,                                   /* .get_max_size     = */
+            ggml_backend_webgpu_buffer_type_get_max_size,                                    /* .get_alloc_size   = */
+            ggml_backend_webgpu_buffer_type_get_alloc_size, /* .is_host          = */ NULL,  // defaults to false
         },
         /* .device  = */
         dev,
@@ -3000,6 +3044,9 @@ static bool ggml_backend_webgpu_device_supports_op(ggml_backend_dev_t dev, const
         case GGML_OP_CONCAT:
             supports_op = (src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_I32);
             break;
+        case GGML_OP_REPEAT:
+            supports_op = (src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_I32 || src0->type == GGML_TYPE_I16);
+            break;
         case GGML_OP_CPY:
         case GGML_OP_CONT:
             supports_op = ((op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16) &&
diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/repeat.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/repeat.wgsl
new file mode 100644 (file)
index 0000000..6e2a1a8
--- /dev/null
@@ -0,0 +1,67 @@
+enable f16;
+
+struct Params {
+    ne: u32,
+
+    offset_src0: u32,
+    offset_dst: u32,
+
+    stride_src0_0: u32,
+    stride_src0_1: u32,
+    stride_src0_2: u32,
+    stride_src0_3: u32,
+
+    a_ne0: u32,
+    a_ne1: u32,
+    a_ne2: u32,
+    a_ne3: u32,
+
+    ne0: u32,
+    ne1: u32,
+    ne2: u32,
+};
+
+#ifdef TYPE_F32
+#define DataType f32
+#endif
+#ifdef TYPE_I32
+#define DataType i32
+#endif
+#ifdef TYPE_I16
+// same size (16-bit) is sufficient for repeat
+#define DataType f16
+#endif
+
+@group(0) @binding(0)
+var<storage, read_write> src0: array<DataType>;
+
+@group(0) @binding(1)
+var<storage, read_write> dst: array<DataType>;
+
+@group(0) @binding(2)
+var<uniform> params: Params;
+
+@compute @workgroup_size(WG_SIZE)
+fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
+    if (gid.x < params.ne) {
+        var i = gid.x;
+        let i3 = i / (params.ne2 * params.ne1 * params.ne0);
+        i = i % (params.ne2 * params.ne1 * params.ne0);
+        let i2 = i / (params.ne1 * params.ne0);
+        i = i % (params.ne1 * params.ne0);
+        let i1 = i / params.ne0;
+        let i0 = i % params.ne0;
+
+        let a_i0 = i0 % params.a_ne0;
+        let a_i1 = i1 % params.a_ne1;
+        let a_i2 = i2 % params.a_ne2;
+        let a_i3 = i3 % params.a_ne3;
+
+        let a_index = a_i0 * params.stride_src0_0 +
+                           a_i1 * params.stride_src0_1 +
+                           a_i2 * params.stride_src0_2 +
+                           a_i3 * params.stride_src0_3;
+
+        dst[params.offset_dst + gid.x] = src0[params.offset_src0 + a_index];
+    }
+}