]> git.djapps.eu Git - pkg/ggml/sources/ggml/commitdiff
ggml-webgpu: port all AOT operators to JIT (llama/20728)
authorAbhijit Ramesh <redacted>
Wed, 1 Apr 2026 09:58:53 +0000 (12:58 +0300)
committerGeorgi Gerganov <redacted>
Wed, 1 Apr 2026 13:00:26 +0000 (16:00 +0300)
* port cpy pipeline to shader lib with JIT compilation
* port glu pipeline to shader lib with JIT compilation
* port rope pipeline to shader lib with JIT compilation
* port soft_max pipeline to shader lib with JIT compilation
* removed unused functions from embed_wgsl.py which were used for
old AOT template expansion

src/ggml-webgpu/ggml-webgpu-shader-lib.hpp
src/ggml-webgpu/ggml-webgpu.cpp
src/ggml-webgpu/wgsl-shaders/cpy.wgsl [new file with mode: 0644]
src/ggml-webgpu/wgsl-shaders/embed_wgsl.py
src/ggml-webgpu/wgsl-shaders/glu.wgsl [new file with mode: 0644]
src/ggml-webgpu/wgsl-shaders/rope.wgsl [new file with mode: 0644]
src/ggml-webgpu/wgsl-shaders/soft_max.wgsl [new file with mode: 0644]

index 59861ac16cc3b807a249acb4aa0e36431f76699f..97863f40412d65b914235257e0dfb8e23ac8fa0a 100644 (file)
@@ -535,6 +535,95 @@ struct ggml_webgpu_mul_mat_shader_decisions {
     uint32_t mul_mat_wg_size;
 };
 
+/** Cpy **/
+
+struct ggml_webgpu_cpy_pipeline_key {
+    ggml_type src_type;
+    ggml_type dst_type;
+
+    bool operator==(const ggml_webgpu_cpy_pipeline_key & other) const {
+        return src_type == other.src_type && dst_type == other.dst_type;
+    }
+};
+
+struct ggml_webgpu_cpy_pipeline_key_hash {
+    size_t operator()(const ggml_webgpu_cpy_pipeline_key & key) const {
+        size_t seed = 0;
+        ggml_webgpu_hash_combine(seed, key.src_type);
+        ggml_webgpu_hash_combine(seed, key.dst_type);
+        return seed;
+    }
+};
+
+/** Glu **/
+
+struct ggml_webgpu_glu_pipeline_key {
+    ggml_glu_op glu_op;
+    ggml_type   type;
+    bool        split;
+
+    bool operator==(const ggml_webgpu_glu_pipeline_key & other) const {
+        return glu_op == other.glu_op && type == other.type && split == other.split;
+    }
+};
+
+struct ggml_webgpu_glu_pipeline_key_hash {
+    size_t operator()(const ggml_webgpu_glu_pipeline_key & key) const {
+        size_t seed = 0;
+        ggml_webgpu_hash_combine(seed, key.glu_op);
+        ggml_webgpu_hash_combine(seed, key.type);
+        ggml_webgpu_hash_combine(seed, key.split);
+        return seed;
+    }
+};
+
+/** Rope **/
+
+struct ggml_webgpu_rope_pipeline_key {
+    ggml_type type;
+    bool      inplace;
+    bool      has_ff;
+
+    bool operator==(const ggml_webgpu_rope_pipeline_key & other) const {
+        return type == other.type && inplace == other.inplace && has_ff == other.has_ff;
+    }
+};
+
+struct ggml_webgpu_rope_pipeline_key_hash {
+    size_t operator()(const ggml_webgpu_rope_pipeline_key & key) const {
+        size_t seed = 0;
+        ggml_webgpu_hash_combine(seed, key.type);
+        ggml_webgpu_hash_combine(seed, key.inplace);
+        ggml_webgpu_hash_combine(seed, key.has_ff);
+        return seed;
+    }
+};
+
+/** SoftMax **/
+
+struct ggml_webgpu_soft_max_pipeline_key {
+    ggml_type mask_type;
+    bool      has_mask;
+    bool      has_sink;
+    bool      inplace;
+
+    bool operator==(const ggml_webgpu_soft_max_pipeline_key & other) const {
+        return mask_type == other.mask_type && has_mask == other.has_mask && has_sink == other.has_sink &&
+               inplace == other.inplace;
+    }
+};
+
+struct ggml_webgpu_soft_max_pipeline_key_hash {
+    size_t operator()(const ggml_webgpu_soft_max_pipeline_key & key) const {
+        size_t seed = 0;
+        ggml_webgpu_hash_combine(seed, key.mask_type);
+        ggml_webgpu_hash_combine(seed, key.has_mask);
+        ggml_webgpu_hash_combine(seed, key.has_sink);
+        ggml_webgpu_hash_combine(seed, key.inplace);
+        return seed;
+    }
+};
+
 class ggml_webgpu_shader_lib {
     wgpu::Device           device;
     pre_wgsl::Preprocessor preprocessor;
@@ -582,6 +671,12 @@ class ggml_webgpu_shader_lib {
     std::unordered_map<ggml_webgpu_set_rows_pipeline_key, webgpu_pipeline, ggml_webgpu_set_rows_pipeline_key_hash>
         set_rows_pipelines;
     std::unordered_map<ggml_webgpu_set_pipeline_key, webgpu_pipeline, ggml_webgpu_set_pipeline_key_hash> set_pipelines;
+    std::unordered_map<ggml_webgpu_cpy_pipeline_key, webgpu_pipeline, ggml_webgpu_cpy_pipeline_key_hash> cpy_pipelines;
+    std::unordered_map<ggml_webgpu_glu_pipeline_key, webgpu_pipeline, ggml_webgpu_glu_pipeline_key_hash> glu_pipelines;
+    std::unordered_map<ggml_webgpu_rope_pipeline_key, webgpu_pipeline, ggml_webgpu_rope_pipeline_key_hash>
+        rope_pipelines;
+    std::unordered_map<ggml_webgpu_soft_max_pipeline_key, webgpu_pipeline, ggml_webgpu_soft_max_pipeline_key_hash>
+        soft_max_pipelines;
 
   public:
     ggml_webgpu_shader_lib(wgpu::Device device) { this->device = device; }
@@ -1679,6 +1774,236 @@ class ggml_webgpu_shader_lib {
         return flash_attn_pipelines[key];
     }
 
+    webgpu_pipeline get_cpy_pipeline(const ggml_webgpu_shader_lib_context & context) {
+        ggml_webgpu_cpy_pipeline_key key = {
+            .src_type = context.src0->type,
+            .dst_type = context.dst->type,
+        };
+
+        auto it = cpy_pipelines.find(key);
+        if (it != cpy_pipelines.end()) {
+            return it->second;
+        }
+
+        std::vector<std::string> defines;
+        std::string              variant = "cpy";
+
+        switch (key.src_type) {
+            case GGML_TYPE_F32:
+                defines.push_back("SRC_F32");
+                variant += "_f32";
+                break;
+            case GGML_TYPE_F16:
+                defines.push_back("SRC_F16");
+                variant += "_f16";
+                break;
+            default:
+                GGML_ABORT("Unsupported src type for cpy shader");
+        }
+
+        switch (key.dst_type) {
+            case GGML_TYPE_F32:
+                defines.push_back("DST_F32");
+                variant += "_f32";
+                break;
+            case GGML_TYPE_F16:
+                defines.push_back("DST_F16");
+                variant += "_f16";
+                break;
+            case GGML_TYPE_I32:
+                defines.push_back("DST_I32");
+                variant += "_i32";
+                break;
+            default:
+                GGML_ABORT("Unsupported dst type for cpy shader");
+        }
+
+        defines.push_back(std::string("WG_SIZE=") + std::to_string(context.max_wg_size));
+
+        auto processed           = preprocessor.preprocess(wgsl_cpy, defines);
+        auto decisions           = std::make_shared<ggml_webgpu_generic_shader_decisions>();
+        decisions->wg_size       = context.max_wg_size;
+        webgpu_pipeline pipeline = ggml_webgpu_create_pipeline(device, processed, variant);
+        pipeline.context         = decisions;
+        cpy_pipelines[key]       = pipeline;
+        return cpy_pipelines[key];
+    }
+
+    webgpu_pipeline get_glu_pipeline(const ggml_webgpu_shader_lib_context & context) {
+        ggml_webgpu_glu_pipeline_key key = {
+            .glu_op = ggml_get_glu_op(context.dst),
+            .type   = context.dst->type,
+            .split  = (context.src1 != nullptr),
+        };
+
+        auto it = glu_pipelines.find(key);
+        if (it != glu_pipelines.end()) {
+            return it->second;
+        }
+
+        std::vector<std::string> defines;
+        std::string              variant = "glu";
+
+        switch (key.glu_op) {
+            case GGML_GLU_OP_REGLU:
+                defines.push_back("OP_REGLU");
+                variant += "_reglu";
+                break;
+            case GGML_GLU_OP_GEGLU:
+                defines.push_back("OP_GEGLU");
+                variant += "_geglu";
+                break;
+            case GGML_GLU_OP_SWIGLU:
+                defines.push_back("OP_SWIGLU");
+                variant += "_swiglu";
+                break;
+            case GGML_GLU_OP_SWIGLU_OAI:
+                defines.push_back("OP_SWIGLU_OAI");
+                variant += "_swiglu_oai";
+                break;
+            case GGML_GLU_OP_GEGLU_ERF:
+                defines.push_back("OP_GEGLU_ERF");
+                variant += "_geglu_erf";
+                break;
+            case GGML_GLU_OP_GEGLU_QUICK:
+                defines.push_back("OP_GEGLU_QUICK");
+                variant += "_geglu_quick";
+                break;
+            default:
+                GGML_ABORT("Unsupported GLU op");
+        }
+        switch (key.type) {
+            case GGML_TYPE_F32:
+                defines.push_back("TYPE_F32");
+                variant += "_f32";
+                break;
+            case GGML_TYPE_F16:
+                defines.push_back("TYPE_F16");
+                variant += "_f16";
+                break;
+            default:
+                GGML_ABORT("Unsupported type for GLU shader");
+        }
+
+        if (key.split) {
+            variant += "_split";
+        } else {
+            defines.push_back("NO_SPLIT");
+        }
+
+        defines.push_back(std::string("WG_SIZE=") + std::to_string(context.max_wg_size));
+
+        auto processed           = preprocessor.preprocess(wgsl_glu, defines);
+        auto decisions           = std::make_shared<ggml_webgpu_generic_shader_decisions>();
+        decisions->wg_size       = context.max_wg_size;
+        webgpu_pipeline pipeline = ggml_webgpu_create_pipeline(device, processed, variant);
+        pipeline.context         = decisions;
+        glu_pipelines[key]       = pipeline;
+        return glu_pipelines[key];
+    }
+
+    webgpu_pipeline get_rope_pipeline(const ggml_webgpu_shader_lib_context & context) {
+        ggml_webgpu_rope_pipeline_key key = {
+            .type    = context.dst->type,
+            .inplace = context.inplace,
+            .has_ff  = (context.src2 != nullptr),
+        };
+
+        auto it = rope_pipelines.find(key);
+        if (it != rope_pipelines.end()) {
+            return it->second;
+        }
+
+        std::vector<std::string> defines;
+        std::string              variant = "rope";
+
+        switch (key.type) {
+            case GGML_TYPE_F32:
+                defines.push_back("TYPE_F32");
+                variant += "_f32";
+                break;
+            case GGML_TYPE_F16:
+                defines.push_back("TYPE_F16");
+                variant += "_f16";
+                break;
+            default:
+                GGML_ABORT("Unsupported type for ROPE shader");
+        }
+
+        if (key.inplace) {
+            defines.push_back("INPLACE");
+            variant += "_inplace";
+        }
+
+        if (key.has_ff) {
+            defines.push_back("FF_FUNC");
+            variant += "_ff";
+        }
+
+        defines.push_back(std::string("WG_SIZE=") + std::to_string(context.max_wg_size));
+
+        auto processed           = preprocessor.preprocess(wgsl_rope, defines);
+        auto decisions           = std::make_shared<ggml_webgpu_generic_shader_decisions>();
+        decisions->wg_size       = context.max_wg_size;
+        webgpu_pipeline pipeline = ggml_webgpu_create_pipeline(device, processed, variant);
+        pipeline.context         = decisions;
+        rope_pipelines[key]      = pipeline;
+        return rope_pipelines[key];
+    }
+
+    webgpu_pipeline get_soft_max_pipeline(const ggml_webgpu_shader_lib_context & context) {
+        ggml_webgpu_soft_max_pipeline_key key = {
+            .mask_type = context.src1 ? context.src1->type : GGML_TYPE_F32,
+            .has_mask  = (context.src1 != nullptr),
+            .has_sink  = (context.src2 != nullptr),
+            .inplace   = context.inplace,
+        };
+
+        auto it = soft_max_pipelines.find(key);
+        if (it != soft_max_pipelines.end()) {
+            return it->second;
+        }
+
+        std::vector<std::string> defines;
+        std::string              variant = "soft_max";
+
+        if (key.has_mask) {
+            defines.push_back("HAS_MASK");
+            switch (key.mask_type) {
+                case GGML_TYPE_F32:
+                    defines.push_back("MASK_F32");
+                    variant += "_mask_f32";
+                    break;
+                case GGML_TYPE_F16:
+                    defines.push_back("MASK_F16");
+                    variant += "_mask_f16";
+                    break;
+                default:
+                    GGML_ABORT("Unsupported type for SOFT_MAX shader");
+            }
+        }
+
+        if (key.has_sink) {
+            defines.push_back("HAS_SINK");
+            variant += "_sink";
+        }
+
+        if (key.inplace) {
+            defines.push_back("INPLACE");
+            variant += "_inplace";
+        }
+
+        defines.push_back(std::string("WG_SIZE=") + std::to_string(context.max_wg_size));
+
+        auto processed           = preprocessor.preprocess(wgsl_soft_max, defines);
+        auto decisions           = std::make_shared<ggml_webgpu_generic_shader_decisions>();
+        decisions->wg_size       = context.max_wg_size;
+        webgpu_pipeline pipeline = ggml_webgpu_create_pipeline(device, processed, variant);
+        pipeline.context         = decisions;
+        soft_max_pipelines[key]  = pipeline;
+        return soft_max_pipelines[key];
+    }
+
   private:
     static webgpu_pipeline ggml_webgpu_create_pipeline(wgpu::Device & device,
                                                        std::string    shader_code,
index 5e16f84ddd260cf0acb59c9e8a1e06e467dbbe4b..fa3c492a7a5537c7376a85831bf355ffe5b20697 100644 (file)
@@ -364,13 +364,6 @@ struct webgpu_context_struct {
     wgpu::Buffer    set_rows_dev_error_buf;
     wgpu::Buffer    set_rows_host_error_buf;
 
-    std::map<int, std::map<int, webgpu_pipeline>> cpy_pipelines;                      // src_type, dst_type
-
-    std::map<int, std::map<int, std::map<int, webgpu_pipeline>>> rope_pipelines;      // type, ff, inplace
-    std::map<int, std::map<int, std::map<int, webgpu_pipeline>>> glu_pipelines;       // glu_op, type, split
-
-    std::map<int, std::map<int, std::map<int, webgpu_pipeline>>> soft_max_pipelines;  // mask_type, has_sink, inplace
-
     size_t memset_bytes_per_thread;
 };
 
@@ -849,6 +842,16 @@ static binary_overlap_flags ggml_webgpu_detect_binary_overlap(ggml_tensor * src0
 }
 
 static webgpu_command ggml_webgpu_cpy(webgpu_context & ctx, ggml_tensor * src, ggml_tensor * dst) {
+    ggml_webgpu_shader_lib_context shader_lib_ctx = {
+        .src0        = src,
+        .dst         = dst,
+        .max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup,
+    };
+
+    webgpu_pipeline pipeline = ctx->shader_lib->get_cpy_pipeline(shader_lib_ctx);
+
+    auto * decisions = static_cast<ggml_webgpu_generic_shader_decisions *>(pipeline.context.get());
+
     uint32_t ne = (uint32_t) ggml_nelements(dst);
 
     std::vector<uint32_t> params = {
@@ -875,9 +878,8 @@ static webgpu_command ggml_webgpu_cpy(webgpu_context & ctx, ggml_tensor * src, g
          .size    = ggml_webgpu_tensor_binding_size(ctx, dst) }
     };
 
-    uint32_t wg_x = CEIL_DIV(ne, WEBGPU_MAX_WG_SIZE);
-    return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_buf_pool, ctx->cpy_pipelines[src->type][dst->type],
-                                     params, entries, wg_x);
+    uint32_t wg_x = CEIL_DIV(ne, decisions->wg_size);
+    return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_buf_pool, pipeline, params, entries, wg_x);
 }
 
 static webgpu_command ggml_webgpu_set(webgpu_context & ctx, ggml_tensor * src0, ggml_tensor * src1, ggml_tensor * dst) {
@@ -1914,6 +1916,19 @@ static webgpu_command ggml_webgpu_rope(webgpu_context & ctx,
                                        ggml_tensor *    src1,
                                        ggml_tensor *    src2,
                                        ggml_tensor *    dst) {
+    ggml_webgpu_shader_lib_context shader_lib_ctx = {
+        .src0        = src0,
+        .src1        = src1,
+        .src2        = src2,
+        .dst         = dst,
+        .max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup,
+        .inplace     = ggml_webgpu_tensor_equal(src0, dst),
+    };
+
+    webgpu_pipeline pipeline = ctx->shader_lib->get_rope_pipeline(shader_lib_ctx);
+
+    auto * decisions = static_cast<ggml_webgpu_generic_shader_decisions *>(pipeline.context.get());
+
     const int inplace         = ggml_webgpu_tensor_equal(src0, dst);
     const int has_freq_factor = (src2 != nullptr);
 
@@ -1996,12 +2011,22 @@ static webgpu_command ggml_webgpu_rope(webgpu_context & ctx,
                             .size    = ggml_webgpu_tensor_binding_size(ctx, dst) });
     }
 
-    webgpu_pipeline pipeline = ctx->rope_pipelines[dst->type][has_freq_factor][inplace];
-    uint32_t        wg_x     = CEIL_DIV(ggml_nelements(dst), WEBGPU_MAX_WG_SIZE);
+    uint32_t wg_x = CEIL_DIV(ggml_nelements(dst), decisions->wg_size);
     return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_buf_pool, pipeline, params, entries, wg_x);
 }
 
 static webgpu_command ggml_webgpu_glu(webgpu_context & ctx, ggml_tensor * src0, ggml_tensor * src1, ggml_tensor * dst) {
+    ggml_webgpu_shader_lib_context shader_lib_ctx = {
+        .src0        = src0,
+        .src1        = src1,
+        .dst         = dst,
+        .max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup,
+    };
+
+    webgpu_pipeline pipeline = ctx->shader_lib->get_glu_pipeline(shader_lib_ctx);
+
+    auto * decisions = static_cast<ggml_webgpu_generic_shader_decisions *>(pipeline.context.get());
+
     const int split = (src1 != nullptr);
 
     std::vector<uint32_t> params = {
@@ -2048,8 +2073,7 @@ static webgpu_command ggml_webgpu_glu(webgpu_context & ctx, ggml_tensor * src0,
                         .offset  = ggml_webgpu_tensor_align_offset(ctx, dst),
                         .size    = ggml_webgpu_tensor_binding_size(ctx, dst) });
 
-    webgpu_pipeline pipeline = ctx->glu_pipelines[ggml_get_glu_op(dst)][dst->type][split];
-    uint32_t        wg_x     = CEIL_DIV(ggml_nelements(dst), WEBGPU_MAX_WG_SIZE);
+    uint32_t wg_x = CEIL_DIV(ggml_nelements(dst), decisions->wg_size);
     return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_buf_pool, pipeline, params, entries, wg_x);
 }
 
@@ -2109,9 +2133,20 @@ static webgpu_command ggml_webgpu_soft_max(webgpu_context & ctx,
                                            ggml_tensor *    src1,
                                            ggml_tensor *    src2,
                                            ggml_tensor *    dst) {
-    const int inplace   = ggml_webgpu_tensor_equal(src0, dst);
-    const int mask_type = (src1 != nullptr) ? src1->type : 2;  // use 2 for no mask here
-    const int has_sink  = (src2 != nullptr);
+    ggml_webgpu_shader_lib_context shader_lib_ctx = {
+        .src0        = src0,
+        .src1        = src1,
+        .src2        = src2,
+        .dst         = dst,
+        .max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup,
+        .inplace     = ggml_webgpu_tensor_equal(src0, dst),
+    };
+
+    webgpu_pipeline pipeline = ctx->shader_lib->get_soft_max_pipeline(shader_lib_ctx);
+
+    const int inplace  = ggml_webgpu_tensor_equal(src0, dst);
+    const int has_mask = (src1 != nullptr);
+    const int has_sink = (src2 != nullptr);
     float     max_bias;
     memcpy(&max_bias, (float *) dst->op_params + 1, sizeof(float));
     float n_head_log2 = float(1u << (uint32_t) floor(log2(src0->ne[2])));
@@ -2120,15 +2155,15 @@ static webgpu_command ggml_webgpu_soft_max(webgpu_context & ctx,
 
     std::vector<uint32_t> params = {
         (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src0) / ggml_type_size(src0->type)),
-        mask_type < 2 ? (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src1) / ggml_type_size(src1->type)) : 0,
+        has_mask ? (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src1) / ggml_type_size(src1->type)) : 0,
         has_sink ? (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src2) / ggml_type_size(src2->type)) : 0,
         (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, dst) / ggml_type_size(dst->type)),
         (uint32_t) (src0->nb[1] / ggml_type_size(src0->type)),
         (uint32_t) (src0->nb[2] / ggml_type_size(src0->type)),
         (uint32_t) (src0->nb[3] / ggml_type_size(src0->type)),
-        mask_type < 2 ? (uint32_t) (src1->nb[1] / ggml_type_size(src1->type)) : 0,
-        mask_type < 2 ? (uint32_t) (src1->nb[2] / ggml_type_size(src1->type)) : 0,
-        mask_type < 2 ? (uint32_t) (src1->nb[3] / ggml_type_size(src1->type)) : 0,
+        has_mask ? (uint32_t) (src1->nb[1] / ggml_type_size(src1->type)) : 0,
+        has_mask ? (uint32_t) (src1->nb[2] / ggml_type_size(src1->type)) : 0,
+        has_mask ? (uint32_t) (src1->nb[3] / ggml_type_size(src1->type)) : 0,
         (uint32_t) (dst->nb[1] / ggml_type_size(dst->type)),
         (uint32_t) (dst->nb[2] / ggml_type_size(dst->type)),
         (uint32_t) (dst->nb[3] / ggml_type_size(dst->type)),
@@ -2136,8 +2171,8 @@ static webgpu_command ggml_webgpu_soft_max(webgpu_context & ctx,
         (uint32_t) src0->ne[0],
         (uint32_t) src0->ne[1],
         (uint32_t) src0->ne[2],
-        mask_type < 2 ? (uint32_t) src1->ne[2] : 0,
-        mask_type < 2 ? (uint32_t) src1->ne[3] : 0,
+        has_mask ? (uint32_t) src1->ne[2] : 0,
+        has_mask ? (uint32_t) src1->ne[3] : 0,
         *(uint32_t *) dst->op_params,  // scale
         *(uint32_t *) &max_bias,
         *(uint32_t *) &n_head_log2,
@@ -2152,7 +2187,7 @@ static webgpu_command ggml_webgpu_soft_max(webgpu_context & ctx,
          .size    = ggml_webgpu_tensor_binding_size(ctx, src0) }
     };
     uint32_t binding_num = 1;
-    if (mask_type < 2) {
+    if (has_mask) {
         entries.push_back({ .binding = binding_num,
                             .buffer  = ggml_webgpu_tensor_buf(src1),
                             .offset  = ggml_webgpu_tensor_align_offset(ctx, src1),
@@ -2173,9 +2208,7 @@ static webgpu_command ggml_webgpu_soft_max(webgpu_context & ctx,
                             .size    = ggml_webgpu_tensor_binding_size(ctx, dst) });
     }
 
-    return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_buf_pool,
-                                     ctx->soft_max_pipelines[mask_type][has_sink][inplace], params, entries,
-                                     ggml_nrows(dst));
+    return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_buf_pool, pipeline, params, entries, ggml_nrows(dst));
 }
 
 static webgpu_command ggml_webgpu_argmax(webgpu_context & ctx, ggml_tensor * src, ggml_tensor * dst) {
@@ -2885,139 +2918,6 @@ static void ggml_webgpu_init_memset_pipeline(webgpu_global_context & ctx) {
     ctx->memset_pipelines[0] = ggml_webgpu_create_pipeline(ctx->device, wgsl_memset, "memset", constants);
 }
 
-static void ggml_webgpu_init_cpy_pipeline(webgpu_context & webgpu_ctx) {
-    std::vector<wgpu::ConstantEntry> constants = ggml_webgpu_wg_size_entry(WEBGPU_MAX_WG_SIZE);
-
-    webgpu_ctx->cpy_pipelines[GGML_TYPE_F32][GGML_TYPE_F32] =
-        ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_cpy_f32_f32, "cpy_f32_f32", constants);
-    webgpu_ctx->cpy_pipelines[GGML_TYPE_F32][GGML_TYPE_I32] =
-        ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_cpy_f32_i32, "cpy_f32_i32", constants);
-    webgpu_ctx->cpy_pipelines[GGML_TYPE_F32][GGML_TYPE_F16] =
-        ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_cpy_f32_f16, "cpy_f32_f16", constants);
-    webgpu_ctx->cpy_pipelines[GGML_TYPE_F16][GGML_TYPE_F32] =
-        ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_cpy_f16_f32, "cpy_f16_f32", constants);
-    webgpu_ctx->cpy_pipelines[GGML_TYPE_F16][GGML_TYPE_F16] =
-        ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_cpy_f16_f16, "cpy_f16_f16", constants);
-}
-
-static void ggml_webgpu_init_rope_pipeline(webgpu_context & webgpu_ctx) {
-    std::vector<wgpu::ConstantEntry> constants = ggml_webgpu_wg_size_entry(WEBGPU_MAX_WG_SIZE);
-
-    webgpu_ctx->rope_pipelines[GGML_TYPE_F32][0][0] =
-        ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_rope_f32, "rope_f32", constants);
-    webgpu_ctx->rope_pipelines[GGML_TYPE_F32][0][1] = ggml_webgpu_create_pipeline(
-        webgpu_ctx->global_ctx->device, wgsl_rope_f32_inplace, "rope_f32_inplace", constants);
-    webgpu_ctx->rope_pipelines[GGML_TYPE_F32][1][0] =
-        ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_rope_f32_ff, "rope_f32_ff", constants);
-    webgpu_ctx->rope_pipelines[GGML_TYPE_F32][1][1] = ggml_webgpu_create_pipeline(
-        webgpu_ctx->global_ctx->device, wgsl_rope_f32_ff_inplace, "rope_f32_ff_inplace", constants);
-
-    webgpu_ctx->rope_pipelines[GGML_TYPE_F16][0][0] =
-        ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_rope_f16, "rope_f16", constants);
-    webgpu_ctx->rope_pipelines[GGML_TYPE_F16][0][1] = ggml_webgpu_create_pipeline(
-        webgpu_ctx->global_ctx->device, wgsl_rope_f16_inplace, "rope_f16_inplace", constants);
-    webgpu_ctx->rope_pipelines[GGML_TYPE_F16][1][0] =
-        ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_rope_f16_ff, "rope_f16_ff", constants);
-    webgpu_ctx->rope_pipelines[GGML_TYPE_F16][1][1] = ggml_webgpu_create_pipeline(
-        webgpu_ctx->global_ctx->device, wgsl_rope_f16_ff_inplace, "rope_f16_ff_inplace", constants);
-}
-
-static void ggml_webgpu_init_glu_pipeline(webgpu_context & webgpu_ctx) {
-    std::vector<wgpu::ConstantEntry> constants = ggml_webgpu_wg_size_entry(WEBGPU_MAX_WG_SIZE);
-
-    // REGLU
-    webgpu_ctx->glu_pipelines[GGML_GLU_OP_REGLU][GGML_TYPE_F32][0] =
-        ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_reglu_f32, "reglu_f32", constants);
-    webgpu_ctx->glu_pipelines[GGML_GLU_OP_REGLU][GGML_TYPE_F16][0] =
-        ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_reglu_f16, "reglu_f16", constants);
-    webgpu_ctx->glu_pipelines[GGML_GLU_OP_REGLU][GGML_TYPE_F32][1] =
-        ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_reglu_f32_split, "reglu_f32_split", constants);
-    webgpu_ctx->glu_pipelines[GGML_GLU_OP_REGLU][GGML_TYPE_F16][1] =
-        ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_reglu_f16_split, "reglu_f16_split", constants);
-
-    // GEGLU
-    webgpu_ctx->glu_pipelines[GGML_GLU_OP_GEGLU][GGML_TYPE_F32][0] =
-        ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_geglu_f32, "geglu_f32", constants);
-    webgpu_ctx->glu_pipelines[GGML_GLU_OP_GEGLU][GGML_TYPE_F16][0] =
-        ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_geglu_f16, "geglu_f16", constants);
-    webgpu_ctx->glu_pipelines[GGML_GLU_OP_GEGLU][GGML_TYPE_F32][1] =
-        ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_geglu_f32_split, "geglu_f32_split", constants);
-    webgpu_ctx->glu_pipelines[GGML_GLU_OP_GEGLU][GGML_TYPE_F16][1] =
-        ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_geglu_f16_split, "geglu_f16_split", constants);
-
-    // SWIGLU
-    webgpu_ctx->glu_pipelines[GGML_GLU_OP_SWIGLU][GGML_TYPE_F32][0] =
-        ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_swiglu_f32, "swiglu_f32", constants);
-    webgpu_ctx->glu_pipelines[GGML_GLU_OP_SWIGLU][GGML_TYPE_F16][0] =
-        ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_swiglu_f16, "swiglu_f16", constants);
-    webgpu_ctx->glu_pipelines[GGML_GLU_OP_SWIGLU][GGML_TYPE_F32][1] = ggml_webgpu_create_pipeline(
-        webgpu_ctx->global_ctx->device, wgsl_swiglu_f32_split, "swiglu_f32_split", constants);
-    webgpu_ctx->glu_pipelines[GGML_GLU_OP_SWIGLU][GGML_TYPE_F16][1] = ggml_webgpu_create_pipeline(
-        webgpu_ctx->global_ctx->device, wgsl_swiglu_f16_split, "swiglu_f16_split", constants);
-
-    // SWIGLU_OAI
-    webgpu_ctx->glu_pipelines[GGML_GLU_OP_SWIGLU_OAI][GGML_TYPE_F32][0] =
-        ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_swiglu_oai_f32, "swiglu_oai_f32", constants);
-    webgpu_ctx->glu_pipelines[GGML_GLU_OP_SWIGLU_OAI][GGML_TYPE_F32][1] = ggml_webgpu_create_pipeline(
-        webgpu_ctx->global_ctx->device, wgsl_swiglu_oai_f32_split, "swiglu_oai_f32_split", constants);
-
-    // GEGLU_ERF
-    webgpu_ctx->glu_pipelines[GGML_GLU_OP_GEGLU_ERF][GGML_TYPE_F32][0] =
-        ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_geglu_erf_f32, "geglu_erf_f32", constants);
-    webgpu_ctx->glu_pipelines[GGML_GLU_OP_GEGLU_ERF][GGML_TYPE_F16][0] =
-        ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_geglu_erf_f16, "geglu_erf_f16", constants);
-    webgpu_ctx->glu_pipelines[GGML_GLU_OP_GEGLU_ERF][GGML_TYPE_F32][1] = ggml_webgpu_create_pipeline(
-        webgpu_ctx->global_ctx->device, wgsl_geglu_erf_f32_split, "geglu_erf_f32_split", constants);
-    webgpu_ctx->glu_pipelines[GGML_GLU_OP_GEGLU_ERF][GGML_TYPE_F16][1] = ggml_webgpu_create_pipeline(
-        webgpu_ctx->global_ctx->device, wgsl_geglu_erf_f16_split, "geglu_erf_f16_split", constants);
-
-    // GEGLU_QUICK
-    webgpu_ctx->glu_pipelines[GGML_GLU_OP_GEGLU_QUICK][GGML_TYPE_F32][0] =
-        ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_geglu_quick_f32, "geglu_quick_f32", constants);
-    webgpu_ctx->glu_pipelines[GGML_GLU_OP_GEGLU_QUICK][GGML_TYPE_F16][0] =
-        ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_geglu_quick_f16, "geglu_quick_f16", constants);
-    webgpu_ctx->glu_pipelines[GGML_GLU_OP_GEGLU_QUICK][GGML_TYPE_F32][1] = ggml_webgpu_create_pipeline(
-        webgpu_ctx->global_ctx->device, wgsl_geglu_quick_f32_split, "geglu_quick_f32_split", constants);
-    webgpu_ctx->glu_pipelines[GGML_GLU_OP_GEGLU_QUICK][GGML_TYPE_F16][1] = ggml_webgpu_create_pipeline(
-        webgpu_ctx->global_ctx->device, wgsl_geglu_quick_f16_split, "geglu_quick_f16_split", constants);
-}
-
-static void ggml_webgpu_init_soft_max_pipeline(webgpu_context & webgpu_ctx) {
-    std::vector<wgpu::ConstantEntry> constants = ggml_webgpu_wg_size_entry(WEBGPU_ROW_SPLIT_WG_SIZE);
-
-    // f32 (no mask)
-    webgpu_ctx->soft_max_pipelines[2][0][0] =
-        ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_soft_max_f32, "soft_max_f32", constants);
-    webgpu_ctx->soft_max_pipelines[2][0][1] = ggml_webgpu_create_pipeline(
-        webgpu_ctx->global_ctx->device, wgsl_soft_max_f32_inplace, "soft_max_f32_inplace", constants);
-    webgpu_ctx->soft_max_pipelines[2][1][0] = ggml_webgpu_create_pipeline(
-        webgpu_ctx->global_ctx->device, wgsl_soft_max_f32_sink, "soft_max_f32_sink", constants);
-    webgpu_ctx->soft_max_pipelines[2][1][1] = ggml_webgpu_create_pipeline(
-        webgpu_ctx->global_ctx->device, wgsl_soft_max_f32_sink_inplace, "soft_max_f32_sink_inplace", constants);
-
-    // f32 mask (mask_type = 0)
-    webgpu_ctx->soft_max_pipelines[0][0][0] = ggml_webgpu_create_pipeline(
-        webgpu_ctx->global_ctx->device, wgsl_soft_max_f32_mask_f32, "soft_max_f32_mask_f32", constants);
-    webgpu_ctx->soft_max_pipelines[0][0][1] = ggml_webgpu_create_pipeline(
-        webgpu_ctx->global_ctx->device, wgsl_soft_max_f32_mask_f32_inplace, "soft_max_f32_mask_f32_inplace", constants);
-    webgpu_ctx->soft_max_pipelines[0][1][0] = ggml_webgpu_create_pipeline(
-        webgpu_ctx->global_ctx->device, wgsl_soft_max_f32_mask_f32_sink, "soft_max_f32_mask_f32_sink", constants);
-    webgpu_ctx->soft_max_pipelines[0][1][1] =
-        ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_soft_max_f32_mask_f32_sink_inplace,
-                                    "soft_max_f32_mask_f32_sink_inplace", constants);
-
-    // f16 mask (mask_type = 1)
-    webgpu_ctx->soft_max_pipelines[1][0][0] = ggml_webgpu_create_pipeline(
-        webgpu_ctx->global_ctx->device, wgsl_soft_max_f32_mask_f16, "soft_max_f32_mask_f16", constants);
-    webgpu_ctx->soft_max_pipelines[1][0][1] = ggml_webgpu_create_pipeline(
-        webgpu_ctx->global_ctx->device, wgsl_soft_max_f32_mask_f16_inplace, "soft_max_f32_mask_f16_inplace", constants);
-    webgpu_ctx->soft_max_pipelines[1][1][0] = ggml_webgpu_create_pipeline(
-        webgpu_ctx->global_ctx->device, wgsl_soft_max_f32_mask_f16_sink, "soft_max_f32_mask_f16_sink", constants);
-    webgpu_ctx->soft_max_pipelines[1][1][1] =
-        ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_soft_max_f32_mask_f16_sink_inplace,
-                                    "soft_max_f32_mask_f16_sink_inplace", constants);
-}
-
 static bool create_webgpu_device(ggml_backend_webgpu_reg_context * ctx) {
     wgpu::RequestAdapterOptions options = {};
 
@@ -3183,10 +3083,6 @@ static webgpu_context initialize_webgpu_context(ggml_backend_dev_t dev) {
                               WEBGPU_SET_ROWS_ERROR_BUF_SIZE_BYTES,
                               wgpu::BufferUsage::CopyDst | wgpu::BufferUsage::MapRead, "set_rows_host_error_buf");
 
-    ggml_webgpu_init_cpy_pipeline(webgpu_ctx);
-    ggml_webgpu_init_rope_pipeline(webgpu_ctx);
-    ggml_webgpu_init_glu_pipeline(webgpu_ctx);
-    ggml_webgpu_init_soft_max_pipeline(webgpu_ctx);
 #ifdef GGML_WEBGPU_DEBUG
     // Initialize debug buffers
     ggml_webgpu_create_buffer(webgpu_ctx->global_ctx->device, webgpu_ctx->global_ctx->debug_host_buf,
diff --git a/src/ggml-webgpu/wgsl-shaders/cpy.wgsl b/src/ggml-webgpu/wgsl-shaders/cpy.wgsl
new file mode 100644 (file)
index 0000000..fa3bdf4
--- /dev/null
@@ -0,0 +1,81 @@
+enable f16;
+
+#ifdef SRC_F32
+#define SRC_TYPE f32
+#elif defined(SRC_F16)
+#define SRC_TYPE f16
+#endif
+
+#ifdef DST_F32
+#define DST_TYPE f32
+#elif defined(DST_F16)
+#define DST_TYPE f16
+#elif defined(DST_I32)
+#define DST_TYPE i32
+#endif
+
+@group(0) @binding(0)
+var<storage, read_write> src: array<SRC_TYPE>;
+
+@group(0) @binding(1)
+var<storage, read_write> dst: array<DST_TYPE>;
+
+struct Params{
+    ne: u32,
+    offset_src: u32,
+    offset_dst: u32,
+
+    stride_src0: u32,
+    stride_src1: u32,
+    stride_src2: u32,
+    stride_src3: u32,
+
+
+    stride_dst0: u32,
+    stride_dst1: u32,
+    stride_dst2: u32,
+    stride_dst3: u32,
+
+    src_ne0: u32,
+    src_ne1: u32,
+    src_ne2: u32,
+
+    dst_ne0: u32,
+    dst_ne1: u32,
+    dst_ne2: u32
+};
+
+@group(0) @binding(2)
+var<uniform> params: Params;
+
+@compute @workgroup_size(WG_SIZE)
+fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
+    if (gid.x >= params.ne) {
+        return;
+    }
+
+    var i = gid.x;
+    let i3 = i / (params.src_ne2 * params.src_ne1 * params.src_ne0);
+    i = i % (params.src_ne2 * params.src_ne1 * params.src_ne0);
+    let i2 = i / (params.src_ne1 * params.src_ne0);
+    i = i % (params.src_ne1 * params.src_ne0);
+    let i1 = i / params.src_ne0;
+    let i0 = i % params.src_ne0;
+
+    var j = gid.x;
+    let j3 = j / (params.dst_ne2 * params.dst_ne1 * params.dst_ne0);
+    j = j % (params.dst_ne2 * params.dst_ne1 * params.dst_ne0);
+    let j2 = j / (params.dst_ne1 * params.dst_ne0);
+    j = j % (params.dst_ne1 * params.dst_ne0);
+    let j1 = j / params.dst_ne0;
+    let j0 = j % params.dst_ne0;
+
+    let src_idx = i0 * params.stride_src0 + i1 * params.stride_src1 +
+                  i2 * params.stride_src2 + i3 * params.stride_src3;
+
+    let dst_idx = j0 * params.stride_dst0 + j1 * params.stride_dst1 +
+                  j2 * params.stride_dst2 + j3 * params.stride_dst3;
+
+    dst[params.offset_dst + dst_idx] = DST_TYPE((src[params.offset_src + src_idx]));
+}
+
index 8b5cfe715e75488ec35e6229d3bab30e3a9f564e..79a3a9597abe1f424304c1f15c00edef63fe2af0 100755 (executable)
@@ -1,41 +1,8 @@
 import os
 import re
-import ast
 import argparse
 
 
-def extract_block(text, name):
-    pattern = rf'#define\({name}\)\s*(.*?)#end\({name}\)'
-    match = re.search(pattern, text, re.DOTALL)
-    if not match:
-        raise ValueError(f"Missing block: {name}")
-    return match.group(1).strip()
-
-
-def parse_decls(decls_text):
-    decls = {}
-    for name, code in re.findall(r'#decl\((.*?)\)\s*(.*?)#enddecl\(\1\)', decls_text, re.DOTALL):
-        decls[name.strip()] = code.strip()
-    return decls
-
-
-def replace_repl_placeholders(variant, template_map):
-    for repl, code in variant["REPLS"].items():
-        for key, val in template_map.items():
-            # Match "key" and avoid matching subsequences using by using \b
-            code = re.sub(rf'\b{re.escape(str(key))}\b', str(val), code)
-        variant["REPLS"][repl] = code
-    return variant
-
-
-def replace_placeholders(shader_text, replacements):
-    for key, val in replacements.items():
-        # Match {{KEY}} literally, where KEY is escaped
-        pattern = r'{{\s*' + re.escape(key) + r'\s*}}'
-        shader_text = re.sub(pattern, str(val), shader_text)
-    return shader_text
-
-
 def expand_includes(shader, input_dir):
     """
     Replace #include "file" lines in the text with the contents of that file.
@@ -98,84 +65,24 @@ def write_shader(shader_name, shader_code, output_dir, outfile, input_dir):
         outfile.write(f'const char* wgsl_{shader_name} = wgsl_{shader_name}_str().c_str();\n\n')
 
 
-def generate_variants(fname, input_dir, output_dir, outfile):
-    shader_path = os.path.join(input_dir, fname)
-    shader_base_name = fname.split(".")[0]
-
-    with open(shader_path, "r", encoding="utf-8") as f:
-        text = f.read()
-
-    try:
-        variants = ast.literal_eval(extract_block(text, "VARIANTS"))
-    except ValueError:
-        write_shader(shader_base_name, text, output_dir, outfile, input_dir)
-    else:
-        try:
-            decls_map = parse_decls(extract_block(text, "DECLS"))
-        except ValueError:
-            decls_map = {}
-        try:
-            templates_map = ast.literal_eval(extract_block(text, "REPL_TEMPLATES"))
-        except ValueError:
-            templates_map = {}
-
-        for fname in sorted(os.listdir(input_dir)):
-            if fname.endswith(".tmpl"):
-                tmpl_path = os.path.join(input_dir, fname)
-                with open(tmpl_path, "r", encoding="utf-8") as f_tmpl:
-                    decls = f_tmpl.read()
-                    decls_map.update(parse_decls(decls))
-
-        shader_template = extract_block(text, "SHADER")
-        for variant in variants:
-            if "DECLS" in variant:
-                decls = variant["DECLS"]
-            else:
-                decls = []
-            decls_code = ""
-            for key in decls:
-                if key not in decls_map:
-                    raise ValueError(f"DECLS key '{key}' not found.")
-                decls_code += decls_map[key] + "\n\n"
-            final_shader = re.sub(r'\bDECLS\b', decls_code, shader_template)
-            if "REPLS" in variant:
-                variant = replace_repl_placeholders(variant, templates_map)
-                final_shader = replace_placeholders(final_shader, variant["REPLS"])
-                # second run to expand placeholders in repl_template
-                final_shader = replace_placeholders(final_shader, variant["REPLS"])
-            final_shader = expand_includes(final_shader, input_dir)
-
-            if "SHADER_NAME" in variant:
-                output_name = variant["SHADER_NAME"]
-            elif "SHADER_SUFFIX" in variant:
-                output_name = f"{shader_base_name}_" + variant["SHADER_SUFFIX"]
-            elif "REPLS" in variant and "SRC0_TYPE" in variant["REPLS"] and "SRC1_TYPE" in variant["REPLS"]:
-                output_name = f"{shader_base_name}_" + "_".join([variant["REPLS"]["SRC0_TYPE"], variant["REPLS"]["SRC1_TYPE"]])
-            elif "REPLS" in variant and "SRC_TYPE" in variant["REPLS"] and "DST_TYPE" in variant["REPLS"]:
-                output_name = f"{shader_base_name}_" + "_".join([variant["REPLS"]["SRC_TYPE"], variant["REPLS"]["DST_TYPE"]])
-            elif "REPLS" in variant and "TYPE" in variant["REPLS"]:
-                output_name = f"{shader_base_name}_" + variant["REPLS"]["TYPE"]
-            else:
-                output_name = shader_base_name
-            write_shader(output_name, final_shader, output_dir, outfile, input_dir)
-
-
 def main():
     parser = argparse.ArgumentParser()
     parser.add_argument("--input_dir", required=True)
     parser.add_argument("--output_file", required=True)
-    parser.add_argument("--output_dir")
     args = parser.parse_args()
 
-    if args.output_dir:
-        os.makedirs(args.output_dir, exist_ok=True)
-
     with open(args.output_file, "w", encoding="utf-8") as out:
         out.write("// Auto-generated shader embedding\n")
         out.write("#include <string>\n\n")
         for fname in sorted(os.listdir(args.input_dir)):
             if fname.endswith(".wgsl"):
-                generate_variants(fname, args.input_dir, args.output_dir, out)
+                shader_path = os.path.join(args.input_dir, fname)
+                shader_name = fname.replace(".wgsl", "")
+
+                with open(shader_path, "r", encoding="utf-8") as f:
+                    shader_code = f.read()
+
+                write_shader(shader_name, shader_code, None, out, args.input_dir)
 
 
 if __name__ == "__main__":
diff --git a/src/ggml-webgpu/wgsl-shaders/glu.wgsl b/src/ggml-webgpu/wgsl-shaders/glu.wgsl
new file mode 100644 (file)
index 0000000..e6d7608
--- /dev/null
@@ -0,0 +1,155 @@
+enable f16;
+
+#ifdef TYPE_F32
+#define DataType f32
+#endif
+#ifdef TYPE_F16
+#define DataType f16
+#endif
+
+#ifdef OP_REGLU
+fn op(a: DataType, b: DataType) -> DataType {
+    return max(a, 0) * b;
+}
+#endif
+
+#ifdef OP_GEGLU
+const SQRT_2_OVER_PI: DataType =  0.79788456080286535587989211986876;
+const GELU_COEF_A: DataType = 0.044715;
+
+fn op(a: DataType, b: DataType) -> DataType {
+    let val = SQRT_2_OVER_PI * a * (1.0 + GELU_COEF_A * a * a);
+    return 0.5 * a * (2.0 - 2.0/ (exp(2* val) + 1)) * b;
+}
+#endif
+
+#ifdef OP_SWIGLU
+fn op(a: DataType, b: DataType) -> DataType {
+    return a / (1.0 + exp(-a)) * b;
+}
+#endif
+#ifdef OP_SWIGLU_OAI
+fn op(a: f32, b: f32) -> f32 {
+    let xi = min(a, params.limit);
+    let gi = max(min(b, params.limit), -params.limit);
+    var out_glu = xi / (1.0 + exp(-xi * params.alpha));
+    out_glu = out_glu * (1.0 + gi);
+    return out_glu;
+}
+#endif
+#ifdef OP_GEGLU_ERF
+const p_erf: DataType = 0.3275911;
+const a1_erf: DataType = 0.254829592;
+const a2_erf: DataType = -0.284496736;
+const a3_erf: DataType = 1.421413741;
+const a4_erf: DataType = -1.453152027;
+const a5_erf: DataType = 1.061405429;
+const SQRT_2_INV: DataType = 0.7071067811865476;
+
+fn op(a: DataType, b: DataType) -> DataType {
+    let a_div_sqr2 = a * SQRT_2_INV;
+    let sign_x = sign(a_div_sqr2);
+    let x = abs(a_div_sqr2);
+    let t = 1.0 / (1.0 + p_erf * x);
+    let y = 1.0 - (((((a5_erf * t + a4_erf) * t + a3_erf) * t + a2_erf) * t + a1_erf) * t * exp(-x * x));
+    let erf_approx = sign_x * y;
+    return 0.5 * a * (1.0 + erf_approx) * b;
+}
+#endif
+#ifdef OP_GEGLU_QUICK
+const GELU_QUICK_COEF: DataType = -1.702;
+
+fn op(a: DataType, b: DataType) -> DataType {
+    return a * (1.0 / (1.0 + exp(GELU_QUICK_COEF * a))) * b;
+}
+#endif
+
+struct Params {
+    offset_src0: u32,
+    offset_src1: u32,
+    offset_dst: u32,
+
+    // Strides (in elements)
+    stride_src01: u32,
+    stride_src02: u32,
+    stride_src03: u32,
+
+    stride_src11: u32,
+    stride_src12: u32,
+    stride_src13: u32,
+
+    stride_dst1: u32,
+    stride_dst2: u32,
+    stride_dst3: u32,
+
+    // shape of dst
+    ne: u32,
+    ne0: u32,
+    ne1: u32,
+    ne2: u32,
+
+    swapped: u32,
+    alpha: f32,
+    limit: f32,
+}
+
+@group(0) @binding(0)
+var<storage, read_write> src0: array<DataType>;
+
+#ifdef NO_SPLIT
+@group(0) @binding(1)
+var<storage, read_write> dst: array<DataType>;
+
+@group(0) @binding(2)
+var<uniform> params: Params;
+
+fn a_value(base: u32) -> DataType {
+    let offset: u32 = select(0, params.ne0, params.swapped != 0);
+    return src0[base + offset];
+}
+
+fn b_value(base: u32) -> DataType {
+    let offset: u32 = select(params.ne0, 0, params.swapped != 0);
+    return src0[base + offset];
+}
+
+#else
+@group(0) @binding(1)
+var<storage, read_write> src1: array<DataType>;
+
+@group(0) @binding(2)
+var<storage, read_write> dst: array<DataType>;
+
+@group(0) @binding(3)
+var<uniform> params: Params;
+
+fn a_value(base: u32) -> DataType {
+    return src0[base];
+}
+
+fn b_value(base: u32) -> DataType {
+    return src1[base];
+}
+
+#endif
+
+@compute @workgroup_size(WG_SIZE)
+fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
+    if (gid.x >= params.ne) {
+        return;
+    }
+
+    var i = gid.x;
+    let i3 = i / (params.ne2 * params.ne1 * params.ne0);
+    i = i % (params.ne2 * params.ne1 * params.ne0);
+    let i2 = i / (params.ne1 * params.ne0);
+    i = i % (params.ne1 * params.ne0);
+    let i1 = i / params.ne0;
+    let i0 = i % params.ne0;
+
+    let i_a = params.offset_src0 + i3 * params.stride_src03 + i2 * params.stride_src02 + i1 * params.stride_src01 + i0;
+    let i_b = params.offset_src1 + i3 * params.stride_src13 + i2 * params.stride_src12 + i1 * params.stride_src11 + i0;
+    let i_dst = params.offset_dst + i3 * params.stride_dst3 + i2 * params.stride_dst2 + i1 * params.stride_dst1 + i0;
+
+    dst[i_dst] = op(a_value(i_a), b_value(i_b));
+}
diff --git a/src/ggml-webgpu/wgsl-shaders/rope.wgsl b/src/ggml-webgpu/wgsl-shaders/rope.wgsl
new file mode 100644 (file)
index 0000000..1c874e1
--- /dev/null
@@ -0,0 +1,224 @@
+enable f16;
+
+#ifdef TYPE_F32
+#define DataType f32
+#endif
+#ifdef TYPE_F16
+#define DataType f16
+#endif
+
+struct Params {
+    offset_src0: u32,
+    offset_src1: u32,
+    offset_src2: u32,
+    offset_dst: u32,
+
+    // Strides (in elements)
+    stride_src01: u32,
+    stride_src02: u32,
+    stride_src03: u32,
+
+    stride_dst1: u32,
+    stride_dst2: u32,
+    stride_dst3: u32,
+
+    n_threads: u32,
+    ne0: u32,
+    ne1: u32,
+    ne2: u32,
+
+    n_dims: u32,
+    mode: u32,
+    theta_scale: f32,
+    attn_factor: f32,
+    freq_scale: f32,
+    ext_factor: f32,
+    corr_dim0: f32,
+    corr_dim1: f32,
+    sections0: u32,
+    sections1: u32,
+    sections2: u32,
+    sections3: u32
+};
+
+@group(0) @binding(0)
+var<storage, read_write> src0: array<DataType>;
+@group(0) @binding(1)
+var<storage, read_write> src1: array<i32>;
+
+#ifdef INPLACE
+
+#ifdef FF_FUNC
+
+@group(0) @binding(2)
+var<storage, read_write> src2: array<f32>;
+
+@group(0) @binding(3)
+var<uniform> params: Params;
+
+#else
+
+@group(0) @binding(2)
+var<uniform> params: Params;
+
+#endif
+
+#else
+
+#ifdef FF_FUNC
+@group(0) @binding(2)
+var<storage, read_write> src2: array<f32>;
+
+@group(0) @binding(3)
+var<storage, read_write> dst: array<DataType>;
+
+@group(0) @binding(4)
+var<uniform> params: Params;
+
+#else
+@group(0) @binding(2)
+var<storage, read_write> dst: array<DataType>;
+
+@group(0) @binding(3)
+var<uniform> params: Params;
+#endif
+#endif
+
+#ifdef FF_FUNC
+fn freq_factor(i: u32) -> f32 {
+    return src2[params.offset_src2 + i/2];
+}
+
+#else
+fn freq_factor(i: u32) -> f32 {
+    return 1.0f;
+}
+#endif
+#ifdef INPLACE
+fn rotate(i_dst0: u32, i_dst1: u32, out0: f32, out1: f32) {
+    src0[i_dst0] = DataType(out0);
+    src0[i_dst1] = DataType(out1);
+}
+#else
+fn rotate(i_dst0: u32, i_dst1: u32, out0: f32, out1: f32) {
+    dst[i_dst0] = DataType(out0);
+    dst[i_dst1] = DataType(out1);
+}
+#endif
+
+fn rope_yarn_ramp(low: f32, high: f32, i: u32) -> f32 {
+    let y = (f32(i / 2) - low) / max(0.001f, high - low);
+    return 1.0f - min(1.0f, max(0.0f, y));
+}
+
+// returns vector of (cos_theta, sin_theta)
+// TODO: check performance of instantiating once on the CPU and passed as buffer, since it's repeated per-row
+fn rope_yarn(theta_extrap: f32, i: u32) -> vec2<f32> {
+    var mscale = params.attn_factor;
+    var theta  = params.freq_scale * theta_extrap;
+    if (params.ext_factor != 0.0f) {
+        let ramp_mix = rope_yarn_ramp(params.corr_dim0, params.corr_dim1, i) * params.ext_factor;
+        theta = theta * (1 - ramp_mix) + theta_extrap * ramp_mix;
+        mscale *= 1.0f + 0.1f * log(1.0f / params.freq_scale);
+    }
+    return vec2<f32>(cos(theta) * mscale, sin(theta) * mscale);
+}
+
+fn pair_base(i0: u32, div_2: bool) -> u32 {
+    if (div_2) {
+        return i0 / 2;
+    } else {
+        return i0;
+    }
+}
+
+fn pair_offset(is_neox: bool, is_mrope: bool, is_vision: bool) -> u32 {
+    if (is_vision) {
+        return params.n_dims;
+    } else if (is_neox || is_mrope) {
+        return params.n_dims / 2;
+    } else {
+        return 1;
+    }
+}
+
+@compute @workgroup_size(WG_SIZE)
+fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
+    // two elements per n_threads
+    if (gid.x >= params.n_threads) {
+        return;
+    }
+
+    let is_neox = bool(params.mode & 2);
+    let is_mrope = bool(params.mode & 8);
+    let is_imrope = params.mode == 40;
+    let is_vision = params.mode == 24;
+
+    var i = gid.x * 2; // start index for this thread
+    let i3 = i / (params.ne2 * params.ne1 * params.ne0);
+    i = i % (params.ne2 * params.ne1 * params.ne0);
+    let i2 = i / (params.ne1 * params.ne0);
+    i = i % (params.ne1 * params.ne0);
+    let i1 = i / params.ne0;
+    let i0 = i % params.ne0;
+
+    let i_src_row = params.offset_src0 + i3 * params.stride_src03 + i2 * params.stride_src02 + i1 * params.stride_src01;
+    let i_dst_row = params.offset_dst + i3 * params.stride_dst3 + i2 * params.stride_dst2 + i1 * params.stride_dst1;
+
+    if (i0 >= params.n_dims && !is_vision) {
+        let i_src = i_src_row + i0;
+        let i_dst = i_dst_row + i0;
+        rotate(i_dst, i_dst + 1, f32(src0[i_src]), f32(src0[i_src + 1]));
+        return;
+    }
+
+    var theta_base_mult: u32 = 0;
+    var theta_scale_pwr: u32 = i0 / 2;
+    if (is_mrope) {
+        let sect_dims = params.sections0 + params.sections1 + params.sections2 + params.sections3;
+        let sec_w = params.sections1 + params.sections0;
+        let sec_e = params.sections2 + sec_w;
+        let sector = (i0 / 2) % sect_dims;
+        if (is_imrope) {
+          if (sector % 3 == 1 && sector < 3 * params.sections1) {
+              theta_base_mult = 1;
+          } else if (sector % 3 == 2 && sector < 3 * params.sections2) {
+              theta_base_mult = 2;
+          } else if (sector % 3 == 0 && sector < 3 * params.sections0) {
+              theta_base_mult = 0;
+          } else {
+              theta_base_mult = 3;
+          }
+        } else {
+          if (sector >= params.sections0 && sector < sec_w) {
+              theta_base_mult = 1;
+              if (is_vision) {
+                  theta_scale_pwr = sector - params.sections0;
+              }
+          } else if (sector >= sec_w && sector < sec_e) {
+              theta_base_mult = 2;
+              if (is_vision) {
+                  theta_scale_pwr = sector - sec_w;
+              }
+          } else if (sector >= sec_e) {
+              if (is_vision) {
+                  theta_scale_pwr = sector - sec_e;
+                  theta_scale_pwr = (i0 / 2) % sec_e;
+              }
+              theta_base_mult = 3;
+          } else if (is_vision) {
+              theta_scale_pwr = sector;
+          }
+        }
+    }
+    let theta_base = f32(src1[params.offset_src1 + i2 + params.ne2 * theta_base_mult]) * pow(params.theta_scale, f32(theta_scale_pwr));
+    let thetas = rope_yarn(theta_base/freq_factor(i0), i0);
+
+    let i_src = i_src_row + pair_base(i0, is_neox || is_mrope || is_vision);
+    let i_dst = i_dst_row + pair_base(i0, is_neox || is_mrope || is_vision);
+
+    let x0 = f32(src0[i_src]);
+    let x1 = f32(src0[i_src + pair_offset(is_neox, is_mrope, is_vision)]);
+    rotate(i_dst, i_dst + pair_offset(is_neox, is_mrope, is_vision), x0 * thetas.x - x1 * thetas.y, x0 * thetas.y + x1 * thetas.x);
+
+}
diff --git a/src/ggml-webgpu/wgsl-shaders/soft_max.wgsl b/src/ggml-webgpu/wgsl-shaders/soft_max.wgsl
new file mode 100644 (file)
index 0000000..10edf13
--- /dev/null
@@ -0,0 +1,245 @@
+enable f16;
+
+#ifdef MASK_F32
+#define MaskType f32
+#endif
+#ifdef MASK_F16
+#define MaskType f16
+#endif
+
+struct Params {
+    offset_src0: u32,
+    offset_src1: u32,
+    offset_sinks: u32,
+    offset_dst: u32,
+
+    // Strides (in elements)
+    stride_src01: u32,
+    stride_src02: u32,
+    stride_src03: u32,
+
+    stride_src11: u32,
+    stride_src12: u32,
+    stride_src13: u32,
+
+    stride_dst1: u32,
+    stride_dst2: u32,
+    stride_dst3: u32,
+
+    // shape of src0/dst
+    ne: u32,
+    ne0: u32,
+    ne1: u32,
+    ne2: u32,
+
+    // shape of src1
+    ne12: u32,
+    ne13: u32,
+
+    scale: f32,
+    max_bias: f32,
+    n_head_log2: f32,
+    m0: f32,
+    m1: f32,
+};
+
+@group(0) @binding(0)
+var<storage, read_write> src: array<f32>;
+
+#ifdef HAS_MASK
+#ifdef HAS_SINK
+@group(0) @binding(1)
+var<storage, read_write> mask: array<MaskType>;
+@group(0) @binding(2)
+var<storage, read_write> sinks: array<f32>;
+
+#ifdef INPLACE
+@group(0) @binding(3)
+var<uniform> params: Params;
+
+#else
+@group(0) @binding(3)
+var<storage, read_write> dst: array<f32>;
+@group(0) @binding(4)
+var<uniform> params: Params;
+#endif
+
+#else
+@group(0) @binding(1)
+var<storage, read_write> mask: array<MaskType>;
+
+#ifdef INPLACE
+@group(0) @binding(2)
+var<uniform> params: Params;
+
+#else
+@group(0) @binding(2)
+var<storage, read_write> dst: array<f32>;
+@group(0) @binding(3)
+var<uniform> params: Params;
+#endif
+#endif
+
+#else
+#ifdef HAS_SINK
+@group(0) @binding(1)
+var<storage, read_write> sinks: array<f32>;
+
+#ifdef INPLACE
+@group(0) @binding(2)
+var<uniform> params: Params;
+
+#else
+@group(0) @binding(2)
+var<storage, read_write> dst: array<f32>;
+@group(0) @binding(3)
+var<uniform> params: Params;
+#endif
+
+#else
+#ifdef INPLACE
+@group(0) @binding(1)
+var<uniform> params: Params;
+#else
+@group(0) @binding(1)
+var<storage, read_write> dst: array<f32>;
+@group(0) @binding(2)
+var<uniform> params: Params;
+#endif
+#endif
+#endif
+
+#ifdef INPLACE
+fn inter_value(i: u32) -> f32 {
+    return src[i];
+}
+fn update(i: u32, val: f32) {
+    src[i] = val;
+}
+
+#else
+fn inter_value(i: u32) -> f32 {
+    return dst[i];
+}
+fn update(i: u32, val: f32) {
+    dst[i] = val;
+}
+#endif
+
+#ifdef HAS_MASK
+fn mask_val(i: u32) -> f32 {
+    return f32(mask[i]);
+}
+
+#else
+fn mask_val(i: u32) -> f32 {
+    return 0.0;
+}
+#endif
+
+#ifdef HAS_SINK
+fn lower_max_bound(i2: u32) -> f32 {
+    return sinks[params.offset_sinks + i2];
+}
+fn add_sinks(val: f32, i2: u32, max_val: f32) -> f32 {
+    return val + exp(sinks[params.offset_sinks + i2] - max_val);
+}
+#else
+fn lower_max_bound(i2: u32) -> f32 {
+    return -1e30;
+}
+fn add_sinks(val: f32, i2: u32, max_val: f32) -> f32 {
+    return val;
+}
+#endif
+
+const CACHE_SIZE: u32 = 16;
+var<workgroup> scratch: array<f32, WG_SIZE>;
+
+@compute @workgroup_size(WG_SIZE)
+fn main(@builtin(workgroup_id) wid: vec3<u32>,
+        @builtin(local_invocation_id) lid: vec3<u32>) {
+
+    var i = wid.x;
+    let i3 = i / (params.ne2 * params.ne1);
+    i = i % (params.ne2 * params.ne1);
+    let i2 = i / params.ne1;
+    let i1 = i % params.ne1;
+    let i_src0_row = params.offset_src0 + i3 * params.stride_src03 + i2 * params.stride_src02 + i1 * params.stride_src01;
+    let i_src1_row = params.offset_src1 + (i3 % params.ne13) * params.stride_src13 + (i2 % params.ne12) * params.stride_src12 + i1 * params.stride_src11;
+    let i_dst_row = params.offset_dst + i3 * params.stride_dst3 + i2 * params.stride_dst2 + i1 * params.stride_dst1;
+    let elems = (params.ne0 + WG_SIZE - 1) / WG_SIZE;
+
+    let head = f32(i2);
+    let slope = select(1, select(pow(params.m1, 2 * (head - params.n_head_log2) + 1), pow(params.m0, head + 1), head < params.n_head_log2), params.max_bias > 0);
+
+    var cache: array<f32, CACHE_SIZE>;
+
+    var max_val = lower_max_bound(i2);
+    var col = lid.x;
+    for (var j: u32 = 0; j < elems; j++) {
+        if (col >= params.ne0) {
+            break;
+        }
+        let val = src[i_src0_row + col] * params.scale + slope * mask_val(i_src1_row + col);
+        max_val = max(max_val, val);
+        if (col < CACHE_SIZE) {
+            cache[col] = val;
+        }
+        col += WG_SIZE;
+    }
+
+    scratch[lid.x] = max_val;
+    workgroupBarrier();
+    var offset: u32 = WG_SIZE / 2;
+    while (offset > 0) {
+        if (lid.x < offset) {
+            scratch[lid.x] = max(scratch[lid.x], scratch[lid.x + offset]);
+        }
+        offset = offset / 2;
+        workgroupBarrier();
+    }
+    let row_max = scratch[0];
+    workgroupBarrier();
+
+    var sum = 0.0f;
+    col = lid.x;
+    for (var j: u32 = 0; j < elems; j++) {
+        if (col >= params.ne0) {
+            break;
+        }
+        let val = select(src[i_src0_row + col] * params.scale + slope * mask_val(i_src1_row + col),
+                         cache[col], col < CACHE_SIZE);
+        let ex = exp(val - row_max);
+        sum += ex;
+        if (col < CACHE_SIZE) {
+            cache[col] = ex;
+        } else {
+            update(i_dst_row + col, ex);
+        }
+        col += WG_SIZE;
+    }
+
+    scratch[lid.x] = sum;
+    workgroupBarrier();
+    offset = WG_SIZE / 2;
+    while (offset > 0) {
+        if (lid.x < offset) {
+            scratch[lid.x] += scratch[lid.x + offset];
+        }
+        offset = offset / 2;
+        workgroupBarrier();
+    }
+    let row_sum = add_sinks(scratch[0], i2, row_max);
+
+    let sum_recip = 1.0 / row_sum;
+    col = lid.x;
+    for  (var j: u32 = 0; j < elems; j++) {
+        if (col >= params.ne0) {
+            break;
+        }
+        update(i_dst_row + col, select(inter_value(i_dst_row + col), cache[col], col < CACHE_SIZE) * sum_recip);
+        col += WG_SIZE;
+    }
+}
+