]> git.djapps.eu Git - pkg/ggml/sources/whisper.cpp/commitdiff
ggml-webgpu: Add supports for `GGML_OP_REPEAT` (llama/20230)
authorMasashi Yoshimura <redacted>
Wed, 11 Mar 2026 21:40:36 +0000 (06:40 +0900)
committerGeorgi Gerganov <redacted>
Mon, 16 Mar 2026 11:10:15 +0000 (13:10 +0200)
* Add GGML_OP_REPEAT to webgpu backend.

* Add i16 support for GGML_OP_REPEAT.

ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp
ggml/src/ggml-webgpu/ggml-webgpu.cpp
ggml/src/ggml-webgpu/wgsl-shaders/repeat.wgsl [new file with mode: 0644]

index 3c38b1a230ffad526bc11c014007ca1086e80303..3d7e59fddf322d5ccfee1367edc01f5fb8fb4382 100644 (file)
@@ -198,6 +198,22 @@ struct ggml_webgpu_concat_pipeline_key_hash {
     }
 };
 
+/** Repeat **/
+
+struct ggml_webgpu_repeat_pipeline_key {
+    int type;
+
+    bool operator==(const ggml_webgpu_repeat_pipeline_key & other) const { return type == other.type; }
+};
+
+struct ggml_webgpu_repeat_pipeline_key_hash {
+    size_t operator()(const ggml_webgpu_repeat_pipeline_key & key) const {
+        size_t seed = 0;
+        ggml_webgpu_hash_combine(seed, key.type);
+        return seed;
+    }
+};
+
 /** Binary **/
 
 struct ggml_webgpu_binary_pipeline_key {
@@ -431,6 +447,8 @@ class ggml_webgpu_shader_lib {
         binary_pipelines;                                              // type/op/inplace/overlap
     std::unordered_map<ggml_webgpu_concat_pipeline_key, webgpu_pipeline, ggml_webgpu_concat_pipeline_key_hash>
         concat_pipelines;                                              // type
+    std::unordered_map<ggml_webgpu_repeat_pipeline_key, webgpu_pipeline, ggml_webgpu_repeat_pipeline_key_hash>
+        repeat_pipelines;                                              // type
     std::unordered_map<ggml_webgpu_flash_attn_pipeline_key, webgpu_pipeline, ggml_webgpu_flash_attn_pipeline_key_hash>
         flash_attn_pipelines;
     std::unordered_map<ggml_webgpu_legacy_mul_mat_pipeline_key,
@@ -1147,7 +1165,7 @@ class ggml_webgpu_shader_lib {
         }
 
         std::vector<std::string> defines;
-        std::string variant = "concat";
+        std::string              variant = "concat";
 
         switch (key.type) {
             case GGML_TYPE_F32:
@@ -1164,15 +1182,56 @@ class ggml_webgpu_shader_lib {
 
         defines.push_back(std::string("WG_SIZE=") + std::to_string(context.max_wg_size));
 
-        auto processed = preprocessor.preprocess(wgsl_concat, defines);
-        auto decisions = std::make_shared<ggml_webgpu_generic_shader_decisions>();
-        decisions->wg_size = context.max_wg_size;
+        auto processed           = preprocessor.preprocess(wgsl_concat, defines);
+        auto decisions           = std::make_shared<ggml_webgpu_generic_shader_decisions>();
+        decisions->wg_size       = context.max_wg_size;
         webgpu_pipeline pipeline = ggml_webgpu_create_pipeline(device, processed, variant);
-        pipeline.context = decisions;
-        concat_pipelines[key] = pipeline;
+        pipeline.context         = decisions;
+        concat_pipelines[key]    = pipeline;
         return concat_pipelines[key];
     }
 
+    webgpu_pipeline get_repeat_pipeline(const ggml_webgpu_shader_lib_context & context) {
+        ggml_webgpu_repeat_pipeline_key key = {
+            .type = context.dst->type,
+        };
+
+        auto it = repeat_pipelines.find(key);
+        if (it != repeat_pipelines.end()) {
+            return it->second;
+        }
+
+        std::vector<std::string> defines;
+        std::string              variant = "repeat";
+
+        switch (key.type) {
+            case GGML_TYPE_F32:
+                defines.push_back("TYPE_F32");
+                variant += "_f32";
+                break;
+            case GGML_TYPE_I32:
+                defines.push_back("TYPE_I32");
+                variant += "_i32";
+                break;
+            case GGML_TYPE_I16:
+                defines.push_back("TYPE_I16");
+                variant += "_i16";
+                break;
+            default:
+                GGML_ABORT("Unsupported type for repeat shader");
+        }
+
+        defines.push_back(std::string("WG_SIZE=") + std::to_string(context.max_wg_size));
+
+        auto processed           = preprocessor.preprocess(wgsl_repeat, defines);
+        auto decisions           = std::make_shared<ggml_webgpu_generic_shader_decisions>();
+        decisions->wg_size       = context.max_wg_size;
+        webgpu_pipeline pipeline = ggml_webgpu_create_pipeline(device, processed, variant);
+        pipeline.context         = decisions;
+        repeat_pipelines[key]    = pipeline;
+        return repeat_pipelines[key];
+    }
+
     webgpu_pipeline get_flash_attn_pipeline(const ggml_webgpu_shader_lib_context & context) {
         const bool has_mask  = context.src3 != nullptr;
         const bool has_sinks = context.src4 != nullptr;
index ccc34cb153f70e17fa7c3c8b140022747e36e62f..128b7dc3de8abd5380e32261fc41c5c53dee271b 100644 (file)
@@ -1567,6 +1567,48 @@ static webgpu_command ggml_webgpu_concat(webgpu_context & ctx,
     return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_buf_pool, pipeline, params, entries, wg_x);
 }
 
+static webgpu_command ggml_webgpu_repeat(webgpu_context & ctx, ggml_tensor * src0, ggml_tensor * dst) {
+    uint32_t ne = (uint32_t) ggml_nelements(dst);
+
+    std::vector<uint32_t> params = { ne,
+                                     (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src0) /
+                                                 ggml_type_size(src0->type)),
+                                     (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, dst) / ggml_type_size(dst->type)),
+                                     (uint32_t) (src0->nb[0] / ggml_type_size(src0->type)),
+                                     (uint32_t) (src0->nb[1] / ggml_type_size(src0->type)),
+                                     (uint32_t) (src0->nb[2] / ggml_type_size(src0->type)),
+                                     (uint32_t) (src0->nb[3] / ggml_type_size(src0->type)),
+                                     (uint32_t) (src0->ne[0]),
+                                     (uint32_t) (src0->ne[1]),
+                                     (uint32_t) (src0->ne[2]),
+                                     (uint32_t) (src0->ne[3]),
+                                     (uint32_t) (dst->ne[0]),
+                                     (uint32_t) (dst->ne[1]),
+                                     (uint32_t) (dst->ne[2]) };
+
+    std::vector<wgpu::BindGroupEntry> entries = {
+        { .binding = 0,
+         .buffer  = ggml_webgpu_tensor_buf(src0),
+         .offset  = ggml_webgpu_tensor_align_offset(ctx, src0),
+         .size    = ggml_webgpu_tensor_binding_size(ctx, src0) },
+        { .binding = 1,
+         .buffer  = ggml_webgpu_tensor_buf(dst),
+         .offset  = ggml_webgpu_tensor_align_offset(ctx, dst),
+         .size    = ggml_webgpu_tensor_binding_size(ctx, dst)  }
+    };
+
+    ggml_webgpu_shader_lib_context shader_lib_ctx = {
+        .src0        = src0,
+        .dst         = dst,
+        .max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup,
+    };
+
+    webgpu_pipeline pipeline  = ctx->shader_lib->get_repeat_pipeline(shader_lib_ctx);
+    auto *          decisions = static_cast<ggml_webgpu_generic_shader_decisions *>(pipeline.context.get());
+    uint32_t        wg_x      = CEIL_DIV(ne, decisions->wg_size);
+    return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_buf_pool, pipeline, params, entries, wg_x);
+}
+
 static webgpu_command ggml_webgpu_rms_norm(webgpu_context & ctx, ggml_tensor * src, ggml_tensor * dst) {
     int inplace = ggml_webgpu_tensor_equal(src, dst);
 
@@ -2158,6 +2200,8 @@ static std::optional<webgpu_command> ggml_webgpu_encode_node(webgpu_context ctx,
             return ggml_webgpu_binary_op(ctx, src0, src1, node);
         case GGML_OP_CONCAT:
             return ggml_webgpu_concat(ctx, src0, src1, node);
+        case GGML_OP_REPEAT:
+            return ggml_webgpu_repeat(ctx, src0, node);
         case GGML_OP_RMS_NORM:
             return ggml_webgpu_rms_norm(ctx, src0, node);
         case GGML_OP_ROPE:
@@ -2919,10 +2963,10 @@ static ggml_backend_buffer_type_t ggml_backend_webgpu_device_get_buffer_type(ggm
         /* .iface = */ {
                         /* .get_name         = */ ggml_backend_webgpu_buffer_type_get_name,
                         /* .alloc_buffer     = */
-            ggml_backend_webgpu_buffer_type_alloc_buffer,  /* .get_alignment    = */
-            ggml_backend_webgpu_buffer_type_get_alignment, /* .get_max_size     = */
-            ggml_backend_webgpu_buffer_type_get_max_size,  /* .get_alloc_size   = */
-            ggml_backend_webgpu_buffer_type_get_alloc_size, /* .is_host          = */ NULL,                // defaults to false
+            ggml_backend_webgpu_buffer_type_alloc_buffer,                                    /* .get_alignment    = */
+            ggml_backend_webgpu_buffer_type_get_alignment,                                   /* .get_max_size     = */
+            ggml_backend_webgpu_buffer_type_get_max_size,                                    /* .get_alloc_size   = */
+            ggml_backend_webgpu_buffer_type_get_alloc_size, /* .is_host          = */ NULL,  // defaults to false
         },
         /* .device  = */
         dev,
@@ -3000,6 +3044,9 @@ static bool ggml_backend_webgpu_device_supports_op(ggml_backend_dev_t dev, const
         case GGML_OP_CONCAT:
             supports_op = (src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_I32);
             break;
+        case GGML_OP_REPEAT:
+            supports_op = (src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_I32 || src0->type == GGML_TYPE_I16);
+            break;
         case GGML_OP_CPY:
         case GGML_OP_CONT:
             supports_op = ((op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16) &&
diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/repeat.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/repeat.wgsl
new file mode 100644 (file)
index 0000000..6e2a1a8
--- /dev/null
@@ -0,0 +1,67 @@
+enable f16;
+
+struct Params {
+    ne: u32,
+
+    offset_src0: u32,
+    offset_dst: u32,
+
+    stride_src0_0: u32,
+    stride_src0_1: u32,
+    stride_src0_2: u32,
+    stride_src0_3: u32,
+
+    a_ne0: u32,
+    a_ne1: u32,
+    a_ne2: u32,
+    a_ne3: u32,
+
+    ne0: u32,
+    ne1: u32,
+    ne2: u32,
+};
+
+#ifdef TYPE_F32
+#define DataType f32
+#endif
+#ifdef TYPE_I32
+#define DataType i32
+#endif
+#ifdef TYPE_I16
+// same size (16-bit) is sufficient for repeat
+#define DataType f16
+#endif
+
+@group(0) @binding(0)
+var<storage, read_write> src0: array<DataType>;
+
+@group(0) @binding(1)
+var<storage, read_write> dst: array<DataType>;
+
+@group(0) @binding(2)
+var<uniform> params: Params;
+
+@compute @workgroup_size(WG_SIZE)
+fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
+    if (gid.x < params.ne) {
+        var i = gid.x;
+        let i3 = i / (params.ne2 * params.ne1 * params.ne0);
+        i = i % (params.ne2 * params.ne1 * params.ne0);
+        let i2 = i / (params.ne1 * params.ne0);
+        i = i % (params.ne1 * params.ne0);
+        let i1 = i / params.ne0;
+        let i0 = i % params.ne0;
+
+        let a_i0 = i0 % params.a_ne0;
+        let a_i1 = i1 % params.a_ne1;
+        let a_i2 = i2 % params.a_ne2;
+        let a_i3 = i3 % params.a_ne3;
+
+        let a_index = a_i0 * params.stride_src0_0 +
+                           a_i1 * params.stride_src0_1 +
+                           a_i2 * params.stride_src0_2 +
+                           a_i3 * params.stride_src0_3;
+
+        dst[params.offset_dst + gid.x] = src0[params.offset_src0 + a_index];
+    }
+}