]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
ggml webgpu: faster matrix multiplication/matrix-vector multiplication (#17031)
authorReese Levine <redacted>
Sat, 8 Nov 2025 03:27:20 +0000 (19:27 -0800)
committerGitHub <redacted>
Sat, 8 Nov 2025 03:27:20 +0000 (19:27 -0800)
* Faster tensors (#8)

Add fast matrix and matrix/vector multiplication.

* Use map for shader replacements instead of pair of strings

.github/workflows/build.yml
ggml/src/ggml-webgpu/ggml-webgpu.cpp
ggml/src/ggml-webgpu/wgsl-shaders/embed_wgsl.py
ggml/src/ggml-webgpu/wgsl-shaders/mul_mat.tmpl.wgsl
ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_decls.tmpl [new file with mode: 0644]
ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_reg_tile.tmpl.wgsl [new file with mode: 0644]
ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_subgroup_matrix.tmpl.wgsl [new file with mode: 0644]
ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec.tmpl.wgsl [new file with mode: 0644]

index 15e1133095213f7637181815872d164394af762a..36084c55078ef82d33790d838e5d2a7afaa18888 100644 (file)
@@ -161,15 +161,16 @@ jobs:
       - name: Dawn Dependency
         id: dawn-depends
         run: |
-          DAWN_VERSION="v1.0.0"
+          DAWN_VERSION="v2.0.0"
           DAWN_OWNER="reeselevine"
           DAWN_REPO="dawn"
-          DAWN_ASSET_NAME="Dawn-a1a6b45cced25a3b7f4fb491e0ae70796cc7f22b-macos-latest-Release.tar.gz"
+          DAWN_ASSET_NAME="Dawn-5e9a4865b1635796ccc77dd30057f2b4002a1355-macos-latest-Release.zip"
           echo "Fetching release asset from https://github.com/${DAWN_OWNER}/${DAWN_REPO}/releases/download/${DAWN_VERSION}/${DAWN_ASSET_NAME}"
-          curl -L -o artifact.tar.gz \
+          curl -L -o artifact.zip \
             "https://github.com/${DAWN_OWNER}/${DAWN_REPO}/releases/download/${DAWN_VERSION}/${DAWN_ASSET_NAME}"
           mkdir dawn
-          tar -xvf artifact.tar.gz -C dawn --strip-components=1
+          unzip artifact.zip
+          tar -xvf Dawn-5e9a4865b1635796ccc77dd30057f2b4002a1355-macos-latest-Release.tar.gz -C dawn --strip-components=1
 
       - name: Build
         id: cmake_build
@@ -521,15 +522,16 @@ jobs:
         id: dawn-depends
         run: |
           sudo apt-get install -y libxrandr-dev libxinerama-dev libxcursor-dev mesa-common-dev libx11-xcb-dev libxi-dev
-          DAWN_VERSION="v1.0.0"
+          DAWN_VERSION="v2.0.0"
           DAWN_OWNER="reeselevine"
           DAWN_REPO="dawn"
-          DAWN_ASSET_NAME="Dawn-a1a6b45cced25a3b7f4fb491e0ae70796cc7f22b-ubuntu-latest-Release.tar.gz"
+          DAWN_ASSET_NAME="Dawn-5e9a4865b1635796ccc77dd30057f2b4002a1355-ubuntu-latest-Release.zip"
           echo "Fetching release asset from https://github.com/${DAWN_OWNER}/${DAWN_REPO}/releases/download/${DAWN_VERSION}/${DAWN_ASSET_NAME}"
-          curl -L -o artifact.tar.gz \
+          curl -L -o artifact.zip \
             "https://github.com/${DAWN_OWNER}/${DAWN_REPO}/releases/download/${DAWN_VERSION}/${DAWN_ASSET_NAME}"
           mkdir dawn
-          tar -xvf artifact.tar.gz -C dawn --strip-components=1
+          unzip artifact.zip
+          tar -xvf Dawn-5e9a4865b1635796ccc77dd30057f2b4002a1355-ubuntu-latest-Release.tar.gz -C dawn --strip-components=1
 
       - name: Build
         id: cmake_build
index 1a15756731580f2868543c337814a45a8d0f8ea6..9e8cbc477ed18ba67d5831b611ecef7b360cd28e 100644 (file)
@@ -15,6 +15,7 @@
 #include <condition_variable>
 #include <cstring>
 #include <iostream>
+#include <map>
 #include <mutex>
 #include <optional>
 #include <string>
 // For operations which process a row in parallel, this seems like a reasonable default
 #define WEBGPU_ROW_SPLIT_WG_SIZE 64
 
+// Matrix multiplication parameters
+
+// Register tiling parameters
+#define WEBGPU_MUL_MAT_TILE_M    8
+#define WEBGPU_MUL_MAT_TILE_N    8
+#define WEBGPU_MUL_MAT_WG_SIZE_M 8
+#define WEBGPU_MUL_MAT_WG_SIZE_N 8
+#define WEBGPU_MUL_MAT_TILE_K    32
+
+// Subgroup matrix parameters
+// The number of subgroups in the M dimension
+#define WEBGPU_MUL_MAT_SUBGROUP_M        2
+// The number of subgroups in the N dimension
+#define WEBGPU_MUL_MAT_SUBGROUP_N        2
+// The number of subgroup matrices each subgroup accumulates over
+#define WEBGPU_MUL_MAT_SUBGROUP_MATRIX_M 4
+#define WEBGPU_MUL_MAT_SUBGROUP_MATRIX_N 2
+
+// Matrix-vector multiplication parameters
+#define WEBGPU_MUL_MAT_VEC_WG_SIZE        256
+// Must be multiple of 4 to work with vectorized paths, and must divide mul_mat_vec wg size
+#define WEBGPU_MUL_MAT_VEC_OUTPUTS_PER_WG 64
+#define WEBGPU_MUL_MAT_VEC_TILE_K         256
+
 /* End Constants */
 
 // This is a "fake" base pointer, since WebGPU buffers do not have pointers to their locations.
@@ -236,6 +261,10 @@ struct webgpu_context_struct {
     wgpu::Queue    queue;
     wgpu::Limits   limits;
 
+    bool                       supports_subgroup_matrix = false;
+    uint32_t                   subgroup_size;
+    wgpu::SubgroupMatrixConfig subgroup_matrix_config;
+
     // Separate this out from limits since on some Metal systems, the limit returned by
     // querying the limits is higher than the actual allowed maximum.
     uint32_t max_wg_size_x;
@@ -247,6 +276,11 @@ struct webgpu_context_struct {
     webgpu_buf_pool set_rows_error_buf_pool;
 
     webgpu_pipeline memset_pipeline;
+
+    std::map<int, std::map<int, std::map<int, webgpu_pipeline>>> mul_mat_pipelines;  // src0_type, src1_type, vectorized
+    std::map<int, std::map<int, std::map<int, webgpu_pipeline>>>
+        mul_mat_vec_pipelines;                                                       // src0_type, src1_type, vectorized
+
     webgpu_pipeline mul_mat_pipeline[30][2];
     webgpu_pipeline set_rows_pipeline[1][2];  // dst->type, vectorized
     webgpu_pipeline get_rows_pipeline[30];
@@ -321,6 +355,25 @@ struct ggml_backend_webgpu_buffer_context {
 
 /* WebGPU object initializations */
 
+// Process a WGSL shader string, replacing tokens of the form {{KEY}} with
+// the corresponding values provided in `repls`.
+static std::string ggml_webgpu_process_shader_repls(const char *                               src,
+                                                    const std::map<std::string, std::string> & repls) {
+    if (!src) {
+        return std::string();
+    }
+    std::string s = src;
+    for (const auto & kv : repls) {
+        std::string token = "{{" + kv.first + "}}";
+        size_t      pos   = 0;
+        while ((pos = s.find(token, pos)) != std::string::npos) {
+            s.replace(pos, token.length(), kv.second);
+            pos += kv.second.length();
+        }
+    }
+    return s;
+}
+
 static void ggml_webgpu_create_pipeline(wgpu::Device &                           device,
                                         webgpu_pipeline &                        pipeline,
                                         const char *                             shader_code,
@@ -346,6 +399,30 @@ static void ggml_webgpu_create_pipeline(wgpu::Device &
     pipeline = { device.CreateComputePipeline(&pipeline_desc), label };
 }
 
+static webgpu_pipeline ggml_webgpu_create_pipeline2(wgpu::Device &                           device,
+                                                    const char *                             shader_code,
+                                                    const char *                             label,
+                                                    const std::vector<wgpu::ConstantEntry> & constants = {}) {
+    wgpu::ShaderSourceWGSL shader_source;
+    shader_source.code = shader_code;
+
+    wgpu::ShaderModuleDescriptor shader_desc;
+    shader_desc.nextInChain = &shader_source;
+
+    wgpu::ShaderModule shader_module = device.CreateShaderModule(&shader_desc);
+
+    wgpu::ComputePipelineDescriptor pipeline_desc;
+    pipeline_desc.label              = label;
+    pipeline_desc.compute.module     = shader_module;
+    pipeline_desc.compute.entryPoint = "main";   // Entry point in the WGSL code
+    pipeline_desc.layout             = nullptr;  // nullptr means auto layout
+    if (constants.size() > 0) {
+        pipeline_desc.compute.constants     = constants.data();
+        pipeline_desc.compute.constantCount = constants.size();
+    }
+    return { device.CreateComputePipeline(&pipeline_desc), label };
+}
+
 static void ggml_webgpu_create_buffer(wgpu::Device &    device,
                                       wgpu::Buffer &    buffer,
                                       size_t            size,
@@ -512,6 +589,7 @@ static webgpu_command ggml_backend_webgpu_build(webgpu_context &
                                                 std::vector<uint32_t>             params,
                                                 std::vector<wgpu::BindGroupEntry> bind_group_entries,
                                                 uint32_t                          wg_x,
+                                                uint32_t                          wg_y                = 1,
                                                 std::optional<webgpu_pool_bufs>   set_rows_error_bufs = std::nullopt) {
     webgpu_pool_bufs params_bufs = ctx->param_buf_pool.alloc_bufs();
 
@@ -557,7 +635,7 @@ static webgpu_command ggml_backend_webgpu_build(webgpu_context &
 #endif
     pass.SetPipeline(pipeline.pipeline);
     pass.SetBindGroup(0, bind_group);
-    pass.DispatchWorkgroups(wg_x, 1, 1);
+    pass.DispatchWorkgroups(wg_x, wg_y, 1);
     pass.End();
 
 #ifdef GGML_WEBGPU_GPU_PROFILE
@@ -779,7 +857,7 @@ static std::optional<webgpu_command> ggml_webgpu_set_rows(webgpu_context & ctx,
 
     uint32_t wg_x = (threads + max_wg_size - 1) / max_wg_size;
 
-    return ggml_backend_webgpu_build(ctx, pipeline, params, entries, wg_x, error_bufs);
+    return ggml_backend_webgpu_build(ctx, pipeline, params, entries, wg_x, 1, error_bufs);
 }
 
 static webgpu_command ggml_webgpu_get_rows(webgpu_context & ctx,
@@ -835,8 +913,8 @@ static webgpu_command ggml_webgpu_mul_mat(webgpu_context & ctx,
         (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src0) / ggml_type_size(src0->type)),
         (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src1) / ggml_type_size(src1->type)),
         (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, dst) / ggml_type_size(dst->type)),
-        (uint32_t) dst->ne[1],                                  // number of rows in result (M)
-        (uint32_t) dst->ne[0],                                  // number of columns in result (N)
+        (uint32_t) dst->ne[0],                                  // number of rows in result (M, transposed)
+        (uint32_t) dst->ne[1],                                  // number of columns in result (N)
         (uint32_t) src0->ne[0],                                 // number of columns in src0/src1 (K)
         (uint32_t) (src0->nb[1] / ggml_type_size(src0->type)),  // stride (elements/blocks) of src0 in dimension 1
         (uint32_t) (src1->nb[1] / ggml_type_size(src1->type)),  // stride (elements/blocks) of src1 in dimension 1
@@ -865,9 +943,67 @@ static webgpu_command ggml_webgpu_mul_mat(webgpu_context & ctx,
          .size    = ggml_webgpu_tensor_binding_size(ctx, dst)  },
     };
 
+    webgpu_pipeline pipeline = ctx->mul_mat_pipeline[src0->type][src1->type];
+
     uint32_t wg_x =
         (dst->ne[0] * dst->ne[1] * dst->ne[2] * dst->ne[3] + WEBGPU_MUL_MAT_WG_SIZE - 1) / WEBGPU_MUL_MAT_WG_SIZE;
-    return ggml_backend_webgpu_build(ctx, ctx->mul_mat_pipeline[src0->type][src1->type], params, entries, wg_x);
+    uint32_t wg_y = 1;
+
+    bool use_fast = false;
+    switch (src1->type) {
+        case GGML_TYPE_F16:
+            use_fast = (src0->type == GGML_TYPE_F16);
+            break;
+        case GGML_TYPE_F32:
+            switch (src0->type) {
+                case GGML_TYPE_F32:
+                case GGML_TYPE_F16:
+                case GGML_TYPE_Q4_0:
+                    use_fast = true;
+                    break;
+                default:
+                    break;
+            }
+            break;
+        default:
+            break;
+    }
+
+    if (use_fast) {
+        int vectorized = src0->ne[0] % 4 == 0 && dst->ne[0] % 4 == 0 && dst->ne[1] % 4 == 0;
+        if (dst->ne[1] == 1) {
+            // We don't support vectorized mul_mat_vec for quantized types
+            vectorized       = vectorized && (src0->type < 2);
+            pipeline         = ctx->mul_mat_vec_pipelines[src0->type][src1->type][vectorized];
+            uint32_t batches = dst->ne[2] * dst->ne[3];
+            uint32_t output_groups =
+                (dst->ne[0] + WEBGPU_MUL_MAT_VEC_OUTPUTS_PER_WG - 1) / WEBGPU_MUL_MAT_VEC_OUTPUTS_PER_WG;
+            uint32_t total_wg = output_groups * batches;
+            wg_x              = total_wg % ctx->limits.maxComputeWorkgroupsPerDimension;
+            wg_y              = (total_wg + ctx->limits.maxComputeWorkgroupsPerDimension - 1) /
+                   ctx->limits.maxComputeWorkgroupsPerDimension;
+        } else {
+            pipeline = ctx->mul_mat_pipelines[src0->type][src1->type][vectorized];
+            uint32_t wg_m;
+            uint32_t wg_n;
+            if (ctx->supports_subgroup_matrix) {
+                // The total number of subgroups/workgroups needed per matrix.
+                uint32_t wg_m_sg_tile =
+                    WEBGPU_MUL_MAT_SUBGROUP_M * WEBGPU_MUL_MAT_SUBGROUP_MATRIX_M * ctx->subgroup_matrix_config.M;
+                wg_m = (dst->ne[0] + wg_m_sg_tile - 1) / wg_m_sg_tile;
+                uint32_t wg_n_sg_tile =
+                    WEBGPU_MUL_MAT_SUBGROUP_N * WEBGPU_MUL_MAT_SUBGROUP_MATRIX_N * ctx->subgroup_matrix_config.N;
+                wg_n = (dst->ne[1] + wg_n_sg_tile - 1) / wg_n_sg_tile;
+            } else {
+                uint32_t tile_m_s = WEBGPU_MUL_MAT_TILE_M * WEBGPU_MUL_MAT_WG_SIZE_M;
+                uint32_t tile_n_s = WEBGPU_MUL_MAT_TILE_N * WEBGPU_MUL_MAT_WG_SIZE_N;
+                wg_m              = (dst->ne[0] + tile_m_s - 1) / tile_m_s;
+                wg_n              = (dst->ne[1] + tile_n_s - 1) / tile_n_s;
+            }
+            wg_x = wg_m * wg_n * dst->ne[2] * dst->ne[3];
+        }
+    }
+    return ggml_backend_webgpu_build(ctx, pipeline, params, entries, wg_x, wg_y);
 }
 
 static webgpu_command ggml_webgpu_binary_op(webgpu_context &  ctx,
@@ -1583,12 +1719,6 @@ static void ggml_webgpu_init_memset_pipeline(webgpu_context & webgpu_ctx) {
 }
 
 static void ggml_webgpu_init_mul_mat_pipeline(webgpu_context & webgpu_ctx) {
-    ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->mul_mat_pipeline[GGML_TYPE_F32][GGML_TYPE_F32],
-                                wgsl_mul_mat_f32_f32, "mul_mat_f32_f32");
-    ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->mul_mat_pipeline[GGML_TYPE_F16][GGML_TYPE_F16],
-                                wgsl_mul_mat_f16_f16, "mul_mat_f16_f16");
-    ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->mul_mat_pipeline[GGML_TYPE_F16][GGML_TYPE_F32],
-                                wgsl_mul_mat_f16_f32, "mul_mat_f16_f32");
     ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->mul_mat_pipeline[GGML_TYPE_Q4_0][GGML_TYPE_F32],
                                 wgsl_mul_mat_q4_0_f32, "mul_mat_q4_0_f32");
     ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->mul_mat_pipeline[GGML_TYPE_Q4_1][GGML_TYPE_F32],
@@ -1627,6 +1757,136 @@ static void ggml_webgpu_init_mul_mat_pipeline(webgpu_context & webgpu_ctx) {
                                 wgsl_mul_mat_iq4_nl_f32, "mul_mat_iq4_nl_f32");
     ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->mul_mat_pipeline[GGML_TYPE_IQ4_XS][GGML_TYPE_F32],
                                 wgsl_mul_mat_iq4_xs_f32, "mul_mat_iq4_xs_f32");
+
+    if (webgpu_ctx->supports_subgroup_matrix) {
+        std::map<std::string, std::string> sg_matrix_repls;
+        sg_matrix_repls["WEBGPU_MAX_SUBGROUP_SIZE"] = std::to_string(webgpu_ctx->subgroup_size);
+        sg_matrix_repls["WEBGPU_TILE_K"]            = std::to_string(WEBGPU_MUL_MAT_TILE_K);
+        sg_matrix_repls["WEBGPU_SUBGROUP_M"]        = std::to_string(WEBGPU_MUL_MAT_SUBGROUP_M);
+        sg_matrix_repls["WEBGPU_SUBGROUP_N"]        = std::to_string(WEBGPU_MUL_MAT_SUBGROUP_N);
+        sg_matrix_repls["WEBGPU_SUBGROUP_MATRIX_M"] = std::to_string(WEBGPU_MUL_MAT_SUBGROUP_MATRIX_M);
+        sg_matrix_repls["WEBGPU_SUBGROUP_MATRIX_N"] = std::to_string(WEBGPU_MUL_MAT_SUBGROUP_MATRIX_N);
+        sg_matrix_repls["WEBGPU_SG_MAT_M_SIZE"]     = std::to_string(webgpu_ctx->subgroup_matrix_config.M);
+        sg_matrix_repls["WEBGPU_SG_MAT_N_SIZE"]     = std::to_string(webgpu_ctx->subgroup_matrix_config.N);
+        sg_matrix_repls["WEBGPU_SG_MAT_K_SIZE"]     = std::to_string(webgpu_ctx->subgroup_matrix_config.K);
+
+        std::string proc_mul_mat_subgroup_matrix_f32_f32 =
+            ggml_webgpu_process_shader_repls(wgsl_mul_mat_subgroup_matrix_f32_f32, sg_matrix_repls);
+        std::string proc_mul_mat_subgroup_matrix_f32_f32_vec =
+            ggml_webgpu_process_shader_repls(wgsl_mul_mat_subgroup_matrix_f32_f32_vec, sg_matrix_repls);
+        std::string proc_mul_mat_subgroup_matrix_f16_f32 =
+            ggml_webgpu_process_shader_repls(wgsl_mul_mat_subgroup_matrix_f16_f32, sg_matrix_repls);
+        std::string proc_mul_mat_subgroup_matrix_f16_f32_vec =
+            ggml_webgpu_process_shader_repls(wgsl_mul_mat_subgroup_matrix_f16_f32_vec, sg_matrix_repls);
+        std::string proc_mul_mat_subgroup_matrix_f16_f16 =
+            ggml_webgpu_process_shader_repls(wgsl_mul_mat_subgroup_matrix_f16_f16, sg_matrix_repls);
+        std::string proc_mul_mat_subgroup_matrix_f16_f16_vec =
+            ggml_webgpu_process_shader_repls(wgsl_mul_mat_subgroup_matrix_f16_f16_vec, sg_matrix_repls);
+        std::string proc_mul_mat_subgroup_matrix_q4_0_f32 =
+            ggml_webgpu_process_shader_repls(wgsl_mul_mat_subgroup_matrix_q4_0_f32, sg_matrix_repls);
+        std::string proc_mul_mat_subgroup_matrix_q4_0_f32_vec =
+            ggml_webgpu_process_shader_repls(wgsl_mul_mat_subgroup_matrix_q4_0_f32_vec, sg_matrix_repls);
+
+        webgpu_ctx->mul_mat_pipelines[GGML_TYPE_F32][GGML_TYPE_F32][0] = ggml_webgpu_create_pipeline2(
+            webgpu_ctx->device, proc_mul_mat_subgroup_matrix_f32_f32.c_str(), "mul_mat_subgroup_matrix_f32_f32");
+        webgpu_ctx->mul_mat_pipelines[GGML_TYPE_F32][GGML_TYPE_F32][1] =
+            ggml_webgpu_create_pipeline2(webgpu_ctx->device, proc_mul_mat_subgroup_matrix_f32_f32_vec.c_str(),
+                                         "mul_mat_subgroup_matrix_f32_f32_vec");
+        webgpu_ctx->mul_mat_pipelines[GGML_TYPE_F16][GGML_TYPE_F32][0] = ggml_webgpu_create_pipeline2(
+            webgpu_ctx->device, proc_mul_mat_subgroup_matrix_f16_f32.c_str(), "mul_mat_subgroup_matrix_f16_f32");
+        webgpu_ctx->mul_mat_pipelines[GGML_TYPE_F16][GGML_TYPE_F32][1] =
+            ggml_webgpu_create_pipeline2(webgpu_ctx->device, proc_mul_mat_subgroup_matrix_f16_f32_vec.c_str(),
+                                         "mul_mat_subgroup_matrix_f16_f32_vec");
+        webgpu_ctx->mul_mat_pipelines[GGML_TYPE_F16][GGML_TYPE_F16][0] = ggml_webgpu_create_pipeline2(
+            webgpu_ctx->device, proc_mul_mat_subgroup_matrix_f16_f16.c_str(), "mul_mat_subgroup_matrix_f16_f16");
+        webgpu_ctx->mul_mat_pipelines[GGML_TYPE_F16][GGML_TYPE_F16][1] =
+            ggml_webgpu_create_pipeline2(webgpu_ctx->device, proc_mul_mat_subgroup_matrix_f16_f16_vec.c_str(),
+                                         "mul_mat_subgroup_matrix_f16_f16_vec");
+        webgpu_ctx->mul_mat_pipelines[GGML_TYPE_Q4_0][GGML_TYPE_F32][0] = ggml_webgpu_create_pipeline2(
+            webgpu_ctx->device, proc_mul_mat_subgroup_matrix_q4_0_f32.c_str(), "mul_mat_subgroup_matrix_q4_0_f32");
+        webgpu_ctx->mul_mat_pipelines[GGML_TYPE_Q4_0][GGML_TYPE_F32][1] =
+            ggml_webgpu_create_pipeline2(webgpu_ctx->device, proc_mul_mat_subgroup_matrix_q4_0_f32_vec.c_str(),
+                                         "mul_mat_subgroup_matrix_q4_0_f32_vec");
+    } else {
+        std::vector<wgpu::ConstantEntry> mul_mat_reg_tile_constants(3);
+        mul_mat_reg_tile_constants[0].key   = "TILE_K";
+        mul_mat_reg_tile_constants[0].value = WEBGPU_MUL_MAT_TILE_K;
+        mul_mat_reg_tile_constants[1].key   = "WORKGROUP_SIZE_M";
+        mul_mat_reg_tile_constants[1].value = WEBGPU_MUL_MAT_WG_SIZE_M;
+        mul_mat_reg_tile_constants[2].key   = "WORKGROUP_SIZE_N";
+        mul_mat_reg_tile_constants[2].value = WEBGPU_MUL_MAT_WG_SIZE_N;
+
+        std::map<std::string, std::string> reg_repls;
+        reg_repls["WEBGPU_TILE_M"] = std::to_string(WEBGPU_MUL_MAT_TILE_M);
+        reg_repls["WEBGPU_TILE_N"] = std::to_string(WEBGPU_MUL_MAT_TILE_N);
+
+        // Process each reg-tile shader with tile replacements.
+        // Keep the processed strings in-scope so .c_str() remains valid.
+        std::string proc_mul_mat_reg_tile_f32_f32 =
+            ggml_webgpu_process_shader_repls(wgsl_mul_mat_reg_tile_f32_f32, reg_repls);
+        std::string proc_mul_mat_reg_tile_f32_f32_vec =
+            ggml_webgpu_process_shader_repls(wgsl_mul_mat_reg_tile_f32_f32_vec, reg_repls);
+        std::string proc_mul_mat_reg_tile_f16_f32 =
+            ggml_webgpu_process_shader_repls(wgsl_mul_mat_reg_tile_f16_f32, reg_repls);
+        std::string proc_mul_mat_reg_tile_f16_f32_vec =
+            ggml_webgpu_process_shader_repls(wgsl_mul_mat_reg_tile_f16_f32_vec, reg_repls);
+        std::string proc_mul_mat_reg_tile_f16_f16 =
+            ggml_webgpu_process_shader_repls(wgsl_mul_mat_reg_tile_f16_f16, reg_repls);
+        std::string proc_mul_mat_reg_tile_f16_f16_vec =
+            ggml_webgpu_process_shader_repls(wgsl_mul_mat_reg_tile_f16_f16_vec, reg_repls);
+        std::string proc_mul_mat_reg_tile_q4_0_f32 =
+            ggml_webgpu_process_shader_repls(wgsl_mul_mat_reg_tile_q4_0_f32, reg_repls);
+        std::string proc_mul_mat_reg_tile_q4_0_f32_vec =
+            ggml_webgpu_process_shader_repls(wgsl_mul_mat_reg_tile_q4_0_f32_vec, reg_repls);
+
+        webgpu_ctx->mul_mat_pipelines[GGML_TYPE_F32][GGML_TYPE_F32][0] =
+            ggml_webgpu_create_pipeline2(webgpu_ctx->device, proc_mul_mat_reg_tile_f32_f32.c_str(),
+                                         "mul_mat_reg_tile_f32_f32", mul_mat_reg_tile_constants);
+        webgpu_ctx->mul_mat_pipelines[GGML_TYPE_F32][GGML_TYPE_F32][1] =
+            ggml_webgpu_create_pipeline2(webgpu_ctx->device, proc_mul_mat_reg_tile_f32_f32_vec.c_str(),
+                                         "mul_mat_reg_tile_f32_f32_vec", mul_mat_reg_tile_constants);
+        webgpu_ctx->mul_mat_pipelines[GGML_TYPE_F16][GGML_TYPE_F32][0] =
+            ggml_webgpu_create_pipeline2(webgpu_ctx->device, proc_mul_mat_reg_tile_f16_f32.c_str(),
+                                         "mul_mat_reg_tile_f16_f32", mul_mat_reg_tile_constants);
+        webgpu_ctx->mul_mat_pipelines[GGML_TYPE_F16][GGML_TYPE_F32][1] =
+            ggml_webgpu_create_pipeline2(webgpu_ctx->device, proc_mul_mat_reg_tile_f16_f32_vec.c_str(),
+                                         "mul_mat_reg_tile_f16_f32_vec", mul_mat_reg_tile_constants);
+        webgpu_ctx->mul_mat_pipelines[GGML_TYPE_F16][GGML_TYPE_F16][0] =
+            ggml_webgpu_create_pipeline2(webgpu_ctx->device, proc_mul_mat_reg_tile_f16_f16.c_str(),
+                                         "mul_mat_reg_tile_f16_f16", mul_mat_reg_tile_constants);
+        webgpu_ctx->mul_mat_pipelines[GGML_TYPE_F16][GGML_TYPE_F16][1] =
+            ggml_webgpu_create_pipeline2(webgpu_ctx->device, proc_mul_mat_reg_tile_f16_f16_vec.c_str(),
+                                         "mul_mat_reg_tile_f16_f16_vec", mul_mat_reg_tile_constants);
+        webgpu_ctx->mul_mat_pipelines[GGML_TYPE_Q4_0][GGML_TYPE_F32][0] =
+            ggml_webgpu_create_pipeline2(webgpu_ctx->device, proc_mul_mat_reg_tile_q4_0_f32.c_str(),
+                                         "mul_mat_reg_tile_q4_0_f32", mul_mat_reg_tile_constants);
+        webgpu_ctx->mul_mat_pipelines[GGML_TYPE_Q4_0][GGML_TYPE_F32][1] =
+            ggml_webgpu_create_pipeline2(webgpu_ctx->device, proc_mul_mat_reg_tile_q4_0_f32_vec.c_str(),
+                                         "mul_mat_reg_tile_q4_0_f32_vec", mul_mat_reg_tile_constants);
+    }
+
+    std::vector<wgpu::ConstantEntry> mul_mat_vec_constants(3);
+    mul_mat_vec_constants[0].key   = "WORKGROUP_SIZE";
+    mul_mat_vec_constants[0].value = WEBGPU_MUL_MAT_VEC_WG_SIZE;
+    mul_mat_vec_constants[1].key   = "TILE_K";
+    mul_mat_vec_constants[1].value = WEBGPU_MUL_MAT_VEC_TILE_K;
+    mul_mat_vec_constants[2].key   = "OUTPUTS_PER_WG";
+    mul_mat_vec_constants[2].value = WEBGPU_MUL_MAT_VEC_OUTPUTS_PER_WG;
+
+    webgpu_ctx->mul_mat_vec_pipelines[GGML_TYPE_F32][GGML_TYPE_F32][0] = ggml_webgpu_create_pipeline2(
+        webgpu_ctx->device, wgsl_mul_mat_vec_f32_f32, "mul_mat_vec_f32_f32", mul_mat_vec_constants);
+    webgpu_ctx->mul_mat_vec_pipelines[GGML_TYPE_F32][GGML_TYPE_F32][1] = ggml_webgpu_create_pipeline2(
+        webgpu_ctx->device, wgsl_mul_mat_vec_f32_f32_vec, "mul_mat_vec_f32_f32_vec", mul_mat_vec_constants);
+    webgpu_ctx->mul_mat_vec_pipelines[GGML_TYPE_F16][GGML_TYPE_F32][0] = ggml_webgpu_create_pipeline2(
+        webgpu_ctx->device, wgsl_mul_mat_vec_f16_f32, "mul_mat_vec_f16_f32", mul_mat_vec_constants);
+    webgpu_ctx->mul_mat_vec_pipelines[GGML_TYPE_F16][GGML_TYPE_F32][1] = ggml_webgpu_create_pipeline2(
+        webgpu_ctx->device, wgsl_mul_mat_vec_f16_f32_vec, "mul_mat_vec_f16_f32_vec", mul_mat_vec_constants);
+    webgpu_ctx->mul_mat_vec_pipelines[GGML_TYPE_F16][GGML_TYPE_F16][0] = ggml_webgpu_create_pipeline2(
+        webgpu_ctx->device, wgsl_mul_mat_vec_f16_f16, "mul_mat_vec_f16_f16", mul_mat_vec_constants);
+    webgpu_ctx->mul_mat_vec_pipelines[GGML_TYPE_F16][GGML_TYPE_F16][1] = ggml_webgpu_create_pipeline2(
+        webgpu_ctx->device, wgsl_mul_mat_vec_f16_f16_vec, "mul_mat_vec_f16_f16_vec", mul_mat_vec_constants);
+    webgpu_ctx->mul_mat_vec_pipelines[GGML_TYPE_Q4_0][GGML_TYPE_F32][0] = ggml_webgpu_create_pipeline2(
+        webgpu_ctx->device, wgsl_mul_mat_vec_q4_0_f32, "mul_mat_vec_q4_0_f32", mul_mat_vec_constants);
 }
 
 static void ggml_webgpu_init_set_rows_pipeline(webgpu_context & webgpu_ctx) {
@@ -2124,7 +2384,13 @@ static ggml_backend_dev_t ggml_backend_webgpu_reg_get_device(ggml_backend_reg_t
 
     webgpu_context ctx = reg_ctx->webgpu_ctx;
 
-    wgpu::RequestAdapterOptions options = {};
+    // TODO: track need for these toggles: https://issues.chromium.org/issues/42251215
+    const char * const          adapterEnabledToggles[] = { "vulkan_enable_f16_on_nvidia", "use_vulkan_memory_model" };
+    wgpu::DawnTogglesDescriptor adapterTogglesDesc;
+    adapterTogglesDesc.enabledToggles     = adapterEnabledToggles;
+    adapterTogglesDesc.enabledToggleCount = 2;
+    wgpu::RequestAdapterOptions options   = {};
+    options.nextInChain                   = &adapterTogglesDesc;
     ctx->instance.WaitAny(ctx->instance.RequestAdapter(
                               &options, wgpu::CallbackMode::AllowSpontaneous,
                               [&ctx](wgpu::RequestAdapterStatus status, wgpu::Adapter adapter, const char * message) {
@@ -2140,12 +2406,46 @@ static ggml_backend_dev_t ggml_backend_webgpu_reg_get_device(ggml_backend_reg_t
     ctx->adapter.GetLimits(&ctx->limits);
     ctx->max_wg_size_x = 288;  // default value
 
-    wgpu::AdapterInfo info{};
+    wgpu::AdapterInfo                            info{};
+    wgpu::AdapterPropertiesSubgroupMatrixConfigs subgroup_matrix_configs{};
+    if (ctx->adapter.HasFeature(wgpu::FeatureName::ChromiumExperimentalSubgroupMatrix)) {
+        info.nextInChain = &subgroup_matrix_configs;
+    }
     ctx->adapter.GetInfo(&info);
 
+    wgpu::SupportedFeatures features;
+    ctx->adapter.GetFeatures(&features);
+    // we require f16 support
+    GGML_ASSERT(ctx->adapter.HasFeature(wgpu::FeatureName::ShaderF16));
+
+    // Only support square f16 matrices of size 8 or 16 for now
+    bool valid_subgroup_matrix_config = false;
+    if (ctx->adapter.HasFeature(wgpu::FeatureName::ChromiumExperimentalSubgroupMatrix)) {
+        for (size_t i = 0; i < subgroup_matrix_configs.configCount; i++) {
+            const wgpu::SubgroupMatrixConfig config = subgroup_matrix_configs.configs[i];
+            if (config.M == config.N && config.N == config.K && (config.K == 8 || config.K == 16) &&
+                config.componentType == wgpu::SubgroupMatrixComponentType::F16 &&
+                config.resultComponentType == wgpu::SubgroupMatrixComponentType::F16) {
+                ctx->subgroup_matrix_config  = config;
+                valid_subgroup_matrix_config = true;
+                break;
+            }
+        }
+    }
+
+    // For subgroup matrix code to be the most efficient, we would like the subgroup size to be consistent and accurate.
+    // Unfortunately, that is not possible, so we use the maximum subgroup size reported by the adapter.
+    ctx->subgroup_size            = info.subgroupMaxSize;
+    ctx->supports_subgroup_matrix = valid_subgroup_matrix_config;
+
     // Initialize device
     std::vector<wgpu::FeatureName> required_features = { wgpu::FeatureName::ShaderF16,
                                                          wgpu::FeatureName::ImplicitDeviceSynchronization };
+    if (ctx->supports_subgroup_matrix) {
+        required_features.push_back(wgpu::FeatureName::Subgroups);
+        required_features.push_back(wgpu::FeatureName::ChromiumExperimentalSubgroupMatrix);
+    }
+
 #ifdef GGML_WEBGPU_GPU_PROFILE
     required_features.push_back(wgpu::FeatureName::TimestampQuery);
 #endif
index 251051eaeca0f8c74ac7db560ed392b8431ba978..ed8068d416ebfab6f81fc51c4646a8920d4b8e08 100755 (executable)
@@ -72,9 +72,12 @@ def generate_variants(fname, input_dir, output_dir, outfile):
         except ValueError:
             decls_map = {}
 
-        with open(os.path.join(input_dir, "common_decls.tmpl"), "r", encoding="utf-8") as f:
-            common_decls = f.read()
-        decls_map.update(parse_decls(common_decls))
+        for fname in sorted(os.listdir(input_dir)):
+            if fname.endswith(".tmpl"):
+                tmpl_path = os.path.join(input_dir, fname)
+                with open(tmpl_path, "r", encoding="utf-8") as f_tmpl:
+                    decls = f_tmpl.read()
+                    decls_map.update(parse_decls(decls))
 
         shader_template = extract_block(text, "SHADER")
         for variant in variants:
index 141db9b39d9579f5e32270afc555a240d60be02e..0f8e6e5ac3dd6fb3dc457eaaa27369683d76a115 100644 (file)
@@ -864,8 +864,8 @@ struct MulMatParams {
     broadcast3: u32
 };
 
-@group(0) @binding(0) var<storage, read_write> src0: array<{{SRC0_TYPE}}>; // N rows, K columns
-@group(0) @binding(1) var<storage, read_write> src1: array<{{SRC1_TYPE}}>; // M rows, K columns (transposed)
+@group(0) @binding(0) var<storage, read_write> src0: array<{{SRC0_TYPE}}>; // M rows, K columns
+@group(0) @binding(1) var<storage, read_write> src1: array<{{SRC1_TYPE}}>; // K rows, N columns (transposed)
 @group(0) @binding(2) var<storage, read_write> dst: array<f32>; // M rows, N columns
 
 @group(0) @binding(3) var<uniform> params: MulMatParams;
@@ -891,8 +891,8 @@ fn main(@builtin(global_invocation_id) global_id: vec3<u32>) {
 
     let dst2_rem = dst3_rem % dst2_stride;
 
-    let row = dst2_rem / params.n; // output row
-    let col = dst2_rem % params.n; // output column
+    let row = dst2_rem / params.m; // output row
+    let col = dst2_rem % params.m; // output column
 
     let src0_idx_base = params.offset_src0 + src03_idx * params.stride_03 + src02_idx * params.stride_02 + col * params.stride_01;
     let src1_idx_base = params.offset_src1 + src13_idx * params.stride_13 + src12_idx * params.stride_12 + row * params.stride_11;
@@ -901,7 +901,7 @@ fn main(@builtin(global_invocation_id) global_id: vec3<u32>) {
     for (var i: u32 = 0u; i < params.k/{{BLOCK_SIZE}}; i = i + 1u) {
         sum += multiply_add(src0_idx_base, src1_idx_base, i);
     }
-    dst[params.offset_dst + dst3_idx * dst3_stride + dst2_idx * dst2_stride + row * params.n + col] = sum;
+    dst[params.offset_dst + dst3_idx * dst3_stride + dst2_idx * dst2_stride + row * params.m + col] = sum;
 }
 
 #end(SHADER)
diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_decls.tmpl b/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_decls.tmpl
new file mode 100644 (file)
index 0000000..109ff8d
--- /dev/null
@@ -0,0 +1,97 @@
+#decl(SHMEM_VEC)
+fn store_shmem(val: vec4<f16>, idx: u32) {
+    shmem[idx] = val.x;
+    shmem[idx + 1] = val.y;
+    shmem[idx + 2] = val.z;
+    shmem[idx + 3] = val.w;
+}
+#enddecl(SHMEM_VEC)
+
+#decl(SHMEM_SCALAR)
+fn store_shmem(val: f16, idx: u32) {
+    shmem[idx] = val;
+}
+#enddecl(SHMEM_SCALAR)
+
+#decl(INIT_SRC0_SHMEM_FLOAT)
+
+fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u32) {
+    for (var elem_idx = thread_id * {{VEC_SIZE}}; elem_idx < TILE_SRC0_SHMEM; elem_idx += TOTAL_WORKGROUP_SIZE * {{VEC_SIZE}}) {
+        let tile_m = elem_idx / TILE_K;
+        let tile_k = elem_idx % TILE_K;
+        let global_m = offset_m + tile_m;
+        let global_k = k_outer + tile_k;
+        let src0_idx = batch_offset + global_m * params.stride_01 + global_k;
+        let src0_val = select( // taking a slight performance hit to avoid oob
+            {{SRC0_TYPE}}(0.0),
+            src0[src0_idx/{{VEC_SIZE}}],
+            global_m < params.m && global_k < params.k);
+        store_shmem({{SHMEM_TYPE}}(src0_val), elem_idx);
+    }
+}
+
+#enddecl(INIT_SRC0_SHMEM_FLOAT)
+
+#decl(INIT_SRC1_SHMEM)
+
+fn init_shmem_src1(thread_id: u32, batch_offset: u32, offset_n: u32, k_outer: u32) {
+    for (var elem_idx = thread_id * {{VEC_SIZE}}; elem_idx < TILE_SRC1_SHMEM; elem_idx += TOTAL_WORKGROUP_SIZE * {{VEC_SIZE}}) {
+        let tile_n = elem_idx / TILE_K;
+        let tile_k = elem_idx % TILE_K;
+        let global_n = offset_n + tile_n;
+        let global_k = k_outer + tile_k;
+        let src1_idx = batch_offset + global_n * params.stride_11 + global_k;
+        let src1_val = select(
+            {{SRC1_TYPE}}(0.0),
+            src1[src1_idx/{{VEC_SIZE}}],
+            global_n < params.n && global_k < params.k);
+        store_shmem({{SHMEM_TYPE}}(src1_val), TILE_SRC0_SHMEM + elem_idx);
+    }
+}
+
+#enddecl(INIT_SRC1_SHMEM)
+
+#decl(INIT_SRC0_SHMEM_Q4_0)
+
+const BLOCK_SIZE = 32u;
+// the number of blocks per k-tile. Note that this currently only works if TILE_K is a multiple of BLOCK_SIZE, which may need to be rethought for larger quantized types.
+override BLOCKS_K = TILE_K/BLOCK_SIZE;
+const NQ = 16u;
+const F16_PER_BLOCK = 9u; // 1 scale + 8x4 packed weights
+const WEIGHTS_PER_F16 = 4u; // 4 weights per f16
+const F16_PER_THREAD = NQ / WEIGHTS_PER_F16;
+
+fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u32) {
+    for (var i = thread_id * NQ; i < TILE_SRC0_SHMEM; i += TOTAL_WORKGROUP_SIZE * NQ) {
+        let blck_idx = i / BLOCK_SIZE;
+        let block_offset = (i % BLOCK_SIZE) / WEIGHTS_PER_F16;
+        let shmem_idx = blck_idx * BLOCK_SIZE + block_offset * 2u;
+
+        let tile_m = blck_idx / BLOCKS_K;
+        let global_m = offset_m + tile_m;
+        let block_k = blck_idx % BLOCKS_K;
+        let global_k = k_outer / BLOCK_SIZE + block_k;
+
+        if (global_m < params.m && global_k < params.k / BLOCK_SIZE) {
+            let src0_idx = batch_offset + global_m * params.stride_01 + global_k;
+            let scale_idx = src0_idx * F16_PER_BLOCK;
+            let d = src0[scale_idx];
+
+            for (var j = 0u; j < F16_PER_THREAD; j += 2) {
+                let q_0 = src0[scale_idx + 1u + block_offset + j];
+                let q_1 = src0[scale_idx + 1u + block_offset + j + 1];
+
+                let q_packed = bitcast<u32>(vec2(q_0, q_1));
+                for (var k = 0u; k < 4u; k++) {
+                    let q_byte = get_byte(q_packed, k);
+                    let q_hi = (f16((q_byte >> 4) & 0xF) - 8.0) * d;
+                    let q_lo = (f16(q_byte & 0xF) - 8.0) * d;
+                    shmem[shmem_idx + j * 2 + k] = q_lo;
+                    shmem[shmem_idx + j * 2 + k + 16u] = q_hi;
+                }
+            }
+        }
+    }
+}
+
+#enddecl(INIT_SRC0_SHMEM_Q4_0)
diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_reg_tile.tmpl.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_reg_tile.tmpl.wgsl
new file mode 100644 (file)
index 0000000..6b1dd26
--- /dev/null
@@ -0,0 +1,247 @@
+#define(VARIANTS)
+[
+  {
+    "SHADER_SUFFIX": "f32_f32_vec",
+    "REPLS": {
+      "SRC0_TYPE" : "vec4<f32>",
+      "SRC1_TYPE" : "vec4<f32>",
+      "DST_TYPE" : "vec4<f32>",
+      "SHMEM_TYPE" : "vec4<f16>",
+      "VEC_SIZE" : 4,
+    },
+    "DECLS": ["VEC", "SHMEM_VEC", "INIT_SRC0_SHMEM_FLOAT", "INIT_SRC1_SHMEM"]
+  },
+  {
+    "SHADER_SUFFIX": "f32_f32",
+    "REPLS": {
+      "SRC0_TYPE" : "f32",
+      "SRC1_TYPE" : "f32",
+      "DST_TYPE" : "f32",
+      "SHMEM_TYPE" : "f16",
+      "VEC_SIZE" : 1,
+    },
+    "DECLS": ["SCALAR", "SHMEM_SCALAR", "INIT_SRC0_SHMEM_FLOAT", "INIT_SRC1_SHMEM"]
+  },
+  {
+    "SHADER_SUFFIX": "f16_f32_vec",
+    "REPLS": {
+      "SRC0_TYPE" : "vec4<f16>",
+      "SRC1_TYPE" : "vec4<f32>",
+      "DST_TYPE" : "vec4<f32>",
+      "SHMEM_TYPE" : "vec4<f16>",
+      "VEC_SIZE" : 4,
+    },
+    "DECLS": ["VEC", "SHMEM_VEC", "INIT_SRC0_SHMEM_FLOAT", "INIT_SRC1_SHMEM"]
+  },
+  {
+    "SHADER_SUFFIX": "f16_f32",
+    "REPLS": {
+      "SRC0_TYPE" : "f16",
+      "SRC1_TYPE" : "f32",
+      "DST_TYPE" : "f32",
+      "SHMEM_TYPE" : "f16",
+      "VEC_SIZE" : 1,
+    },
+    "DECLS": ["SCALAR", "SHMEM_SCALAR", "INIT_SRC0_SHMEM_FLOAT", "INIT_SRC1_SHMEM"]
+  },
+  {
+    "SHADER_SUFFIX": "f16_f16_vec",
+    "REPLS": {
+      "SRC0_TYPE" : "vec4<f16>",
+      "SRC1_TYPE" : "vec4<f16>",
+      "DST_TYPE" : "vec4<f32>",
+      "SHMEM_TYPE" : "vec4<f16>",
+      "VEC_SIZE" : 4,
+    },
+    "DECLS": ["VEC", "SHMEM_VEC", "INIT_SRC0_SHMEM_FLOAT", "INIT_SRC1_SHMEM"]
+  },
+  {
+    "SHADER_SUFFIX": "f16_f16",
+    "REPLS": {
+      "SRC0_TYPE" : "f16",
+      "SRC1_TYPE" : "f16",
+      "DST_TYPE" : "f32",
+      "SHMEM_TYPE" : "f16",
+      "VEC_SIZE" : 1,
+    },
+    "DECLS": ["SCALAR", "SHMEM_SCALAR", "INIT_SRC0_SHMEM_FLOAT", "INIT_SRC1_SHMEM"]
+  },
+  {
+    "SHADER_SUFFIX": "q4_0_f32_vec",
+    "REPLS": {
+      "SRC0_TYPE" : "f16",
+      "SRC1_TYPE" : "vec4<f32>",
+      "DST_TYPE" : "vec4<f32>",
+      "SHMEM_TYPE" : "vec4<f16>",
+      "VEC_SIZE" : 4,
+    },
+    "DECLS": ["BYTE_HELPERS", "VEC", "SHMEM_VEC", "INIT_SRC0_SHMEM_Q4_0", "INIT_SRC1_SHMEM"]
+  },
+  {
+    "SHADER_SUFFIX": "q4_0_f32",
+    "REPLS": {
+      "SRC0_TYPE" : "f16",
+      "SRC1_TYPE" : "f32",
+      "DST_TYPE" : "f32",
+      "SHMEM_TYPE" : "f16",
+      "VEC_SIZE" : 1,
+    },
+    "DECLS": ["BYTE_HELPERS", "SCALAR", "SHMEM_SCALAR", "INIT_SRC0_SHMEM_Q4_0", "INIT_SRC1_SHMEM"]
+  }
+]
+
+#end(VARIANTS)
+
+#define(DECLS)
+
+#decl(VEC)
+fn store_val(acc: array<array<f16, TILE_N>, TILE_M>, tn: u32, tm: u32) -> vec4<f32> {
+    return vec4<f32>(f32(acc[tm][tn]), f32(acc[tm + 1][tn]), f32(acc[tm + 2][tn]), f32(acc[tm + 3][tn]));
+}
+#enddecl(VEC)
+
+#decl(SCALAR)
+fn store_val(acc: array<array<f16, TILE_N>, TILE_M>, tn: u32, tm: u32) -> f32 {
+    return f32(acc[tm][tn]);
+}
+#enddecl(SCALAR)
+
+#end(DECLS)
+
+#define(SHADER)
+enable f16;
+
+struct MulMatParams {
+    offset_src0: u32,
+    offset_src1: u32,
+    offset_dst: u32,
+    m: u32,
+    n: u32,
+    k: u32,
+    stride_01: u32,
+    stride_11: u32,
+    stride_02: u32,
+    stride_12: u32,
+    stride_03: u32,
+    stride_13: u32,
+    bs02: u32,
+    bs03: u32,
+    broadcast2: u32,
+    broadcast3: u32
+};
+
+@group(0) @binding(0) var<storage, read_write> src0: array<{{SRC0_TYPE}}>; // M rows, K columns
+@group(0) @binding(1) var<storage, read_write> src1: array<{{SRC1_TYPE}}>; // K rows, N columns (transposed)
+@group(0) @binding(2) var<storage, read_write> dst: array<{{DST_TYPE}}>; // M rows, N columns (transposed)
+
+@group(0) @binding(3) var<uniform> params: MulMatParams;
+
+DECLS
+
+fn get_local_n(thread_id: u32) -> u32 {
+    return thread_id / WORKGROUP_SIZE_M;
+}
+fn get_local_m(thread_id: u32) -> u32 {
+    return thread_id % WORKGROUP_SIZE_M;
+}
+
+// TILE_M must be multiple of 4 for vec4 loads
+const TILE_M = {{WEBGPU_TILE_M}}u;
+const TILE_N = {{WEBGPU_TILE_N}}u;
+
+override WORKGROUP_SIZE_M: u32;
+override WORKGROUP_SIZE_N: u32;
+override TILE_K: u32;
+
+override TOTAL_WORKGROUP_SIZE = WORKGROUP_SIZE_M * WORKGROUP_SIZE_N;
+override TILE_SRC0_SHMEM = TILE_K * WORKGROUP_SIZE_M * TILE_M;
+override TILE_SRC1_SHMEM = TILE_K * WORKGROUP_SIZE_N * TILE_N;
+
+var<workgroup> shmem: array<f16, TILE_SRC0_SHMEM + TILE_SRC1_SHMEM>;
+
+@compute @workgroup_size(TOTAL_WORKGROUP_SIZE)
+fn main(@builtin(workgroup_id) wg_id: vec3<u32>,
+        @builtin(local_invocation_id) local_id: vec3<u32>) {
+
+    let thread_id = local_id.x;
+    let local_m = get_local_m(thread_id);
+    let local_n = get_local_n(thread_id);
+
+    let wg_n_count = (params.n + WORKGROUP_SIZE_N * TILE_N - 1u) / (WORKGROUP_SIZE_N * TILE_N);
+    let wg_m_count = (params.m + WORKGROUP_SIZE_M * TILE_M - 1u) / (WORKGROUP_SIZE_M * TILE_M);
+    let wg_per_matrix = wg_m_count * wg_n_count;
+
+    let batch_idx = wg_id.x / wg_per_matrix;
+
+    let wg_in_batch = wg_id.x % wg_per_matrix;
+    let wg_m = wg_in_batch % wg_m_count;
+    let wg_n = wg_in_batch / wg_m_count;
+
+    let output_row_base = wg_m * WORKGROUP_SIZE_M * TILE_M + local_m * TILE_M;
+    let output_col_base = wg_n * WORKGROUP_SIZE_N * TILE_N + local_n * TILE_N;
+
+    let dst2_stride = params.m * params.n;
+    let dst3_stride = dst2_stride * params.bs02 * params.broadcast2;
+
+    let dst3_idx = batch_idx / (params.bs02 * params.broadcast2);
+    let src03_idx = dst3_idx / params.broadcast3;
+    let src13_idx = dst3_idx;
+    let dst2_idx = batch_idx % (params.bs02 * params.broadcast2);
+    let src02_idx = dst2_idx / params.broadcast2;
+    let src12_idx = dst2_idx;
+
+    let src0_batch_offset = params.offset_src0 + src03_idx * params.stride_03 + src02_idx * params.stride_02;
+    let src1_batch_offset = params.offset_src1 + src13_idx * params.stride_13 + src12_idx * params.stride_12;
+
+    let offset_m = wg_m * WORKGROUP_SIZE_M * TILE_M;
+    let offset_n = wg_n * WORKGROUP_SIZE_N * TILE_N;
+
+    var acc: array<array<f16, TILE_N>, TILE_M>;
+
+    for (var k_outer = 0u; k_outer < params.k; k_outer += TILE_K) {
+
+        // see mul_mat_decls.tmpl
+        init_shmem_src0(thread_id, src0_batch_offset, offset_m, k_outer);
+        init_shmem_src1(thread_id, src1_batch_offset, offset_n, k_outer);
+
+        workgroupBarrier();
+
+        let k_end = min(TILE_K, params.k - k_outer);
+
+        for (var k_inner = 0u; k_inner < k_end; k_inner++) {
+            var src0_tile: array<f16, TILE_M>;
+            for (var tm = 0u; tm < TILE_M; tm++) {
+                let src0_m = local_m * TILE_M + tm;
+                let src0_idx = k_inner + src0_m * TILE_K;
+                src0_tile[tm] = shmem[src0_idx];
+            }
+            for (var tn = 0u; tn < TILE_N; tn++) {
+                let src1_n = local_n * TILE_N + tn;
+                let src1_idx = src1_n * TILE_K + k_inner;
+                let src1_val = shmem[TILE_SRC0_SHMEM + src1_idx];
+                for (var tm = 0u; tm < TILE_M; tm++) {
+                      acc[tm][tn] += src0_tile[tm] * src1_val;
+                }
+            }
+        }
+
+        workgroupBarrier();
+    }
+
+    let dst_batch_offset = params.offset_dst + dst3_idx * dst3_stride + dst2_idx * dst2_stride;
+
+    for (var tn = 0u; tn < TILE_N; tn++) {
+        let global_col = output_col_base + tn;
+        if (global_col < params.n) {
+            for (var tm = 0u; tm < TILE_M; tm += {{VEC_SIZE}}) {
+                let global_row = output_row_base + tm;
+                if (global_row < params.m) {
+                    let dst_idx = dst_batch_offset + global_col * params.m + global_row;
+                    dst[dst_idx/{{VEC_SIZE}}] = store_val(acc, tn, tm);
+                }
+            }
+        }
+    }
+}
+
+#end(SHADER)
diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_subgroup_matrix.tmpl.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_subgroup_matrix.tmpl.wgsl
new file mode 100644 (file)
index 0000000..47c8ce3
--- /dev/null
@@ -0,0 +1,302 @@
+#define(VARIANTS)
+[
+  {
+    "SHADER_SUFFIX": "f32_f32_vec",
+    "REPLS": {
+      "SRC0_TYPE" : "vec4<f32>",
+      "SRC1_TYPE" : "vec4<f32>",
+      "DST_TYPE" : "vec4<f32>",
+      "SHMEM_TYPE" : "vec4<f16>",
+      "VEC_SIZE" : 4,
+    },
+    "DECLS": ["VEC", "SHMEM_VEC", "INIT_SRC0_SHMEM_FLOAT", "INIT_SRC1_SHMEM"]
+  },
+  {
+    "SHADER_SUFFIX": "f32_f32",
+    "REPLS": {
+      "SRC0_TYPE" : "f32",
+      "SRC1_TYPE" : "f32",
+      "DST_TYPE" : "f32",
+      "SHMEM_TYPE" : "f16",
+      "VEC_SIZE" : 1,
+    },
+    "DECLS": ["SCALAR", "SHMEM_SCALAR", "INIT_SRC0_SHMEM_FLOAT", "INIT_SRC1_SHMEM"]
+  },
+  {
+    "SHADER_SUFFIX": "f16_f32_vec",
+    "REPLS": {
+      "SRC0_TYPE" : "vec4<f16>",
+      "SRC1_TYPE" : "vec4<f32>",
+      "DST_TYPE" : "vec4<f32>",
+      "SHMEM_TYPE" : "vec4<f16>",
+      "VEC_SIZE" : 4,
+    },
+    "DECLS": ["VEC", "SHMEM_VEC", "INIT_SRC0_SHMEM_FLOAT", "INIT_SRC1_SHMEM"]
+  },
+  {
+    "SHADER_SUFFIX": "f16_f32",
+    "REPLS": {
+      "SRC0_TYPE" : "f16",
+      "SRC1_TYPE" : "f32",
+      "DST_TYPE" : "f32",
+      "SHMEM_TYPE" : "f16",
+      "VEC_SIZE" : 1,
+    },
+    "DECLS": ["SCALAR", "SHMEM_SCALAR", "INIT_SRC0_SHMEM_FLOAT", "INIT_SRC1_SHMEM"]
+  },
+  {
+    "SHADER_SUFFIX": "f16_f16_vec",
+    "REPLS": {
+      "SRC0_TYPE" : "vec4<f16>",
+      "SRC1_TYPE" : "vec4<f16>",
+      "DST_TYPE" : "vec4<f32>",
+      "SHMEM_TYPE" : "vec4<f16>",
+      "VEC_SIZE" : 4,
+    },
+    "DECLS": ["VEC", "SHMEM_VEC", "INIT_SRC0_SHMEM_FLOAT", "INIT_SRC1_SHMEM"]
+  },
+  {
+    "SHADER_SUFFIX": "f16_f16",
+    "REPLS": {
+      "SRC0_TYPE" : "f16",
+      "SRC1_TYPE" : "f16",
+      "DST_TYPE" : "f32",
+      "SHMEM_TYPE" : "f16",
+      "VEC_SIZE" : 1,
+    },
+    "DECLS": ["SCALAR", "SHMEM_SCALAR", "INIT_SRC0_SHMEM_FLOAT", "INIT_SRC1_SHMEM"]
+  },
+  {
+    "SHADER_SUFFIX": "q4_0_f32_vec",
+    "REPLS": {
+      "SRC0_TYPE" : "f16",
+      "SRC1_TYPE" : "vec4<f32>",
+      "DST_TYPE" : "vec4<f32>",
+      "SHMEM_TYPE" : "vec4<f16>",
+      "VEC_SIZE" : 4,
+    },
+    "DECLS": ["BYTE_HELPERS", "VEC", "SHMEM_VEC", "INIT_SRC0_SHMEM_Q4_0", "INIT_SRC1_SHMEM"]
+  },
+  {
+    "SHADER_SUFFIX": "q4_0_f32",
+    "REPLS": {
+      "SRC0_TYPE" : "f16",
+      "SRC1_TYPE" : "f32",
+      "DST_TYPE" : "f32",
+      "SHMEM_TYPE" : "f16",
+      "VEC_SIZE" : 1,
+    },
+    "DECLS": ["BYTE_HELPERS", "SCALAR", "SHMEM_SCALAR", "INIT_SRC0_SHMEM_Q4_0", "INIT_SRC1_SHMEM"]
+  }
+]
+
+#end(VARIANTS)
+
+#define(DECLS)
+
+#decl(VEC)
+fn store_dst(shmem_idx: u32, dst_idx: u32) {
+    dst[dst_idx] = vec4<f32>(
+        f32(shmem[shmem_idx]),
+        f32(shmem[shmem_idx + 1]),
+        f32(shmem[shmem_idx + 2]),
+        f32(shmem[shmem_idx + 3])
+    );
+}
+#enddecl(VEC)
+
+#decl(SCALAR)
+fn store_dst(shmem_idx: u32, dst_idx: u32) {
+    dst[dst_idx] = f32(shmem[shmem_idx]);
+}
+#enddecl(SCALAR)
+
+#end(DECLS)
+
+#define(SHADER)
+diagnostic(off, chromium.subgroup_matrix_uniformity);
+enable f16;
+enable subgroups;
+enable chromium_experimental_subgroup_matrix;
+
+struct MulMatParams {
+    offset_src0: u32,
+    offset_src1: u32,
+    offset_dst: u32,
+    m: u32,
+    n: u32,
+    k: u32,
+    stride_01: u32,
+    stride_11: u32,
+    stride_02: u32,
+    stride_12: u32,
+    stride_03: u32,
+    stride_13: u32,
+    bs02: u32,
+    bs03: u32,
+    broadcast2: u32,
+    broadcast3: u32
+};
+
+@group(0) @binding(0) var<storage, read_write> src0: array<{{SRC0_TYPE}}>; // M rows, K columns
+@group(0) @binding(1) var<storage, read_write> src1: array<{{SRC1_TYPE}}>; // K rows, N columns (transposed)
+@group(0) @binding(2) var<storage, read_write> dst: array<{{DST_TYPE}}>; // M rows, N columns (transposed)
+
+@group(0) @binding(3) var<uniform> params: MulMatParams;
+
+DECLS
+
+// Note: These are string interpolated at build time, cannot use override constants due to limitations in
+// current Dawn version type definitions/matrix load requirements for constant memory sizes.
+const SUBGROUP_M = {{WEBGPU_SUBGROUP_M}}u;
+const SUBGROUP_N = {{WEBGPU_SUBGROUP_N}}u;
+// For portability we assume the max subgroup size, meaning some subgroups will be masked out if the
+// runtime subgroup size is smaller.
+const MAX_SUBGROUP_SIZE = {{WEBGPU_MAX_SUBGROUP_SIZE}}u;
+
+const EXPECTED_SUBGROUPS = SUBGROUP_M * SUBGROUP_N;
+
+const SUBGROUP_MATRIX_M_SIZE = {{WEBGPU_SG_MAT_M_SIZE}}u;
+const SUBGROUP_MATRIX_N_SIZE = {{WEBGPU_SG_MAT_N_SIZE}}u;
+const SUBGROUP_MATRIX_K_SIZE = {{WEBGPU_SG_MAT_K_SIZE}}u;
+
+const SUBGROUP_MATRIX_M = {{WEBGPU_SUBGROUP_MATRIX_M}}u;
+const SUBGROUP_MATRIX_N = {{WEBGPU_SUBGROUP_MATRIX_N}}u;
+
+const TILE_K = {{WEBGPU_TILE_K}}u;
+
+const WG_M_SG_TILE_SIZE = SUBGROUP_M * SUBGROUP_MATRIX_M * SUBGROUP_MATRIX_M_SIZE;
+const WG_N_SG_TILE_SIZE = SUBGROUP_N * SUBGROUP_MATRIX_N * SUBGROUP_MATRIX_N_SIZE;
+
+const TOTAL_WORKGROUP_SIZE = SUBGROUP_M * SUBGROUP_N * MAX_SUBGROUP_SIZE;
+const TILE_SRC0_SHMEM = TILE_K * SUBGROUP_M * SUBGROUP_MATRIX_M * SUBGROUP_MATRIX_M_SIZE;
+const TILE_SRC1_SHMEM = TILE_K * SUBGROUP_N * SUBGROUP_MATRIX_N * SUBGROUP_MATRIX_N_SIZE;
+
+const SG_MAT_ACCUM_SHMEM = SUBGROUP_M * SUBGROUP_MATRIX_M * SUBGROUP_N * SUBGROUP_MATRIX_N * SUBGROUP_MATRIX_M_SIZE * SUBGROUP_MATRIX_N_SIZE;
+
+// We reuse shmem for accumulation matrices
+const SHMEM_SIZE = max(TILE_SRC0_SHMEM + TILE_SRC1_SHMEM, SG_MAT_ACCUM_SHMEM);
+
+var<workgroup> shmem: array<f16, SHMEM_SIZE>;
+
+@compute @workgroup_size(TOTAL_WORKGROUP_SIZE)
+fn main(@builtin(workgroup_id) wg_id: vec3<u32>,
+        @builtin(local_invocation_id) local_id: vec3<u32>,
+        @builtin(subgroup_id) subgroup_id: u32) {
+
+    let thread_id = local_id.x;
+    let subgroup_m = subgroup_id % SUBGROUP_M;
+    let subgroup_n = subgroup_id / SUBGROUP_M;
+
+    let wg_m_count = (params.m + WG_M_SG_TILE_SIZE - 1) / WG_M_SG_TILE_SIZE;
+    let wg_n_count = (params.n + WG_N_SG_TILE_SIZE - 1) / WG_N_SG_TILE_SIZE;
+    let wg_per_matrix = wg_m_count * wg_n_count;
+
+    let batch_idx = wg_id.x / wg_per_matrix;
+
+    let wg_in_batch = wg_id.x % wg_per_matrix;
+    let wg_m = wg_in_batch % wg_m_count;
+    let wg_n = wg_in_batch / wg_m_count;
+
+    let dst2_stride = params.m * params.n;
+    let dst3_stride = dst2_stride * params.bs02 * params.broadcast2;
+
+    let dst3_idx = batch_idx / (params.bs02 * params.broadcast2);
+    let src03_idx = dst3_idx / params.broadcast3;
+    let src13_idx = dst3_idx;
+    let dst2_idx = batch_idx % (params.bs02 * params.broadcast2);
+    let src02_idx = dst2_idx / params.broadcast2;
+    let src12_idx = dst2_idx;
+
+    let src0_batch_offset = params.offset_src0 + src03_idx * params.stride_03 + src02_idx * params.stride_02;
+    let src1_batch_offset = params.offset_src1 + src13_idx * params.stride_13 + src12_idx * params.stride_12;
+
+    let offset_m = wg_m * SUBGROUP_M * SUBGROUP_MATRIX_M * SUBGROUP_MATRIX_M_SIZE;
+    let offset_n = wg_n * SUBGROUP_N * SUBGROUP_MATRIX_N * SUBGROUP_MATRIX_N_SIZE;
+
+    var acc_sg_mat : array<array<subgroup_matrix_result<f16, SUBGROUP_MATRIX_N_SIZE, SUBGROUP_MATRIX_M_SIZE>, SUBGROUP_MATRIX_N>, SUBGROUP_MATRIX_M>;
+
+    for (var k_outer = 0u; k_outer < params.k; k_outer += TILE_K) {
+
+        // see mul_mat_decls.tmpl
+        init_shmem_src0(thread_id, src0_batch_offset, offset_m, k_outer);
+        init_shmem_src1(thread_id, src1_batch_offset, offset_n, k_outer);
+
+        workgroupBarrier();
+
+        if (subgroup_id < EXPECTED_SUBGROUPS) {
+
+            for (var k_inner = 0u; k_inner < TILE_K; k_inner += SUBGROUP_MATRIX_K_SIZE) {
+
+                let src0_shmem_idx_base = subgroup_m * SUBGROUP_MATRIX_M * SUBGROUP_MATRIX_M_SIZE * TILE_K + k_inner;
+                var src0_sg_mats: array<subgroup_matrix_left<f16, SUBGROUP_MATRIX_K_SIZE, SUBGROUP_MATRIX_M_SIZE>, SUBGROUP_MATRIX_M>;
+                for (var m = 0u; m < SUBGROUP_MATRIX_M; m++) {
+                    src0_sg_mats[m] = subgroupMatrixLoad<subgroup_matrix_left<f16, SUBGROUP_MATRIX_K_SIZE, SUBGROUP_MATRIX_M_SIZE>>(
+                        &shmem,
+                        src0_shmem_idx_base + m * SUBGROUP_MATRIX_M_SIZE * TILE_K,
+                        false,
+                        TILE_K
+                    );
+                }
+
+                let src1_shmem_idx_base = TILE_SRC0_SHMEM + subgroup_n * SUBGROUP_MATRIX_N * SUBGROUP_MATRIX_N_SIZE * TILE_K + k_inner;
+                for (var n = 0u; n < SUBGROUP_MATRIX_N; n++) {
+                    let src1_sg_mat = subgroupMatrixLoad<subgroup_matrix_right<f16, SUBGROUP_MATRIX_N_SIZE, SUBGROUP_MATRIX_K_SIZE>>(
+                        &shmem,
+                        src1_shmem_idx_base + n * SUBGROUP_MATRIX_N_SIZE * TILE_K,
+                        true,
+                        TILE_K
+                    );
+                    for (var m = 0u; m < SUBGROUP_MATRIX_M; m++) {
+                        acc_sg_mat[m][n] = subgroupMatrixMultiplyAccumulate(src0_sg_mats[m], src1_sg_mat, acc_sg_mat[m][n]);
+                    }
+                }
+            }
+        }
+
+        workgroupBarrier();
+    }
+
+    let dst_batch_offset = params.offset_dst + dst3_idx * dst3_stride + dst2_idx * dst2_stride;
+
+    // Stage the subgroup matrix tiles into shared memory
+    // This uses WG_M_SG_TILE_SIZE as the stride (number of columns in the workgroup tile).
+    let WG_TILE_STRIDE = WG_M_SG_TILE_SIZE;
+    let tile_row_base_local = subgroup_n * SUBGROUP_MATRIX_N * SUBGROUP_MATRIX_N_SIZE;
+    let tile_col_base_local = subgroup_m * SUBGROUP_MATRIX_M * SUBGROUP_MATRIX_M_SIZE;
+
+    if (subgroup_id < EXPECTED_SUBGROUPS) { // 2-5% performance hit :(
+        for (var n = 0u; n < SUBGROUP_MATRIX_N; n++) {
+            for (var m = 0u; m < SUBGROUP_MATRIX_M; m++) {
+                let local_row = tile_row_base_local + n * SUBGROUP_MATRIX_N_SIZE;
+                let local_col = tile_col_base_local + m * SUBGROUP_MATRIX_M_SIZE;
+                let out_base = local_row * WG_TILE_STRIDE + local_col;
+                subgroupMatrixStore(&shmem, out_base, acc_sg_mat[m][n], true, WG_TILE_STRIDE);
+            }
+        }
+    }
+
+    workgroupBarrier();
+
+    // Cooperative write: iterate over the entire workgroup tile
+    let tile_rows = WG_N_SG_TILE_SIZE;
+    let tile_cols = WG_M_SG_TILE_SIZE;
+    let total_tile_elems = tile_rows * tile_cols;
+    let tile_dst_row_base = wg_m * SUBGROUP_M * SUBGROUP_MATRIX_M * SUBGROUP_MATRIX_M_SIZE;
+    let tile_dst_col_base = wg_n * SUBGROUP_N * SUBGROUP_MATRIX_N * SUBGROUP_MATRIX_N_SIZE;
+
+    for (var idx = thread_id * {{VEC_SIZE}}; idx < total_tile_elems; idx += TOTAL_WORKGROUP_SIZE * {{VEC_SIZE}}) {
+        let local_row = idx % WG_TILE_STRIDE;
+        let local_col = idx / WG_TILE_STRIDE;
+
+        let global_row = tile_dst_row_base + local_row;
+        let global_col = tile_dst_col_base + local_col;
+
+        if (global_col < params.n && global_row < params.m) {
+            let dst_idx = dst_batch_offset + global_col * params.m + global_row;
+            store_dst(idx, dst_idx/{{VEC_SIZE}});
+        }
+    }
+}
+
+#end(SHADER)
diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec.tmpl.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec.tmpl.wgsl
new file mode 100644 (file)
index 0000000..ffbb640
--- /dev/null
@@ -0,0 +1,267 @@
+#define(VARIANTS)
+[
+  {
+    "SHADER_SUFFIX": "f32_f32_vec",
+    "REPLS": {
+      "SRC0_TYPE" : "vec4<f32>",
+      "SRC1_TYPE" : "vec4<f32>",
+      "DST_TYPE": "vec4<f32>",
+      "VEC_SIZE" : 4,
+    },
+    "DECLS": ["VEC", "MUL_ACC_FLOAT"]
+  },
+  {
+    "SHADER_SUFFIX": "f32_f32",
+    "REPLS": {
+      "SRC0_TYPE" : "f32",
+      "SRC1_TYPE" : "f32",
+      "DST_TYPE": "f32",
+      "VEC_SIZE" : 1,
+    },
+    "DECLS": ["SCALAR", "MUL_ACC_FLOAT"]
+  },
+  {
+    "SHADER_SUFFIX": "f16_f32_vec",
+    "REPLS": {
+      "SRC0_TYPE" : "vec4<f16>",
+      "SRC1_TYPE" : "vec4<f32>",
+      "DST_TYPE": "vec4<f32>",
+      "VEC_SIZE" : 4,
+    },
+    "DECLS": ["VEC", "MUL_ACC_FLOAT"]
+  },
+  {
+    "SHADER_SUFFIX": "f16_f32",
+    "REPLS": {
+      "SRC0_TYPE" : "f16",
+      "SRC1_TYPE" : "f32",
+      "DST_TYPE": "f32",
+      "VEC_SIZE" : 1,
+    },
+    "DECLS": ["SCALAR", "MUL_ACC_FLOAT"]
+  },
+  {
+    "SHADER_SUFFIX": "f16_f16_vec",
+    "REPLS": {
+      "SRC0_TYPE" : "vec4<f16>",
+      "SRC1_TYPE" : "vec4<f16>",
+      "DST_TYPE": "vec4<f32>",
+      "VEC_SIZE" : 4,
+    },
+    "DECLS": ["VEC", "MUL_ACC_FLOAT"]
+  },
+  {
+    "SHADER_SUFFIX": "f16_f16",
+    "REPLS": {
+      "SRC0_TYPE" : "f16",
+      "SRC1_TYPE" : "f16",
+      "DST_TYPE": "f32",
+      "VEC_SIZE" : 1,
+    },
+    "DECLS": ["SCALAR", "MUL_ACC_FLOAT"]
+  },
+  {
+    "SHADER_SUFFIX": "q4_0_f32",
+    "REPLS": {
+      "SRC0_TYPE" : "f16",
+      "SRC1_TYPE" : "f32",
+      "DST_TYPE": "f32",
+      "VEC_SIZE" : 1,
+    },
+    "DECLS": ["BYTE_HELPERS", "SCALAR", "MUL_ACC_Q4_0"]
+  }
+]
+
+#end(VARIANTS)
+
+#define(DECLS)
+
+#decl(VEC)
+fn inner_dot(src0_val: {{SRC0_TYPE}}, src1_val: {{SRC1_TYPE}}) -> f32 {
+    return f32(dot({{SRC1_TYPE}}(src0_val), src1_val));
+}
+
+fn store_val(group_base: u32) -> vec4<f32> {
+    return vec4<f32>(partial_sums[group_base],
+                     partial_sums[group_base + THREADS_PER_OUTPUT],
+                     partial_sums[group_base + THREADS_PER_OUTPUT * 2],
+                     partial_sums[group_base + THREADS_PER_OUTPUT * 3]);
+}
+#enddecl(VEC)
+
+#decl(SCALAR)
+fn inner_dot(src0_val: {{SRC0_TYPE}}, src1_val: {{SRC1_TYPE}}) -> f32 {
+    return f32(src0_val) * f32(src1_val);
+}
+
+fn store_val(group_base: u32) -> f32 {
+    return partial_sums[group_base];
+}
+#enddecl(SCALAR)
+
+#decl(MUL_ACC_FLOAT)
+
+fn mul_acc(tig:u32, tile_size: u32, idx_base: u32, k_outer: u32) -> f32 {
+    var local_sum = 0.0;
+    for (var i = tig * {{VEC_SIZE}}; i < tile_size; i += THREADS_PER_OUTPUT * {{VEC_SIZE}}) {
+        let a = src0[(idx_base + k_outer + i) / {{VEC_SIZE}}];
+        let b = shared_vector[i / {{VEC_SIZE}}];
+        local_sum += inner_dot(a, b);
+    }
+    return local_sum;
+}
+
+#enddecl(MUL_ACC_FLOAT)
+
+#decl(MUL_ACC_Q4_0)
+
+const BLOCK_SIZE = 32;
+const NQ = 16u; // number of weights per thread
+const F16_PER_BLOCK = 9u; // 1 scale + 8x4 packed weights
+const WEIGHTS_PER_F16 = 4u; // 4 weights per f16
+const F16_PER_THREAD = NQ / WEIGHTS_PER_F16;
+
+fn mul_acc(tig:u32, tile_size: u32, idx_base: u32, k_outer: u32) -> f32 {
+    var local_sum = 0.0;
+    for (var i = tig * NQ; i < tile_size; i += THREADS_PER_OUTPUT * NQ) {
+        let blck_idx = i / BLOCK_SIZE;
+        let block_offset = (i % BLOCK_SIZE) / WEIGHTS_PER_F16;
+        let scale_idx = (idx_base + k_outer / BLOCK_SIZE + blck_idx) * F16_PER_BLOCK;
+        // each f16 contains offsets [block_offset, block_offset + 1] and [block_offset + 16, block_offset + 17]
+        let shmem_idx = blck_idx * BLOCK_SIZE + block_offset * 2u;
+        let d = f32(src0[scale_idx]);
+        for (var j = 0u; j < F16_PER_THREAD; j += 2) {
+            let q_0 = src0[scale_idx + 1 + block_offset + j];
+            let q_1 = src0[scale_idx + 1 + block_offset + j + 1];
+            let q_packed = bitcast<u32>(vec2(q_0, q_1));
+            for (var k: u32 = 0; k < 4; k++) {
+                let q_byte = get_byte(q_packed, k);
+                let q_hi = (f32((q_byte >> 4) & 0xF) - 8.0) * d;
+                let q_lo = (f32(q_byte & 0xF) - 8.0) * d;
+                local_sum += q_lo * shared_vector[shmem_idx + j * 2 + k];
+                local_sum += q_hi * shared_vector[shmem_idx + j * 2 + k + 16];
+            }
+        }
+    }
+    return local_sum;
+}
+
+#enddecl(MUL_ACC_Q4_0)
+
+#end(DECLS)
+
+#define(SHADER)
+enable f16;
+
+DECLS
+
+struct MulMatParams {
+    offset_src0: u32,
+    offset_src1: u32,
+    offset_dst: u32,
+    m: u32,
+    n: u32,
+    k: u32,
+    stride_01: u32,
+    stride_11: u32,
+    stride_02: u32,
+    stride_12: u32,
+    stride_03: u32,
+    stride_13: u32,
+    bs02: u32,
+    bs03: u32,
+    broadcast2: u32,
+    broadcast3: u32
+};
+
+@group(0) @binding(0) var<storage, read_write> src0: array<{{SRC0_TYPE}}>; // Matrix (M x K)
+@group(0) @binding(1) var<storage, read_write> src1: array<{{SRC1_TYPE}}>; // Vector (K x 1, transposed)
+@group(0) @binding(2) var<storage, read_write> dst: array<{{DST_TYPE}}>;  // Result vector (transposed)
+
+@group(0) @binding(3) var<uniform> params: MulMatParams;
+
+override WORKGROUP_SIZE: u32;
+override TILE_K: u32;
+override OUTPUTS_PER_WG: u32;
+override THREADS_PER_OUTPUT = WORKGROUP_SIZE / OUTPUTS_PER_WG;
+
+// Shared memory for collaborative loading and reduction
+var<workgroup> shared_vector: array<{{SRC1_TYPE}}, TILE_K/{{VEC_SIZE}}>;  // Cache vector tile
+var<workgroup> partial_sums: array<f32, WORKGROUP_SIZE>;   // For reduction
+
+@compute @workgroup_size(WORKGROUP_SIZE)
+fn main(
+    @builtin(local_invocation_id) local_id: vec3<u32>,
+    @builtin(workgroup_id) wg_id: vec3<u32>,
+    @builtin(num_workgroups) num_wg: vec3<u32>) {
+    let thread_id = local_id.x;
+
+    // Handle batch dimensions
+    let total_batches = params.bs02 * params.broadcast2 * params.bs03 * params.broadcast3;
+    let wg_linear = wg_id.y * num_wg.x + wg_id.x;
+    let output_groups = (params.m + OUTPUTS_PER_WG - 1u) / OUTPUTS_PER_WG;
+    let batch_idx = wg_linear / output_groups;
+    if (batch_idx >= total_batches) {
+        return;
+    }
+
+    // Which of the outputs does this thread belong to?
+    let thread_group = thread_id / THREADS_PER_OUTPUT;
+    let thread_in_group = thread_id % THREADS_PER_OUTPUT;
+
+    // Each workgroup computes OUTPUTS_PER_WG consecutive outputs
+    let output_row = (wg_linear % output_groups) * OUTPUTS_PER_WG + thread_group;
+
+    let dst2_stride = params.m * params.n;
+    let dst2_idx = batch_idx % (params.bs02 * params.broadcast2);
+    let dst3_stride = dst2_stride * params.bs02 * params.broadcast2;
+    let dst3_idx = batch_idx / (params.bs02 * params.broadcast2);
+    let src03_idx = dst3_idx / params.broadcast3;
+    let src13_idx = dst3_idx;
+    let src02_idx = dst2_idx / params.broadcast2;
+    let src12_idx = dst2_idx;
+
+    let src0_idx_base = params.offset_src0 + src03_idx * params.stride_03 + src02_idx * params.stride_02 + output_row * params.stride_01;
+    let src1_idx_base = params.offset_src1 + src13_idx * params.stride_13 + src12_idx * params.stride_12;
+    let dst_idx = params.offset_dst + dst3_idx * dst3_stride + dst2_idx * dst2_stride + output_row;
+
+    var local_sum = 0.0;
+
+    // Each thread processes multiple K elements and accumulates
+    for (var k_tile = 0u; k_tile < params.k; k_tile += TILE_K) {
+        let tile_size = min(TILE_K, params.k - k_tile);
+
+        // Cooperatively load vector tile into shared memory (all threads)
+        for (var i = thread_id * {{VEC_SIZE}}; i < tile_size; i += WORKGROUP_SIZE * {{VEC_SIZE}}) {
+            shared_vector[i / {{VEC_SIZE}}] = src1[(src1_idx_base + k_tile + i) / {{VEC_SIZE}}];
+        }
+
+        workgroupBarrier();
+
+        if (output_row < params.m) {
+            local_sum += mul_acc(thread_in_group, tile_size, src0_idx_base, k_tile);
+        }
+
+        workgroupBarrier();
+    }
+
+    // Store partial sums and reduce within each partition
+    partial_sums[thread_id] = local_sum;
+    workgroupBarrier();
+    let group_base = thread_group * THREADS_PER_OUTPUT;
+    let thread_base = group_base + thread_in_group;
+    var offset = THREADS_PER_OUTPUT / 2;
+    while (offset > 0) {
+        if (thread_in_group < offset) {
+            partial_sums[thread_base] += partial_sums[thread_base + offset];
+        }
+        offset = offset / 2;
+        workgroupBarrier();
+    }
+
+    // Store back to global memory
+    if (output_row < params.m && thread_group % {{VEC_SIZE}} == 0 && thread_in_group == 0) {
+        dst[dst_idx / {{VEC_SIZE}}] = store_val(group_base);
+    }
+}
+#end(SHADER)