]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
ggml webgpu: shader library organization (#19530)
authorReese Levine <redacted>
Wed, 18 Feb 2026 14:51:02 +0000 (07:51 -0700)
committerGitHub <redacted>
Wed, 18 Feb 2026 14:51:02 +0000 (07:51 -0700)
* Basic JIT compilation for mul_mat, get_rows, and scale (#17)

* scale jit working

* preliminary working jit for getrows and mulmat, needs refining

* simplified mul_mat preprocessing switch statement

* get_rows fixes, mul_mat refinement

* formatted + last edits

* removed some extraneous prints

* fixed get_rows, fixed workgroup dispatch in mul_mat. no gibberish

* small fix

* some changes, working

* get_rows and mul_mat jit fixed and working

* Update formatting

* formatting

* Add header

---------

Co-authored-by: Neha Abbas <redacted>
Co-authored-by: Reese Levine <redacted>
* Start work on all-encompassing shader library

* refactor argmax, set_rows

* Refactor all but flashattention, mat mul

* flashattention and matrix multiplication moved to new format

* clean up preprocessing

* Formatting

* remove duplicate constants

* Split large shaders into multiple static strings

---------

Co-authored-by: neha-ha <redacted>
17 files changed:
ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp
ggml/src/ggml-webgpu/ggml-webgpu.cpp
ggml/src/ggml-webgpu/wgsl-shaders/common_decls.tmpl
ggml/src/ggml-webgpu/wgsl-shaders/embed_wgsl.py
ggml/src/ggml-webgpu/wgsl-shaders/get_rows.tmpl.wgsl [deleted file]
ggml/src/ggml-webgpu/wgsl-shaders/get_rows.wgsl [new file with mode: 0644]
ggml/src/ggml-webgpu/wgsl-shaders/mul_mat.tmpl.wgsl [deleted file]
ggml/src/ggml-webgpu/wgsl-shaders/mul_mat.wgsl [new file with mode: 0644]
ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_decls.tmpl
ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_reg_tile.tmpl.wgsl [deleted file]
ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_reg_tile.wgsl [new file with mode: 0644]
ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_subgroup_matrix.tmpl.wgsl [deleted file]
ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_subgroup_matrix.wgsl [new file with mode: 0644]
ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec.tmpl.wgsl [deleted file]
ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec.wgsl [new file with mode: 0644]
ggml/src/ggml-webgpu/wgsl-shaders/scale.tmpl.wgsl [deleted file]
ggml/src/ggml-webgpu/wgsl-shaders/scale.wgsl [new file with mode: 0644]

index 63f797f142d0c14c1e40d0403d042665647bb188..0d5a818dacb598e5e29947f59b8f0a9fe1c3929b 100644 (file)
@@ -1,11 +1,16 @@
 #ifndef GGML_WEBGPU_SHADER_LIB_HPP
 #define GGML_WEBGPU_SHADER_LIB_HPP
 
+#include "ggml-wgsl-shaders.hpp"
 #include "ggml.h"
 #include "pre_wgsl.hpp"
 
+#include <webgpu/webgpu_cpp.h>
+
+#include <algorithm>
 #include <memory>
 #include <string>
+#include <unordered_map>
 #include <vector>
 
 #define GGML_WEBGPU_F16_SIZE_BYTES                   2
 
 #define GGML_WEBGPU_ARGSORT_MERGE_MAX_WG_SIZE 512u
 
-struct ggml_webgpu_processed_shader {
-    std::string           wgsl;
-    std::string           variant;
-    std::shared_ptr<void> decisions;
-};
+// Matrix multiplication parameters
+
+// Register tiling parameters
+#define WEBGPU_MUL_MAT_TILE_M    8
+#define WEBGPU_MUL_MAT_TILE_N    8
+#define WEBGPU_MUL_MAT_WG_SIZE_M 8
+#define WEBGPU_MUL_MAT_WG_SIZE_N 8
+#define WEBGPU_MUL_MAT_TILE_K    32
+
+// Subgroup matrix parameters
+// The number of subgroups in the M dimension
+#define WEBGPU_MUL_MAT_SUBGROUP_M        2
+// The number of subgroups in the N dimension
+#define WEBGPU_MUL_MAT_SUBGROUP_N        2
+// The number of subgroup matrices each subgroup accumulates over
+#define WEBGPU_MUL_MAT_SUBGROUP_MATRIX_M 4
+#define WEBGPU_MUL_MAT_SUBGROUP_MATRIX_N 2
+
+// Matrix-vector multiplication parameters
+#define WEBGPU_MUL_MAT_VEC_WG_SIZE        256
+// Must be multiple of 4 to work with vectorized paths, and must divide
+// mul_mat_vec wg size
+#define WEBGPU_MUL_MAT_VEC_OUTPUTS_PER_WG 64
+#define WEBGPU_MUL_MAT_VEC_TILE_K         256
+
+// default size for legacy matrix multiplication
+#define WEBGPU_MUL_MAT_WG_SIZE 256
 
 // Same hash combine function as in boost
 template <typename T> inline void ggml_webgpu_hash_combine(size_t & seed, const T & value) {
     seed ^= std::hash<T>{}(value) + 0x9e3779b9 + (seed << 6) + (seed >> 2);
 }
 
+struct ggml_webgpu_shader_lib_context {
+    ggml_tensor * src0;
+    ggml_tensor * src1;
+    ggml_tensor * src2;
+    ggml_tensor * src3;
+    ggml_tensor * src4;
+    ggml_tensor * dst;
+
+    uint32_t max_wg_size;
+    size_t   wg_mem_limit_bytes       = 0;
+    bool     inplace                  = false;
+    bool     overlap                  = false;
+    bool     supports_subgroup_matrix = false;
+    uint32_t sg_mat_m                 = 0;
+    uint32_t sg_mat_n                 = 0;
+    uint32_t sg_mat_k                 = 0;
+    uint32_t max_subgroup_size        = 0;
+};
+
+struct webgpu_pipeline {
+    wgpu::ComputePipeline pipeline;
+    std::string           name;
+    std::shared_ptr<void> context = nullptr;
+};
+
+struct ggml_webgpu_generic_shader_decisions {
+    uint32_t wg_size = 0;
+};
+
+/** Argsort **/
+
+struct ggml_webgpu_argsort_shader_lib_context {
+    uint32_t max_wg_size;
+    size_t   wg_mem_limit_bytes;
+    int32_t  order;
+};
+
+/** Set Rows **/
+
+struct ggml_webgpu_set_rows_pipeline_key {
+    int dst_type;
+    int vec4;
+    int i64_idx;
+
+    bool operator==(const ggml_webgpu_set_rows_pipeline_key & other) const {
+        return dst_type == other.dst_type && vec4 == other.vec4 && i64_idx == other.i64_idx;
+    }
+};
+
+struct ggml_webgpu_set_rows_pipeline_key_hash {
+    size_t operator()(const ggml_webgpu_set_rows_pipeline_key & key) const {
+        size_t seed = 0;
+        ggml_webgpu_hash_combine(seed, key.dst_type);
+        ggml_webgpu_hash_combine(seed, key.vec4);
+        ggml_webgpu_hash_combine(seed, key.i64_idx);
+        return seed;
+    }
+};
+
+struct ggml_webgpu_set_rows_shader_decisions {
+    bool     vec4;
+    bool     i64_idx;
+    uint32_t wg_size;
+};
+
+/** Get Rows **/
+
+struct ggml_webgpu_get_rows_pipeline_key {
+    ggml_type src_type;
+    int       vectorized;
+
+    bool operator==(const ggml_webgpu_get_rows_pipeline_key & other) const {
+        return src_type == other.src_type && vectorized == other.vectorized;
+    }
+};
+
+struct ggml_webgpu_get_rows_pipeline_key_hash {
+    size_t operator()(const ggml_webgpu_get_rows_pipeline_key & key) const {
+        size_t seed = 0;
+        ggml_webgpu_hash_combine(seed, key.src_type);
+        ggml_webgpu_hash_combine(seed, key.vectorized);
+        return seed;
+    }
+};
+
+/** Pad **/
+struct ggml_webgpu_pad_pipeline_key {
+    bool circular;
+
+    bool operator==(const ggml_webgpu_pad_pipeline_key & other) const { return circular == other.circular; }
+};
+
+struct ggml_webgpu_pad_pipeline_key_hash {
+    size_t operator()(const ggml_webgpu_pad_pipeline_key & key) const {
+        size_t seed = 0;
+        ggml_webgpu_hash_combine(seed, key.circular);
+        return seed;
+    }
+};
+
+/** Scale **/
+
+struct ggml_webgpu_scale_pipeline_key {
+    int inplace;
+
+    bool operator==(const ggml_webgpu_scale_pipeline_key & other) const { return inplace == other.inplace; }
+};
+
+struct ggml_webgpu_scale_pipeline_key_hash {
+    size_t operator()(const ggml_webgpu_scale_pipeline_key & key) const {
+        size_t seed = 0;
+        ggml_webgpu_hash_combine(seed, key.inplace);
+        return seed;
+    }
+};
+
+/** Binary **/
+
+struct ggml_webgpu_binary_pipeline_key {
+    int  type;
+    int  op;
+    bool inplace;
+    bool overlap;
+
+    bool operator==(const ggml_webgpu_binary_pipeline_key & other) const {
+        return type == other.type && op == other.op && inplace == other.inplace && overlap == other.overlap;
+    }
+};
+
+struct ggml_webgpu_binary_pipeline_key_hash {
+    size_t operator()(const ggml_webgpu_binary_pipeline_key & key) const {
+        size_t seed = 0;
+        ggml_webgpu_hash_combine(seed, key.type);
+        ggml_webgpu_hash_combine(seed, key.op);
+        ggml_webgpu_hash_combine(seed, key.inplace);
+        ggml_webgpu_hash_combine(seed, key.overlap);
+        return seed;
+    }
+};
+
+/** Unary **/
+
+struct ggml_webgpu_unary_pipeline_key {
+    int  type;
+    int  op;
+    bool is_unary;  // many unary operators fall under the GGML_OP_UNARY umbrella
+    bool inplace;
+
+    bool operator==(const ggml_webgpu_unary_pipeline_key & other) const {
+        return type == other.type && op == other.op && is_unary == other.is_unary && inplace == other.inplace;
+    }
+};
+
+struct ggml_webgpu_unary_pipeline_key_hash {
+    size_t operator()(const ggml_webgpu_unary_pipeline_key & key) const {
+        size_t seed = 0;
+        ggml_webgpu_hash_combine(seed, key.type);
+        ggml_webgpu_hash_combine(seed, key.op);
+        ggml_webgpu_hash_combine(seed, key.is_unary);
+        ggml_webgpu_hash_combine(seed, key.inplace);
+        return seed;
+    }
+};
+
 /** FlashAttention */
 
 struct ggml_webgpu_flash_attn_pipeline_key {
@@ -100,439 +291,941 @@ inline size_t ggml_webgpu_flash_attn_wg_mem_bytes(uint32_t q_tile,
     return f16_elems * GGML_WEBGPU_F16_SIZE_BYTES + f32_elems * GGML_WEBGPU_F32_SIZE_BYTES;
 }
 
-static uint32_t ggml_webgpu_flash_attn_max_kv_tile(const ggml_webgpu_flash_attn_shader_lib_context & context) {
-    const size_t limit_bytes = context.wg_mem_limit_bytes;
-    const size_t q_tile      = context.sg_mat_m;
-    const size_t base_q_bytes =
-        (context.key.head_dim_qk + context.key.head_dim_v) * q_tile * GGML_WEBGPU_F16_SIZE_BYTES +
-        2 * q_tile * GGML_WEBGPU_F32_SIZE_BYTES;
-    size_t bytes_per_kv = 0;
-    if (!context.key.kv_direct) {
-        bytes_per_kv += std::max(context.key.head_dim_qk, context.key.head_dim_v);
-    }
-    if (context.key.has_mask) {
-        bytes_per_kv += q_tile;
-    }
-    bytes_per_kv += q_tile;
-    bytes_per_kv *= GGML_WEBGPU_F16_SIZE_BYTES;
-    const uint32_t max_kv_tile = (limit_bytes - base_q_bytes) / bytes_per_kv;
-    return (max_kv_tile / context.sg_mat_n) * context.sg_mat_n;
-}
-
-inline ggml_webgpu_processed_shader ggml_webgpu_preprocess_flash_attn_shader(
-    pre_wgsl::Preprocessor &                          preprocessor,
-    const char *                                      shader_src,
-    const ggml_webgpu_flash_attn_shader_lib_context & context) {
-    std::vector<std::string> defines;
-    std::string              variant = "flash_attn";
-
-    switch (context.key.kv_type) {
-        case GGML_TYPE_F32:
-            defines.push_back("KV_F32");
-            break;
-        case GGML_TYPE_F16:
-            defines.push_back("KV_F16");
-            break;
-        case GGML_TYPE_Q4_0:
-            defines.push_back("KV_Q4_0");
-            break;
-        case GGML_TYPE_Q8_0:
-            defines.push_back("KV_Q8_0");
-            break;
-        default:
-            GGML_ABORT("Unsupported KV type for flash attention shader");
-    }
-    variant += std::string("_") + ggml_type_name(context.key.kv_type);
-
-    if (context.key.has_mask) {
-        defines.push_back("MASK");
-        variant += "_mask";
-    }
-    if (context.key.has_sinks) {
-        defines.push_back("SINKS");
-        variant += "_sinks";
-    }
-    if (context.key.uses_logit_softcap) {
-        defines.push_back("LOGIT_SOFTCAP");
-        variant += "_lgsc";
-    }
-
-    if (context.key.kv_direct) {
-        defines.push_back("KV_DIRECT");
-        variant += "_kvdirect";
-    }
-
-    defines.push_back(std::string("HEAD_DIM_QK=") + std::to_string(context.key.head_dim_qk));
-    variant += std::string("_hsqk") + std::to_string(context.key.head_dim_qk);
-
-    defines.push_back(std::string("HEAD_DIM_V=") + std::to_string(context.key.head_dim_v));
-    variant += std::string("_hsv") + std::to_string(context.key.head_dim_v);
-    // For now these are not part of the variant name
-    defines.push_back(std::string("SG_MAT_M=") + std::to_string(context.sg_mat_m));
-    defines.push_back(std::string("SG_MAT_N=") + std::to_string(context.sg_mat_n));
-    defines.push_back(std::string("SG_MAT_K=") + std::to_string(context.sg_mat_k));
-
-    // Add chosen Q/KV tile sizes
-    uint32_t q_tile  = context.sg_mat_m;
-    uint32_t kv_tile = std::min(ggml_webgpu_flash_attn_max_kv_tile(context),
-                                context.sg_mat_n * GGML_WEBGPU_FLASH_ATTN_PREFERRED_KV_SG_TILES);
-    if (context.key.kv_direct) {
-        GGML_ASSERT(kv_tile <= GGML_WEBGPU_KV_SEQ_PAD);
-        // Avoids having to use bounds-checks and decreasing performance for direct KV loads
-        while (GGML_WEBGPU_KV_SEQ_PAD % kv_tile != 0) {
-            kv_tile -= context.sg_mat_n;
-        }
-    }
-
-    defines.push_back(std::string("Q_TILE=") + std::to_string(q_tile));
-    defines.push_back(std::string("KV_TILE=") + std::to_string(kv_tile));
-
-    // workgroup size
-    uint32_t wg_size = std::max(context.max_subgroup_size, GGML_WEBGPU_FLASH_ATTN_PREFERRED_WG_SIZE);
-
-    defines.push_back(std::string("WG_SIZE=") + std::to_string(wg_size));
-
-    ggml_webgpu_processed_shader result;
-    result.wgsl        = preprocessor.preprocess(shader_src, defines);
-    result.variant     = variant;
-    auto decisions     = std::make_shared<ggml_webgpu_flash_attn_shader_decisions>();
-    decisions->q_tile  = q_tile;
-    decisions->kv_tile = kv_tile;
-    decisions->wg_size = wg_size;
-    result.decisions   = decisions;
-    return result;
-}
+/** Matrix Multiplication **/
 
-/** Generic **/
+struct ggml_webgpu_legacy_mul_mat_pipeline_key {
+    ggml_type src0_type;
+    ggml_type src1_type;
 
-struct ggml_webgpu_generic_shader_lib_context {
-    int      vec4;
-    uint32_t max_wg_size;
+    bool operator==(const ggml_webgpu_legacy_mul_mat_pipeline_key & other) const {
+        return src0_type == other.src0_type && src1_type == other.src1_type;
+    }
 };
 
-struct ggml_webgpu_generic_shader_decisions {
-    uint32_t wg_size;
+struct ggml_webgpu_legacy_mul_mat_pipeline_key_hash {
+    size_t operator()(const ggml_webgpu_legacy_mul_mat_pipeline_key & key) const {
+        size_t seed = 0;
+        ggml_webgpu_hash_combine(seed, key.src0_type);
+        ggml_webgpu_hash_combine(seed, key.src1_type);
+        return seed;
+    }
 };
 
-inline ggml_webgpu_processed_shader ggml_webgpu_preprocess_generic_shader(
-    pre_wgsl::Preprocessor &                       preprocessor,
-    const char *                                   shader_src,
-    const ggml_webgpu_generic_shader_lib_context & context,
-    const std::string &                            base_variant) {
-    std::vector<std::string> defines;
-    std::string              variant = base_variant;
+struct ggml_webgpu_mul_mat_vec_pipeline_key {
+    ggml_type src0_type;
+    ggml_type src1_type;
+    int       vectorized;
 
-    if (context.vec4) {
-        defines.push_back("VEC4");
-        variant += "_vec";
+    bool operator==(const ggml_webgpu_mul_mat_vec_pipeline_key & other) const {
+        return src0_type == other.src0_type && src1_type == other.src1_type && vectorized == other.vectorized;
     }
+};
 
-    defines.push_back(std::string("WG_SIZE=") + std::to_string(context.max_wg_size));
-
-    ggml_webgpu_processed_shader result;
-    result.wgsl    = preprocessor.preprocess(shader_src, defines);
-    result.variant = variant;
-    return result;
-}
+struct ggml_webgpu_mul_mat_vec_pipeline_key_hash {
+    size_t operator()(const ggml_webgpu_mul_mat_vec_pipeline_key & key) const {
+        size_t seed = 0;
+        ggml_webgpu_hash_combine(seed, key.src0_type);
+        ggml_webgpu_hash_combine(seed, key.src1_type);
+        ggml_webgpu_hash_combine(seed, key.vectorized);
+        return seed;
+    }
+};
 
-/** Pad **/
+struct ggml_webgpu_mul_mat_vec_shader_decisions {
+    uint32_t wg_size;
+    uint32_t tile_k;
+    uint32_t outputs_per_wg;
+    uint32_t vec_size;
+};
 
-struct ggml_webgpu_pad_pipeline_key {
-    bool circular;
+struct ggml_webgpu_mul_mat_pipeline_key {
+    ggml_type src0_type;
+    ggml_type src1_type;
+    int       vectorized;
+    int       use_subgroup_matrix;
 
-    bool operator==(const ggml_webgpu_pad_pipeline_key & other) const { return circular == other.circular; }
+    bool operator==(const ggml_webgpu_mul_mat_pipeline_key & other) const {
+        return src0_type == other.src0_type && src1_type == other.src1_type && vectorized == other.vectorized &&
+               use_subgroup_matrix == other.use_subgroup_matrix;
+    }
 };
 
-struct ggml_webgpu_pad_pipeline_key_hash {
-    size_t operator()(const ggml_webgpu_pad_pipeline_key & key) const {
+struct ggml_webgpu_mul_mat_pipeline_key_hash {
+    size_t operator()(const ggml_webgpu_mul_mat_pipeline_key & key) const {
         size_t seed = 0;
-        ggml_webgpu_hash_combine(seed, key.circular);
+        ggml_webgpu_hash_combine(seed, key.src0_type);
+        ggml_webgpu_hash_combine(seed, key.src1_type);
+        ggml_webgpu_hash_combine(seed, key.vectorized);
+        ggml_webgpu_hash_combine(seed, key.use_subgroup_matrix);
         return seed;
     }
 };
 
-struct ggml_webgpu_pad_shader_lib_context {
-    ggml_webgpu_pad_pipeline_key key;
-    uint32_t                     max_wg_size;
+struct ggml_webgpu_mul_mat_shader_decisions {
+    uint32_t tile_k;
+    uint32_t wg_size_m;
+    uint32_t wg_size_n;
+    uint32_t wg_size;
+    uint32_t outputs_per_wg;
+    int      use_subgroup_matrix;
+
+    uint32_t tile_m;
+    uint32_t tile_n;
+
+    // Subgroup matrix parameters
+    uint32_t subgroup_m;
+    uint32_t subgroup_n;
+    uint32_t subgroup_matrix_m;
+    uint32_t subgroup_matrix_n;
+
+    uint32_t mul_mat_wg_size;
 };
 
-inline ggml_webgpu_processed_shader ggml_webgpu_preprocess_pad_shader(
-    pre_wgsl::Preprocessor &                   preprocessor,
-    const char *                               shader_src,
-    const ggml_webgpu_pad_shader_lib_context & context) {
-    std::vector<std::string> defines;
-    std::string              variant = "pad";
+class ggml_webgpu_shader_lib {
+    wgpu::Device           device;
+    pre_wgsl::Preprocessor preprocessor;
+
+    std::unordered_map<int, webgpu_pipeline> sum_rows_pipelines;       // key is fixed, no variants yet
+    std::unordered_map<int, webgpu_pipeline> argmax_pipelines;         // key is vec4
+    std::unordered_map<int, webgpu_pipeline> argsort_pipelines;        // key is order
+    std::unordered_map<int, webgpu_pipeline> argsort_merge_pipelines;  // key is order
+    std::unordered_map<int, webgpu_pipeline> cumsum_pipelines;         // key is fixed, no variants yet
+    std::unordered_map<ggml_webgpu_get_rows_pipeline_key, webgpu_pipeline, ggml_webgpu_get_rows_pipeline_key_hash>
+        get_rows_pipelines;                                            // src_type, vectorized
+    std::unordered_map<ggml_webgpu_unary_pipeline_key, webgpu_pipeline, ggml_webgpu_unary_pipeline_key_hash>
+        unary_pipelines;                                               // type/op/inplace
+    std::unordered_map<ggml_webgpu_scale_pipeline_key, webgpu_pipeline, ggml_webgpu_scale_pipeline_key_hash>
+        scale_pipelines;                                               // inplace
+    std::unordered_map<ggml_webgpu_pad_pipeline_key, webgpu_pipeline, ggml_webgpu_pad_pipeline_key_hash>
+        pad_pipelines;                                                 // circular/non-circular
+    std::unordered_map<ggml_webgpu_binary_pipeline_key, webgpu_pipeline, ggml_webgpu_binary_pipeline_key_hash>
+        binary_pipelines;                                              // type/op/inplace/overlap
+    std::unordered_map<ggml_webgpu_flash_attn_pipeline_key, webgpu_pipeline, ggml_webgpu_flash_attn_pipeline_key_hash>
+        flash_attn_pipelines;
+    std::unordered_map<ggml_webgpu_legacy_mul_mat_pipeline_key,
+                       webgpu_pipeline,
+                       ggml_webgpu_legacy_mul_mat_pipeline_key_hash>
+        mul_mat_legacy_pipelines;  // legacy mul_mat (non-subgroup/non-regtile/non-vec)
+    std::unordered_map<ggml_webgpu_mul_mat_vec_pipeline_key, webgpu_pipeline, ggml_webgpu_mul_mat_vec_pipeline_key_hash>
+        mul_mat_vec_pipelines;     // fast mat-vec (n==1)
+    std::unordered_map<ggml_webgpu_mul_mat_pipeline_key, webgpu_pipeline, ggml_webgpu_mul_mat_pipeline_key_hash>
+        mul_mat_fast_pipelines;    // fast mat-mat (reg-tile or subgroup)
+
+    std::unordered_map<ggml_webgpu_set_rows_pipeline_key, webgpu_pipeline, ggml_webgpu_set_rows_pipeline_key_hash>
+        set_rows_pipelines;
+
+  public:
+    ggml_webgpu_shader_lib(wgpu::Device device) { this->device = device; }
+
+    webgpu_pipeline get_sum_rows_pipeline(const ggml_webgpu_shader_lib_context & context) {
+        auto it = sum_rows_pipelines.find(1);
+        if (it != sum_rows_pipelines.end()) {
+            return it->second;
+        }
+        std::vector<std::string> defines;
+        defines.push_back(std::string("WG_SIZE=") + std::to_string(context.max_wg_size));
 
-    if (context.key.circular) {
-        defines.push_back("CIRCULAR");
-        variant += "_circular";
+        auto processed        = preprocessor.preprocess(wgsl_sum_rows, defines);
+        sum_rows_pipelines[1] = ggml_webgpu_create_pipeline(device, processed, "sum_rows");
+        return sum_rows_pipelines[1];
     }
 
-    defines.push_back(std::string("WG_SIZE=") + std::to_string(context.max_wg_size));
+    webgpu_pipeline get_argmax_pipeline(const ggml_webgpu_shader_lib_context & context) {
+        bool vec4 = context.src0->ne[0] % 4 == 0;
 
-    ggml_webgpu_processed_shader result;
-    result.wgsl        = preprocessor.preprocess(shader_src, defines);
-    result.variant     = variant;
-    auto decisions     = std::make_shared<ggml_webgpu_generic_shader_decisions>();
-    decisions->wg_size = context.max_wg_size;
-    result.decisions   = decisions;
-    return result;
-}
+        auto it = argmax_pipelines.find(vec4);
+        if (it != argmax_pipelines.end()) {
+            return it->second;
+        }
+        std::string              variant = "argmax";
+        std::vector<std::string> defines;
+        defines.push_back(std::string("WG_SIZE=") + std::to_string(context.max_wg_size));
+        if (vec4) {
+            defines.push_back("VEC4");
+            variant += "_vec4";
+        }
 
-/** Argsort **/
+        auto processed         = preprocessor.preprocess(wgsl_argmax, defines);
+        argmax_pipelines[vec4] = ggml_webgpu_create_pipeline(device, processed, variant);
+        return argmax_pipelines.at(vec4);
+    }
 
-struct ggml_webgpu_argsort_shader_lib_context {
-    uint32_t max_wg_size;
-    size_t   wg_mem_limit_bytes;
-    int32_t  order;
-};
+    webgpu_pipeline get_set_rows_pipeline(const ggml_webgpu_shader_lib_context & context) {
+        ggml_webgpu_set_rows_pipeline_key key = { .dst_type = context.dst->type,
+                                                  .vec4     = context.src0->ne[0] % 4 == 0,
+                                                  .i64_idx  = context.src1->type == GGML_TYPE_I64 };
 
-struct ggml_webgpu_argsort_shader_decisions {
-    uint32_t wg_size = 0;
-};
+        auto it = set_rows_pipelines.find(key);
+        if (it != set_rows_pipelines.end()) {
+            return it->second;
+        }
 
-inline ggml_webgpu_processed_shader ggml_webgpu_preprocess_argsort_shader(
-    pre_wgsl::Preprocessor &                       preprocessor,
-    const char *                                   shader_src,
-    const ggml_webgpu_argsort_shader_lib_context & context) {
-    std::vector<std::string> defines;
-    std::string              variant = "argsort";
-    defines.push_back(std::string("ORDER=") + std::to_string(context.order));
-    variant += std::string("_order") + std::to_string(context.order);
-    uint32_t wg_size = 1;
-    while (wg_size * 2 <= context.max_wg_size &&
-           wg_size * GGML_WEBGPU_I32_SIZE_BYTES <= context.wg_mem_limit_bytes / 2) {
-        wg_size *= 2;
-    }
-    defines.push_back(std::string("WG_SIZE=") + std::to_string(wg_size));
-    ggml_webgpu_processed_shader result;
-    result.wgsl        = preprocessor.preprocess(shader_src, defines);
-    result.variant     = variant;
-    auto decisions     = std::make_shared<ggml_webgpu_argsort_shader_decisions>();
-    decisions->wg_size = wg_size;
-    result.decisions   = decisions;
-    return result;
-}
+        std::vector<std::string> defines;
+        std::string              variant = "set_rows";
+
+        switch (context.dst->type) {
+            case GGML_TYPE_F32:
+                defines.push_back("DST_F32");
+                variant += "_dstf32";
+                break;
+            case GGML_TYPE_F16:
+                defines.push_back("DST_F16");
+                variant += "_dstf16";
+                break;
+            default:
+                GGML_ABORT("Unsupported dst type for set_rows shader");
+        }
 
-inline ggml_webgpu_processed_shader ggml_webgpu_preprocess_argsort_merge_shader(
-    pre_wgsl::Preprocessor &                       preprocessor,
-    const char *                                   shader_src,
-    const ggml_webgpu_argsort_shader_lib_context & context) {
-    std::vector<std::string> defines;
-    std::string              variant = "argsort_merge";
-    defines.push_back(std::string("ORDER=") + std::to_string(context.order));
-    variant += std::string("_order") + std::to_string(context.order);
-    uint32_t wg_size = std::min(GGML_WEBGPU_ARGSORT_MERGE_MAX_WG_SIZE, context.max_wg_size);
-    defines.push_back(std::string("WG_SIZE=") + std::to_string(wg_size));
-    ggml_webgpu_processed_shader result;
-    result.wgsl        = preprocessor.preprocess(shader_src, defines);
-    result.variant     = variant;
-    auto decisions     = std::make_shared<ggml_webgpu_argsort_shader_decisions>();
-    decisions->wg_size = wg_size;
-    result.decisions   = decisions;
-    return result;
-}
+        if (key.vec4) {
+            defines.push_back("VEC4");
+            variant += "_vec4";
+        }
+        if (key.i64_idx) {
+            defines.push_back("I64_IDX");
+            variant += "_i64idx";
+        }
 
-/** Set Rows **/
+        defines.push_back(std::string("WG_SIZE=") + std::to_string(context.max_wg_size));
 
-struct ggml_webgpu_set_rows_pipeline_key {
-    int dst_type;
-    int vec4;
-    int i64_idx;
+        auto processed                  = preprocessor.preprocess(wgsl_set_rows, defines);
+        auto decisions                  = std::make_shared<ggml_webgpu_set_rows_shader_decisions>();
+        decisions->vec4                 = key.vec4;
+        decisions->i64_idx              = key.i64_idx;
+        decisions->wg_size              = context.max_wg_size;
+        set_rows_pipelines[key]         = ggml_webgpu_create_pipeline(device, processed, variant);
+        set_rows_pipelines[key].context = decisions;
+        return set_rows_pipelines[key];
+    }
 
-    bool operator==(const ggml_webgpu_set_rows_pipeline_key & other) const {
-        return dst_type == other.dst_type && vec4 == other.vec4 && i64_idx == other.i64_idx;
+    webgpu_pipeline get_cumsum_pipeline(const ggml_webgpu_shader_lib_context & context) {
+        auto it = cumsum_pipelines.find(1);
+        if (it != cumsum_pipelines.end()) {
+            return it->second;
+        }
+
+        std::vector<std::string> defines;
+        defines.push_back(std::string("WG_SIZE=") + std::to_string(context.max_wg_size));
+
+        auto processed      = preprocessor.preprocess(wgsl_cumsum, defines);
+        cumsum_pipelines[1] = ggml_webgpu_create_pipeline(device, processed, "cumsum");
+        return cumsum_pipelines[1];
     }
-};
 
-struct ggml_webgpu_set_rows_pipeline_key_hash {
-    size_t operator()(const ggml_webgpu_set_rows_pipeline_key & key) const {
-        size_t seed = 0;
-        ggml_webgpu_hash_combine(seed, key.dst_type);
-        ggml_webgpu_hash_combine(seed, key.vec4);
-        ggml_webgpu_hash_combine(seed, key.i64_idx);
-        return seed;
+    webgpu_pipeline get_argsort_pipeline(const ggml_webgpu_shader_lib_context & context) {
+        bool          is_top_k = context.dst->op == GGML_OP_TOP_K;
+        // ascending order is 0, descending order is 1
+        const int32_t order =
+            is_top_k ? (int32_t) GGML_SORT_ORDER_DESC : (int32_t) ggml_get_op_params_i32(context.dst, 0);
+
+        auto it = argsort_pipelines.find(order);
+        if (it != argsort_pipelines.end()) {
+            return it->second;
+        }
+
+        std::vector<std::string> defines;
+        std::string              variant = "argsort";
+        defines.push_back(std::string("ORDER=") + std::to_string(order));
+        variant += std::string("_order") + std::to_string(order);
+        uint32_t wg_size = 1;
+        while (wg_size * 2 <= context.max_wg_size &&
+               wg_size * GGML_WEBGPU_I32_SIZE_BYTES <= context.wg_mem_limit_bytes / 2) {
+            wg_size *= 2;
+        }
+        defines.push_back(std::string("WG_SIZE=") + std::to_string(wg_size));
+        auto processed                   = preprocessor.preprocess(wgsl_argsort, defines);
+        auto decisions                   = std::make_shared<ggml_webgpu_generic_shader_decisions>();
+        decisions->wg_size               = wg_size;
+        argsort_pipelines[order]         = ggml_webgpu_create_pipeline(device, processed, variant);
+        argsort_pipelines[order].context = decisions;
+        return argsort_pipelines[order];
     }
-};
 
-struct ggml_webgpu_set_rows_shader_lib_context {
-    ggml_webgpu_set_rows_pipeline_key key;
-    uint32_t                          max_wg_size;
-};
+    webgpu_pipeline get_argsort_merge_pipeline(const ggml_webgpu_shader_lib_context & context) {
+        bool          is_top_k = context.dst->op == GGML_OP_TOP_K;
+        // ascending order is 0, descending order is 1
+        const int32_t order =
+            is_top_k ? (int32_t) GGML_SORT_ORDER_DESC : (int32_t) ggml_get_op_params_i32(context.dst, 0);
 
-inline ggml_webgpu_processed_shader ggml_webgpu_preprocess_set_rows_shader(
-    pre_wgsl::Preprocessor &                        preprocessor,
-    const char *                                    shader_src,
-    const ggml_webgpu_set_rows_shader_lib_context & context) {
-    std::vector<std::string> defines;
-    std::string              variant = "set_rows";
-
-    switch (context.key.dst_type) {
-        case GGML_TYPE_F32:
-            defines.push_back("DST_F32");
-            variant += "_dstf32";
-            break;
-        case GGML_TYPE_F16:
-            defines.push_back("DST_F16");
-            variant += "_dstf16";
-            break;
-        default:
-            GGML_ABORT("Unsupported dst type for set_rows shader");
-    }
-
-    if (context.key.vec4) {
-        defines.push_back("VEC4");
-        variant += "_vec";
-    }
-    if (context.key.i64_idx) {
-        defines.push_back("I64_IDX");
-        variant += "_i64idx";
-    }
-
-    defines.push_back(std::string("WG_SIZE=") + std::to_string(context.max_wg_size));
-
-    ggml_webgpu_processed_shader result;
-    result.wgsl        = preprocessor.preprocess(shader_src, defines);
-    result.variant     = variant;
-    auto decisions     = std::make_shared<ggml_webgpu_generic_shader_decisions>();
-    decisions->wg_size = context.max_wg_size;
-    result.decisions   = decisions;
-    return result;
-}
+        auto it = argsort_merge_pipelines.find(order);
+        if (it != argsort_merge_pipelines.end()) {
+            return it->second;
+        }
 
-struct ggml_webgpu_unary_pipeline_key {
-    int  type;
-    int  op;
-    bool is_unary;  // many unary operators fall under the GGML_OP_UNARY umbrella
-    bool inplace;
+        std::vector<std::string> defines;
+        std::string              variant = "argsort_merge";
+        defines.push_back(std::string("ORDER=") + std::to_string(order));
+        variant += std::string("_order") + std::to_string(order);
+        uint32_t wg_size = std::min(GGML_WEBGPU_ARGSORT_MERGE_MAX_WG_SIZE, context.max_wg_size);
+        defines.push_back(std::string("WG_SIZE=") + std::to_string(wg_size));
 
-    bool operator==(const ggml_webgpu_unary_pipeline_key & other) const {
-        return type == other.type && op == other.op && is_unary == other.is_unary && inplace == other.inplace;
+        auto processed                 = preprocessor.preprocess(wgsl_argsort_merge, defines);
+        argsort_merge_pipelines[order] = ggml_webgpu_create_pipeline(device, processed, variant);
+        return argsort_merge_pipelines[order];
     }
-};
 
-struct ggml_webgpu_unary_pipeline_key_hash {
-    size_t operator()(const ggml_webgpu_unary_pipeline_key & key) const {
-        size_t seed = 0;
-        ggml_webgpu_hash_combine(seed, key.type);
-        ggml_webgpu_hash_combine(seed, key.op);
-        ggml_webgpu_hash_combine(seed, key.is_unary);
-        ggml_webgpu_hash_combine(seed, key.inplace);
-        return seed;
+    webgpu_pipeline get_get_rows_pipeline(const ggml_webgpu_shader_lib_context & context) {
+        const bool vectorized                 = context.src0->type == GGML_TYPE_F32 && context.dst->ne[0] % 4 == 0;
+        ggml_webgpu_get_rows_pipeline_key key = {
+            .src_type   = context.src0->type,
+            .vectorized = (int) vectorized,
+        };
+
+        auto it = get_rows_pipelines.find(key);
+        if (it != get_rows_pipelines.end()) {
+            return it->second;
+        }
+
+        std::vector<std::string> defines;
+        std::string              variant = "get_rows";
+
+        const struct ggml_type_traits * type_traits = ggml_get_type_traits(key.src_type);
+        const char *                    type_str    = type_traits->type_name;
+
+        switch (key.src_type) {
+            case GGML_TYPE_F32:
+                if (key.vectorized) {
+                    defines.push_back("F32_VEC");
+                    defines.push_back("SRC_TYPE=vec4<f32>");
+                    defines.push_back("DST_TYPE=vec4<f32>");
+                    defines.push_back("BLOCK_SIZE=4u");
+                } else {
+                    defines.push_back("F32");
+                    defines.push_back("SRC_TYPE=f32");
+                    defines.push_back("DST_TYPE=f32");
+                    defines.push_back("BLOCK_SIZE=1u");
+                }
+                variant += "_f32";
+                break;
+            case GGML_TYPE_F16:
+                defines.push_back("F16");
+                defines.push_back("SRC_TYPE=f16");
+                defines.push_back("DST_TYPE=f32");
+                defines.push_back("BLOCK_SIZE=1u");
+                variant += "_f16";
+                break;
+            case GGML_TYPE_I32:
+                defines.push_back("I32");
+                defines.push_back("SRC_TYPE=i32");
+                defines.push_back("DST_TYPE=i32");
+                defines.push_back("BLOCK_SIZE=1u");
+                variant += "_i32";
+                break;
+            default:
+                {
+                    std::string type_upper = type_str;
+                    std::transform(type_upper.begin(), type_upper.end(), type_upper.begin(), ::toupper);
+
+                    defines.push_back("BYTE_HELPERS");
+                    defines.push_back(type_upper + "_T");
+                    defines.push_back(type_upper);
+                    defines.push_back(type_upper + "_SCALE_MIN");
+                    defines.push_back(type_upper + "_TABLES");
+                    defines.push_back(type_upper + "_GRID");
+
+                    variant += "_";
+                    variant += type_str;
+
+                    defines.push_back(std::string("SRC_TYPE=") + type_str);
+                    defines.push_back("DST_TYPE=f32");
+
+                    if ((key.src_type >= GGML_TYPE_Q4_0 && key.src_type <= GGML_TYPE_Q8_1) ||
+                        key.src_type == GGML_TYPE_IQ4_NL) {
+                        defines.push_back("BLOCK_SIZE=32u");
+                    } else if (key.src_type >= GGML_TYPE_Q2_K) {
+                        defines.push_back("BLOCK_SIZE=256u");
+                    } else {
+                        defines.push_back("BLOCK_SIZE=1u");
+                    }
+                    break;
+                }
+        }
+
+        if (key.vectorized) {
+            variant += "_vec";
+        }
+
+        defines.push_back("WG_SIZE=" + std::to_string(context.max_wg_size));
+
+        auto processed           = preprocessor.preprocess(wgsl_get_rows, defines);
+        auto decisions           = std::make_shared<ggml_webgpu_generic_shader_decisions>();
+        decisions->wg_size       = context.max_wg_size;
+        webgpu_pipeline pipeline = ggml_webgpu_create_pipeline(device, processed, variant);
+        pipeline.context         = decisions;
+        get_rows_pipelines[key]  = pipeline;
+        return get_rows_pipelines[key];
     }
-};
 
-struct ggml_webgpu_unary_shader_lib_context {
-    ggml_webgpu_unary_pipeline_key key;
-    uint32_t                       max_wg_size;
-};
+    webgpu_pipeline get_scale_pipeline(const ggml_webgpu_shader_lib_context & context) {
+        ggml_webgpu_scale_pipeline_key key = { .inplace = context.inplace };
 
-inline ggml_webgpu_processed_shader ggml_webgpu_preprocess_unary_shader(
-    pre_wgsl::Preprocessor &                     preprocessor,
-    const char *                                 shader_src,
-    const ggml_webgpu_unary_shader_lib_context & context) {
-    std::vector<std::string> defines;
-    std::string              variant = context.key.is_unary ? ggml_unary_op_name((ggml_unary_op) context.key.op) :
-                                                              ggml_op_name((ggml_op) context.key.op);
-    // Operation-specific behavior
-    defines.push_back(variant);
-
-    switch (context.key.type) {
-        case GGML_TYPE_F32:
-            defines.push_back("TYPE_F32");
-            variant += "_f32";
-            break;
-        case GGML_TYPE_F16:
-            defines.push_back("TYPE_F16");
-            variant += "_f16";
-            break;
-        default:
-            GGML_ABORT("Unsupported type for unary shader");
-    }
-
-    if (context.key.inplace) {
-        defines.push_back("INPLACE");
-        variant += "_inplace";
-    }
-
-    defines.push_back(std::string("WG_SIZE=") + std::to_string(context.max_wg_size));
-
-    ggml_webgpu_processed_shader result;
-    result.wgsl        = preprocessor.preprocess(shader_src, defines);
-    result.variant     = variant;
-    auto decisions     = std::make_shared<ggml_webgpu_generic_shader_decisions>();
-    decisions->wg_size = context.max_wg_size;
-    result.decisions   = decisions;
-    return result;
-}
+        auto it = scale_pipelines.find(key);
+        if (it != scale_pipelines.end()) {
+            return it->second;
+        }
 
-/** Binary **/
+        std::vector<std::string> defines;
+        std::string              variant = "scale";
 
-struct ggml_webgpu_binary_pipeline_key {
-    int  type;
-    int  op;
-    bool inplace;
-    bool overlap;
+        if (key.inplace) {
+            defines.push_back("INPLACE");
+            variant += "_inplace";
+        }
 
-    bool operator==(const ggml_webgpu_binary_pipeline_key & other) const {
-        return type == other.type && op == other.op && inplace == other.inplace && overlap == other.overlap;
+        defines.push_back(std::string("WG_SIZE=") + std::to_string(context.max_wg_size));
+
+        auto processed           = preprocessor.preprocess(wgsl_scale, defines);
+        auto decisions           = std::make_shared<ggml_webgpu_generic_shader_decisions>();
+        decisions->wg_size       = context.max_wg_size;
+        webgpu_pipeline pipeline = ggml_webgpu_create_pipeline(device, processed, variant);
+        pipeline.context         = decisions;
+        scale_pipelines[key]     = pipeline;
+        return scale_pipelines[key];
     }
-};
 
-struct ggml_webgpu_binary_pipeline_key_hash {
-    size_t operator()(const ggml_webgpu_binary_pipeline_key & key) const {
-        size_t seed = 0;
-        ggml_webgpu_hash_combine(seed, key.type);
-        ggml_webgpu_hash_combine(seed, key.op);
-        ggml_webgpu_hash_combine(seed, key.inplace);
-        ggml_webgpu_hash_combine(seed, key.overlap);
-        return seed;
+    webgpu_pipeline get_pad_pipeline(const ggml_webgpu_shader_lib_context & context) {
+        ggml_webgpu_pad_pipeline_key key = { .circular = ggml_get_op_params_i32(context.dst, 8) != 0 };
+
+        auto it = pad_pipelines.find(key);
+        if (it != pad_pipelines.end()) {
+            return it->second;
+        }
+
+        std::vector<std::string> defines;
+        std::string              variant = "pad";
+
+        if (key.circular) {
+            defines.push_back("CIRCULAR");
+            variant += "_circular";
+        }
+
+        defines.push_back(std::string("WG_SIZE=") + std::to_string(context.max_wg_size));
+
+        auto processed           = preprocessor.preprocess(wgsl_pad, defines);
+        auto decisions           = std::make_shared<ggml_webgpu_generic_shader_decisions>();
+        decisions->wg_size       = context.max_wg_size;
+        webgpu_pipeline pipeline = ggml_webgpu_create_pipeline(device, processed, variant);
+        pipeline.context         = decisions;
+        pad_pipelines[key]       = pipeline;
+        return pad_pipelines[key];
     }
-};
 
-struct ggml_webgpu_binary_shader_lib_context {
-    ggml_webgpu_binary_pipeline_key key;
-    uint32_t                        max_wg_size;
+    webgpu_pipeline get_mul_mat_vec_pipeline(const ggml_webgpu_shader_lib_context & context) {
+        ggml_webgpu_mul_mat_vec_pipeline_key key = {
+            .src0_type  = context.src0->type,
+            .src1_type  = context.src1->type,
+            // Quantized mat-vec path currently runs scalar; only allow vectorization when both inputs are float
+            .vectorized = (context.src0->ne[0] % 4 == 0 && context.dst->ne[0] % 4 == 0 &&
+                           (context.src0->type == GGML_TYPE_F32 || context.src0->type == GGML_TYPE_F16)) ?
+                              1 :
+                              0,
+        };
+
+        auto it = mul_mat_vec_pipelines.find(key);
+        if (it != mul_mat_vec_pipelines.end()) {
+            return it->second;
+        }
+
+        std::vector<std::string> defines;
+        std::string              variant = "mul_mat_vec";
+
+        // src1 type (vector)
+        switch (context.src1->type) {
+            case GGML_TYPE_F32:
+                defines.push_back("SRC1_INNER_TYPE=f32");
+                variant += "_f32";
+                break;
+            case GGML_TYPE_F16:
+                defines.push_back("SRC1_INNER_TYPE=f16");
+                variant += "_f16";
+                break;
+            default:
+                GGML_ABORT("Unsupported src1 type for mul_mat_vec shader");
+        }
+
+        // src0 type (matrix row)
+        switch (context.src0->type) {
+            case GGML_TYPE_F32:
+                defines.push_back("SRC0_INNER_TYPE=f32");
+                defines.push_back("MUL_ACC_FLOAT");
+                break;
+            case GGML_TYPE_F16:
+                defines.push_back("SRC0_INNER_TYPE=f16");
+                defines.push_back("MUL_ACC_FLOAT");
+                break;
+            default:
+                {
+                    // Quantized types: use helpers but accumulate in f16
+                    const struct ggml_type_traits * src0_traits = ggml_get_type_traits(context.src0->type);
+                    std::string                     src0_name   = src0_traits->type_name;
+                    std::string                     type_upper  = src0_name;
+                    std::transform(type_upper.begin(), type_upper.end(), type_upper.begin(), ::toupper);
+
+                    defines.push_back("BYTE_HELPERS");
+                    defines.push_back("MUL_ACC_" + type_upper);
+
+                    // For fast path we always dequantize from f16 inside the shader
+                    defines.push_back("SRC0_INNER_TYPE=f16");
+                    break;
+                }
+        }
+
+        // VEC/SCALAR controls
+        defines.push_back(key.vectorized ? "VEC" : "SCALAR");
+
+        uint32_t wg_size        = WEBGPU_MUL_MAT_VEC_WG_SIZE;
+        uint32_t tile_k         = WEBGPU_MUL_MAT_VEC_TILE_K;
+        uint32_t outputs_per_wg = WEBGPU_MUL_MAT_VEC_OUTPUTS_PER_WG;
+        defines.push_back(std::string("WG_SIZE=") + std::to_string(wg_size));
+        defines.push_back(std::string("TILE_K=") + std::to_string(tile_k));
+        defines.push_back(std::string("OUTPUTS_PER_WG=") + std::to_string(outputs_per_wg));
+
+        auto processed            = preprocessor.preprocess(wgsl_mul_mat_vec, defines);
+        auto decisions            = std::make_shared<ggml_webgpu_mul_mat_vec_shader_decisions>();
+        decisions->wg_size        = wg_size;
+        decisions->tile_k         = tile_k;
+        decisions->outputs_per_wg = outputs_per_wg;
+        decisions->vec_size       = key.vectorized ? 4 : 1;
+
+        webgpu_pipeline pipeline   = ggml_webgpu_create_pipeline(device, processed, variant);
+        pipeline.context           = decisions;
+        mul_mat_vec_pipelines[key] = pipeline;
+        return mul_mat_vec_pipelines[key];
+    }
+
+    webgpu_pipeline get_mul_mat_fast_pipeline(const ggml_webgpu_shader_lib_context & context) {
+        ggml_webgpu_mul_mat_pipeline_key key = {
+            .src0_type  = context.src0->type,
+            .src1_type  = context.src1->type,
+            .vectorized = (context.src0->ne[0] % 4 == 0 && context.dst->ne[0] % 4 == 0 && context.dst->ne[1] % 4 == 0 &&
+                           (context.src0->type == GGML_TYPE_F32 || context.src0->type == GGML_TYPE_F16)) ?
+                              1 :
+                              0,
+            .use_subgroup_matrix = context.supports_subgroup_matrix
+        };
+
+        auto it = mul_mat_fast_pipelines.find(key);
+        if (it != mul_mat_fast_pipelines.end()) {
+            return it->second;
+        }
+
+        const char * shader_src = key.use_subgroup_matrix ? wgsl_mul_mat_subgroup_matrix : wgsl_mul_mat_reg_tile;
+        std::vector<std::string> defines;
+        std::string              variant = key.use_subgroup_matrix ? "mul_mat_subgroup_matrix" : "mul_mat_reg_tile";
+
+        // src1 type
+        switch (context.src1->type) {
+            case GGML_TYPE_F32:
+                defines.push_back("SRC1_INNER_TYPE=f32");
+                break;
+            case GGML_TYPE_F16:
+                defines.push_back("SRC1_INNER_TYPE=f16");
+                break;
+            default:
+                GGML_ABORT("Unsupported src1 type for mul_mat fast shader");
+        }
+
+        // src0 type
+        const struct ggml_type_traits * src0_traits = ggml_get_type_traits(context.src0->type);
+        const char *                    src0_name   = src0_traits->type_name;
+
+        switch (context.src0->type) {
+            case GGML_TYPE_F32:
+                defines.push_back("SRC0_INNER_TYPE=f32");
+                defines.push_back("FLOAT");
+                defines.push_back("MUL_ACC_FLOAT");
+                defines.push_back("INIT_SRC0_SHMEM_FLOAT");
+                defines.push_back("INIT_SRC1_SHMEM_FLOAT");
+                variant += "_f32";
+                break;
+            case GGML_TYPE_F16:
+                defines.push_back("SRC0_INNER_TYPE=f16");
+                defines.push_back("FLOAT");
+                defines.push_back("MUL_ACC_FLOAT");
+                defines.push_back("INIT_SRC0_SHMEM_FLOAT");
+                defines.push_back("INIT_SRC1_SHMEM_FLOAT");
+                variant += "_f16";
+                break;
+            default:
+                {
+                    std::string type_upper = src0_name;
+                    std::transform(type_upper.begin(), type_upper.end(), type_upper.begin(), ::toupper);
+
+                    defines.push_back("BYTE_HELPERS");
+                    defines.push_back("MUL_ACC_" + type_upper);
+                    defines.push_back("INIT_SRC0_SHMEM_" + type_upper);
+                    defines.push_back("INIT_SRC1_SHMEM_FLOAT");
+
+                    // Use f16 inside the shader for quantized types
+                    defines.push_back("SRC0_INNER_TYPE=f16");
+
+                    variant += std::string("_") + src0_name;
+                    break;
+                }
+        }
+
+        // VEC/SCALAR controls
+        defines.push_back(key.vectorized ? "VEC" : "SCALAR");
+
+        // Tiles
+        defines.push_back("TILE_M=" + std::to_string(WEBGPU_MUL_MAT_TILE_M) + "u");
+        defines.push_back("TILE_N=" + std::to_string(WEBGPU_MUL_MAT_TILE_N) + "u");
+        defines.push_back("TILE_K=" + std::to_string(WEBGPU_MUL_MAT_TILE_K) + "u");
+
+        // Subgroup matrix specifics
+        if (key.use_subgroup_matrix) {
+            defines.push_back("MAX_SUBGROUP_SIZE=" + std::to_string(context.max_subgroup_size) + "u");
+            defines.push_back("SUBGROUP_M=" + std::to_string(WEBGPU_MUL_MAT_SUBGROUP_M) + "u");
+            defines.push_back("SUBGROUP_N=" + std::to_string(WEBGPU_MUL_MAT_SUBGROUP_N) + "u");
+            defines.push_back("SUBGROUP_MATRIX_M=" + std::to_string(WEBGPU_MUL_MAT_SUBGROUP_MATRIX_M) + "u");
+            defines.push_back("SUBGROUP_MATRIX_N=" + std::to_string(WEBGPU_MUL_MAT_SUBGROUP_MATRIX_N) + "u");
+            defines.push_back("SUBGROUP_MATRIX_M_SIZE=" + std::to_string(context.sg_mat_m) + "u");
+            defines.push_back("SUBGROUP_MATRIX_N_SIZE=" + std::to_string(context.sg_mat_n) + "u");
+            defines.push_back("SUBGROUP_MATRIX_K_SIZE=" + std::to_string(context.sg_mat_k) + "u");
+        }
+
+        // variant suffix for src1 type
+        variant += std::string("_") + (context.src1->type == GGML_TYPE_F32 ? "f32" : "f16");
+        if (key.vectorized) {
+            variant += "_vectorized";
+        }
+
+        if (!key.use_subgroup_matrix) {
+            defines.push_back("WORKGROUP_SIZE_M=" + std::to_string(WEBGPU_MUL_MAT_WG_SIZE_M) + "u");
+            defines.push_back("WORKGROUP_SIZE_N=" + std::to_string(WEBGPU_MUL_MAT_WG_SIZE_N) + "u");
+        }
+
+        auto processed = preprocessor.preprocess(shader_src, defines);
+
+        auto decisions                 = std::make_shared<ggml_webgpu_mul_mat_shader_decisions>();
+        decisions->tile_k              = WEBGPU_MUL_MAT_TILE_K;
+        decisions->tile_m              = WEBGPU_MUL_MAT_TILE_M;
+        decisions->tile_n              = WEBGPU_MUL_MAT_TILE_N;
+        decisions->use_subgroup_matrix = key.use_subgroup_matrix;
+        if (key.use_subgroup_matrix) {
+            decisions->subgroup_m        = WEBGPU_MUL_MAT_SUBGROUP_M;
+            decisions->subgroup_n        = WEBGPU_MUL_MAT_SUBGROUP_N;
+            decisions->subgroup_matrix_m = WEBGPU_MUL_MAT_SUBGROUP_MATRIX_M;
+            decisions->subgroup_matrix_n = WEBGPU_MUL_MAT_SUBGROUP_MATRIX_N;
+            decisions->wg_size           = context.max_subgroup_size;
+        } else {
+            decisions->wg_size_m       = WEBGPU_MUL_MAT_WG_SIZE_M;
+            decisions->wg_size_n       = WEBGPU_MUL_MAT_WG_SIZE_N;
+            decisions->wg_size         = WEBGPU_MUL_MAT_WG_SIZE_M * WEBGPU_MUL_MAT_WG_SIZE_N;
+            decisions->mul_mat_wg_size = WEBGPU_MUL_MAT_WG_SIZE;
+        }
+
+        webgpu_pipeline pipeline    = ggml_webgpu_create_pipeline(device, processed, variant);
+        pipeline.context            = decisions;
+        mul_mat_fast_pipelines[key] = pipeline;
+        return mul_mat_fast_pipelines[key];
+    }
+
+    webgpu_pipeline get_mul_mat_legacy_pipeline(const ggml_webgpu_shader_lib_context & context) {
+        ggml_webgpu_legacy_mul_mat_pipeline_key key = { .src0_type = context.src0->type,
+                                                        .src1_type = context.src1->type };
+
+        auto it = mul_mat_legacy_pipelines.find(key);
+        if (it != mul_mat_legacy_pipelines.end()) {
+            return it->second;
+        }
+
+        std::vector<std::string> defines;
+        std::string              variant = "mul_mat";
+
+        switch (context.src1->type) {
+            case GGML_TYPE_F32:
+                defines.push_back("SRC1_TYPE=f32");
+                variant += "_f32";
+                break;
+            case GGML_TYPE_F16:
+                defines.push_back("SRC1_TYPE=f16");
+                variant += "_f16";
+                break;
+            default:
+                GGML_ABORT("Unsupported src1 type for mul_mat legacy shader");
+        }
+
+        const struct ggml_type_traits * src0_traits = ggml_get_type_traits(context.src0->type);
+        const char *                    src0_name   = src0_traits->type_name;
+
+        switch (context.src0->type) {
+            case GGML_TYPE_F32:
+                defines.push_back("SRC0_TYPE=f32");
+                defines.push_back("FLOAT");
+                variant += "_f32";
+                break;
+            case GGML_TYPE_F16:
+                defines.push_back("SRC0_TYPE=f16");
+                defines.push_back("FLOAT");
+                variant += "_f16";
+                break;
+            default:
+                {
+                    // quantized types
+                    std::string type_upper = src0_name;
+                    std::transform(type_upper.begin(), type_upper.end(), type_upper.begin(), ::toupper);
+
+                    defines.push_back(std::string("SRC0_TYPE=") + src0_name);
+                    defines.push_back("BYTE_HELPERS");
+                    defines.push_back(type_upper + "_T");
+                    defines.push_back(type_upper);
+                    defines.push_back(type_upper + "_SCALE_MIN");
+                    defines.push_back(type_upper + "_TABLES");
+                    defines.push_back(type_upper + "_GRID");
+
+                    variant += std::string("_") + src0_name;
+                    break;
+                }
+        }
+
+        auto processed = preprocessor.preprocess(wgsl_mul_mat, defines);
+
+        auto decisions     = std::make_shared<ggml_webgpu_generic_shader_decisions>();
+        decisions->wg_size = WEBGPU_MUL_MAT_WG_SIZE;
+
+        webgpu_pipeline pipeline      = ggml_webgpu_create_pipeline(device, processed, variant);
+        pipeline.context              = decisions;
+        mul_mat_legacy_pipelines[key] = pipeline;
+        return mul_mat_legacy_pipelines[key];
+    }
+
+    webgpu_pipeline get_unary_pipeline(const ggml_webgpu_shader_lib_context & context) {
+        const bool                     is_unary = context.dst->op == GGML_OP_UNARY;
+        const int                      op       = is_unary ? (int) ggml_get_unary_op(context.dst) : context.dst->op;
+        ggml_webgpu_unary_pipeline_key key      = {
+                 .type     = context.dst->type,
+                 .op       = op,
+                 .is_unary = is_unary,
+                 .inplace  = context.inplace,
+        };
+
+        auto it = unary_pipelines.find(key);
+        if (it != unary_pipelines.end()) {
+            return it->second;
+        }
+
+        std::vector<std::string> defines;
+        std::string              variant =
+            key.is_unary ? ggml_unary_op_name((ggml_unary_op) key.op) : ggml_op_name((ggml_op) key.op);
+        defines.push_back(variant);
+
+        switch (key.type) {
+            case GGML_TYPE_F32:
+                defines.push_back("TYPE_F32");
+                variant += "_f32";
+                break;
+            case GGML_TYPE_F16:
+                defines.push_back("TYPE_F16");
+                variant += "_f16";
+                break;
+            default:
+                GGML_ABORT("Unsupported type for unary shader");
+        }
+
+        if (key.inplace) {
+            defines.push_back("INPLACE");
+            variant += "_inplace";
+        }
+
+        defines.push_back(std::string("WG_SIZE=") + std::to_string(context.max_wg_size));
+
+        auto processed           = preprocessor.preprocess(wgsl_unary, defines);
+        auto decisions           = std::make_shared<ggml_webgpu_generic_shader_decisions>();
+        decisions->wg_size       = context.max_wg_size;
+        webgpu_pipeline pipeline = ggml_webgpu_create_pipeline(device, processed, variant);
+        pipeline.context         = decisions;
+        unary_pipelines[key]     = pipeline;
+        return unary_pipelines[key];
+    }
+
+    webgpu_pipeline get_binary_pipeline(const ggml_webgpu_shader_lib_context & context) {
+        ggml_webgpu_binary_pipeline_key key = {
+            .type    = context.dst->type,
+            .op      = context.dst->op,
+            .inplace = context.inplace,
+            .overlap = context.overlap,
+        };
+
+        auto it = binary_pipelines.find(key);
+        if (it != binary_pipelines.end()) {
+            return it->second;
+        }
+
+        std::vector<std::string> defines;
+        std::string              op_name = ggml_op_name((ggml_op) key.op);
+        std::string              variant = op_name;
+
+        defines.push_back(std::string("OP_") + op_name);
+
+        switch (key.type) {
+            case GGML_TYPE_F32:
+                defines.push_back("TYPE_F32");
+                variant += "_f32";
+                break;
+            case GGML_TYPE_F16:
+                defines.push_back("TYPE_F16");
+                variant += "_f16";
+                break;
+            default:
+                GGML_ABORT("Unsupported type for binary shader");
+        }
+
+        if (key.inplace) {
+            defines.push_back("INPLACE");
+            variant += "_inplace";
+        } else if (key.overlap) {
+            defines.push_back("OVERLAP");
+            variant += "_overlap";
+        }
+
+        defines.push_back(std::string("WG_SIZE=") + std::to_string(context.max_wg_size));
+
+        auto processed           = preprocessor.preprocess(wgsl_binary, defines);
+        auto decisions           = std::make_shared<ggml_webgpu_generic_shader_decisions>();
+        decisions->wg_size       = context.max_wg_size;
+        webgpu_pipeline pipeline = ggml_webgpu_create_pipeline(device, processed, variant);
+        pipeline.context         = decisions;
+        binary_pipelines[key]    = pipeline;
+        return binary_pipelines[key];
+    }
+
+    webgpu_pipeline get_flash_attn_pipeline(const ggml_webgpu_shader_lib_context & context) {
+        const bool has_mask  = context.src3 != nullptr;
+        const bool has_sinks = context.src4 != nullptr;
+
+        bool kv_direct = (context.src1->type == GGML_TYPE_F16) && (context.src0->ne[0] % context.sg_mat_k == 0) &&
+                         (context.src1->ne[1] % context.sg_mat_n == 0);
+
+        ggml_webgpu_flash_attn_pipeline_key key = {
+            .kv_type            = context.src1->type,
+            .head_dim_qk        = (uint32_t) context.src0->ne[0],
+            .head_dim_v         = (uint32_t) context.src2->ne[0],
+            .kv_direct          = kv_direct,
+            .has_mask           = has_mask,
+            .has_sinks          = has_sinks,
+            .uses_logit_softcap = (*(float *) &context.dst->op_params[2]) != 0.0f,
+        };
+
+        auto it = flash_attn_pipelines.find(key);
+        if (it != flash_attn_pipelines.end()) {
+            return it->second;
+        }
+
+        std::vector<std::string> defines;
+        std::string              variant = "flash_attn";
+
+        switch (key.kv_type) {
+            case GGML_TYPE_F32:
+                defines.push_back("KV_F32");
+                break;
+            case GGML_TYPE_F16:
+                defines.push_back("KV_F16");
+                break;
+            case GGML_TYPE_Q4_0:
+                defines.push_back("KV_Q4_0");
+                break;
+            case GGML_TYPE_Q8_0:
+                defines.push_back("KV_Q8_0");
+                break;
+            default:
+                GGML_ABORT("Unsupported KV type for flash attention shader");
+        }
+        variant += std::string("_") + ggml_type_name(key.kv_type);
+
+        if (key.has_mask) {
+            defines.push_back("MASK");
+            variant += "_mask";
+        }
+        if (key.has_sinks) {
+            defines.push_back("SINKS");
+            variant += "_sinks";
+        }
+        if (key.uses_logit_softcap) {
+            defines.push_back("LOGIT_SOFTCAP");
+            variant += "_lgsc";
+        }
+        if (key.kv_direct) {
+            defines.push_back("KV_DIRECT");
+            variant += "_kvdirect";
+        }
+
+        defines.push_back(std::string("HEAD_DIM_QK=") + std::to_string(key.head_dim_qk));
+        variant += std::string("_hsqk") + std::to_string(key.head_dim_qk);
+
+        defines.push_back(std::string("HEAD_DIM_V=") + std::to_string(key.head_dim_v));
+        variant += std::string("_hsv") + std::to_string(key.head_dim_v);
+
+        defines.push_back(std::string("SG_MAT_M=") + std::to_string(context.sg_mat_m));
+        defines.push_back(std::string("SG_MAT_N=") + std::to_string(context.sg_mat_n));
+        defines.push_back(std::string("SG_MAT_K=") + std::to_string(context.sg_mat_k));
+
+        uint32_t q_tile = context.sg_mat_m;
+        uint32_t kv_tile =
+            std::min(ggml_webgpu_flash_attn_max_kv_tile({ key, context.sg_mat_m, context.sg_mat_n, context.sg_mat_k,
+                                                          context.wg_mem_limit_bytes, context.max_subgroup_size }),
+                     context.sg_mat_n * GGML_WEBGPU_FLASH_ATTN_PREFERRED_KV_SG_TILES);
+        if (key.kv_direct) {
+            while (GGML_WEBGPU_KV_SEQ_PAD % kv_tile != 0) {
+                kv_tile -= context.sg_mat_n;
+            }
+        }
+
+        defines.push_back(std::string("Q_TILE=") + std::to_string(q_tile));
+        defines.push_back(std::string("KV_TILE=") + std::to_string(kv_tile));
+
+        uint32_t wg_size = std::max(context.max_subgroup_size, GGML_WEBGPU_FLASH_ATTN_PREFERRED_WG_SIZE);
+        defines.push_back(std::string("WG_SIZE=") + std::to_string(wg_size));
+
+        auto processed     = preprocessor.preprocess(wgsl_flash_attn, defines);
+        auto decisions     = std::make_shared<ggml_webgpu_flash_attn_shader_decisions>();
+        decisions->q_tile  = q_tile;
+        decisions->kv_tile = kv_tile;
+        decisions->wg_size = wg_size;
+
+        webgpu_pipeline pipeline  = ggml_webgpu_create_pipeline(device, processed, variant);
+        pipeline.context          = decisions;
+        flash_attn_pipelines[key] = pipeline;
+        return flash_attn_pipelines[key];
+    }
+
+  private:
+    static webgpu_pipeline ggml_webgpu_create_pipeline(wgpu::Device & device,
+                                                       std::string    shader_code,
+                                                       std::string    label) {
+        wgpu::ShaderSourceWGSL shader_source;
+        shader_source.code = shader_code.c_str();
+
+        wgpu::ShaderModuleDescriptor shader_desc;
+        shader_desc.nextInChain = &shader_source;
+
+        wgpu::ShaderModule shader_module = device.CreateShaderModule(&shader_desc);
+
+        wgpu::ComputePipelineDescriptor pipeline_desc;
+        pipeline_desc.label              = label.c_str();
+        pipeline_desc.compute.module     = shader_module;
+        pipeline_desc.compute.entryPoint = "main";   // Entry point in the WGSL code
+        pipeline_desc.layout             = nullptr;  // nullptr means auto layout
+        return { device.CreateComputePipeline(&pipeline_desc), label };
+    }
+
+    static uint32_t ggml_webgpu_flash_attn_max_kv_tile(const ggml_webgpu_flash_attn_shader_lib_context & context) {
+        const size_t limit_bytes = context.wg_mem_limit_bytes;
+        const size_t q_tile      = context.sg_mat_m;
+        const size_t base_q_bytes =
+            (context.key.head_dim_qk + context.key.head_dim_v) * q_tile * GGML_WEBGPU_F16_SIZE_BYTES +
+            2 * q_tile * GGML_WEBGPU_F32_SIZE_BYTES;
+        size_t bytes_per_kv = 0;
+        if (!context.key.kv_direct) {
+            bytes_per_kv += std::max(context.key.head_dim_qk, context.key.head_dim_v);
+        }
+        if (context.key.has_mask) {
+            bytes_per_kv += q_tile;
+        }
+        bytes_per_kv += q_tile;
+        bytes_per_kv *= GGML_WEBGPU_F16_SIZE_BYTES;
+        const uint32_t max_kv_tile = (limit_bytes - base_q_bytes) / bytes_per_kv;
+        return (max_kv_tile / context.sg_mat_n) * context.sg_mat_n;
+    }
 };
 
-inline ggml_webgpu_processed_shader ggml_webgpu_preprocess_binary_shader(
-    pre_wgsl::Preprocessor &                      preprocessor,
-    const char *                                  shader_src,
-    const ggml_webgpu_binary_shader_lib_context & context) {
-    std::vector<std::string> defines;
-    std::string              op_name = ggml_op_name((ggml_op) context.key.op);
-    std::string              variant = op_name;
-
-    defines.push_back(std::string("OP_") + op_name);
-
-    switch (context.key.type) {
-        case GGML_TYPE_F32:
-            defines.push_back("TYPE_F32");
-            variant += "_f32";
-            break;
-        case GGML_TYPE_F16:
-            defines.push_back("TYPE_F16");
-            variant += "_f16";
-            break;
-        default:
-            GGML_ABORT("Unsupported type for binary shader");
-    }
-
-    if (context.key.inplace) {
-        defines.push_back("INPLACE");
-        variant += "_inplace";
-    } else if (context.key.overlap) {
-        defines.push_back("OVERLAP");
-        variant += "_overlap";
-    }
-
-    defines.push_back(std::string("WG_SIZE=") + std::to_string(context.max_wg_size));
-    ggml_webgpu_processed_shader result;
-    result.wgsl        = preprocessor.preprocess(shader_src, defines);
-    result.variant     = variant;
-    auto decisions     = std::make_shared<ggml_webgpu_generic_shader_decisions>();
-    decisions->wg_size = context.max_wg_size;
-    result.decisions   = decisions;
-    return result;
-}
 #endif  // GGML_WEBGPU_SHADER_LIB_HPP
index 32e120266a9bb477ee8a0979694eda729260dbd6..17bb2f47126247c228fd2b54cb41be9b011715fa 100644 (file)
@@ -8,7 +8,6 @@
 #include "ggml-backend-impl.h"
 #include "ggml-impl.h"
 #include "ggml-webgpu-shader-lib.hpp"
-#include "ggml-wgsl-shaders.hpp"
 #include "pre_wgsl.hpp"
 
 #ifdef __EMSCRIPTEN__
@@ -23,6 +22,7 @@
 #include <cstring>
 #include <iostream>
 #include <map>
+#include <memory>
 #include <mutex>
 #include <optional>
 #include <string>
 
 /* Constants */
 
-// Track https://github.com/gpuweb/gpuweb/issues/5315 for fixes to implementations so this can be removed.
-#define WEBGPU_MAX_WG_SIZE 288
-
-#define WEBGPU_MUL_MAT_WG_SIZE               256
 #define WEBGPU_NUM_PARAM_BUFS                16u
 #define WEBGPU_COMMAND_SUBMIT_BATCH_SIZE     8u
 #define WEBGPU_WAIT_ANY_TIMEOUT_MS           0
-// Maximum number of in-flight submissions per-thread, to avoid exhausting the parameter buffer pool
+// Maximum number of in-flight submissions per-thread, to avoid exhausting the
+// parameter buffer pool
 #define WEBGPU_MAX_INFLIGHT_SUBS_PER_THREAD  WEBGPU_NUM_PARAM_BUFS / WEBGPU_COMMAND_SUBMIT_BATCH_SIZE
 #define WEBGPU_PARAMS_BUF_SIZE_BYTES         128  // enough for 32 parameters
 #define WEBGPU_NUM_SET_ROWS_ERROR_BUFS       16
 #define WEBGPU_SET_ROWS_ERROR_BUF_SIZE_BYTES 4
 #define WEBGPU_STORAGE_BUF_BINDING_MULT      4  // a storage buffer binding size must be a multiple of 4
 
-// For operations which process a row in parallel, this seems like a reasonable default
+// For operations which process a row in parallel, this seems like a reasonable
+// default
 #define WEBGPU_ROW_SPLIT_WG_SIZE 64
 
-// Matrix multiplication parameters
-
-// Register tiling parameters
-#define WEBGPU_MUL_MAT_TILE_M    8
-#define WEBGPU_MUL_MAT_TILE_N    8
-#define WEBGPU_MUL_MAT_WG_SIZE_M 8
-#define WEBGPU_MUL_MAT_WG_SIZE_N 8
-#define WEBGPU_MUL_MAT_TILE_K    32
-
-// Subgroup matrix parameters
-// The number of subgroups in the M dimension
-#define WEBGPU_MUL_MAT_SUBGROUP_M        2
-// The number of subgroups in the N dimension
-#define WEBGPU_MUL_MAT_SUBGROUP_N        2
-// The number of subgroup matrices each subgroup accumulates over
-#define WEBGPU_MUL_MAT_SUBGROUP_MATRIX_M 4
-#define WEBGPU_MUL_MAT_SUBGROUP_MATRIX_N 2
-
-// Matrix-vector multiplication parameters
-#define WEBGPU_MUL_MAT_VEC_WG_SIZE        256
-// Must be multiple of 4 to work with vectorized paths, and must divide mul_mat_vec wg size
-#define WEBGPU_MUL_MAT_VEC_OUTPUTS_PER_WG 64
-#define WEBGPU_MUL_MAT_VEC_TILE_K         256
+// Track https://github.com/gpuweb/gpuweb/issues/5315 for fixes to
+// implementations so this can be removed, necessary only for get_rows right now
+#define WEBGPU_MAX_WG_SIZE 288
 
 /* End Constants */
 
-// This is a "fake" base pointer, since WebGPU buffers do not have pointers to their locations.
+// This is a "fake" base pointer, since WebGPU buffers do not have pointers to
+// their locations.
 static void * const webgpu_ptr_base = (void *) (uintptr_t) 0x1000;  // NOLINT
 
 // Always returns the base offset of a tensor, regardless of views.
@@ -263,12 +242,6 @@ struct webgpu_gpu_profile_buf_pool {
 };
 #endif
 
-struct webgpu_pipeline {
-    wgpu::ComputePipeline pipeline;
-    std::string           name;
-    std::shared_ptr<void> context = nullptr;
-};
-
 struct webgpu_command {
     wgpu::CommandBuffer             commands;
     std::vector<webgpu_pool_bufs>   params_bufs;
@@ -353,41 +326,18 @@ struct webgpu_context_struct {
     // Points to global instances owned by ggml_backend_webgpu_reg_context
     webgpu_global_context global_ctx;
 
-    pre_wgsl::Preprocessor p;
+    std::unique_ptr<ggml_webgpu_shader_lib> shader_lib;
 
     webgpu_buf_pool param_buf_pool;
     webgpu_buf_pool set_rows_error_buf_pool;
 
-    std::map<int, std::map<int, std::map<int, webgpu_pipeline>>> mul_mat_pipelines;  // src0_type, src1_type, vectorized
-    std::map<int, std::map<int, std::map<int, webgpu_pipeline>>>
-        mul_mat_vec_pipelines;                                                       // src0_type, src1_type, vectorized
-
-    std::unordered_map<ggml_webgpu_flash_attn_pipeline_key, webgpu_pipeline, ggml_webgpu_flash_attn_pipeline_key_hash>
-        flash_attn_pipelines;
-
-    std::unordered_map<int, webgpu_pipeline> argmax_pipelines;         // key is vec4
-    std::unordered_map<int, webgpu_pipeline> argsort_pipelines;        // key is order (asc/desc)
-    std::unordered_map<int, webgpu_pipeline> argsort_merge_pipelines;  // key is order (asc/desc)
-    std::unordered_map<int, webgpu_pipeline> cumsum_pipelines;         // key is fixed, no variants yet
-    std::unordered_map<int, webgpu_pipeline> sum_rows_pipelines;       // key is fixed, no variants yet
-
-    std::unordered_map<ggml_webgpu_set_rows_pipeline_key, webgpu_pipeline, ggml_webgpu_set_rows_pipeline_key_hash>
-                                                  set_rows_pipelines;
-    std::map<int, std::map<int, webgpu_pipeline>> get_rows_pipelines;  // src_type, vectorized
-
-    std::map<int, std::map<int, webgpu_pipeline>> cpy_pipelines;       // src_type, dst_type
-
-    std::unordered_map<ggml_webgpu_binary_pipeline_key, webgpu_pipeline, ggml_webgpu_binary_pipeline_key_hash>
-        binary_pipelines;
+    std::map<int, std::map<int, webgpu_pipeline>> cpy_pipelines;                      // src_type, dst_type
 
     std::map<int, webgpu_pipeline>                               rms_norm_pipelines;  // inplace
     std::map<int, std::map<int, std::map<int, webgpu_pipeline>>> rope_pipelines;      // type, ff, inplace
     std::map<int, std::map<int, std::map<int, webgpu_pipeline>>> glu_pipelines;       // glu_op, type, split
-    std::map<int, webgpu_pipeline>                               scale_pipelines;     // inplace
+
     std::map<int, std::map<int, std::map<int, webgpu_pipeline>>> soft_max_pipelines;  // mask_type, has_sink, inplace
-    std::unordered_map<ggml_webgpu_unary_pipeline_key, webgpu_pipeline, ggml_webgpu_unary_pipeline_key_hash>
-        unary_pipelines;
-    std::unordered_map<ggml_webgpu_pad_pipeline_key, webgpu_pipeline, ggml_webgpu_pad_pipeline_key_hash> pad_pipelines;
 
     size_t memset_bytes_per_thread;
 };
@@ -429,25 +379,6 @@ struct ggml_backend_webgpu_buffer_context {
 
 /* WebGPU object initializations */
 
-// Process a WGSL shader string, replacing tokens of the form {{KEY}} with
-// the corresponding values provided in `repls`.
-static std::string ggml_webgpu_process_shader_repls(const char *                               src,
-                                                    const std::map<std::string, std::string> & repls) {
-    if (!src) {
-        return std::string();
-    }
-    std::string s = src;
-    for (const auto & kv : repls) {
-        std::string token = "{{" + kv.first + "}}";
-        size_t      pos   = 0;
-        while ((pos = s.find(token, pos)) != std::string::npos) {
-            s.replace(pos, token.length(), kv.second);
-            pos += kv.second.length();
-        }
-    }
-    return s;
-}
-
 static webgpu_pipeline ggml_webgpu_create_pipeline(wgpu::Device &                           device,
                                                    const char *                             shader_code,
                                                    const char *                             label,
@@ -495,8 +426,9 @@ static void ggml_webgpu_create_buffer(wgpu::Device &    device,
 static void ggml_backend_webgpu_wait(webgpu_global_context &                  ctx,
                                      std::vector<webgpu_submission_futures> & futures,
                                      bool                                     block = true) {
-    // If we have too many in-flight submissions, wait on the oldest one first. If there are many threads,
-    // inflight_max may be 0, meaning that we must wait on all futures.
+    // If we have too many in-flight submissions, wait on the oldest one first. If
+    // there are many threads, inflight_max may be 0, meaning that we must wait on
+    // all futures.
     uint64_t timeout_ms       = block ? UINT64_MAX : 0;
     uint32_t inflight_threads = ctx->inflight_threads;
     uint32_t inflight_max     = WEBGPU_MAX_INFLIGHT_SUBS_PER_THREAD / std::max(inflight_threads, 1u);
@@ -681,7 +613,8 @@ static webgpu_command ggml_backend_webgpu_build_multi(
         encoder.CopyBufferToBuffer(params_bufs.host_buf, 0, params_bufs.dev_buf, 0, params_bufs.dev_buf.GetSize());
     }
 
-    // If there are SET_ROWS operations in this submission, copy their error buffers to the host.
+    // If there are SET_ROWS operations in this submission, copy their error
+    // buffers to the host.
     if (set_rows_error_bufs) {
         encoder.CopyBufferToBuffer(set_rows_error_bufs->dev_buf, 0, set_rows_error_bufs->host_buf, 0,
                                    set_rows_error_bufs->host_buf.GetSize());
@@ -900,24 +833,11 @@ static webgpu_command ggml_webgpu_cpy(webgpu_context & ctx, ggml_tensor * src, g
 }
 
 static webgpu_command ggml_webgpu_pad(webgpu_context & ctx, ggml_tensor * src, ggml_tensor * dst) {
-    const bool circular = ggml_get_op_params_i32(dst, 8) != 0;
-
-    ggml_webgpu_pad_pipeline_key       pipeline_key   = { .circular = circular };
-    ggml_webgpu_pad_shader_lib_context shader_lib_ctx = {
-        .key = pipeline_key, .max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup
+    ggml_webgpu_shader_lib_context shader_lib_ctx = {
+        .src0 = src, .dst = dst, .max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup
     };
 
-    webgpu_pipeline pipeline;
-    auto            it = ctx->pad_pipelines.find(pipeline_key);
-    if (it != ctx->pad_pipelines.end()) {
-        pipeline = it->second;
-    } else {
-        ggml_webgpu_processed_shader processed = ggml_webgpu_preprocess_pad_shader(ctx->p, wgsl_pad, shader_lib_ctx);
-        pipeline =
-            ggml_webgpu_create_pipeline(ctx->global_ctx->device, processed.wgsl.c_str(), processed.variant.c_str());
-        pipeline.context = processed.decisions;
-        ctx->pad_pipelines.emplace(pipeline_key, pipeline);
-    }
+    webgpu_pipeline pipeline = ctx->shader_lib->get_pad_pipeline(shader_lib_ctx);
 
     auto * decisions = static_cast<ggml_webgpu_generic_shader_decisions *>(pipeline.context.get());
 
@@ -971,36 +891,25 @@ static std::optional<webgpu_command> ggml_webgpu_set_rows(webgpu_context & ctx,
                                                           ggml_tensor *    src,
                                                           ggml_tensor *    idx,
                                                           ggml_tensor *    dst) {
-    // For set rows specifically, we need to check if src and idx are empty tensors.
+    // For set rows specifically, we need to check if src and idx are empty
+    // tensors.
     if (ggml_is_empty(src) || ggml_is_empty(idx)) {
         return std::nullopt;
     }
 
-    ggml_webgpu_set_rows_pipeline_key key = { .dst_type = dst->type,
-                                              .vec4     = src->ne[0] % 4 == 0,
-                                              .i64_idx  = idx->type == GGML_TYPE_I64 };
-
-    ggml_webgpu_set_rows_shader_lib_context shader_lib_ctx = {
-        .key = key, .max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup
+    ggml_webgpu_shader_lib_context shader_lib_ctx = {
+        .src0        = src,
+        .src1        = idx,
+        .dst         = dst,
+        .max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup
     };
 
-    webgpu_pipeline pipeline;
-    auto            it = ctx->set_rows_pipelines.find(key);
-    if (it != ctx->set_rows_pipelines.end()) {
-        pipeline = it->second;
-    } else {
-        ggml_webgpu_processed_shader processed =
-            ggml_webgpu_preprocess_set_rows_shader(ctx->p, wgsl_set_rows, shader_lib_ctx);
-        pipeline =
-            ggml_webgpu_create_pipeline(ctx->global_ctx->device, processed.wgsl.c_str(), processed.variant.c_str());
-        pipeline.context = processed.decisions;
-        ctx->set_rows_pipelines.emplace(key, pipeline);
-    }
+    webgpu_pipeline pipeline = ctx->shader_lib->get_set_rows_pipeline(shader_lib_ctx);
 
-    auto * decisions = static_cast<ggml_webgpu_generic_shader_decisions *>(pipeline.context.get());
+    auto * decisions = static_cast<ggml_webgpu_set_rows_shader_decisions *>(pipeline.context.get());
 
     std::optional<webgpu_pool_bufs> error_bufs = std::nullopt;
-    if (key.i64_idx) {
+    if (decisions->i64_idx) {
         error_bufs = ctx->set_rows_error_buf_pool.alloc_bufs();
         if (error_bufs->host_buf.GetMapState() == wgpu::BufferMapState::Mapped) {
             error_bufs->host_buf.Unmap();
@@ -1038,13 +947,13 @@ static std::optional<webgpu_command> ggml_webgpu_set_rows(webgpu_context & ctx,
          .size    = ggml_webgpu_tensor_binding_size(ctx, dst) }
     };
 
-    if (key.i64_idx) {
+    if (decisions->i64_idx) {
         entries.push_back(
             { .binding = 3, .buffer = error_bufs->dev_buf, .offset = 0, .size = error_bufs->dev_buf.GetSize() });
     }
 
     uint32_t threads;
-    if (key.vec4) {
+    if (decisions->vec4) {
         threads = (src->ne[1] * src->ne[2] * src->ne[3]) * (src->ne[0] / 4);
     } else {
         threads = src->ne[0] * src->ne[1] * src->ne[2] * src->ne[3];
@@ -1054,26 +963,47 @@ static std::optional<webgpu_command> ggml_webgpu_set_rows(webgpu_context & ctx,
                                      error_bufs);
 }
 
+// Workgroup size is a common constant
+static std::vector<wgpu::ConstantEntry> ggml_webgpu_wg_size_entry(uint32_t wg_size) {
+    std::vector<wgpu::ConstantEntry> constants(1);
+    constants[0].key   = "wg_size";
+    constants[0].value = wg_size;
+    return constants;
+}
+
 static webgpu_command ggml_webgpu_get_rows(webgpu_context & ctx,
                                            ggml_tensor *    src,
                                            ggml_tensor *    idx,
                                            ggml_tensor *    dst) {
-    std::vector<uint32_t> params = {
-        (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src) / ggml_type_size(src->type)),
-        (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, idx) / ggml_type_size(idx->type)),
-        (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, dst) / ggml_type_size(dst->type)),
-        // Convert byte-strides to element-strides
-        (uint32_t) (src->nb[1] / ggml_type_size(src->type)), (uint32_t) (src->nb[2] / ggml_type_size(src->type)),
-        (uint32_t) (src->nb[3] / ggml_type_size(src->type)), (uint32_t) (idx->nb[0] / ggml_type_size(idx->type)),
-        (uint32_t) (idx->nb[1] / ggml_type_size(idx->type)), (uint32_t) (idx->nb[2] / ggml_type_size(idx->type)),
-        (uint32_t) (dst->nb[1] / ggml_type_size(dst->type)), (uint32_t) (dst->nb[2] / ggml_type_size(dst->type)),
-        (uint32_t) (dst->nb[3] / ggml_type_size(dst->type)),
-        // Shape of dst
-        (uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2], (uint32_t) dst->ne[3],
-        // Shape of idx
-        (uint32_t) (idx->ne[1]), (uint32_t) (idx->ne[2])
+    ggml_webgpu_shader_lib_context shader_lib_ctx = {
+        .src0        = src,
+        .src1        = nullptr,
+        .dst         = dst,
+        .max_wg_size = WEBGPU_MAX_WG_SIZE,
     };
 
+    webgpu_pipeline pipeline  = ctx->shader_lib->get_get_rows_pipeline(shader_lib_ctx);
+    auto *          decisions = static_cast<ggml_webgpu_generic_shader_decisions *>(pipeline.context.get());
+
+    std::vector<uint32_t> params = { (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src) / ggml_type_size(src->type)),
+                                     (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, idx) / ggml_type_size(idx->type)),
+                                     (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, dst) / ggml_type_size(dst->type)),
+                                     (uint32_t) (src->nb[1] / ggml_type_size(src->type)),
+                                     (uint32_t) (src->nb[2] / ggml_type_size(src->type)),
+                                     (uint32_t) (src->nb[3] / ggml_type_size(src->type)),
+                                     (uint32_t) (idx->nb[0] / ggml_type_size(idx->type)),
+                                     (uint32_t) (idx->nb[1] / ggml_type_size(idx->type)),
+                                     (uint32_t) (idx->nb[2] / ggml_type_size(idx->type)),
+                                     (uint32_t) (dst->nb[1] / ggml_type_size(dst->type)),
+                                     (uint32_t) (dst->nb[2] / ggml_type_size(dst->type)),
+                                     (uint32_t) (dst->nb[3] / ggml_type_size(dst->type)),
+                                     (uint32_t) dst->ne[0],
+                                     (uint32_t) dst->ne[1],
+                                     (uint32_t) dst->ne[2],
+                                     (uint32_t) dst->ne[3],
+                                     (uint32_t) (idx->ne[1]),
+                                     (uint32_t) (idx->ne[2]) };
+
     std::vector<wgpu::BindGroupEntry> entries = {
         { .binding = 0,
          .buffer  = ggml_webgpu_tensor_buf(src),
@@ -1089,10 +1019,8 @@ static webgpu_command ggml_webgpu_get_rows(webgpu_context & ctx,
          .size    = ggml_webgpu_tensor_binding_size(ctx, dst) }
     };
 
-    uint32_t wg_x = CEIL_DIV(dst->ne[1] * dst->ne[2] * dst->ne[3], WEBGPU_MAX_WG_SIZE);
+    uint32_t wg_x = CEIL_DIV(dst->ne[1] * dst->ne[2] * dst->ne[3], decisions->wg_size);
 
-    uint32_t        vectorized = src->type == GGML_TYPE_F32 && dst->ne[0] % 4 == 0;
-    webgpu_pipeline pipeline   = ctx->get_rows_pipelines[src->type][vectorized];
     return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_buf_pool, pipeline, params, entries, wg_x);
 }
 
@@ -1100,25 +1028,74 @@ static webgpu_command ggml_webgpu_mul_mat(webgpu_context & ctx,
                                           ggml_tensor *    src0,
                                           ggml_tensor *    src1,
                                           ggml_tensor *    dst) {
+    // Determine if this is a mat-vec operation
+    bool is_vec = (dst->ne[1] == 1);
+
+    // Determine if we should use fast path
+    bool use_fast = false;
+    switch (src1->type) {
+        case GGML_TYPE_F16:
+            use_fast = (src0->type == GGML_TYPE_F16);
+            break;
+        case GGML_TYPE_F32:
+            switch (src0->type) {
+                case GGML_TYPE_F32:
+                case GGML_TYPE_F16:
+                case GGML_TYPE_Q4_0:
+                    use_fast = true;
+                    break;
+                default:
+                    break;
+            }
+            break;
+        default:
+            break;
+    }
+
+    ggml_webgpu_shader_lib_context shader_lib_ctx = {
+        .src0                     = src0,
+        .src1                     = src1,
+        .dst                      = dst,
+        .max_wg_size              = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup,
+        .supports_subgroup_matrix = ctx->global_ctx->capabilities.supports_subgroup_matrix,
+        .sg_mat_m                 = ctx->global_ctx->capabilities.sg_mat_m,
+        .sg_mat_n                 = ctx->global_ctx->capabilities.sg_mat_n,
+        .sg_mat_k                 = ctx->global_ctx->capabilities.sg_mat_k,
+        .max_subgroup_size        = ctx->global_ctx->capabilities.max_subgroup_size,
+    };
+
+    // Get or create pipeline
+    webgpu_pipeline pipeline;
+
+    if (use_fast && is_vec) {
+        pipeline = ctx->shader_lib->get_mul_mat_vec_pipeline(shader_lib_ctx);
+    } else if (use_fast) {
+        pipeline = ctx->shader_lib->get_mul_mat_fast_pipeline(shader_lib_ctx);
+    } else {
+        pipeline = ctx->shader_lib->get_mul_mat_legacy_pipeline(shader_lib_ctx);
+    }
+
+    // Build params
     std::vector<uint32_t> params = {
         (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src0) / ggml_type_size(src0->type)),
         (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src1) / ggml_type_size(src1->type)),
         (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, dst) / ggml_type_size(dst->type)),
-        (uint32_t) dst->ne[0],                                  // number of rows in result (M, transposed)
-        (uint32_t) dst->ne[1],                                  // number of columns in result (N)
-        (uint32_t) src0->ne[0],                                 // number of columns in src0/src1 (K)
-        (uint32_t) (src0->nb[1] / ggml_type_size(src0->type)),  // stride (elements/blocks) of src0 in dimension 1
-        (uint32_t) (src1->nb[1] / ggml_type_size(src1->type)),  // stride (elements/blocks) of src1 in dimension 1
-        (uint32_t) (src0->nb[2] / ggml_type_size(src0->type)),  // stride (elements/blocks) of src0 in dimension 2
-        (uint32_t) (src1->nb[2] / ggml_type_size(src1->type)),  // stride (elements/blocks) of src1 in dimension 2
-        (uint32_t) (src0->nb[3] / ggml_type_size(src0->type)),  // stride (elements/blocks) of src0 in dimension 3
-        (uint32_t) (src1->nb[3] / ggml_type_size(src1->type)),  // stride (elements/blocks) of src1 in dimension 3
-        (uint32_t) src0->ne[2],                                 // batch size in dimension 2
-        (uint32_t) src0->ne[3],                                 // batch size in dimension 3
-        (uint32_t) (src1->ne[2] / src0->ne[2]),                 // broadcast in dimension 2
-        (uint32_t) (src1->ne[3] / src0->ne[3])                  // broadcast in dimension 3
+        (uint32_t) dst->ne[0],
+        (uint32_t) dst->ne[1],
+        (uint32_t) src0->ne[0],
+        (uint32_t) (src0->nb[1] / ggml_type_size(src0->type)),
+        (uint32_t) (src1->nb[1] / ggml_type_size(src1->type)),
+        (uint32_t) (src0->nb[2] / ggml_type_size(src0->type)),
+        (uint32_t) (src1->nb[2] / ggml_type_size(src1->type)),
+        (uint32_t) (src0->nb[3] / ggml_type_size(src0->type)),
+        (uint32_t) (src1->nb[3] / ggml_type_size(src1->type)),
+        (uint32_t) src0->ne[2],
+        (uint32_t) src0->ne[3],
+        (uint32_t) (src1->ne[2] / src0->ne[2]),
+        (uint32_t) (src1->ne[3] / src0->ne[3])
     };
 
+    // Build bind group entries
     std::vector<wgpu::BindGroupEntry> entries = {
         { .binding = 0,
          .buffer  = ggml_webgpu_tensor_buf(src0),
@@ -1134,68 +1111,44 @@ static webgpu_command ggml_webgpu_mul_mat(webgpu_context & ctx,
          .size    = ggml_webgpu_tensor_binding_size(ctx, dst)  },
     };
 
-    webgpu_pipeline pipeline = ctx->mul_mat_pipelines[src0->type][src1->type][0];
-
-    uint32_t wg_x = CEIL_DIV(dst->ne[0] * dst->ne[1] * dst->ne[2] * dst->ne[3], WEBGPU_MUL_MAT_WG_SIZE);
+    // Calculate workgroup dimensions
+    uint32_t wg_x = 1;
     uint32_t wg_y = 1;
 
-    bool use_fast = false;
-    switch (src1->type) {
-        case GGML_TYPE_F16:
-            use_fast = (src0->type == GGML_TYPE_F16);
-            break;
-        case GGML_TYPE_F32:
-            switch (src0->type) {
-                case GGML_TYPE_F32:
-                case GGML_TYPE_F16:
-                case GGML_TYPE_Q4_0:
-                    use_fast = true;
-                    break;
-                default:
-                    break;
-            }
-            break;
-        default:
-            break;
-    }
-
-    if (use_fast) {
-        int vectorized = src0->ne[0] % 4 == 0 && dst->ne[0] % 4 == 0 && dst->ne[1] % 4 == 0;
-        if (dst->ne[1] == 1) {
-            // We don't support vectorized mul_mat_vec for quantized types
-            vectorized             = vectorized && (src0->type < 2);
-            pipeline               = ctx->mul_mat_vec_pipelines[src0->type][src1->type][vectorized];
-            uint32_t batches       = dst->ne[2] * dst->ne[3];
-            uint32_t output_groups = CEIL_DIV(dst->ne[0], WEBGPU_MUL_MAT_VEC_OUTPUTS_PER_WG);
-            uint32_t total_wg      = output_groups * batches;
-            wg_x                   = total_wg % ctx->global_ctx->capabilities.limits.maxComputeWorkgroupsPerDimension;
-            wg_y = CEIL_DIV(total_wg, ctx->global_ctx->capabilities.limits.maxComputeWorkgroupsPerDimension);
+    if (use_fast && is_vec) {
+        auto decisions = static_cast<ggml_webgpu_mul_mat_vec_shader_decisions *>(pipeline.context.get());
+
+        uint32_t batches       = dst->ne[2] * dst->ne[3];
+        uint32_t output_groups = CEIL_DIV(dst->ne[0], decisions->outputs_per_wg);
+        uint32_t total_wg      = output_groups * batches;
+        wg_x                   = total_wg % ctx->global_ctx->capabilities.limits.maxComputeWorkgroupsPerDimension;
+        wg_y = CEIL_DIV(total_wg, ctx->global_ctx->capabilities.limits.maxComputeWorkgroupsPerDimension);
+    } else if (use_fast) {
+        auto decisions = static_cast<ggml_webgpu_mul_mat_shader_decisions *>(pipeline.context.get());
+
+        // Fast-path tiled/subgroup calculations
+        uint32_t wg_m, wg_n;
+        if (decisions->use_subgroup_matrix) {
+            uint32_t wg_m_sg_tile =
+                decisions->subgroup_m * decisions->subgroup_matrix_m * ctx->global_ctx->capabilities.sg_mat_m;
+            wg_m = CEIL_DIV(dst->ne[0], wg_m_sg_tile);
+            uint32_t wg_n_sg_tile =
+                decisions->subgroup_n * decisions->subgroup_matrix_n * ctx->global_ctx->capabilities.sg_mat_n;
+            wg_n = CEIL_DIV(dst->ne[1], wg_n_sg_tile);
         } else {
-            pipeline = ctx->mul_mat_pipelines[src0->type][src1->type][vectorized];
-            uint32_t wg_m;
-            uint32_t wg_n;
-#ifndef __EMSCRIPTEN__
-            if (ctx->global_ctx->capabilities.supports_subgroup_matrix) {
-                // The total number of subgroups/workgroups needed per matrix.
-                uint32_t wg_m_sg_tile = WEBGPU_MUL_MAT_SUBGROUP_M * WEBGPU_MUL_MAT_SUBGROUP_MATRIX_M *
-                                        ctx->global_ctx->capabilities.sg_mat_m;
-                wg_m                  = CEIL_DIV(dst->ne[0], wg_m_sg_tile);
-                uint32_t wg_n_sg_tile = WEBGPU_MUL_MAT_SUBGROUP_N * WEBGPU_MUL_MAT_SUBGROUP_MATRIX_N *
-                                        ctx->global_ctx->capabilities.sg_mat_n;
-                wg_n = CEIL_DIV(dst->ne[1], wg_n_sg_tile);
-            } else {
-#endif
-                uint32_t tile_m_s = WEBGPU_MUL_MAT_TILE_M * WEBGPU_MUL_MAT_WG_SIZE_M;
-                uint32_t tile_n_s = WEBGPU_MUL_MAT_TILE_N * WEBGPU_MUL_MAT_WG_SIZE_N;
-                wg_m              = CEIL_DIV(dst->ne[0], tile_m_s);
-                wg_n              = CEIL_DIV(dst->ne[1], tile_n_s);
-#ifndef __EMSCRIPTEN__
-            }
-#endif
-
-            wg_x = wg_m * wg_n * dst->ne[2] * dst->ne[3];
+            uint32_t tile_m_s = decisions->tile_m * decisions->wg_size_m;
+            uint32_t tile_n_s = decisions->tile_n * decisions->wg_size_n;
+            wg_m              = CEIL_DIV(dst->ne[0], tile_m_s);
+            wg_n              = CEIL_DIV(dst->ne[1], tile_n_s);
         }
+        wg_x = wg_m * wg_n * dst->ne[2] * dst->ne[3];
+    } else {  // legacy
+        auto     decisions = static_cast<ggml_webgpu_generic_shader_decisions *>(pipeline.context.get());
+        uint32_t wg_size   = decisions->wg_size;
+        wg_x               = CEIL_DIV(dst->ne[0] * dst->ne[1] * dst->ne[2] * dst->ne[3], wg_size);
+        wg_y               = 1;
     }
+
     return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_buf_pool, pipeline, params, entries, wg_x, wg_y);
 }
 
@@ -1283,40 +1236,22 @@ static webgpu_command ggml_webgpu_flash_attn(webgpu_context & ctx,
                         .offset  = ggml_webgpu_tensor_align_offset(ctx, dst),
                         .size    = ggml_webgpu_tensor_binding_size(ctx, dst) });
 
-    bool kv_direct = (K->type == GGML_TYPE_F16) && (Q->ne[0] % ctx->global_ctx->capabilities.sg_mat_k == 0) &&
-                     (K->ne[1] % GGML_WEBGPU_KV_SEQ_PAD == 0);
-
-    ggml_webgpu_flash_attn_pipeline_key key = {
-        .kv_type            = K->type,
-        .head_dim_qk        = (uint32_t) Q->ne[0],
-        .head_dim_v         = (uint32_t) V->ne[0],
-        .kv_direct          = kv_direct,
-        .has_mask           = static_cast<bool>(has_mask),
-        .has_sinks          = static_cast<bool>(has_sinks),
-        .uses_logit_softcap = logit_softcap != 0.0f,
+    ggml_webgpu_shader_lib_context shader_lib_ctx = {
+        .src0               = Q,
+        .src1               = K,
+        .src2               = V,
+        .src3               = mask,
+        .src4               = sinks,
+        .dst                = dst,
+        .max_wg_size        = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup,
+        .wg_mem_limit_bytes = ctx->global_ctx->capabilities.limits.maxComputeWorkgroupStorageSize,
+        .sg_mat_m           = ctx->global_ctx->capabilities.sg_mat_m,
+        .sg_mat_n           = ctx->global_ctx->capabilities.sg_mat_n,
+        .sg_mat_k           = ctx->global_ctx->capabilities.sg_mat_k,
+        .max_subgroup_size  = ctx->global_ctx->capabilities.max_subgroup_size,
     };
 
-    webgpu_pipeline pipeline;
-    auto            it = ctx->flash_attn_pipelines.find(key);
-    if (it != ctx->flash_attn_pipelines.end()) {
-        pipeline = it->second;
-    } else {
-        ggml_webgpu_flash_attn_shader_lib_context shader_lib_ctx = {
-            .key                = key,
-            .sg_mat_m           = ctx->global_ctx->capabilities.sg_mat_m,
-            .sg_mat_n           = ctx->global_ctx->capabilities.sg_mat_n,
-            .sg_mat_k           = ctx->global_ctx->capabilities.sg_mat_k,
-            .wg_mem_limit_bytes = ctx->global_ctx->capabilities.limits.maxComputeWorkgroupStorageSize,
-            .max_subgroup_size  = ctx->global_ctx->capabilities.max_subgroup_size
-        };
-
-        ggml_webgpu_processed_shader processed =
-            ggml_webgpu_preprocess_flash_attn_shader(ctx->p, wgsl_flash_attn, shader_lib_ctx);
-        pipeline =
-            ggml_webgpu_create_pipeline(ctx->global_ctx->device, processed.wgsl.c_str(), processed.variant.c_str());
-        pipeline.context = processed.decisions;
-        ctx->flash_attn_pipelines.emplace(key, pipeline);
-    }
+    webgpu_pipeline pipeline = ctx->shader_lib->get_flash_attn_pipeline(shader_lib_ctx);
 
     auto * decisions = static_cast<ggml_webgpu_flash_attn_shader_decisions *>(pipeline.context.get());
 
@@ -1329,27 +1264,16 @@ static webgpu_command ggml_webgpu_flash_attn(webgpu_context & ctx,
 static webgpu_command ggml_webgpu_unary_op(webgpu_context & ctx, ggml_tensor * src, ggml_tensor * dst) {
     bool is_unary = dst->op == GGML_OP_UNARY;
     bool inplace  = ggml_webgpu_tensor_equal(src, dst) || (dst->op == GGML_OP_FILL);
-    int  op       = is_unary ? (int) ggml_get_unary_op(dst) : dst->op;
 
-    ggml_webgpu_unary_pipeline_key pipeline_key = {
-        .type = dst->type, .op = op, .is_unary = is_unary, .inplace = inplace
-    };
-    ggml_webgpu_unary_shader_lib_context shader_lib_ctx = {
-        .key = pipeline_key, .max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup
+    ggml_webgpu_shader_lib_context shader_lib_ctx = {
+        .src0        = src,
+        .src1        = nullptr,
+        .dst         = dst,
+        .max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup,
+        .inplace     = inplace,
     };
 
-    webgpu_pipeline pipeline;
-    auto            it = ctx->unary_pipelines.find(pipeline_key);
-    if (it != ctx->unary_pipelines.end()) {
-        pipeline = it->second;
-    } else {
-        ggml_webgpu_processed_shader processed =
-            ggml_webgpu_preprocess_unary_shader(ctx->p, wgsl_unary, shader_lib_ctx);
-        pipeline =
-            ggml_webgpu_create_pipeline(ctx->global_ctx->device, processed.wgsl.c_str(), processed.variant.c_str());
-        pipeline.context = processed.decisions;
-        ctx->unary_pipelines.emplace(pipeline_key, pipeline);
-    }
+    webgpu_pipeline pipeline = ctx->shader_lib->get_unary_pipeline(shader_lib_ctx);
 
     auto * decisions = static_cast<ggml_webgpu_generic_shader_decisions *>(pipeline.context.get());
 
@@ -1421,30 +1345,18 @@ static webgpu_command ggml_webgpu_binary_op(webgpu_context & ctx,
                                             ggml_tensor *    dst) {
     binary_overlap_flags flags = ggml_webgpu_detect_binary_overlap(src0, src1, dst);
 
-    ggml_webgpu_binary_pipeline_key pipeline_key = {
-        .type    = dst->type,
-        .op      = dst->op,
-        .inplace = flags.inplace,
-        .overlap = flags.overlap,
-    };
-    ggml_webgpu_binary_shader_lib_context shader_lib_ctx = {
-        .key = pipeline_key, .max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup
+    ggml_webgpu_shader_lib_context shader_lib_ctx = {
+        .src0        = src0,
+        .src1        = src1,
+        .dst         = dst,
+        .max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup,
+        .inplace     = flags.inplace,
+        .overlap     = flags.overlap,
     };
 
-    webgpu_pipeline pipeline;
-    auto            it = ctx->binary_pipelines.find(pipeline_key);
-    if (it != ctx->binary_pipelines.end()) {
-        pipeline = it->second;
-    } else {
-        ggml_webgpu_processed_shader processed =
-            ggml_webgpu_preprocess_binary_shader(ctx->p, wgsl_binary, shader_lib_ctx);
-        pipeline =
-            ggml_webgpu_create_pipeline(ctx->global_ctx->device, processed.wgsl.c_str(), processed.variant.c_str());
-        pipeline.context = processed.decisions;
-        ctx->binary_pipelines.emplace(pipeline_key, pipeline);
-    }
+    webgpu_pipeline pipeline = ctx->shader_lib->get_binary_pipeline(shader_lib_ctx);
 
-    auto * decisions = static_cast<ggml_webgpu_argsort_shader_decisions *>(pipeline.context.get());
+    auto * decisions = static_cast<ggml_webgpu_generic_shader_decisions *>(pipeline.context.get());
 
     uint32_t ne = (uint32_t) ggml_nelements(dst);
 
@@ -1669,8 +1581,20 @@ static webgpu_command ggml_webgpu_glu(webgpu_context & ctx, ggml_tensor * src0,
 }
 
 static webgpu_command ggml_webgpu_scale(webgpu_context & ctx, ggml_tensor * src, ggml_tensor * dst) {
-    int inplace = ggml_webgpu_tensor_equal(src, dst);
+    bool inplace = ggml_webgpu_tensor_equal(src, dst);
 
+    ggml_webgpu_shader_lib_context shader_lib_ctx = {
+        .src0        = src,
+        .src1        = nullptr,
+        .dst         = dst,
+        .max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup,
+        .inplace     = inplace,
+    };
+
+    webgpu_pipeline pipeline  = ctx->shader_lib->get_scale_pipeline(shader_lib_ctx);
+    auto *          decisions = static_cast<ggml_webgpu_generic_shader_decisions *>(pipeline.context.get());
+
+    // params unchanged
     std::vector<uint32_t> params = {
         (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src) / ggml_type_size(src->type)),
         (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, dst) / ggml_type_size(dst->type)),
@@ -1688,12 +1612,14 @@ static webgpu_command ggml_webgpu_scale(webgpu_context & ctx, ggml_tensor * src,
         *(uint32_t *) &dst->op_params[1]  // bias
     };
 
+    // bindgroups unchanged
     std::vector<wgpu::BindGroupEntry> entries = {
         { .binding = 0,
          .buffer  = ggml_webgpu_tensor_buf(src),
          .offset  = ggml_webgpu_tensor_align_offset(ctx, src),
          .size    = ggml_webgpu_tensor_binding_size(ctx, src) }
     };
+
     if (!inplace) {
         entries.push_back({ .binding = 1,
                             .buffer  = ggml_webgpu_tensor_buf(dst),
@@ -1701,9 +1627,8 @@ static webgpu_command ggml_webgpu_scale(webgpu_context & ctx, ggml_tensor * src,
                             .size    = ggml_webgpu_tensor_binding_size(ctx, dst) });
     }
 
-    uint32_t wg_x = CEIL_DIV(ggml_nelements(dst), WEBGPU_MAX_WG_SIZE);
-    return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_buf_pool, ctx->scale_pipelines[inplace], params,
-                                     entries, wg_x);
+    uint32_t wg_x = CEIL_DIV(ggml_nelements(dst), decisions->wg_size);
+    return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_buf_pool, pipeline, params, entries, wg_x);
 }
 
 static webgpu_command ggml_webgpu_soft_max(webgpu_context & ctx,
@@ -1796,63 +1721,30 @@ static webgpu_command ggml_webgpu_argmax(webgpu_context & ctx, ggml_tensor * src
          .size    = ggml_webgpu_tensor_binding_size(ctx, dst) }
     };
 
-    ggml_webgpu_generic_shader_lib_context shader_lib_ctx = {
-        .vec4        = src->ne[0] % 4 == 0,
-        .max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup,
+    ggml_webgpu_shader_lib_context shader_lib_ctx = {
+        .src0 = src, .dst = dst, .max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup
     };
 
-    webgpu_pipeline pipeline;
-    auto            it = ctx->argmax_pipelines.find(shader_lib_ctx.vec4);
-    if (it != ctx->argmax_pipelines.end()) {
-        pipeline = it->second;
-    } else {
-        ggml_webgpu_processed_shader processed =
-            ggml_webgpu_preprocess_generic_shader(ctx->p, wgsl_argmax, shader_lib_ctx, "argmax");
-        pipeline =
-            ggml_webgpu_create_pipeline(ctx->global_ctx->device, processed.wgsl.c_str(), processed.variant.c_str());
-        ctx->argmax_pipelines.emplace(shader_lib_ctx.vec4, pipeline);
-    }
-    uint32_t wg_x = ggml_nelements(dst);
+    webgpu_pipeline pipeline = ctx->shader_lib->get_argmax_pipeline(shader_lib_ctx);
+    uint32_t        wg_x     = ggml_nelements(dst);
     return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_buf_pool, pipeline, params, entries, wg_x);
 }
 
 static webgpu_command ggml_webgpu_argsort(webgpu_context & ctx, ggml_tensor * src, ggml_tensor * dst) {
-    bool          is_top_k = dst->op == GGML_OP_TOP_K;
-    // ascending order is 0, descending order is 1
-    const int32_t order    = is_top_k ? (int32_t) GGML_SORT_ORDER_DESC : (int32_t) ggml_get_op_params_i32(dst, 0);
+    bool is_top_k = dst->op == GGML_OP_TOP_K;
 
-    ggml_webgpu_argsort_shader_lib_context shader_lib_ctx = {
+    ggml_webgpu_shader_lib_context shader_lib_ctx = {
+        .src0               = src,
+        .src1               = nullptr,
+        .dst                = dst,
         .max_wg_size        = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup,
         .wg_mem_limit_bytes = ctx->global_ctx->capabilities.limits.maxComputeWorkgroupStorageSize,
-        .order              = order
     };
 
-    webgpu_pipeline argsort_pipeline;
-    auto            it = ctx->argsort_pipelines.find(order);
-    if (it != ctx->argsort_pipelines.end()) {
-        argsort_pipeline = it->second;
-    } else {
-        ggml_webgpu_processed_shader processed =
-            ggml_webgpu_preprocess_argsort_shader(ctx->p, wgsl_argsort, shader_lib_ctx);
-        argsort_pipeline =
-            ggml_webgpu_create_pipeline(ctx->global_ctx->device, processed.wgsl.c_str(), processed.variant.c_str());
-        argsort_pipeline.context = processed.decisions;
-        ctx->argsort_pipelines.emplace(order, argsort_pipeline);
-    }
-    auto * argsort_decisions = static_cast<ggml_webgpu_argsort_shader_decisions *>(argsort_pipeline.context.get());
-
-    webgpu_pipeline argsort_merge_pipeline;
-    it = ctx->argsort_merge_pipelines.find(order);
-    if (it != ctx->argsort_merge_pipelines.end()) {
-        argsort_merge_pipeline = it->second;
-    } else {
-        ggml_webgpu_processed_shader processed =
-            ggml_webgpu_preprocess_argsort_merge_shader(ctx->p, wgsl_argsort_merge, shader_lib_ctx);
-        argsort_merge_pipeline =
-            ggml_webgpu_create_pipeline(ctx->global_ctx->device, processed.wgsl.c_str(), processed.variant.c_str());
-        argsort_merge_pipeline.context = processed.decisions;
-        ctx->argsort_merge_pipelines.emplace(order, argsort_merge_pipeline);
-    }
+    webgpu_pipeline argsort_pipeline = ctx->shader_lib->get_argsort_pipeline(shader_lib_ctx);
+    auto * argsort_decisions = static_cast<ggml_webgpu_generic_shader_decisions *>(argsort_pipeline.context.get());
+
+    webgpu_pipeline argsort_merge_pipeline = ctx->shader_lib->get_argsort_merge_pipeline(shader_lib_ctx);
 
     const uint32_t src_ne0 = (uint32_t) src->ne[0];
     const uint32_t nrows   = (uint32_t) ggml_nrows(src);
@@ -2011,22 +1903,15 @@ static webgpu_command ggml_webgpu_cumsum(webgpu_context & ctx, ggml_tensor * src
          .size    = ggml_webgpu_tensor_binding_size(ctx, dst) }
     };
 
-    ggml_webgpu_generic_shader_lib_context shader_lib_ctx = {
-        .vec4        = false,
+    ggml_webgpu_shader_lib_context shader_lib_ctx = {
+        .src0        = src,
+        .src1        = nullptr,
+        .dst         = dst,
         .max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup,
     };
-    webgpu_pipeline pipeline;
-    auto            it = ctx->cumsum_pipelines.find(1);
-    if (it != ctx->cumsum_pipelines.end()) {
-        pipeline = it->second;
-    } else {
-        ggml_webgpu_processed_shader processed =
-            ggml_webgpu_preprocess_generic_shader(ctx->p, wgsl_cumsum, shader_lib_ctx, "cumsum");
-        pipeline =
-            ggml_webgpu_create_pipeline(ctx->global_ctx->device, processed.wgsl.c_str(), processed.variant.c_str());
-        ctx->cumsum_pipelines.emplace(1, pipeline);
-    }
-    uint32_t wg_x = ggml_nrows(dst);
+
+    webgpu_pipeline pipeline = ctx->shader_lib->get_cumsum_pipeline(shader_lib_ctx);
+    uint32_t        wg_x     = ggml_nrows(dst);
     return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_buf_pool, pipeline, params, entries, wg_x);
 }
 
@@ -2052,22 +1937,12 @@ static webgpu_command ggml_webgpu_sum_rows(webgpu_context & ctx, ggml_tensor * s
          .size    = ggml_webgpu_tensor_binding_size(ctx, dst) }
     };
 
-    ggml_webgpu_generic_shader_lib_context shader_lib_ctx = {
-        .vec4        = false,
-        .max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup,
+    ggml_webgpu_shader_lib_context shader_lib_ctx = {
+        .src0 = src, .dst = dst, .max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup
     };
 
-    webgpu_pipeline pipeline;
-    auto            it = ctx->sum_rows_pipelines.find(1);
-    if (it != ctx->sum_rows_pipelines.end()) {
-        pipeline = it->second;
-    } else {
-        ggml_webgpu_processed_shader processed =
-            ggml_webgpu_preprocess_generic_shader(ctx->p, wgsl_sum_rows, shader_lib_ctx, "sum_rows");
-        pipeline =
-            ggml_webgpu_create_pipeline(ctx->global_ctx->device, processed.wgsl.c_str(), processed.variant.c_str());
-        ctx->sum_rows_pipelines.emplace(1, pipeline);
-    }
+    webgpu_pipeline pipeline = ctx->shader_lib->get_sum_rows_pipeline(shader_lib_ctx);
+
     uint32_t wg_x = total_sum ? 1 : ggml_nrows(dst);
     return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_buf_pool, pipeline, params, entries, wg_x);
 }
@@ -2233,7 +2108,9 @@ static void ggml_backend_webgpu_buffer_memset_tensor(ggml_backend_buffer_t buffe
                                                      size_t                offset,
                                                      size_t                size) {
     if (size == 0) {
-        WEBGPU_LOG_DEBUG("ggml_backend_webgpu_buffer_memset_tensor: size is zero, nothing to do.");
+        WEBGPU_LOG_DEBUG(
+            "ggml_backend_webgpu_buffer_memset_tensor: size is zero, "
+            "nothing to do.");
         return;
     }
 
@@ -2310,7 +2187,8 @@ static void ggml_backend_webgpu_buffer_get_tensor(ggml_backend_buffer_t buffer,
 
     size_t final_size = size;
     if (size % 4 != 0) {
-        // If size is not a multiple of 4, we need to round it up to the next multiple of 4
+        // If size is not a multiple of 4, we need to round it up to the next
+        // multiple of 4
         final_size = size + (4 - (size % 4));
     }
 
@@ -2364,7 +2242,8 @@ static ggml_backend_buffer_i ggml_backend_webgpu_buffer_interface = {
     /* .get_tensor      = */ ggml_backend_webgpu_buffer_get_tensor,
     /* .cpy_tensor      = */ NULL,  // TODO: optional, implement this
     /* .clear           = */ ggml_backend_webgpu_buffer_clear,
-    /* .reset           = */ NULL,  // TODO: optional, think it coordinates with .init_tensor
+    /* .reset           = */ NULL,  // TODO: optional, think it coordinates with
+                                    // .init_tensor
 };
 
 /* End GGML Backend Buffer Interface */
@@ -2401,7 +2280,8 @@ static size_t ggml_backend_webgpu_buffer_type_get_alignment(ggml_backend_buffer_
     return dev_ctx->webgpu_global_ctx->capabilities.limits.minStorageBufferOffsetAlignment;
 }
 
-// maxBufferSize might be larger, but you can't bind more than maxStorageBufferBindingSize to a single binding.
+// maxBufferSize might be larger, but you can't bind more than
+// maxStorageBufferBindingSize to a single binding.
 static size_t ggml_backend_webgpu_buffer_type_get_max_size(ggml_backend_buffer_type_t buft) {
     ggml_backend_webgpu_device_context * dev_ctx =
         static_cast<ggml_backend_webgpu_device_context *>(buft->device->context);
@@ -2487,14 +2367,6 @@ static ggml_guid_t ggml_backend_webgpu_guid(void) {
     return reinterpret_cast<ggml_guid_t>((void *) guid_str);
 }
 
-// Workgroup size is a common constant
-static std::vector<wgpu::ConstantEntry> ggml_webgpu_wg_size_entry(uint32_t wg_size) {
-    std::vector<wgpu::ConstantEntry> constants(1);
-    constants[0].key   = "wg_size";
-    constants[0].value = wg_size;
-    return constants;
-}
-
 static void ggml_webgpu_init_memset_pipeline(webgpu_global_context & ctx) {
     // we use the maximum workgroup size for the memset pipeline
     size_t max_threads = WEBGPU_MAX_WG_SIZE * ctx->capabilities.limits.maxComputeWorkgroupsPerDimension;
@@ -2509,207 +2381,6 @@ static void ggml_webgpu_init_memset_pipeline(webgpu_global_context & ctx) {
     ctx->memset_pipelines[0] = ggml_webgpu_create_pipeline(ctx->device, wgsl_memset, "memset", constants);
 }
 
-static void ggml_webgpu_init_mul_mat_pipeline(webgpu_context & webgpu_ctx) {
-    // Q4/Q5/Q8 classic quantizations
-    webgpu_ctx->mul_mat_pipelines[GGML_TYPE_Q4_0][GGML_TYPE_F32][0] =
-        ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_mul_mat_q4_0_f32, "mul_mat_q4_0_f32");
-    webgpu_ctx->mul_mat_pipelines[GGML_TYPE_Q4_1][GGML_TYPE_F32][0] =
-        ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_mul_mat_q4_1_f32, "mul_mat_q4_1_f32");
-    webgpu_ctx->mul_mat_pipelines[GGML_TYPE_Q5_0][GGML_TYPE_F32][0] =
-        ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_mul_mat_q5_0_f32, "mul_mat_q5_0_f32");
-    webgpu_ctx->mul_mat_pipelines[GGML_TYPE_Q5_1][GGML_TYPE_F32][0] =
-        ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_mul_mat_q5_1_f32, "mul_mat_q5_1_f32");
-    webgpu_ctx->mul_mat_pipelines[GGML_TYPE_Q8_0][GGML_TYPE_F32][0] =
-        ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_mul_mat_q8_0_f32, "mul_mat_q8_0_f32");
-
-    // K-quantizations
-    webgpu_ctx->mul_mat_pipelines[GGML_TYPE_Q2_K][GGML_TYPE_F32][0] =
-        ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_mul_mat_q2_k_f32, "mul_mat_q2_k_f32");
-    webgpu_ctx->mul_mat_pipelines[GGML_TYPE_Q3_K][GGML_TYPE_F32][0] =
-        ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_mul_mat_q3_k_f32, "mul_mat_q3_k_f32");
-    webgpu_ctx->mul_mat_pipelines[GGML_TYPE_Q4_K][GGML_TYPE_F32][0] =
-        ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_mul_mat_q4_k_f32, "mul_mat_q4_k_f32");
-    webgpu_ctx->mul_mat_pipelines[GGML_TYPE_Q5_K][GGML_TYPE_F32][0] =
-        ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_mul_mat_q5_k_f32, "mul_mat_q5_k_f32");
-    webgpu_ctx->mul_mat_pipelines[GGML_TYPE_Q6_K][GGML_TYPE_F32][0] =
-        ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_mul_mat_q6_k_f32, "mul_mat_q6_k_f32");
-
-    // IQ quantizations (2-, 3-, 4-bit variants)
-    webgpu_ctx->mul_mat_pipelines[GGML_TYPE_IQ2_XXS][GGML_TYPE_F32][0] =
-        ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_mul_mat_iq2_xxs_f32, "mul_mat_iq2_xxs_f32");
-    webgpu_ctx->mul_mat_pipelines[GGML_TYPE_IQ2_XS][GGML_TYPE_F32][0] =
-        ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_mul_mat_iq2_xs_f32, "mul_mat_iq2_xs_f32");
-    webgpu_ctx->mul_mat_pipelines[GGML_TYPE_IQ2_S][GGML_TYPE_F32][0] =
-        ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_mul_mat_iq2_s_f32, "mul_mat_iq2_s_f32");
-
-    webgpu_ctx->mul_mat_pipelines[GGML_TYPE_IQ3_XXS][GGML_TYPE_F32][0] =
-        ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_mul_mat_iq3_xxs_f32, "mul_mat_iq3_xxs_f32");
-    webgpu_ctx->mul_mat_pipelines[GGML_TYPE_IQ3_S][GGML_TYPE_F32][0] =
-        ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_mul_mat_iq3_s_f32, "mul_mat_iq3_s_f32");
-
-    // 1-bit and 4-bit IQ variants
-    webgpu_ctx->mul_mat_pipelines[GGML_TYPE_IQ1_S][GGML_TYPE_F32][0] =
-        ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_mul_mat_iq1_s_f32, "mul_mat_iq1_s_f32");
-    webgpu_ctx->mul_mat_pipelines[GGML_TYPE_IQ1_M][GGML_TYPE_F32][0] =
-        ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_mul_mat_iq1_m_f32, "mul_mat_iq1_m_f32");
-    webgpu_ctx->mul_mat_pipelines[GGML_TYPE_IQ4_NL][GGML_TYPE_F32][0] =
-        ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_mul_mat_iq4_nl_f32, "mul_mat_iq4_nl_f32");
-    webgpu_ctx->mul_mat_pipelines[GGML_TYPE_IQ4_XS][GGML_TYPE_F32][0] =
-        ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_mul_mat_iq4_xs_f32, "mul_mat_iq4_xs_f32");
-
-    std::string proc_mul_mat_f32_f32;
-    std::string proc_mul_mat_f32_f32_vec;
-    std::string proc_mul_mat_f16_f32;
-    std::string proc_mul_mat_f16_f32_vec;
-    std::string proc_mul_mat_f16_f16;
-    std::string proc_mul_mat_f16_f16_vec;
-    std::string proc_mul_mat_q4_0_f32;
-    std::string proc_mul_mat_q4_0_f32_vec;
-
-    std::vector<wgpu::ConstantEntry> mul_mat_constants;
-#ifndef __EMSCRIPTEN__
-    if (webgpu_ctx->global_ctx->capabilities.supports_subgroup_matrix) {
-        std::map<std::string, std::string> sg_matrix_repls;
-        sg_matrix_repls["WEBGPU_MAX_SUBGROUP_SIZE"] =
-            std::to_string(webgpu_ctx->global_ctx->capabilities.max_subgroup_size);
-        sg_matrix_repls["WEBGPU_TILE_K"]            = std::to_string(WEBGPU_MUL_MAT_TILE_K);
-        sg_matrix_repls["WEBGPU_SUBGROUP_M"]        = std::to_string(WEBGPU_MUL_MAT_SUBGROUP_M);
-        sg_matrix_repls["WEBGPU_SUBGROUP_N"]        = std::to_string(WEBGPU_MUL_MAT_SUBGROUP_N);
-        sg_matrix_repls["WEBGPU_SUBGROUP_MATRIX_M"] = std::to_string(WEBGPU_MUL_MAT_SUBGROUP_MATRIX_M);
-        sg_matrix_repls["WEBGPU_SUBGROUP_MATRIX_N"] = std::to_string(WEBGPU_MUL_MAT_SUBGROUP_MATRIX_N);
-        sg_matrix_repls["WEBGPU_SG_MAT_M_SIZE"]     = std::to_string(webgpu_ctx->global_ctx->capabilities.sg_mat_m);
-        sg_matrix_repls["WEBGPU_SG_MAT_N_SIZE"]     = std::to_string(webgpu_ctx->global_ctx->capabilities.sg_mat_n);
-        sg_matrix_repls["WEBGPU_SG_MAT_K_SIZE"]     = std::to_string(webgpu_ctx->global_ctx->capabilities.sg_mat_k);
-        proc_mul_mat_f32_f32 = ggml_webgpu_process_shader_repls(wgsl_mul_mat_subgroup_matrix_f32_f32, sg_matrix_repls);
-        proc_mul_mat_f32_f32_vec =
-            ggml_webgpu_process_shader_repls(wgsl_mul_mat_subgroup_matrix_f32_f32_vec, sg_matrix_repls);
-        proc_mul_mat_f16_f32 = ggml_webgpu_process_shader_repls(wgsl_mul_mat_subgroup_matrix_f16_f32, sg_matrix_repls);
-        proc_mul_mat_f16_f32_vec =
-            ggml_webgpu_process_shader_repls(wgsl_mul_mat_subgroup_matrix_f16_f32_vec, sg_matrix_repls);
-        proc_mul_mat_f16_f16 = ggml_webgpu_process_shader_repls(wgsl_mul_mat_subgroup_matrix_f16_f16, sg_matrix_repls);
-        proc_mul_mat_f16_f16_vec =
-            ggml_webgpu_process_shader_repls(wgsl_mul_mat_subgroup_matrix_f16_f16_vec, sg_matrix_repls);
-        proc_mul_mat_q4_0_f32 =
-            ggml_webgpu_process_shader_repls(wgsl_mul_mat_subgroup_matrix_q4_0_f32, sg_matrix_repls);
-        proc_mul_mat_q4_0_f32_vec =
-            ggml_webgpu_process_shader_repls(wgsl_mul_mat_subgroup_matrix_q4_0_f32_vec, sg_matrix_repls);
-    } else {
-#endif
-        mul_mat_constants.push_back({ .key = "TILE_K", .value = WEBGPU_MUL_MAT_TILE_K });
-        mul_mat_constants.push_back({ .key = "WORKGROUP_SIZE_M", .value = WEBGPU_MUL_MAT_WG_SIZE_M });
-        mul_mat_constants.push_back({ .key = "WORKGROUP_SIZE_N", .value = WEBGPU_MUL_MAT_WG_SIZE_N });
-
-        std::map<std::string, std::string> reg_repls;
-        reg_repls["WEBGPU_TILE_M"] = std::to_string(WEBGPU_MUL_MAT_TILE_M);
-        reg_repls["WEBGPU_TILE_N"] = std::to_string(WEBGPU_MUL_MAT_TILE_N);
-
-        proc_mul_mat_f32_f32      = ggml_webgpu_process_shader_repls(wgsl_mul_mat_reg_tile_f32_f32, reg_repls);
-        proc_mul_mat_f32_f32_vec  = ggml_webgpu_process_shader_repls(wgsl_mul_mat_reg_tile_f32_f32_vec, reg_repls);
-        proc_mul_mat_f16_f32      = ggml_webgpu_process_shader_repls(wgsl_mul_mat_reg_tile_f16_f32, reg_repls);
-        proc_mul_mat_f16_f32_vec  = ggml_webgpu_process_shader_repls(wgsl_mul_mat_reg_tile_f16_f32_vec, reg_repls);
-        proc_mul_mat_f16_f16      = ggml_webgpu_process_shader_repls(wgsl_mul_mat_reg_tile_f16_f16, reg_repls);
-        proc_mul_mat_f16_f16_vec  = ggml_webgpu_process_shader_repls(wgsl_mul_mat_reg_tile_f16_f16_vec, reg_repls);
-        proc_mul_mat_q4_0_f32     = ggml_webgpu_process_shader_repls(wgsl_mul_mat_reg_tile_q4_0_f32, reg_repls);
-        proc_mul_mat_q4_0_f32_vec = ggml_webgpu_process_shader_repls(wgsl_mul_mat_reg_tile_q4_0_f32_vec, reg_repls);
-#ifndef __EMSCRIPTEN__
-    }
-#endif
-
-    webgpu_ctx->mul_mat_pipelines[GGML_TYPE_F32][GGML_TYPE_F32][0] = ggml_webgpu_create_pipeline(
-        webgpu_ctx->global_ctx->device, proc_mul_mat_f32_f32.c_str(), "mul_mat_f32_f32", mul_mat_constants);
-    webgpu_ctx->mul_mat_pipelines[GGML_TYPE_F32][GGML_TYPE_F32][1] = ggml_webgpu_create_pipeline(
-        webgpu_ctx->global_ctx->device, proc_mul_mat_f32_f32_vec.c_str(), "mul_mat_f32_f32_vec", mul_mat_constants);
-    webgpu_ctx->mul_mat_pipelines[GGML_TYPE_F16][GGML_TYPE_F32][0] = ggml_webgpu_create_pipeline(
-        webgpu_ctx->global_ctx->device, proc_mul_mat_f16_f32.c_str(), "mul_mat_f16_f32", mul_mat_constants);
-    webgpu_ctx->mul_mat_pipelines[GGML_TYPE_F16][GGML_TYPE_F32][1] = ggml_webgpu_create_pipeline(
-        webgpu_ctx->global_ctx->device, proc_mul_mat_f16_f32_vec.c_str(), "mul_mat_f16_f32_vec", mul_mat_constants);
-    webgpu_ctx->mul_mat_pipelines[GGML_TYPE_F16][GGML_TYPE_F16][0] = ggml_webgpu_create_pipeline(
-        webgpu_ctx->global_ctx->device, proc_mul_mat_f16_f16.c_str(), "mul_mat_f16_f16", mul_mat_constants);
-    webgpu_ctx->mul_mat_pipelines[GGML_TYPE_F16][GGML_TYPE_F16][1] = ggml_webgpu_create_pipeline(
-        webgpu_ctx->global_ctx->device, proc_mul_mat_f16_f16_vec.c_str(), "mul_mat_f16_f16_vec", mul_mat_constants);
-    webgpu_ctx->mul_mat_pipelines[GGML_TYPE_Q4_0][GGML_TYPE_F32][0] = ggml_webgpu_create_pipeline(
-        webgpu_ctx->global_ctx->device, proc_mul_mat_q4_0_f32.c_str(), "mul_mat_q4_0_f32", mul_mat_constants);
-    webgpu_ctx->mul_mat_pipelines[GGML_TYPE_Q4_0][GGML_TYPE_F32][1] = ggml_webgpu_create_pipeline(
-        webgpu_ctx->global_ctx->device, proc_mul_mat_q4_0_f32_vec.c_str(), "mul_mat_q4_0_f32_vec", mul_mat_constants);
-
-    std::vector<wgpu::ConstantEntry> mul_mat_vec_constants(3);
-    mul_mat_vec_constants[0].key   = "WORKGROUP_SIZE";
-    mul_mat_vec_constants[0].value = WEBGPU_MUL_MAT_VEC_WG_SIZE;
-    mul_mat_vec_constants[1].key   = "TILE_K";
-    mul_mat_vec_constants[1].value = WEBGPU_MUL_MAT_VEC_TILE_K;
-    mul_mat_vec_constants[2].key   = "OUTPUTS_PER_WG";
-    mul_mat_vec_constants[2].value = WEBGPU_MUL_MAT_VEC_OUTPUTS_PER_WG;
-
-    webgpu_ctx->mul_mat_vec_pipelines[GGML_TYPE_F32][GGML_TYPE_F32][0] = ggml_webgpu_create_pipeline(
-        webgpu_ctx->global_ctx->device, wgsl_mul_mat_vec_f32_f32, "mul_mat_vec_f32_f32", mul_mat_vec_constants);
-    webgpu_ctx->mul_mat_vec_pipelines[GGML_TYPE_F32][GGML_TYPE_F32][1] = ggml_webgpu_create_pipeline(
-        webgpu_ctx->global_ctx->device, wgsl_mul_mat_vec_f32_f32_vec, "mul_mat_vec_f32_f32_vec", mul_mat_vec_constants);
-    webgpu_ctx->mul_mat_vec_pipelines[GGML_TYPE_F16][GGML_TYPE_F32][0] = ggml_webgpu_create_pipeline(
-        webgpu_ctx->global_ctx->device, wgsl_mul_mat_vec_f16_f32, "mul_mat_vec_f16_f32", mul_mat_vec_constants);
-    webgpu_ctx->mul_mat_vec_pipelines[GGML_TYPE_F16][GGML_TYPE_F32][1] = ggml_webgpu_create_pipeline(
-        webgpu_ctx->global_ctx->device, wgsl_mul_mat_vec_f16_f32_vec, "mul_mat_vec_f16_f32_vec", mul_mat_vec_constants);
-    webgpu_ctx->mul_mat_vec_pipelines[GGML_TYPE_F16][GGML_TYPE_F16][0] = ggml_webgpu_create_pipeline(
-        webgpu_ctx->global_ctx->device, wgsl_mul_mat_vec_f16_f16, "mul_mat_vec_f16_f16", mul_mat_vec_constants);
-    webgpu_ctx->mul_mat_vec_pipelines[GGML_TYPE_F16][GGML_TYPE_F16][1] = ggml_webgpu_create_pipeline(
-        webgpu_ctx->global_ctx->device, wgsl_mul_mat_vec_f16_f16_vec, "mul_mat_vec_f16_f16_vec", mul_mat_vec_constants);
-    webgpu_ctx->mul_mat_vec_pipelines[GGML_TYPE_Q4_0][GGML_TYPE_F32][0] = ggml_webgpu_create_pipeline(
-        webgpu_ctx->global_ctx->device, wgsl_mul_mat_vec_q4_0_f32, "mul_mat_vec_q4_0_f32", mul_mat_vec_constants);
-}
-
-static void ggml_webgpu_init_get_rows_pipeline(webgpu_context & webgpu_ctx) {
-    std::vector<wgpu::ConstantEntry> constants = ggml_webgpu_wg_size_entry(WEBGPU_MAX_WG_SIZE);
-
-    webgpu_ctx->get_rows_pipelines[GGML_TYPE_F32][0] =
-        ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_get_rows_f32, "get_rows_f32", constants);
-    webgpu_ctx->get_rows_pipelines[GGML_TYPE_F32][1] = ggml_webgpu_create_pipeline(
-        webgpu_ctx->global_ctx->device, wgsl_get_rows_f32_vec, "get_rows_f32_vec", constants);
-
-    webgpu_ctx->get_rows_pipelines[GGML_TYPE_F16][0] =
-        ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_get_rows_f16, "get_rows_f16", constants);
-    webgpu_ctx->get_rows_pipelines[GGML_TYPE_I32][0] =
-        ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_get_rows_i32, "get_rows_i32", constants);
-    webgpu_ctx->get_rows_pipelines[GGML_TYPE_Q4_0][0] =
-        ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_get_rows_q4_0, "get_rows_q4_0", constants);
-    webgpu_ctx->get_rows_pipelines[GGML_TYPE_Q4_1][0] =
-        ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_get_rows_q4_1, "get_rows_q4_1", constants);
-    webgpu_ctx->get_rows_pipelines[GGML_TYPE_Q5_0][0] =
-        ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_get_rows_q5_0, "get_rows_q5_0", constants);
-    webgpu_ctx->get_rows_pipelines[GGML_TYPE_Q5_1][0] =
-        ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_get_rows_q5_1, "get_rows_q5_1", constants);
-    webgpu_ctx->get_rows_pipelines[GGML_TYPE_Q8_0][0] =
-        ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_get_rows_q8_0, "get_rows_q8_0", constants);
-
-    webgpu_ctx->get_rows_pipelines[GGML_TYPE_Q2_K][0] =
-        ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_get_rows_q2_k, "get_rows_q2_k", constants);
-    webgpu_ctx->get_rows_pipelines[GGML_TYPE_Q3_K][0] =
-        ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_get_rows_q3_k, "get_rows_q3_k", constants);
-    webgpu_ctx->get_rows_pipelines[GGML_TYPE_Q4_K][0] =
-        ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_get_rows_q4_k, "get_rows_q4_k", constants);
-    webgpu_ctx->get_rows_pipelines[GGML_TYPE_Q5_K][0] =
-        ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_get_rows_q5_k, "get_rows_q5_k", constants);
-    webgpu_ctx->get_rows_pipelines[GGML_TYPE_Q6_K][0] =
-        ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_get_rows_q6_k, "get_rows_q6_k", constants);
-
-    webgpu_ctx->get_rows_pipelines[GGML_TYPE_IQ2_XXS][0] = ggml_webgpu_create_pipeline(
-        webgpu_ctx->global_ctx->device, wgsl_get_rows_iq2_xxs, "get_rows_iq2_xxs", constants);
-    webgpu_ctx->get_rows_pipelines[GGML_TYPE_IQ2_XS][0] =
-        ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_get_rows_iq2_xs, "get_rows_iq2_xs", constants);
-    webgpu_ctx->get_rows_pipelines[GGML_TYPE_IQ2_S][0] =
-        ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_get_rows_iq2_s, "get_rows_iq2_s", constants);
-    webgpu_ctx->get_rows_pipelines[GGML_TYPE_IQ3_XXS][0] = ggml_webgpu_create_pipeline(
-        webgpu_ctx->global_ctx->device, wgsl_get_rows_iq3_xxs, "get_rows_iq3_xxs", constants);
-    webgpu_ctx->get_rows_pipelines[GGML_TYPE_IQ3_S][0] =
-        ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_get_rows_iq3_s, "get_rows_iq3_s", constants);
-    webgpu_ctx->get_rows_pipelines[GGML_TYPE_IQ1_S][0] =
-        ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_get_rows_iq1_s, "get_rows_iq1_s", constants);
-    webgpu_ctx->get_rows_pipelines[GGML_TYPE_IQ1_M][0] =
-        ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_get_rows_iq1_m, "get_rows_iq1_m", constants);
-    webgpu_ctx->get_rows_pipelines[GGML_TYPE_IQ4_NL][0] =
-        ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_get_rows_iq4_nl, "get_rows_iq4_nl", constants);
-    webgpu_ctx->get_rows_pipelines[GGML_TYPE_IQ4_XS][0] =
-        ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_get_rows_iq4_xs, "get_rows_iq4_xs", constants);
-}
-
 static void ggml_webgpu_init_cpy_pipeline(webgpu_context & webgpu_ctx) {
     std::vector<wgpu::ConstantEntry> constants = ggml_webgpu_wg_size_entry(WEBGPU_MAX_WG_SIZE);
 
@@ -2816,15 +2487,6 @@ static void ggml_webgpu_init_glu_pipeline(webgpu_context & webgpu_ctx) {
         webgpu_ctx->global_ctx->device, wgsl_geglu_quick_f16_split, "geglu_quick_f16_split", constants);
 }
 
-static void ggml_webgpu_init_scale_pipeline(webgpu_context & webgpu_ctx) {
-    std::vector<wgpu::ConstantEntry> constants = ggml_webgpu_wg_size_entry(WEBGPU_MAX_WG_SIZE);
-
-    webgpu_ctx->scale_pipelines[0] =
-        ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_scale_f32, "scale_f32", constants);
-    webgpu_ctx->scale_pipelines[1] = ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_scale_f32_inplace,
-                                                                 "scale_f32_inplace", constants);
-}
-
 static void ggml_webgpu_init_soft_max_pipeline(webgpu_context & webgpu_ctx) {
     std::vector<wgpu::ConstantEntry> constants = ggml_webgpu_wg_size_entry(WEBGPU_ROW_SPLIT_WG_SIZE);
 
@@ -3015,6 +2677,7 @@ static webgpu_context initialize_webgpu_context(ggml_backend_dev_t dev) {
     ggml_backend_webgpu_device_context * dev_ctx    = (ggml_backend_webgpu_device_context *) dev->context;
     webgpu_context                       webgpu_ctx = std::make_shared<webgpu_context_struct>();
     webgpu_ctx->global_ctx                          = dev_ctx->webgpu_global_ctx;
+    webgpu_ctx->shader_lib = std::make_unique<ggml_webgpu_shader_lib>(dev_ctx->webgpu_global_ctx->device);
     webgpu_ctx->param_buf_pool.init(webgpu_ctx->global_ctx->device, WEBGPU_NUM_PARAM_BUFS, WEBGPU_PARAMS_BUF_SIZE_BYTES,
                                     wgpu::BufferUsage::CopyDst | wgpu::BufferUsage::Uniform,
                                     wgpu::BufferUsage::CopySrc | wgpu::BufferUsage::MapWrite);
@@ -3023,13 +2686,10 @@ static webgpu_context initialize_webgpu_context(ggml_backend_dev_t dev) {
                                              wgpu::BufferUsage::CopySrc | wgpu::BufferUsage::Storage,
                                              wgpu::BufferUsage::CopyDst | wgpu::BufferUsage::MapRead);
 
-    ggml_webgpu_init_mul_mat_pipeline(webgpu_ctx);
-    ggml_webgpu_init_get_rows_pipeline(webgpu_ctx);
     ggml_webgpu_init_cpy_pipeline(webgpu_ctx);
     ggml_webgpu_init_rms_norm_pipeline(webgpu_ctx);
     ggml_webgpu_init_rope_pipeline(webgpu_ctx);
     ggml_webgpu_init_glu_pipeline(webgpu_ctx);
-    ggml_webgpu_init_scale_pipeline(webgpu_ctx);
     ggml_webgpu_init_soft_max_pipeline(webgpu_ctx);
 #ifdef GGML_WEBGPU_DEBUG
     // Initialize debug buffers
@@ -3071,11 +2731,11 @@ static ggml_backend_buffer_type_t ggml_backend_webgpu_device_get_buffer_type(ggm
     static struct ggml_backend_buffer_type ggml_backend_webgpu_buffer_type = {
         /* .iface = */ {
                         /* .get_name         = */ ggml_backend_webgpu_buffer_type_get_name,
-                        /* .alloc_buffer     = */ ggml_backend_webgpu_buffer_type_alloc_buffer,
-                        /* .get_alignment    = */ ggml_backend_webgpu_buffer_type_get_alignment,
-                        /* .get_max_size     = */ ggml_backend_webgpu_buffer_type_get_max_size,
-                        /* .get_alloc_size   = */ ggml_backend_webgpu_buffer_type_get_alloc_size,
-                        /* .is_host          = */ NULL,  // defaults to false
+                        /* .alloc_buffer     = */
+            ggml_backend_webgpu_buffer_type_alloc_buffer,  /* .get_alignment    = */
+            ggml_backend_webgpu_buffer_type_get_alignment, /* .get_max_size     = */
+            ggml_backend_webgpu_buffer_type_get_max_size,  /* .get_alloc_size   = */
+            ggml_backend_webgpu_buffer_type_get_alloc_size, /* .is_host          = */ NULL,                // defaults to false
         },
         /* .device  = */
         dev,
index 389c97bb51b9a01dc79b8e0a3309393254895213..9a5b18ebc07b3263d35ec3a5138b91ebf46929ef 100644 (file)
@@ -1,5 +1,4 @@
-#decl(BYTE_HELPERS)
-
+#ifdef BYTE_HELPERS
 fn get_byte(value: u32, index: u32) -> u32 {
     return (value >> (index * 8)) & 0xFF;
 }
@@ -7,76 +6,74 @@ fn get_byte(value: u32, index: u32) -> u32 {
 fn get_byte_i32(value: u32, index: u32) -> i32 {
     return bitcast<i32>(((value >> (index * 8)) & 0xFF) << 24) >> 24;
 }
+#endif
 
-#enddecl(BYTE_HELPERS)
-
-#decl(Q4_0_T)
+#ifdef Q4_0_T
 struct q4_0 {
     d: f16,
     qs: array<f16, 8>
 };
-#enddecl(Q4_0_T)
+#endif
 
-#decl(Q4_1_T)
+#ifdef Q4_1_T
 struct q4_1 {
     d: f16,
     m: f16,
     qs: array<u32, 4>
 };
-#enddecl(Q4_1_T)
+#endif
 
-#decl(Q5_0_T)
+#ifdef Q5_0_T
 struct q5_0 {
     d: f16,
     qh: array<f16, 2>,
     qs: array<f16, 8>
 };
-#enddecl(Q5_0_T)
+#endif
 
-#decl(Q5_1_T)
+#ifdef Q5_1_T
 struct q5_1 {
     d: f16,
     m: f16,
     qh: u32,
     qs: array<u32, 4>
 };
-#enddecl(Q5_1_T)
+#endif
 
-#decl(Q8_0_T)
+#ifdef Q8_0_T
 struct q8_0 {
     d: f16,
     qs: array<f16, 16>
 };
-#enddecl(Q8_0_T)
+#endif
 
-#decl(Q8_1_T)
+#ifdef Q8_1_T
 struct q8_1 {
     d: f16,
     m: f16,
     qs: array<u32, 8>
 };
-#enddecl(Q8_1_T)
+#endif
 
-#decl(Q2_K_T)
-struct q2_k {
+#ifdef Q2_K_T
+struct q2_K {
     scales: array<u32, 4>,
     qs: array<u32, 16>,
     d: f16,
     dmin: f16
 };
-#enddecl(Q2_K_T)
+#endif
 
-#decl(Q3_K_T)
-struct q3_k {
+#ifdef Q3_K_T
+struct q3_K {
     hmask: array<f16, 16>,
     qs: array<f16, 32>,
     scales: array<f16, 6>,
     d: f16
 };
-#enddecl(Q3_K_T)
-
-#decl(Q45_K_SCALE_MIN)
+#endif
 
+#if defined(Q4_K_SCALE_MIN) || defined(Q5_K_SCALE_MIN)
 fn get_scale_min(is: u32, scales: array<u32, 3>) -> vec2<f32> {
     if (is < 4) {
         let sc_byte = get_byte(scales[is / 4], is % 4);
@@ -91,69 +88,67 @@ fn get_scale_min(is: u32, scales: array<u32, 3>) -> vec2<f32> {
         return vec2(f32(sc), f32(m));
     }
 }
-
-#enddecl(Q45_K_SCALE_MIN)
-
-#decl(Q4_K_T)
-struct q4_k {
+#endif
+#ifdef Q4_K_T
+struct q4_K {
     d: f16,
     dmin: f16,
     scales: array<u32, 3>,
     qs: array<u32, 32>
 };
-#enddecl(Q4_K_T)
+#endif
 
-#decl(Q5_K_T)
-struct q5_k {
+#ifdef Q5_K_T
+struct q5_K {
     d: f16,
     dmin: f16,
     scales: array<u32, 3>,
     qh: array<u32, 8>,
     qs: array<u32, 32>
 };
-#enddecl(Q5_K_T)
+#endif
 
-#decl(Q6_K_T)
-struct q6_k {
+#ifdef Q6_K_T
+struct q6_K {
     ql: array<f16, 64>,
     qh: array<f16, 32>,
     scales: array<f16, 8>,
     d: f16
 };
-#enddecl(Q6_K_T)
+#endif
 
-#decl(IQ2_XXS_T)
+#ifdef IQ2_XXS_T
 struct iq2_xxs {
     d: f16,
     qs: array<f16, 32>
 };
-#enddecl(IQ2_XXS_T)
+#endif
 
-#decl(IQ2_XS_T)
+#ifdef IQ2_XS_T
 struct iq2_xs {
     d: f16,
     qs: array<f16, 32>,
     scales: array<f16, 4>
 };
-#enddecl(IQ2_XS_T)
+#endif
 
-#decl(IQ2_S_T)
+#ifdef IQ2_S_T
 struct iq2_s {
     d: f16,
     qs: array<f16, 32>,
     qh: array<f16, 4>,
     scales: array<f16, 4>
 };
-#enddecl(IQ2_S_T)
+#endif
 
-#decl(IQ3_XSS_T)
+#ifdef IQ3_XXS_T
 struct iq3_xxs {
     d: f16,
     qs: array<f16, 48>
 };
-#enddecl(IQ3_XSS_T)
+#endif
 
-#decl(IQ3_S_T)
+#ifdef IQ3_S_T
 struct iq3_s {
     d: f16,
     qs: array<f16, 32>,
@@ -161,41 +156,41 @@ struct iq3_s {
     signs: array<f16, 16>,
     scales: array<f16, 2>
 };
-#enddecl(IQ3_S_T)
+#endif
 
-#decl(IQ1_S_T)
+#ifdef IQ1_S_T
 struct iq1_s {
     d: f16,
     qs: array<f16, 16>,
     qh: array<f16, 8>
 };
-#enddecl(IQ1_S_T)
+#endif
 
-#decl(IQ1_M_T)
+#ifdef IQ1_M_T
 struct iq1_m {
     qs: array<u32, 8>,
     qh: array<u32, 4>,
     scales: array<u32, 2>
 };
-#enddecl(IQ1_M_T)
+#endif
 
-#decl(IQ4_NL_T)
+#ifdef IQ4_NL_T
 struct iq4_nl {
     d: f16,
     qs: array<f16, 8>,
 };
-#enddecl(IQ4_NL_T)
+#endif
 
-#decl(IQ4_XS_T)
+#ifdef IQ4_XS_T
 struct iq4_xs {
     d: f16,
     scales_h: f16,
     scales_l: u32,
     qs: array<u32, 32>
 };
-#enddecl(IQ4_XS_T)
+#endif
 
-#decl(IQ23_TABLES)
+#if defined(IQ2_XXS_TABLES) || defined(IQ2_XS_TABLES) || defined(IQ2_S_TABLES) || defined(IQ3_XXS_TABLES) || defined(IQ3_S_TABLES)
 const kmask_iq2xs : array<u32, 2> = array<u32, 2>(
     0x08040201u, // 1, 2, 4, 8
     0x80402010u  // 16, 32, 64, 128
@@ -211,9 +206,9 @@ const ksigns_iq2xs: array<u32, 32> = array<u32, 32>(
     0x63e2e160,0xe76665e4,0xeb6a69e8,0x6feeed6c,
     0xf37271f0,0x77f6f574,0x7bfaf978,0xff7e7dfc
 );
-#enddecl(IQ23_TABLES)
+#endif
 
-#decl(IQ2_XXS_GRID)
+#ifdef IQ2_XXS_GRID
 const iq2xxs_grid = array<u32, 512>(
     0x08080808, 0x08080808, 0x0808082b, 0x08080808, 0x08081919, 0x08080808, 0x08082b08, 0x08080808,
     0x08082b2b, 0x08080808, 0x08190819, 0x08080808, 0x08191908, 0x08080808, 0x082b0808, 0x08080808,
@@ -280,9 +275,9 @@ const iq2xxs_grid = array<u32, 512>(
     0x0808082b, 0x2b2b0808, 0x19190808, 0x2b2b0808, 0x2b081919, 0x2b2b0808, 0x08082b19, 0x2b2b0819,
     0x08080808, 0x2b2b082b, 0x08192b08, 0x2b2b1908, 0x19190808, 0x2b2b2b08, 0x08081908, 0x2b2b2b19
 );
-#enddecl(IQ2_XXS_GRID)
+#endif
 
-#decl(IQ2_XS_GRID)
+#ifdef IQ2_XS_GRID
 const iq2xs_grid = array<u32, 1024>(
     0x08080808, 0x08080808, 0x0808082b, 0x08080808, 0x08081919, 0x08080808, 0x08082b08, 0x08080808,
     0x08082b2b, 0x08080808, 0x08190819, 0x08080808, 0x08191908, 0x08080808, 0x0819192b, 0x08080808,
@@ -413,9 +408,9 @@ const iq2xs_grid = array<u32, 1024>(
     0x2b2b2b08, 0x2b2b2b08, 0x08081908, 0x2b2b2b19, 0x2b081908, 0x2b2b2b19, 0x2b08192b, 0x2b2b2b19,
     0x082b2b08, 0x2b2b2b2b, 0x082b2b2b, 0x2b2b2b2b, 0x2b190819, 0x2b2b2b2b, 0x2b2b2b2b, 0x2b2b2b2b
 );
-#enddecl(IQ2_XS_GRID)
+#endif
 
-#decl(IQ2_S_GRID)
+#ifdef IQ2_S_GRID
 const iq2s_grid = array<u32, 2048>(
     0x08080808, 0x08080808, 0x0808082b, 0x08080808, 0x08081919, 0x08080808, 0x08082b08, 0x08080808,
     0x08082b2b, 0x08080808, 0x08190819, 0x08080808, 0x08191908, 0x08080808, 0x0819192b, 0x08080808,
@@ -674,10 +669,9 @@ const iq2s_grid = array<u32, 2048>(
     0x2b08192b, 0x2b2b2b19, 0x08082b08, 0x2b2b2b2b, 0x08082b2b, 0x2b2b2b2b, 0x082b0808, 0x2b2b2b2b,
     0x082b082b, 0x2b2b2b2b, 0x082b2b08, 0x2b2b2b2b, 0x2b082b08, 0x2b2b2b2b, 0x2b2b2b2b, 0x2b2b2b2b
 );
-#enddecl(IQ2_S_GRID)
-
-#decl(IQ3_XSS_GRID)
+#endif
 
+#ifdef IQ3_XXS_GRID
 const iq3xxs_grid = array<u32, 256>(
     0x04040404, 0x04040414, 0x04040424, 0x04040c0c, 0x04040c1c, 0x04040c3e, 0x04041404, 0x04041414,
     0x04041c0c, 0x04042414, 0x04043e1c, 0x04043e2c, 0x040c040c, 0x040c041c, 0x040c0c04, 0x040c0c14,
@@ -712,10 +706,9 @@ const iq3xxs_grid = array<u32, 256>(
     0x3e042c14, 0x3e0c1434, 0x3e0c2404, 0x3e140c14, 0x3e14242c, 0x3e142c14, 0x3e1c0404, 0x3e1c0c2c,
     0x3e1c1c1c, 0x3e1c3404, 0x3e24140c, 0x3e24240c, 0x3e2c0404, 0x3e2c0414, 0x3e2c1424, 0x3e341c04
 );
-#enddecl(IQ3_XSS_GRID)
-
-#decl(IQ3_S_GRID)
+#endif
 
+#ifdef IQ3_S_GRID
 const iq3s_grid = array<u32, 512>(
     0x01010101, 0x01010103, 0x01010105, 0x0101010b, 0x0101010f, 0x01010301, 0x01010303, 0x01010305,
     0x01010309, 0x0101030d, 0x01010501, 0x01010503, 0x0101050b, 0x01010707, 0x01010901, 0x01010905,
@@ -782,9 +775,9 @@ const iq3s_grid = array<u32, 512>(
     0x0f050701, 0x0f050b03, 0x0f070105, 0x0f070705, 0x0f07070b, 0x0f070b07, 0x0f090103, 0x0f09010b,
     0x0f090307, 0x0f090501, 0x0f090b01, 0x0f0b0505, 0x0f0b0905, 0x0f0d0105, 0x0f0d0703, 0x0f0f0101
 );
-#enddecl(IQ3_S_GRID)
+#endif
 
-#decl(IQ1_GRID)
+#if defined(IQ1_S_GRID) || defined(IQ1_M_GRID)
 
 const IQ1_DELTA: f32 = 0.125;
 
@@ -919,12 +912,12 @@ const iq1_grid = array<u32, 1024>(
     0x55dd55df, 0x55d555d7, 0x5503550c, 0x557f5501, 0x5577557d, 0x55405575, 0x555d555f, 0x55555557
 );
 
-#enddecl(IQ1_GRID)
+#endif
 
-#decl(IQ4_GRID)
+#if defined(IQ4_NL_GRID) || defined(IQ4_XS_GRID)
 
 const kvalues_iq4nl = array<i32, 16>(
     -127, -104, -83, -65, -49, -35, -22, -10, 1, 13, 25, 38, 53, 69, 89, 113
 );
 
-#enddecl(IQ4_GRID)
+#endif
index d61df5bb9e544253be94a6399aaba03c2cd2f9b8..8b5cfe715e75488ec35e6229d3bab30e3a9f564e 100755 (executable)
@@ -56,12 +56,46 @@ def expand_includes(shader, input_dir):
     return include_pattern.sub(replacer, shader)
 
 
-def write_shader(shader_name, shader_code, output_dir, outfile):
+def chunk_shader(shader_code, max_chunk_len=60000):
+    """Split shader_code into safe raw-string sized chunks."""
+    return [shader_code[i : i + max_chunk_len] for i in range(0, len(shader_code), max_chunk_len)]
+
+
+def raw_delim(shader_code):
+    """Pick a raw-string delimiter that does not appear in the shader."""
+    delim = "wgsl"
+    while f"){delim}\"" in shader_code:
+        delim += "_x"
+    return delim
+
+
+def write_shader(shader_name, shader_code, output_dir, outfile, input_dir):
+    shader_code = expand_includes(shader_code, input_dir)
+
     if output_dir:
         wgsl_filename = os.path.join(output_dir, f"{shader_name}.wgsl")
         with open(wgsl_filename, "w", encoding="utf-8") as f_out:
             f_out.write(shader_code)
-    outfile.write(f'const char* wgsl_{shader_name} = R"({shader_code})";\n\n')
+
+    delim = raw_delim(shader_code)
+    chunks = chunk_shader(shader_code)
+
+    if len(chunks) == 1:
+        outfile.write(f'const char* wgsl_{shader_name} = R"{delim}({shader_code}){delim}";\n\n')
+    else:
+        for idx, chunk in enumerate(chunks):
+            outfile.write(f'static const char wgsl_{shader_name}_part{idx}[] = R"{delim}({chunk}){delim}";\n\n')
+        outfile.write(f'static const std::string& wgsl_{shader_name}_str() {{\n')
+        outfile.write('    static const std::string s = []{\n')
+        outfile.write('        std::string tmp;\n')
+        outfile.write(f'        tmp.reserve({len(shader_code)});\n')
+        for idx in range(len(chunks)):
+            outfile.write(f'        tmp.append(wgsl_{shader_name}_part{idx});\n')
+        outfile.write('        return tmp;\n')
+        outfile.write('    }();\n')
+        outfile.write('    return s;\n')
+        outfile.write('}\n')
+        outfile.write(f'const char* wgsl_{shader_name} = wgsl_{shader_name}_str().c_str();\n\n')
 
 
 def generate_variants(fname, input_dir, output_dir, outfile):
@@ -74,7 +108,7 @@ def generate_variants(fname, input_dir, output_dir, outfile):
     try:
         variants = ast.literal_eval(extract_block(text, "VARIANTS"))
     except ValueError:
-        write_shader(shader_base_name, text, output_dir, outfile)
+        write_shader(shader_base_name, text, output_dir, outfile, input_dir)
     else:
         try:
             decls_map = parse_decls(extract_block(text, "DECLS"))
@@ -123,7 +157,7 @@ def generate_variants(fname, input_dir, output_dir, outfile):
                 output_name = f"{shader_base_name}_" + variant["REPLS"]["TYPE"]
             else:
                 output_name = shader_base_name
-            write_shader(output_name, final_shader, output_dir, outfile)
+            write_shader(output_name, final_shader, output_dir, outfile, input_dir)
 
 
 def main():
@@ -137,7 +171,8 @@ def main():
         os.makedirs(args.output_dir, exist_ok=True)
 
     with open(args.output_file, "w", encoding="utf-8") as out:
-        out.write("// Auto-generated shader embedding\n\n")
+        out.write("// Auto-generated shader embedding\n")
+        out.write("#include <string>\n\n")
         for fname in sorted(os.listdir(args.input_dir)):
             if fname.endswith(".wgsl"):
                 generate_variants(fname, args.input_dir, args.output_dir, out)
diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/get_rows.tmpl.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/get_rows.tmpl.wgsl
deleted file mode 100644 (file)
index f80ce1f..0000000
+++ /dev/null
@@ -1,874 +0,0 @@
-#define(VARIANTS)
-
-[
-  {
-    "SHADER_SUFFIX": "f32_vec",
-    "REPLS": {
-      "TYPE" : "vec4<f32>",
-      "DST_TYPE": "vec4<f32>",
-      "BLOCK_SIZE": 4
-    },
-    "DECLS": ["F32_VEC"]
-  },
-  {
-    "REPLS": {
-      "TYPE" : "f32",
-      "DST_TYPE": "f32",
-      "BLOCK_SIZE": 1
-    },
-    "DECLS": ["F32"]
-  },
-  {
-    "REPLS": {
-      "TYPE" : "f16",
-      "DST_TYPE": "f32",
-      "BLOCK_SIZE": 1
-    },
-    "DECLS": ["F16"]
-  },
-  {
-    "REPLS": {
-      "TYPE" : "i32",
-      "DST_TYPE": "i32",
-      "BLOCK_SIZE": 1
-    },
-    "DECLS": ["I32"]
-  },
-  {
-    "REPLS": {
-      "TYPE" : "q4_0",
-      "DST_TYPE": "f32",
-      "BLOCK_SIZE": 32
-    },
-    "DECLS": ["BYTE_HELPERS", "Q4_0_T", "Q4_0"]
-  },
-  {
-    "REPLS": {
-      "TYPE" : "q4_1",
-      "DST_TYPE": "f32",
-      "BLOCK_SIZE": 32
-    },
-    "DECLS": ["BYTE_HELPERS", "Q4_1_T", "Q4_1"]
-  },
-  {
-    "REPLS": {
-      "TYPE" : "q5_0",
-      "DST_TYPE": "f32",
-      "BLOCK_SIZE": 32
-    },
-    "DECLS": ["BYTE_HELPERS", "Q5_0_T", "Q5_0"]
-  },
-  {
-    "REPLS": {
-      "TYPE" : "q5_1",
-      "DST_TYPE": "f32",
-      "BLOCK_SIZE": 32
-    },
-    "DECLS": ["BYTE_HELPERS", "Q5_1_T", "Q5_1"]
-  },
-  {
-    "REPLS": {
-      "TYPE" : "q8_0",
-      "DST_TYPE": "f32",
-      "BLOCK_SIZE": 32
-    },
-    "DECLS": ["BYTE_HELPERS", "Q8_0_T", "Q8_0"]
-  },
-  {
-    "REPLS": {
-      "TYPE" : "q2_k",
-      "DST_TYPE": "f32",
-      "BLOCK_SIZE": 256
-    },
-    "DECLS": ["BYTE_HELPERS", "Q2_K_T", "Q2_K"]
-  },
-  {
-    "REPLS": {
-      "TYPE" : "q3_k",
-      "DST_TYPE": "f32",
-      "BLOCK_SIZE": 256
-    },
-    "DECLS": ["BYTE_HELPERS", "Q3_K_T", "Q3_K"]
-  },
-  {
-    "REPLS": {
-      "TYPE" : "q4_k",
-      "DST_TYPE": "f32",
-      "BLOCK_SIZE": 256
-    },
-    "DECLS": ["Q45_K_SCALE_MIN", "BYTE_HELPERS", "Q4_K_T", "Q4_K"]
-  },
-  {
-    "REPLS": {
-      "TYPE" : "q5_k",
-      "DST_TYPE": "f32",
-      "BLOCK_SIZE": 256
-    },
-    "DECLS": ["Q45_K_SCALE_MIN", "BYTE_HELPERS", "Q5_K_T", "Q5_K"]
-  },
-  {
-    "REPLS": {
-      "TYPE" : "q6_k",
-      "DST_TYPE": "f32",
-      "BLOCK_SIZE": 256
-    },
-    "DECLS": ["BYTE_HELPERS", "Q6_K_T", "Q6_K"]
-  },
-  {
-    "REPLS": {
-      "TYPE" : "iq2_xxs",
-      "DST_TYPE": "f32",
-      "BLOCK_SIZE": 256
-    },
-    "DECLS": ["BYTE_HELPERS", "IQ23_TABLES", "IQ2_XXS_GRID", "IQ2_XXS_T", "IQ2_XXS"]
-  },
-  {
-    "REPLS": {
-      "TYPE" : "iq2_xs",
-      "DST_TYPE": "f32",
-      "BLOCK_SIZE": 256
-    },
-    "DECLS": ["BYTE_HELPERS", "IQ23_TABLES", "IQ2_XS_GRID", "IQ2_XS_T", "IQ2_XS"]
-  },
-  {
-    "REPLS": {
-      "TYPE": "iq2_s",
-      "DST_TYPE": "f32",
-      "BLOCK_SIZE": 256
-    },
-    "DECLS": ["BYTE_HELPERS", "IQ23_TABLES", "IQ2_S_GRID", "IQ2_S_T", "IQ2_S"]
-  },
-  {
-    "REPLS": {
-      "TYPE": "iq3_xxs",
-      "DST_TYPE": "f32",
-      "BLOCK_SIZE": 256
-    },
-    "DECLS": ["BYTE_HELPERS", "IQ23_TABLES", "IQ3_XSS_GRID", "IQ3_XSS_T", "IQ3_XSS"]
-  },
-  {
-    "REPLS": {
-      "TYPE": "iq3_s",
-      "DST_TYPE": "f32",
-      "BLOCK_SIZE": 256
-    },
-    "DECLS": ["BYTE_HELPERS", "IQ23_TABLES", "IQ3_S_GRID", "IQ3_S_T", "IQ3_S"]
-  },
-  {
-    "REPLS": {
-      "TYPE": "iq1_s",
-      "DST_TYPE": "f32",
-      "BLOCK_SIZE": 256
-    },
-    "DECLS": ["BYTE_HELPERS", "IQ1_GRID", "IQ1_S_T", "IQ1_S"]
-  },
-  {
-    "REPLS": {
-      "TYPE": "iq1_m",
-      "DST_TYPE": "f32",
-      "BLOCK_SIZE": 256
-    },
-    "DECLS": ["BYTE_HELPERS", "IQ1_GRID", "IQ1_M_T", "IQ1_M"]
-  },
-  {
-    "REPLS": {
-      "TYPE": "iq4_nl",
-      "DST_TYPE": "f32",
-      "BLOCK_SIZE": 32,
-    },
-    "DECLS": ["BYTE_HELPERS", "IQ4_GRID", "IQ4_NL_T", "IQ4_NL"]
-  },
-  {
-    "REPLS": {
-      "TYPE": "iq4_xs",
-      "DST_TYPE": "f32",
-      "BLOCK_SIZE": 256,
-    },
-    "DECLS": ["BYTE_HELPERS", "IQ4_GRID", "IQ4_XS_T", "IQ4_XS"]
-  }
-]
-
-#end(VARIANTS)
-
-#define(DECLS)
-
-#decl(F32_VEC)
-fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
-    dst[(dst_base / 4) + offset] = src[(src_base / 4) + offset];
-}
-#enddecl(F32_VEC)
-
-#decl(F32)
-fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
-    dst[dst_base + offset] = src[src_base + offset];
-}
-#enddecl(F32)
-
-#decl(F16)
-fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
-    dst[dst_base + offset] = f32(src[src_base + offset]);
-}
-#enddecl(F16)
-
-#decl(I32)
-fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
-    dst[dst_base + offset] = src[src_base + offset];
-}
-#enddecl(I32)
-
-#decl(Q4_0)
-fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
-    let block_q4_0 = src[src_base + offset];
-    let d = f32(block_q4_0.d);
-    for (var j: u32 = 0; j < 4; j++) {
-        let q_packed = bitcast<u32>(vec2(block_q4_0.qs[2 * j], block_q4_0.qs[2 * j + 1]));
-        for (var k: u32 = 0; k < 4; k++) {
-            let q_byte = get_byte(q_packed, k);
-            let q_hi = (f32((q_byte >> 4) & 0xF) - 8.0f) * d;
-            let q_lo = (f32(q_byte & 0xF) - 8.0f) * d;
-            let dst_offset = dst_base + offset * 32 + j * 4 + k;
-            dst[dst_offset] = q_lo;
-            dst[dst_offset + 16] = q_hi;
-        }
-    }
-}
-#enddecl(Q4_0)
-
-#decl(Q4_1)
-fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
-    let block_q4_1 = src[src_base + offset];
-    let d = f32(block_q4_1.d);
-    let m = f32(block_q4_1.m);
-    for (var j: u32 = 0; j < 4; j++) {
-        let q_packed = block_q4_1.qs[j];
-        for (var k: u32 = 0; k < 4; k++) {
-            let q_byte = get_byte(q_packed, k);
-            let q_hi = f32((q_byte >> 4) & 0xF) * d + m;
-            let q_lo = f32(q_byte & 0xF) * d + m;
-            let dst_offset = dst_base + offset * 32 + j * 4 + k;
-            dst[dst_offset] = q_lo;
-            dst[dst_offset + 16] = q_hi;
-        }
-    }
-}
-#enddecl(Q4_1)
-
-#decl(Q5_0)
-fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
-    let block_q5_0 = src[src_base + offset];
-    let d = f32(block_q5_0.d);
-    let qh_packed = bitcast<u32>(vec2(block_q5_0.qh[0], block_q5_0.qh[1]));
-    for (var j: u32 = 0; j < 4; j++) {
-        let q_packed = bitcast<u32>(vec2(block_q5_0.qs[2 * j], block_q5_0.qs[2 * j + 1]));
-        for (var k: u32 = 0; k < 4; k++) {
-            let q_byte = get_byte(q_packed, k);
-            let qh_hi = (qh_packed >> (j * 4 + k + 12)) & 0x10;
-            let q_hi = (f32(((q_byte >> 4) & 0xF) | qh_hi) - 16.0) * d;
-            let qh_lo = ((qh_packed >> (j * 4 + k)) << 4) & 0x10;
-            let q_lo = (f32((q_byte & 0xF) | qh_lo) - 16.0) * d;
-            let dst_offset = dst_base + offset * 32 + j * 4 + k;
-            dst[dst_offset] = q_lo;
-            dst[dst_offset + 16] = q_hi;
-        }
-    }
-}
-
-#enddecl(Q5_0)
-
-#decl(Q5_1)
-fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
-    let block_q5_1 = src[src_base + offset];
-    let d = f32(block_q5_1.d);
-    let m = f32(block_q5_1.m);
-    for (var j: u32 = 0; j < 4; j++) {
-        let q_packed = block_q5_1.qs[j];
-        for (var k: u32 = 0; k < 4; k++) {
-            let q_byte = get_byte(q_packed, k);
-            let qh_hi = (block_q5_1.qh >> (j * 4 + k + 12)) & 0x10;
-            let q_hi = f32(((q_byte >> 4) & 0xF) | qh_hi) * d + m;
-            let qh_lo = ((block_q5_1.qh >> (j * 4 + k)) << 4) & 0x10;
-            let q_lo = f32((q_byte & 0xF) | qh_lo) * d + m;
-            let dst_offset = dst_base + offset * 32 + j * 4 + k;
-            dst[dst_offset] = q_lo;
-            dst[dst_offset + 16] = q_hi;
-        }
-    }
-}
-#enddecl(Q5_1)
-
-#decl(Q8_0)
-fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
-    let block_q8_0 = src[src_base + offset];
-    let d = f32(block_q8_0.d);
-    for (var j: u32 = 0; j < 8; j++) {
-        let q_packed = bitcast<u32>(vec2(block_q8_0.qs[2 * j], block_q8_0.qs[2 * j + 1]));
-        for (var k: u32 = 0; k < 4; k++) {
-            let q_byte = get_byte_i32(q_packed, k);
-            let q_val = f32(q_byte) * d;
-            let dst_offset = dst_base + offset * 32 + j * 4 + k;
-            dst[dst_offset] = q_val;
-        }
-    }
-}
-#enddecl(Q8_0)
-
-#decl(Q2_K)
-fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
-    let block = src[src_base + offset];
-    let d = f32(block.d);
-    let m = f32(block.dmin);
-    var dst_i = dst_base + offset * 256;
-    var is: u32 = 0;
-    // 2 halves of the block (128 elements each)
-    for (var q_b_idx: u32 = 0; q_b_idx < 64; q_b_idx += 32) {
-        // 4 groups (each group has 2 blocks of 16 elements)
-        for (var shift: u32 = 0; shift < 8; shift += 2) {
-            // 2 blocks
-            for (var k: u32 = 0; k < 32; k += 16) {
-                let sc = get_byte(block.scales[is / 4], is % 4);
-                is++;
-                let dl = d * f32(sc & 0xF);
-                let ml = m * f32(sc >> 4);
-                for (var l: u32 = 0u; l < 16; l++) {
-                    let q_idx = q_b_idx + k + l;
-                    let q_byte = get_byte(block.qs[q_idx / 4], q_idx % 4);
-                    let qs_val = (q_byte >> shift) & 3;
-                    dst[dst_i] = (f32(qs_val) * dl - ml);
-                    dst_i++;
-                }
-            }
-        }
-    }
-}
-#enddecl(Q2_K)
-
-#decl(Q3_K)
-fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
-    let block = src[src_base + offset];
-    let d = f32(block.d);
-
-    // extract 6-bit scales, which consist of 4-bits from first 8 bytes of scale,
-    // and 2-bits from the last 4 bytes
-    let kmask1: u32 = 0x03030303;
-    let kmask2: u32 = 0x0f0f0f0f;
-    var scale_vals: array<u32, 4>;
-    for (var i: u32 = 0; i < 4; i++) {
-        scale_vals[i] = bitcast<u32>(vec2(block.scales[2 * i], block.scales[2 * i + 1]));
-    }
-    var tmp: u32 = scale_vals[2];
-    scale_vals[2] = ((scale_vals[0] >> 4) & kmask2) | (((tmp >> 4) & kmask1) << 4);
-    scale_vals[3] = ((scale_vals[1] >> 4) & kmask2) | (((tmp >> 6) & kmask1) << 4);
-    scale_vals[0] = (scale_vals[0] & kmask2) | ((tmp & kmask1) << 4);
-    scale_vals[1] = (scale_vals[1] & kmask2) | (((tmp >> 2) & kmask1) << 4);
-
-    // convert arrays of f16 -> u32
-    var hmask_vals: array<u32, 8>;
-    for (var i: u32 = 0; i < 8; i++) {
-        hmask_vals[i] = bitcast<u32>(vec2(block.hmask[2 * i], block.hmask[2 * i + 1]));
-    }
-    var qs_vals: array<u32, 16>;
-    for (var i: u32 = 0; i < 16; i++) {
-        qs_vals[i] = bitcast<u32>(vec2(block.qs[2 * i], block.qs[2 * i + 1]));
-    }
-
-    var dst_i = dst_base + offset * 256;
-    var is: u32 = 0;
-    var m: u32 = 1;
-    // 2 halves of the block (128 elements each)
-    for (var q_b_idx: u32 = 0; q_b_idx < 64; q_b_idx += 32) {
-        // 4 groups (each group has 2 blocks of 16 elements)
-        for (var shift: u32 = 0; shift < 8; shift += 2) {
-            // 2 blocks
-            for (var k: u32 = 0; k < 32; k += 16) {
-                let sc = get_byte(scale_vals[is / 4], is % 4);
-                is++;
-                let dl = d * (f32(sc) - 32.0);
-                for (var l: u32 = 0u; l < 16u; l++) {
-                    let q_idx = q_b_idx + k + l;
-                    let hm_idx = k + l;
-                    let q_byte = get_byte(qs_vals[q_idx / 4], q_idx % 4);
-                    let hmask_byte = get_byte(hmask_vals[hm_idx / 4], hm_idx % 4);
-                    let hm = select(4.0, 0.0, (hmask_byte & m) != 0);
-                    let qs_val = (q_byte >> shift) & 3;
-                    dst[dst_i] = (f32(qs_val) - hm) * dl;
-                    dst_i++;
-                }
-            }
-            m <<= 1;
-        }
-    }
-}
-#enddecl(Q3_K)
-
-#decl(Q4_K)
-// 8 blocks of 32 elements each
-fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
-    let block = src[src_base + offset];
-    let d = f32(block.d);
-    let m = f32(block.dmin);
-    var dst_i = dst_base + offset * 256;
-    var is: u32 = 0;
-    // 2 blocks each iteration
-    for (var q_b_idx: u32 = 0; q_b_idx < 128; q_b_idx += 32) {
-        for (var shift: u32 = 0; shift < 8; shift += 4) {
-            let scale_min = get_scale_min(is, block.scales);
-            is++;
-            let dl = d * scale_min.x;
-            let ml = m * scale_min.y;
-            for (var l: u32 = 0; l < 32; l++) {
-                let q_idx = q_b_idx + l;
-                let q_byte = get_byte(block.qs[q_idx / 4], q_idx % 4);
-                let qs_val = (q_byte >> shift) & 0xF;
-                dst[dst_i] = (f32(qs_val) * dl - ml);
-                dst_i++;
-            }
-        }
-    }
-}
-#enddecl(Q4_K)
-
-#decl(Q5_K)
-fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
-    let block = src[src_base + offset];
-    let d = f32(block.d);
-    let m = f32(block.dmin);
-    var dst_i = dst_base + offset * 256;
-    var is: u32 = 0;
-    var u: u32 = 1;
-    // 2 blocks each iteration
-    for (var q_b_idx: u32 = 0; q_b_idx < 128; q_b_idx += 32) {
-        for (var shift: u32 = 0; shift < 8; shift += 4) {
-            let scale_min = get_scale_min(is, block.scales);
-            is++;
-            let dl = d * scale_min.x;
-            let ml = m * scale_min.y;
-            for (var l: u32 = 0; l < 32; l++) {
-                let q_idx = q_b_idx + l;
-                let q_byte = get_byte(block.qs[q_idx / 4], q_idx % 4);
-                let qh_byte = get_byte(block.qh[l / 4], l % 4);
-                let qs_val = (q_byte >> shift) & 0xF;
-                let qh_val = select(0.0, 16.0, (qh_byte & u) != 0);
-                dst[dst_i] = (f32(qs_val) + qh_val) * dl - ml;
-                dst_i++;
-            }
-            u <<= 1;
-        }
-    }
-}
-#enddecl(Q5_K)
-
-#decl(Q6_K)
-// 16 blocks of 16 elements each
-fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
-    let block = src[src_base + offset];
-    let d = f32(block.d);
-
-    // convert arrays of f16 -> u32
-    var ql_vals: array<u32, 32>;
-    for (var i: u32 = 0; i < 32; i++) {
-        ql_vals[i] = bitcast<u32>(vec2(block.ql[2 * i], block.ql[2 * i + 1]));
-    }
-    var qh_vals: array<u32, 16>;
-    for (var i: u32 = 0; i < 16; i++) {
-        qh_vals[i] = bitcast<u32>(vec2(block.qh[2 * i], block.qh[2 * i + 1]));
-    }
-    var scale_vals: array<u32, 4>;
-    for (var i: u32 = 0; i < 4; i++) {
-        scale_vals[i] = bitcast<u32>(vec2(block.scales[2 * i], block.scales[2 * i + 1]));
-    }
-
-    var dst_i = dst_base + offset * 256;
-    var qh_b_idx: u32 = 0;
-    var sc_b_idx: u32 = 0;
-    for (var ql_b_idx: u32 = 0; ql_b_idx < 128; ql_b_idx += 64) {
-        for (var l: u32 = 0; l < 32; l++) {
-            let ql13_b = get_byte(ql_vals[(ql_b_idx + l) / 4], (ql_b_idx + l) % 4);
-            let ql24_b = get_byte(ql_vals[(ql_b_idx + l + 32) / 4], (ql_b_idx + l + 32) % 4);
-            let qh_b = get_byte(qh_vals[(qh_b_idx + l) / 4], (qh_b_idx + l) % 4);
-
-            let q1 = f32((ql13_b & 0xF) | ((qh_b & 3) << 4)) - 32.0;
-            let q2 = f32((ql24_b & 0xF) | (((qh_b >> 2) & 3) << 4)) - 32.0;
-            let q3 = f32((ql13_b >> 4) | (((qh_b >> 4) & 3) << 4)) - 32.0;
-            let q4 = f32((ql24_b >> 4) | (((qh_b >> 6) & 3) << 4)) - 32.0;
-
-            let is = l/16;
-            let is1 = sc_b_idx + is;
-            let sc1 = get_byte_i32(scale_vals[is1 / 4], is1 % 4);
-            let is2 = sc_b_idx + is + 2;
-            let sc2 = get_byte_i32(scale_vals[is2 / 4], is2 % 4);
-            let is3 = sc_b_idx + is + 4;
-            let sc3 = get_byte_i32(scale_vals[is3 / 4], is3 % 4);
-            let is4 = sc_b_idx + is + 6;
-            let sc4 = get_byte_i32(scale_vals[is4 / 4], is4 % 4);
-
-            dst[dst_i + l] = (q1 * f32(sc1)) * d;
-            dst[dst_i + l + 32] = (q2 * f32(sc2)) * d;
-            dst[dst_i + l + 64] = (q3 * f32(sc3)) * d;
-            dst[dst_i + l + 96] = (q4 * f32(sc4)) * d;
-        }
-        dst_i += 128;
-        qh_b_idx += 32;
-        sc_b_idx += 8;
-    }
-}
-
-#enddecl(Q6_K)
-
-#decl(IQ2_XXS)
-fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
-    let block = src[src_base + offset];
-    let d = f32(block.d);
-    var dst_i = dst_base + offset * 256;
-    for (var ib: u32 = 0; ib < 32; ib += 4) {
-        let aux0 = bitcast<u32>(vec2(block.qs[ib], block.qs[ib + 1]));
-        let aux1 = bitcast<u32>(vec2(block.qs[ib + 2], block.qs[ib + 3]));
-        let db = d * (0.5 + f32(aux1 >> 28)) * 0.25;
-        for (var l: u32 = 0; l < 4; l++) {
-            let ig = get_byte(aux0, l) * 8;
-            let is = (aux1 >> (7 * l)) & 127;
-            let signs = get_byte(ksigns_iq2xs[is / 4], is % 4);
-            for (var j: u32 = 0; j < 8; j++) {
-                let g = get_byte(iq2xxs_grid[(ig + j) / 4], (ig + j) % 4);
-                let m = select(1.0, -1.0, (get_byte(kmask_iq2xs[j / 4], j % 4) & signs) != 0);
-                dst[dst_i] = db * f32(g) * m;
-                dst_i++;
-            }
-        }
-    }
-}
-#enddecl(IQ2_XXS)
-
-#decl(IQ2_XS)
-fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
-    let block = src[src_base + offset];
-    let d = f32(block.d);
-    var dst_i = dst_base + offset * 256;
-    var scale_vals = array<u32, 2>(
-        bitcast<u32>(vec2(block.scales[0], block.scales[1])),
-        bitcast<u32>(vec2(block.scales[2], block.scales[3]))
-    );
-    for (var ib: u32 = 0; ib < 32; ib += 4) {
-        let s = get_byte(scale_vals[ib / 16], (ib % 16) / 4);
-        let db = array<f32, 2>(
-            d * (0.5 + f32(s & 0xF)) * 0.25,
-            d * (0.5 + f32(s >> 4)) * 0.25
-        );
-        for (var l: u32 = 0; l < 4; l++) {
-            let qs_val = bitcast<u32>(vec2(block.qs[ib + l], 0.0));
-            let ig = (qs_val & 511) * 8;
-            let is = qs_val >> 9;
-            let signs = get_byte(ksigns_iq2xs[is / 4], is % 4);
-            let dl = db[l/2];
-            for (var j: u32 = 0; j < 8; j++) {
-                let g = get_byte(iq2xs_grid[(ig + j) / 4], (ig + j) % 4);
-                let m = select(1.0, -1.0, (get_byte(kmask_iq2xs[j / 4], j % 4) & signs) != 0);
-                dst[dst_i] = dl * f32(g) * m;
-                dst_i++;
-            }
-        }
-    }
-}
-#enddecl(IQ2_XS)
-
-#decl(IQ2_S)
-fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
-    let block = src[src_base + offset];
-    let d = f32(block.d);
-    var dst_i = dst_base + offset * 256;
-    var qs_vals : array<u32, 16>;
-    for (var i: u32 = 0; i < 16; i++) {
-        qs_vals[i] = bitcast<u32>(vec2(block.qs[i * 2], block.qs[i * 2 + 1]));
-    }
-    var qh_vals = array<u32, 2>(
-        bitcast<u32>(vec2(block.qh[0], block.qh[1])),
-        bitcast<u32>(vec2(block.qh[2], block.qh[3]))
-    );
-    var scale_vals = array<u32, 2>(
-        bitcast<u32>(vec2(block.scales[0], block.scales[1])),
-        bitcast<u32>(vec2(block.scales[2], block.scales[3]))
-    );
-    for (var ib: u32 = 0; ib < 8; ib ++) {
-        let s = get_byte(scale_vals[ib / 4], ib % 4);
-        let db = array<f32, 2>(
-            d * (0.5 + f32(s & 0xF)) * 0.25,
-            d * (0.5 + f32(s >> 4)) * 0.25
-        );
-        let qs_w = qs_vals[ib];
-        for (var l: u32 = 0; l < 4; l++) {
-            let qh_b = (get_byte(qh_vals[ib / 4], ib % 4) << (8 - 2 * l)) & 0x300;
-            let ig = (get_byte(qs_w, l) | qh_b) * 8;
-            let signs = get_byte(qs_vals[ib + 8], l);
-            let dl = db[l/2];
-            for (var j: u32 = 0; j < 8; j++) {
-                let g = get_byte(iq2s_grid[(ig + j) / 4], (ig + j) % 4);
-                let m = select(1.0, -1.0, (get_byte(kmask_iq2xs[j / 4], j % 4) & signs) != 0);
-                dst[dst_i] = dl * f32(g) * m;
-                dst_i++;
-            }
-        }
-    }
-}
-
-#enddecl(IQ2_S)
-
-#decl(IQ3_XSS)
-fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
-    let block = src[src_base + offset];
-    let d = f32(block.d);
-    var dst_i = dst_base + offset * 256;
-    for (var ib: u32 = 0; ib < 16; ib += 2) {
-        let sc_sign = bitcast<u32>(vec2(block.qs[ib + 32], block.qs[ib + 33]));
-        let db = d * (0.5 + f32(sc_sign >> 28)) * 0.5;
-        for (var l: u32 = 0; l < 4; l++) {
-            let is = (sc_sign >> (7 * l)) & 127;
-            let signs = get_byte(ksigns_iq2xs[is / 4], is % 4);
-            let ig_val = bitcast<u32>(vec2(block.qs[ib * 2 + l], 0.0));
-            let ig1 = get_byte(ig_val, 0);
-            let ig2 = get_byte(ig_val, 1);
-            for (var j: u32 = 0; j < 4; j++) {
-                let g1 = get_byte(iq3xxs_grid[ig1], j);
-                let g2 = get_byte(iq3xxs_grid[ig2], j);
-                let m1 = select(1.0, -1.0, (get_byte(kmask_iq2xs[0], j) & signs) != 0);
-                let m2 = select(1.0, -1.0, (get_byte(kmask_iq2xs[1], j) & signs) != 0);
-                dst[dst_i] = db * f32(g1) * m1;
-                dst[dst_i + 4] = db * f32(g2) * m2;
-                dst_i++;
-            }
-            dst_i += 4;
-        }
-    }
-}
-#enddecl(IQ3_XSS)
-
-#decl(IQ3_S)
-fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
-    let block = src[src_base + offset];
-    let d = f32(block.d);
-    var dst_i = dst_base + offset * 256;
-    var qh_vals = array<u32, 2>(
-        bitcast<u32>(vec2(block.qh[0], block.qh[1])),
-        bitcast<u32>(vec2(block.qh[2], block.qh[3]))
-    );
-    var sign_vals: array<u32, 8>;
-    for (var i: u32 = 0; i < 8; i++) {
-        sign_vals[i] = bitcast<u32>(vec2(block.signs[i * 2], block.signs[i * 2 + 1]));
-    }
-    var scale_vals = bitcast<u32>(vec2(block.scales[0], block.scales[1]));
-    for (var ib: u32 = 0; ib < 4; ib++) {
-        let s = get_byte(scale_vals, ib);
-        let db = array<f32, 2>(
-            d * (1.0 + 2.0 * f32(s & 0xF)),
-            d * (1.0 + 2.0 * f32(s >> 4))
-        );
-        for (var k: u32 = 0; k < 2; k++) {
-            let dl = db[k];
-            let qh_byte = get_byte(qh_vals[ib / 2], (ib % 2) * 2 + k);
-            let sign_w = sign_vals[ib * 2 + k];
-            for (var l: u32 = 0; l < 4; l++) {
-                let signs = get_byte(sign_w, l);
-                let ig_val = bitcast<u32>(vec2(block.qs[ib * 8 + k * 4 + l], 0.0));
-                let ig1 = get_byte(ig_val, 0) | ((qh_byte << ((8 - (2 * l)))) & 256);
-                let ig2 = get_byte(ig_val, 1) | ((qh_byte << ((7 - (2 * l)))) & 256);
-                for (var j: u32 = 0; j < 4; j++) {
-                    let g1 = get_byte(iq3s_grid[ig1], j);
-                    let g2 = get_byte(iq3s_grid[ig2], j);
-                    let m1 = select(1.0, -1.0, (get_byte(kmask_iq2xs[0], j) & signs) != 0);
-                    let m2 = select(1.0, -1.0, (get_byte(kmask_iq2xs[1], j) & signs) != 0);
-                    dst[dst_i] = dl * f32(g1) * m1;
-                    dst[dst_i + 4] = dl * f32(g2) * m2;
-                    dst_i++;
-                }
-                dst_i += 4;
-            }
-        }
-    }
-}
-#enddecl(IQ3_S)
-
-#decl(IQ1_S)
-fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
-    let block = src[src_base + offset];
-    let d = f32(block.d);
-    var dst_i = dst_base + offset * 256;
-    for (var ib: u32 = 0; ib < 8; ib++) {
-        let qh = bitcast<u32>(vec2(block.qh[ib], 0.0));
-        let dl = d * (2 * f32((qh >> 12) & 7) + 1);
-        let delta = select(IQ1_DELTA, -IQ1_DELTA, (qh & 0x8000) != 0);
-        let qs_w = bitcast<u32>(vec2(block.qs[ib * 2], block.qs[ib * 2 + 1]));
-        for (var l: u32 = 0; l < 4; l++) {
-            let ig = (get_byte(qs_w, l) | (((qh >> (3 * l)) & 7) << 8)) * 8;
-            for (var j: u32 = 0; j < 8; j++) {
-                let gw = iq1_grid[(ig + j) / 16];
-                let g = (gw >> (((ig + j) % 16) * 2)) & 3;
-                let gs = bitcast<i32>(g << 30) >> 30;
-                dst[dst_i] = dl * (f32(gs) + delta);
-                dst_i++;
-            }
-        }
-    }
-}
-
-#enddecl(IQ1_S)
-
-#decl(IQ1_M)
-fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
-    let block = src[src_base + offset];
-
-    let scale = ((block.scales[0] >> 12) & 0xF) | ((block.scales[0] >> 24) & 0x00F0) | ((block.scales[1] >> 4) & 0x0F00) | ((block.scales[1] >> 16) & 0xF000);
-    let d = f32(bitcast<vec2<f16>>(scale).x);
-    var dst_i = dst_base + offset * 256;
-    for (var ib: u32 = 0; ib < 8; ib++) {
-        let sw = (block.scales[ib / 4] >> (16 * ((ib / 2) % 2))) & 0xFFFF;
-        let s1 : u32 = (sw >> (6 * (ib % 2))) & 0x7;
-        let s2 : u32 = (sw >> (6 * (ib % 2) + 3)) & 0x7;
-        var dl = array<f32, 2>(
-            d * f32(2 * s1 + 1),
-            d * f32(2 * s2 + 1)
-        );
-
-        let qh = block.qh[ib / 2] >> (16 * (ib % 2));
-        var idx = array<u32, 4>(
-            get_byte(block.qs[ib], 0) | ((qh << 8) & 0x700),
-            get_byte(block.qs[ib], 1) | ((qh << 4) & 0x700),
-            get_byte(block.qs[ib], 2) | ((qh) & 0x700),
-            get_byte(block.qs[ib], 3) | ((qh >> 4) & 0x700)
-        );
-        var delta = array<f32, 4>(
-            select(IQ1_DELTA, -IQ1_DELTA, (qh & 0x08) != 0),
-            select(IQ1_DELTA, -IQ1_DELTA, (qh & 0x80) != 0),
-            select(IQ1_DELTA, -IQ1_DELTA, ((qh >> 8) & 0x08) != 0),
-            select(IQ1_DELTA, -IQ1_DELTA, ((qh >> 8) & 0x80) != 0)
-        );
-        for (var l: u32 = 0; l < 4; l++) {
-            let ig = idx[l] * 8;
-            for (var j: u32 = 0; j < 8; j++) {
-                let gw = iq1_grid[(ig + j) / 16];
-                let g = (gw >> (((ig + j) % 16) * 2)) & 3;
-                let gs = bitcast<i32>(g << 30) >> 30;
-                dst[dst_i] = dl[l/2] * (f32(gs) + delta[l]);
-                dst_i++;
-            }
-        }
-    }
-}
-
-#enddecl(IQ1_M)
-
-#decl(IQ4_NL)
-fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
-    let block = src[src_base + offset];
-    let d = f32(block.d);
-    var dst_i = dst_base + offset * 32;
-    var qs: array<u32, 4>;
-    for (var i: u32 = 0; i < 4; i++) {
-        qs[i] = bitcast<u32>(vec2(block.qs[i * 2], block.qs[i * 2 + 1]));
-    }
-    for (var j: u32 = 0; j < 16; j++) {
-        let qsb = get_byte(qs[j / 4], j % 4);
-        dst[dst_i] = d * f32(kvalues_iq4nl[qsb & 0xF]);
-        dst[dst_i + 16] = d * f32(kvalues_iq4nl[qsb >> 4]);
-        dst_i++;
-    }
-}
-#enddecl(IQ4_NL)
-
-#decl(IQ4_XS)
-fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
-    let block = src[src_base + offset];
-    let d = f32(block.d);
-    let scales_h = bitcast<u32>(vec2(block.scales_h, 0.0));
-    var dst_i = dst_base + offset * 256;
-    for (var ib: u32 = 0; ib < 8; ib++) {
-        let ls = ((get_byte(block.scales_l, ib / 2) >> (4 * (ib % 2))) & 0xF) | (((scales_h >> (2 * ib)) & 3) << 4);
-        let dl = d * (f32(ls) - 32.0);
-        for (var j: u32 = 0; j < 16; j++) {
-            let iqs = ib * 16 + j;
-            let qsb = get_byte(block.qs[iqs / 4], iqs % 4);
-            dst[dst_i] = dl * f32(kvalues_iq4nl[qsb & 0xF]);
-            dst[dst_i + 16] = dl * f32(kvalues_iq4nl[qsb >> 4]);
-            dst_i++;
-        }
-        dst_i += 16;
-    }
-}
-#enddecl(IQ4_XS)
-
-#end(DECLS)
-
-#define(SHADER)
-
-enable f16;
-
-DECLS
-
-@group(0) @binding(0)
-var<storage, read_write> src: array<{{TYPE}}>;
-
-@group(0) @binding(1)
-var<storage, read_write> idx: array<i32>;
-
-@group(0) @binding(2)
-var<storage, read_write> dst: array<{{DST_TYPE}}>;
-
-struct Params {
-    offset_src: u32, // in elements
-    offset_idx: u32, // in elements
-    offset_dst: u32, // in elements
-
-    // Strides (in elements)
-    stride_src1: u32,
-    stride_src2: u32,
-    stride_src3: u32,
-
-    stride_idx0: u32,
-    stride_idx1: u32,
-    stride_idx2: u32,
-
-    stride_dst1: u32,
-    stride_dst2: u32,
-    stride_dst3: u32,
-
-    // Shape of dst
-    ne0: u32,
-    n_rows: u32,
-    ne2: u32,
-    ne3: u32,
-
-    // Shape of idx
-    idx1: u32,
-    idx2: u32,
-};
-
-@group(0) @binding(3)
-var<uniform> params: Params;
-
-override wg_size: u32;
-@compute @workgroup_size(wg_size)
-fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
-    if (gid.x >= params.n_rows * params.ne2 * params.ne3) {
-        return;
-    }
-    var i = gid.x;
-    let i_dst3 = i / (params.ne2 * params.n_rows);
-
-    i = i % (params.ne2 * params.n_rows);
-    let i_dst2 = i / params.n_rows;
-    let i_dst1 = i % params.n_rows;
-
-    let i_idx2 = i_dst3 % params.idx2;
-    let i_idx1 = i_dst2 % params.idx1;
-    let i_idx0 = i_dst1;
-
-    let i_idx = params.offset_idx + i_idx0 * params.stride_idx0 + i_idx1 * params.stride_idx1 + i_idx2 * params.stride_idx2;
-
-    let idx_val = u32(idx[i_idx]);
-
-    let i_src_row = params.offset_src + idx_val * params.stride_src1 + i_dst2 * params.stride_src2 + i_dst3 * params.stride_src3;
-    let i_dst_row = params.offset_dst + i_dst1 * params.stride_dst1 + i_dst2 * params.stride_dst2 + i_dst3 * params.stride_dst3;
-
-    for (var i: u32 = 0; i < params.ne0/{{BLOCK_SIZE}}; i++) {
-      copy_elements(i_src_row, i_dst_row, i);
-    }
-}
-
-#end(SHADER)
diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/get_rows.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/get_rows.wgsl
new file mode 100644 (file)
index 0000000..b10800e
--- /dev/null
@@ -0,0 +1,668 @@
+enable f16;
+#include "common_decls.tmpl"
+
+#ifdef F32_VEC
+fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
+    dst[(dst_base / 4) + offset] = src[(src_base / 4) + offset];
+}
+#endif
+
+#ifdef F32
+fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
+    dst[dst_base + offset] = src[src_base + offset];
+}
+#endif
+
+#ifdef F16
+fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
+    dst[dst_base + offset] = f32(src[src_base + offset]);
+}
+#endif
+
+#ifdef I32
+fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
+    dst[dst_base + offset] = src[src_base + offset];
+}
+#endif
+
+#ifdef Q4_0
+fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
+    let block_q4_0 = src[src_base + offset];
+    let d = f32(block_q4_0.d);
+    for (var j: u32 = 0; j < 4; j++) {
+        let q_packed = bitcast<u32>(vec2(block_q4_0.qs[2 * j], block_q4_0.qs[2 * j + 1]));
+        for (var k: u32 = 0; k < 4; k++) {
+            let q_byte = get_byte(q_packed, k);
+            let q_hi = (f32((q_byte >> 4) & 0xF) - 8.0f) * d;
+            let q_lo = (f32(q_byte & 0xF) - 8.0f) * d;
+            let dst_offset = dst_base + offset * 32 + j * 4 + k;
+            dst[dst_offset] = q_lo;
+            dst[dst_offset + 16] = q_hi;
+        }
+    }
+}
+#endif
+
+#ifdef Q4_1
+fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
+    let block_q4_1 = src[src_base + offset];
+    let d = f32(block_q4_1.d);
+    let m = f32(block_q4_1.m);
+    for (var j: u32 = 0; j < 4; j++) {
+        let q_packed = block_q4_1.qs[j];
+        for (var k: u32 = 0; k < 4; k++) {
+            let q_byte = get_byte(q_packed, k);
+            let q_hi = f32((q_byte >> 4) & 0xF) * d + m;
+            let q_lo = f32(q_byte & 0xF) * d + m;
+            let dst_offset = dst_base + offset * 32 + j * 4 + k;
+            dst[dst_offset] = q_lo;
+            dst[dst_offset + 16] = q_hi;
+        }
+    }
+}
+#endif
+
+#ifdef Q5_0
+fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
+    let block_q5_0 = src[src_base + offset];
+    let d = f32(block_q5_0.d);
+    let qh_packed = bitcast<u32>(vec2(block_q5_0.qh[0], block_q5_0.qh[1]));
+    for (var j: u32 = 0; j < 4; j++) {
+        let q_packed = bitcast<u32>(vec2(block_q5_0.qs[2 * j], block_q5_0.qs[2 * j + 1]));
+        for (var k: u32 = 0; k < 4; k++) {
+            let q_byte = get_byte(q_packed, k);
+            let qh_hi = (qh_packed >> (j * 4 + k + 12)) & 0x10;
+            let q_hi = (f32(((q_byte >> 4) & 0xF) | qh_hi) - 16.0) * d;
+            let qh_lo = ((qh_packed >> (j * 4 + k)) << 4) & 0x10;
+            let q_lo = (f32((q_byte & 0xF) | qh_lo) - 16.0) * d;
+            let dst_offset = dst_base + offset * 32 + j * 4 + k;
+            dst[dst_offset] = q_lo;
+            dst[dst_offset + 16] = q_hi;
+        }
+    }
+}
+#endif
+
+#ifdef Q5_1
+fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
+    let block_q5_1 = src[src_base + offset];
+    let d = f32(block_q5_1.d);
+    let m = f32(block_q5_1.m);
+    for (var j: u32 = 0; j < 4; j++) {
+        let q_packed = block_q5_1.qs[j];
+        for (var k: u32 = 0; k < 4; k++) {
+            let q_byte = get_byte(q_packed, k);
+            let qh_hi = (block_q5_1.qh >> (j * 4 + k + 12)) & 0x10;
+            let q_hi = f32(((q_byte >> 4) & 0xF) | qh_hi) * d + m;
+            let qh_lo = ((block_q5_1.qh >> (j * 4 + k)) << 4) & 0x10;
+            let q_lo = f32((q_byte & 0xF) | qh_lo) * d + m;
+            let dst_offset = dst_base + offset * 32 + j * 4 + k;
+            dst[dst_offset] = q_lo;
+            dst[dst_offset + 16] = q_hi;
+        }
+    }
+}
+#endif
+
+#ifdef Q8_0
+fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
+    let block_q8_0 = src[src_base + offset];
+    let d = f32(block_q8_0.d);
+    for (var j: u32 = 0; j < 8; j++) {
+        let q_packed = bitcast<u32>(vec2(block_q8_0.qs[2 * j], block_q8_0.qs[2 * j + 1]));
+        for (var k: u32 = 0; k < 4; k++) {
+            let q_byte = get_byte_i32(q_packed, k);
+            let q_val = f32(q_byte) * d;
+            let dst_offset = dst_base + offset * 32 + j * 4 + k;
+            dst[dst_offset] = q_val;
+        }
+    }
+}
+#endif
+
+#ifdef Q2_K
+fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
+    let block = src[src_base + offset];
+    let d = f32(block.d);
+    let m = f32(block.dmin);
+    var dst_i = dst_base + offset * 256;
+    var is: u32 = 0;
+    // 2 halves of the block (128 elements each)
+    for (var q_b_idx: u32 = 0; q_b_idx < 64; q_b_idx += 32) {
+        // 4 groups (each group has 2 blocks of 16 elements)
+        for (var shift: u32 = 0; shift < 8; shift += 2) {
+            // 2 blocks
+            for (var k: u32 = 0; k < 32; k += 16) {
+                let sc = get_byte(block.scales[is / 4], is % 4);
+                is++;
+                let dl = d * f32(sc & 0xF);
+                let ml = m * f32(sc >> 4);
+                for (var l: u32 = 0u; l < 16; l++) {
+                    let q_idx = q_b_idx + k + l;
+                    let q_byte = get_byte(block.qs[q_idx / 4], q_idx % 4);
+                    let qs_val = (q_byte >> shift) & 3;
+                    dst[dst_i] = (f32(qs_val) * dl - ml);
+                    dst_i++;
+                }
+            }
+        }
+    }
+}
+#endif
+
+#ifdef Q3_K
+fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
+    let block = src[src_base + offset];
+    let d = f32(block.d);
+
+    // extract 6-bit scales, which consist of 4-bits from first 8 bytes of scale,
+    // and 2-bits from the last 4 bytes
+    let kmask1: u32 = 0x03030303;
+    let kmask2: u32 = 0x0f0f0f0f;
+    var scale_vals: array<u32, 4>;
+    for (var i: u32 = 0; i < 4; i++) {
+        scale_vals[i] = bitcast<u32>(vec2(block.scales[2 * i], block.scales[2 * i + 1]));
+    }
+    var tmp: u32 = scale_vals[2];
+    scale_vals[2] = ((scale_vals[0] >> 4) & kmask2) | (((tmp >> 4) & kmask1) << 4);
+    scale_vals[3] = ((scale_vals[1] >> 4) & kmask2) | (((tmp >> 6) & kmask1) << 4);
+    scale_vals[0] = (scale_vals[0] & kmask2) | ((tmp & kmask1) << 4);
+    scale_vals[1] = (scale_vals[1] & kmask2) | (((tmp >> 2) & kmask1) << 4);
+
+    // convert arrays of f16 -> u32
+    var hmask_vals: array<u32, 8>;
+    for (var i: u32 = 0; i < 8; i++) {
+        hmask_vals[i] = bitcast<u32>(vec2(block.hmask[2 * i], block.hmask[2 * i + 1]));
+    }
+    var qs_vals: array<u32, 16>;
+    for (var i: u32 = 0; i < 16; i++) {
+        qs_vals[i] = bitcast<u32>(vec2(block.qs[2 * i], block.qs[2 * i + 1]));
+    }
+
+    var dst_i = dst_base + offset * 256;
+    var is: u32 = 0;
+    var m: u32 = 1;
+    // 2 halves of the block (128 elements each)
+    for (var q_b_idx: u32 = 0; q_b_idx < 64; q_b_idx += 32) {
+        // 4 groups (each group has 2 blocks of 16 elements)
+        for (var shift: u32 = 0; shift < 8; shift += 2) {
+            // 2 blocks
+            for (var k: u32 = 0; k < 32; k += 16) {
+                let sc = get_byte(scale_vals[is / 4], is % 4);
+                is++;
+                let dl = d * (f32(sc) - 32.0);
+                for (var l: u32 = 0u; l < 16u; l++) {
+                    let q_idx = q_b_idx + k + l;
+                    let hm_idx = k + l;
+                    let q_byte = get_byte(qs_vals[q_idx / 4], q_idx % 4);
+                    let hmask_byte = get_byte(hmask_vals[hm_idx / 4], hm_idx % 4);
+                    let hm = select(4.0, 0.0, (hmask_byte & m) != 0);
+                    let qs_val = (q_byte >> shift) & 3;
+                    dst[dst_i] = (f32(qs_val) - hm) * dl;
+                    dst_i++;
+                }
+            }
+            m <<= 1;
+        }
+    }
+}
+#endif
+
+#ifdef Q4_K
+// 8 blocks of 32 elements each
+fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
+    let block = src[src_base + offset];
+    let d = f32(block.d);
+    let m = f32(block.dmin);
+    var dst_i = dst_base + offset * 256;
+    var is: u32 = 0;
+    // 2 blocks each iteration
+    for (var q_b_idx: u32 = 0; q_b_idx < 128; q_b_idx += 32) {
+        for (var shift: u32 = 0; shift < 8; shift += 4) {
+            let scale_min = get_scale_min(is, block.scales);
+            is++;
+            let dl = d * scale_min.x;
+            let ml = m * scale_min.y;
+            for (var l: u32 = 0; l < 32; l++) {
+                let q_idx = q_b_idx + l;
+                let q_byte = get_byte(block.qs[q_idx / 4], q_idx % 4);
+                let qs_val = (q_byte >> shift) & 0xF;
+                dst[dst_i] = (f32(qs_val) * dl - ml);
+                dst_i++;
+            }
+        }
+    }
+}
+#endif
+
+#ifdef Q5_K
+fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
+    let block = src[src_base + offset];
+    let d = f32(block.d);
+    let m = f32(block.dmin);
+    var dst_i = dst_base + offset * 256;
+    var is: u32 = 0;
+    var u: u32 = 1;
+    // 2 blocks each iteration
+    for (var q_b_idx: u32 = 0; q_b_idx < 128; q_b_idx += 32) {
+        for (var shift: u32 = 0; shift < 8; shift += 4) {
+            let scale_min = get_scale_min(is, block.scales);
+            is++;
+            let dl = d * scale_min.x;
+            let ml = m * scale_min.y;
+            for (var l: u32 = 0; l < 32; l++) {
+                let q_idx = q_b_idx + l;
+                let q_byte = get_byte(block.qs[q_idx / 4], q_idx % 4);
+                let qh_byte = get_byte(block.qh[l / 4], l % 4);
+                let qs_val = (q_byte >> shift) & 0xF;
+                let qh_val = select(0.0, 16.0, (qh_byte & u) != 0);
+                dst[dst_i] = (f32(qs_val) + qh_val) * dl - ml;
+                dst_i++;
+            }
+            u <<= 1;
+        }
+    }
+}
+#endif
+
+#ifdef Q6_K
+// 16 blocks of 16 elements each
+fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
+    let block = src[src_base + offset];
+    let d = f32(block.d);
+
+    // convert arrays of f16 -> u32
+    var ql_vals: array<u32, 32>;
+    for (var i: u32 = 0; i < 32; i++) {
+        ql_vals[i] = bitcast<u32>(vec2(block.ql[2 * i], block.ql[2 * i + 1]));
+    }
+    var qh_vals: array<u32, 16>;
+    for (var i: u32 = 0; i < 16; i++) {
+        qh_vals[i] = bitcast<u32>(vec2(block.qh[2 * i], block.qh[2 * i + 1]));
+    }
+    var scale_vals: array<u32, 4>;
+    for (var i: u32 = 0; i < 4; i++) {
+        scale_vals[i] = bitcast<u32>(vec2(block.scales[2 * i], block.scales[2 * i + 1]));
+    }
+
+    var dst_i = dst_base + offset * 256;
+    var qh_b_idx: u32 = 0;
+    var sc_b_idx: u32 = 0;
+    for (var ql_b_idx: u32 = 0; ql_b_idx < 128; ql_b_idx += 64) {
+        for (var l: u32 = 0; l < 32; l++) {
+            let ql13_b = get_byte(ql_vals[(ql_b_idx + l) / 4], (ql_b_idx + l) % 4);
+            let ql24_b = get_byte(ql_vals[(ql_b_idx + l + 32) / 4], (ql_b_idx + l + 32) % 4);
+            let qh_b = get_byte(qh_vals[(qh_b_idx + l) / 4], (qh_b_idx + l) % 4);
+
+            let q1 = f32((ql13_b & 0xF) | ((qh_b & 3) << 4)) - 32.0;
+            let q2 = f32((ql24_b & 0xF) | (((qh_b >> 2) & 3) << 4)) - 32.0;
+            let q3 = f32((ql13_b >> 4) | (((qh_b >> 4) & 3) << 4)) - 32.0;
+            let q4 = f32((ql24_b >> 4) | (((qh_b >> 6) & 3) << 4)) - 32.0;
+
+            let is = l/16;
+            let is1 = sc_b_idx + is;
+            let sc1 = get_byte_i32(scale_vals[is1 / 4], is1 % 4);
+            let is2 = sc_b_idx + is + 2;
+            let sc2 = get_byte_i32(scale_vals[is2 / 4], is2 % 4);
+            let is3 = sc_b_idx + is + 4;
+            let sc3 = get_byte_i32(scale_vals[is3 / 4], is3 % 4);
+            let is4 = sc_b_idx + is + 6;
+            let sc4 = get_byte_i32(scale_vals[is4 / 4], is4 % 4);
+
+            dst[dst_i + l] = (q1 * f32(sc1)) * d;
+            dst[dst_i + l + 32] = (q2 * f32(sc2)) * d;
+            dst[dst_i + l + 64] = (q3 * f32(sc3)) * d;
+            dst[dst_i + l + 96] = (q4 * f32(sc4)) * d;
+        }
+        dst_i += 128;
+        qh_b_idx += 32;
+        sc_b_idx += 8;
+    }
+}
+#endif
+
+#ifdef IQ2_XXS
+fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
+    let block = src[src_base + offset];
+    let d = f32(block.d);
+    var dst_i = dst_base + offset * 256;
+    for (var ib: u32 = 0; ib < 32; ib += 4) {
+        let aux0 = bitcast<u32>(vec2(block.qs[ib], block.qs[ib + 1]));
+        let aux1 = bitcast<u32>(vec2(block.qs[ib + 2], block.qs[ib + 3]));
+        let db = d * (0.5 + f32(aux1 >> 28)) * 0.25;
+        for (var l: u32 = 0; l < 4; l++) {
+            let ig = get_byte(aux0, l) * 8;
+            let is = (aux1 >> (7 * l)) & 127;
+            let signs = get_byte(ksigns_iq2xs[is / 4], is % 4);
+            for (var j: u32 = 0; j < 8; j++) {
+                let g = get_byte(iq2xxs_grid[(ig + j) / 4], (ig + j) % 4);
+                let m = select(1.0, -1.0, (get_byte(kmask_iq2xs[j / 4], j % 4) & signs) != 0);
+                dst[dst_i] = db * f32(g) * m;
+                dst_i++;
+            }
+        }
+    }
+}
+#endif
+
+#ifdef IQ2_XS
+fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
+    let block = src[src_base + offset];
+    let d = f32(block.d);
+    var dst_i = dst_base + offset * 256;
+    var scale_vals = array<u32, 2>(
+        bitcast<u32>(vec2(block.scales[0], block.scales[1])),
+        bitcast<u32>(vec2(block.scales[2], block.scales[3]))
+    );
+    for (var ib: u32 = 0; ib < 32; ib += 4) {
+        let s = get_byte(scale_vals[ib / 16], (ib % 16) / 4);
+        let db = array<f32, 2>(
+            d * (0.5 + f32(s & 0xF)) * 0.25,
+            d * (0.5 + f32(s >> 4)) * 0.25
+        );
+        for (var l: u32 = 0; l < 4; l++) {
+            let qs_val = bitcast<u32>(vec2(block.qs[ib + l], 0.0));
+            let ig = (qs_val & 511) * 8;
+            let is = qs_val >> 9;
+            let signs = get_byte(ksigns_iq2xs[is / 4], is % 4);
+            let dl = db[l/2];
+            for (var j: u32 = 0; j < 8; j++) {
+                let g = get_byte(iq2xs_grid[(ig + j) / 4], (ig + j) % 4);
+                let m = select(1.0, -1.0, (get_byte(kmask_iq2xs[j / 4], j % 4) & signs) != 0);
+                dst[dst_i] = dl * f32(g) * m;
+                dst_i++;
+            }
+        }
+    }
+}
+#endif
+
+#ifdef IQ2_S
+fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
+    let block = src[src_base + offset];
+    let d = f32(block.d);
+    var dst_i = dst_base + offset * 256;
+    var qs_vals : array<u32, 16>;
+    for (var i: u32 = 0; i < 16; i++) {
+        qs_vals[i] = bitcast<u32>(vec2(block.qs[i * 2], block.qs[i * 2 + 1]));
+    }
+    var qh_vals = array<u32, 2>(
+        bitcast<u32>(vec2(block.qh[0], block.qh[1])),
+        bitcast<u32>(vec2(block.qh[2], block.qh[3]))
+    );
+    var scale_vals = array<u32, 2>(
+        bitcast<u32>(vec2(block.scales[0], block.scales[1])),
+        bitcast<u32>(vec2(block.scales[2], block.scales[3]))
+    );
+    for (var ib: u32 = 0; ib < 8; ib ++) {
+        let s = get_byte(scale_vals[ib / 4], ib % 4);
+        let db = array<f32, 2>(
+            d * (0.5 + f32(s & 0xF)) * 0.25,
+            d * (0.5 + f32(s >> 4)) * 0.25
+        );
+        let qs_w = qs_vals[ib];
+        for (var l: u32 = 0; l < 4; l++) {
+            let qh_b = (get_byte(qh_vals[ib / 4], ib % 4) << (8 - 2 * l)) & 0x300;
+            let ig = (get_byte(qs_w, l) | qh_b) * 8;
+            let signs = get_byte(qs_vals[ib + 8], l);
+            let dl = db[l/2];
+            for (var j: u32 = 0; j < 8; j++) {
+                let g = get_byte(iq2s_grid[(ig + j) / 4], (ig + j) % 4);
+                let m = select(1.0, -1.0, (get_byte(kmask_iq2xs[j / 4], j % 4) & signs) != 0);
+                dst[dst_i] = dl * f32(g) * m;
+                dst_i++;
+            }
+        }
+    }
+}
+#endif
+
+#ifdef IQ3_XXS
+fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
+    let block = src[src_base + offset];
+    let d = f32(block.d);
+    var dst_i = dst_base + offset * 256;
+    for (var ib: u32 = 0; ib < 16; ib += 2) {
+        let sc_sign = bitcast<u32>(vec2(block.qs[ib + 32], block.qs[ib + 33]));
+        let db = d * (0.5 + f32(sc_sign >> 28)) * 0.5;
+        for (var l: u32 = 0; l < 4; l++) {
+            let is = (sc_sign >> (7 * l)) & 127;
+            let signs = get_byte(ksigns_iq2xs[is / 4], is % 4);
+            let ig_val = bitcast<u32>(vec2(block.qs[ib * 2 + l], 0.0));
+            let ig1 = get_byte(ig_val, 0);
+            let ig2 = get_byte(ig_val, 1);
+            for (var j: u32 = 0; j < 4; j++) {
+                let g1 = get_byte(iq3xxs_grid[ig1], j);
+                let g2 = get_byte(iq3xxs_grid[ig2], j);
+                let m1 = select(1.0, -1.0, (get_byte(kmask_iq2xs[0], j) & signs) != 0);
+                let m2 = select(1.0, -1.0, (get_byte(kmask_iq2xs[1], j) & signs) != 0);
+                dst[dst_i] = db * f32(g1) * m1;
+                dst[dst_i + 4] = db * f32(g2) * m2;
+                dst_i++;
+            }
+            dst_i += 4;
+        }
+    }
+}
+#endif
+
+#ifdef IQ3_S
+fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
+    let block = src[src_base + offset];
+    let d = f32(block.d);
+    var dst_i = dst_base + offset * 256;
+    var qh_vals = array<u32, 2>(
+        bitcast<u32>(vec2(block.qh[0], block.qh[1])),
+        bitcast<u32>(vec2(block.qh[2], block.qh[3]))
+    );
+    var sign_vals: array<u32, 8>;
+    for (var i: u32 = 0; i < 8; i++) {
+        sign_vals[i] = bitcast<u32>(vec2(block.signs[i * 2], block.signs[i * 2 + 1]));
+    }
+    var scale_vals = bitcast<u32>(vec2(block.scales[0], block.scales[1]));
+    for (var ib: u32 = 0; ib < 4; ib++) {
+        let s = get_byte(scale_vals, ib);
+        let db = array<f32, 2>(
+            d * (1.0 + 2.0 * f32(s & 0xF)),
+            d * (1.0 + 2.0 * f32(s >> 4))
+        );
+        for (var k: u32 = 0; k < 2; k++) {
+            let dl = db[k];
+            let qh_byte = get_byte(qh_vals[ib / 2], (ib % 2) * 2 + k);
+            let sign_w = sign_vals[ib * 2 + k];
+            for (var l: u32 = 0; l < 4; l++) {
+                let signs = get_byte(sign_w, l);
+                let ig_val = bitcast<u32>(vec2(block.qs[ib * 8 + k * 4 + l], 0.0));
+                let ig1 = get_byte(ig_val, 0) | ((qh_byte << ((8 - (2 * l)))) & 256);
+                let ig2 = get_byte(ig_val, 1) | ((qh_byte << ((7 - (2 * l)))) & 256);
+                for (var j: u32 = 0; j < 4; j++) {
+                    let g1 = get_byte(iq3s_grid[ig1], j);
+                    let g2 = get_byte(iq3s_grid[ig2], j);
+                    let m1 = select(1.0, -1.0, (get_byte(kmask_iq2xs[0], j) & signs) != 0);
+                    let m2 = select(1.0, -1.0, (get_byte(kmask_iq2xs[1], j) & signs) != 0);
+                    dst[dst_i] = dl * f32(g1) * m1;
+                    dst[dst_i + 4] = dl * f32(g2) * m2;
+                    dst_i++;
+                }
+                dst_i += 4;
+            }
+        }
+    }
+}
+#endif
+
+#ifdef IQ1_S
+fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
+    let block = src[src_base + offset];
+    let d = f32(block.d);
+    var dst_i = dst_base + offset * 256;
+    for (var ib: u32 = 0; ib < 8; ib++) {
+        let qh = bitcast<u32>(vec2(block.qh[ib], 0.0));
+        let dl = d * (2 * f32((qh >> 12) & 7) + 1);
+        let delta = select(IQ1_DELTA, -IQ1_DELTA, (qh & 0x8000) != 0);
+        let qs_w = bitcast<u32>(vec2(block.qs[ib * 2], block.qs[ib * 2 + 1]));
+        for (var l: u32 = 0; l < 4; l++) {
+            let ig = (get_byte(qs_w, l) | (((qh >> (3 * l)) & 7) << 8)) * 8;
+            for (var j: u32 = 0; j < 8; j++) {
+                let gw = iq1_grid[(ig + j) / 16];
+                let g = (gw >> (((ig + j) % 16) * 2)) & 3;
+                let gs = bitcast<i32>(g << 30) >> 30;
+                dst[dst_i] = dl * (f32(gs) + delta);
+                dst_i++;
+            }
+        }
+    }
+}
+#endif
+
+#ifdef IQ1_M
+fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
+    let block = src[src_base + offset];
+
+    let scale = ((block.scales[0] >> 12) & 0xF) | ((block.scales[0] >> 24) & 0x00F0) | ((block.scales[1] >> 4) & 0x0F00) | ((block.scales[1] >> 16) & 0xF000);
+    let d = f32(bitcast<vec2<f16>>(scale).x);
+    var dst_i = dst_base + offset * 256;
+    for (var ib: u32 = 0; ib < 8; ib++) {
+        let sw = (block.scales[ib / 4] >> (16 * ((ib / 2) % 2))) & 0xFFFF;
+        let s1 : u32 = (sw >> (6 * (ib % 2))) & 0x7;
+        let s2 : u32 = (sw >> (6 * (ib % 2) + 3)) & 0x7;
+        var dl = array<f32, 2>(
+            d * f32(2 * s1 + 1),
+            d * f32(2 * s2 + 1)
+        );
+
+        let qh = block.qh[ib / 2] >> (16 * (ib % 2));
+        var idx = array<u32, 4>(
+            get_byte(block.qs[ib], 0) | ((qh << 8) & 0x700),
+            get_byte(block.qs[ib], 1) | ((qh << 4) & 0x700),
+            get_byte(block.qs[ib], 2) | ((qh) & 0x700),
+            get_byte(block.qs[ib], 3) | ((qh >> 4) & 0x700)
+        );
+        var delta = array<f32, 4>(
+            select(IQ1_DELTA, -IQ1_DELTA, (qh & 0x08) != 0),
+            select(IQ1_DELTA, -IQ1_DELTA, (qh & 0x80) != 0),
+            select(IQ1_DELTA, -IQ1_DELTA, ((qh >> 8) & 0x08) != 0),
+            select(IQ1_DELTA, -IQ1_DELTA, ((qh >> 8) & 0x80) != 0)
+        );
+        for (var l: u32 = 0; l < 4; l++) {
+            let ig = idx[l] * 8;
+            for (var j: u32 = 0; j < 8; j++) {
+                let gw = iq1_grid[(ig + j) / 16];
+                let g = (gw >> (((ig + j) % 16) * 2)) & 3;
+                let gs = bitcast<i32>(g << 30) >> 30;
+                dst[dst_i] = dl[l/2] * (f32(gs) + delta[l]);
+                dst_i++;
+            }
+        }
+    }
+}
+#endif
+
+#ifdef IQ4_NL
+fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
+    let block = src[src_base + offset];
+    let d = f32(block.d);
+    var dst_i = dst_base + offset * 32;
+    var qs: array<u32, 4>;
+    for (var i: u32 = 0; i < 4; i++) {
+        qs[i] = bitcast<u32>(vec2(block.qs[i * 2], block.qs[i * 2 + 1]));
+    }
+    for (var j: u32 = 0; j < 16; j++) {
+        let qsb = get_byte(qs[j / 4], j % 4);
+        dst[dst_i] = d * f32(kvalues_iq4nl[qsb & 0xF]);
+        dst[dst_i + 16] = d * f32(kvalues_iq4nl[qsb >> 4]);
+        dst_i++;
+    }
+}
+#endif
+
+#ifdef IQ4_XS
+fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
+    let block = src[src_base + offset];
+    let d = f32(block.d);
+    let scales_h = bitcast<u32>(vec2(block.scales_h, 0.0));
+    var dst_i = dst_base + offset * 256;
+    for (var ib: u32 = 0; ib < 8; ib++) {
+        let ls = ((get_byte(block.scales_l, ib / 2) >> (4 * (ib % 2))) & 0xF) | (((scales_h >> (2 * ib)) & 3) << 4);
+        let dl = d * (f32(ls) - 32.0);
+        for (var j: u32 = 0; j < 16; j++) {
+            let iqs = ib * 16 + j;
+            let qsb = get_byte(block.qs[iqs / 4], iqs % 4);
+            dst[dst_i] = dl * f32(kvalues_iq4nl[qsb & 0xF]);
+            dst[dst_i + 16] = dl * f32(kvalues_iq4nl[qsb >> 4]);
+            dst_i++;
+        }
+        dst_i += 16;
+    }
+}
+#endif
+
+@group(0) @binding(0)
+var<storage, read_write> src: array<SRC_TYPE>;
+
+@group(0) @binding(1)
+var<storage, read_write> idx: array<i32>;
+
+@group(0) @binding(2)
+var<storage, read_write> dst: array<DST_TYPE>;
+
+struct Params {
+    offset_src: u32, // in elements
+    offset_idx: u32, // in elements
+    offset_dst: u32, // in elements
+
+    // Strides (in elements)
+    stride_src1: u32,
+    stride_src2: u32,
+    stride_src3: u32,
+
+    stride_idx0: u32,
+    stride_idx1: u32,
+    stride_idx2: u32,
+
+    stride_dst1: u32,
+    stride_dst2: u32,
+    stride_dst3: u32,
+
+    // Shape of dst
+    ne0: u32,
+    n_rows: u32,
+    ne2: u32,
+    ne3: u32,
+
+    // Shape of idx
+    idx1: u32,
+    idx2: u32,
+};
+
+@group(0) @binding(3)
+var<uniform> params: Params;
+
+@compute @workgroup_size(WG_SIZE)
+fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
+    if (gid.x >= params.n_rows * params.ne2 * params.ne3) {
+        return;
+    }
+    var i = gid.x;
+    let i_dst3 = i / (params.ne2 * params.n_rows);
+
+    i = i % (params.ne2 * params.n_rows);
+    let i_dst2 = i / params.n_rows;
+    let i_dst1 = i % params.n_rows;
+
+    let i_idx2 = i_dst3 % params.idx2;
+    let i_idx1 = i_dst2 % params.idx1;
+    let i_idx0 = i_dst1;
+
+    let i_idx = params.offset_idx + i_idx0 * params.stride_idx0 + i_idx1 * params.stride_idx1 + i_idx2 * params.stride_idx2;
+
+    let idx_val = u32(idx[i_idx]);
+
+    let i_src_row = params.offset_src + idx_val * params.stride_src1 + i_dst2 * params.stride_src2 + i_dst3 * params.stride_src3;
+    let i_dst_row = params.offset_dst + i_dst1 * params.stride_dst1 + i_dst2 * params.stride_dst2 + i_dst3 * params.stride_dst3;
+
+    for (var i: u32 = 0; i < params.ne0/BLOCK_SIZE; i++) {
+      copy_elements(i_src_row, i_dst_row, i);
+    }
+}
+
diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat.tmpl.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat.tmpl.wgsl
deleted file mode 100644 (file)
index 0f8e6e5..0000000
+++ /dev/null
@@ -1,907 +0,0 @@
-#define(VARIANTS)
-
-[
-  {
-    "REPLS": {
-      "SRC0_TYPE" : "f32",
-      "SRC1_TYPE" : "f32",
-      "BLOCK_SIZE" : 1
-    },
-    "DECLS" : ["FLOAT"]
-  },
-  {
-    "REPLS": {
-      "SRC0_TYPE" : "f16",
-      "SRC1_TYPE" : "f16",
-      "BLOCK_SIZE" : 1
-    },
-    "DECLS" : ["FLOAT"]
-  },
-  {
-    "REPLS": {
-      "SRC0_TYPE" : "f16",
-      "SRC1_TYPE" : "f32",
-      "BLOCK_SIZE" : 1
-    },
-    "DECLS" : ["FLOAT"]
-  },
-  {
-    "REPLS": {
-      "SRC0_TYPE": "q4_0",
-      "SRC1_TYPE": "f32",
-      "BLOCK_SIZE": 32
-    },
-    "DECLS": ["BYTE_HELPERS", "Q4_0_T", "Q4_0"]
-  },
-  {
-    "REPLS": {
-      "SRC0_TYPE": "q4_1",
-      "SRC1_TYPE": "f32",
-      "BLOCK_SIZE": 32
-    },
-    "DECLS": ["BYTE_HELPERS", "Q4_1_T", "Q4_1"]
-  },
-  {
-    "REPLS": {
-      "SRC0_TYPE": "q5_0",
-      "SRC1_TYPE": "f32",
-      "BLOCK_SIZE": 32
-    },
-    "DECLS": ["BYTE_HELPERS", "Q5_0_T", "Q5_0"]
-  },
-  {
-    "REPLS": {
-      "SRC0_TYPE": "q5_1",
-      "SRC1_TYPE": "f32",
-      "BLOCK_SIZE": 32
-    },
-    "DECLS": ["BYTE_HELPERS", "Q5_1_T", "Q5_1"]
-  },
-  {
-    "REPLS": {
-      "SRC0_TYPE": "q8_0",
-      "SRC1_TYPE": "f32",
-      "BLOCK_SIZE": 32
-    },
-    "DECLS": ["BYTE_HELPERS", "Q8_0_T", "Q8_0"]
-  },
-  {
-    "REPLS": {
-      "SRC0_TYPE": "q2_k",
-      "SRC1_TYPE": "f32",
-      "BLOCK_SIZE": 256
-    },
-    "DECLS": ["BYTE_HELPERS", "Q2_K_T", "Q2_K"]
-  },
-  {
-    "REPLS": {
-      "SRC0_TYPE": "q3_k",
-      "SRC1_TYPE": "f32",
-      "BLOCK_SIZE": 256
-    },
-    "DECLS": ["BYTE_HELPERS", "Q3_K_T", "Q3_K"]
-  },
-  {
-    "REPLS": {
-      "SRC0_TYPE": "q4_k",
-      "SRC1_TYPE": "f32",
-      "BLOCK_SIZE": 256
-    },
-    "DECLS": ["Q45_K_SCALE_MIN", "BYTE_HELPERS", "Q4_K_T", "Q4_K"]
-  },
-  {
-    "REPLS": {
-      "SRC0_TYPE": "q5_k",
-      "SRC1_TYPE": "f32",
-      "BLOCK_SIZE": 256
-    },
-    "DECLS": ["Q45_K_SCALE_MIN", "BYTE_HELPERS", "Q5_K_T", "Q5_K"]
-  },
-  {
-    "REPLS": {
-      "SRC0_TYPE": "q6_k",
-      "SRC1_TYPE": "f32",
-      "BLOCK_SIZE": 256
-    },
-    "DECLS": ["BYTE_HELPERS", "Q6_K_T", "Q6_K"]
-  },
-  {
-    "REPLS": {
-      "SRC0_TYPE": "iq2_xxs",
-      "SRC1_TYPE": "f32",
-      "BLOCK_SIZE": 256
-    },
-    "DECLS": ["BYTE_HELPERS", "IQ23_TABLES", "IQ2_XXS_GRID", "IQ2_XXS_T", "IQ2_XXS"]
-  },
-  {
-    "REPLS": {
-      "SRC0_TYPE": "iq2_xs",
-      "SRC1_TYPE": "f32",
-      "BLOCK_SIZE": 256
-    },
-    "DECLS": ["BYTE_HELPERS", "IQ23_TABLES", "IQ2_XS_GRID", "IQ2_XS_T", "IQ2_XS"]
-  },
-  {
-    "REPLS": {
-      "SRC0_TYPE": "iq2_s",
-      "SRC1_TYPE": "f32",
-      "BLOCK_SIZE": 256
-    },
-    "DECLS": ["BYTE_HELPERS", "IQ23_TABLES", "IQ2_S_GRID", "IQ2_S_T", "IQ2_S"]
-  },
-  {
-    "REPLS": {
-      "SRC0_TYPE": "iq3_xxs",
-      "SRC1_TYPE": "f32",
-      "BLOCK_SIZE": 256
-    },
-    "DECLS": ["BYTE_HELPERS", "IQ23_TABLES", "IQ3_XSS_GRID", "IQ3_XSS_T", "IQ3_XSS"]
-  },
-  {
-    "REPLS": {
-      "SRC0_TYPE": "iq3_s",
-      "SRC1_TYPE": "f32",
-      "BLOCK_SIZE": 256
-    },
-    "DECLS": ["BYTE_HELPERS", "IQ23_TABLES", "IQ3_S_GRID", "IQ3_S_T", "IQ3_S"]
-  },
-  {
-    "REPLS": {
-      "SRC0_TYPE": "iq1_s",
-      "SRC1_TYPE": "f32",
-      "BLOCK_SIZE": 256
-    },
-    "DECLS": ["BYTE_HELPERS", "IQ1_GRID", "IQ1_S_T", "IQ1_S"]
-  },
-  {
-    "REPLS": {
-      "SRC0_TYPE": "iq1_m",
-      "SRC1_TYPE": "f32",
-      "BLOCK_SIZE": 256
-    },
-    "DECLS": ["BYTE_HELPERS", "IQ1_GRID", "IQ1_M_T", "IQ1_M"]
-  },
-  {
-    "REPLS": {
-      "SRC0_TYPE": "iq4_nl",
-      "SRC1_TYPE": "f32",
-      "BLOCK_SIZE": 32,
-    },
-    "DECLS": ["BYTE_HELPERS", "IQ4_GRID", "IQ4_NL_T", "IQ4_NL"]
-  },
-  {
-    "REPLS": {
-      "SRC0_TYPE": "iq4_xs",
-      "SRC1_TYPE": "f32",
-      "BLOCK_SIZE": 256,
-    },
-    "DECLS": ["BYTE_HELPERS", "IQ4_GRID", "IQ4_XS_T", "IQ4_XS"]
-  }
-]
-
-#end(VARIANTS)
-
-#define(DECLS)
-
-#decl(FLOAT)
-fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
-    return f32(src0[src0_idx_base + offset]) * f32(src1[src1_idx_base + offset]);
-}
-#enddecl(FLOAT)
-
-#decl(Q4_0)
-fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
-    let block_q4_0 = src0[src0_idx_base + offset];
-    let d = f32(block_q4_0.d);
-    var sum: f32 = 0.0;
-    for (var j: u32 = 0; j < 4; j++) {
-        let q_packed = bitcast<u32>(vec2(block_q4_0.qs[2 * j], block_q4_0.qs[2 * j + 1]));
-        for (var k: u32 = 0; k < 4; k++) {
-            let q_byte = get_byte(q_packed, k);
-            let q_hi = (f32((q_byte >> 4) & 0xF) - 8.0f) * d;
-            let q_lo = (f32(q_byte & 0xF) - 8.0f) * d;
-            let src1_offset = src1_idx_base + offset * 32 + j * 4 + k;
-            sum += q_lo * f32(src1[src1_offset]);
-            sum += q_hi * f32(src1[src1_offset + 16]);
-        }
-    }
-    return sum;
-}
-#enddecl(Q4_0)
-
-#decl(Q4_1)
-fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
-    let block_q4_1 = src0[src0_idx_base + offset];
-    let d = f32(block_q4_1.d);
-    let m = f32(block_q4_1.m);
-    var sum: f32 = 0.0;
-    for (var j: u32 = 0; j < 4; j++) {
-        let q_packed = block_q4_1.qs[j];
-        for (var k: u32 = 0; k < 4; k++) {
-            let q_byte = get_byte(q_packed, k);
-            let q_hi = f32((q_byte >> 4) & 0xF) * d + m;
-            let q_lo = f32(q_byte & 0xF) * d + m;
-            let src1_offset = src1_idx_base + offset * 32 + j * 4 + k;
-            sum += q_lo * f32(src1[src1_offset]);
-            sum += q_hi * f32(src1[src1_offset + 16]);
-        }
-    }
-    return sum;
-}
-#enddecl(Q4_1)
-
-#decl(Q5_0)
-fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
-    let block_q5_0 = src0[src0_idx_base + offset];
-    let d = f32(block_q5_0.d);
-    var sum: f32 = 0.0;
-    let qh_packed = bitcast<u32>(vec2(block_q5_0.qh[0], block_q5_0.qh[1]));
-    for (var j: u32 = 0; j < 4; j++) {
-        let q_packed = bitcast<u32>(vec2(block_q5_0.qs[2 * j], block_q5_0.qs[2 * j + 1]));
-        for (var k: u32 = 0; k < 4; k++) {
-            let q_byte = get_byte(q_packed, k);
-            let qh_hi = (qh_packed >> (j * 4 + k + 12)) & 0x10;
-            let q_hi = (f32(((q_byte >> 4) & 0xF) | qh_hi) - 16.0) * d;
-            let qh_lo = ((qh_packed >> (j * 4 + k)) << 4) & 0x10;
-            let q_lo = (f32((q_byte & 0xF) | qh_lo) - 16.0) * d;
-            let src1_offset = src1_idx_base + offset * 32 + j * 4 + k;
-            sum += q_lo * f32(src1[src1_offset]);
-            sum += q_hi * f32(src1[src1_offset + 16]);
-        }
-    }
-    return sum;
-}
-#enddecl(Q5_0)
-
-#decl(Q5_1)
-fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
-    let block_q5_1 = src0[src0_idx_base + offset];
-    let d = f32(block_q5_1.d);
-    let m = f32(block_q5_1.m);
-    var sum: f32 = 0.0;
-    for (var j: u32 = 0; j < 4; j++) {
-        let q_packed = block_q5_1.qs[j];
-        for (var k: u32 = 0; k < 4; k++) {
-            let q_byte = get_byte(q_packed, k);
-            let qh_hi = (block_q5_1.qh >> (j * 4 + k + 12)) & 0x10;
-            let q_hi = f32(((q_byte >> 4) & 0xF) | qh_hi) * d + m;
-            let qh_lo = ((block_q5_1.qh >> (j * 4 + k)) << 4) & 0x10;
-            let q_lo = f32((q_byte & 0xF) | qh_lo) * d + m;
-            let src1_offset = src1_idx_base + offset * 32 + j * 4 + k;
-            sum += q_lo * f32(src1[src1_offset]);
-            sum += q_hi * f32(src1[src1_offset + 16]);
-        }
-    }
-    return sum;
-}
-#enddecl(Q5_1)
-
-#decl(Q8_0)
-fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
-    let block_q8_0 = src0[src0_idx_base + offset];
-    let d = f32(block_q8_0.d);
-    var sum: f32 = 0.0;
-    for (var j: u32 = 0; j < 8; j++) {
-        let q_packed = bitcast<u32>(vec2(block_q8_0.qs[2 * j], block_q8_0.qs[2 * j + 1]));
-        for (var k: u32 = 0; k < 4; k++) {
-            let q_byte = get_byte_i32(q_packed, k);
-            let q_val = f32(q_byte) * d;
-            let src1_offset = src1_idx_base + offset * 32 + j * 4 + k;
-            sum += q_val * f32(src1[src1_offset]);
-        }
-    }
-    return sum;
-}
-#enddecl(Q8_0)
-
-#decl(Q8_1)
-fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
-    let block_q8_1 = src0[src0_idx_base + offset];
-    let d = f32(block_q8_1.d);
-    let m = f32(block_q8_1.m);
-    var sum: f32 = 0.0;
-    for (var j: u32 = 0; j < 8; j++) {
-        let q_packed = block_q8_1.qs[j];
-        for (var k: u32 = 0; k < 4; k++) {
-            let q_byte = get_byte_i32(q_packed, k);
-            let q_val = f32(q_byte) * d + m;
-            let src1_offset = src1_idx_base + offset * 32 + j * 4 + k;
-            sum += q_val * f32(src1[src1_offset]);
-        }
-    }
-    return sum;
-}
-#enddecl(Q8_1)
-
-#decl(Q2_K)
-// 16 blocks of 16 elements each
-fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
-    let block = src0[src0_idx_base + offset];
-    let d = f32(block.d);
-    let m = f32(block.dmin);
-    var sum = 0.0;
-    var src1_i = src1_idx_base + offset * 256;
-    var is: u32 = 0;
-    // 2 halves of the block (128 elements each)
-    for (var q_b_idx: u32 = 0; q_b_idx < 64; q_b_idx += 32) {
-        // 4 groups (each group has 2 blocks of 16 elements)
-        for (var shift: u32 = 0; shift < 8; shift += 2) {
-            // 2 blocks
-            for (var k: u32 = 0; k < 32; k += 16) {
-                let sc = get_byte(block.scales[is / 4], is % 4);
-                is++;
-                let dl = d * f32(sc & 0xF);
-                let ml = m * f32(sc >> 4);
-                for (var l: u32 = 0u; l < 16; l++) {
-                    let q_idx = q_b_idx + k + l;
-                    let q_byte = get_byte(block.qs[q_idx / 4], q_idx % 4);
-                    let qs_val = (q_byte >> shift) & 3;
-                    sum += (f32(qs_val) * dl - ml) * src1[src1_i];
-                    src1_i++;
-                }
-            }
-        }
-    }
-    return sum;
-}
-
-#enddecl(Q2_K)
-
-#decl(Q3_K)
-// 16 blocks of 16 elements each
-fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
-    let block = src0[src0_idx_base + offset];
-    let d = f32(block.d);
-
-    // extract 6-bit scales, which consist of 4-bits from first 8 bytes of scale,
-    // and 2-bits from the last 4 bytes
-    let kmask1: u32 = 0x03030303;
-    let kmask2: u32 = 0x0f0f0f0f;
-    var scale_vals: array<u32, 4>;
-    for (var i: u32 = 0; i < 4; i++) {
-        scale_vals[i] = bitcast<u32>(vec2(block.scales[2 * i], block.scales[2 * i + 1]));
-    }
-    var tmp: u32 = scale_vals[2];
-    scale_vals[2] = ((scale_vals[0] >> 4) & kmask2) | (((tmp >> 4) & kmask1) << 4);
-    scale_vals[3] = ((scale_vals[1] >> 4) & kmask2) | (((tmp >> 6) & kmask1) << 4);
-    scale_vals[0] = (scale_vals[0] & kmask2) | ((tmp & kmask1) << 4);
-    scale_vals[1] = (scale_vals[1] & kmask2) | (((tmp >> 2) & kmask1) << 4);
-
-    // convert arrays of f16 -> u32
-    var hmask_vals: array<u32, 8>;
-    for (var i: u32 = 0; i < 8; i++) {
-        hmask_vals[i] = bitcast<u32>(vec2(block.hmask[2 * i], block.hmask[2 * i + 1]));
-    }
-    var qs_vals: array<u32, 16>;
-    for (var i: u32 = 0; i < 16; i++) {
-        qs_vals[i] = bitcast<u32>(vec2(block.qs[2 * i], block.qs[2 * i + 1]));
-    }
-
-    var sum = 0.0;
-    var src1_i = src1_idx_base + offset * 256;
-    var is: u32 = 0;
-    var m: u32 = 1;
-    // 2 halves of the block (128 elements each)
-    for (var q_b_idx: u32 = 0; q_b_idx < 64; q_b_idx += 32) {
-        // 4 groups (each group has 2 blocks of 16 elements)
-        for (var shift: u32 = 0; shift < 8; shift += 2) {
-            // 2 blocks
-            for (var k: u32 = 0; k < 32; k += 16) {
-                let sc = get_byte(scale_vals[is / 4], is % 4);
-                is++;
-                let dl = d * (f32(sc) - 32.0);
-                for (var l: u32 = 0u; l < 16u; l++) {
-                    let q_idx = q_b_idx + k + l;
-                    let hm_idx = k + l;
-                    let q_byte = get_byte(qs_vals[q_idx / 4], q_idx % 4);
-                    let hmask_byte = get_byte(hmask_vals[hm_idx / 4], hm_idx % 4);
-                    let hm = select(4.0, 0.0, (hmask_byte & m) != 0);
-                    let qs_val = (q_byte >> shift) & 3;
-                    sum += ((f32(qs_val) - hm) * dl) * src1[src1_i];
-                    src1_i++;
-                }
-            }
-            m <<= 1;
-        }
-    }
-    return sum;
-}
-
-#enddecl(Q3_K)
-
-#decl(Q4_K)
-// 8 blocks of 32 elements each
-fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
-    let block = src0[src0_idx_base + offset];
-    let d = f32(block.d);
-    let m = f32(block.dmin);
-    var sum = 0.0;
-    var src1_i = src1_idx_base + offset * 256;
-    var is: u32 = 0;
-    // 2 blocks each iteration
-    for (var q_b_idx: u32 = 0; q_b_idx < 128; q_b_idx += 32) {
-        for (var shift: u32 = 0; shift < 8; shift += 4) {
-            let scale_min = get_scale_min(is, block.scales);
-            is++;
-            let dl = d * scale_min.x;
-            let ml = m * scale_min.y;
-            for (var l: u32 = 0; l < 32; l++) {
-                let q_idx = q_b_idx + l;
-                let q_byte = get_byte(block.qs[q_idx / 4], q_idx % 4);
-                let qs_val = (q_byte >> shift) & 0xF;
-                sum += (f32(qs_val) * dl - ml) * src1[src1_i];
-                src1_i++;
-            }
-        }
-    }
-    return sum;
-}
-
-#enddecl(Q4_K)
-
-#decl(Q5_K)
-// 8 blocks of 32 elements each
-fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
-    let block = src0[src0_idx_base + offset];
-    let d = f32(block.d);
-    let m = f32(block.dmin);
-    var sum = 0.0;
-    var src1_i = src1_idx_base + offset * 256;
-    var is: u32 = 0;
-    var u: u32 = 1;
-    // 2 blocks each iteration
-    for (var q_b_idx: u32 = 0; q_b_idx < 128; q_b_idx += 32) {
-        for (var shift: u32 = 0; shift < 8; shift += 4) {
-            let scale_min = get_scale_min(is, block.scales);
-            is++;
-            let dl = d * scale_min.x;
-            let ml = m * scale_min.y;
-            for (var l: u32 = 0; l < 32; l++) {
-                let q_idx = q_b_idx + l;
-                let q_byte = get_byte(block.qs[q_idx / 4], q_idx % 4);
-                let qh_byte = get_byte(block.qh[l / 4], l % 4);
-                let qs_val = (q_byte >> shift) & 0xF;
-                let qh_val = select(0.0, 16.0, (qh_byte & u) != 0);
-                sum += ((f32(qs_val) + qh_val) * dl - ml) * src1[src1_i];
-               src1_i++;
-            }
-            u <<= 1;
-        }
-    }
-    return sum;
-}
-
-#enddecl(Q5_K)
-
-#decl(Q6_K)
-// 16 blocks of 16 elements each
-fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
-    let block = src0[src0_idx_base + offset];
-    let d = f32(block.d);
-
-    // convert arrays of f16 -> u32
-    var ql_vals: array<u32, 32>;
-    for (var i: u32 = 0; i < 32; i++) {
-        ql_vals[i] = bitcast<u32>(vec2(block.ql[2 * i], block.ql[2 * i + 1]));
-    }
-    var qh_vals: array<u32, 16>;
-    for (var i: u32 = 0; i < 16; i++) {
-        qh_vals[i] = bitcast<u32>(vec2(block.qh[2 * i], block.qh[2 * i + 1]));
-    }
-    var scale_vals: array<u32, 4>;
-    for (var i: u32 = 0; i < 4; i++) {
-        scale_vals[i] = bitcast<u32>(vec2(block.scales[2 * i], block.scales[2 * i + 1]));
-    }
-
-    var sum = 0.0;
-    var src1_i = src1_idx_base + offset * 256;
-    var qh_b_idx: u32 = 0;
-    var sc_b_idx: u32 = 0;
-    for (var ql_b_idx: u32 = 0; ql_b_idx < 128; ql_b_idx += 64) {
-        for (var l: u32 = 0; l < 32; l++) {
-            let ql13_b = get_byte(ql_vals[(ql_b_idx + l) / 4], (ql_b_idx + l) % 4);
-            let ql24_b = get_byte(ql_vals[(ql_b_idx + l + 32) / 4], (ql_b_idx + l + 32) % 4);
-            let qh_b = get_byte(qh_vals[(qh_b_idx + l) / 4], (qh_b_idx + l) % 4);
-
-            let q1 = f32((ql13_b & 0xF) | ((qh_b & 3) << 4)) - 32.0;
-            let q2 = f32((ql24_b & 0xF) | (((qh_b >> 2) & 3) << 4)) - 32.0;
-            let q3 = f32((ql13_b >> 4) | (((qh_b >> 4) & 3) << 4)) - 32.0;
-            let q4 = f32((ql24_b >> 4) | (((qh_b >> 6) & 3) << 4)) - 32.0;
-
-            let is = l/16;
-            let is1 = sc_b_idx + is;
-            let sc1 = get_byte_i32(scale_vals[is1 / 4], is1 % 4);
-            let is2 = sc_b_idx + is + 2;
-            let sc2 = get_byte_i32(scale_vals[is2 / 4], is2 % 4);
-            let is3 = sc_b_idx + is + 4;
-            let sc3 = get_byte_i32(scale_vals[is3 / 4], is3 % 4);
-            let is4 = sc_b_idx + is + 6;
-            let sc4 = get_byte_i32(scale_vals[is4 / 4], is4 % 4);
-
-            sum += d * f32(sc1) * q1 * src1[src1_i + l];
-            sum += d * f32(sc2) * q2 * src1[src1_i + l + 32];
-            sum += d * f32(sc3) * q3 * src1[src1_i + l + 64];
-            sum += d * f32(sc4) * q4 * src1[src1_i + l + 96];
-        }
-        src1_i += 128;
-        qh_b_idx += 32;
-        sc_b_idx += 8;
-    }
-    return sum;
-}
-
-#enddecl(Q6_K)
-
-#decl(IQ2_XXS)
-fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
-    let block = src0[src0_idx_base + offset];
-    let d = f32(block.d);
-    var src1_i = src1_idx_base + offset * 256;
-    var sum = 0.0;
-    for (var ib: u32 = 0; ib < 32; ib += 4) {
-        let aux0 = bitcast<u32>(vec2(block.qs[ib], block.qs[ib + 1]));
-        let aux1 = bitcast<u32>(vec2(block.qs[ib + 2], block.qs[ib + 3]));
-        let db = d * (0.5 + f32(aux1 >> 28)) * 0.25;
-        for (var l: u32 = 0; l < 4; l++) {
-            let ig = get_byte(aux0, l) * 8;
-            let is = (aux1 >> (7 * l)) & 127;
-            let signs = get_byte(ksigns_iq2xs[is / 4], is % 4);
-            for (var j: u32 = 0; j < 8; j++) {
-                let g = get_byte(iq2xxs_grid[(ig + j) / 4], (ig + j) % 4);
-                let m = select(1.0, -1.0, (get_byte(kmask_iq2xs[j / 4], j % 4) & signs) != 0);
-                sum += db * f32(g) * m * src1[src1_i];
-                src1_i++;
-            }
-        }
-    }
-    return sum;
-}
-
-#enddecl(IQ2_XXS)
-
-#decl(IQ2_XS)
-fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
-    let block = src0[src0_idx_base + offset];
-    let d = f32(block.d);
-    var src1_i = src1_idx_base + offset * 256;
-    var scale_vals = array<u32, 2>(
-        bitcast<u32>(vec2(block.scales[0], block.scales[1])),
-        bitcast<u32>(vec2(block.scales[2], block.scales[3]))
-    );
-    var sum = 0.0;
-    for (var ib: u32 = 0; ib < 32; ib += 4) {
-        let s = get_byte(scale_vals[ib / 16], (ib % 16) / 4);
-        let db = array<f32, 2>(
-            d * (0.5 + f32(s & 0xF)) * 0.25,
-            d * (0.5 + f32(s >> 4)) * 0.25
-        );
-        for (var l: u32 = 0; l < 4; l++) {
-            let qs_val = bitcast<u32>(vec2(block.qs[ib + l], 0.0));
-            let ig = (qs_val & 511) * 8;
-            let is = qs_val >> 9;
-            let signs = get_byte(ksigns_iq2xs[is / 4], is % 4);
-            let dl = db[l/2];
-            for (var j: u32 = 0; j < 8; j++) {
-                let g = get_byte(iq2xs_grid[(ig + j) / 4], (ig + j) % 4);
-                let m = select(1.0, -1.0, (get_byte(kmask_iq2xs[j / 4], j % 4) & signs) != 0);
-                sum += dl * f32(g) * m * src1[src1_i];
-                src1_i++;
-            }
-        }
-    }
-    return sum;
-}
-
-#enddecl(IQ2_XS)
-
-#decl(IQ2_S)
-fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
-    let block = src0[src0_idx_base + offset];
-    let d = f32(block.d);
-    var src1_i = src1_idx_base + offset * 256;
-    var qs_vals : array<u32, 16>;
-    for (var i: u32 = 0; i < 16; i++) {
-        qs_vals[i] = bitcast<u32>(vec2(block.qs[i * 2], block.qs[i * 2 + 1]));
-    }
-    var qh_vals = array<u32, 2>(
-        bitcast<u32>(vec2(block.qh[0], block.qh[1])),
-        bitcast<u32>(vec2(block.qh[2], block.qh[3]))
-    );
-    var scale_vals = array<u32, 2>(
-        bitcast<u32>(vec2(block.scales[0], block.scales[1])),
-        bitcast<u32>(vec2(block.scales[2], block.scales[3]))
-    );
-    var sum = 0.0;
-    for (var ib: u32 = 0; ib < 8; ib ++) {
-        let s = get_byte(scale_vals[ib / 4], ib % 4);
-        let db = array<f32, 2>(
-            d * (0.5 + f32(s & 0xF)) * 0.25,
-            d * (0.5 + f32(s >> 4)) * 0.25
-        );
-        let qs_w = qs_vals[ib];
-        for (var l: u32 = 0; l < 4; l++) {
-            let qh_b = (get_byte(qh_vals[ib / 4], ib % 4) << (8 - 2 * l)) & 0x300;
-            let ig = (get_byte(qs_w, l) | qh_b) * 8;
-            let signs = get_byte(qs_vals[ib + 8], l);
-            let dl = db[l/2];
-            for (var j: u32 = 0; j < 8; j++) {
-                let g = get_byte(iq2s_grid[(ig + j) / 4], (ig + j) % 4);
-                let m = select(1.0, -1.0, (get_byte(kmask_iq2xs[j / 4], j % 4) & signs) != 0);
-                sum += dl * f32(g) * m * src1[src1_i];
-                src1_i++;
-            }
-        }
-    }
-    return sum;
-}
-
-
-#enddecl(IQ2_S)
-
-#decl(IQ3_XSS)
-fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
-    let block = src0[src0_idx_base + offset];
-    let d = f32(block.d);
-    var src1_i = src1_idx_base + offset * 256;
-    var sum = 0.0;
-    for (var ib: u32 = 0; ib < 16; ib += 2) {
-        let sc_sign = bitcast<u32>(vec2(block.qs[ib + 32], block.qs[ib + 33]));
-        let db = d * (0.5 + f32(sc_sign >> 28)) * 0.5;
-        for (var l: u32 = 0; l < 4; l++) {
-            let is = (sc_sign >> (7 * l)) & 127;
-            let signs = get_byte(ksigns_iq2xs[is / 4], is % 4);
-            let ig_val = bitcast<u32>(vec2(block.qs[ib * 2 + l], 0.0));
-            let ig1 = get_byte(ig_val, 0);
-            let ig2 = get_byte(ig_val, 1);
-            for (var j: u32 = 0; j < 4; j++) {
-                let g1 = get_byte(iq3xxs_grid[ig1], j);
-                let g2 = get_byte(iq3xxs_grid[ig2], j);
-                let m1 = select(1.0, -1.0, (get_byte(kmask_iq2xs[0], j) & signs) != 0);
-                let m2 = select(1.0, -1.0, (get_byte(kmask_iq2xs[1], j) & signs) != 0);
-                sum += db * f32(g1) * m1 * src1[src1_i];
-                sum += db * f32(g2) * m2 * src1[src1_i + 4];
-                src1_i++;
-            }
-            src1_i += 4;
-        }
-    }
-    return sum;
-}
-
-#enddecl(IQ3_XSS)
-
-#decl(IQ3_S)
-fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
-    let block = src0[src0_idx_base + offset];
-    let d = f32(block.d);
-    var src1_i = src1_idx_base + offset * 256;
-    var qh_vals = array<u32, 2>(
-        bitcast<u32>(vec2(block.qh[0], block.qh[1])),
-        bitcast<u32>(vec2(block.qh[2], block.qh[3]))
-    );
-    var sign_vals: array<u32, 8>;
-    for (var i: u32 = 0; i < 8; i++) {
-        sign_vals[i] = bitcast<u32>(vec2(block.signs[i * 2], block.signs[i * 2 + 1]));
-    }
-    var scale_vals = bitcast<u32>(vec2(block.scales[0], block.scales[1]));
-    var sum = 0.0;
-    for (var ib: u32 = 0; ib < 4; ib++) {
-        let s = get_byte(scale_vals, ib);
-        let db = array<f32, 2>(
-            d * (1.0 + 2.0 * f32(s & 0xF)),
-            d * (1.0 + 2.0 * f32(s >> 4))
-        );
-        for (var k: u32 = 0; k < 2; k++) {
-            let dl = db[k];
-            let qh_byte = get_byte(qh_vals[ib / 2], (ib % 2) * 2 + k);
-            let sign_w = sign_vals[ib * 2 + k];
-            for (var l: u32 = 0; l < 4; l++) {
-                let signs = get_byte(sign_w, l);
-                let ig_val = bitcast<u32>(vec2(block.qs[ib * 8 + k * 4 + l], 0.0));
-                let ig1 = get_byte(ig_val, 0) | ((qh_byte << ((8 - (2 * l)))) & 256);
-                let ig2 = get_byte(ig_val, 1) | ((qh_byte << ((7 - (2 * l)))) & 256);
-                for (var j: u32 = 0; j < 4; j++) {
-                    let g1 = get_byte(iq3s_grid[ig1], j);
-                    let g2 = get_byte(iq3s_grid[ig2], j);
-                    let m1 = select(1.0, -1.0, (get_byte(kmask_iq2xs[0], j) & signs) != 0);
-                    let m2 = select(1.0, -1.0, (get_byte(kmask_iq2xs[1], j) & signs) != 0);
-                    sum += dl * f32(g1) * m1 * src1[src1_i];
-                    sum += dl * f32(g2) * m2 * src1[src1_i + 4];
-                    src1_i++;
-                }
-                src1_i += 4;
-            }
-        }
-    }
-    return sum;
-}
-#enddecl(IQ3_S)
-
-#decl(IQ1_S)
-fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
-    let block = src0[src0_idx_base + offset];
-    let d = f32(block.d);
-    var src1_i = src1_idx_base + offset * 256;
-    var sum = 0.0;
-    for (var ib: u32 = 0; ib < 8; ib++) {
-        let qh = bitcast<u32>(vec2(block.qh[ib], 0.0));
-        let dl = d * (2 * f32((qh >> 12) & 7) + 1);
-        let delta = select(IQ1_DELTA, -IQ1_DELTA, (qh & 0x8000) != 0);
-        let qs_w = bitcast<u32>(vec2(block.qs[ib * 2], block.qs[ib * 2 + 1]));
-        for (var l: u32 = 0; l < 4; l++) {
-            let ig = (get_byte(qs_w, l) | (((qh >> (3 * l)) & 7) << 8)) * 8;
-            for (var j: u32 = 0; j < 8; j++) {
-                let gw = iq1_grid[(ig + j) / 16];
-                let g = (gw >> (((ig + j) % 16) * 2)) & 3;
-                let gs = bitcast<i32>(g << 30) >> 30;
-                sum += dl * (f32(gs) + delta) * src1[src1_i];
-                src1_i++;
-            }
-        }
-    }
-    return sum;
-}
-
-#enddecl(IQ1_S)
-
-#decl(IQ1_M)
-fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
-    let block = src0[src0_idx_base + offset];
-
-    let scale = ((block.scales[0] >> 12) & 0xF) | ((block.scales[0] >> 24) & 0x00F0) | ((block.scales[1] >> 4) & 0x0F00) | ((block.scales[1] >> 16) & 0xF000);
-    let d = f32(bitcast<vec2<f16>>(scale).x);
-    var src1_i = src1_idx_base + offset * 256;
-    var sum = 0.0;
-    for (var ib: u32 = 0; ib < 8; ib++) {
-        let sw = (block.scales[ib / 4] >> (16 * ((ib / 2) % 2))) & 0xFFFF;
-        let s1 : u32 = (sw >> (6 * (ib % 2))) & 0x7;
-        let s2 : u32 = (sw >> (6 * (ib % 2) + 3)) & 0x7;
-        var dl = array<f32, 2>(
-            d * f32(2 * s1 + 1),
-            d * f32(2 * s2 + 1)
-        );
-
-        let qh = block.qh[ib / 2] >> (16 * (ib % 2));
-        var idx = array<u32, 4>(
-            get_byte(block.qs[ib], 0) | ((qh << 8) & 0x700),
-            get_byte(block.qs[ib], 1) | ((qh << 4) & 0x700),
-            get_byte(block.qs[ib], 2) | ((qh) & 0x700),
-            get_byte(block.qs[ib], 3) | ((qh >> 4) & 0x700)
-        );
-        var delta = array<f32, 4>(
-            select(IQ1_DELTA, -IQ1_DELTA, (qh & 0x08) != 0),
-            select(IQ1_DELTA, -IQ1_DELTA, (qh & 0x80) != 0),
-            select(IQ1_DELTA, -IQ1_DELTA, ((qh >> 8) & 0x08) != 0),
-            select(IQ1_DELTA, -IQ1_DELTA, ((qh >> 8) & 0x80) != 0)
-        );
-        for (var l: u32 = 0; l < 4; l++) {
-            let ig = idx[l] * 8;
-            for (var j: u32 = 0; j < 8; j++) {
-                let gw = iq1_grid[(ig + j) / 16];
-                let g = (gw >> (((ig + j) % 16) * 2)) & 3;
-                let gs = bitcast<i32>(g << 30) >> 30;
-                sum += dl[l/2] * (f32(gs) + delta[l]) * src1[src1_i];
-                src1_i++;
-            }
-        }
-    }
-    return sum;
-}
-
-#enddecl(IQ1_M)
-
-#decl(IQ4_NL)
-fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
-    let block = src0[src0_idx_base + offset];
-    let d = f32(block.d);
-    var src1_i = src1_idx_base + offset * 32;
-    var sum = 0.0;
-    var qs: array<u32, 4>;
-    for (var i: u32 = 0; i < 4; i++) {
-        qs[i] = bitcast<u32>(vec2(block.qs[i * 2], block.qs[i * 2 + 1]));
-    }
-    for (var j: u32 = 0; j < 16; j++) {
-        let qsb = get_byte(qs[j / 4], j % 4);
-        sum += d * f32(kvalues_iq4nl[qsb & 0xF]) * src1[src1_i];
-        sum += d * f32(kvalues_iq4nl[qsb >> 4]) * src1[src1_i + 16];
-        src1_i++;
-    }
-    return sum;
-}
-
-#enddecl(IQ4_NL)
-
-#decl(IQ4_XS)
-fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
-    let block = src0[src0_idx_base + offset];
-    let d = f32(block.d);
-    let scales_h = bitcast<u32>(vec2(block.scales_h, 0.0));
-    var src1_i = src1_idx_base + offset * 256;
-    var sum = 0.0;
-    for (var ib: u32 = 0; ib < 8; ib++) {
-        let ls = ((get_byte(block.scales_l, ib / 2) >> (4 * (ib % 2))) & 0xF) | (((scales_h >> (2 * ib)) & 3) << 4);
-        let dl = d * (f32(ls) - 32.0);
-        for (var j: u32 = 0; j < 16; j++) {
-            let iqs = ib * 16 + j;
-            let qsb = get_byte(block.qs[iqs / 4], iqs % 4);
-            sum += dl * f32(kvalues_iq4nl[qsb & 0xF]) * src1[src1_i];
-            sum += dl * f32(kvalues_iq4nl[qsb >> 4]) * src1[src1_i + 16];
-            src1_i++;
-        }
-        src1_i += 16;
-    }
-    return sum;
-}
-
-#enddecl(IQ4_XS)
-
-#end(DECLS)
-
-#define(SHADER)
-
-enable f16;
-
-DECLS
-
-struct MulMatParams {
-    offset_src0: u32, // in elements/blocks
-    offset_src1: u32, // in elements/blocks
-    offset_dst: u32, // in elements/blocks
-    m: u32,
-    n: u32,
-    k: u32,
-    // all strides are in elements/blocks
-    stride_01: u32,
-    stride_11: u32,
-    stride_02: u32,
-    stride_12: u32,
-    stride_03: u32,
-    stride_13: u32,
-
-    bs02: u32,
-    bs03: u32,
-    broadcast2: u32,
-    broadcast3: u32
-};
-
-@group(0) @binding(0) var<storage, read_write> src0: array<{{SRC0_TYPE}}>; // M rows, K columns
-@group(0) @binding(1) var<storage, read_write> src1: array<{{SRC1_TYPE}}>; // K rows, N columns (transposed)
-@group(0) @binding(2) var<storage, read_write> dst: array<f32>; // M rows, N columns
-
-@group(0) @binding(3) var<uniform> params: MulMatParams;
-
-@compute @workgroup_size(256)
-fn main(@builtin(global_invocation_id) global_id: vec3<u32>) {
-    let total = params.m * params.n * params.bs02 * params.broadcast2 * params.bs03 * params.broadcast3;
-    if (global_id.x >= total) {
-        return;
-    }
-
-    let dst2_stride = params.m * params.n;
-    let dst3_stride = dst2_stride * params.bs02 * params.broadcast2;
-
-    let dst3_idx = global_id.x / dst3_stride;
-    let src03_idx = dst3_idx / params.broadcast3; // src0 may be broadcast along the third dimension
-    let src13_idx = dst3_idx; // src1 is not broadcast
-    let dst3_rem = global_id.x % dst3_stride;
-
-    let dst2_idx = dst3_rem / dst2_stride;
-    let src02_idx = dst2_idx / params.broadcast2; // src0 may also be broadcast along the second dimension
-    let src12_idx = dst2_idx; // src1 is not broadcast
-
-    let dst2_rem = dst3_rem % dst2_stride;
-
-    let row = dst2_rem / params.m; // output row
-    let col = dst2_rem % params.m; // output column
-
-    let src0_idx_base = params.offset_src0 + src03_idx * params.stride_03 + src02_idx * params.stride_02 + col * params.stride_01;
-    let src1_idx_base = params.offset_src1 + src13_idx * params.stride_13 + src12_idx * params.stride_12 + row * params.stride_11;
-
-    var sum = 0.0;
-    for (var i: u32 = 0u; i < params.k/{{BLOCK_SIZE}}; i = i + 1u) {
-        sum += multiply_add(src0_idx_base, src1_idx_base, i);
-    }
-    dst[params.offset_dst + dst3_idx * dst3_stride + dst2_idx * dst2_stride + row * params.m + col] = sum;
-}
-
-#end(SHADER)
diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat.wgsl
new file mode 100644 (file)
index 0000000..6aba473
--- /dev/null
@@ -0,0 +1,713 @@
+enable f16;
+
+#include "common_decls.tmpl"
+
+#ifdef FLOAT
+const BLOCK_SIZE = 1u;
+
+#elif defined(Q4_0) || defined(Q4_1) || defined(Q5_0) || defined(Q5_1) || defined(Q8_0) || defined(Q8_1) || defined(IQ4_NL)
+const BLOCK_SIZE = 32u;
+
+#elif defined(Q2_K) || defined(Q3_K) || defined(Q4_K) || defined(Q5_K) || defined(Q6_K) || defined(IQ2_XXS) || defined(IQ2_XS) || defined(IQ2_S) || defined(IQ3_XXS) || defined(IQ3_S) || defined(IQ1_S) || defined(IQ1_M) || defined(IQ4_XS)
+const BLOCK_SIZE = 256u;
+#endif
+
+#ifdef FLOAT
+fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
+    return f32(src0[src0_idx_base + offset]) * f32(src1[src1_idx_base + offset]);
+}
+#endif
+
+#ifdef Q4_0
+fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
+    let block_q4_0 = src0[src0_idx_base + offset];
+    let d = f32(block_q4_0.d);
+    var sum: f32 = 0.0;
+    for (var j: u32 = 0; j < 4; j++) {
+        let q_packed = bitcast<u32>(vec2(block_q4_0.qs[2 * j], block_q4_0.qs[2 * j + 1]));
+        for (var k: u32 = 0; k < 4; k++) {
+            let q_byte = get_byte(q_packed, k);
+            let q_hi = (f32((q_byte >> 4) & 0xF) - 8.0f) * d;
+            let q_lo = (f32(q_byte & 0xF) - 8.0f) * d;
+            let src1_offset = src1_idx_base + offset * 32 + j * 4 + k;
+            sum += q_lo * f32(src1[src1_offset]);
+            sum += q_hi * f32(src1[src1_offset + 16]);
+        }
+    }
+    return sum;
+}
+#endif
+
+#ifdef Q4_1
+fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
+    let block_q4_1 = src0[src0_idx_base + offset];
+    let d = f32(block_q4_1.d);
+    let m = f32(block_q4_1.m);
+    var sum: f32 = 0.0;
+    for (var j: u32 = 0; j < 4; j++) {
+        let q_packed = block_q4_1.qs[j];
+        for (var k: u32 = 0; k < 4; k++) {
+            let q_byte = get_byte(q_packed, k);
+            let q_hi = f32((q_byte >> 4) & 0xF) * d + m;
+            let q_lo = f32(q_byte & 0xF) * d + m;
+            let src1_offset = src1_idx_base + offset * 32 + j * 4 + k;
+            sum += q_lo * f32(src1[src1_offset]);
+            sum += q_hi * f32(src1[src1_offset + 16]);
+        }
+    }
+    return sum;
+}
+#endif
+
+#ifdef Q5_0
+fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
+    let block_q5_0 = src0[src0_idx_base + offset];
+    let d = f32(block_q5_0.d);
+    var sum: f32 = 0.0;
+    let qh_packed = bitcast<u32>(vec2(block_q5_0.qh[0], block_q5_0.qh[1]));
+    for (var j: u32 = 0; j < 4; j++) {
+        let q_packed = bitcast<u32>(vec2(block_q5_0.qs[2 * j], block_q5_0.qs[2 * j + 1]));
+        for (var k: u32 = 0; k < 4; k++) {
+            let q_byte = get_byte(q_packed, k);
+            let qh_hi = (qh_packed >> (j * 4 + k + 12)) & 0x10;
+            let q_hi = (f32(((q_byte >> 4) & 0xF) | qh_hi) - 16.0) * d;
+            let qh_lo = ((qh_packed >> (j * 4 + k)) << 4) & 0x10;
+            let q_lo = (f32((q_byte & 0xF) | qh_lo) - 16.0) * d;
+            let src1_offset = src1_idx_base + offset * 32 + j * 4 + k;
+            sum += q_lo * f32(src1[src1_offset]);
+            sum += q_hi * f32(src1[src1_offset + 16]);
+        }
+    }
+    return sum;
+}
+#endif
+
+#ifdef Q5_1
+fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
+    let block_q5_1 = src0[src0_idx_base + offset];
+    let d = f32(block_q5_1.d);
+    let m = f32(block_q5_1.m);
+    var sum: f32 = 0.0;
+    for (var j: u32 = 0; j < 4; j++) {
+        let q_packed = block_q5_1.qs[j];
+        for (var k: u32 = 0; k < 4; k++) {
+            let q_byte = get_byte(q_packed, k);
+            let qh_hi = (block_q5_1.qh >> (j * 4 + k + 12)) & 0x10;
+            let q_hi = f32(((q_byte >> 4) & 0xF) | qh_hi) * d + m;
+            let qh_lo = ((block_q5_1.qh >> (j * 4 + k)) << 4) & 0x10;
+            let q_lo = f32((q_byte & 0xF) | qh_lo) * d + m;
+            let src1_offset = src1_idx_base + offset * 32 + j * 4 + k;
+            sum += q_lo * f32(src1[src1_offset]);
+            sum += q_hi * f32(src1[src1_offset + 16]);
+        }
+    }
+    return sum;
+}
+#endif
+
+#ifdef Q8_0
+fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
+    let block_q8_0 = src0[src0_idx_base + offset];
+    let d = f32(block_q8_0.d);
+    var sum: f32 = 0.0;
+    for (var j: u32 = 0; j < 8; j++) {
+        let q_packed = bitcast<u32>(vec2(block_q8_0.qs[2 * j], block_q8_0.qs[2 * j + 1]));
+        for (var k: u32 = 0; k < 4; k++) {
+            let q_byte = get_byte_i32(q_packed, k);
+            let q_val = f32(q_byte) * d;
+            let src1_offset = src1_idx_base + offset * 32 + j * 4 + k;
+            sum += q_val * f32(src1[src1_offset]);
+        }
+    }
+    return sum;
+}
+#endif
+
+#ifdef Q8_1
+fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
+    let block_q8_1 = src0[src0_idx_base + offset];
+    let d = f32(block_q8_1.d);
+    let m = f32(block_q8_1.m);
+    var sum: f32 = 0.0;
+    for (var j: u32 = 0; j < 8; j++) {
+        let q_packed = block_q8_1.qs[j];
+        for (var k: u32 = 0; k < 4; k++) {
+            let q_byte = get_byte_i32(q_packed, k);
+            let q_val = f32(q_byte) * d + m;
+            let src1_offset = src1_idx_base + offset * 32 + j * 4 + k;
+            sum += q_val * f32(src1[src1_offset]);
+        }
+    }
+    return sum;
+}
+#endif
+
+#ifdef Q2_K
+// 16 blocks of 16 elements each
+fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
+    let block = src0[src0_idx_base + offset];
+    let d = f32(block.d);
+    let m = f32(block.dmin);
+    var sum = 0.0;
+    var src1_i = src1_idx_base + offset * 256;
+    var is: u32 = 0;
+    // 2 halves of the block (128 elements each)
+    for (var q_b_idx: u32 = 0; q_b_idx < 64; q_b_idx += 32) {
+        // 4 groups (each group has 2 blocks of 16 elements)
+        for (var shift: u32 = 0; shift < 8; shift += 2) {
+            // 2 blocks
+            for (var k: u32 = 0; k < 32; k += 16) {
+                let sc = get_byte(block.scales[is / 4], is % 4);
+                is++;
+                let dl = d * f32(sc & 0xF);
+                let ml = m * f32(sc >> 4);
+                for (var l: u32 = 0u; l < 16; l++) {
+                    let q_idx = q_b_idx + k + l;
+                    let q_byte = get_byte(block.qs[q_idx / 4], q_idx % 4);
+                    let qs_val = (q_byte >> shift) & 3;
+                    sum += (f32(qs_val) * dl - ml) * src1[src1_i];
+                    src1_i++;
+                }
+            }
+        }
+    }
+    return sum;
+}
+#endif
+
+#ifdef Q3_K
+// 16 blocks of 16 elements each
+fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
+    let block = src0[src0_idx_base + offset];
+    let d = f32(block.d);
+
+    // extract 6-bit scales, which consist of 4-bits from first 8 bytes of scale,
+    // and 2-bits from the last 4 bytes
+    let kmask1: u32 = 0x03030303;
+    let kmask2: u32 = 0x0f0f0f0f;
+    var scale_vals: array<u32, 4>;
+    for (var i: u32 = 0; i < 4; i++) {
+        scale_vals[i] = bitcast<u32>(vec2(block.scales[2 * i], block.scales[2 * i + 1]));
+    }
+    var tmp: u32 = scale_vals[2];
+    scale_vals[2] = ((scale_vals[0] >> 4) & kmask2) | (((tmp >> 4) & kmask1) << 4);
+    scale_vals[3] = ((scale_vals[1] >> 4) & kmask2) | (((tmp >> 6) & kmask1) << 4);
+    scale_vals[0] = (scale_vals[0] & kmask2) | ((tmp & kmask1) << 4);
+    scale_vals[1] = (scale_vals[1] & kmask2) | (((tmp >> 2) & kmask1) << 4);
+
+    // convert arrays of f16 -> u32
+    var hmask_vals: array<u32, 8>;
+    for (var i: u32 = 0; i < 8; i++) {
+        hmask_vals[i] = bitcast<u32>(vec2(block.hmask[2 * i], block.hmask[2 * i + 1]));
+    }
+    var qs_vals: array<u32, 16>;
+    for (var i: u32 = 0; i < 16; i++) {
+        qs_vals[i] = bitcast<u32>(vec2(block.qs[2 * i], block.qs[2 * i + 1]));
+    }
+
+    var sum = 0.0;
+    var src1_i = src1_idx_base + offset * 256;
+    var is: u32 = 0;
+    var m: u32 = 1;
+    // 2 halves of the block (128 elements each)
+    for (var q_b_idx: u32 = 0; q_b_idx < 64; q_b_idx += 32) {
+        // 4 groups (each group has 2 blocks of 16 elements)
+        for (var shift: u32 = 0; shift < 8; shift += 2) {
+            // 2 blocks
+            for (var k: u32 = 0; k < 32; k += 16) {
+                let sc = get_byte(scale_vals[is / 4], is % 4);
+                is++;
+                let dl = d * (f32(sc) - 32.0);
+                for (var l: u32 = 0u; l < 16u; l++) {
+                    let q_idx = q_b_idx + k + l;
+                    let hm_idx = k + l;
+                    let q_byte = get_byte(qs_vals[q_idx / 4], q_idx % 4);
+                    let hmask_byte = get_byte(hmask_vals[hm_idx / 4], hm_idx % 4);
+                    let hm = select(4.0, 0.0, (hmask_byte & m) != 0);
+                    let qs_val = (q_byte >> shift) & 3;
+                    sum += ((f32(qs_val) - hm) * dl) * src1[src1_i];
+                    src1_i++;
+                }
+            }
+            m <<= 1;
+        }
+    }
+    return sum;
+}
+#endif
+
+#ifdef Q4_K
+// 8 blocks of 32 elements each
+fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
+    let block = src0[src0_idx_base + offset];
+    let d = f32(block.d);
+    let m = f32(block.dmin);
+    var sum = 0.0;
+    var src1_i = src1_idx_base + offset * 256;
+    var is: u32 = 0;
+    // 2 blocks each iteration
+    for (var q_b_idx: u32 = 0; q_b_idx < 128; q_b_idx += 32) {
+        for (var shift: u32 = 0; shift < 8; shift += 4) {
+            let scale_min = get_scale_min(is, block.scales);
+            is++;
+            let dl = d * scale_min.x;
+            let ml = m * scale_min.y;
+            for (var l: u32 = 0; l < 32; l++) {
+                let q_idx = q_b_idx + l;
+                let q_byte = get_byte(block.qs[q_idx / 4], q_idx % 4);
+                let qs_val = (q_byte >> shift) & 0xF;
+                sum += (f32(qs_val) * dl - ml) * src1[src1_i];
+                src1_i++;
+            }
+        }
+    }
+    return sum;
+}
+#endif
+
+#ifdef Q5_K
+// 8 blocks of 32 elements each
+fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
+    let block = src0[src0_idx_base + offset];
+    let d = f32(block.d);
+    let m = f32(block.dmin);
+    var sum = 0.0;
+    var src1_i = src1_idx_base + offset * 256;
+    var is: u32 = 0;
+    var u: u32 = 1;
+    // 2 blocks each iteration
+    for (var q_b_idx: u32 = 0; q_b_idx < 128; q_b_idx += 32) {
+        for (var shift: u32 = 0; shift < 8; shift += 4) {
+            let scale_min = get_scale_min(is, block.scales);
+            is++;
+            let dl = d * scale_min.x;
+            let ml = m * scale_min.y;
+            for (var l: u32 = 0; l < 32; l++) {
+                let q_idx = q_b_idx + l;
+                let q_byte = get_byte(block.qs[q_idx / 4], q_idx % 4);
+                let qh_byte = get_byte(block.qh[l / 4], l % 4);
+                let qs_val = (q_byte >> shift) & 0xF;
+                let qh_val = select(0.0, 16.0, (qh_byte & u) != 0);
+                sum += ((f32(qs_val) + qh_val) * dl - ml) * src1[src1_i];
+               src1_i++;
+            }
+            u <<= 1;
+        }
+    }
+    return sum;
+}
+#endif
+
+#ifdef Q6_K
+// 16 blocks of 16 elements each
+fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
+    let block = src0[src0_idx_base + offset];
+    let d = f32(block.d);
+
+    // convert arrays of f16 -> u32
+    var ql_vals: array<u32, 32>;
+    for (var i: u32 = 0; i < 32; i++) {
+        ql_vals[i] = bitcast<u32>(vec2(block.ql[2 * i], block.ql[2 * i + 1]));
+    }
+    var qh_vals: array<u32, 16>;
+    for (var i: u32 = 0; i < 16; i++) {
+        qh_vals[i] = bitcast<u32>(vec2(block.qh[2 * i], block.qh[2 * i + 1]));
+    }
+    var scale_vals: array<u32, 4>;
+    for (var i: u32 = 0; i < 4; i++) {
+        scale_vals[i] = bitcast<u32>(vec2(block.scales[2 * i], block.scales[2 * i + 1]));
+    }
+
+    var sum = 0.0;
+    var src1_i = src1_idx_base + offset * 256;
+    var qh_b_idx: u32 = 0;
+    var sc_b_idx: u32 = 0;
+    for (var ql_b_idx: u32 = 0; ql_b_idx < 128; ql_b_idx += 64) {
+        for (var l: u32 = 0; l < 32; l++) {
+            let ql13_b = get_byte(ql_vals[(ql_b_idx + l) / 4], (ql_b_idx + l) % 4);
+            let ql24_b = get_byte(ql_vals[(ql_b_idx + l + 32) / 4], (ql_b_idx + l + 32) % 4);
+            let qh_b = get_byte(qh_vals[(qh_b_idx + l) / 4], (qh_b_idx + l) % 4);
+
+            let q1 = f32((ql13_b & 0xF) | ((qh_b & 3) << 4)) - 32.0;
+            let q2 = f32((ql24_b & 0xF) | (((qh_b >> 2) & 3) << 4)) - 32.0;
+            let q3 = f32((ql13_b >> 4) | (((qh_b >> 4) & 3) << 4)) - 32.0;
+            let q4 = f32((ql24_b >> 4) | (((qh_b >> 6) & 3) << 4)) - 32.0;
+
+            let is = l/16;
+            let is1 = sc_b_idx + is;
+            let sc1 = get_byte_i32(scale_vals[is1 / 4], is1 % 4);
+            let is2 = sc_b_idx + is + 2;
+            let sc2 = get_byte_i32(scale_vals[is2 / 4], is2 % 4);
+            let is3 = sc_b_idx + is + 4;
+            let sc3 = get_byte_i32(scale_vals[is3 / 4], is3 % 4);
+            let is4 = sc_b_idx + is + 6;
+            let sc4 = get_byte_i32(scale_vals[is4 / 4], is4 % 4);
+
+            sum += d * f32(sc1) * q1 * src1[src1_i + l];
+            sum += d * f32(sc2) * q2 * src1[src1_i + l + 32];
+            sum += d * f32(sc3) * q3 * src1[src1_i + l + 64];
+            sum += d * f32(sc4) * q4 * src1[src1_i + l + 96];
+        }
+        src1_i += 128;
+        qh_b_idx += 32;
+        sc_b_idx += 8;
+    }
+    return sum;
+}
+#endif
+
+#ifdef IQ2_XXS
+fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
+    let block = src0[src0_idx_base + offset];
+    let d = f32(block.d);
+    var src1_i = src1_idx_base + offset * 256;
+    var sum = 0.0;
+    for (var ib: u32 = 0; ib < 32; ib += 4) {
+        let aux0 = bitcast<u32>(vec2(block.qs[ib], block.qs[ib + 1]));
+        let aux1 = bitcast<u32>(vec2(block.qs[ib + 2], block.qs[ib + 3]));
+        let db = d * (0.5 + f32(aux1 >> 28)) * 0.25;
+        for (var l: u32 = 0; l < 4; l++) {
+            let ig = get_byte(aux0, l) * 8;
+            let is = (aux1 >> (7 * l)) & 127;
+            let signs = get_byte(ksigns_iq2xs[is / 4], is % 4);
+            for (var j: u32 = 0; j < 8; j++) {
+                let g = get_byte(iq2xxs_grid[(ig + j) / 4], (ig + j) % 4);
+                let m = select(1.0, -1.0, (get_byte(kmask_iq2xs[j / 4], j % 4) & signs) != 0);
+                sum += db * f32(g) * m * src1[src1_i];
+                src1_i++;
+            }
+        }
+    }
+    return sum;
+}
+#endif
+
+#ifdef IQ2_XS
+fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
+    let block = src0[src0_idx_base + offset];
+    let d = f32(block.d);
+    var src1_i = src1_idx_base + offset * 256;
+    var scale_vals = array<u32, 2>(
+        bitcast<u32>(vec2(block.scales[0], block.scales[1])),
+        bitcast<u32>(vec2(block.scales[2], block.scales[3]))
+    );
+    var sum = 0.0;
+    for (var ib: u32 = 0; ib < 32; ib += 4) {
+        let s = get_byte(scale_vals[ib / 16], (ib % 16) / 4);
+        let db = array<f32, 2>(
+            d * (0.5 + f32(s & 0xF)) * 0.25,
+            d * (0.5 + f32(s >> 4)) * 0.25
+        );
+        for (var l: u32 = 0; l < 4; l++) {
+            let qs_val = bitcast<u32>(vec2(block.qs[ib + l], 0.0));
+            let ig = (qs_val & 511) * 8;
+            let is = qs_val >> 9;
+            let signs = get_byte(ksigns_iq2xs[is / 4], is % 4);
+            let dl = db[l/2];
+            for (var j: u32 = 0; j < 8; j++) {
+                let g = get_byte(iq2xs_grid[(ig + j) / 4], (ig + j) % 4);
+                let m = select(1.0, -1.0, (get_byte(kmask_iq2xs[j / 4], j % 4) & signs) != 0);
+                sum += dl * f32(g) * m * src1[src1_i];
+                src1_i++;
+            }
+        }
+    }
+    return sum;
+}
+#endif
+
+#ifdef IQ2_S
+fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
+    let block = src0[src0_idx_base + offset];
+    let d = f32(block.d);
+    var src1_i = src1_idx_base + offset * 256;
+    var qs_vals : array<u32, 16>;
+    for (var i: u32 = 0; i < 16; i++) {
+        qs_vals[i] = bitcast<u32>(vec2(block.qs[i * 2], block.qs[i * 2 + 1]));
+    }
+    var qh_vals = array<u32, 2>(
+        bitcast<u32>(vec2(block.qh[0], block.qh[1])),
+        bitcast<u32>(vec2(block.qh[2], block.qh[3]))
+    );
+    var scale_vals = array<u32, 2>(
+        bitcast<u32>(vec2(block.scales[0], block.scales[1])),
+        bitcast<u32>(vec2(block.scales[2], block.scales[3]))
+    );
+    var sum = 0.0;
+    for (var ib: u32 = 0; ib < 8; ib ++) {
+        let s = get_byte(scale_vals[ib / 4], ib % 4);
+        let db = array<f32, 2>(
+            d * (0.5 + f32(s & 0xF)) * 0.25,
+            d * (0.5 + f32(s >> 4)) * 0.25
+        );
+        let qs_w = qs_vals[ib];
+        for (var l: u32 = 0; l < 4; l++) {
+            let qh_b = (get_byte(qh_vals[ib / 4], ib % 4) << (8 - 2 * l)) & 0x300;
+            let ig = (get_byte(qs_w, l) | qh_b) * 8;
+            let signs = get_byte(qs_vals[ib + 8], l);
+            let dl = db[l/2];
+            for (var j: u32 = 0; j < 8; j++) {
+                let g = get_byte(iq2s_grid[(ig + j) / 4], (ig + j) % 4);
+                let m = select(1.0, -1.0, (get_byte(kmask_iq2xs[j / 4], j % 4) & signs) != 0);
+                sum += dl * f32(g) * m * src1[src1_i];
+                src1_i++;
+            }
+        }
+    }
+    return sum;
+}
+#endif
+
+#ifdef IQ3_XXS
+fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
+    let block = src0[src0_idx_base + offset];
+    let d = f32(block.d);
+    var src1_i = src1_idx_base + offset * 256;
+    var sum = 0.0;
+    for (var ib: u32 = 0; ib < 16; ib += 2) {
+        let sc_sign = bitcast<u32>(vec2(block.qs[ib + 32], block.qs[ib + 33]));
+        let db = d * (0.5 + f32(sc_sign >> 28)) * 0.5;
+        for (var l: u32 = 0; l < 4; l++) {
+            let is = (sc_sign >> (7 * l)) & 127;
+            let signs = get_byte(ksigns_iq2xs[is / 4], is % 4);
+            let ig_val = bitcast<u32>(vec2(block.qs[ib * 2 + l], 0.0));
+            let ig1 = get_byte(ig_val, 0);
+            let ig2 = get_byte(ig_val, 1);
+            for (var j: u32 = 0; j < 4; j++) {
+                let g1 = get_byte(iq3xxs_grid[ig1], j);
+                let g2 = get_byte(iq3xxs_grid[ig2], j);
+                let m1 = select(1.0, -1.0, (get_byte(kmask_iq2xs[0], j) & signs) != 0);
+                let m2 = select(1.0, -1.0, (get_byte(kmask_iq2xs[1], j) & signs) != 0);
+                sum += db * f32(g1) * m1 * src1[src1_i];
+                sum += db * f32(g2) * m2 * src1[src1_i + 4];
+                src1_i++;
+            }
+            src1_i += 4;
+        }
+    }
+    return sum;
+}
+#endif
+
+#ifdef IQ3_S
+fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
+    let block = src0[src0_idx_base + offset];
+    let d = f32(block.d);
+    var src1_i = src1_idx_base + offset * 256;
+    var qh_vals = array<u32, 2>(
+        bitcast<u32>(vec2(block.qh[0], block.qh[1])),
+        bitcast<u32>(vec2(block.qh[2], block.qh[3]))
+    );
+    var sign_vals: array<u32, 8>;
+    for (var i: u32 = 0; i < 8; i++) {
+        sign_vals[i] = bitcast<u32>(vec2(block.signs[i * 2], block.signs[i * 2 + 1]));
+    }
+    var scale_vals = bitcast<u32>(vec2(block.scales[0], block.scales[1]));
+    var sum = 0.0;
+    for (var ib: u32 = 0; ib < 4; ib++) {
+        let s = get_byte(scale_vals, ib);
+        let db = array<f32, 2>(
+            d * (1.0 + 2.0 * f32(s & 0xF)),
+            d * (1.0 + 2.0 * f32(s >> 4))
+        );
+        for (var k: u32 = 0; k < 2; k++) {
+            let dl = db[k];
+            let qh_byte = get_byte(qh_vals[ib / 2], (ib % 2) * 2 + k);
+            let sign_w = sign_vals[ib * 2 + k];
+            for (var l: u32 = 0; l < 4; l++) {
+                let signs = get_byte(sign_w, l);
+                let ig_val = bitcast<u32>(vec2(block.qs[ib * 8 + k * 4 + l], 0.0));
+                let ig1 = get_byte(ig_val, 0) | ((qh_byte << ((8 - (2 * l)))) & 256);
+                let ig2 = get_byte(ig_val, 1) | ((qh_byte << ((7 - (2 * l)))) & 256);
+                for (var j: u32 = 0; j < 4; j++) {
+                    let g1 = get_byte(iq3s_grid[ig1], j);
+                    let g2 = get_byte(iq3s_grid[ig2], j);
+                    let m1 = select(1.0, -1.0, (get_byte(kmask_iq2xs[0], j) & signs) != 0);
+                    let m2 = select(1.0, -1.0, (get_byte(kmask_iq2xs[1], j) & signs) != 0);
+                    sum += dl * f32(g1) * m1 * src1[src1_i];
+                    sum += dl * f32(g2) * m2 * src1[src1_i + 4];
+                    src1_i++;
+                }
+                src1_i += 4;
+            }
+        }
+    }
+    return sum;
+}
+#endif
+
+#ifdef IQ1_S
+fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
+    let block = src0[src0_idx_base + offset];
+    let d = f32(block.d);
+    var src1_i = src1_idx_base + offset * 256;
+    var sum = 0.0;
+    for (var ib: u32 = 0; ib < 8; ib++) {
+        let qh = bitcast<u32>(vec2(block.qh[ib], 0.0));
+        let dl = d * (2 * f32((qh >> 12) & 7) + 1);
+        let delta = select(IQ1_DELTA, -IQ1_DELTA, (qh & 0x8000) != 0);
+        let qs_w = bitcast<u32>(vec2(block.qs[ib * 2], block.qs[ib * 2 + 1]));
+        for (var l: u32 = 0; l < 4; l++) {
+            let ig = (get_byte(qs_w, l) | (((qh >> (3 * l)) & 7) << 8)) * 8;
+            for (var j: u32 = 0; j < 8; j++) {
+                let gw = iq1_grid[(ig + j) / 16];
+                let g = (gw >> (((ig + j) % 16) * 2)) & 3;
+                let gs = bitcast<i32>(g << 30) >> 30;
+                sum += dl * (f32(gs) + delta) * src1[src1_i];
+                src1_i++;
+            }
+        }
+    }
+    return sum;
+}
+#endif
+
+
+#ifdef IQ1_M
+fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
+    let block = src0[src0_idx_base + offset];
+
+    let scale = ((block.scales[0] >> 12) & 0xF) | ((block.scales[0] >> 24) & 0x00F0) | ((block.scales[1] >> 4) & 0x0F00) | ((block.scales[1] >> 16) & 0xF000);
+    let d = f32(bitcast<vec2<f16>>(scale).x);
+    var src1_i = src1_idx_base + offset * 256;
+    var sum = 0.0;
+    for (var ib: u32 = 0; ib < 8; ib++) {
+        let sw = (block.scales[ib / 4] >> (16 * ((ib / 2) % 2))) & 0xFFFF;
+        let s1 : u32 = (sw >> (6 * (ib % 2))) & 0x7;
+        let s2 : u32 = (sw >> (6 * (ib % 2) + 3)) & 0x7;
+        var dl = array<f32, 2>(
+            d * f32(2 * s1 + 1),
+            d * f32(2 * s2 + 1)
+        );
+
+        let qh = block.qh[ib / 2] >> (16 * (ib % 2));
+        var idx = array<u32, 4>(
+            get_byte(block.qs[ib], 0) | ((qh << 8) & 0x700),
+            get_byte(block.qs[ib], 1) | ((qh << 4) & 0x700),
+            get_byte(block.qs[ib], 2) | ((qh) & 0x700),
+            get_byte(block.qs[ib], 3) | ((qh >> 4) & 0x700)
+        );
+        var delta = array<f32, 4>(
+            select(IQ1_DELTA, -IQ1_DELTA, (qh & 0x08) != 0),
+            select(IQ1_DELTA, -IQ1_DELTA, (qh & 0x80) != 0),
+            select(IQ1_DELTA, -IQ1_DELTA, ((qh >> 8) & 0x08) != 0),
+            select(IQ1_DELTA, -IQ1_DELTA, ((qh >> 8) & 0x80) != 0)
+        );
+        for (var l: u32 = 0; l < 4; l++) {
+            let ig = idx[l] * 8;
+            for (var j: u32 = 0; j < 8; j++) {
+                let gw = iq1_grid[(ig + j) / 16];
+                let g = (gw >> (((ig + j) % 16) * 2)) & 3;
+                let gs = bitcast<i32>(g << 30) >> 30;
+                sum += dl[l/2] * (f32(gs) + delta[l]) * src1[src1_i];
+                src1_i++;
+            }
+        }
+    }
+    return sum;
+}
+#endif
+
+#ifdef IQ4_NL
+fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
+    let block = src0[src0_idx_base + offset];
+    let d = f32(block.d);
+    var src1_i = src1_idx_base + offset * 32;
+    var sum = 0.0;
+    var qs: array<u32, 4>;
+    for (var i: u32 = 0; i < 4; i++) {
+        qs[i] = bitcast<u32>(vec2(block.qs[i * 2], block.qs[i * 2 + 1]));
+    }
+    for (var j: u32 = 0; j < 16; j++) {
+        let qsb = get_byte(qs[j / 4], j % 4);
+        sum += d * f32(kvalues_iq4nl[qsb & 0xF]) * src1[src1_i];
+        sum += d * f32(kvalues_iq4nl[qsb >> 4]) * src1[src1_i + 16];
+        src1_i++;
+    }
+    return sum;
+}
+#endif
+
+#ifdef IQ4_XS
+fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
+    let block = src0[src0_idx_base + offset];
+    let d = f32(block.d);
+    let scales_h = bitcast<u32>(vec2(block.scales_h, 0.0));
+    var src1_i = src1_idx_base + offset * 256;
+    var sum = 0.0;
+    for (var ib: u32 = 0; ib < 8; ib++) {
+        let ls = ((get_byte(block.scales_l, ib / 2) >> (4 * (ib % 2))) & 0xF) | (((scales_h >> (2 * ib)) & 3) << 4);
+        let dl = d * (f32(ls) - 32.0);
+        for (var j: u32 = 0; j < 16; j++) {
+            let iqs = ib * 16 + j;
+            let qsb = get_byte(block.qs[iqs / 4], iqs % 4);
+            sum += dl * f32(kvalues_iq4nl[qsb & 0xF]) * src1[src1_i];
+            sum += dl * f32(kvalues_iq4nl[qsb >> 4]) * src1[src1_i + 16];
+            src1_i++;
+        }
+        src1_i += 16;
+    }
+    return sum;
+}
+#endif
+
+struct MulMatParams {
+    offset_src0: u32, // in elements/blocks
+    offset_src1: u32, // in elements/blocks
+    offset_dst: u32, // in elements/blocks
+    m: u32,
+    n: u32,
+    k: u32,
+    // all strides are in elements/blocks
+    stride_01: u32,
+    stride_11: u32,
+    stride_02: u32,
+    stride_12: u32,
+    stride_03: u32,
+    stride_13: u32,
+
+    bs02: u32,
+    bs03: u32,
+    broadcast2: u32,
+    broadcast3: u32
+};
+
+@group(0) @binding(0) var<storage, read_write> src0: array<SRC0_TYPE>; // M rows, K columns
+@group(0) @binding(1) var<storage, read_write> src1: array<SRC1_TYPE>; // K rows, N columns (transposed)
+@group(0) @binding(2) var<storage, read_write> dst: array<f32>; // M rows, N columns
+
+@group(0) @binding(3) var<uniform> params: MulMatParams;
+
+@compute @workgroup_size(256)
+fn main(@builtin(global_invocation_id) global_id: vec3<u32>) {
+    let total = params.m * params.n * params.bs02 * params.broadcast2 * params.bs03 * params.broadcast3;
+    if (global_id.x >= total) {
+        return;
+    }
+
+    let dst2_stride = params.m * params.n;
+    let dst3_stride = dst2_stride * params.bs02 * params.broadcast2;
+
+    let dst3_idx = global_id.x / dst3_stride;
+    let src03_idx = dst3_idx / params.broadcast3; // src0 may be broadcast along the third dimension
+    let src13_idx = dst3_idx; // src1 is not broadcast
+    let dst3_rem = global_id.x % dst3_stride;
+
+    let dst2_idx = dst3_rem / dst2_stride;
+    let src02_idx = dst2_idx / params.broadcast2; // src0 may also be broadcast along the second dimension
+    let src12_idx = dst2_idx; // src1 is not broadcast
+
+    let dst2_rem = dst3_rem % dst2_stride;
+
+    let row = dst2_rem / params.m; // output row
+    let col = dst2_rem % params.m; // output column
+
+    let src0_idx_base = params.offset_src0 + src03_idx * params.stride_03 + src02_idx * params.stride_02 + col * params.stride_01;
+    let src1_idx_base = params.offset_src1 + src13_idx * params.stride_13 + src12_idx * params.stride_12 + row * params.stride_11;
+
+    var sum = 0.0;
+    for (var i: u32 = 0u; i < params.k/BLOCK_SIZE; i = i + 1u) {
+        sum += multiply_add(src0_idx_base, src1_idx_base, i);
+    }
+    dst[params.offset_dst + dst3_idx * dst3_stride + dst2_idx * dst2_stride + row * params.m + col] = sum;
+}
index 109ff8d6159e1ac040b9b6302fe72134bc39badc..5c1074ebc10afea4945240d96f645a4dfff61c25 100644 (file)
@@ -1,58 +1,65 @@
-#decl(SHMEM_VEC)
+#ifdef VEC
+#define VEC_SIZE 4
+#define SHMEM_TYPE vec4<f16>
+#define DST_TYPE vec4<f32>
+#define SRC0_TYPE vec4<SRC0_INNER_TYPE>
+#define SRC1_TYPE vec4<SRC1_INNER_TYPE>
+
 fn store_shmem(val: vec4<f16>, idx: u32) {
     shmem[idx] = val.x;
     shmem[idx + 1] = val.y;
     shmem[idx + 2] = val.z;
     shmem[idx + 3] = val.w;
 }
-#enddecl(SHMEM_VEC)
+#endif
+
+#ifdef SCALAR
+#define VEC_SIZE 1
+#define SHMEM_TYPE f16
+#define DST_TYPE f32
+#define SRC0_TYPE SRC0_INNER_TYPE
+#define SRC1_TYPE SRC1_INNER_TYPE
 
-#decl(SHMEM_SCALAR)
 fn store_shmem(val: f16, idx: u32) {
     shmem[idx] = val;
 }
-#enddecl(SHMEM_SCALAR)
-
-#decl(INIT_SRC0_SHMEM_FLOAT)
+#endif
 
+#ifdef INIT_SRC0_SHMEM_FLOAT
 fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u32) {
-    for (var elem_idx = thread_id * {{VEC_SIZE}}; elem_idx < TILE_SRC0_SHMEM; elem_idx += TOTAL_WORKGROUP_SIZE * {{VEC_SIZE}}) {
+    for (var elem_idx = thread_id * VEC_SIZE; elem_idx < TILE_SRC0_SHMEM; elem_idx += TOTAL_WORKGROUP_SIZE * VEC_SIZE) {
         let tile_m = elem_idx / TILE_K;
         let tile_k = elem_idx % TILE_K;
         let global_m = offset_m + tile_m;
         let global_k = k_outer + tile_k;
         let src0_idx = batch_offset + global_m * params.stride_01 + global_k;
         let src0_val = select( // taking a slight performance hit to avoid oob
-            {{SRC0_TYPE}}(0.0),
-            src0[src0_idx/{{VEC_SIZE}}],
+            SRC0_TYPE(0.0),
+            src0[src0_idx/VEC_SIZE],
             global_m < params.m && global_k < params.k);
-        store_shmem({{SHMEM_TYPE}}(src0_val), elem_idx);
+        store_shmem(SHMEM_TYPE(src0_val), elem_idx);
     }
 }
+#endif
 
-#enddecl(INIT_SRC0_SHMEM_FLOAT)
-
-#decl(INIT_SRC1_SHMEM)
-
+#ifdef INIT_SRC1_SHMEM_FLOAT
 fn init_shmem_src1(thread_id: u32, batch_offset: u32, offset_n: u32, k_outer: u32) {
-    for (var elem_idx = thread_id * {{VEC_SIZE}}; elem_idx < TILE_SRC1_SHMEM; elem_idx += TOTAL_WORKGROUP_SIZE * {{VEC_SIZE}}) {
+    for (var elem_idx = thread_id * VEC_SIZE; elem_idx < TILE_SRC1_SHMEM; elem_idx += TOTAL_WORKGROUP_SIZE * VEC_SIZE) {
         let tile_n = elem_idx / TILE_K;
         let tile_k = elem_idx % TILE_K;
         let global_n = offset_n + tile_n;
         let global_k = k_outer + tile_k;
         let src1_idx = batch_offset + global_n * params.stride_11 + global_k;
         let src1_val = select(
-            {{SRC1_TYPE}}(0.0),
-            src1[src1_idx/{{VEC_SIZE}}],
+            SRC1_TYPE(0.0),
+            src1[src1_idx/VEC_SIZE],
             global_n < params.n && global_k < params.k);
-        store_shmem({{SHMEM_TYPE}}(src1_val), TILE_SRC0_SHMEM + elem_idx);
+        store_shmem(SHMEM_TYPE(src1_val), TILE_SRC0_SHMEM + elem_idx);
     }
 }
+#endif
 
-#enddecl(INIT_SRC1_SHMEM)
-
-#decl(INIT_SRC0_SHMEM_Q4_0)
-
+#ifdef INIT_SRC0_SHMEM_Q4_0
 const BLOCK_SIZE = 32u;
 // the number of blocks per k-tile. Note that this currently only works if TILE_K is a multiple of BLOCK_SIZE, which may need to be rethought for larger quantized types.
 override BLOCKS_K = TILE_K/BLOCK_SIZE;
@@ -93,5 +100,4 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3
         }
     }
 }
-
-#enddecl(INIT_SRC0_SHMEM_Q4_0)
+#endif
diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_reg_tile.tmpl.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_reg_tile.tmpl.wgsl
deleted file mode 100644 (file)
index 6b1dd26..0000000
+++ /dev/null
@@ -1,247 +0,0 @@
-#define(VARIANTS)
-[
-  {
-    "SHADER_SUFFIX": "f32_f32_vec",
-    "REPLS": {
-      "SRC0_TYPE" : "vec4<f32>",
-      "SRC1_TYPE" : "vec4<f32>",
-      "DST_TYPE" : "vec4<f32>",
-      "SHMEM_TYPE" : "vec4<f16>",
-      "VEC_SIZE" : 4,
-    },
-    "DECLS": ["VEC", "SHMEM_VEC", "INIT_SRC0_SHMEM_FLOAT", "INIT_SRC1_SHMEM"]
-  },
-  {
-    "SHADER_SUFFIX": "f32_f32",
-    "REPLS": {
-      "SRC0_TYPE" : "f32",
-      "SRC1_TYPE" : "f32",
-      "DST_TYPE" : "f32",
-      "SHMEM_TYPE" : "f16",
-      "VEC_SIZE" : 1,
-    },
-    "DECLS": ["SCALAR", "SHMEM_SCALAR", "INIT_SRC0_SHMEM_FLOAT", "INIT_SRC1_SHMEM"]
-  },
-  {
-    "SHADER_SUFFIX": "f16_f32_vec",
-    "REPLS": {
-      "SRC0_TYPE" : "vec4<f16>",
-      "SRC1_TYPE" : "vec4<f32>",
-      "DST_TYPE" : "vec4<f32>",
-      "SHMEM_TYPE" : "vec4<f16>",
-      "VEC_SIZE" : 4,
-    },
-    "DECLS": ["VEC", "SHMEM_VEC", "INIT_SRC0_SHMEM_FLOAT", "INIT_SRC1_SHMEM"]
-  },
-  {
-    "SHADER_SUFFIX": "f16_f32",
-    "REPLS": {
-      "SRC0_TYPE" : "f16",
-      "SRC1_TYPE" : "f32",
-      "DST_TYPE" : "f32",
-      "SHMEM_TYPE" : "f16",
-      "VEC_SIZE" : 1,
-    },
-    "DECLS": ["SCALAR", "SHMEM_SCALAR", "INIT_SRC0_SHMEM_FLOAT", "INIT_SRC1_SHMEM"]
-  },
-  {
-    "SHADER_SUFFIX": "f16_f16_vec",
-    "REPLS": {
-      "SRC0_TYPE" : "vec4<f16>",
-      "SRC1_TYPE" : "vec4<f16>",
-      "DST_TYPE" : "vec4<f32>",
-      "SHMEM_TYPE" : "vec4<f16>",
-      "VEC_SIZE" : 4,
-    },
-    "DECLS": ["VEC", "SHMEM_VEC", "INIT_SRC0_SHMEM_FLOAT", "INIT_SRC1_SHMEM"]
-  },
-  {
-    "SHADER_SUFFIX": "f16_f16",
-    "REPLS": {
-      "SRC0_TYPE" : "f16",
-      "SRC1_TYPE" : "f16",
-      "DST_TYPE" : "f32",
-      "SHMEM_TYPE" : "f16",
-      "VEC_SIZE" : 1,
-    },
-    "DECLS": ["SCALAR", "SHMEM_SCALAR", "INIT_SRC0_SHMEM_FLOAT", "INIT_SRC1_SHMEM"]
-  },
-  {
-    "SHADER_SUFFIX": "q4_0_f32_vec",
-    "REPLS": {
-      "SRC0_TYPE" : "f16",
-      "SRC1_TYPE" : "vec4<f32>",
-      "DST_TYPE" : "vec4<f32>",
-      "SHMEM_TYPE" : "vec4<f16>",
-      "VEC_SIZE" : 4,
-    },
-    "DECLS": ["BYTE_HELPERS", "VEC", "SHMEM_VEC", "INIT_SRC0_SHMEM_Q4_0", "INIT_SRC1_SHMEM"]
-  },
-  {
-    "SHADER_SUFFIX": "q4_0_f32",
-    "REPLS": {
-      "SRC0_TYPE" : "f16",
-      "SRC1_TYPE" : "f32",
-      "DST_TYPE" : "f32",
-      "SHMEM_TYPE" : "f16",
-      "VEC_SIZE" : 1,
-    },
-    "DECLS": ["BYTE_HELPERS", "SCALAR", "SHMEM_SCALAR", "INIT_SRC0_SHMEM_Q4_0", "INIT_SRC1_SHMEM"]
-  }
-]
-
-#end(VARIANTS)
-
-#define(DECLS)
-
-#decl(VEC)
-fn store_val(acc: array<array<f16, TILE_N>, TILE_M>, tn: u32, tm: u32) -> vec4<f32> {
-    return vec4<f32>(f32(acc[tm][tn]), f32(acc[tm + 1][tn]), f32(acc[tm + 2][tn]), f32(acc[tm + 3][tn]));
-}
-#enddecl(VEC)
-
-#decl(SCALAR)
-fn store_val(acc: array<array<f16, TILE_N>, TILE_M>, tn: u32, tm: u32) -> f32 {
-    return f32(acc[tm][tn]);
-}
-#enddecl(SCALAR)
-
-#end(DECLS)
-
-#define(SHADER)
-enable f16;
-
-struct MulMatParams {
-    offset_src0: u32,
-    offset_src1: u32,
-    offset_dst: u32,
-    m: u32,
-    n: u32,
-    k: u32,
-    stride_01: u32,
-    stride_11: u32,
-    stride_02: u32,
-    stride_12: u32,
-    stride_03: u32,
-    stride_13: u32,
-    bs02: u32,
-    bs03: u32,
-    broadcast2: u32,
-    broadcast3: u32
-};
-
-@group(0) @binding(0) var<storage, read_write> src0: array<{{SRC0_TYPE}}>; // M rows, K columns
-@group(0) @binding(1) var<storage, read_write> src1: array<{{SRC1_TYPE}}>; // K rows, N columns (transposed)
-@group(0) @binding(2) var<storage, read_write> dst: array<{{DST_TYPE}}>; // M rows, N columns (transposed)
-
-@group(0) @binding(3) var<uniform> params: MulMatParams;
-
-DECLS
-
-fn get_local_n(thread_id: u32) -> u32 {
-    return thread_id / WORKGROUP_SIZE_M;
-}
-fn get_local_m(thread_id: u32) -> u32 {
-    return thread_id % WORKGROUP_SIZE_M;
-}
-
-// TILE_M must be multiple of 4 for vec4 loads
-const TILE_M = {{WEBGPU_TILE_M}}u;
-const TILE_N = {{WEBGPU_TILE_N}}u;
-
-override WORKGROUP_SIZE_M: u32;
-override WORKGROUP_SIZE_N: u32;
-override TILE_K: u32;
-
-override TOTAL_WORKGROUP_SIZE = WORKGROUP_SIZE_M * WORKGROUP_SIZE_N;
-override TILE_SRC0_SHMEM = TILE_K * WORKGROUP_SIZE_M * TILE_M;
-override TILE_SRC1_SHMEM = TILE_K * WORKGROUP_SIZE_N * TILE_N;
-
-var<workgroup> shmem: array<f16, TILE_SRC0_SHMEM + TILE_SRC1_SHMEM>;
-
-@compute @workgroup_size(TOTAL_WORKGROUP_SIZE)
-fn main(@builtin(workgroup_id) wg_id: vec3<u32>,
-        @builtin(local_invocation_id) local_id: vec3<u32>) {
-
-    let thread_id = local_id.x;
-    let local_m = get_local_m(thread_id);
-    let local_n = get_local_n(thread_id);
-
-    let wg_n_count = (params.n + WORKGROUP_SIZE_N * TILE_N - 1u) / (WORKGROUP_SIZE_N * TILE_N);
-    let wg_m_count = (params.m + WORKGROUP_SIZE_M * TILE_M - 1u) / (WORKGROUP_SIZE_M * TILE_M);
-    let wg_per_matrix = wg_m_count * wg_n_count;
-
-    let batch_idx = wg_id.x / wg_per_matrix;
-
-    let wg_in_batch = wg_id.x % wg_per_matrix;
-    let wg_m = wg_in_batch % wg_m_count;
-    let wg_n = wg_in_batch / wg_m_count;
-
-    let output_row_base = wg_m * WORKGROUP_SIZE_M * TILE_M + local_m * TILE_M;
-    let output_col_base = wg_n * WORKGROUP_SIZE_N * TILE_N + local_n * TILE_N;
-
-    let dst2_stride = params.m * params.n;
-    let dst3_stride = dst2_stride * params.bs02 * params.broadcast2;
-
-    let dst3_idx = batch_idx / (params.bs02 * params.broadcast2);
-    let src03_idx = dst3_idx / params.broadcast3;
-    let src13_idx = dst3_idx;
-    let dst2_idx = batch_idx % (params.bs02 * params.broadcast2);
-    let src02_idx = dst2_idx / params.broadcast2;
-    let src12_idx = dst2_idx;
-
-    let src0_batch_offset = params.offset_src0 + src03_idx * params.stride_03 + src02_idx * params.stride_02;
-    let src1_batch_offset = params.offset_src1 + src13_idx * params.stride_13 + src12_idx * params.stride_12;
-
-    let offset_m = wg_m * WORKGROUP_SIZE_M * TILE_M;
-    let offset_n = wg_n * WORKGROUP_SIZE_N * TILE_N;
-
-    var acc: array<array<f16, TILE_N>, TILE_M>;
-
-    for (var k_outer = 0u; k_outer < params.k; k_outer += TILE_K) {
-
-        // see mul_mat_decls.tmpl
-        init_shmem_src0(thread_id, src0_batch_offset, offset_m, k_outer);
-        init_shmem_src1(thread_id, src1_batch_offset, offset_n, k_outer);
-
-        workgroupBarrier();
-
-        let k_end = min(TILE_K, params.k - k_outer);
-
-        for (var k_inner = 0u; k_inner < k_end; k_inner++) {
-            var src0_tile: array<f16, TILE_M>;
-            for (var tm = 0u; tm < TILE_M; tm++) {
-                let src0_m = local_m * TILE_M + tm;
-                let src0_idx = k_inner + src0_m * TILE_K;
-                src0_tile[tm] = shmem[src0_idx];
-            }
-            for (var tn = 0u; tn < TILE_N; tn++) {
-                let src1_n = local_n * TILE_N + tn;
-                let src1_idx = src1_n * TILE_K + k_inner;
-                let src1_val = shmem[TILE_SRC0_SHMEM + src1_idx];
-                for (var tm = 0u; tm < TILE_M; tm++) {
-                      acc[tm][tn] += src0_tile[tm] * src1_val;
-                }
-            }
-        }
-
-        workgroupBarrier();
-    }
-
-    let dst_batch_offset = params.offset_dst + dst3_idx * dst3_stride + dst2_idx * dst2_stride;
-
-    for (var tn = 0u; tn < TILE_N; tn++) {
-        let global_col = output_col_base + tn;
-        if (global_col < params.n) {
-            for (var tm = 0u; tm < TILE_M; tm += {{VEC_SIZE}}) {
-                let global_row = output_row_base + tm;
-                if (global_row < params.m) {
-                    let dst_idx = dst_batch_offset + global_col * params.m + global_row;
-                    dst[dst_idx/{{VEC_SIZE}}] = store_val(acc, tn, tm);
-                }
-            }
-        }
-    }
-}
-
-#end(SHADER)
diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_reg_tile.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_reg_tile.wgsl
new file mode 100644 (file)
index 0000000..771e5cd
--- /dev/null
@@ -0,0 +1,138 @@
+enable f16;
+
+#include "common_decls.tmpl"
+#include "mul_mat_decls.tmpl"
+
+#ifdef VEC
+fn store_val(acc: array<array<f16, TILE_N>, TILE_M>, tn: u32, tm: u32) -> vec4<f32> {
+    return vec4<f32>(f32(acc[tm][tn]), f32(acc[tm + 1][tn]), f32(acc[tm + 2][tn]), f32(acc[tm + 3][tn]));
+}
+#endif
+
+#ifdef SCALAR
+fn store_val(acc: array<array<f16, TILE_N>, TILE_M>, tn: u32, tm: u32) -> f32 {
+    return f32(acc[tm][tn]);
+}
+#endif
+
+struct MulMatParams {
+    offset_src0: u32,
+    offset_src1: u32,
+    offset_dst: u32,
+    m: u32,
+    n: u32,
+    k: u32,
+    stride_01: u32,
+    stride_11: u32,
+    stride_02: u32,
+    stride_12: u32,
+    stride_03: u32,
+    stride_13: u32,
+    bs02: u32,
+    bs03: u32,
+    broadcast2: u32,
+    broadcast3: u32
+};
+
+@group(0) @binding(0) var<storage, read_write> src0: array<SRC0_TYPE>; // M rows, K columns
+@group(0) @binding(1) var<storage, read_write> src1: array<SRC1_TYPE>; // K rows, N columns (transposed)
+@group(0) @binding(2) var<storage, read_write> dst: array<DST_TYPE>; // M rows, N columns (transposed)
+
+@group(0) @binding(3) var<uniform> params: MulMatParams;
+
+fn get_local_n(thread_id: u32) -> u32 {
+    return thread_id / WORKGROUP_SIZE_M;
+}
+fn get_local_m(thread_id: u32) -> u32 {
+    return thread_id % WORKGROUP_SIZE_M;
+}
+
+const TOTAL_WORKGROUP_SIZE = WORKGROUP_SIZE_M * WORKGROUP_SIZE_N;
+const TILE_SRC0_SHMEM = TILE_K * WORKGROUP_SIZE_M * TILE_M;
+const TILE_SRC1_SHMEM = TILE_K * WORKGROUP_SIZE_N * TILE_N;
+var<workgroup> shmem: array<f16, TILE_SRC0_SHMEM + TILE_SRC1_SHMEM>;
+
+@compute @workgroup_size(TOTAL_WORKGROUP_SIZE)
+fn main(@builtin(workgroup_id) wg_id: vec3<u32>,
+        @builtin(local_invocation_id) local_id: vec3<u32>) {
+
+    let thread_id = local_id.x;
+    let local_m = get_local_m(thread_id);
+    let local_n = get_local_n(thread_id);
+
+    let wg_n_count = (params.n + WORKGROUP_SIZE_N * TILE_N - 1u) / (WORKGROUP_SIZE_N * TILE_N);
+    let wg_m_count = (params.m + WORKGROUP_SIZE_M * TILE_M - 1u) / (WORKGROUP_SIZE_M * TILE_M);
+    let wg_per_matrix = wg_m_count * wg_n_count;
+
+    let batch_idx = wg_id.x / wg_per_matrix;
+
+    let wg_in_batch = wg_id.x % wg_per_matrix;
+    let wg_m = wg_in_batch % wg_m_count;
+    let wg_n = wg_in_batch / wg_m_count;
+
+    let output_row_base = wg_m * WORKGROUP_SIZE_M * TILE_M + local_m * TILE_M;
+    let output_col_base = wg_n * WORKGROUP_SIZE_N * TILE_N + local_n * TILE_N;
+
+    let dst2_stride = params.m * params.n;
+    let dst3_stride = dst2_stride * params.bs02 * params.broadcast2;
+
+    let dst3_idx = batch_idx / (params.bs02 * params.broadcast2);
+    let src03_idx = dst3_idx / params.broadcast3;
+    let src13_idx = dst3_idx;
+    let dst2_idx = batch_idx % (params.bs02 * params.broadcast2);
+    let src02_idx = dst2_idx / params.broadcast2;
+    let src12_idx = dst2_idx;
+
+    let src0_batch_offset = params.offset_src0 + src03_idx * params.stride_03 + src02_idx * params.stride_02;
+    let src1_batch_offset = params.offset_src1 + src13_idx * params.stride_13 + src12_idx * params.stride_12;
+
+    let offset_m = wg_m * WORKGROUP_SIZE_M * TILE_M;
+    let offset_n = wg_n * WORKGROUP_SIZE_N * TILE_N;
+
+    var acc: array<array<f16, TILE_N>, TILE_M>;
+
+    for (var k_outer = 0u; k_outer < params.k; k_outer += TILE_K) {
+
+        // see mul_mat_decls.tmpl
+        init_shmem_src0(thread_id, src0_batch_offset, offset_m, k_outer);
+        init_shmem_src1(thread_id, src1_batch_offset, offset_n, k_outer);
+
+        workgroupBarrier();
+
+        let k_end = min(TILE_K, params.k - k_outer);
+
+        for (var k_inner = 0u; k_inner < k_end; k_inner++) {
+            var src0_tile: array<f16, TILE_M>;
+            for (var tm = 0u; tm < TILE_M; tm++) {
+                let src0_m = local_m * TILE_M + tm;
+                let src0_idx = k_inner + src0_m * TILE_K;
+                src0_tile[tm] = shmem[src0_idx];
+            }
+            for (var tn = 0u; tn < TILE_N; tn++) {
+                let src1_n = local_n * TILE_N + tn;
+                let src1_idx = src1_n * TILE_K + k_inner;
+                let src1_val = shmem[TILE_SRC0_SHMEM + src1_idx];
+                for (var tm = 0u; tm < TILE_M; tm++) {
+                      acc[tm][tn] += src0_tile[tm] * src1_val;
+                }
+            }
+        }
+
+        workgroupBarrier();
+    }
+
+    let dst_batch_offset = params.offset_dst + dst3_idx * dst3_stride + dst2_idx * dst2_stride;
+
+    for (var tn = 0u; tn < TILE_N; tn++) {
+        let global_col = output_col_base + tn;
+        if (global_col < params.n) {
+            for (var tm = 0u; tm < TILE_M; tm += VEC_SIZE) {
+                let global_row = output_row_base + tm;
+                if (global_row < params.m) {
+                    let dst_idx = dst_batch_offset + global_col * params.m + global_row;
+                    dst[dst_idx/VEC_SIZE] = store_val(acc, tn, tm);
+                }
+            }
+        }
+    }
+}
diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_subgroup_matrix.tmpl.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_subgroup_matrix.tmpl.wgsl
deleted file mode 100644 (file)
index 47c8ce3..0000000
+++ /dev/null
@@ -1,302 +0,0 @@
-#define(VARIANTS)
-[
-  {
-    "SHADER_SUFFIX": "f32_f32_vec",
-    "REPLS": {
-      "SRC0_TYPE" : "vec4<f32>",
-      "SRC1_TYPE" : "vec4<f32>",
-      "DST_TYPE" : "vec4<f32>",
-      "SHMEM_TYPE" : "vec4<f16>",
-      "VEC_SIZE" : 4,
-    },
-    "DECLS": ["VEC", "SHMEM_VEC", "INIT_SRC0_SHMEM_FLOAT", "INIT_SRC1_SHMEM"]
-  },
-  {
-    "SHADER_SUFFIX": "f32_f32",
-    "REPLS": {
-      "SRC0_TYPE" : "f32",
-      "SRC1_TYPE" : "f32",
-      "DST_TYPE" : "f32",
-      "SHMEM_TYPE" : "f16",
-      "VEC_SIZE" : 1,
-    },
-    "DECLS": ["SCALAR", "SHMEM_SCALAR", "INIT_SRC0_SHMEM_FLOAT", "INIT_SRC1_SHMEM"]
-  },
-  {
-    "SHADER_SUFFIX": "f16_f32_vec",
-    "REPLS": {
-      "SRC0_TYPE" : "vec4<f16>",
-      "SRC1_TYPE" : "vec4<f32>",
-      "DST_TYPE" : "vec4<f32>",
-      "SHMEM_TYPE" : "vec4<f16>",
-      "VEC_SIZE" : 4,
-    },
-    "DECLS": ["VEC", "SHMEM_VEC", "INIT_SRC0_SHMEM_FLOAT", "INIT_SRC1_SHMEM"]
-  },
-  {
-    "SHADER_SUFFIX": "f16_f32",
-    "REPLS": {
-      "SRC0_TYPE" : "f16",
-      "SRC1_TYPE" : "f32",
-      "DST_TYPE" : "f32",
-      "SHMEM_TYPE" : "f16",
-      "VEC_SIZE" : 1,
-    },
-    "DECLS": ["SCALAR", "SHMEM_SCALAR", "INIT_SRC0_SHMEM_FLOAT", "INIT_SRC1_SHMEM"]
-  },
-  {
-    "SHADER_SUFFIX": "f16_f16_vec",
-    "REPLS": {
-      "SRC0_TYPE" : "vec4<f16>",
-      "SRC1_TYPE" : "vec4<f16>",
-      "DST_TYPE" : "vec4<f32>",
-      "SHMEM_TYPE" : "vec4<f16>",
-      "VEC_SIZE" : 4,
-    },
-    "DECLS": ["VEC", "SHMEM_VEC", "INIT_SRC0_SHMEM_FLOAT", "INIT_SRC1_SHMEM"]
-  },
-  {
-    "SHADER_SUFFIX": "f16_f16",
-    "REPLS": {
-      "SRC0_TYPE" : "f16",
-      "SRC1_TYPE" : "f16",
-      "DST_TYPE" : "f32",
-      "SHMEM_TYPE" : "f16",
-      "VEC_SIZE" : 1,
-    },
-    "DECLS": ["SCALAR", "SHMEM_SCALAR", "INIT_SRC0_SHMEM_FLOAT", "INIT_SRC1_SHMEM"]
-  },
-  {
-    "SHADER_SUFFIX": "q4_0_f32_vec",
-    "REPLS": {
-      "SRC0_TYPE" : "f16",
-      "SRC1_TYPE" : "vec4<f32>",
-      "DST_TYPE" : "vec4<f32>",
-      "SHMEM_TYPE" : "vec4<f16>",
-      "VEC_SIZE" : 4,
-    },
-    "DECLS": ["BYTE_HELPERS", "VEC", "SHMEM_VEC", "INIT_SRC0_SHMEM_Q4_0", "INIT_SRC1_SHMEM"]
-  },
-  {
-    "SHADER_SUFFIX": "q4_0_f32",
-    "REPLS": {
-      "SRC0_TYPE" : "f16",
-      "SRC1_TYPE" : "f32",
-      "DST_TYPE" : "f32",
-      "SHMEM_TYPE" : "f16",
-      "VEC_SIZE" : 1,
-    },
-    "DECLS": ["BYTE_HELPERS", "SCALAR", "SHMEM_SCALAR", "INIT_SRC0_SHMEM_Q4_0", "INIT_SRC1_SHMEM"]
-  }
-]
-
-#end(VARIANTS)
-
-#define(DECLS)
-
-#decl(VEC)
-fn store_dst(shmem_idx: u32, dst_idx: u32) {
-    dst[dst_idx] = vec4<f32>(
-        f32(shmem[shmem_idx]),
-        f32(shmem[shmem_idx + 1]),
-        f32(shmem[shmem_idx + 2]),
-        f32(shmem[shmem_idx + 3])
-    );
-}
-#enddecl(VEC)
-
-#decl(SCALAR)
-fn store_dst(shmem_idx: u32, dst_idx: u32) {
-    dst[dst_idx] = f32(shmem[shmem_idx]);
-}
-#enddecl(SCALAR)
-
-#end(DECLS)
-
-#define(SHADER)
-diagnostic(off, chromium.subgroup_matrix_uniformity);
-enable f16;
-enable subgroups;
-enable chromium_experimental_subgroup_matrix;
-
-struct MulMatParams {
-    offset_src0: u32,
-    offset_src1: u32,
-    offset_dst: u32,
-    m: u32,
-    n: u32,
-    k: u32,
-    stride_01: u32,
-    stride_11: u32,
-    stride_02: u32,
-    stride_12: u32,
-    stride_03: u32,
-    stride_13: u32,
-    bs02: u32,
-    bs03: u32,
-    broadcast2: u32,
-    broadcast3: u32
-};
-
-@group(0) @binding(0) var<storage, read_write> src0: array<{{SRC0_TYPE}}>; // M rows, K columns
-@group(0) @binding(1) var<storage, read_write> src1: array<{{SRC1_TYPE}}>; // K rows, N columns (transposed)
-@group(0) @binding(2) var<storage, read_write> dst: array<{{DST_TYPE}}>; // M rows, N columns (transposed)
-
-@group(0) @binding(3) var<uniform> params: MulMatParams;
-
-DECLS
-
-// Note: These are string interpolated at build time, cannot use override constants due to limitations in
-// current Dawn version type definitions/matrix load requirements for constant memory sizes.
-const SUBGROUP_M = {{WEBGPU_SUBGROUP_M}}u;
-const SUBGROUP_N = {{WEBGPU_SUBGROUP_N}}u;
-// For portability we assume the max subgroup size, meaning some subgroups will be masked out if the
-// runtime subgroup size is smaller.
-const MAX_SUBGROUP_SIZE = {{WEBGPU_MAX_SUBGROUP_SIZE}}u;
-
-const EXPECTED_SUBGROUPS = SUBGROUP_M * SUBGROUP_N;
-
-const SUBGROUP_MATRIX_M_SIZE = {{WEBGPU_SG_MAT_M_SIZE}}u;
-const SUBGROUP_MATRIX_N_SIZE = {{WEBGPU_SG_MAT_N_SIZE}}u;
-const SUBGROUP_MATRIX_K_SIZE = {{WEBGPU_SG_MAT_K_SIZE}}u;
-
-const SUBGROUP_MATRIX_M = {{WEBGPU_SUBGROUP_MATRIX_M}}u;
-const SUBGROUP_MATRIX_N = {{WEBGPU_SUBGROUP_MATRIX_N}}u;
-
-const TILE_K = {{WEBGPU_TILE_K}}u;
-
-const WG_M_SG_TILE_SIZE = SUBGROUP_M * SUBGROUP_MATRIX_M * SUBGROUP_MATRIX_M_SIZE;
-const WG_N_SG_TILE_SIZE = SUBGROUP_N * SUBGROUP_MATRIX_N * SUBGROUP_MATRIX_N_SIZE;
-
-const TOTAL_WORKGROUP_SIZE = SUBGROUP_M * SUBGROUP_N * MAX_SUBGROUP_SIZE;
-const TILE_SRC0_SHMEM = TILE_K * SUBGROUP_M * SUBGROUP_MATRIX_M * SUBGROUP_MATRIX_M_SIZE;
-const TILE_SRC1_SHMEM = TILE_K * SUBGROUP_N * SUBGROUP_MATRIX_N * SUBGROUP_MATRIX_N_SIZE;
-
-const SG_MAT_ACCUM_SHMEM = SUBGROUP_M * SUBGROUP_MATRIX_M * SUBGROUP_N * SUBGROUP_MATRIX_N * SUBGROUP_MATRIX_M_SIZE * SUBGROUP_MATRIX_N_SIZE;
-
-// We reuse shmem for accumulation matrices
-const SHMEM_SIZE = max(TILE_SRC0_SHMEM + TILE_SRC1_SHMEM, SG_MAT_ACCUM_SHMEM);
-
-var<workgroup> shmem: array<f16, SHMEM_SIZE>;
-
-@compute @workgroup_size(TOTAL_WORKGROUP_SIZE)
-fn main(@builtin(workgroup_id) wg_id: vec3<u32>,
-        @builtin(local_invocation_id) local_id: vec3<u32>,
-        @builtin(subgroup_id) subgroup_id: u32) {
-
-    let thread_id = local_id.x;
-    let subgroup_m = subgroup_id % SUBGROUP_M;
-    let subgroup_n = subgroup_id / SUBGROUP_M;
-
-    let wg_m_count = (params.m + WG_M_SG_TILE_SIZE - 1) / WG_M_SG_TILE_SIZE;
-    let wg_n_count = (params.n + WG_N_SG_TILE_SIZE - 1) / WG_N_SG_TILE_SIZE;
-    let wg_per_matrix = wg_m_count * wg_n_count;
-
-    let batch_idx = wg_id.x / wg_per_matrix;
-
-    let wg_in_batch = wg_id.x % wg_per_matrix;
-    let wg_m = wg_in_batch % wg_m_count;
-    let wg_n = wg_in_batch / wg_m_count;
-
-    let dst2_stride = params.m * params.n;
-    let dst3_stride = dst2_stride * params.bs02 * params.broadcast2;
-
-    let dst3_idx = batch_idx / (params.bs02 * params.broadcast2);
-    let src03_idx = dst3_idx / params.broadcast3;
-    let src13_idx = dst3_idx;
-    let dst2_idx = batch_idx % (params.bs02 * params.broadcast2);
-    let src02_idx = dst2_idx / params.broadcast2;
-    let src12_idx = dst2_idx;
-
-    let src0_batch_offset = params.offset_src0 + src03_idx * params.stride_03 + src02_idx * params.stride_02;
-    let src1_batch_offset = params.offset_src1 + src13_idx * params.stride_13 + src12_idx * params.stride_12;
-
-    let offset_m = wg_m * SUBGROUP_M * SUBGROUP_MATRIX_M * SUBGROUP_MATRIX_M_SIZE;
-    let offset_n = wg_n * SUBGROUP_N * SUBGROUP_MATRIX_N * SUBGROUP_MATRIX_N_SIZE;
-
-    var acc_sg_mat : array<array<subgroup_matrix_result<f16, SUBGROUP_MATRIX_N_SIZE, SUBGROUP_MATRIX_M_SIZE>, SUBGROUP_MATRIX_N>, SUBGROUP_MATRIX_M>;
-
-    for (var k_outer = 0u; k_outer < params.k; k_outer += TILE_K) {
-
-        // see mul_mat_decls.tmpl
-        init_shmem_src0(thread_id, src0_batch_offset, offset_m, k_outer);
-        init_shmem_src1(thread_id, src1_batch_offset, offset_n, k_outer);
-
-        workgroupBarrier();
-
-        if (subgroup_id < EXPECTED_SUBGROUPS) {
-
-            for (var k_inner = 0u; k_inner < TILE_K; k_inner += SUBGROUP_MATRIX_K_SIZE) {
-
-                let src0_shmem_idx_base = subgroup_m * SUBGROUP_MATRIX_M * SUBGROUP_MATRIX_M_SIZE * TILE_K + k_inner;
-                var src0_sg_mats: array<subgroup_matrix_left<f16, SUBGROUP_MATRIX_K_SIZE, SUBGROUP_MATRIX_M_SIZE>, SUBGROUP_MATRIX_M>;
-                for (var m = 0u; m < SUBGROUP_MATRIX_M; m++) {
-                    src0_sg_mats[m] = subgroupMatrixLoad<subgroup_matrix_left<f16, SUBGROUP_MATRIX_K_SIZE, SUBGROUP_MATRIX_M_SIZE>>(
-                        &shmem,
-                        src0_shmem_idx_base + m * SUBGROUP_MATRIX_M_SIZE * TILE_K,
-                        false,
-                        TILE_K
-                    );
-                }
-
-                let src1_shmem_idx_base = TILE_SRC0_SHMEM + subgroup_n * SUBGROUP_MATRIX_N * SUBGROUP_MATRIX_N_SIZE * TILE_K + k_inner;
-                for (var n = 0u; n < SUBGROUP_MATRIX_N; n++) {
-                    let src1_sg_mat = subgroupMatrixLoad<subgroup_matrix_right<f16, SUBGROUP_MATRIX_N_SIZE, SUBGROUP_MATRIX_K_SIZE>>(
-                        &shmem,
-                        src1_shmem_idx_base + n * SUBGROUP_MATRIX_N_SIZE * TILE_K,
-                        true,
-                        TILE_K
-                    );
-                    for (var m = 0u; m < SUBGROUP_MATRIX_M; m++) {
-                        acc_sg_mat[m][n] = subgroupMatrixMultiplyAccumulate(src0_sg_mats[m], src1_sg_mat, acc_sg_mat[m][n]);
-                    }
-                }
-            }
-        }
-
-        workgroupBarrier();
-    }
-
-    let dst_batch_offset = params.offset_dst + dst3_idx * dst3_stride + dst2_idx * dst2_stride;
-
-    // Stage the subgroup matrix tiles into shared memory
-    // This uses WG_M_SG_TILE_SIZE as the stride (number of columns in the workgroup tile).
-    let WG_TILE_STRIDE = WG_M_SG_TILE_SIZE;
-    let tile_row_base_local = subgroup_n * SUBGROUP_MATRIX_N * SUBGROUP_MATRIX_N_SIZE;
-    let tile_col_base_local = subgroup_m * SUBGROUP_MATRIX_M * SUBGROUP_MATRIX_M_SIZE;
-
-    if (subgroup_id < EXPECTED_SUBGROUPS) { // 2-5% performance hit :(
-        for (var n = 0u; n < SUBGROUP_MATRIX_N; n++) {
-            for (var m = 0u; m < SUBGROUP_MATRIX_M; m++) {
-                let local_row = tile_row_base_local + n * SUBGROUP_MATRIX_N_SIZE;
-                let local_col = tile_col_base_local + m * SUBGROUP_MATRIX_M_SIZE;
-                let out_base = local_row * WG_TILE_STRIDE + local_col;
-                subgroupMatrixStore(&shmem, out_base, acc_sg_mat[m][n], true, WG_TILE_STRIDE);
-            }
-        }
-    }
-
-    workgroupBarrier();
-
-    // Cooperative write: iterate over the entire workgroup tile
-    let tile_rows = WG_N_SG_TILE_SIZE;
-    let tile_cols = WG_M_SG_TILE_SIZE;
-    let total_tile_elems = tile_rows * tile_cols;
-    let tile_dst_row_base = wg_m * SUBGROUP_M * SUBGROUP_MATRIX_M * SUBGROUP_MATRIX_M_SIZE;
-    let tile_dst_col_base = wg_n * SUBGROUP_N * SUBGROUP_MATRIX_N * SUBGROUP_MATRIX_N_SIZE;
-
-    for (var idx = thread_id * {{VEC_SIZE}}; idx < total_tile_elems; idx += TOTAL_WORKGROUP_SIZE * {{VEC_SIZE}}) {
-        let local_row = idx % WG_TILE_STRIDE;
-        let local_col = idx / WG_TILE_STRIDE;
-
-        let global_row = tile_dst_row_base + local_row;
-        let global_col = tile_dst_col_base + local_col;
-
-        if (global_col < params.n && global_row < params.m) {
-            let dst_idx = dst_batch_offset + global_col * params.m + global_row;
-            store_dst(idx, dst_idx/{{VEC_SIZE}});
-        }
-    }
-}
-
-#end(SHADER)
diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_subgroup_matrix.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_subgroup_matrix.wgsl
new file mode 100644 (file)
index 0000000..64529e0
--- /dev/null
@@ -0,0 +1,188 @@
+diagnostic(off, chromium.subgroup_matrix_uniformity);
+enable f16;
+enable subgroups;
+enable chromium_experimental_subgroup_matrix;
+
+#include "common_decls.tmpl"
+#include "mul_mat_decls.tmpl"
+
+#ifdef VEC
+fn store_dst(shmem_idx: u32, dst_idx: u32) {
+    dst[dst_idx] = vec4<f32>(
+        f32(shmem[shmem_idx]),
+        f32(shmem[shmem_idx + 1]),
+        f32(shmem[shmem_idx + 2]),
+        f32(shmem[shmem_idx + 3])
+    );
+}
+#endif
+
+#ifdef SCALAR
+fn store_dst(shmem_idx: u32, dst_idx: u32) {
+    dst[dst_idx] = f32(shmem[shmem_idx]);
+}
+#endif
+
+struct MulMatParams {
+    offset_src0: u32,
+    offset_src1: u32,
+    offset_dst: u32,
+    m: u32,
+    n: u32,
+    k: u32,
+    stride_01: u32,
+    stride_11: u32,
+    stride_02: u32,
+    stride_12: u32,
+    stride_03: u32,
+    stride_13: u32,
+    bs02: u32,
+    bs03: u32,
+    broadcast2: u32,
+    broadcast3: u32
+};
+
+// SRC0_TYPE and SRC1_TYPE are defined in mul_mat_decls, which is included
+@group(0) @binding(0) var<storage, read_write> src0: array<SRC0_TYPE>; // M rows, K columns
+@group(0) @binding(1) var<storage, read_write> src1: array<SRC1_TYPE>; // K rows, N columns (transposed)
+@group(0) @binding(2) var<storage, read_write> dst: array<DST_TYPE>; // M rows, N columns (transposed)
+
+@group(0) @binding(3) var<uniform> params: MulMatParams;
+
+const WG_M_SG_TILE_SIZE = SUBGROUP_M * SUBGROUP_MATRIX_M * SUBGROUP_MATRIX_M_SIZE;
+const WG_N_SG_TILE_SIZE = SUBGROUP_N * SUBGROUP_MATRIX_N * SUBGROUP_MATRIX_N_SIZE;
+
+// For portability we assume the max subgroup size, meaning some subgroups will be masked out if the
+// runtime subgroup size is smaller.
+const EXPECTED_SUBGROUPS = SUBGROUP_M * SUBGROUP_N;
+const TOTAL_WORKGROUP_SIZE = SUBGROUP_M * SUBGROUP_N * MAX_SUBGROUP_SIZE;
+const TILE_SRC0_SHMEM = TILE_K * SUBGROUP_M * SUBGROUP_MATRIX_M * SUBGROUP_MATRIX_M_SIZE;
+const TILE_SRC1_SHMEM = TILE_K * SUBGROUP_N * SUBGROUP_MATRIX_N * SUBGROUP_MATRIX_N_SIZE;
+
+const SG_MAT_ACCUM_SHMEM = SUBGROUP_M * SUBGROUP_MATRIX_M * SUBGROUP_N * SUBGROUP_MATRIX_N * SUBGROUP_MATRIX_M_SIZE * SUBGROUP_MATRIX_N_SIZE;
+
+// We reuse shmem for accumulation matrices
+const SHMEM_SIZE = max(TILE_SRC0_SHMEM + TILE_SRC1_SHMEM, SG_MAT_ACCUM_SHMEM);
+
+var<workgroup> shmem: array<f16, SHMEM_SIZE>;
+
+@compute @workgroup_size(TOTAL_WORKGROUP_SIZE)
+fn main(@builtin(workgroup_id) wg_id: vec3<u32>,
+        @builtin(local_invocation_id) local_id: vec3<u32>,
+        @builtin(subgroup_id) subgroup_id: u32) {
+
+    let thread_id = local_id.x;
+    let subgroup_m = subgroup_id % SUBGROUP_M;
+    let subgroup_n = subgroup_id / SUBGROUP_M;
+
+    let wg_m_count = (params.m + WG_M_SG_TILE_SIZE - 1) / WG_M_SG_TILE_SIZE;
+    let wg_n_count = (params.n + WG_N_SG_TILE_SIZE - 1) / WG_N_SG_TILE_SIZE;
+    let wg_per_matrix = wg_m_count * wg_n_count;
+
+    let batch_idx = wg_id.x / wg_per_matrix;
+
+    let wg_in_batch = wg_id.x % wg_per_matrix;
+    let wg_m = wg_in_batch % wg_m_count;
+    let wg_n = wg_in_batch / wg_m_count;
+
+    let dst2_stride = params.m * params.n;
+    let dst3_stride = dst2_stride * params.bs02 * params.broadcast2;
+
+    let dst3_idx = batch_idx / (params.bs02 * params.broadcast2);
+    let src03_idx = dst3_idx / params.broadcast3;
+    let src13_idx = dst3_idx;
+    let dst2_idx = batch_idx % (params.bs02 * params.broadcast2);
+    let src02_idx = dst2_idx / params.broadcast2;
+    let src12_idx = dst2_idx;
+
+    let src0_batch_offset = params.offset_src0 + src03_idx * params.stride_03 + src02_idx * params.stride_02;
+    let src1_batch_offset = params.offset_src1 + src13_idx * params.stride_13 + src12_idx * params.stride_12;
+
+    let offset_m = wg_m * SUBGROUP_M * SUBGROUP_MATRIX_M * SUBGROUP_MATRIX_M_SIZE;
+    let offset_n = wg_n * SUBGROUP_N * SUBGROUP_MATRIX_N * SUBGROUP_MATRIX_N_SIZE;
+
+    var acc_sg_mat : array<array<subgroup_matrix_result<f16, SUBGROUP_MATRIX_N_SIZE, SUBGROUP_MATRIX_M_SIZE>, SUBGROUP_MATRIX_N>, SUBGROUP_MATRIX_M>;
+
+    for (var k_outer = 0u; k_outer < params.k; k_outer += TILE_K) {
+
+        // see mul_mat_decls.tmpl
+        init_shmem_src0(thread_id, src0_batch_offset, offset_m, k_outer);
+        init_shmem_src1(thread_id, src1_batch_offset, offset_n, k_outer);
+
+        workgroupBarrier();
+
+        if (subgroup_id < EXPECTED_SUBGROUPS) {
+
+            for (var k_inner = 0u; k_inner < TILE_K; k_inner += SUBGROUP_MATRIX_K_SIZE) {
+
+                let src0_shmem_idx_base = subgroup_m * SUBGROUP_MATRIX_M * SUBGROUP_MATRIX_M_SIZE * TILE_K + k_inner;
+                var src0_sg_mats: array<subgroup_matrix_left<f16, SUBGROUP_MATRIX_K_SIZE, SUBGROUP_MATRIX_M_SIZE>, SUBGROUP_MATRIX_M>;
+                for (var m = 0u; m < SUBGROUP_MATRIX_M; m++) {
+                    src0_sg_mats[m] = subgroupMatrixLoad<subgroup_matrix_left<f16, SUBGROUP_MATRIX_K_SIZE, SUBGROUP_MATRIX_M_SIZE>>(
+                        &shmem,
+                        src0_shmem_idx_base + m * SUBGROUP_MATRIX_M_SIZE * TILE_K,
+                        false,
+                        TILE_K
+                    );
+                }
+
+                let src1_shmem_idx_base = TILE_SRC0_SHMEM + subgroup_n * SUBGROUP_MATRIX_N * SUBGROUP_MATRIX_N_SIZE * TILE_K + k_inner;
+                for (var n = 0u; n < SUBGROUP_MATRIX_N; n++) {
+                    let src1_sg_mat = subgroupMatrixLoad<subgroup_matrix_right<f16, SUBGROUP_MATRIX_N_SIZE, SUBGROUP_MATRIX_K_SIZE>>(
+                        &shmem,
+                        src1_shmem_idx_base + n * SUBGROUP_MATRIX_N_SIZE * TILE_K,
+                        true,
+                        TILE_K
+                    );
+                    for (var m = 0u; m < SUBGROUP_MATRIX_M; m++) {
+                        acc_sg_mat[m][n] = subgroupMatrixMultiplyAccumulate(src0_sg_mats[m], src1_sg_mat, acc_sg_mat[m][n]);
+                    }
+                }
+            }
+        }
+
+        workgroupBarrier();
+    }
+
+    let dst_batch_offset = params.offset_dst + dst3_idx * dst3_stride + dst2_idx * dst2_stride;
+
+    // Stage the subgroup matrix tiles into shared memory
+    // This uses WG_M_SG_TILE_SIZE as the stride (number of columns in the workgroup tile).
+    let WG_TILE_STRIDE = WG_M_SG_TILE_SIZE;
+    let tile_row_base_local = subgroup_n * SUBGROUP_MATRIX_N * SUBGROUP_MATRIX_N_SIZE;
+    let tile_col_base_local = subgroup_m * SUBGROUP_MATRIX_M * SUBGROUP_MATRIX_M_SIZE;
+
+    if (subgroup_id < EXPECTED_SUBGROUPS) { // 2-5% performance hit :(
+        for (var n = 0u; n < SUBGROUP_MATRIX_N; n++) {
+            for (var m = 0u; m < SUBGROUP_MATRIX_M; m++) {
+                let local_row = tile_row_base_local + n * SUBGROUP_MATRIX_N_SIZE;
+                let local_col = tile_col_base_local + m * SUBGROUP_MATRIX_M_SIZE;
+                let out_base = local_row * WG_TILE_STRIDE + local_col;
+                subgroupMatrixStore(&shmem, out_base, acc_sg_mat[m][n], true, WG_TILE_STRIDE);
+            }
+        }
+    }
+
+    workgroupBarrier();
+
+    // Cooperative write: iterate over the entire workgroup tile
+    let tile_rows = WG_N_SG_TILE_SIZE;
+    let tile_cols = WG_M_SG_TILE_SIZE;
+    let total_tile_elems = tile_rows * tile_cols;
+    let tile_dst_row_base = wg_m * SUBGROUP_M * SUBGROUP_MATRIX_M * SUBGROUP_MATRIX_M_SIZE;
+    let tile_dst_col_base = wg_n * SUBGROUP_N * SUBGROUP_MATRIX_N * SUBGROUP_MATRIX_N_SIZE;
+
+    for (var idx = thread_id * VEC_SIZE; idx < total_tile_elems; idx += TOTAL_WORKGROUP_SIZE * VEC_SIZE) {
+        let local_row = idx % WG_TILE_STRIDE;
+        let local_col = idx / WG_TILE_STRIDE;
+
+        let global_row = tile_dst_row_base + local_row;
+        let global_col = tile_dst_col_base + local_col;
+
+        if (global_col < params.n && global_row < params.m) {
+            let dst_idx = dst_batch_offset + global_col * params.m + global_row;
+            store_dst(idx, dst_idx/VEC_SIZE);
+        }
+    }
+}
+
diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec.tmpl.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec.tmpl.wgsl
deleted file mode 100644 (file)
index ffbb640..0000000
+++ /dev/null
@@ -1,267 +0,0 @@
-#define(VARIANTS)
-[
-  {
-    "SHADER_SUFFIX": "f32_f32_vec",
-    "REPLS": {
-      "SRC0_TYPE" : "vec4<f32>",
-      "SRC1_TYPE" : "vec4<f32>",
-      "DST_TYPE": "vec4<f32>",
-      "VEC_SIZE" : 4,
-    },
-    "DECLS": ["VEC", "MUL_ACC_FLOAT"]
-  },
-  {
-    "SHADER_SUFFIX": "f32_f32",
-    "REPLS": {
-      "SRC0_TYPE" : "f32",
-      "SRC1_TYPE" : "f32",
-      "DST_TYPE": "f32",
-      "VEC_SIZE" : 1,
-    },
-    "DECLS": ["SCALAR", "MUL_ACC_FLOAT"]
-  },
-  {
-    "SHADER_SUFFIX": "f16_f32_vec",
-    "REPLS": {
-      "SRC0_TYPE" : "vec4<f16>",
-      "SRC1_TYPE" : "vec4<f32>",
-      "DST_TYPE": "vec4<f32>",
-      "VEC_SIZE" : 4,
-    },
-    "DECLS": ["VEC", "MUL_ACC_FLOAT"]
-  },
-  {
-    "SHADER_SUFFIX": "f16_f32",
-    "REPLS": {
-      "SRC0_TYPE" : "f16",
-      "SRC1_TYPE" : "f32",
-      "DST_TYPE": "f32",
-      "VEC_SIZE" : 1,
-    },
-    "DECLS": ["SCALAR", "MUL_ACC_FLOAT"]
-  },
-  {
-    "SHADER_SUFFIX": "f16_f16_vec",
-    "REPLS": {
-      "SRC0_TYPE" : "vec4<f16>",
-      "SRC1_TYPE" : "vec4<f16>",
-      "DST_TYPE": "vec4<f32>",
-      "VEC_SIZE" : 4,
-    },
-    "DECLS": ["VEC", "MUL_ACC_FLOAT"]
-  },
-  {
-    "SHADER_SUFFIX": "f16_f16",
-    "REPLS": {
-      "SRC0_TYPE" : "f16",
-      "SRC1_TYPE" : "f16",
-      "DST_TYPE": "f32",
-      "VEC_SIZE" : 1,
-    },
-    "DECLS": ["SCALAR", "MUL_ACC_FLOAT"]
-  },
-  {
-    "SHADER_SUFFIX": "q4_0_f32",
-    "REPLS": {
-      "SRC0_TYPE" : "f16",
-      "SRC1_TYPE" : "f32",
-      "DST_TYPE": "f32",
-      "VEC_SIZE" : 1,
-    },
-    "DECLS": ["BYTE_HELPERS", "SCALAR", "MUL_ACC_Q4_0"]
-  }
-]
-
-#end(VARIANTS)
-
-#define(DECLS)
-
-#decl(VEC)
-fn inner_dot(src0_val: {{SRC0_TYPE}}, src1_val: {{SRC1_TYPE}}) -> f32 {
-    return f32(dot({{SRC1_TYPE}}(src0_val), src1_val));
-}
-
-fn store_val(group_base: u32) -> vec4<f32> {
-    return vec4<f32>(partial_sums[group_base],
-                     partial_sums[group_base + THREADS_PER_OUTPUT],
-                     partial_sums[group_base + THREADS_PER_OUTPUT * 2],
-                     partial_sums[group_base + THREADS_PER_OUTPUT * 3]);
-}
-#enddecl(VEC)
-
-#decl(SCALAR)
-fn inner_dot(src0_val: {{SRC0_TYPE}}, src1_val: {{SRC1_TYPE}}) -> f32 {
-    return f32(src0_val) * f32(src1_val);
-}
-
-fn store_val(group_base: u32) -> f32 {
-    return partial_sums[group_base];
-}
-#enddecl(SCALAR)
-
-#decl(MUL_ACC_FLOAT)
-
-fn mul_acc(tig:u32, tile_size: u32, idx_base: u32, k_outer: u32) -> f32 {
-    var local_sum = 0.0;
-    for (var i = tig * {{VEC_SIZE}}; i < tile_size; i += THREADS_PER_OUTPUT * {{VEC_SIZE}}) {
-        let a = src0[(idx_base + k_outer + i) / {{VEC_SIZE}}];
-        let b = shared_vector[i / {{VEC_SIZE}}];
-        local_sum += inner_dot(a, b);
-    }
-    return local_sum;
-}
-
-#enddecl(MUL_ACC_FLOAT)
-
-#decl(MUL_ACC_Q4_0)
-
-const BLOCK_SIZE = 32;
-const NQ = 16u; // number of weights per thread
-const F16_PER_BLOCK = 9u; // 1 scale + 8x4 packed weights
-const WEIGHTS_PER_F16 = 4u; // 4 weights per f16
-const F16_PER_THREAD = NQ / WEIGHTS_PER_F16;
-
-fn mul_acc(tig:u32, tile_size: u32, idx_base: u32, k_outer: u32) -> f32 {
-    var local_sum = 0.0;
-    for (var i = tig * NQ; i < tile_size; i += THREADS_PER_OUTPUT * NQ) {
-        let blck_idx = i / BLOCK_SIZE;
-        let block_offset = (i % BLOCK_SIZE) / WEIGHTS_PER_F16;
-        let scale_idx = (idx_base + k_outer / BLOCK_SIZE + blck_idx) * F16_PER_BLOCK;
-        // each f16 contains offsets [block_offset, block_offset + 1] and [block_offset + 16, block_offset + 17]
-        let shmem_idx = blck_idx * BLOCK_SIZE + block_offset * 2u;
-        let d = f32(src0[scale_idx]);
-        for (var j = 0u; j < F16_PER_THREAD; j += 2) {
-            let q_0 = src0[scale_idx + 1 + block_offset + j];
-            let q_1 = src0[scale_idx + 1 + block_offset + j + 1];
-            let q_packed = bitcast<u32>(vec2(q_0, q_1));
-            for (var k: u32 = 0; k < 4; k++) {
-                let q_byte = get_byte(q_packed, k);
-                let q_hi = (f32((q_byte >> 4) & 0xF) - 8.0) * d;
-                let q_lo = (f32(q_byte & 0xF) - 8.0) * d;
-                local_sum += q_lo * shared_vector[shmem_idx + j * 2 + k];
-                local_sum += q_hi * shared_vector[shmem_idx + j * 2 + k + 16];
-            }
-        }
-    }
-    return local_sum;
-}
-
-#enddecl(MUL_ACC_Q4_0)
-
-#end(DECLS)
-
-#define(SHADER)
-enable f16;
-
-DECLS
-
-struct MulMatParams {
-    offset_src0: u32,
-    offset_src1: u32,
-    offset_dst: u32,
-    m: u32,
-    n: u32,
-    k: u32,
-    stride_01: u32,
-    stride_11: u32,
-    stride_02: u32,
-    stride_12: u32,
-    stride_03: u32,
-    stride_13: u32,
-    bs02: u32,
-    bs03: u32,
-    broadcast2: u32,
-    broadcast3: u32
-};
-
-@group(0) @binding(0) var<storage, read_write> src0: array<{{SRC0_TYPE}}>; // Matrix (M x K)
-@group(0) @binding(1) var<storage, read_write> src1: array<{{SRC1_TYPE}}>; // Vector (K x 1, transposed)
-@group(0) @binding(2) var<storage, read_write> dst: array<{{DST_TYPE}}>;  // Result vector (transposed)
-
-@group(0) @binding(3) var<uniform> params: MulMatParams;
-
-override WORKGROUP_SIZE: u32;
-override TILE_K: u32;
-override OUTPUTS_PER_WG: u32;
-override THREADS_PER_OUTPUT = WORKGROUP_SIZE / OUTPUTS_PER_WG;
-
-// Shared memory for collaborative loading and reduction
-var<workgroup> shared_vector: array<{{SRC1_TYPE}}, TILE_K/{{VEC_SIZE}}>;  // Cache vector tile
-var<workgroup> partial_sums: array<f32, WORKGROUP_SIZE>;   // For reduction
-
-@compute @workgroup_size(WORKGROUP_SIZE)
-fn main(
-    @builtin(local_invocation_id) local_id: vec3<u32>,
-    @builtin(workgroup_id) wg_id: vec3<u32>,
-    @builtin(num_workgroups) num_wg: vec3<u32>) {
-    let thread_id = local_id.x;
-
-    // Handle batch dimensions
-    let total_batches = params.bs02 * params.broadcast2 * params.bs03 * params.broadcast3;
-    let wg_linear = wg_id.y * num_wg.x + wg_id.x;
-    let output_groups = (params.m + OUTPUTS_PER_WG - 1u) / OUTPUTS_PER_WG;
-    let batch_idx = wg_linear / output_groups;
-    if (batch_idx >= total_batches) {
-        return;
-    }
-
-    // Which of the outputs does this thread belong to?
-    let thread_group = thread_id / THREADS_PER_OUTPUT;
-    let thread_in_group = thread_id % THREADS_PER_OUTPUT;
-
-    // Each workgroup computes OUTPUTS_PER_WG consecutive outputs
-    let output_row = (wg_linear % output_groups) * OUTPUTS_PER_WG + thread_group;
-
-    let dst2_stride = params.m * params.n;
-    let dst2_idx = batch_idx % (params.bs02 * params.broadcast2);
-    let dst3_stride = dst2_stride * params.bs02 * params.broadcast2;
-    let dst3_idx = batch_idx / (params.bs02 * params.broadcast2);
-    let src03_idx = dst3_idx / params.broadcast3;
-    let src13_idx = dst3_idx;
-    let src02_idx = dst2_idx / params.broadcast2;
-    let src12_idx = dst2_idx;
-
-    let src0_idx_base = params.offset_src0 + src03_idx * params.stride_03 + src02_idx * params.stride_02 + output_row * params.stride_01;
-    let src1_idx_base = params.offset_src1 + src13_idx * params.stride_13 + src12_idx * params.stride_12;
-    let dst_idx = params.offset_dst + dst3_idx * dst3_stride + dst2_idx * dst2_stride + output_row;
-
-    var local_sum = 0.0;
-
-    // Each thread processes multiple K elements and accumulates
-    for (var k_tile = 0u; k_tile < params.k; k_tile += TILE_K) {
-        let tile_size = min(TILE_K, params.k - k_tile);
-
-        // Cooperatively load vector tile into shared memory (all threads)
-        for (var i = thread_id * {{VEC_SIZE}}; i < tile_size; i += WORKGROUP_SIZE * {{VEC_SIZE}}) {
-            shared_vector[i / {{VEC_SIZE}}] = src1[(src1_idx_base + k_tile + i) / {{VEC_SIZE}}];
-        }
-
-        workgroupBarrier();
-
-        if (output_row < params.m) {
-            local_sum += mul_acc(thread_in_group, tile_size, src0_idx_base, k_tile);
-        }
-
-        workgroupBarrier();
-    }
-
-    // Store partial sums and reduce within each partition
-    partial_sums[thread_id] = local_sum;
-    workgroupBarrier();
-    let group_base = thread_group * THREADS_PER_OUTPUT;
-    let thread_base = group_base + thread_in_group;
-    var offset = THREADS_PER_OUTPUT / 2;
-    while (offset > 0) {
-        if (thread_in_group < offset) {
-            partial_sums[thread_base] += partial_sums[thread_base + offset];
-        }
-        offset = offset / 2;
-        workgroupBarrier();
-    }
-
-    // Store back to global memory
-    if (output_row < params.m && thread_group % {{VEC_SIZE}} == 0 && thread_in_group == 0) {
-        dst[dst_idx / {{VEC_SIZE}}] = store_val(group_base);
-    }
-}
-#end(SHADER)
diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec.wgsl
new file mode 100644 (file)
index 0000000..f9ea95e
--- /dev/null
@@ -0,0 +1,194 @@
+
+enable f16;
+
+#include "common_decls.tmpl"
+
+#ifdef VEC
+
+#define VEC_SIZE 4
+#define DST_TYPE vec4<f32>
+#define SRC0_TYPE vec4<SRC0_INNER_TYPE>
+#define SRC1_TYPE vec4<SRC1_INNER_TYPE>
+
+fn inner_dot(src0_val: SRC0_TYPE, src1_val: SRC1_TYPE) -> f32 {
+    return f32(dot(SRC1_TYPE(src0_val), src1_val));
+}
+
+fn store_val(group_base: u32) -> vec4<f32> {
+    return vec4<f32>(partial_sums[group_base],
+                     partial_sums[group_base + THREADS_PER_OUTPUT],
+                     partial_sums[group_base + THREADS_PER_OUTPUT * 2],
+                     partial_sums[group_base + THREADS_PER_OUTPUT * 3]);
+}
+#endif
+
+#ifdef SCALAR
+
+#define VEC_SIZE 1
+#define DST_TYPE f32
+#define SRC0_TYPE SRC0_INNER_TYPE
+#define SRC1_TYPE SRC1_INNER_TYPE
+
+fn inner_dot(src0_val: SRC0_TYPE, src1_val: SRC1_TYPE) -> f32 {
+    return f32(src0_val) * f32(src1_val);
+}
+
+fn store_val(group_base: u32) -> f32 {
+    return partial_sums[group_base];
+}
+#endif
+
+#ifdef MUL_ACC_FLOAT
+fn mul_acc(tig:u32, tile_size: u32, idx_base: u32, k_outer: u32) -> f32 {
+    var local_sum = 0.0;
+    for (var i = tig * VEC_SIZE; i < tile_size; i += THREADS_PER_OUTPUT * VEC_SIZE) {
+        let a = src0[(idx_base + k_outer + i) / VEC_SIZE];
+        let b = shared_vector[i / VEC_SIZE];
+        local_sum += inner_dot(a, b);
+    }
+    return local_sum;
+}
+#endif
+
+#ifdef MUL_ACC_Q4_0
+
+const BLOCK_SIZE = 32;
+const NQ = 16u; // number of weights per thread
+const F16_PER_BLOCK = 9u; // 1 scale + 8x4 packed weights
+const WEIGHTS_PER_F16 = 4u; // 4 weights per f16
+const F16_PER_THREAD = NQ / WEIGHTS_PER_F16;
+
+fn mul_acc(tig:u32, tile_size: u32, idx_base: u32, k_outer: u32) -> f32 {
+    var local_sum = 0.0;
+    for (var i = tig * NQ; i < tile_size; i += THREADS_PER_OUTPUT * NQ) {
+        let blck_idx = i / BLOCK_SIZE;
+        let block_offset = (i % BLOCK_SIZE) / WEIGHTS_PER_F16;
+        let scale_idx = (idx_base + k_outer / BLOCK_SIZE + blck_idx) * F16_PER_BLOCK;
+        // each f16 contains offsets [block_offset, block_offset + 1] and [block_offset + 16, block_offset + 17]
+        let shmem_idx = blck_idx * BLOCK_SIZE + block_offset * 2u;
+        let d = f32(src0[scale_idx]);
+        for (var j = 0u; j < F16_PER_THREAD; j += 2) {
+            let q_0 = src0[scale_idx + 1 + block_offset + j];
+            let q_1 = src0[scale_idx + 1 + block_offset + j + 1];
+            let q_packed = bitcast<u32>(vec2(q_0, q_1));
+            for (var k: u32 = 0; k < 4; k++) {
+                let q_byte = get_byte(q_packed, k);
+                let q_hi = (f32((q_byte >> 4) & 0xF) - 8.0) * d;
+                let q_lo = (f32(q_byte & 0xF) - 8.0) * d;
+                local_sum += q_lo * shared_vector[shmem_idx + j * 2 + k];
+                local_sum += q_hi * shared_vector[shmem_idx + j * 2 + k + 16];
+            }
+        }
+    }
+    return local_sum;
+}
+#endif
+
+struct MulMatParams {
+    offset_src0: u32,
+    offset_src1: u32,
+    offset_dst: u32,
+    m: u32,
+    n: u32,
+    k: u32,
+    stride_01: u32,
+    stride_11: u32,
+    stride_02: u32,
+    stride_12: u32,
+    stride_03: u32,
+    stride_13: u32,
+    bs02: u32,
+    bs03: u32,
+    broadcast2: u32,
+    broadcast3: u32
+};
+
+// SRC0_TYPE and SRC1_TYPE are defined in mul_mat_decls, which is included
+@group(0) @binding(0) var<storage, read_write> src0: array<SRC0_TYPE>; // M rows, K columns
+@group(0) @binding(1) var<storage, read_write> src1: array<SRC1_TYPE>; // K rows, N columns (transposed)
+@group(0) @binding(2) var<storage, read_write> dst: array<DST_TYPE>; // M rows, N columns (transposed)
+
+@group(0) @binding(3) var<uniform> params: MulMatParams;
+
+const THREADS_PER_OUTPUT = WG_SIZE / OUTPUTS_PER_WG;
+
+// Shared memory for collaborative loading and reduction
+var<workgroup> shared_vector: array<SRC1_TYPE, TILE_K/VEC_SIZE>;  // Cache vector tile
+var<workgroup> partial_sums: array<f32, WG_SIZE>;   // For reduction
+
+@compute @workgroup_size(WG_SIZE)
+fn main(
+    @builtin(local_invocation_id) local_id: vec3<u32>,
+    @builtin(workgroup_id) wg_id: vec3<u32>,
+    @builtin(num_workgroups) num_wg: vec3<u32>) {
+    let thread_id = local_id.x;
+
+    // Handle batch dimensions
+    let total_batches = params.bs02 * params.broadcast2 * params.bs03 * params.broadcast3;
+    let wg_linear = wg_id.y * num_wg.x + wg_id.x;
+    let output_groups = (params.m + OUTPUTS_PER_WG - 1u) / OUTPUTS_PER_WG;
+    let batch_idx = wg_linear / output_groups;
+    if (batch_idx >= total_batches) {
+        return;
+    }
+
+    // Which of the outputs does this thread belong to?
+    let thread_group = thread_id / THREADS_PER_OUTPUT;
+    let thread_in_group = thread_id % THREADS_PER_OUTPUT;
+
+    // Each workgroup computes OUTPUTS_PER_WG consecutive outputs
+    let output_row = (wg_linear % output_groups) * OUTPUTS_PER_WG + thread_group;
+
+    let dst2_stride = params.m * params.n;
+    let dst2_idx = batch_idx % (params.bs02 * params.broadcast2);
+    let dst3_stride = dst2_stride * params.bs02 * params.broadcast2;
+    let dst3_idx = batch_idx / (params.bs02 * params.broadcast2);
+    let src03_idx = dst3_idx / params.broadcast3;
+    let src13_idx = dst3_idx;
+    let src02_idx = dst2_idx / params.broadcast2;
+    let src12_idx = dst2_idx;
+
+    let src0_idx_base = params.offset_src0 + src03_idx * params.stride_03 + src02_idx * params.stride_02 + output_row * params.stride_01;
+    let src1_idx_base = params.offset_src1 + src13_idx * params.stride_13 + src12_idx * params.stride_12;
+    let dst_idx = params.offset_dst + dst3_idx * dst3_stride + dst2_idx * dst2_stride + output_row;
+
+    var local_sum = 0.0;
+
+    // Each thread processes multiple K elements and accumulates
+    for (var k_tile = 0u; k_tile < params.k; k_tile += TILE_K) {
+        let tile_size = min(TILE_K, params.k - k_tile);
+
+        // Cooperatively load vector tile into shared memory (all threads)
+        for (var i = thread_id * VEC_SIZE; i < tile_size; i += WG_SIZE * VEC_SIZE) {
+            shared_vector[i / VEC_SIZE] = src1[(src1_idx_base + k_tile + i) / VEC_SIZE];
+        }
+
+        workgroupBarrier();
+
+        if (output_row < params.m) {
+            local_sum += mul_acc(thread_in_group, tile_size, src0_idx_base, k_tile);
+        }
+
+        workgroupBarrier();
+    }
+
+    // Store partial sums and reduce within each partition
+    partial_sums[thread_id] = local_sum;
+    workgroupBarrier();
+    let group_base = thread_group * THREADS_PER_OUTPUT;
+    let thread_base = group_base + thread_in_group;
+    var offset: u32 = THREADS_PER_OUTPUT / 2;
+    while (offset > 0) {
+        if (thread_in_group < offset) {
+            partial_sums[thread_base] += partial_sums[thread_base + offset];
+        }
+        offset = offset / 2;
+        workgroupBarrier();
+    }
+
+    // Store back to global memory
+    if (output_row < params.m && thread_group % VEC_SIZE == 0 && thread_in_group == 0) {
+        dst[dst_idx / VEC_SIZE] = store_val(group_base);
+    }
+}
+
diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/scale.tmpl.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/scale.tmpl.wgsl
deleted file mode 100644 (file)
index 040e80d..0000000
+++ /dev/null
@@ -1,90 +0,0 @@
-#define(VARIANTS)
-
-[
-  {
-    "SHADER_NAME": "scale_f32",
-    "DECLS": ["NOT_INPLACE"]
-  },
-  {
-    "SHADER_NAME": "scale_f32_inplace",
-    "DECLS": ["INPLACE"]
-  }
-]
-
-#end(VARIANTS)
-
-#define(DECLS)
-
-#decl(NOT_INPLACE)
-@group(0) @binding(1)
-var<storage, read_write> dst: array<f32>;
-
-@group(0) @binding(2)
-var<uniform> params: Params;
-
-fn store_scale(val: f32, offset: u32) {
-    dst[offset] = val;
-}
-#enddecl(NOT_INPLACE)
-
-#decl(INPLACE)
-@group(0) @binding(1)
-var<uniform> params: Params;
-
-fn store_scale(val: f32, offset: u32) {
-    src[offset] = val;
-}
-#enddecl(INPLACE)
-
-#end(DECLS)
-
-#define(SHADER)
-
-struct Params {
-    offset_src: u32,
-    offset_dst: u32,
-
-    // Strides (in elements)
-    stride_src1: u32,
-    stride_src2: u32,
-    stride_src3: u32,
-
-    stride_dst1: u32,
-    stride_dst2: u32,
-    stride_dst3: u32,
-
-    ne: u32,
-    ne0: u32,
-    ne1: u32,
-    ne2: u32,
-
-    scale: f32,
-    bias: f32
-};
-
-@group(0) @binding(0)
-var<storage, read_write> src: array<f32>;
-
-DECLS
-
-override wg_size: u32;
-@compute @workgroup_size(wg_size)
-fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
-    if (gid.x >= params.ne) {
-        return;
-    }
-
-    var i = gid.x;
-    let i3 = i / (params.ne2 * params.ne1 * params.ne0);
-    i = i % (params.ne2 * params.ne1 * params.ne0);
-    let i2 = i / (params.ne1 * params.ne0);
-    i = i % (params.ne1 * params.ne0);
-    let i1 = i / params.ne0;
-    let i0 = i % params.ne0;
-
-    let i_src = params.offset_src + i3 * params.stride_src3 + i2 * params.stride_src2 + i1 * params.stride_src1 + i0;
-    let i_dst = params.offset_dst + i3 * params.stride_dst3 + i2 * params.stride_dst2 + i1 * params.stride_dst1 + i0;
-
-    store_scale(src[i_src] * params.scale + params.bias, i_dst);
-}
-#end(SHADER)
diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/scale.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/scale.wgsl
new file mode 100644 (file)
index 0000000..3b70a87
--- /dev/null
@@ -0,0 +1,63 @@
+#ifdef INPLACE
+@group(0) @binding(1)
+var<uniform> params: Params;
+
+fn store_scale(val: f32, offset: u32) {
+    src[offset] = val;
+}
+#else
+@group(0) @binding(1)
+var<storage, read_write> dst: array<f32>;
+
+@group(0) @binding(2)
+var<uniform> params: Params;
+
+fn store_scale(val: f32, offset: u32) {
+    dst[offset] = val;
+}
+#endif
+
+struct Params {
+    offset_src: u32,
+    offset_dst: u32,
+
+    // Strides (in elements)
+    stride_src1: u32,
+    stride_src2: u32,
+    stride_src3: u32,
+
+    stride_dst1: u32,
+    stride_dst2: u32,
+    stride_dst3: u32,
+
+    ne: u32,
+    ne0: u32,
+    ne1: u32,
+    ne2: u32,
+
+    scale: f32,
+    bias: f32
+};
+
+@group(0) @binding(0)
+var<storage, read_write> src: array<f32>;
+
+@compute @workgroup_size(WG_SIZE)
+fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
+    if (gid.x >= params.ne) {
+        return;
+    }
+
+    var i = gid.x;
+    let i3 = i / (params.ne2 * params.ne1 * params.ne0);
+    i = i % (params.ne2 * params.ne1 * params.ne0);
+    let i2 = i / (params.ne1 * params.ne0);
+    i = i % (params.ne1 * params.ne0);
+    let i1 = i / params.ne0;
+    let i0 = i % params.ne0;
+
+    let i_src = params.offset_src + i3 * params.stride_src3 + i2 * params.stride_src2 + i1 * params.stride_src1 + i0;
+    let i_dst = params.offset_dst + i3 * params.stride_dst3 + i2 * params.stride_dst2 + i1 * params.stride_dst1 + i0;
+
+    store_scale(src[i_src] * params.scale + params.bias, i_dst);
+}