]> git.djapps.eu Git - pkg/ggml/sources/whisper.cpp/commitdiff
ggml webgpu: support for rope,div,sub,glu,scale,cont operators (llama/16187)
authorReese Levine <redacted>
Tue, 30 Sep 2025 16:57:51 +0000 (09:57 -0700)
committerGeorgi Gerganov <redacted>
Sun, 12 Oct 2025 08:16:23 +0000 (11:16 +0300)
* Work on rope

* Simplify inplace operation generation and combine mul/add generation

* Work on rope variants

* implement neox rope

* rope complete

* Add sub,div,glu operators

* implement scale op

* Update cpy shader to handle cont/more types

* formatting

* Update test vars printing for rope,rms_norm

* Avoid ROPE hardcoded constants

* Add TODO to change ROPE constants to enum

Co-authored-by: Georgi Gerganov <redacted>
* fix TODO comment

---------

Co-authored-by: Georgi Gerganov <redacted>
16 files changed:
ggml/include/ggml.h
ggml/src/ggml-webgpu/ggml-webgpu.cpp
ggml/src/ggml-webgpu/wgsl-shaders/add.tmpl.wgsl [deleted file]
ggml/src/ggml-webgpu/wgsl-shaders/add_in_place.tmpl.wgsl [deleted file]
ggml/src/ggml-webgpu/wgsl-shaders/bin_op.tmpl.wgsl [new file with mode: 0644]
ggml/src/ggml-webgpu/wgsl-shaders/cpy.tmpl.wgsl [new file with mode: 0644]
ggml/src/ggml-webgpu/wgsl-shaders/cpy.wgsl [deleted file]
ggml/src/ggml-webgpu/wgsl-shaders/embed_wgsl.py
ggml/src/ggml-webgpu/wgsl-shaders/get_rows.tmpl.wgsl
ggml/src/ggml-webgpu/wgsl-shaders/glu.tmpl.wgsl [new file with mode: 0644]
ggml/src/ggml-webgpu/wgsl-shaders/mul.tmpl.wgsl [deleted file]
ggml/src/ggml-webgpu/wgsl-shaders/mul_in_place.tmpl.wgsl [deleted file]
ggml/src/ggml-webgpu/wgsl-shaders/rms_norm.wgsl
ggml/src/ggml-webgpu/wgsl-shaders/rms_norm_in_place.wgsl [deleted file]
ggml/src/ggml-webgpu/wgsl-shaders/rope.tmpl.wgsl [new file with mode: 0644]
ggml/src/ggml-webgpu/wgsl-shaders/scale.tmpl.wgsl [new file with mode: 0644]

index 36b23dc6d0d827407ac667119b52c656eab2a6d4..5028a9cebf260b657f199fe80e3e2c7dcb630cd3 100644 (file)
 #define GGML_EXIT_SUCCESS 0
 #define GGML_EXIT_ABORTED 1
 
+// TODO: convert to enum https://github.com/ggml-org/llama.cpp/pull/16187#discussion_r2388538726
+#define GGML_ROPE_TYPE_NORMAL 0
 #define GGML_ROPE_TYPE_NEOX   2
 #define GGML_ROPE_TYPE_MROPE  8
 #define GGML_ROPE_TYPE_VISION 24
index cee4b08366d79c387781a94b2e1bd5d7e87f8834..93200a4d29f53bc5c5f7043b0ec246f8d03f99ba 100644 (file)
@@ -130,13 +130,15 @@ struct webgpu_context_struct {
     wgpu::ComputePipeline set_rows_pipeline;
     wgpu::ComputePipeline get_rows_pipeline[30];
     wgpu::ComputePipeline get_rows_f32_no_vec_pipeline;
-    wgpu::ComputePipeline cpy_pipeline;
-    wgpu::ComputePipeline add_pipeline[2];
-    wgpu::ComputePipeline add_ip_pipeline[2];
-    wgpu::ComputePipeline mul_pipeline[2];
-    wgpu::ComputePipeline mul_ip_pipeline[2];
-    wgpu::ComputePipeline rms_norm_pipeline;
-    wgpu::ComputePipeline rms_norm_ip_pipeline;
+    wgpu::ComputePipeline cpy_pipeline[2][2];      // src type, dst type
+    wgpu::ComputePipeline add_pipeline[2][2];      // type, inplace
+    wgpu::ComputePipeline sub_pipeline[2][2];      // type, inplace
+    wgpu::ComputePipeline mul_pipeline[2][2];      // type, inplace
+    wgpu::ComputePipeline div_pipeline[2][2];      // type, inplace
+    wgpu::ComputePipeline rms_norm_pipeline[2];    // inplace
+    wgpu::ComputePipeline rope_pipeline[2][2][2];  // type, ff, inplace
+    wgpu::ComputePipeline glu_pipeline[7][2][2];   // glu-op, type, split
+    wgpu::ComputePipeline scale_pipeline[2];       // inplace
 
     size_t memset_bytes_per_thread;
 
@@ -489,8 +491,9 @@ static void ggml_webgpu_cpy(webgpu_context & ctx, ggml_tensor * src, ggml_tensor
         (uint32_t) (src->nb[2] / ggml_type_size(src->type)), (uint32_t) (src->nb[3] / ggml_type_size(src->type)),
         (uint32_t) (dst->nb[0] / ggml_type_size(dst->type)), (uint32_t) (dst->nb[1] / ggml_type_size(dst->type)),
         (uint32_t) (dst->nb[2] / ggml_type_size(dst->type)), (uint32_t) (dst->nb[3] / ggml_type_size(dst->type)),
-        // Logical shape — same for both tensors even if permuted
-        (uint32_t) src->ne[0], (uint32_t) src->ne[1], (uint32_t) src->ne[2], (uint32_t) src->ne[3]
+        // Logical shapes
+        (uint32_t) src->ne[0], (uint32_t) src->ne[1], (uint32_t) src->ne[2], (uint32_t) dst->ne[0],
+        (uint32_t) dst->ne[1], (uint32_t) dst->ne[2]
     };
 
     std::vector<wgpu::BindGroupEntry> entries = {
@@ -506,7 +509,8 @@ static void ggml_webgpu_cpy(webgpu_context & ctx, ggml_tensor * src, ggml_tensor
 
     size_t   max_wg_size = ctx->max_wg_size_x;
     uint32_t wg_x        = (ne + max_wg_size - 1) / max_wg_size;
-    ggml_backend_webgpu_build_and_enqueue(ctx, ctx->cpy_pipeline, params, entries, wg_x, ggml_op_name(dst->op));
+    ggml_backend_webgpu_build_and_enqueue(ctx, ctx->cpy_pipeline[src->type][dst->type], params, entries, wg_x,
+                                          ggml_op_name(dst->op));
 }
 
 static void ggml_webgpu_set_rows(webgpu_context & ctx, ggml_tensor * src, ggml_tensor * idx, ggml_tensor * dst) {
@@ -649,7 +653,7 @@ static void ggml_webgpu_binary_op(webgpu_context &        ctx,
                                   ggml_tensor *           src1,
                                   ggml_tensor *           dst,
                                   wgpu::ComputePipeline & pipeline,
-                                  bool                    in_place) {
+                                  bool                    inplace) {
     std::vector<uint32_t> params = {
         (uint32_t) ggml_nelements(dst),
         (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src0) / ggml_type_size(src0->type)),
@@ -678,7 +682,7 @@ static void ggml_webgpu_binary_op(webgpu_context &        ctx,
          .offset  = ggml_webgpu_tensor_align_offset(ctx, src1),
          .size    = ggml_webgpu_tensor_binding_size(ctx, src1) }
     };
-    if (!in_place) {
+    if (!inplace) {
         entries.push_back({ .binding = 2,
                             .buffer  = ggml_webgpu_tensor_buf(dst),
                             .offset  = ggml_webgpu_tensor_align_offset(ctx, dst),
@@ -691,30 +695,23 @@ static void ggml_webgpu_binary_op(webgpu_context &        ctx,
 }
 
 static void ggml_webgpu_rms_norm(webgpu_context & ctx, ggml_tensor * src, ggml_tensor * dst) {
-    bool in_place = ggml_webgpu_tensor_equal(src, dst);
-
-    uint32_t eps;
-    memcpy(&eps, dst->op_params, sizeof(float));
+    int inplace = ggml_webgpu_tensor_equal(src, dst);
 
     std::vector<uint32_t> params = {
         (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src) / ggml_type_size(src->type)),
+        (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, dst) / ggml_type_size(dst->type)),
+        (uint32_t) (src->nb[1] / ggml_type_size(src->type)),
+        (uint32_t) (src->nb[2] / ggml_type_size(src->type)),
+        (uint32_t) (src->nb[3] / ggml_type_size(src->type)),
+        (uint32_t) (dst->nb[1] / ggml_type_size(dst->type)),
+        (uint32_t) (dst->nb[2] / ggml_type_size(dst->type)),
+        (uint32_t) (dst->nb[3] / ggml_type_size(dst->type)),
+        (uint32_t) src->ne[0],
+        (uint32_t) src->ne[1],
+        (uint32_t) src->ne[2],
+        (uint32_t) src->ne[3],
+        *(uint32_t *) dst->op_params  // epsilon, treated as f32 in the shader
     };
-    if (!in_place) {
-        params.push_back((uint32_t) (ggml_webgpu_tensor_misalignment(ctx, dst) / ggml_type_size(dst->type)));
-    }
-    params.push_back((uint32_t) (src->nb[1] / ggml_type_size(src->type)));
-    params.push_back((uint32_t) (src->nb[2] / ggml_type_size(src->type)));
-    params.push_back((uint32_t) (src->nb[3] / ggml_type_size(src->type)));
-    if (!in_place) {
-        params.push_back((uint32_t) (dst->nb[1] / ggml_type_size(dst->type)));
-        params.push_back((uint32_t) (dst->nb[2] / ggml_type_size(dst->type)));
-        params.push_back((uint32_t) (dst->nb[3] / ggml_type_size(dst->type)));
-    }
-    params.push_back((uint32_t) src->ne[0]);
-    params.push_back((uint32_t) src->ne[1]);
-    params.push_back((uint32_t) src->ne[2]);
-    params.push_back((uint32_t) src->ne[3]);
-    params.push_back(eps);  // epsilon, will be bitcast to float in shader
 
     std::vector<wgpu::BindGroupEntry> entries = {
         { .binding = 0,
@@ -722,24 +719,199 @@ static void ggml_webgpu_rms_norm(webgpu_context & ctx, ggml_tensor * src, ggml_t
          .offset  = ggml_webgpu_tensor_align_offset(ctx, src),
          .size    = ggml_webgpu_tensor_binding_size(ctx, src) }
     };
-    if (!in_place) {
+    if (!inplace) {
         entries.push_back({ .binding = 1,
                             .buffer  = ggml_webgpu_tensor_buf(dst),
                             .offset  = ggml_webgpu_tensor_align_offset(ctx, dst),
                             .size    = ggml_webgpu_tensor_binding_size(ctx, dst) });
     }
 
-    wgpu::ComputePipeline pipeline;
-    if (in_place) {
-        pipeline = ctx->rms_norm_ip_pipeline;
-    } else {
-        pipeline = ctx->rms_norm_pipeline;
-    }
     size_t   max_wg_size = ctx->max_wg_size_x;
     uint32_t wg_x        = (src->ne[1] * src->ne[2] * src->ne[3] + max_wg_size - 1) / max_wg_size;
+    ggml_backend_webgpu_build_and_enqueue(ctx, ctx->rms_norm_pipeline[inplace], params, entries, wg_x,
+                                          ggml_op_name(dst->op));
+}
+
+static void ggml_webgpu_rope(webgpu_context & ctx,
+                             ggml_tensor *    src0,
+                             ggml_tensor *    src1,
+                             ggml_tensor *    src2,
+                             ggml_tensor *    dst) {
+    const int inplace         = ggml_webgpu_tensor_equal(src0, dst);
+    const int has_freq_factor = (src2 != nullptr);
+
+    const int n_dims     = ((int32_t *) dst->op_params)[1];
+    const int mode       = ((int32_t *) dst->op_params)[2];
+    const int n_ctx_orig = ((int32_t *) dst->op_params)[4];
+
+    float freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow;
+    memcpy(&freq_base, (int32_t *) dst->op_params + 5, sizeof(float));
+    memcpy(&freq_scale, (int32_t *) dst->op_params + 6, sizeof(float));
+    memcpy(&ext_factor, (int32_t *) dst->op_params + 7, sizeof(float));
+    memcpy(&attn_factor, (int32_t *) dst->op_params + 8, sizeof(float));
+    memcpy(&beta_fast, (int32_t *) dst->op_params + 9, sizeof(float));
+    memcpy(&beta_slow, (int32_t *) dst->op_params + 10, sizeof(float));
+
+    int sections[4];
+    memcpy(sections, (int32_t *) dst->op_params + 11, 4 * sizeof(int));
+
+    float theta_scale = powf(freq_base, -2.0f / n_dims);
+
+    float corr_dims[2];
+    ggml_rope_yarn_corr_dims(n_dims, n_ctx_orig, freq_base, beta_fast, beta_slow, corr_dims);
+
+    std::vector<uint32_t> params = {
+        (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src0) / ggml_type_size(src0->type)),
+        (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src1) / ggml_type_size(src1->type)),
+        src2 != nullptr ? (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src2) / ggml_type_size(src2->type)) : 0,
+        (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, dst) / ggml_type_size(dst->type)),
+        (uint32_t) (src0->nb[1] / ggml_type_size(src0->type)),
+        (uint32_t) (src0->nb[2] / ggml_type_size(src0->type)),
+        (uint32_t) (src0->nb[3] / ggml_type_size(src0->type)),
+        (uint32_t) (dst->nb[1] / ggml_type_size(dst->type)),
+        (uint32_t) (dst->nb[2] / ggml_type_size(dst->type)),
+        (uint32_t) (dst->nb[3] / ggml_type_size(dst->type)),
+        (uint32_t) ggml_nelements(src0) / 2,
+        (uint32_t) src0->ne[0],
+        (uint32_t) src0->ne[1],
+        (uint32_t) src0->ne[2],
+        (uint32_t) n_dims,
+        (uint32_t) mode,
+        *(uint32_t *) &theta_scale,
+        *(uint32_t *) &attn_factor,
+        *(uint32_t *) &freq_scale,
+        *(uint32_t *) &ext_factor,
+        *(uint32_t *) &corr_dims[0],
+        *(uint32_t *) &corr_dims[1],
+        (uint32_t) sections[0],
+        (uint32_t) sections[1],
+        (uint32_t) sections[2],
+        (uint32_t) sections[3]
+    };
+
+    std::vector<wgpu::BindGroupEntry> entries = {
+        { .binding = 0,
+         .buffer  = ggml_webgpu_tensor_buf(src0),
+         .offset  = ggml_webgpu_tensor_align_offset(ctx, src0),
+         .size    = ggml_webgpu_tensor_binding_size(ctx, src0) },
+        { .binding = 1,
+         .buffer  = ggml_webgpu_tensor_buf(src1),
+         .offset  = ggml_webgpu_tensor_align_offset(ctx, src1),
+         .size    = ggml_webgpu_tensor_binding_size(ctx, src1) }
+    };
+    uint32_t dst_binding = 2;
+    if (has_freq_factor) {
+        dst_binding = 3;
+        entries.push_back({ .binding = 2,
+                            .buffer  = ggml_webgpu_tensor_buf(src2),
+                            .offset  = ggml_webgpu_tensor_align_offset(ctx, src2),
+                            .size    = ggml_webgpu_tensor_binding_size(ctx, src2) });
+    }
+    if (!inplace) {
+        entries.push_back({ .binding = dst_binding,
+                            .buffer  = ggml_webgpu_tensor_buf(dst),
+                            .offset  = ggml_webgpu_tensor_align_offset(ctx, dst),
+                            .size    = ggml_webgpu_tensor_binding_size(ctx, dst) });
+    }
+
+    wgpu::ComputePipeline pipeline    = ctx->rope_pipeline[dst->type][has_freq_factor][inplace];
+    size_t                max_wg_size = ctx->max_wg_size_x;
+    uint32_t              wg_x        = (ggml_nelements(src0) / 2 + max_wg_size - 1) / max_wg_size;
     ggml_backend_webgpu_build_and_enqueue(ctx, pipeline, params, entries, wg_x, ggml_op_name(dst->op));
 }
 
+static void ggml_webgpu_glu(webgpu_context & ctx, ggml_tensor * src0, ggml_tensor * src1, ggml_tensor * dst) {
+    const int split = (src1 != nullptr);
+
+    std::vector<uint32_t> params = {
+        (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src0) / ggml_type_size(src0->type)),
+        src1 != nullptr ? (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src1) / ggml_type_size(src1->type)) : 0,
+        (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, dst) / ggml_type_size(dst->type)),
+        (uint32_t) (src0->nb[1] / ggml_type_size(src0->type)),
+        (uint32_t) (src0->nb[2] / ggml_type_size(src0->type)),
+        (uint32_t) (src0->nb[3] / ggml_type_size(src0->type)),
+        src1 != nullptr ? (uint32_t) (src1->nb[1] / ggml_type_size(src1->type)) :
+                          (uint32_t) (src0->nb[1] / ggml_type_size(src0->type)),
+        src1 != nullptr ? (uint32_t) (src1->nb[2] / ggml_type_size(src1->type)) :
+                          (uint32_t) (src0->nb[2] / ggml_type_size(src0->type)),
+        src1 != nullptr ? (uint32_t) (src1->nb[3] / ggml_type_size(src1->type)) :
+                          (uint32_t) (src0->nb[3] / ggml_type_size(src0->type)),
+        (uint32_t) (dst->nb[1] / ggml_type_size(dst->type)),
+        (uint32_t) (dst->nb[2] / ggml_type_size(dst->type)),
+        (uint32_t) (dst->nb[3] / ggml_type_size(dst->type)),
+        (uint32_t) ggml_nelements(dst),
+        (uint32_t) dst->ne[0],
+        (uint32_t) dst->ne[1],
+        (uint32_t) dst->ne[2],
+        (uint32_t) ((int32_t *) dst->op_params)[1],  // swapped
+        *(uint32_t *) &dst->op_params[2],            // alpha, for swiglu_oai
+        *(uint32_t *) &dst->op_params[3],            // limit, for swiglu_oai
+    };
+
+    std::vector<wgpu::BindGroupEntry> entries = {
+        { .binding = 0,
+         .buffer  = ggml_webgpu_tensor_buf(src0),
+         .offset  = ggml_webgpu_tensor_align_offset(ctx, src0),
+         .size    = ggml_webgpu_tensor_binding_size(ctx, src0) },
+    };
+    uint32_t dst_binding = 1;
+    if (split) {
+        dst_binding = 2;
+        entries.push_back({ .binding = 1,
+                            .buffer  = ggml_webgpu_tensor_buf(src1),
+                            .offset  = ggml_webgpu_tensor_align_offset(ctx, src1),
+                            .size    = ggml_webgpu_tensor_binding_size(ctx, src1) });
+    }
+    entries.push_back({ .binding = dst_binding,
+                        .buffer  = ggml_webgpu_tensor_buf(dst),
+                        .offset  = ggml_webgpu_tensor_align_offset(ctx, dst),
+                        .size    = ggml_webgpu_tensor_binding_size(ctx, dst) });
+
+    wgpu::ComputePipeline pipeline    = ctx->glu_pipeline[ggml_get_glu_op(dst)][dst->type][split];
+    size_t                max_wg_size = ctx->max_wg_size_x;
+    uint32_t              wg_x        = (ggml_nelements(dst) + max_wg_size - 1) / max_wg_size;
+    ggml_backend_webgpu_build_and_enqueue(ctx, pipeline, params, entries, wg_x, ggml_op_name(dst->op));
+}
+
+static void ggml_webgpu_scale(webgpu_context & ctx, ggml_tensor * src, ggml_tensor * dst) {
+    int inplace = ggml_webgpu_tensor_equal(src, dst);
+
+    std::vector<uint32_t> params = {
+        (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src) / ggml_type_size(src->type)),
+        (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, dst) / ggml_type_size(dst->type)),
+        (uint32_t) (src->nb[1] / ggml_type_size(src->type)),
+        (uint32_t) (src->nb[2] / ggml_type_size(src->type)),
+        (uint32_t) (src->nb[3] / ggml_type_size(src->type)),
+        (uint32_t) (dst->nb[1] / ggml_type_size(dst->type)),
+        (uint32_t) (dst->nb[2] / ggml_type_size(dst->type)),
+        (uint32_t) (dst->nb[3] / ggml_type_size(dst->type)),
+        (uint32_t) ggml_nelements(dst),
+        (uint32_t) src->ne[0],
+        (uint32_t) src->ne[1],
+        (uint32_t) src->ne[2],
+        *(uint32_t *) dst->op_params,     // scale
+        *(uint32_t *) &dst->op_params[1]  // bias
+    };
+
+    std::vector<wgpu::BindGroupEntry> entries = {
+        { .binding = 0,
+         .buffer  = ggml_webgpu_tensor_buf(src),
+         .offset  = ggml_webgpu_tensor_align_offset(ctx, src),
+         .size    = ggml_webgpu_tensor_binding_size(ctx, src) }
+    };
+    if (!inplace) {
+        entries.push_back({ .binding = 1,
+                            .buffer  = ggml_webgpu_tensor_buf(dst),
+                            .offset  = ggml_webgpu_tensor_align_offset(ctx, dst),
+                            .size    = ggml_webgpu_tensor_binding_size(ctx, dst) });
+    }
+
+    size_t   max_wg_size = ctx->max_wg_size_x;
+    uint32_t wg_x        = (ggml_nelements(dst) + max_wg_size - 1) / max_wg_size;
+    ggml_backend_webgpu_build_and_enqueue(ctx, ctx->scale_pipeline[inplace], params, entries, wg_x,
+                                          ggml_op_name(dst->op));
+}
+
 // Returns true if node has enqueued work into the queue, false otherwise
 static bool ggml_webgpu_encode_node(webgpu_context ctx, ggml_tensor * node) {
     if (ggml_is_empty(node)) {
@@ -749,6 +921,7 @@ static bool ggml_webgpu_encode_node(webgpu_context ctx, ggml_tensor * node) {
 
     ggml_tensor * src0 = node->src[0];
     ggml_tensor * src1 = node->src[1];
+    ggml_tensor * src2 = node->src[2];
 
     switch (node->op) {
             // no-ops
@@ -759,6 +932,7 @@ static bool ggml_webgpu_encode_node(webgpu_context ctx, ggml_tensor * node) {
         case GGML_OP_RESHAPE:
             return false;
         case GGML_OP_CPY:
+        case GGML_OP_CONT:
             ggml_webgpu_cpy(ctx, src0, node);
             break;
         case GGML_OP_SET_ROWS:
@@ -771,22 +945,41 @@ static bool ggml_webgpu_encode_node(webgpu_context ctx, ggml_tensor * node) {
             ggml_webgpu_mul_mat(ctx, src0, src1, node);
             break;
         case GGML_OP_ADD:
-            if (ggml_webgpu_tensor_equal(src0, node)) {
-                ggml_webgpu_binary_op(ctx, src0, src1, node, ctx->add_ip_pipeline[node->type], true);
-            } else {
-                ggml_webgpu_binary_op(ctx, src0, src1, node, ctx->add_pipeline[node->type], false);
+            {
+                int inplace = ggml_webgpu_tensor_equal(src0, node);
+                ggml_webgpu_binary_op(ctx, src0, src1, node, ctx->add_pipeline[node->type][inplace], inplace);
+                break;
+            }
+        case GGML_OP_SUB:
+            {
+                int inplace = ggml_webgpu_tensor_equal(src0, node);
+                ggml_webgpu_binary_op(ctx, src0, src1, node, ctx->sub_pipeline[node->type][inplace], inplace);
+                break;
             }
-            break;
         case GGML_OP_MUL:
-            if (ggml_webgpu_tensor_equal(src0, node)) {
-                ggml_webgpu_binary_op(ctx, src0, src1, node, ctx->mul_ip_pipeline[node->type], true);
-            } else {
-                ggml_webgpu_binary_op(ctx, src0, src1, node, ctx->mul_pipeline[node->type], false);
+            {
+                int inplace = ggml_webgpu_tensor_equal(src0, node);
+                ggml_webgpu_binary_op(ctx, src0, src1, node, ctx->mul_pipeline[node->type][inplace], inplace);
+                break;
+            }
+        case GGML_OP_DIV:
+            {
+                int inplace = ggml_webgpu_tensor_equal(src0, node);
+                ggml_webgpu_binary_op(ctx, src0, src1, node, ctx->div_pipeline[node->type][inplace], inplace);
+                break;
             }
-            break;
         case GGML_OP_RMS_NORM:
             ggml_webgpu_rms_norm(ctx, src0, node);
             break;
+        case GGML_OP_ROPE:
+            ggml_webgpu_rope(ctx, src0, src1, src2, node);
+            break;
+        case GGML_OP_GLU:
+            ggml_webgpu_glu(ctx, src0, src1, node);
+            break;
+        case GGML_OP_SCALE:
+            ggml_webgpu_scale(ctx, src0, node);
+            break;
         default:
             return false;
     }
@@ -1170,40 +1363,153 @@ static void ggml_webgpu_init_get_rows_pipeline(webgpu_context & webgpu_ctx) {
 }
 
 static void ggml_webgpu_init_cpy_pipeline(webgpu_context & webgpu_ctx) {
-    ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->cpy_pipeline, wgsl_cpy, "cpy",
-                                ggml_webgpu_max_wg_size_entry(webgpu_ctx));
+    std::vector<wgpu::ConstantEntry> constants = ggml_webgpu_max_wg_size_entry(webgpu_ctx);
+    ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->cpy_pipeline[GGML_TYPE_F32][GGML_TYPE_F32],
+                                wgsl_cpy_f32_f32, "cpy_f32_f32", constants);
+    ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->cpy_pipeline[GGML_TYPE_F32][GGML_TYPE_F16],
+                                wgsl_cpy_f32_f16, "cpy_f32_f16", constants);
+    ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->cpy_pipeline[GGML_TYPE_F16][GGML_TYPE_F32],
+                                wgsl_cpy_f16_f32, "cpy_f16_f32", constants);
+    ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->cpy_pipeline[GGML_TYPE_F16][GGML_TYPE_F16],
+                                wgsl_cpy_f16_f16, "cpy_f16_f16", constants);
 }
 
 static void ggml_webgpu_init_add_pipeline(webgpu_context & webgpu_ctx) {
     std::vector<wgpu::ConstantEntry> constants = ggml_webgpu_max_wg_size_entry(webgpu_ctx);
-    ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->add_pipeline[GGML_TYPE_F32], wgsl_add_f32, "add_f32",
+    ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->add_pipeline[GGML_TYPE_F32][0], wgsl_add_f32, "add_f32",
+                                constants);
+    ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->add_pipeline[GGML_TYPE_F16][0], wgsl_add_f16, "add_f16",
+                                constants);
+    ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->add_pipeline[GGML_TYPE_F32][1], wgsl_add_f32_inplace,
+                                "add_f32_inplace", constants);
+    ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->add_pipeline[GGML_TYPE_F16][1], wgsl_add_f16_inplace,
+                                "add_f16_inplace", constants);
+}
+
+static void ggml_webgpu_init_sub_pipeline(webgpu_context & webgpu_ctx) {
+    std::vector<wgpu::ConstantEntry> constants = ggml_webgpu_max_wg_size_entry(webgpu_ctx);
+    ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->sub_pipeline[GGML_TYPE_F32][0], wgsl_sub_f32, "sub_f32",
                                 constants);
-    ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->add_pipeline[GGML_TYPE_F16], wgsl_add_f16, "add_f16",
+    ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->sub_pipeline[GGML_TYPE_F16][0], wgsl_sub_f16, "sub_f16",
                                 constants);
-    ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->add_ip_pipeline[GGML_TYPE_F32], wgsl_add_in_place_f32,
-                                "add_in_place_f32", constants);
-    ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->add_ip_pipeline[GGML_TYPE_F16], wgsl_add_in_place_f16,
-                                "add_in_place_f16", constants);
+    ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->sub_pipeline[GGML_TYPE_F32][1], wgsl_sub_f32_inplace,
+                                "sub_f32_inplace", constants);
+    ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->sub_pipeline[GGML_TYPE_F16][1], wgsl_sub_f16_inplace,
+                                "sub_f16_inplace", constants);
 }
 
 static void ggml_webgpu_init_mul_pipeline(webgpu_context & webgpu_ctx) {
     std::vector<wgpu::ConstantEntry> constants = ggml_webgpu_max_wg_size_entry(webgpu_ctx);
-    ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->mul_pipeline[GGML_TYPE_F32], wgsl_mul_f32, "mul_f32",
+    ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->mul_pipeline[GGML_TYPE_F32][0], wgsl_mul_f32, "mul_f32",
+                                constants);
+    ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->mul_pipeline[GGML_TYPE_F16][0], wgsl_mul_f16, "mul_f16",
+                                constants);
+    ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->mul_pipeline[GGML_TYPE_F32][1], wgsl_mul_f32_inplace,
+                                "mul_f32_inplace", constants);
+    ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->mul_pipeline[GGML_TYPE_F16][1], wgsl_mul_f16_inplace,
+                                "mul_f16_inplace", constants);
+}
+
+static void ggml_webgpu_init_div_pipeline(webgpu_context & webgpu_ctx) {
+    std::vector<wgpu::ConstantEntry> constants = ggml_webgpu_max_wg_size_entry(webgpu_ctx);
+    ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->div_pipeline[GGML_TYPE_F32][0], wgsl_div_f32, "div_f32",
                                 constants);
-    ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->mul_pipeline[GGML_TYPE_F16], wgsl_mul_f16, "mul_f16",
+    ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->div_pipeline[GGML_TYPE_F16][0], wgsl_div_f16, "div_f16",
                                 constants);
-    ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->mul_ip_pipeline[GGML_TYPE_F32], wgsl_mul_in_place_f32,
-                                "mul_in_place_f32", constants);
-    ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->mul_ip_pipeline[GGML_TYPE_F16], wgsl_mul_in_place_f16,
-                                "mul_in_place_f16", constants);
+    ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->div_pipeline[GGML_TYPE_F32][1], wgsl_div_f32_inplace,
+                                "div_f32_inplace", constants);
+    ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->div_pipeline[GGML_TYPE_F16][1], wgsl_div_f16_inplace,
+                                "div_f16_inplace", constants);
 }
 
 static void ggml_webgpu_init_rms_norm_pipeline(webgpu_context & webgpu_ctx) {
     std::vector<wgpu::ConstantEntry> constants = ggml_webgpu_max_wg_size_entry(webgpu_ctx);
-    ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->rms_norm_pipeline, wgsl_rms_norm, "rms_norm",
+    ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->rms_norm_pipeline[0], wgsl_rms_norm, "rms_norm",
                                 constants);
-    ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->rms_norm_ip_pipeline, wgsl_rms_norm_in_place,
-                                "rms_norm_in_place", constants);
+    ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->rms_norm_pipeline[1], wgsl_rms_norm_inplace,
+                                "rms_norm_inplace", constants);
+}
+
+static void ggml_webgpu_init_rope_pipeline(webgpu_context & webgpu_ctx) {
+    std::vector<wgpu::ConstantEntry> constants = ggml_webgpu_max_wg_size_entry(webgpu_ctx);
+    ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->rope_pipeline[GGML_TYPE_F32][0][0], wgsl_rope_f32,
+                                "rope_f32", constants);
+    ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->rope_pipeline[GGML_TYPE_F32][0][1],
+                                wgsl_rope_f32_inplace, "rope_f32_inplace", constants);
+    ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->rope_pipeline[GGML_TYPE_F32][1][0], wgsl_rope_f32_ff,
+                                "rope_f32_ff", constants);
+    ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->rope_pipeline[GGML_TYPE_F32][1][1],
+                                wgsl_rope_f32_ff_inplace, "rope_f32_ff_inplace", constants);
+    ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->rope_pipeline[GGML_TYPE_F16][0][0], wgsl_rope_f16,
+                                "rope_f16", constants);
+    ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->rope_pipeline[GGML_TYPE_F16][0][1],
+                                wgsl_rope_f16_inplace, "rope_f16_inplace", constants);
+    ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->rope_pipeline[GGML_TYPE_F16][1][0], wgsl_rope_f16_ff,
+                                "rope_f16_ff", constants);
+    ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->rope_pipeline[GGML_TYPE_F16][1][1],
+                                wgsl_rope_f16_ff_inplace, "rope_f16_ff_inplace", constants);
+}
+
+static void ggml_webgpu_init_glu_pipeline(webgpu_context & webgpu_ctx) {
+    std::vector<wgpu::ConstantEntry> constants = ggml_webgpu_max_wg_size_entry(webgpu_ctx);
+    // reglu
+    ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->glu_pipeline[GGML_GLU_OP_REGLU][GGML_TYPE_F32][0],
+                                wgsl_reglu_f32, "reglu_f32", constants);
+    ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->glu_pipeline[GGML_GLU_OP_REGLU][GGML_TYPE_F16][0],
+                                wgsl_reglu_f16, "reglu_f16", constants);
+    ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->glu_pipeline[GGML_GLU_OP_REGLU][GGML_TYPE_F32][1],
+                                wgsl_reglu_f32_split, "reglu_f32_split", constants);
+    ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->glu_pipeline[GGML_GLU_OP_REGLU][GGML_TYPE_F16][1],
+                                wgsl_reglu_f16_split, "reglu_f16_split", constants);
+    // geglu
+    ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->glu_pipeline[GGML_GLU_OP_GEGLU][GGML_TYPE_F32][0],
+                                wgsl_geglu_f32, "geglu_f32", constants);
+    ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->glu_pipeline[GGML_GLU_OP_GEGLU][GGML_TYPE_F16][0],
+                                wgsl_geglu_f16, "geglu_f16", constants);
+    ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->glu_pipeline[GGML_GLU_OP_GEGLU][GGML_TYPE_F32][1],
+                                wgsl_geglu_f32_split, "geglu_f32_split", constants);
+    ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->glu_pipeline[GGML_GLU_OP_GEGLU][GGML_TYPE_F16][1],
+                                wgsl_geglu_f16_split, "geglu_f16_split", constants);
+    // swiglu
+    ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->glu_pipeline[GGML_GLU_OP_SWIGLU][GGML_TYPE_F32][0],
+                                wgsl_swiglu_f32, "swiglu_f32", constants);
+    ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->glu_pipeline[GGML_GLU_OP_SWIGLU][GGML_TYPE_F16][0],
+                                wgsl_swiglu_f16, "swiglu_f16", constants);
+    ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->glu_pipeline[GGML_GLU_OP_SWIGLU][GGML_TYPE_F32][1],
+                                wgsl_swiglu_f32_split, "swiglu_f32_split", constants);
+    ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->glu_pipeline[GGML_GLU_OP_SWIGLU][GGML_TYPE_F16][1],
+                                wgsl_swiglu_f16_split, "swiglu_f16_split", constants);
+    // swiglu_oai
+    ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->glu_pipeline[GGML_GLU_OP_SWIGLU_OAI][GGML_TYPE_F32][0],
+                                wgsl_swiglu_oai_f32, "swiglu_oai_f32", constants);
+    ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->glu_pipeline[GGML_GLU_OP_SWIGLU_OAI][GGML_TYPE_F32][1],
+                                wgsl_swiglu_oai_f32_split, "swiglu_oai_f32_split", constants);
+    // geglu_erf
+    ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->glu_pipeline[GGML_GLU_OP_GEGLU_ERF][GGML_TYPE_F32][0],
+                                wgsl_geglu_erf_f32, "geglu_erf_f32", constants);
+    ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->glu_pipeline[GGML_GLU_OP_GEGLU_ERF][GGML_TYPE_F16][0],
+                                wgsl_geglu_erf_f16, "geglu_erf_f16", constants);
+    ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->glu_pipeline[GGML_GLU_OP_GEGLU_ERF][GGML_TYPE_F32][1],
+                                wgsl_geglu_erf_f32_split, "geglu_erf_f32_split", constants);
+    ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->glu_pipeline[GGML_GLU_OP_GEGLU_ERF][GGML_TYPE_F16][1],
+                                wgsl_geglu_erf_f16_split, "geglu_erf_f16_split", constants);
+    // geglu_quick
+    ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->glu_pipeline[GGML_GLU_OP_GEGLU_QUICK][GGML_TYPE_F32][0],
+                                wgsl_geglu_quick_f32, "geglu_quick_f32", constants);
+    ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->glu_pipeline[GGML_GLU_OP_GEGLU_QUICK][GGML_TYPE_F16][0],
+                                wgsl_geglu_quick_f16, "geglu_quick_f16", constants);
+    ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->glu_pipeline[GGML_GLU_OP_GEGLU_QUICK][GGML_TYPE_F32][1],
+                                wgsl_geglu_quick_f32_split, "geglu_quick_f32_split", constants);
+    ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->glu_pipeline[GGML_GLU_OP_GEGLU_QUICK][GGML_TYPE_F16][1],
+                                wgsl_geglu_quick_f16_split, "geglu_quick_f16_split", constants);
+}
+
+static void ggml_webgpu_init_scale_pipeline(webgpu_context & webgpu_ctx) {
+    std::vector<wgpu::ConstantEntry> constants = ggml_webgpu_max_wg_size_entry(webgpu_ctx);
+    ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->scale_pipeline[0], wgsl_scale_f32, "scale_f32",
+                                constants);
+    ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->scale_pipeline[1], wgsl_scale_f32_inplace,
+                                "scale_f32_inplace", constants);
 }
 
 static ggml_backend_t ggml_backend_webgpu_device_init(ggml_backend_dev_t dev, const char * params) {
@@ -1287,6 +1593,7 @@ static bool ggml_backend_webgpu_device_supports_op(ggml_backend_dev_t dev, const
 
     ggml_tensor * src0 = op->src[0];
     ggml_tensor * src1 = op->src[1];
+
     // on smaller devices (or CI), tensors may be larger than the max storage buffer size
     if (ggml_nbytes(op) > webgpu_ctx->limits.maxStorageBufferBindingSize ||
         (src0 != nullptr && ggml_nbytes(src0) > webgpu_ctx->limits.maxStorageBufferBindingSize) ||
@@ -1304,28 +1611,34 @@ static bool ggml_backend_webgpu_device_supports_op(ggml_backend_dev_t dev, const
             supports_op = true;
             break;
         case GGML_OP_ADD:
+        case GGML_OP_SUB:
         case GGML_OP_MUL:
-            supports_op = (op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16) && (op->src[0]->type == op->type) &&
-                          (op->src[1]->type == op->type);
+        case GGML_OP_DIV:
+            supports_op = (op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16) && (src0->type == op->type) &&
+                          (src1->type == op->type);
             break;
         case GGML_OP_CPY:
+        case GGML_OP_CONT:
+            supports_op = (op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16) &&
+                          (src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16);
+            break;
         case GGML_OP_SET_ROWS:
             supports_op = (op->type == GGML_TYPE_F16 && op->src[0]->type == GGML_TYPE_F32 && op->src[1]->type == GGML_TYPE_I64);
             break;
         case GGML_OP_GET_ROWS:
-            if (op->src[0]->type == GGML_TYPE_F32 || op->src[0]->type == GGML_TYPE_F16 ||
-                op->src[0]->type == GGML_TYPE_I32 || ggml_webgpu_supported_qtype(op->src[0]->type)) {
+            if (src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16 || src0->type == GGML_TYPE_I32 ||
+                ggml_webgpu_supported_qtype(src0->type)) {
                 supports_op = (op->type == GGML_TYPE_F32);
             }
             break;
         case GGML_OP_MUL_MAT:
             {
-                switch (op->src[1]->type) {
+                switch (src1->type) {
                     case GGML_TYPE_F16:
-                        supports_op = (op->src[0]->type == GGML_TYPE_F16);
+                        supports_op |= (src0->type == GGML_TYPE_F16);
                         break;
                     case GGML_TYPE_F32:
-                        switch (op->src[0]->type) {
+                        switch (src0->type) {
                             case GGML_TYPE_F32:
                             case GGML_TYPE_F16:
                             case GGML_TYPE_Q4_0:
@@ -1358,7 +1671,29 @@ static bool ggml_backend_webgpu_device_supports_op(ggml_backend_dev_t dev, const
                 break;
             }
         case GGML_OP_RMS_NORM:
-            supports_op = op->type == GGML_TYPE_F32 && op->src[0]->type == GGML_TYPE_F32;
+            supports_op = op->type == GGML_TYPE_F32 && src0->type == GGML_TYPE_F32;
+            break;
+        case GGML_OP_ROPE:
+            supports_op = op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16;
+            break;
+        case GGML_OP_GLU:
+            switch (ggml_get_glu_op(op)) {
+                case GGML_GLU_OP_REGLU:
+                case GGML_GLU_OP_GEGLU:
+                case GGML_GLU_OP_SWIGLU:
+                case GGML_GLU_OP_GEGLU_ERF:
+                case GGML_GLU_OP_GEGLU_QUICK:
+                    supports_op = op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16;
+                    break;
+                case GGML_GLU_OP_SWIGLU_OAI:
+                    supports_op = op->type == GGML_TYPE_F32;
+                    break;
+                default:
+                    break;
+            }
+            break;
+        case GGML_OP_SCALE:
+            supports_op = op->type == GGML_TYPE_F32;
             break;
         default:
             break;
@@ -1484,8 +1819,13 @@ static ggml_backend_dev_t ggml_backend_webgpu_reg_get_device(ggml_backend_reg_t
     ggml_webgpu_init_get_rows_pipeline(ctx);
     ggml_webgpu_init_cpy_pipeline(ctx);
     ggml_webgpu_init_add_pipeline(ctx);
+    ggml_webgpu_init_sub_pipeline(ctx);
     ggml_webgpu_init_mul_pipeline(ctx);
+    ggml_webgpu_init_div_pipeline(ctx);
     ggml_webgpu_init_rms_norm_pipeline(ctx);
+    ggml_webgpu_init_rope_pipeline(ctx);
+    ggml_webgpu_init_glu_pipeline(ctx);
+    ggml_webgpu_init_scale_pipeline(ctx);
 
 #ifdef GGML_WEBGPU_DEBUG
     // Initialize debug buffers
diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/add.tmpl.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/add.tmpl.wgsl
deleted file mode 100644 (file)
index f261cbb..0000000
+++ /dev/null
@@ -1,44 +0,0 @@
-#define(VARIANTS)
-
-[
-  {
-    "REPLS": {
-      "TYPE" : "f32",
-    }
-  },
-  {
-    "REPLS": {
-      "TYPE" : "f16",
-    }
-  }
-]
-
-#end(VARIANTS)
-
-#define(SHADER)
-
-enable f16;
-
-#include "binary_head.tmpl"
-
-@group(0) @binding(0)
-var<storage, read_write> src0: array<{{TYPE}}>;
-
-@group(0) @binding(1)
-var<storage, read_write> src1: array<{{TYPE}}>;
-
-@group(0) @binding(2)
-var<storage, read_write> dst: array<{{TYPE}}>;
-
-@group(0) @binding(3)
-var<uniform> params: Params;
-
-override wg_size: u32;
-@compute @workgroup_size(wg_size)
-fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
-    if (gid.x < params.ne) {
-        dst[params.offset_dst + gid.x] = src0[params.offset_src0 + gid.x] + src1[params.offset_src1 + src1_index(gid.x)];
-    }
-}
-
-#end(SHADER)
diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/add_in_place.tmpl.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/add_in_place.tmpl.wgsl
deleted file mode 100644 (file)
index 903f7bd..0000000
+++ /dev/null
@@ -1,41 +0,0 @@
-#define(VARIANTS)
-
-[
-  {
-    "REPLS": {
-      "TYPE" : "f32",
-    }
-  },
-  {
-    "REPLS": {
-      "TYPE" : "f16",
-    }
-  }
-]
-
-#end(VARIANTS)
-
-#define(SHADER)
-
-enable f16;
-
-#include "binary_head.tmpl"
-
-@group(0) @binding(0)
-var<storage, read_write> src0: array<{{TYPE}}>;
-
-@group(0) @binding(1)
-var<storage, read_write> src1: array<{{TYPE}}>;
-
-@group(0) @binding(2)
-var<uniform> params: Params;
-
-override wg_size: u32;
-@compute @workgroup_size(wg_size)
-fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
-    if (gid.x < params.ne) {
-        src0[params.offset_dst + gid.x] = src0[params.offset_src0 + gid.x] + src1[params.offset_src1 + src1_index(gid.x)];
-    }
-}
-
-#end(SHADER)
diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/bin_op.tmpl.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/bin_op.tmpl.wgsl
new file mode 100644 (file)
index 0000000..1ce4d83
--- /dev/null
@@ -0,0 +1,188 @@
+#define(VARIANTS)
+
+[
+  {
+    "SHADER_NAME": "add_f32",
+    "REPLS": {
+      "TYPE" : "f32",
+      "OP": "+"
+    },
+    "DECLS": ["NOT_INPLACE"]
+  },
+  {
+    "SHADER_NAME": "add_f16",
+    "REPLS": {
+      "TYPE" : "f16",
+      "OP": "+"
+    },
+    "DECLS": ["NOT_INPLACE"]
+  },
+  {
+    "SHADER_NAME": "add_f32_inplace",
+    "REPLS": {
+      "TYPE" : "f32",
+      "OP": "+"
+    },
+    "DECLS": ["INPLACE"]
+  },
+  {
+    "SHADER_NAME": "add_f16_inplace",
+    "REPLS": {
+      "TYPE" : "f16",
+      "OP": "+"
+    },
+    "DECLS": ["INPLACE"]
+  },
+  {
+    "SHADER_NAME": "mul_f32",
+    "REPLS": {
+      "TYPE" : "f32",
+      "OP": "*"
+    },
+    "DECLS": ["NOT_INPLACE"]
+  },
+  {
+    "SHADER_NAME": "mul_f16",
+    "REPLS": {
+      "TYPE" : "f16",
+      "OP": "*"
+    },
+    "DECLS": ["NOT_INPLACE"]
+  },
+  {
+    "SHADER_NAME": "mul_f32_inplace",
+    "REPLS": {
+      "TYPE" : "f32",
+      "OP": "*"
+    },
+    "DECLS": ["INPLACE"]
+  },
+  {
+    "SHADER_NAME": "mul_f16_inplace",
+    "REPLS": {
+      "TYPE" : "f16",
+      "OP": "*"
+    },
+    "DECLS": ["INPLACE"]
+  },
+  {
+    "SHADER_NAME": "sub_f32",
+    "REPLS": {
+      "TYPE" : "f32",
+      "OP": "-"
+    },
+    "DECLS": ["NOT_INPLACE"]
+  },
+  {
+    "SHADER_NAME": "sub_f16",
+    "REPLS": {
+      "TYPE" : "f16",
+      "OP": "-"
+    },
+    "DECLS": ["NOT_INPLACE"]
+  },
+  {
+    "SHADER_NAME": "sub_f32_inplace",
+    "REPLS": {
+      "TYPE" : "f32",
+      "OP": "-"
+    },
+    "DECLS": ["INPLACE"]
+  },
+  {
+    "SHADER_NAME": "sub_f16_inplace",
+    "REPLS": {
+      "TYPE" : "f16",
+      "OP": "-"
+    },
+    "DECLS": ["INPLACE"]
+  },
+  {
+    "SHADER_NAME": "div_f32",
+    "REPLS": {
+      "TYPE" : "f32",
+      "OP": "/"
+    },
+    "DECLS": ["NOT_INPLACE"]
+  },
+  {
+    "SHADER_NAME": "div_f16",
+    "REPLS": {
+      "TYPE" : "f16",
+      "OP": "/"
+    },
+    "DECLS": ["NOT_INPLACE"]
+  },
+  {
+    "SHADER_NAME": "div_f32_inplace",
+    "REPLS": {
+      "TYPE" : "f32",
+      "OP": "/"
+    },
+    "DECLS": ["INPLACE"]
+  },
+  {
+    "SHADER_NAME": "div_f16_inplace",
+    "REPLS": {
+      "TYPE" : "f16",
+      "OP": "/"
+    },
+    "DECLS": ["INPLACE"]
+  }
+]
+
+#end(VARIANTS)
+
+#define(DECLS)
+
+#decl(NOT_INPLACE)
+
+fn update(dst_i: u32, src0_i: u32, src1_i: u32) {
+    dst[dst_i] = src0[src0_i] {{OP}} src1[src1_i];
+}
+
+@group(0) @binding(2)
+var<storage, read_write> dst: array<{{TYPE}}>;
+
+@group(0) @binding(3)
+var<uniform> params: Params;
+
+#enddecl(NOT_INPLACE)
+
+#decl(INPLACE)
+
+fn update(dst_i: u32, src0_i: u32, src1_i: u32) {
+    src0[dst_i] = src0[src0_i] {{OP}} src1[src1_i];
+}
+
+@group(0) @binding(2)
+var<uniform> params: Params;
+
+#enddecl(INPLACE)
+
+#end(DECLS)
+
+
+#define(SHADER)
+
+enable f16;
+
+#include "binary_head.tmpl"
+
+@group(0) @binding(0)
+var<storage, read_write> src0: array<{{TYPE}}>;
+
+@group(0) @binding(1)
+var<storage, read_write> src1: array<{{TYPE}}>;
+
+DECLS
+
+override wg_size: u32;
+@compute @workgroup_size(wg_size)
+fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
+    if (gid.x < params.ne) {
+        update(params.offset_dst + gid.x, params.offset_src0 + gid.x, params.offset_src1 + src1_index(gid.x));
+    }
+}
+
+#end(SHADER)
diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/cpy.tmpl.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/cpy.tmpl.wgsl
new file mode 100644 (file)
index 0000000..db1aa34
--- /dev/null
@@ -0,0 +1,101 @@
+#define(VARIANTS)
+
+[
+  {
+    "REPLS": {
+      "SRC_TYPE": "f32",
+      "DST_TYPE": "f32"
+    }
+  },
+  {
+    "REPLS": {
+      "SRC_TYPE": "f32",
+      "DST_TYPE": "f16"
+    }
+  },
+  {
+    "REPLS": {
+      "SRC_TYPE": "f16",
+      "DST_TYPE": "f16"
+    }
+  },
+  {
+    "REPLS": {
+      "SRC_TYPE": "f16",
+      "DST_TYPE": "f32"
+    }
+  }
+]
+
+#end(VARIANTS)
+
+#define(SHADER)
+enable f16;
+
+@group(0) @binding(0)
+var<storage, read_write> src: array<{{SRC_TYPE}}>;
+
+@group(0) @binding(1)
+var<storage, read_write> dst: array<{{DST_TYPE}}>;
+
+struct Params {
+    ne: u32,            // total number of elements
+    offset_src: u32,    // in elements
+    offset_dst: u32,    // in elements
+
+    // Strides (in elements) — may be permuted
+    stride_src0: u32,
+    stride_src1: u32,
+    stride_src2: u32,
+    stride_src3: u32,
+
+    stride_dst0: u32,
+    stride_dst1: u32,
+    stride_dst2: u32,
+    stride_dst3: u32,
+
+    // Logical shapes
+    src_ne0: u32,
+    src_ne1: u32,
+    src_ne2: u32,
+
+    dst_ne0: u32,
+    dst_ne1: u32,
+    dst_ne2: u32
+};
+
+@group(0) @binding(2)
+var<uniform> params: Params;
+
+override wg_size: u32;
+@compute @workgroup_size(wg_size)
+fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
+    if (gid.x >= params.ne) {
+        return;
+    }
+
+    var i = gid.x;
+    let i3 = i / (params.src_ne2 * params.src_ne1 * params.src_ne0);
+    i = i % (params.src_ne2 * params.src_ne1 * params.src_ne0);
+    let i2 = i / (params.src_ne1 * params.src_ne0);
+    i = i % (params.src_ne1 * params.src_ne0);
+    let i1 = i / params.src_ne0;
+    let i0 = i % params.src_ne0;
+
+    var j = gid.x;
+    let j3 = j / (params.dst_ne2 * params.dst_ne1 * params.dst_ne0);
+    j = j % (params.dst_ne2 * params.dst_ne1 * params.dst_ne0);
+    let j2 = j / (params.dst_ne1 * params.dst_ne0);
+    j = j % (params.dst_ne1 * params.dst_ne0);
+    let j1 = j / params.dst_ne0;
+    let j0 = j % params.dst_ne0;
+
+    let src_idx = i0 * params.stride_src0 + i1 * params.stride_src1 +
+                  i2 * params.stride_src2 + i3 * params.stride_src3;
+
+    let dst_idx = j0 * params.stride_dst0 + j1 * params.stride_dst1 +
+                  j2 * params.stride_dst2 + j3 * params.stride_dst3;
+
+    dst[params.offset_dst + dst_idx] = {{DST_TYPE}}((src[params.offset_src + src_idx]));
+}
+#end(SHADER)
diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/cpy.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/cpy.wgsl
deleted file mode 100644 (file)
index 6fe924c..0000000
+++ /dev/null
@@ -1,60 +0,0 @@
-enable f16;
-
-@group(0) @binding(0)
-var<storage, read_write> src: array<f32>;
-
-@group(0) @binding(1)
-var<storage, read_write> dst: array<f16>;
-
-struct Params {
-    ne: u32,            // total number of elements
-    offset_src: u32,    // in elements
-    offset_dst: u32,    // in elements
-
-    // Strides (in elements) — may be permuted
-    stride_src0: u32,
-    stride_src1: u32,
-    stride_src2: u32,
-    stride_src3: u32,
-
-    stride_dst0: u32,
-    stride_dst1: u32,
-    stride_dst2: u32,
-    stride_dst3: u32,
-
-    // Logical shape (same for both tensors)
-    ne0: u32,
-    ne1: u32,
-    ne2: u32,
-    ne3: u32,
-};
-
-@group(0) @binding(2)
-var<uniform> params: Params;
-
-override wg_size: u32;
-@compute @workgroup_size(wg_size)
-fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
-    if (gid.x >= params.ne) {
-        return;
-    }
-
-    var i = gid.x;
-
-    let i3 = i / (params.ne2 * params.ne1 * params.ne0);
-    i = i % (params.ne2 * params.ne1 * params.ne0);
-
-    let i2 = i / (params.ne1 * params.ne0);
-    i = i % (params.ne1 * params.ne0);
-
-    let i1 = i / params.ne0;
-    let i0 = i % params.ne0;
-
-    let src_idx = i0 * params.stride_src0 + i1 * params.stride_src1 +
-                  i2 * params.stride_src2 + i3 * params.stride_src3;
-
-    let dst_idx = i0 * params.stride_dst0 + i1 * params.stride_dst1 +
-                  i2 * params.stride_dst2 + i3 * params.stride_dst3;
-
-    dst[params.offset_dst + dst_idx] = f16(src[params.offset_src + src_idx]);
-}
index d9dfd7d6f4ffc1b1dec7e5a347af87c68c618d03..251051eaeca0f8c74ac7db560ed392b8431ba978 100755 (executable)
@@ -88,15 +88,20 @@ def generate_variants(fname, input_dir, output_dir, outfile):
                     raise ValueError(f"DECLS key '{key}' not found.")
                 decls_code += decls_map[key] + "\n\n"
 
-            shader_variant = replace_placeholders(shader_template, variant["REPLS"])
-            final_shader = re.sub(r'\bDECLS\b', decls_code, shader_variant)
+            final_shader = re.sub(r'\bDECLS\b', decls_code, shader_template)
+            if "REPLS" in variant:
+                final_shader = replace_placeholders(final_shader, variant["REPLS"])
             final_shader = expand_includes(final_shader, input_dir)
 
-            if "SRC0_TYPE" in variant["REPLS"] and "SRC1_TYPE" in variant["REPLS"]:
+            if "SHADER_NAME" in variant:
+                output_name = variant["SHADER_NAME"]
+            elif "SHADER_SUFFIX" in variant:
+                output_name = f"{shader_base_name}_" + variant["SHADER_SUFFIX"]
+            elif "REPLS" in variant and "SRC0_TYPE" in variant["REPLS"] and "SRC1_TYPE" in variant["REPLS"]:
                 output_name = f"{shader_base_name}_" + "_".join([variant["REPLS"]["SRC0_TYPE"], variant["REPLS"]["SRC1_TYPE"]])
-            elif "TYPE_SUFFIX" in variant["REPLS"]:
-                output_name = f"{shader_base_name}_" + variant["REPLS"]["TYPE_SUFFIX"]
-            elif "TYPE" in variant["REPLS"]:
+            elif "REPLS" in variant and "SRC_TYPE" in variant["REPLS"] and "DST_TYPE" in variant["REPLS"]:
+                output_name = f"{shader_base_name}_" + "_".join([variant["REPLS"]["SRC_TYPE"], variant["REPLS"]["DST_TYPE"]])
+            elif "REPLS" in variant and "TYPE" in variant["REPLS"]:
                 output_name = f"{shader_base_name}_" + variant["REPLS"]["TYPE"]
             else:
                 output_name = shader_base_name
index e3fe311b26b8301bd4bfd604df5d735475658e98..f80ce1fc550609174ef511bc0d4dbecd51ae6c4f 100644 (file)
@@ -2,9 +2,9 @@
 
 [
   {
+    "SHADER_SUFFIX": "f32_vec",
     "REPLS": {
       "TYPE" : "vec4<f32>",
-      "TYPE_SUFFIX": "f32_vec",
       "DST_TYPE": "vec4<f32>",
       "BLOCK_SIZE": 4
     },
diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/glu.tmpl.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/glu.tmpl.wgsl
new file mode 100644 (file)
index 0000000..03fcd54
--- /dev/null
@@ -0,0 +1,323 @@
+#define(VARIANTS)
+
+[
+  {
+    "SHADER_NAME": "reglu_f32",
+    "REPLS": {
+      "TYPE" : "f32",
+    },
+    "DECLS": ["NO_SPLIT", "REGLU"]
+  },
+  {
+    "SHADER_NAME": "reglu_f32_split",
+    "REPLS": {
+      "TYPE" : "f32",
+    },
+    "DECLS": ["SPLIT", "REGLU"]
+  },
+  {
+    "SHADER_NAME": "reglu_f16",
+    "REPLS": {
+      "TYPE" : "f16",
+    },
+    "DECLS": ["NO_SPLIT", "REGLU"]
+  },
+  {
+    "SHADER_NAME": "reglu_f16_split",
+    "REPLS": {
+      "TYPE" : "f16",
+    },
+    "DECLS": ["SPLIT", "REGLU"]
+  },
+  {
+    "SHADER_NAME": "geglu_f32",
+    "REPLS": {
+      "TYPE" : "f32",
+    },
+    "DECLS": ["NO_SPLIT", "GEGLU"]
+  },
+  {
+    "SHADER_NAME": "geglu_f32_split",
+    "REPLS": {
+      "TYPE" : "f32",
+    },
+    "DECLS": ["SPLIT", "GEGLU"]
+  },
+  {
+    "SHADER_NAME": "geglu_f16",
+    "REPLS": {
+      "TYPE" : "f16",
+    },
+    "DECLS": ["NO_SPLIT", "GEGLU"]
+  },
+  {
+    "SHADER_NAME": "geglu_f16_split",
+    "REPLS": {
+      "TYPE" : "f16",
+    },
+    "DECLS": ["SPLIT", "GEGLU"]
+  },
+  {
+    "SHADER_NAME": "swiglu_f32",
+    "REPLS": {
+      "TYPE" : "f32",
+    },
+    "DECLS": ["NO_SPLIT", "SWIGLU"]
+  },
+  {
+    "SHADER_NAME": "swiglu_f32_split",
+    "REPLS": {
+      "TYPE" : "f32",
+    },
+    "DECLS": ["SPLIT", "SWIGLU"]
+  },
+  {
+    "SHADER_NAME": "swiglu_f16",
+    "REPLS": {
+      "TYPE" : "f16",
+    },
+    "DECLS": ["NO_SPLIT", "SWIGLU"]
+  },
+  {
+    "SHADER_NAME": "swiglu_f16_split",
+    "REPLS": {
+      "TYPE" : "f16",
+    },
+    "DECLS": ["SPLIT", "SWIGLU"]
+  },
+  {
+    "SHADER_NAME": "swiglu_oai_f32",
+    "REPLS": {
+      "TYPE" : "f32",
+    },
+    "DECLS": ["NO_SPLIT", "SWIGLU_OAI"]
+  },
+  {
+    "SHADER_NAME": "swiglu_oai_f32_split",
+    "REPLS": {
+      "TYPE" : "f32",
+    },
+    "DECLS": ["SPLIT", "SWIGLU_OAI"]
+  },
+  {
+    "SHADER_NAME": "geglu_erf_f32",
+    "REPLS": {
+      "TYPE" : "f32",
+    },
+    "DECLS": ["NO_SPLIT", "GEGLU_ERF"]
+  },
+  {
+    "SHADER_NAME": "geglu_erf_f32_split",
+    "REPLS": {
+      "TYPE" : "f32",
+    },
+    "DECLS": ["SPLIT", "GEGLU_ERF"]
+  },
+  {
+    "SHADER_NAME": "geglu_erf_f16",
+    "REPLS": {
+      "TYPE" : "f16",
+    },
+    "DECLS": ["NO_SPLIT", "GEGLU_ERF"]
+  },
+  {
+    "SHADER_NAME": "geglu_erf_f16_split",
+    "REPLS": {
+      "TYPE" : "f16",
+    },
+    "DECLS": ["SPLIT", "GEGLU_ERF"]
+  },
+  {
+    "SHADER_NAME": "geglu_quick_f32",
+    "REPLS": {
+      "TYPE" : "f32",
+    },
+    "DECLS": ["NO_SPLIT", "GEGLU_QUICK"]
+  },
+  {
+    "SHADER_NAME": "geglu_quick_f32_split",
+    "REPLS": {
+      "TYPE" : "f32",
+    },
+    "DECLS": ["SPLIT", "GEGLU_QUICK"]
+  },
+  {
+    "SHADER_NAME": "geglu_quick_f16",
+    "REPLS": {
+      "TYPE" : "f16",
+    },
+    "DECLS": ["NO_SPLIT", "GEGLU_QUICK"]
+  },
+  {
+    "SHADER_NAME": "geglu_quick_f16_split",
+    "REPLS": {
+      "TYPE" : "f16",
+    },
+    "DECLS": ["SPLIT", "GEGLU_QUICK"]
+  },
+]
+
+#end(VARIANTS)
+
+#define(DECLS)
+
+#decl(REGLU)
+fn op(a: {{TYPE}}, b: {{TYPE}}) -> {{TYPE}} {
+    return max(a, 0) * b;
+}
+#enddecl(REGLU)
+
+#decl(GEGLU)
+const SQRT_2_OVER_PI: {{TYPE}} = 0.79788456080286535587989211986876;
+const GELU_COEF_A: {{TYPE}} = 0.044715;
+
+fn op(a: {{TYPE}}, b: {{TYPE}}) -> {{TYPE}} {
+    let val = SQRT_2_OVER_PI * a * (1.0 + GELU_COEF_A * a * a);
+    return 0.5 * a * (2.0 - 2.0 / (exp(2 * val) + 1)) * b;
+}
+#enddecl(GEGLU)
+
+#decl(SWIGLU)
+fn op(a: {{TYPE}}, b: {{TYPE}}) -> {{TYPE}} {
+    return a / (1.0 + exp(-a)) * b;
+}
+#enddecl(SWIGLU)
+
+#decl(SWIGLU_OAI)
+fn op(a: f32, b: f32) -> f32 {
+  let xi = min(a, params.limit);
+  let gi = max(min(b, params.limit), -params.limit);
+  var out_glu = xi / (1.0 + exp(-xi * params.alpha));
+  out_glu = out_glu * (1.0 + gi);
+  return out_glu;
+}
+#enddecl(SWIGLU_OAI)
+
+#decl(GEGLU_ERF)
+const p_erf: {{TYPE}} = 0.3275911;
+const a1_erf: {{TYPE}} = 0.254829592;
+const a2_erf: {{TYPE}} = -0.284496736;
+const a3_erf: {{TYPE}} = 1.421413741;
+const a4_erf: {{TYPE}} = -1.453152027;
+const a5_erf: {{TYPE}} = 1.061405429;
+const SQRT_2_INV: {{TYPE}} = 0.7071067811865476;
+
+fn op(a: {{TYPE}}, b: {{TYPE}}) -> {{TYPE}} {
+  let a_div_sqr2 = a * SQRT_2_INV;
+  let sign_x = sign(a_div_sqr2);
+  let x = abs(a_div_sqr2);
+  let t = 1.0 / (1.0 + p_erf * x);
+  let y = 1.0 - (((((a5_erf * t + a4_erf) * t + a3_erf) * t + a2_erf) * t + a1_erf) * t * exp(-x * x));
+  let erf_approx = sign_x * y;
+  return 0.5 * a * (1.0 + erf_approx) * b;
+}
+#enddecl(GEGLU_ERF)
+
+#decl(GEGLU_QUICK)
+const GELU_QUICK_COEF: {{TYPE}} = -1.702;
+
+fn op(a: {{TYPE}}, b: {{TYPE}}) -> {{TYPE}} {
+    return a * (1.0 / (1.0 + exp(GELU_QUICK_COEF * a))) * b;
+}
+#enddecl(GEGLU_QUICK)
+
+#decl(NO_SPLIT)
+@group(0) @binding(1)
+var<storage, read_write> dst: array<{{TYPE}}>;
+
+@group(0) @binding(2)
+var<uniform> params: Params;
+
+fn a_value(base: u32) -> {{TYPE}} {
+    let offset: u32 = select(0, params.ne0, params.swapped != 0);
+    return src0[base + offset];
+}
+
+fn b_value(base: u32) -> {{TYPE}} {
+    let offset: u32 = select(params.ne0, 0, params.swapped != 0);
+    return src0[base + offset];
+}
+#enddecl(NO_SPLIT)
+
+#decl(SPLIT)
+@group(0) @binding(1)
+var<storage, read_write> src1: array<{{TYPE}}>;
+
+@group(0) @binding(2)
+var<storage, read_write> dst: array<{{TYPE}}>;
+
+@group(0) @binding(3)
+var<uniform> params: Params;
+
+fn a_value(base: u32) -> {{TYPE}} {
+    return src0[base];
+}
+
+fn b_value(base: u32) -> {{TYPE}} {
+    return src1[base];
+}
+#enddecl(SPLIT)
+
+#end(DECLS)
+
+#define(SHADER)
+
+enable f16;
+
+struct Params {
+    offset_src0: u32,
+    offset_src1: u32,
+    offset_dst: u32,
+
+    // Strides (in elements)
+    stride_src01: u32,
+    stride_src02: u32,
+    stride_src03: u32,
+
+    stride_src11: u32,
+    stride_src12: u32,
+    stride_src13: u32,
+
+    stride_dst1: u32,
+    stride_dst2: u32,
+    stride_dst3: u32,
+
+    // shape of dst
+    ne: u32,
+    ne0: u32,
+    ne1: u32,
+    ne2: u32,
+
+    swapped: u32,
+    alpha: f32,
+    limit: f32,
+}
+
+@group(0) @binding(0)
+var<storage, read_write> src0: array<{{TYPE}}>;
+
+DECLS
+
+override wg_size: u32;
+@compute @workgroup_size(wg_size)
+fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
+    if (gid.x >= params.ne) {
+        return;
+    }
+
+    var i = gid.x;
+    let i3 = i / (params.ne2 * params.ne1 * params.ne0);
+    i = i % (params.ne2 * params.ne1 * params.ne0);
+    let i2 = i / (params.ne1 * params.ne0);
+    i = i % (params.ne1 * params.ne0);
+    let i1 = i / params.ne0;
+    let i0 = i % params.ne0;
+
+    let i_a = params.offset_src0 + i3 * params.stride_src03 + i2 * params.stride_src02 + i1 * params.stride_src01 + i0;
+    let i_b = params.offset_src1 + i3 * params.stride_src13 + i2 * params.stride_src12 + i1 * params.stride_src11 + i0;
+    let i_dst = params.offset_dst + i3 * params.stride_dst3 + i2 * params.stride_dst2 + i1 * params.stride_dst1 + i0;
+
+    dst[i_dst] = op(a_value(i_a), b_value(i_b));
+}
+
+#end(SHADER)
diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/mul.tmpl.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/mul.tmpl.wgsl
deleted file mode 100644 (file)
index 12506e1..0000000
+++ /dev/null
@@ -1,44 +0,0 @@
-#define(VARIANTS)
-
-[
-  {
-    "REPLS": {
-      "TYPE" : "f32",
-    }
-  },
-  {
-    "REPLS": {
-      "TYPE" : "f16",
-    }
-  }
-]
-
-#end(VARIANTS)
-
-#define(SHADER)
-
-enable f16;
-
-#include "binary_head.tmpl"
-
-@group(0) @binding(0)
-var<storage, read_write> src0: array<{{TYPE}}>;
-
-@group(0) @binding(1)
-var<storage, read_write> src1: array<{{TYPE}}>;
-
-@group(0) @binding(2)
-var<storage, read_write> dst: array<{{TYPE}}>;
-
-@group(0) @binding(3)
-var<uniform> params: Params;
-
-override wg_size: u32;
-@compute @workgroup_size(wg_size)
-fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
-    if (gid.x < params.ne) {
-        dst[params.offset_dst + gid.x] = src0[params.offset_src0 + gid.x] * src1[params.offset_src1 + src1_index(gid.x)];
-    }
-}
-
-#end(SHADER)
diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/mul_in_place.tmpl.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/mul_in_place.tmpl.wgsl
deleted file mode 100644 (file)
index e467e59..0000000
+++ /dev/null
@@ -1,41 +0,0 @@
-#define(VARIANTS)
-
-[
-  {
-    "REPLS": {
-      "TYPE" : "f32",
-    }
-  },
-  {
-    "REPLS": {
-      "TYPE" : "f16",
-    }
-  }
-]
-
-#end(VARIANTS)
-
-#define(SHADER)
-
-enable f16;
-
-#include "binary_head.tmpl"
-
-@group(0) @binding(0)
-var<storage, read_write> src0: array<{{TYPE}}>;
-
-@group(0) @binding(1)
-var<storage, read_write> src1: array<{{TYPE}}>;
-
-@group(0) @binding(2)
-var<uniform> params: Params;
-
-override wg_size: u32;
-@compute @workgroup_size(wg_size)
-fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
-    if (gid.x < params.ne) {
-        src0[params.offset_dst + gid.x] = src0[params.offset_src0 + gid.x] * src1[params.offset_src1 + src1_index(gid.x)];
-    }
-}
-
-#end(SHADER)
index f919a51336343182083c8cbcc473ab6b05a60324..a275eeb9783daa4eaee4fb2e9e46f5da6815d4d2 100644 (file)
@@ -1,9 +1,48 @@
-@group(0) @binding(0)
-var<storage, read_write> src: array<f32>;
+#define(VARIANTS)
+
+[
+  {
+    "DECLS": ["NOT_INPLACE"]
+  },
+  {
+    "SHADER_SUFFIX": "inplace",
+    "DECLS": ["INPLACE"]
+  },
+]
+
+#end(VARIANTS)
+
+#define(DECLS)
+
+#decl(NOT_INPLACE)
+
+fn update(src_offset: u32, dst_offset: u32, scale: f32) {
+    dst[dst_offset] = scale * src[src_offset];
+}
 
 @group(0) @binding(1)
 var<storage, read_write> dst: array<f32>;
 
+@group(0) @binding(2)
+var<uniform> params: Params;
+
+#enddecl(NOT_INPLACE)
+
+#decl(INPLACE)
+
+fn update(src_offset: u32, dst_offset: u32, scale: f32) {
+    src[dst_offset] = scale * src[src_offset];
+}
+
+@group(0) @binding(1)
+var<uniform> params: Params;
+
+#enddecl(INPLACE)
+
+#end(DECLS)
+
+#define(SHADER)
+
 struct Params {
     offset_src: u32, // in elements
     offset_dst: u32, // in elements
@@ -23,11 +62,13 @@ struct Params {
     ne2: u32,
     ne3: u32,
 
-    eps: u32
+    eps: f32
 };
 
-@group(0) @binding(2)
-var<uniform> params: Params;
+@group(0) @binding(0)
+var<storage, read_write> src: array<f32>;
+
+DECLS
 
 override wg_size: u32;
 @compute @workgroup_size(wg_size)
@@ -49,9 +90,9 @@ fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
     for (var j: u32 = 0; j < params.ne0; j++) {
         sum += src[i_src_row + j] * src[i_src_row + j];
     }
-    let eps = bitcast<f32>(params.eps);
-    let scale = 1.0/sqrt(sum/f32(params.ne0) + eps);
+    let scale = 1.0/sqrt(sum/f32(params.ne0) + params.eps);
     for (var j: u32 = 0; j < params.ne0; j++) {
-        dst[i_dst_row + j] = scale * src[i_src_row + j];
+        update(i_src_row + j, i_dst_row + j, scale);
     }
 }
+#end(SHADER)
diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/rms_norm_in_place.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/rms_norm_in_place.wgsl
deleted file mode 100644 (file)
index ae84f55..0000000
+++ /dev/null
@@ -1,48 +0,0 @@
-@group(0) @binding(0)
-var<storage, read_write> a: array<f32>;
-
-struct Params {
-    offset: u32, // in elements
-
-    // Strides (in elements)
-    stride1: u32,
-    stride2: u32,
-    stride3: u32,
-
-    // Shape
-    ne0: u32,
-    ne1: u32,
-    ne2: u32,
-    ne3: u32,
-
-    eps: u32
-};
-
-@group(0) @binding(1)
-var<uniform> params: Params;
-
-override wg_size: u32;
-@compute @workgroup_size(wg_size)
-fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
-    if (gid.x >= params.ne1 * params.ne2 * params.ne3) {
-        return;
-    }
-
-    // one thread per row
-    var i = gid.x;
-    let i3 = i / (params.ne2 * params.ne1);
-    i = i % (params.ne2 * params.ne1);
-    let i2 = i / params.ne1;
-    let i1 = i % params.ne1;
-    let i_row = params.offset + i3 * params.stride3 + i2 * params.stride2 + i1 * params.stride1;
-
-    var sum = 0.0f;
-    for (var j: u32 = 0; j < params.ne0; j++) {
-        sum += a[i_row + j] * a[i_row + j];
-    }
-    let eps = bitcast<f32>(params.eps);
-    let scale = 1.0/sqrt(sum/f32(params.ne0) + eps);
-    for (var j: u32 = 0; j < params.ne0; j++) {
-        a[i_row + j] = scale * a[i_row + j];
-    }
-}
diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/rope.tmpl.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/rope.tmpl.wgsl
new file mode 100644 (file)
index 0000000..9a6ff41
--- /dev/null
@@ -0,0 +1,282 @@
+#define(VARIANTS)
+
+[
+  {
+    "REPLS": {
+      "TYPE" : "f32",
+    },
+    "DECLS": ["NO_FF_BINDINGS", "NO_FF_FUNC", "ROTATE"]
+  },
+  {
+    "SHADER_SUFFIX": "f32_inplace",
+    "REPLS": {
+      "TYPE" : "f32",
+    },
+    "DECLS": ["NO_FF_BINDINGS_INPLACE", "NO_FF_FUNC", "ROTATE_INPLACE"]
+  },
+  {
+    "REPLS": {
+      "TYPE" : "f16",
+    },
+    "DECLS": ["NO_FF_BINDINGS", "NO_FF_FUNC", "ROTATE"]
+  },
+  {
+    "SHADER_SUFFIX": "f16_inplace",
+    "REPLS": {
+      "TYPE" : "f16",
+    },
+    "DECLS": ["NO_FF_BINDINGS_INPLACE", "NO_FF_FUNC", "ROTATE_INPLACE"]
+  },
+  {
+   "SHADER_SUFFIX": "f32_ff",
+    "REPLS": {
+      "TYPE" : "f32",
+    },
+    "DECLS": ["FF_BINDINGS", "FF_FUNC", "ROTATE"]
+  },
+  {
+   "SHADER_SUFFIX": "f32_ff_inplace",
+    "REPLS": {
+      "TYPE" : "f32",
+    },
+    "DECLS": ["FF_BINDINGS_INPLACE", "FF_FUNC", "ROTATE_INPLACE"]
+  },
+  {
+    "SHADER_SUFFIX": "f16_ff",
+    "REPLS": {
+      "TYPE" : "f16",
+    },
+    "DECLS": ["FF_BINDINGS", "FF_FUNC", "ROTATE"]
+  },
+  {
+    "SHADER_SUFFIX": "f16_ff_inplace",
+    "REPLS": {
+      "TYPE" : "f16",
+    },
+    "DECLS": ["FF_BINDINGS_INPLACE", "FF_FUNC", "ROTATE_INPLACE"]
+  }
+]
+
+#end(VARIANTS)
+
+#define(DECLS)
+
+#decl(ROTATE)
+fn rotate(i_dst0: u32, i_dst1: u32, out0: f32, out1: f32) {
+    dst[i_dst0] = {{TYPE}}(out0);
+    dst[i_dst1] = {{TYPE}}(out1);
+}
+#enddecl(ROTATE)
+
+#decl(ROTATE_INPLACE)
+fn rotate(i_dst0: u32, i_dst1: u32, out0: f32, out1: f32) {
+    src0[i_dst0] = {{TYPE}}(out0);
+    src0[i_dst1] = {{TYPE}}(out1);
+}
+#enddecl(ROTATE_INPLACE)
+
+#decl(NO_FF_FUNC)
+fn freq_factor(i: u32) -> f32 {
+    return 1.0f;
+}
+#enddecl(NO_FF_FUNC)
+
+#decl(FF_FUNC)
+fn freq_factor(i: u32) -> f32 {
+    return src2[params.offset_src2 + i/2];
+}
+#enddecl(FF_FUNC)
+
+#decl(NO_FF_BINDINGS)
+
+@group(0) @binding(2)
+var<storage, read_write> dst: array<{{TYPE}}>;
+
+@group(0) @binding(3)
+var<uniform> params: Params;
+
+#enddecl(NO_FF_BINDINGS)
+
+#decl(NO_FF_BINDINGS_INPLACE)
+
+@group(0) @binding(2)
+var<uniform> params: Params;
+
+#enddecl(NO_FF_BINDINGS_INPLACE)
+
+#decl(FF_BINDINGS)
+
+@group(0) @binding(2)
+var<storage, read_write> src2: array<f32>;
+
+@group(0) @binding(3)
+var<storage, read_write> dst: array<{{TYPE}}>;
+
+@group(0) @binding(4)
+var<uniform> params: Params;
+
+#enddecl(FF_BINDINGS)
+
+#decl(FF_BINDINGS_INPLACE)
+
+@group(0) @binding(2)
+var<storage, read_write> src2: array<f32>;
+
+@group(0) @binding(3)
+var<uniform> params: Params;
+
+#enddecl(FF_BINDINGS_INPLACE)
+
+#end(DECLS)
+
+#define(SHADER)
+
+enable f16;
+
+struct Params {
+    offset_src0: u32,
+    offset_src1: u32,
+    offset_src2: u32,
+    offset_dst: u32,
+
+    // Strides (in elements)
+    stride_src01: u32,
+    stride_src02: u32,
+    stride_src03: u32,
+
+    stride_dst1: u32,
+    stride_dst2: u32,
+    stride_dst3: u32,
+
+    n_threads: u32,
+    ne0: u32,
+    ne1: u32,
+    ne2: u32,
+
+    n_dims: u32,
+    mode: u32,
+    theta_scale: f32,
+    attn_factor: f32,
+    freq_scale: f32,
+    ext_factor: f32,
+    corr_dim0: f32,
+    corr_dim1: f32,
+    sections0: u32,
+    sections1: u32,
+    sections2: u32,
+    sections3: u32
+};
+
+@group(0) @binding(0)
+var<storage, read_write> src0: array<{{TYPE}}>;
+
+@group(0) @binding(1)
+var<storage, read_write> src1: array<i32>;
+
+DECLS
+
+fn rope_yarn_ramp(low: f32, high: f32, i: u32) -> f32 {
+    let y = (f32(i / 2) - low) / max(0.001f, high - low);
+    return 1.0f - min(1.0f, max(0.0f, y));
+}
+
+// returns vector of (cos_theta, sin_theta)
+// TODO: check performance of instantiating once on the CPU and passed as buffer, since it's repeated per-row
+fn rope_yarn(theta_extrap: f32, i: u32) -> vec2<f32> {
+    var mscale = params.attn_factor;
+    var theta = params.freq_scale * theta_extrap;
+    if (params.ext_factor != 0.0f) {
+        let ramp_mix = rope_yarn_ramp(params.corr_dim0, params.corr_dim1, i) * params.ext_factor;
+        theta = theta * (1 - ramp_mix) + theta_extrap * ramp_mix;
+        mscale *= 1.0f + 0.1f * log(1.0f / params.freq_scale);
+    }
+    return vec2<f32>(cos(theta) * mscale, sin(theta) * mscale);
+}
+
+fn pair_base(i0: u32, div_2: bool) -> u32 {
+    if (div_2) {
+        return i0 / 2;
+    } else {
+        return i0;
+    }
+}
+
+fn pair_offset(is_neox: bool, is_mrope: bool, is_vision: bool) -> u32 {
+    if (is_vision) {
+        return params.n_dims;
+    } else if (is_neox || is_mrope) {
+        return params.n_dims / 2;
+    } else {
+        return 1;
+    }
+}
+
+override wg_size: u32;
+@compute @workgroup_size(wg_size)
+fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
+    // two elements per thread
+    if (gid.x >= params.n_threads) {
+        return;
+    }
+
+    let is_neox = bool(params.mode & 2);
+    let is_mrope = bool(params.mode & 8);
+    let is_vision = params.mode == 24;
+
+    var i = gid.x * 2; // start index for this thread
+    let i3 = i / (params.ne2 * params.ne1 * params.ne0);
+    i = i % (params.ne2 * params.ne1 * params.ne0);
+    let i2 = i / (params.ne1 * params.ne0);
+    i = i % (params.ne1 * params.ne0);
+    let i1 = i / params.ne0;
+    let i0 = i % params.ne0;
+
+    let i_src_row = params.offset_src0 + i3 * params.stride_src03 + i2 * params.stride_src02 + i1 * params.stride_src01;
+    let i_dst_row = params.offset_dst + i3 * params.stride_dst3 + i2 * params.stride_dst2 + i1 * params.stride_dst1;
+
+    if (i0 >= params.n_dims && !is_vision) {
+        let i_src = i_src_row + i0;
+        let i_dst = i_dst_row + i0;
+        rotate(i_dst, i_dst + 1, f32(src0[i_src]), f32(src0[i_src + 1]));
+        return;
+    }
+
+    var theta_base_mult: u32 = 0;
+    var theta_scale_pwr: u32 = i0 / 2;
+    if (is_mrope) {
+        let sect_dims = params.sections0 + params.sections1 + params.sections2 + params.sections3;
+        let sec_w = params.sections1 + params.sections0;
+        let sec_e = params.sections2 + sec_w;
+        let sector = (i0 / 2) % sect_dims;
+        if (sector >= params.sections0 && sector < sec_w) {
+            theta_base_mult = 1;
+            if (is_vision) {
+                theta_scale_pwr = sector - params.sections0;
+            }
+        } else if (sector >= sec_w && sector < sec_e) {
+            theta_base_mult = 2;
+            if (is_vision) {
+                theta_scale_pwr = sector - sec_w;
+            }
+        } else if (sector >= sec_e) {
+            if (is_vision) {
+                theta_scale_pwr = sector - sec_e;
+                theta_scale_pwr = (i0 / 2) % sec_e;
+            }
+            theta_base_mult = 3;
+        } else if (is_vision) {
+            theta_scale_pwr = sector;
+        }
+    }
+    let theta_base = f32(src1[params.offset_src1 + i2 + params.ne2 * theta_base_mult]) * pow(params.theta_scale, f32(theta_scale_pwr));
+    let thetas = rope_yarn(theta_base/freq_factor(i0), i0);
+
+    let i_src = i_src_row + pair_base(i0, is_neox || is_mrope || is_vision);
+    let i_dst = i_dst_row + pair_base(i0, is_neox || is_mrope || is_vision);
+
+    let x0 = f32(src0[i_src]);
+    let x1 = f32(src0[i_src + pair_offset(is_neox, is_mrope, is_vision)]);
+    rotate(i_dst, i_dst + pair_offset(is_neox, is_mrope, is_vision), x0 * thetas.x - x1 * thetas.y, x0 * thetas.y + x1 * thetas.x);
+}
+
+#end(SHADER)
diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/scale.tmpl.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/scale.tmpl.wgsl
new file mode 100644 (file)
index 0000000..040e80d
--- /dev/null
@@ -0,0 +1,90 @@
+#define(VARIANTS)
+
+[
+  {
+    "SHADER_NAME": "scale_f32",
+    "DECLS": ["NOT_INPLACE"]
+  },
+  {
+    "SHADER_NAME": "scale_f32_inplace",
+    "DECLS": ["INPLACE"]
+  }
+]
+
+#end(VARIANTS)
+
+#define(DECLS)
+
+#decl(NOT_INPLACE)
+@group(0) @binding(1)
+var<storage, read_write> dst: array<f32>;
+
+@group(0) @binding(2)
+var<uniform> params: Params;
+
+fn store_scale(val: f32, offset: u32) {
+    dst[offset] = val;
+}
+#enddecl(NOT_INPLACE)
+
+#decl(INPLACE)
+@group(0) @binding(1)
+var<uniform> params: Params;
+
+fn store_scale(val: f32, offset: u32) {
+    src[offset] = val;
+}
+#enddecl(INPLACE)
+
+#end(DECLS)
+
+#define(SHADER)
+
+struct Params {
+    offset_src: u32,
+    offset_dst: u32,
+
+    // Strides (in elements)
+    stride_src1: u32,
+    stride_src2: u32,
+    stride_src3: u32,
+
+    stride_dst1: u32,
+    stride_dst2: u32,
+    stride_dst3: u32,
+
+    ne: u32,
+    ne0: u32,
+    ne1: u32,
+    ne2: u32,
+
+    scale: f32,
+    bias: f32
+};
+
+@group(0) @binding(0)
+var<storage, read_write> src: array<f32>;
+
+DECLS
+
+override wg_size: u32;
+@compute @workgroup_size(wg_size)
+fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
+    if (gid.x >= params.ne) {
+        return;
+    }
+
+    var i = gid.x;
+    let i3 = i / (params.ne2 * params.ne1 * params.ne0);
+    i = i % (params.ne2 * params.ne1 * params.ne0);
+    let i2 = i / (params.ne1 * params.ne0);
+    i = i % (params.ne1 * params.ne0);
+    let i1 = i / params.ne0;
+    let i0 = i % params.ne0;
+
+    let i_src = params.offset_src + i3 * params.stride_src3 + i2 * params.stride_src2 + i1 * params.stride_src1 + i0;
+    let i_dst = params.offset_dst + i3 * params.stride_dst3 + i2 * params.stride_dst2 + i1 * params.stride_dst1 + i0;
+
+    store_scale(src[i_src] * params.scale + params.bias, i_dst);
+}
+#end(SHADER)