]> git.djapps.eu Git - pkg/ggml/sources/whisper.cpp/commitdiff
ggml webgpu: faster normal quant and some k-quant matrix operations, better shader...
authorReese Levine <redacted>
Tue, 10 Mar 2026 16:14:27 +0000 (09:14 -0700)
committerGeorgi Gerganov <redacted>
Mon, 16 Mar 2026 11:10:15 +0000 (13:10 +0200)
* K quant speedup (llama/20)

* Basic JIT compilation for mul_mat, get_rows, and scale (llama/17)

* scale jit working

* preliminary working jit for getrows and mulmat, needs refining

* simplified mul_mat preprocessing switch statement

* get_rows fixes, mul_mat refinement

* formatted + last edits

* removed some extraneous prints

* fixed get_rows, fixed workgroup dispatch in mul_mat. no gibberish

* small fix

* some changes, working

* get_rows and mul_mat jit fixed and working

* Update formatting

* formatting

* Add header

---------

Co-authored-by: Neha Abbas <redacted>
Co-authored-by: Reese Levine <redacted>
* Start work on all-encompassing shader library

* refactor argmax, set_rows

* Refactor all but flashattention, mat mul

* no gibberish, all k quants added, merged

* vec memory fix

* q6_k matching metal on my machine, tests passing

* Set tile size for q6_k separately

* Separate out fast shaders

---------

Co-authored-by: neha-ha <redacted>
* Move towards writeBuffer for params

* Move away from multiple buffers for set_rows errors, remove host buffer for parameter buffers, minor cleanups

* Remove extra file

* Formatting

---------

Co-authored-by: neha-ha <redacted>
ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp
ggml/src/ggml-webgpu/ggml-webgpu.cpp
ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_decls.tmpl
ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_reg_tile.wgsl
ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec.wgsl

index 17c5e0fb51f77fcb35d2854333fbfefa325282ca..3c38b1a230ffad526bc11c014007ca1086e80303 100644 (file)
 #define WEBGPU_MUL_MAT_SUBGROUP_MATRIX_N 2
 
 // Matrix-vector multiplication parameters
-#define WEBGPU_MUL_MAT_VEC_WG_SIZE        256
+#define WEBGPU_MUL_MAT_VEC_WG_SIZE 256
+
 // Must be multiple of 4 to work with vectorized paths, and must divide
 // mul_mat_vec wg size
-#define WEBGPU_MUL_MAT_VEC_OUTPUTS_PER_WG 64
-#define WEBGPU_MUL_MAT_VEC_TILE_K         256
+#define WEBGPU_MUL_MAT_VEC_FLOAT_OUTPUTS_PER_WG 64
+#define WEBGPU_MUL_MAT_VEC_FLOAT_TILE_K         256
+
+#define WEBGPU_MUL_MAT_VEC_LEGACY_Q_OUTPUTS_PER_WG 64
+#define WEBGPU_MUL_MAT_VEC_LEGACY_Q_TILE_K         256
+
+// Requires 32 threads per output (wg_size/outputs_per_wg == 32)
+#define WEBGPU_MUL_MAT_VEC_K_Q_OUTPUTS_PER_WG 8
+// Requires at least two (and multiple of 2) k-quant blocks per tile
+#define WEBGPU_MUL_MAT_VEC_K_Q_TILE_K         512
 
 // default size for legacy matrix multiplication
 #define WEBGPU_MUL_MAT_WG_SIZE 256
@@ -199,7 +208,8 @@ struct ggml_webgpu_binary_pipeline_key {
     bool src_overlap;
 
     bool operator==(const ggml_webgpu_binary_pipeline_key & other) const {
-        return type == other.type && op == other.op && inplace == other.inplace && overlap == other.overlap && src_overlap == other.src_overlap;
+        return type == other.type && op == other.op && inplace == other.inplace && overlap == other.overlap &&
+               src_overlap == other.src_overlap;
     }
 };
 
@@ -749,29 +759,17 @@ class ggml_webgpu_shader_lib {
         std::vector<std::string> defines;
         std::string              variant = "mul_mat_vec";
 
-        // src1 type (vector)
-        switch (context.src1->type) {
-            case GGML_TYPE_F32:
-                defines.push_back("SRC1_INNER_TYPE=f32");
-                variant += "_f32";
-                break;
-            case GGML_TYPE_F16:
-                defines.push_back("SRC1_INNER_TYPE=f16");
-                variant += "_f16";
-                break;
-            default:
-                GGML_ABORT("Unsupported src1 type for mul_mat_vec shader");
-        }
-
         // src0 type (matrix row)
         switch (context.src0->type) {
             case GGML_TYPE_F32:
                 defines.push_back("SRC0_INNER_TYPE=f32");
                 defines.push_back("MUL_ACC_FLOAT");
+                variant += "_f32";
                 break;
             case GGML_TYPE_F16:
                 defines.push_back("SRC0_INNER_TYPE=f16");
                 defines.push_back("MUL_ACC_FLOAT");
+                variant += "_f16";
                 break;
             default:
                 {
@@ -779,6 +777,7 @@ class ggml_webgpu_shader_lib {
                     const struct ggml_type_traits * src0_traits = ggml_get_type_traits(context.src0->type);
                     std::string                     src0_name   = src0_traits->type_name;
                     std::string                     type_upper  = src0_name;
+                    variant += "_" + src0_name;
                     std::transform(type_upper.begin(), type_upper.end(), type_upper.begin(), ::toupper);
 
                     defines.push_back("BYTE_HELPERS");
@@ -790,12 +789,35 @@ class ggml_webgpu_shader_lib {
                 }
         }
 
+        // src1 type (vector)
+        switch (context.src1->type) {
+            case GGML_TYPE_F32:
+                defines.push_back("SRC1_INNER_TYPE=f32");
+                variant += "_f32";
+                break;
+            case GGML_TYPE_F16:
+                defines.push_back("SRC1_INNER_TYPE=f16");
+                variant += "_f16";
+                break;
+            default:
+                GGML_ABORT("Unsupported src1 type for mul_mat_vec shader");
+        }
+
         // VEC/SCALAR controls
         defines.push_back(key.vectorized ? "VEC" : "SCALAR");
 
         uint32_t wg_size        = WEBGPU_MUL_MAT_VEC_WG_SIZE;
-        uint32_t tile_k         = WEBGPU_MUL_MAT_VEC_TILE_K;
-        uint32_t outputs_per_wg = WEBGPU_MUL_MAT_VEC_OUTPUTS_PER_WG;
+        uint32_t tile_k         = WEBGPU_MUL_MAT_VEC_FLOAT_TILE_K;
+        uint32_t outputs_per_wg = WEBGPU_MUL_MAT_VEC_FLOAT_OUTPUTS_PER_WG;
+
+        if (key.src0_type >= GGML_TYPE_Q2_K) {
+            tile_k         = WEBGPU_MUL_MAT_VEC_K_Q_TILE_K;
+            outputs_per_wg = WEBGPU_MUL_MAT_VEC_K_Q_OUTPUTS_PER_WG;
+        } else if (key.src0_type >= GGML_TYPE_Q4_0) {
+            tile_k         = WEBGPU_MUL_MAT_VEC_LEGACY_Q_TILE_K;
+            outputs_per_wg = WEBGPU_MUL_MAT_VEC_LEGACY_Q_OUTPUTS_PER_WG;
+        }
+
         defines.push_back(std::string("WG_SIZE=") + std::to_string(wg_size));
         defines.push_back(std::string("TILE_K=") + std::to_string(tile_k));
         defines.push_back(std::string("OUTPUTS_PER_WG=") + std::to_string(outputs_per_wg));
@@ -1061,10 +1083,10 @@ class ggml_webgpu_shader_lib {
 
     webgpu_pipeline get_binary_pipeline(const ggml_webgpu_shader_lib_context & context) {
         ggml_webgpu_binary_pipeline_key key = {
-            .type    = context.dst->type,
-            .op      = context.dst->op,
-            .inplace = context.inplace,
-            .overlap = context.overlap,
+            .type        = context.dst->type,
+            .op          = context.dst->op,
+            .inplace     = context.inplace,
+            .overlap     = context.overlap,
             .src_overlap = context.src_overlap,
         };
 
index b2ef2d59010588d7fd4d433bceaec7e2ad63748e..ccc34cb153f70e17fa7c3c8b140022747e36e62f 100644 (file)
@@ -8,7 +8,6 @@
 #include "ggml-backend-impl.h"
 #include "ggml-impl.h"
 #include "ggml-webgpu-shader-lib.hpp"
-#include "pre_wgsl.hpp"
 
 #ifdef __EMSCRIPTEN__
 #    include <emscripten/emscripten.h>
 #include <condition_variable>
 #include <cstdint>
 #include <cstring>
-#include <iostream>
+#ifdef GGML_WEBGPU_GPU_PROFILE
+#    include <iomanip>
+#endif
+#if defined(GGML_WEBGPU_DEBUG) || defined(GGML_WEBGPU_CPU_PROFILE) || defined(GGML_WEBGPU_GPU_PROFILE)
+#    include <iostream>
+#endif
 #include <map>
 #include <memory>
 #include <mutex>
 #include <optional>
 #include <string>
+#include <utility>
 #include <vector>
 
 #define ROUNDUP_POW2(x, pow2) (((x) + ((pow2) - 1)) & ~((pow2) - 1))
@@ -70,22 +75,21 @@ static inline void compute_2d_workgroups(uint32_t total_wg, uint32_t max_per_dim
 #endif  // GGML_WEBGPU_CPU_PROFILE
 
 #ifdef GGML_WEBGPU_GPU_PROFILE
-#    define WEBGPU_NUM_TIMESTAMP_QUERY_BUFS       24
+#    define WEBGPU_NUM_TIMESTAMP_QUERY_BUFS       32
 #    define WEBGPU_TIMESTAMP_QUERY_BUF_SIZE_BYTES 16  // e.g. enough for two timestamps
 #endif
 
 /* Constants */
 
-#define WEBGPU_NUM_PARAM_BUFS                48u
-#define WEBGPU_COMMAND_SUBMIT_BATCH_SIZE     16u
+#define WEBGPU_NUM_PARAM_BUFS                96u
+#define WEBGPU_COMMAND_SUBMIT_BATCH_SIZE     32u
 #define WEBGPU_WAIT_ANY_TIMEOUT_MS           0
 // Maximum number of in-flight submissions per-thread, to avoid exhausting the
 // parameter buffer pool
-#define WEBGPU_MAX_INFLIGHT_SUBS_PER_THREAD  WEBGPU_NUM_PARAM_BUFS / WEBGPU_COMMAND_SUBMIT_BATCH_SIZE
+#define WEBGPU_MAX_INFLIGHT_SUBS_PER_THREAD  (WEBGPU_NUM_PARAM_BUFS / WEBGPU_COMMAND_SUBMIT_BATCH_SIZE)
 #define WEBGPU_PARAMS_BUF_SIZE_BYTES         128  // enough for 32 parameters
-#define WEBGPU_NUM_SET_ROWS_ERROR_BUFS       16
 #define WEBGPU_SET_ROWS_ERROR_BUF_SIZE_BYTES 4
-#define WEBGPU_STORAGE_BUF_BINDING_MULT      4  // a storage buffer binding size must be a multiple of 4
+#define WEBGPU_STORAGE_BUF_BINDING_MULT      4    // a storage buffer binding size must be a multiple of 4
 
 // For operations which process a row in parallel, this seems like a reasonable
 // default
@@ -118,14 +122,9 @@ static void ggml_webgpu_create_buffer(wgpu::Device &    device,
                                       wgpu::BufferUsage usage,
                                       const char *      label);
 
-struct webgpu_pool_bufs {
-    wgpu::Buffer host_buf;
-    wgpu::Buffer dev_buf;
-};
-
 // Holds a pool of parameter buffers for WebGPU operations
 struct webgpu_buf_pool {
-    std::vector<webgpu_pool_bufs> free;
+    std::vector<wgpu::Buffer> free;
 
     // The pool must be synchronized because
     // 1. The memset pool is shared globally by every ggml buffer,
@@ -138,7 +137,6 @@ struct webgpu_buf_pool {
     size_t                  cur_pool_size;
     size_t                  max_pool_size;
     wgpu::Device            device;
-    wgpu::BufferUsage       host_buf_usage;
     wgpu::BufferUsage       dev_buf_usage;
     size_t                  buf_size;
     bool                    should_grow;
@@ -147,53 +145,47 @@ struct webgpu_buf_pool {
               int               num_bufs,
               size_t            buf_size,
               wgpu::BufferUsage dev_buf_usage,
-              wgpu::BufferUsage host_buf_usage,
               bool              should_grow   = false,
               size_t            max_pool_size = WEBGPU_NUM_PARAM_BUFS * 2) {
-        this->max_pool_size  = max_pool_size;
-        this->cur_pool_size  = num_bufs;
-        this->device         = device;
-        this->host_buf_usage = host_buf_usage;
-        this->dev_buf_usage  = dev_buf_usage;
-        this->buf_size       = buf_size;
-        this->should_grow    = should_grow;
+        this->max_pool_size = max_pool_size;
+        this->cur_pool_size = num_bufs;
+        this->device        = device;
+        this->dev_buf_usage = dev_buf_usage;
+        this->buf_size      = buf_size;
+        this->should_grow   = should_grow;
         for (int i = 0; i < num_bufs; i++) {
-            wgpu::Buffer host_buf;
             wgpu::Buffer dev_buf;
-            ggml_webgpu_create_buffer(device, host_buf, buf_size, host_buf_usage, "ggml_webgpu_host_pool_buf");
             ggml_webgpu_create_buffer(device, dev_buf, buf_size, dev_buf_usage, "ggml_webgpu_dev_pool_buf");
-            free.push_back({ host_buf, dev_buf });
+            free.push_back(dev_buf);
         }
     }
 
-    webgpu_pool_bufs alloc_bufs() {
+    wgpu::Buffer alloc_bufs() {
         std::unique_lock<std::mutex> lock(mutex);
         if (!free.empty()) {
-            webgpu_pool_bufs bufs = free.back();
+            wgpu::Buffer buf = free.back();
             free.pop_back();
-            return bufs;
+            return buf;
         }
 
         // Try growing the pool if no free buffers
         if (free.empty() && cur_pool_size < max_pool_size && should_grow) {
             cur_pool_size++;
-            wgpu::Buffer host_buf;
             wgpu::Buffer dev_buf;
-            ggml_webgpu_create_buffer(device, host_buf, buf_size, host_buf_usage, "ggml_webgpu_host_pool_buf");
             ggml_webgpu_create_buffer(device, dev_buf, buf_size, dev_buf_usage, "ggml_webgpu_dev_pool_buf");
 
-            if (!(host_buf && dev_buf)) {
+            if (!dev_buf) {
                 GGML_ABORT("webgpu_buf_pool: failed to allocate buffers");
             }
-            return webgpu_pool_bufs{ host_buf, dev_buf };
+            return dev_buf;
         }
         cv.wait(lock, [this] { return !free.empty(); });
-        webgpu_pool_bufs bufs = free.back();
+        wgpu::Buffer buf = free.back();
         free.pop_back();
-        return bufs;
+        return buf;
     }
 
-    void free_bufs(std::vector<webgpu_pool_bufs> bufs) {
+    void free_bufs(std::vector<wgpu::Buffer> bufs) {
         std::lock_guard<std::mutex> lock(mutex);
         free.insert(free.end(), bufs.begin(), bufs.end());
         cv.notify_all();
@@ -201,12 +193,9 @@ struct webgpu_buf_pool {
 
     void cleanup() {
         std::lock_guard<std::mutex> lock(mutex);
-        for (auto & bufs : free) {
-            if (bufs.host_buf) {
-                bufs.host_buf.Destroy();
-            }
-            if (bufs.dev_buf) {
-                bufs.dev_buf.Destroy();
+        for (auto & buf : free) {
+            if (buf) {
+                buf.Destroy();
             }
         }
         free.clear();
@@ -280,10 +269,9 @@ struct webgpu_gpu_profile_buf_pool {
 #endif
 
 struct webgpu_command {
-    uint32_t                        num_kernels;
-    wgpu::CommandBuffer             commands;
-    std::vector<webgpu_pool_bufs>   params_bufs;
-    std::optional<webgpu_pool_bufs> set_rows_error_bufs;
+    uint32_t                  num_kernels;
+    wgpu::CommandBuffer       commands;
+    std::vector<wgpu::Buffer> params_bufs;
 #ifdef GGML_WEBGPU_GPU_PROFILE
     webgpu_gpu_profile_bufs timestamp_query_bufs;
     std::string             pipeline_name;
@@ -358,6 +346,13 @@ struct webgpu_global_context_struct {
 
 typedef std::shared_ptr<webgpu_global_context_struct> webgpu_global_context;
 
+struct webgpu_submission {
+    wgpu::FutureWaitInfo submit_done;
+#ifdef GGML_WEBGPU_GPU_PROFILE
+    std::vector<wgpu::FutureWaitInfo> profile_futures;
+#endif
+};
+
 // All the base objects needed to run operations on a WebGPU device
 struct webgpu_context_struct {
     // Points to global instances owned by ggml_backend_webgpu_reg_context
@@ -366,7 +361,8 @@ struct webgpu_context_struct {
     std::unique_ptr<ggml_webgpu_shader_lib> shader_lib;
 
     webgpu_buf_pool param_buf_pool;
-    webgpu_buf_pool set_rows_error_buf_pool;
+    wgpu::Buffer    set_rows_dev_error_buf;
+    wgpu::Buffer    set_rows_host_error_buf;
 
     std::map<int, std::map<int, webgpu_pipeline>> cpy_pipelines;                      // src_type, dst_type
 
@@ -458,67 +454,105 @@ static void ggml_webgpu_create_buffer(wgpu::Device &    device,
 /** End WebGPU object initializations */
 
 /** WebGPU Actions */
-static void erase_completed(std::vector<wgpu::FutureWaitInfo> & futures) {
+
+static bool ggml_backend_webgpu_handle_wait_status(wgpu::WaitStatus status, bool allow_timeout = false) {
+    switch (status) {
+        case wgpu::WaitStatus::Success:
+            return true;
+        case wgpu::WaitStatus::TimedOut:
+            if (allow_timeout) {
+                return false;
+            }
+            GGML_LOG_ERROR("ggml_webgpu: WaitAny timed out unexpectedly\n");
+            return false;
+        case wgpu::WaitStatus::Error:
+            GGML_LOG_ERROR("ggml_webgpu: WaitAny returned an error\n");
+            return false;
+        default:
+            GGML_LOG_ERROR("ggml_webgpu: WaitAny returned an unknown status\n");
+            return false;
+    }
+}
+
+#ifdef GGML_WEBGPU_GPU_PROFILE
+static void ggml_backend_webgpu_erase_completed_futures(std::vector<wgpu::FutureWaitInfo> & futures) {
     futures.erase(std::remove_if(futures.begin(), futures.end(),
                                  [](const wgpu::FutureWaitInfo & info) { return info.completed; }),
                   futures.end());
 }
 
-// Wait for the queue to finish processing all submitted work
-static void ggml_backend_webgpu_wait(webgpu_global_context &             ctx,
-                                     std::vector<wgpu::FutureWaitInfo> & futures,
-                                     bool                                block = true) {
-    // If we have too many in-flight submissions, wait on the oldest one first.
+static void ggml_backend_webgpu_wait_profile_futures(webgpu_global_context &             ctx,
+                                                     std::vector<wgpu::FutureWaitInfo> & futures,
+                                                     bool                                block) {
     if (futures.empty()) {
         return;
     }
+
     uint64_t timeout_ms = block ? UINT64_MAX : 0;
-    while (futures.size() >= WEBGPU_MAX_INFLIGHT_SUBS_PER_THREAD) {
-        auto waitStatus = ctx->instance.WaitAny(1, &futures[0], UINT64_MAX);
-        if (waitStatus == wgpu::WaitStatus::Error) {
-            GGML_LOG_ERROR("ggml_webgpu: WaitAny returned an error\n");
+    if (block) {
+        while (!futures.empty()) {
+            auto waitStatus = ctx->instance.WaitAny(futures.size(), futures.data(), timeout_ms);
+            if (ggml_backend_webgpu_handle_wait_status(waitStatus)) {
+                ggml_backend_webgpu_erase_completed_futures(futures);
+            }
         }
-        if (futures[0].completed) {
-            futures.erase(futures.begin());
+    } else {
+        auto waitStatus = ctx->instance.WaitAny(futures.size(), futures.data(), timeout_ms);
+        if (ggml_backend_webgpu_handle_wait_status(waitStatus, true)) {
+            ggml_backend_webgpu_erase_completed_futures(futures);
         }
     }
+}
+#endif
 
-    if (futures.empty()) {
+// Wait for the queue to finish processing all submitted work
+static void ggml_backend_webgpu_wait(webgpu_global_context &          ctx,
+                                     std::vector<webgpu_submission> & subs,
+                                     bool                             block = true) {
+    // If we have too many in-flight submissions, wait on the oldest one first.
+    if (subs.empty()) {
+        return;
+    }
+    while (subs.size() >= WEBGPU_MAX_INFLIGHT_SUBS_PER_THREAD) {
+        auto waitStatus = ctx->instance.WaitAny(1, &subs[0].submit_done, UINT64_MAX);
+        if (ggml_backend_webgpu_handle_wait_status(waitStatus)) {
+#ifdef GGML_WEBGPU_GPU_PROFILE
+            ggml_backend_webgpu_wait_profile_futures(ctx, subs[0].profile_futures, true);
+#endif
+            subs.erase(subs.begin());
+        }
+    }
+
+    if (subs.empty()) {
         return;
     }
 
     if (block) {
-        while (!futures.empty()) {
-            auto waitStatus = ctx->instance.WaitAny(futures.size(), futures.data(), timeout_ms);
-            switch (waitStatus) {
-                case wgpu::WaitStatus::Success:
-                    // WaitAny doesn't tell us which future completed, so we must check all futures to see which finished.
-                    erase_completed(futures);
-                    break;
-                case wgpu::WaitStatus::Error:
-                    GGML_LOG_ERROR("ggml_webgpu: WaitAny returned an error\n");
-                    break;
-                default:
-                    GGML_LOG_ERROR("ggml_webgpu: WaitAny returned an unknown status\n");
-                    break;
+        for (auto & sub : subs) {
+            while (!sub.submit_done.completed) {
+                auto waitStatus = ctx->instance.WaitAny(1, &sub.submit_done, UINT64_MAX);
+                ggml_backend_webgpu_handle_wait_status(waitStatus);
             }
+#ifdef GGML_WEBGPU_GPU_PROFILE
+            ggml_backend_webgpu_wait_profile_futures(ctx, sub.profile_futures, true);
+#endif
         }
+        subs.clear();
     } else {
-        // Poll once and return
-        auto waitStatus = ctx->instance.WaitAny(futures.size(), futures.data(), timeout_ms);
-        switch (waitStatus) {
-            case wgpu::WaitStatus::Success:
-                // WaitAny doesn't tell us which future completed, so we must check all futures to see which finished.
-                erase_completed(futures);
-                break;
-            case wgpu::WaitStatus::TimedOut:
-                break;
-            case wgpu::WaitStatus::Error:
-                GGML_LOG_ERROR("ggml_webgpu: WaitAny returned an error\n");
-                break;
-            default:
-                GGML_LOG_ERROR("ggml_webgpu: WaitAny returned an unknown status\n");
-                break;
+        // Poll each submit future once and remove completed submissions.
+        for (auto sub = subs.begin(); sub != subs.end();) {
+            auto waitStatus = ctx->instance.WaitAny(1, &sub->submit_done, 0);
+            ggml_backend_webgpu_handle_wait_status(waitStatus, true);
+#ifdef GGML_WEBGPU_GPU_PROFILE
+            ggml_backend_webgpu_wait_profile_futures(ctx, sub->profile_futures, false);
+            if (sub->submit_done.completed && sub->profile_futures.empty()) {
+#else
+            if (sub->submit_done.completed) {
+#endif
+                sub = subs.erase(sub);
+            } else {
+                ++sub;
+            }
         }
     }
 }
@@ -554,14 +588,12 @@ static void ggml_backend_webgpu_debug(webgpu_global_context & ctx) {
 }
 #endif
 
-static std::vector<wgpu::FutureWaitInfo> ggml_backend_webgpu_submit(
-    webgpu_global_context       ctx,
-    std::vector<webgpu_command> commands,
-    webgpu_buf_pool &           param_buf_pool,
-    webgpu_buf_pool *           set_rows_error_buf_pool = nullptr) {
+static webgpu_submission ggml_backend_webgpu_submit(webgpu_global_context &       ctx,
+                                                    std::vector<webgpu_command> & commands,
+                                                    webgpu_buf_pool &             param_buf_pool) {
     std::vector<wgpu::CommandBuffer> command_buffers;
-    std::vector<webgpu_pool_bufs>    params_bufs;
-    std::vector<webgpu_pool_bufs>    set_rows_error_bufs;
+    std::vector<wgpu::Buffer>        params_bufs;
+    webgpu_submission                submission;
 #ifdef GGML_WEBGPU_GPU_PROFILE
     std::vector<std::pair<std::string, webgpu_gpu_profile_bufs>> pipeline_name_and_ts_bufs;
 #endif
@@ -569,14 +601,9 @@ static std::vector<wgpu::FutureWaitInfo> ggml_backend_webgpu_submit(
     for (const auto & command : commands) {
         command_buffers.push_back(command.commands);
         params_bufs.insert(params_bufs.end(), command.params_bufs.begin(), command.params_bufs.end());
-        if (command.set_rows_error_bufs) {
-            set_rows_error_bufs.push_back(command.set_rows_error_bufs.value());
-        }
     }
     ctx->queue.Submit(command_buffers.size(), command_buffers.data());
 
-    std::vector<wgpu::FutureWaitInfo> futures;
-
     wgpu::Future p_f = ctx->queue.OnSubmittedWorkDone(
         wgpu::CallbackMode::AllowSpontaneous,
         [&param_buf_pool, params_bufs](wgpu::QueueWorkDoneStatus status, wgpu::StringView message) {
@@ -586,27 +613,7 @@ static std::vector<wgpu::FutureWaitInfo> ggml_backend_webgpu_submit(
             // Free the staged buffers
             param_buf_pool.free_bufs(params_bufs);
         });
-    futures.push_back({ p_f });
-
-    for (const auto & bufs : set_rows_error_bufs) {
-        wgpu::Future f = bufs.host_buf.MapAsync(
-            wgpu::MapMode::Read, 0, bufs.host_buf.GetSize(), wgpu::CallbackMode::AllowSpontaneous,
-            [set_rows_error_buf_pool, bufs](wgpu::MapAsyncStatus status, wgpu::StringView message) {
-                if (status != wgpu::MapAsyncStatus::Success) {
-                    GGML_LOG_ERROR("ggml_webgpu: Failed to map error buffer: %s\n", std::string(message).c_str());
-                } else {
-                    const uint32_t * error_data = (const uint32_t *) bufs.host_buf.GetConstMappedRange();
-                    if (*error_data) {
-                        GGML_ABORT("ggml_webgpu: SET_ROWS index > 2^32, unsupported.");
-                    }
-                    // We can't unmap in here due to WebGPU reentrancy limitations.
-                    if (set_rows_error_buf_pool) {
-                        set_rows_error_buf_pool->free_bufs({ bufs });
-                    }
-                }
-            });
-        futures.push_back({ f });
-    }
+    submission.submit_done = { p_f };
 
 #ifdef GGML_WEBGPU_GPU_PROFILE
     for (const auto & command : commands) {
@@ -623,14 +630,14 @@ static std::vector<wgpu::FutureWaitInfo> ggml_backend_webgpu_submit(
                     // WebGPU timestamps are in ns; convert to ms
                     double           elapsed_ms = double(ts_data[1] - ts_data[0]) * 1e-6;
                     ctx->shader_gpu_time_ms[label] += elapsed_ms;
-                    // We can't unmap in here due to WebGPU reentrancy limitations.
-                    ctx->timestamp_query_buf_pool.free_bufs({ ts_bufs });
                 }
+                // We can't unmap in here due to WebGPU reentrancy limitations.
+                ctx->timestamp_query_buf_pool.free_bufs({ ts_bufs });
             });
-        futures.push_back({ f });
+        submission.profile_futures.push_back({ f });
     }
 #endif
-    return futures;
+    return submission;
 }
 
 static webgpu_command ggml_backend_webgpu_build_multi(
@@ -639,32 +646,21 @@ static webgpu_command ggml_backend_webgpu_build_multi(
     const std::vector<webgpu_pipeline> &                   pipelines,
     const std::vector<std::vector<uint32_t>> &             params_list,
     const std::vector<std::vector<wgpu::BindGroupEntry>> & bind_group_entries_list,
-    const std::vector<std::pair<uint32_t, uint32_t>> &     workgroups_list,
-    const std::optional<webgpu_pool_bufs> &                set_rows_error_bufs = std::nullopt) {
+    const std::vector<std::pair<uint32_t, uint32_t>> &     workgroups_list) {
     GGML_ASSERT(pipelines.size() == params_list.size());
     GGML_ASSERT(pipelines.size() == bind_group_entries_list.size());
     GGML_ASSERT(pipelines.size() == workgroups_list.size());
 
-    std::vector<webgpu_pool_bufs> params_bufs_list;
-    std::vector<wgpu::BindGroup>  bind_groups;
+    std::vector<wgpu::Buffer>    params_bufs_list;
+    std::vector<wgpu::BindGroup> bind_groups;
 
     for (size_t i = 0; i < pipelines.size(); i++) {
-        webgpu_pool_bufs params_bufs = param_buf_pool.alloc_bufs();
-
-        ggml_backend_webgpu_map_buffer(ctx, params_bufs.host_buf, wgpu::MapMode::Write, 0,
-                                       params_bufs.host_buf.GetSize());
-        uint32_t * _params = (uint32_t *) params_bufs.host_buf.GetMappedRange();
-        for (size_t j = 0; j < params_list[i].size(); j++) {
-            _params[j] = params_list[i][j];
-        }
-        params_bufs.host_buf.Unmap();
+        wgpu::Buffer params_bufs = param_buf_pool.alloc_bufs();
 
         std::vector<wgpu::BindGroupEntry> entries            = bind_group_entries_list[i];
         uint32_t                          params_binding_num = entries.size();
-        entries.push_back({ .binding = params_binding_num,
-                            .buffer  = params_bufs.dev_buf,
-                            .offset  = 0,
-                            .size    = params_bufs.dev_buf.GetSize() });
+        entries.push_back(
+            { .binding = params_binding_num, .buffer = params_bufs, .offset = 0, .size = params_bufs.GetSize() });
 
         wgpu::BindGroupDescriptor bind_group_desc;
         bind_group_desc.layout     = pipelines[i].pipeline.GetBindGroupLayout(0);
@@ -677,15 +673,8 @@ static webgpu_command ggml_backend_webgpu_build_multi(
     }
 
     wgpu::CommandEncoder encoder = ctx->device.CreateCommandEncoder();
-    for (const auto & params_bufs : params_bufs_list) {
-        encoder.CopyBufferToBuffer(params_bufs.host_buf, 0, params_bufs.dev_buf, 0, params_bufs.dev_buf.GetSize());
-    }
-
-    // If there are SET_ROWS operations in this submission, copy their error
-    // buffers to the host.
-    if (set_rows_error_bufs) {
-        encoder.CopyBufferToBuffer(set_rows_error_bufs->dev_buf, 0, set_rows_error_bufs->host_buf, 0,
-                                   set_rows_error_bufs->host_buf.GetSize());
+    for (size_t i = 0; i < params_bufs_list.size(); i++) {
+        ctx->queue.WriteBuffer(params_bufs_list[i], 0, params_list[i].data(), params_list[i].size() * sizeof(uint32_t));
     }
 
 #ifdef GGML_WEBGPU_GPU_PROFILE
@@ -718,7 +707,6 @@ static webgpu_command ggml_backend_webgpu_build_multi(
     webgpu_command      result   = {};
     result.commands              = commands;
     result.params_bufs           = params_bufs_list;
-    result.set_rows_error_bufs   = set_rows_error_bufs;
     result.num_kernels           = pipelines.size();
 #ifdef GGML_WEBGPU_GPU_PROFILE
     result.timestamp_query_bufs = ts_bufs;
@@ -734,13 +722,13 @@ static webgpu_command ggml_backend_webgpu_build(webgpu_global_context &
                                                 std::vector<uint32_t>             params,
                                                 std::vector<wgpu::BindGroupEntry> bind_group_entries,
                                                 uint32_t                          wg_x,
-                                                uint32_t                          wg_y                = 1,
-                                                std::optional<webgpu_pool_bufs>   set_rows_error_bufs = std::nullopt) {
+                                                uint32_t                          wg_y = 1) {
     return ggml_backend_webgpu_build_multi(ctx, param_buf_pool,
                                            {
                                                pipeline
     },
-                                           { params }, { bind_group_entries }, { { wg_x, wg_y } }, set_rows_error_bufs);
+                                           { std::move(params) }, { std::move(bind_group_entries) },
+                                           { { wg_x, wg_y } });
 }
 
 static void ggml_backend_webgpu_buffer_memset(webgpu_global_context & ctx,
@@ -757,8 +745,9 @@ static void ggml_backend_webgpu_buffer_memset(webgpu_global_context & ctx,
 
     webgpu_command command =
         ggml_backend_webgpu_build(ctx, ctx->memset_buf_pool, ctx->memset_pipelines[0], params, entries, wg_x);
-    auto futures = ggml_backend_webgpu_submit(ctx, { command }, ctx->memset_buf_pool);
-    ggml_backend_webgpu_wait(ctx, futures);
+    std::vector<webgpu_command>    commands = { command };
+    std::vector<webgpu_submission> sub      = { ggml_backend_webgpu_submit(ctx, commands, ctx->memset_buf_pool) };
+    ggml_backend_webgpu_wait(ctx, sub);
 }
 
 /** End WebGPU Actions */
@@ -805,7 +794,8 @@ static void ggml_backend_webgpu_free(ggml_backend_t backend) {
     std::cout << "\nggml_webgpu: gpu breakdown:\n";
     for (const auto & kv : ctx->webgpu_ctx->global_ctx->shader_gpu_time_ms) {
         double pct = (total_gpu > 0.0) ? (kv.second / total_gpu * 100.0) : 0.0;
-        std::cout << "ggml_webgpu:  " << kv.first << ": " << kv.second << " ms (" << pct << "%)\n";
+        std::cout << "ggml_webgpu:  " << kv.first << ": " << kv.second << " ms (" << std::fixed << std::setprecision(2)
+                  << pct << "%)\n";
     }
 #endif
 
@@ -978,14 +968,6 @@ static std::optional<webgpu_command> ggml_webgpu_set_rows(webgpu_context & ctx,
 
     auto * decisions = static_cast<ggml_webgpu_set_rows_shader_decisions *>(pipeline.context.get());
 
-    std::optional<webgpu_pool_bufs> error_bufs = std::nullopt;
-    if (decisions->i64_idx) {
-        error_bufs = ctx->set_rows_error_buf_pool.alloc_bufs();
-        if (error_bufs->host_buf.GetMapState() == wgpu::BufferMapState::Mapped) {
-            error_bufs->host_buf.Unmap();
-        }
-    }
-
     std::vector<uint32_t> params = {
         (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src) / ggml_type_size(src->type)),
         (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, idx) / ggml_type_size(idx->type)),
@@ -1018,8 +1000,10 @@ static std::optional<webgpu_command> ggml_webgpu_set_rows(webgpu_context & ctx,
     };
 
     if (decisions->i64_idx) {
-        entries.push_back(
-            { .binding = 3, .buffer = error_bufs->dev_buf, .offset = 0, .size = error_bufs->dev_buf.GetSize() });
+        entries.push_back({ .binding = 3,
+                            .buffer  = ctx->set_rows_dev_error_buf,
+                            .offset  = 0,
+                            .size    = ctx->set_rows_dev_error_buf.GetSize() });
     }
 
     uint32_t threads;
@@ -1029,8 +1013,7 @@ static std::optional<webgpu_command> ggml_webgpu_set_rows(webgpu_context & ctx,
         threads = src->ne[0] * src->ne[1] * src->ne[2] * src->ne[3];
     }
     uint32_t wg_x = CEIL_DIV(threads, decisions->wg_size);
-    return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_buf_pool, pipeline, params, entries, wg_x, 1,
-                                     error_bufs);
+    return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_buf_pool, pipeline, params, entries, wg_x, 1);
 }
 
 // Workgroup size is a common constant
@@ -1108,12 +1091,26 @@ static webgpu_command ggml_webgpu_mul_mat(webgpu_context & ctx,
             use_fast = (src0->type == GGML_TYPE_F16);
             break;
         case GGML_TYPE_F32:
+            // TODO: implement better mat-mat for k-quants, mat-vec for all k-quants except q6_K
             switch (src0->type) {
                 case GGML_TYPE_F32:
                 case GGML_TYPE_F16:
                 case GGML_TYPE_Q4_0:
+                case GGML_TYPE_Q4_1:
+                case GGML_TYPE_Q5_0:
+                case GGML_TYPE_Q5_1:
+                case GGML_TYPE_Q8_0:
+                case GGML_TYPE_Q8_1:
+                case GGML_TYPE_Q6_K:
                     use_fast = true;
                     break;
+                case GGML_TYPE_Q2_K:
+                case GGML_TYPE_Q3_K:
+                case GGML_TYPE_Q4_K:
+                case GGML_TYPE_Q5_K:
+                    // we don't have fast mat-vec for these types, but we do have (semi) fast mat-mat
+                    use_fast = !is_vec;
+                    break;
                 default:
                     break;
             }
@@ -1187,17 +1184,18 @@ static webgpu_command ggml_webgpu_mul_mat(webgpu_context & ctx,
     const uint32_t max_wg_per_dim = ctx->global_ctx->capabilities.limits.maxComputeWorkgroupsPerDimension;
 
     if (use_fast && is_vec) {
-        auto decisions = static_cast<ggml_webgpu_mul_mat_vec_shader_decisions *>(pipeline.context.get());
+        auto decisions = static_cast<ggml_webgpu_mul_mat_vec_shader_decisions *>(pipeline.context.get());
 
         uint32_t batches       = dst->ne[2] * dst->ne[3];
         uint32_t output_groups = CEIL_DIV(dst->ne[0], decisions->outputs_per_wg);
         uint32_t total_wg      = output_groups * batches;
         compute_2d_workgroups(total_wg, max_wg_per_dim, wg_x, wg_y);
     } else if (use_fast) {
-        auto decisions = static_cast<ggml_webgpu_mul_mat_shader_decisions *>(pipeline.context.get());
+        auto decisions = static_cast<ggml_webgpu_mul_mat_shader_decisions *>(pipeline.context.get());
 
         // Fast-path tiled/subgroup calculations
-        uint32_t wg_m, wg_n;
+        uint32_t wg_m;
+        uint32_t wg_n;
         if (decisions->use_subgroup_matrix) {
             uint32_t wg_m_sg_tile =
                 decisions->subgroup_m * decisions->subgroup_matrix_m * ctx->global_ctx->capabilities.sg_mat_m;
@@ -1215,7 +1213,7 @@ static webgpu_command ggml_webgpu_mul_mat(webgpu_context & ctx,
         compute_2d_workgroups(total_wg, max_wg_per_dim, wg_x, wg_y);
 
     } else {  // legacy
-        auto     decisions = static_cast<ggml_webgpu_generic_shader_decisions *>(pipeline.context.get());
+        auto *   decisions = static_cast<ggml_webgpu_generic_shader_decisions *>(pipeline.context.get());
         uint32_t wg_size   = decisions->wg_size;
         uint32_t total_wg  = CEIL_DIV(dst->ne[0] * dst->ne[1] * dst->ne[2] * dst->ne[3], wg_size);
         compute_2d_workgroups(total_wg, max_wg_per_dim, wg_x, wg_y);
@@ -1514,10 +1512,10 @@ static webgpu_command ggml_webgpu_binary_op(webgpu_context & ctx,
 }
 
 static webgpu_command ggml_webgpu_concat(webgpu_context & ctx,
-                                         ggml_tensor * src0,
-                                         ggml_tensor * src1,
-                                         ggml_tensor * dst) {
-    uint32_t ne = (uint32_t) ggml_nelements(dst);
+                                         ggml_tensor *    src0,
+                                         ggml_tensor *    src1,
+                                         ggml_tensor *    dst) {
+    uint32_t ne  = (uint32_t) ggml_nelements(dst);
     uint32_t dim = (uint32_t) dst->op_params[0];
 
     std::vector<uint32_t> params = {
@@ -1538,28 +1536,22 @@ static webgpu_command ggml_webgpu_concat(webgpu_context & ctx,
         (uint32_t) dst->ne[2],
         (uint32_t) dst->ne[3],
         dim,
-        (uint32_t)src0->ne[dim]
+        (uint32_t) src0->ne[dim]
     };
 
     std::vector<wgpu::BindGroupEntry> entries = {
-        {
-            .binding = 0,
-            .buffer = ggml_webgpu_tensor_buf(src0),
-            .offset = ggml_webgpu_tensor_align_offset(ctx, src0),
-            .size = ggml_webgpu_tensor_binding_size(ctx, src0)
-        },
-        {
-            .binding = 1,
-            .buffer = ggml_webgpu_tensor_buf(src1),
-            .offset = ggml_webgpu_tensor_align_offset(ctx, src1),
-            .size = ggml_webgpu_tensor_binding_size(ctx, src1)
-        },
-        {
-            .binding = 2,
-            .buffer = ggml_webgpu_tensor_buf(dst),
-            .offset = ggml_webgpu_tensor_align_offset(ctx, dst),
-            .size = ggml_webgpu_tensor_binding_size(ctx, dst)
-        }
+        { .binding = 0,
+         .buffer  = ggml_webgpu_tensor_buf(src0),
+         .offset  = ggml_webgpu_tensor_align_offset(ctx, src0),
+         .size    = ggml_webgpu_tensor_binding_size(ctx, src0) },
+        { .binding = 1,
+         .buffer  = ggml_webgpu_tensor_buf(src1),
+         .offset  = ggml_webgpu_tensor_align_offset(ctx, src1),
+         .size    = ggml_webgpu_tensor_binding_size(ctx, src1) },
+        { .binding = 2,
+         .buffer  = ggml_webgpu_tensor_buf(dst),
+         .offset  = ggml_webgpu_tensor_align_offset(ctx, dst),
+         .size    = ggml_webgpu_tensor_binding_size(ctx, dst)  }
     };
 
     ggml_webgpu_shader_lib_context shader_lib_ctx = {
@@ -1569,9 +1561,9 @@ static webgpu_command ggml_webgpu_concat(webgpu_context & ctx,
         .max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup,
     };
 
-    webgpu_pipeline pipeline = ctx->shader_lib->get_concat_pipeline(shader_lib_ctx);
-    auto * decisions = static_cast<ggml_webgpu_generic_shader_decisions *>(pipeline.context.get());
-    uint32_t wg_x = CEIL_DIV(ne, decisions->wg_size);
+    webgpu_pipeline pipeline  = ctx->shader_lib->get_concat_pipeline(shader_lib_ctx);
+    auto *          decisions = static_cast<ggml_webgpu_generic_shader_decisions *>(pipeline.context.get());
+    uint32_t        wg_x      = CEIL_DIV(ne, decisions->wg_size);
     return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_buf_pool, pipeline, params, entries, wg_x);
 }
 
@@ -1623,7 +1615,12 @@ static webgpu_command ggml_webgpu_rope(webgpu_context & ctx,
     const int mode       = ((int32_t *) dst->op_params)[2];
     const int n_ctx_orig = ((int32_t *) dst->op_params)[4];
 
-    float freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow;
+    float freq_base;
+    float freq_scale;
+    float ext_factor;
+    float attn_factor;
+    float beta_fast;
+    float beta_slow;
     memcpy(&freq_base, (int32_t *) dst->op_params + 5, sizeof(float));
     memcpy(&freq_scale, (int32_t *) dst->op_params + 6, sizeof(float));
     memcpy(&ext_factor, (int32_t *) dst->op_params + 7, sizeof(float));
@@ -2172,19 +2169,12 @@ static std::optional<webgpu_command> ggml_webgpu_encode_node(webgpu_context ctx,
         case GGML_OP_SOFT_MAX:
             return ggml_webgpu_soft_max(ctx, src0, src1, src2, node);
         case GGML_OP_UNARY:
-            return ggml_webgpu_unary_op(ctx, src0, node);
         case GGML_OP_CLAMP:
-            return ggml_webgpu_unary_op(ctx, src0, node);
         case GGML_OP_FILL:
-            return ggml_webgpu_unary_op(ctx, src0, node);
         case GGML_OP_LOG:
-            return ggml_webgpu_unary_op(ctx, src0, node);
         case GGML_OP_SQR:
-            return ggml_webgpu_unary_op(ctx, src0, node);
         case GGML_OP_SQRT:
-            return ggml_webgpu_unary_op(ctx, src0, node);
         case GGML_OP_SIN:
-            return ggml_webgpu_unary_op(ctx, src0, node);
         case GGML_OP_COS:
             return ggml_webgpu_unary_op(ctx, src0, node);
         case GGML_OP_PAD:
@@ -2192,7 +2182,6 @@ static std::optional<webgpu_command> ggml_webgpu_encode_node(webgpu_context ctx,
         case GGML_OP_ARGMAX:
             return ggml_webgpu_argmax(ctx, src0, node);
         case GGML_OP_ARGSORT:
-            return ggml_webgpu_argsort(ctx, src0, node);
         case GGML_OP_TOP_K:
             // we reuse the same argsort implementation for top_k
             return ggml_webgpu_argsort(ctx, src0, node);
@@ -2214,33 +2203,51 @@ static ggml_status ggml_backend_webgpu_graph_compute(ggml_backend_t backend, str
 
     WEBGPU_CPU_PROFILE_TOTAL_START(graph_compute);
 
-    std::vector<webgpu_command>       commands;
-    std::vector<wgpu::FutureWaitInfo> futures;
-    uint32_t                          num_batched_kernels = 0;
+    std::vector<webgpu_command>    commands;
+    std::vector<webgpu_submission> subs;
+    uint32_t                       num_batched_kernels = 0;
+    bool                           contains_set_rows   = false;
+
     for (int i = 0; i < cgraph->n_nodes; i++) {
+        if (cgraph->nodes[i]->op == GGML_OP_SET_ROWS) {
+            contains_set_rows = true;
+        }
         if (auto cmd = ggml_webgpu_encode_node(ctx, cgraph->nodes[i])) {
             commands.push_back(*cmd);
             num_batched_kernels += cmd.value().num_kernels;
         }
 
         if (num_batched_kernels >= WEBGPU_COMMAND_SUBMIT_BATCH_SIZE) {
-            num_batched_kernels                               = 0;
-            std::vector<wgpu::FutureWaitInfo> compute_futures = ggml_backend_webgpu_submit(
-                ctx->global_ctx, commands, ctx->param_buf_pool, &ctx->set_rows_error_buf_pool);
-            futures.insert(futures.end(), compute_futures.begin(), compute_futures.end());
+            num_batched_kernels = 0;
+            subs.push_back(ggml_backend_webgpu_submit(ctx->global_ctx, commands, ctx->param_buf_pool));
             // Process events and check for completed submissions
             ctx->global_ctx->instance.ProcessEvents();
-            ggml_backend_webgpu_wait(ctx->global_ctx, futures, false);
+            ggml_backend_webgpu_wait(ctx->global_ctx, subs, false);
             commands.clear();
         }
     }
     if (!commands.empty()) {
-        auto new_futures =
-            ggml_backend_webgpu_submit(ctx->global_ctx, commands, ctx->param_buf_pool, &ctx->set_rows_error_buf_pool);
-        futures.insert(futures.end(), new_futures.begin(), new_futures.end());
+        subs.push_back(ggml_backend_webgpu_submit(ctx->global_ctx, commands, ctx->param_buf_pool));
+        commands.clear();
+    }
+
+    // If there are SET_ROWS operations in this graph, copy the error buffers to the host for checking.
+    if (contains_set_rows) {
+        wgpu::CommandEncoder encoder = ctx->global_ctx->device.CreateCommandEncoder();
+        encoder.CopyBufferToBuffer(ctx->set_rows_dev_error_buf, 0, ctx->set_rows_host_error_buf, 0,
+                                   ctx->set_rows_host_error_buf.GetSize());
+        wgpu::CommandBuffer set_rows_commands = encoder.Finish();
+        ctx->global_ctx->queue.Submit(1, &set_rows_commands);
+        ggml_backend_webgpu_map_buffer(ctx->global_ctx, ctx->set_rows_host_error_buf, wgpu::MapMode::Read, 0,
+                                       ctx->set_rows_host_error_buf.GetSize());
+        const uint32_t * error_data = (const uint32_t *) ctx->set_rows_host_error_buf.GetConstMappedRange();
+        if (*error_data) {
+            GGML_ABORT("ggml_webgpu: SET_ROWS index > 2^32, unsupported.");
+        }
+        ctx->set_rows_host_error_buf.Unmap();
     }
 
-    ggml_backend_webgpu_wait(ctx->global_ctx, futures);
+    ggml_backend_webgpu_wait(ctx->global_ctx, subs);
     WEBGPU_CPU_PROFILE_TOTAL_END(graph_compute, ctx->global_ctx);
     return GGML_STATUS_SUCCESS;
 }
@@ -2859,10 +2866,12 @@ static webgpu_context initialize_webgpu_context(ggml_backend_dev_t dev) {
     webgpu_ctx->param_buf_pool.init(webgpu_ctx->global_ctx->device, WEBGPU_NUM_PARAM_BUFS, WEBGPU_PARAMS_BUF_SIZE_BYTES,
                                     wgpu::BufferUsage::CopyDst | wgpu::BufferUsage::Uniform,
                                     wgpu::BufferUsage::CopySrc | wgpu::BufferUsage::MapWrite, true);
-    webgpu_ctx->set_rows_error_buf_pool.init(webgpu_ctx->global_ctx->device, WEBGPU_NUM_SET_ROWS_ERROR_BUFS,
-                                             WEBGPU_SET_ROWS_ERROR_BUF_SIZE_BYTES,
-                                             wgpu::BufferUsage::CopySrc | wgpu::BufferUsage::Storage,
-                                             wgpu::BufferUsage::CopyDst | wgpu::BufferUsage::MapRead);
+    ggml_webgpu_create_buffer(webgpu_ctx->global_ctx->device, webgpu_ctx->set_rows_dev_error_buf,
+                              WEBGPU_SET_ROWS_ERROR_BUF_SIZE_BYTES,
+                              wgpu::BufferUsage::Storage | wgpu::BufferUsage::CopySrc, "set_rows_dev_error_buf");
+    ggml_webgpu_create_buffer(webgpu_ctx->global_ctx->device, webgpu_ctx->set_rows_host_error_buf,
+                              WEBGPU_SET_ROWS_ERROR_BUF_SIZE_BYTES,
+                              wgpu::BufferUsage::CopyDst | wgpu::BufferUsage::MapRead, "set_rows_host_error_buf");
 
     ggml_webgpu_init_cpy_pipeline(webgpu_ctx);
     ggml_webgpu_init_rms_norm_pipeline(webgpu_ctx);
index 5c1074ebc10afea4945240d96f645a4dfff61c25..de60ebbcf2b5712330df77532b1af2767844b6a0 100644 (file)
@@ -11,7 +11,7 @@ fn store_shmem(val: vec4<f16>, idx: u32) {
     shmem[idx + 2] = val.z;
     shmem[idx + 3] = val.w;
 }
-#endif
+#endif // VEC
 
 #ifdef SCALAR
 #define VEC_SIZE 1
@@ -23,7 +23,7 @@ fn store_shmem(val: vec4<f16>, idx: u32) {
 fn store_shmem(val: f16, idx: u32) {
     shmem[idx] = val;
 }
-#endif
+#endif // SCALAR
 
 #ifdef INIT_SRC0_SHMEM_FLOAT
 fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u32) {
@@ -40,7 +40,7 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3
         store_shmem(SHMEM_TYPE(src0_val), elem_idx);
     }
 }
-#endif
+#endif // INIT_SRC0_SHMEM_FLOAT
 
 #ifdef INIT_SRC1_SHMEM_FLOAT
 fn init_shmem_src1(thread_id: u32, batch_offset: u32, offset_n: u32, k_outer: u32) {
@@ -57,7 +57,7 @@ fn init_shmem_src1(thread_id: u32, batch_offset: u32, offset_n: u32, k_outer: u3
         store_shmem(SHMEM_TYPE(src1_val), TILE_SRC0_SHMEM + elem_idx);
     }
 }
-#endif
+#endif // INIT_SRC1_SHMEM_FLOAT
 
 #ifdef INIT_SRC0_SHMEM_Q4_0
 const BLOCK_SIZE = 32u;
@@ -100,4 +100,667 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3
         }
     }
 }
-#endif
+#endif // INIT_SRC0_SHMEM_Q4_0
+
+#ifdef INIT_SRC0_SHMEM_Q4_1
+const BLOCK_SIZE = 32u;
+// the number of blocks per k-tile. Note that this currently only works if TILE_K is a multiple of BLOCK_SIZE, which may need to be rethought for larger quantized types.
+override BLOCKS_K = TILE_K/BLOCK_SIZE;
+const NQ = 16u;
+const F16_PER_BLOCK = 10u; // 1 scale + 8 packed weights + 1 mean
+const WEIGHTS_PER_F16 = 4u; // 4 weights per f16
+const F16_PER_THREAD = NQ / WEIGHTS_PER_F16;
+
+fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u32) {
+    for (var i = thread_id * NQ; i < TILE_SRC0_SHMEM; i += TOTAL_WORKGROUP_SIZE * NQ) {
+        let blck_idx = i / BLOCK_SIZE;
+        let block_offset = (i % BLOCK_SIZE) / WEIGHTS_PER_F16;
+        let shmem_idx = blck_idx * BLOCK_SIZE + block_offset * 2u;
+
+        let tile_m = blck_idx / BLOCKS_K;
+        let global_m = offset_m + tile_m;
+        let block_k = blck_idx % BLOCKS_K;
+        let global_k = k_outer / BLOCK_SIZE + block_k;
+
+        if (global_m < params.m && global_k < params.k / BLOCK_SIZE) {
+            let src0_idx = batch_offset + global_m * params.stride_01 + global_k;
+            let scale_idx = src0_idx * F16_PER_BLOCK;
+            let d = src0[scale_idx];
+            let m = src0[scale_idx + 1u];
+
+            for (var j = 0u; j < F16_PER_THREAD; j += 2) {
+                let q_0 = src0[scale_idx + 2u + block_offset + j];
+                let q_1 = src0[scale_idx + 2u + block_offset + j + 1];
+
+                let q_packed = bitcast<u32>(vec2(q_0, q_1));
+                for (var k = 0u; k < 4u; k++) {
+                    let q_byte = get_byte(q_packed, k);
+                    let q_lo = f16(q_byte & 0xF) * d + m;
+                    let q_hi = f16((q_byte >> 4) & 0xF) * d + m;
+                    shmem[shmem_idx + j * 2 + k] = q_lo;
+                    shmem[shmem_idx + j * 2 + k + 16u] = q_hi;
+                }
+            }
+        }
+    }
+}
+#endif // INIT_SRC0_SHMEM_Q4_1
+
+#ifdef INIT_SRC0_SHMEM_Q5_0
+// 32 weights per block, each at 4 bits each = 32 * 4 = 128 bits / 16 = 8 f16s per block
+const BLOCK_SIZE = 32u;
+// the number of blocks per k-tile. Note that this currently only works if TILE_K is a multiple of BLOCK_SIZE, which may need to be rethought for larger quantized types.
+// tile_k is defined as 32u, so blocks_k ends up being 1 always
+override BLOCKS_K = TILE_K / BLOCK_SIZE;
+const NQ = 16u;
+const F16_PER_BLOCK = 11u; // 1 scale + 2 qh + 8 packed weights
+const WEIGHTS_PER_F16 = 4u; // 4 weights per f16
+const F16_PER_THREAD = NQ / WEIGHTS_PER_F16; // 16 / 4 = 4 f16s per thread, each thread should handle 4 f16s * 4 weights per = 16 weights
+
+fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u32) {
+
+    for (var i = thread_id * NQ; i < TILE_SRC0_SHMEM; i += TOTAL_WORKGROUP_SIZE * NQ) {
+        let blck_idx    = i / BLOCK_SIZE;
+        let block_offset = (i % BLOCK_SIZE) / WEIGHTS_PER_F16;
+        let shmem_idx   = blck_idx * BLOCK_SIZE + block_offset * 2u;
+
+        let tile_m   = blck_idx / BLOCKS_K;
+        let global_m = offset_m + tile_m;
+        let block_k  = blck_idx % BLOCKS_K;
+        let global_k = k_outer / BLOCK_SIZE + block_k;
+
+        if (global_m < params.m && global_k < params.k / BLOCK_SIZE) {
+            let src0_idx  = batch_offset + global_m * params.stride_01 + global_k;
+            let scale_idx = src0_idx * F16_PER_BLOCK;
+
+            let d  = src0[scale_idx];
+            let qh0 = src0[scale_idx + 1u];
+            let qh1 = src0[scale_idx + 2u];
+            let qh_packed = bitcast<u32>(vec2(qh0, qh1));
+
+            for (var j = 0u; j < 2; j++) {
+                let q_0 = src0[scale_idx + 3u + block_offset + (j*2)];
+                let q_1 = src0[scale_idx + 3u + block_offset + (j*2) + 1u];
+
+                let q_packed = bitcast<u32>(vec2(q_0, q_1));
+
+                let j_adjusted = j + (block_offset / 2u);
+
+
+                for (var k = 0u; k < 4u; k++) {
+                    let q_byte = get_byte(q_packed, k);
+
+                    let qh_hi = (qh_packed >> (j_adjusted * 4 + k + 12)) & 0x10;
+                    let q_hi = (f16(((q_byte >> 4) & 0xF) | qh_hi) - 16.0) * d;
+                    let qh_lo = ((qh_packed >> (j_adjusted * 4 + k)) << 4) & 0x10;
+                    let q_lo = (f16((q_byte & 0xF) | qh_lo) - 16.0) * d;
+
+                    shmem[shmem_idx + j * 4u + k]        = q_lo; // store first weight
+                    shmem[shmem_idx + j * 4u + k + 16u]  = q_hi; // store second weight
+                }
+            }
+        }
+    }
+}
+#endif // INIT_SRC0_SHMEM_Q5_0
+
+#ifdef INIT_SRC0_SHMEM_Q5_1
+// 32 weights per block, each at 4 bits each = 32 * 4 = 128 bits / 16 = 8 f16s per block
+const BLOCK_SIZE = 32u;
+// the number of blocks per k-tile. Note that this currently only works if TILE_K is a multiple of BLOCK_SIZE, which may need to be rethought for larger quantized types.
+// tile_k is defined as 32u, so blocks_k ends up being 1 always
+override BLOCKS_K = TILE_K / BLOCK_SIZE;
+const NQ = 16u;
+const F16_PER_BLOCK = 12u; // 1 scale + 2 qh + 8 packed weights + 1 mean
+const WEIGHTS_PER_F16 = 4u; // 4 weights per f16
+const F16_PER_THREAD = NQ / WEIGHTS_PER_F16; // 16 / 4 = 4 f16s per thread, each thread should handle 4 f16s * 4 weights per = 16 weights
+
+fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u32) {
+
+    for (var i = thread_id * NQ; i < TILE_SRC0_SHMEM; i += TOTAL_WORKGROUP_SIZE * NQ) {
+        let blck_idx    = i / BLOCK_SIZE;
+        let block_offset = (i % BLOCK_SIZE) / WEIGHTS_PER_F16;
+        let shmem_idx   = blck_idx * BLOCK_SIZE + block_offset * 2u;
+
+        let tile_m   = blck_idx / BLOCKS_K;
+        let global_m = offset_m + tile_m;
+        let block_k  = blck_idx % BLOCKS_K;
+        let global_k = k_outer / BLOCK_SIZE + block_k;
+
+        if (global_m < params.m && global_k < params.k / BLOCK_SIZE) {
+            let src0_idx  = batch_offset + global_m * params.stride_01 + global_k;
+            let scale_idx = src0_idx * F16_PER_BLOCK;
+
+            let d  = src0[scale_idx];
+            let m = src0[scale_idx + 1u];
+            let qh0 = src0[scale_idx + 2u];
+            let qh1 = src0[scale_idx + 3u];
+            let qh_packed = bitcast<u32>(vec2(qh0, qh1));
+
+            for (var j = 0u; j < 2; j++) {
+
+                let q_0 = src0[scale_idx + 4u + block_offset + (j*2)];
+                let q_1 = src0[scale_idx + 4u + block_offset + (j*2) + 1u];
+
+                let q_packed = bitcast<u32>(vec2(q_0, q_1));
+
+                let j_adjusted = j + (block_offset / 2u);
+
+
+                for (var k = 0u; k < 4u; k++) {
+                    let q_byte = get_byte(q_packed, k);
+
+                    let qh_hi = (qh_packed >> (j_adjusted * 4 + k + 12)) & 0x10;
+                    let q_hi = (f16(((q_byte >> 4) & 0xF) | qh_hi)) * d + m;
+                    let qh_lo = ((qh_packed >> (j_adjusted * 4 + k)) << 4) & 0x10;
+                    let q_lo = (f16((q_byte & 0xF) | qh_lo)) * d + m;
+
+                    shmem[shmem_idx + j * 4u + k]        = q_lo; // store first weight
+                    shmem[shmem_idx + j * 4u + k + 16u]  = q_hi; // store second weight
+                }
+            }
+        }
+    }
+}
+#endif // INIT_SRC0_SHMEM_Q5_1
+
+#ifdef INIT_SRC0_SHMEM_Q8_0
+const BLOCK_SIZE = 32u;
+// the number of blocks per k-tile. Note that this currently only works if TILE_K is a multiple of BLOCK_SIZE, which may need to be rethought for larger quantized types.
+override BLOCKS_K = TILE_K/BLOCK_SIZE;
+const NQ = 16u;
+const F16_PER_BLOCK = 17u; // 1 scale + 16 in array of weights
+const WEIGHTS_PER_F16 = 2u; // 2 8-bit weights per f16
+const F16_PER_THREAD = NQ / WEIGHTS_PER_F16; // 8 f16s per thread
+
+fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u32) {
+    for (var i = thread_id * NQ; i < TILE_SRC0_SHMEM; i += TOTAL_WORKGROUP_SIZE * NQ) {
+        let blck_idx = i / BLOCK_SIZE;
+        let block_offset = (i % BLOCK_SIZE) / WEIGHTS_PER_F16;
+        let shmem_idx = blck_idx * BLOCK_SIZE + block_offset * 2u;
+
+        let tile_m = blck_idx / BLOCKS_K;
+        let global_m = offset_m + tile_m;
+        let block_k = blck_idx % BLOCKS_K;
+        let global_k = k_outer / BLOCK_SIZE + block_k;
+
+        if (global_m < params.m && global_k < params.k / BLOCK_SIZE) {
+            let src0_idx = batch_offset + global_m * params.stride_01 + global_k;
+            let scale_idx = src0_idx * F16_PER_BLOCK;
+            let d = src0[scale_idx];
+
+            for (var j = 0u; j < F16_PER_THREAD; j+=2) {
+                let q_0 = src0[scale_idx + 1u + block_offset + j];
+                let q_1 = src0[scale_idx + 1u + block_offset + j + 1];
+
+                let q_packed = bitcast<u32>(vec2(q_0, q_1));
+                for (var k = 0u; k < 4u; k++) {
+                    let q_byte = get_byte_i32(q_packed, k);
+
+                    let q_val = f16(q_byte) * d;
+                    shmem[shmem_idx + j * 2 + k] = q_val;
+                }
+            }
+        }
+    }
+}
+#endif // INIT_SRC0_SHMEM_Q8_0
+
+#ifdef INIT_SRC0_SHMEM_Q8_1
+const BLOCK_SIZE = 32u;
+// the number of blocks per k-tile. Note that this currently only works if TILE_K is a multiple of BLOCK_SIZE, which may need to be rethought for larger quantized types.
+override BLOCKS_K = TILE_K/BLOCK_SIZE;
+const NQ = 16u;
+const F16_PER_BLOCK = 18u; // 1 scale + 1 mean + 8 32-bit values in array of weights
+const WEIGHTS_PER_F16 = 2u; // 2 8-bit weights per f16
+const F16_PER_THREAD = NQ / WEIGHTS_PER_F16; // 8 f16s per thread, 2 threads per block
+
+fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u32) {
+    for (var i = thread_id * NQ; i < TILE_SRC0_SHMEM; i += TOTAL_WORKGROUP_SIZE * NQ) {
+        let blck_idx = i / BLOCK_SIZE;
+        let block_offset = (i % BLOCK_SIZE) / WEIGHTS_PER_F16;
+        let shmem_idx = blck_idx * BLOCK_SIZE + block_offset * 2u;
+
+        let tile_m = blck_idx / BLOCKS_K;
+        let global_m = offset_m + tile_m;
+        let block_k = blck_idx % BLOCKS_K;
+        let global_k = k_outer / BLOCK_SIZE + block_k;
+
+        if (global_m < params.m && global_k < params.k / BLOCK_SIZE) {
+            let src0_idx = batch_offset + global_m * params.stride_01 + global_k;
+            let scale_idx = src0_idx * F16_PER_BLOCK;
+            let d = src0[scale_idx];
+            let m = src0[scale_idx + 1u];
+
+            for (var j = 0u; j < F16_PER_THREAD; j+=2) {
+                let q_0 = src0[scale_idx + 2u + block_offset + j];
+                let q_1 = src0[scale_idx + 2u + block_offset + j + 1];
+
+                let q_packed = bitcast<u32>(vec2(q_0, q_1));
+                for (var k = 0u; k < 4u; k++) {
+                    let q_byte = get_byte_i32(q_packed, k);
+
+                    let q_val = f16(q_byte) * d + m;
+                    shmem[shmem_idx + j * 2 + k] = q_val;
+                }
+            }
+        }
+    }
+}
+#endif // INIT_SRC0_SHMEM_Q8_1
+
+#ifdef INIT_SRC0_SHMEM_Q2_K
+const BLOCK_SIZE = 256u;
+const F16_PER_BLOCK = 42u;
+
+fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u32) {
+    // Use standard thread layout instead of lane/row_group
+    for (var elem_idx = thread_id; elem_idx < TILE_SRC0_SHMEM; elem_idx += TOTAL_WORKGROUP_SIZE) {
+        let tile_m = elem_idx / TILE_K;
+        let tile_k = elem_idx % TILE_K;
+
+        let global_m = offset_m + tile_m;
+        let global_k = k_outer + tile_k;
+
+        if (global_m >= params.m || global_k >= params.k) {
+            shmem[elem_idx] = f16(0.0);
+            continue;
+        }
+
+        let block_k = global_k / BLOCK_SIZE;
+        let k_in_block = global_k % BLOCK_SIZE;
+
+        let src0_idx = batch_offset + global_m * params.stride_01 + block_k;
+        let scale_idx = src0_idx * F16_PER_BLOCK;
+
+        let d = src0[scale_idx + 40u];
+        let dmin = src0[scale_idx + 41u];
+
+        // Decode the element at position k_in_block
+        let block_of_32 = k_in_block / 32u;
+        let pos_in_32 = k_in_block % 32u;
+
+        let q_b_idx = (block_of_32 / 4u) * 32u;
+        let shift = (block_of_32 % 4u) * 2u;
+        let k = (pos_in_32 / 16u) * 16u;
+        let l = pos_in_32 % 16u;
+
+        let is = k_in_block / 16u;
+
+        let sc_0 = src0[scale_idx + 2u * (is / 4u)];
+        let sc_1 = src0[scale_idx + 2u * (is / 4u) + 1u];
+        let sc_packed = bitcast<u32>(vec2(sc_0, sc_1));
+        let sc = get_byte(sc_packed, is % 4u);
+
+        let dl = d * f16(sc & 0xFu);
+        let ml = dmin * f16(sc >> 4u);
+
+        let q_idx = q_b_idx + k + l;
+        let q_0 = src0[scale_idx + 8u + 2u * (q_idx / 4u)];
+        let q_1 = src0[scale_idx + 8u + 2u * (q_idx / 4u) + 1u];
+        let q_packed = bitcast<u32>(vec2(q_0, q_1));
+        let q_byte = get_byte(q_packed, q_idx % 4u);
+        let qs_val = (q_byte >> shift) & 3u;
+
+        let q_val = f16(qs_val) * dl - ml;
+        shmem[elem_idx] = q_val;
+    }
+}
+#endif // INIT_SRC0_SHMEM_Q2_K
+
+#ifdef INIT_SRC0_SHMEM_Q3_K
+const BLOCK_SIZE = 256u;
+const F16_PER_BLOCK = 55u;
+
+fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u32) {
+    for (var elem_idx = thread_id; elem_idx < TILE_SRC0_SHMEM; elem_idx += TOTAL_WORKGROUP_SIZE) {
+        let tile_m = elem_idx / TILE_K;
+        let tile_k = elem_idx % TILE_K;
+
+        let global_m = offset_m + tile_m;
+        let global_k = k_outer + tile_k;
+
+        if (global_m >= params.m || global_k >= params.k) {
+            shmem[elem_idx] = f16(0.0);
+            continue;
+        }
+
+        let block_k = global_k / BLOCK_SIZE;
+        let k_in_block = global_k % BLOCK_SIZE;
+
+        let src0_idx = batch_offset + global_m * params.stride_01 + block_k;
+        let scale_idx = src0_idx * F16_PER_BLOCK;
+
+        let d = src0[scale_idx + 54u];
+
+        // Load and unpack scales
+        let kmask1: u32 = 0x03030303u;
+        let kmask2: u32 = 0x0f0f0f0fu;
+
+        var scale_vals: array<u32, 4>;
+        for (var i: u32 = 0u; i < 4u; i++) {
+            let scale_0 = src0[scale_idx + 48u + (2u*i)];
+            let scale_1 = src0[scale_idx + 48u + (2u*i) + 1u];
+            scale_vals[i] = bitcast<u32>(vec2(scale_0, scale_1));
+        }
+
+        var tmp: u32 = scale_vals[2];
+        scale_vals[2] = ((scale_vals[0] >> 4u) & kmask2) | (((tmp >> 4u) & kmask1) << 4u);
+        scale_vals[3] = ((scale_vals[1] >> 4u) & kmask2) | (((tmp >> 6u) & kmask1) << 4u);
+        scale_vals[0] = (scale_vals[0] & kmask2) | ((tmp & kmask1) << 4u);
+        scale_vals[1] = (scale_vals[1] & kmask2) | (((tmp >> 2u) & kmask1) << 4u);
+
+        // Load hmask and qs arrays
+        var hmask_vals: array<u32, 8>;
+        for (var i: u32 = 0u; i < 8u; i++) {
+            let hmask_0 = src0[scale_idx + (2u*i)];
+            let hmask_1 = src0[scale_idx + (2u*i) + 1u];
+            hmask_vals[i] = bitcast<u32>(vec2(hmask_0, hmask_1));
+        }
+
+        var qs_vals: array<u32, 16>;
+        for (var i: u32 = 0u; i < 16u; i++) {
+            let qs_0 = src0[scale_idx + 16u + (2u*i)];
+            let qs_1 = src0[scale_idx + 16u + (2u*i) + 1u];
+            qs_vals[i] = bitcast<u32>(vec2(qs_0, qs_1));
+        }
+
+        let half = k_in_block / 128u;           // 0 or 1
+        let pos_in_half = k_in_block % 128u;    // 0-127
+        let shift_group = pos_in_half / 32u;    // 0-3
+        let pos_in_32 = pos_in_half % 32u;      // 0-31
+        let k_group = pos_in_32 / 16u;          // 0 or 1
+        let l = pos_in_32 % 16u;                // 0-15
+
+        let q_b_idx = half * 32u;               // 0 or 32
+        let shift = shift_group * 2u;           // 0, 2, 4, 6
+        let k = k_group * 16u;                  // 0 or 16
+        let is = k_in_block / 16u;              // 0-15
+
+        // m increments every 32 elements across entire 256 element block
+        let m_shift = k_in_block / 32u;         // 0-7
+        let m: u32 = 1u << m_shift;             // 1,2,4,8,16,32,64,128
+
+        let sc = get_byte(scale_vals[is / 4u], is % 4u);
+        let dl = d * (f16(sc) - 32.0);
+
+        let q_idx = q_b_idx + k + l;
+        let hm_idx = k + l;
+
+        let q_byte = get_byte(qs_vals[q_idx / 4u], q_idx % 4u);
+        let hmask_byte = get_byte(hmask_vals[hm_idx / 4u], hm_idx % 4u);
+
+        let hm = select(4.0, 0.0, (hmask_byte & m) != 0);
+        let qs_val = (q_byte >> shift) & 3u;
+
+        let q_val = (f16(qs_val) - f16(hm)) * dl;
+        shmem[elem_idx] = q_val;
+    }
+}
+
+#endif // INIT_SRC0_SHMEM_Q3_K
+
+#ifdef INIT_SRC0_SHMEM_Q4_K
+const BLOCK_SIZE = 256u;
+const F16_PER_BLOCK = 72u;
+
+fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u32) {
+    for (var elem_idx = thread_id; elem_idx < TILE_SRC0_SHMEM; elem_idx += TOTAL_WORKGROUP_SIZE) {
+        let tile_m = elem_idx / TILE_K;
+        let tile_k = elem_idx % TILE_K;
+
+        let global_m = offset_m + tile_m;
+        let global_k = k_outer + tile_k;
+
+        if (global_m >= params.m || global_k >= params.k) {
+            shmem[elem_idx] = f16(0.0);
+            continue;
+        }
+
+        let block_k = global_k / BLOCK_SIZE;
+        let k_in_block = global_k % BLOCK_SIZE;
+
+        let src0_idx = batch_offset + global_m * params.stride_01 + block_k;
+        let scale_idx = src0_idx * F16_PER_BLOCK;
+
+        let d = src0[scale_idx];
+        let dmin = src0[scale_idx + 1u];
+
+        // Load packed scales
+        var scale_vals: array<u32, 3>;
+        for (var i: u32 = 0u; i < 3u; i++) {
+            let scale_0 = src0[scale_idx + 2u + (2u*i)];
+            let scale_1 = src0[scale_idx + 2u + (2u*i) + 1u];
+            scale_vals[i] = bitcast<u32>(vec2(scale_0, scale_1));
+        }
+
+        // Map k_in_block to loop structure:
+        // Outer loop over 64-element groups (alternating q_b_idx)
+        // Inner loop over 2 shifts per group
+        let group_of_64 = k_in_block / 64u;  // 0-3 (maps to q_b_idx)
+        let pos_in_64 = k_in_block % 64u;    // 0-63
+        let shift_group = pos_in_64 / 32u;   // 0 or 1
+        let l = pos_in_64 % 32u;             // 0-31
+
+        let q_b_idx = group_of_64 * 32u;     // 0, 32, 64, 96
+        let shift = shift_group * 4u;        // 0 or 4
+        let is = k_in_block / 32u;           // 0-7
+
+        var sc: u32;
+        var mn: u32;
+
+        if (is < 4u) {
+            let sc_byte = get_byte(scale_vals[is / 4u], is % 4u);
+            let min_byte = get_byte(scale_vals[(is + 4u) / 4u], is % 4u);
+            sc = sc_byte & 63u;
+            mn = min_byte & 63u;
+        } else {
+            let sc_min_lo = get_byte(scale_vals[(is + 4u) / 4u], (is + 4u) % 4u);
+            let sc_hi = get_byte(scale_vals[(is - 4u) / 4u], (is - 4u) % 4u);
+            let min_hi = get_byte(scale_vals[is / 4u], is % 4u);
+
+            sc = (sc_min_lo & 0xFu) | ((sc_hi >> 6u) << 4u);
+            mn = (sc_min_lo >> 4u) | ((min_hi >> 6u) << 4u);
+        }
+
+        let dl = d * f16(sc);
+        let ml = dmin * f16(mn);
+
+        let q_idx = q_b_idx + l;
+        let q_0 = src0[scale_idx + 8u + 2u * (q_idx / 4u)];
+        let q_1 = src0[scale_idx + 8u + 2u * (q_idx / 4u) + 1u];
+        let q_packed = bitcast<u32>(vec2(q_0, q_1));
+
+        let q_byte = get_byte(q_packed, q_idx % 4u);
+        let qs_val = (q_byte >> shift) & 0xFu;
+
+        let q_val = f16(qs_val) * dl - ml;
+        shmem[elem_idx] = q_val;
+    }
+}
+#endif // INIT_SRC0_SHMEM_Q4_K
+
+#ifdef INIT_SRC0_SHMEM_Q5_K
+const BLOCK_SIZE = 256u;
+const F16_PER_BLOCK = 88u;
+
+fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u32) {
+    for (var elem_idx = thread_id; elem_idx < TILE_SRC0_SHMEM; elem_idx += TOTAL_WORKGROUP_SIZE) {
+        let tile_m = elem_idx / TILE_K;
+        let tile_k = elem_idx % TILE_K;
+
+        let global_m = offset_m + tile_m;
+        let global_k = k_outer + tile_k;
+
+        if (global_m >= params.m || global_k >= params.k) {
+            shmem[elem_idx] = f16(0.0);
+            continue;
+        }
+
+        let block_k = global_k / BLOCK_SIZE;
+        let k_in_block = global_k % BLOCK_SIZE;
+
+        let src0_idx = batch_offset + global_m * params.stride_01 + block_k;
+        let scale_idx = src0_idx * F16_PER_BLOCK;
+
+        let d = src0[scale_idx];
+        let dmin = src0[scale_idx + 1u];
+
+        // Load packed scales
+        var scale_vals: array<u32, 3>;
+        for (var i: u32 = 0u; i < 3u; i++) {
+            let scale_0 = src0[scale_idx + 2u + (2u*i)];
+            let scale_1 = src0[scale_idx + 2u + (2u*i) + 1u];
+            scale_vals[i] = bitcast<u32>(vec2(scale_0, scale_1));
+        }
+
+        // The original loop processes elements in groups of 64
+        // Each group of 64: q_b_idx cycles through [0,32,64,96], shift cycles [0,4]
+        // But u increments EVERY 32 elements (after each l loop)
+        let group_of_64 = k_in_block / 64u;  // 0-3
+        let pos_in_64 = k_in_block % 64u;    // 0-63
+        let shift_group = pos_in_64 / 32u;   // 0 or 1
+        let l = pos_in_64 % 32u;             // 0-31
+
+        let q_b_idx = group_of_64 * 32u;     // 0, 32, 64, 96
+        let shift = shift_group * 4u;        // 0 or 4
+        let is = k_in_block / 32u;           // 0-7
+
+        // u increments every 32 elements (0->1, 1->2, 2->4, 3->8, 4->16, 5->32, 6->64, 7->128)
+        let u_shift = k_in_block / 32u;      // 0-7
+        let u: u32 = 1u << u_shift;
+
+        var sc: u32;
+        var mn: u32;
+
+        if (is < 4u) {
+            let sc_byte = get_byte(scale_vals[is / 4u], is % 4u);
+            let min_byte = get_byte(scale_vals[(is + 4u) / 4u], is % 4u);
+            sc = sc_byte & 63u;
+            mn = min_byte & 63u;
+        } else {
+            let sc_min_lo = get_byte(scale_vals[(is + 4u) / 4u], (is + 4u) % 4u);
+            let sc_hi = get_byte(scale_vals[(is - 4u) / 4u], (is - 4u) % 4u);
+            let min_hi = get_byte(scale_vals[is / 4u], is % 4u);
+
+            sc = (sc_min_lo & 0xFu) | ((sc_hi >> 6u) << 4u);
+            mn = (sc_min_lo >> 4u) | ((min_hi >> 6u) << 4u);
+        }
+
+        let dl = d * f16(sc);
+        let ml = dmin * f16(mn);
+
+        let q_idx = q_b_idx + l;
+        let q_0 = src0[scale_idx + 24u + 2u * (q_idx / 4u)];
+        let q_1 = src0[scale_idx + 24u + 2u * (q_idx / 4u) + 1u];
+        let q_packed = bitcast<u32>(vec2(q_0, q_1));
+
+        let q_byte = get_byte(q_packed, q_idx % 4u);
+
+        let qh_0 = src0[scale_idx + 8u + 2u * (l / 4u)];
+        let qh_1 = src0[scale_idx + 8u + 2u * (l / 4u) + 1u];
+        let qh_packed = bitcast<u32>(vec2(qh_0, qh_1));
+
+        let qh_byte = get_byte(qh_packed, l % 4u);
+
+        let qs_val = (q_byte >> shift) & 0xFu;
+        let qh_val = select(0.0, 16.0, (qh_byte & u) != 0);
+
+        let q_val = (f16(qs_val) + f16(qh_val)) * dl - ml;
+        shmem[elem_idx] = q_val;
+    }
+}
+
+#endif // INIT_SRC0_SHMEM_Q5_K
+
+#ifdef INIT_SRC0_SHMEM_Q6_K
+const BLOCK_SIZE = 256u;
+const F16_PER_BLOCK = 105u;
+
+fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u32) {
+    for (var elem_idx = thread_id; elem_idx < TILE_SRC0_SHMEM; elem_idx += TOTAL_WORKGROUP_SIZE) {
+        let tile_m = elem_idx / TILE_K;
+        let tile_k = elem_idx % TILE_K;
+
+        let global_m = offset_m + tile_m;
+        let global_k = k_outer + tile_k;
+
+        if (global_m >= params.m || global_k >= params.k) {
+            shmem[elem_idx] = f16(0.0);
+            continue;
+        }
+
+        let block_k = global_k / BLOCK_SIZE;
+        let k_in_block = global_k % BLOCK_SIZE;
+
+        let src0_idx = batch_offset + global_m * params.stride_01 + block_k;
+        let scale_idx = src0_idx * F16_PER_BLOCK;
+
+        let half = k_in_block / 128u;
+        let pos_in_half = k_in_block % 128u;
+        let quarter = pos_in_half / 32u;
+        let l = pos_in_half % 32u;
+
+        let ql_b_idx = half * 64u;
+        let qh_b_idx = half * 32u;
+        let sc_b_idx = half * 8u;
+
+        // Load only ql13 word needed
+        let ql13_flat = ql_b_idx + l;
+        let ql13_word = ql13_flat / 4u;
+        let ql13 = bitcast<u32>(vec2(
+            src0[scale_idx + 2u * ql13_word],
+            src0[scale_idx + 2u * ql13_word + 1u]
+        ));
+        let ql13_b = get_byte(ql13, ql13_flat % 4u);
+
+        // Load only ql24 word needed
+        let ql24_flat = ql_b_idx + l + 32u;
+        let ql24_word = ql24_flat / 4u;
+        let ql24 = bitcast<u32>(vec2(
+            src0[scale_idx + 2u * ql24_word],
+            src0[scale_idx + 2u * ql24_word + 1u]
+        ));
+        let ql24_b = get_byte(ql24, ql24_flat % 4u);
+
+        // Load only qh word needed
+        let qh_flat = qh_b_idx + l;
+        let qh_word = qh_flat / 4u;
+        let qh = bitcast<u32>(vec2(
+            src0[scale_idx + 64u + 2u * qh_word],
+            src0[scale_idx + 64u + 2u * qh_word + 1u]
+        ));
+        let qh_b = get_byte(qh, qh_flat % 4u);
+
+        let q1 = f16((ql13_b & 0xFu) | ((qh_b & 3u) << 4u)) - f16(32.0);
+        let q2 = f16((ql24_b & 0xFu) | (((qh_b >> 2u) & 3u) << 4u)) - f16(32.0);
+        let q3 = f16((ql13_b >> 4u) | (((qh_b >> 4u) & 3u) << 4u)) - f16(32.0);
+        let q4 = f16((ql24_b >> 4u) | (((qh_b >> 6u) & 3u) << 4u)) - f16(32.0);
+
+        // Load only the scale word needed
+        let is = l / 16u;
+        let sc_idx = sc_b_idx + is + quarter * 2u;
+        let sc_word = sc_idx / 4u;
+        let sc = bitcast<u32>(vec2(
+            src0[scale_idx + 96u + 2u * sc_word],
+            src0[scale_idx + 96u + 2u * sc_word + 1u]
+        ));
+        let sc_val = get_byte_i32(sc, sc_idx % 4u);
+
+        let d = src0[scale_idx + 104u];
+
+        var q_val: f16;
+        if (quarter == 0u) {
+            q_val = q1;
+        } else if (quarter == 1u) {
+            q_val = q2;
+        } else if (quarter == 2u) {
+            q_val = q3;
+        } else {
+            q_val = q4;
+        }
+
+        shmem[elem_idx] = d * f16(sc_val) * q_val;
+    }
+}
+#endif // INIT_SRC0_SHMEM_Q6_K
index 761e3017c14cbfdf66b3b698a150f61ce9a3501a..b1da421a69158890690169f0833a1aedb267f447 100644 (file)
@@ -50,6 +50,7 @@ fn get_local_m(thread_id: u32) -> u32 {
 const TOTAL_WORKGROUP_SIZE = WORKGROUP_SIZE_M * WORKGROUP_SIZE_N;
 const TILE_SRC0_SHMEM = TILE_K * WORKGROUP_SIZE_M * TILE_M;
 const TILE_SRC1_SHMEM = TILE_K * WORKGROUP_SIZE_N * TILE_N;
+
 var<workgroup> shmem: array<f16, TILE_SRC0_SHMEM + TILE_SRC1_SHMEM>;
 
 @compute @workgroup_size(TOTAL_WORKGROUP_SIZE)
index f9ea95e07b9ef3bb2798a3c328727bdbe40c20dd..94f4bae11f4a3aa8e9b3538668fa85115e5a2749 100644 (file)
@@ -1,4 +1,3 @@
-
 enable f16;
 
 #include "common_decls.tmpl"
@@ -84,6 +83,294 @@ fn mul_acc(tig:u32, tile_size: u32, idx_base: u32, k_outer: u32) -> f32 {
 }
 #endif
 
+#ifdef MUL_ACC_Q4_1
+
+const BLOCK_SIZE = 32;
+const NQ = 16u; // number of weights per thread
+const F16_PER_BLOCK = 10u;
+const WEIGHTS_PER_F16 = 4u; // 4 weights per f16
+const F16_PER_THREAD = NQ / WEIGHTS_PER_F16;
+
+fn mul_acc(tig:u32, tile_size: u32, idx_base: u32, k_outer: u32) -> f32 {
+    var local_sum = 0.0;
+    for (var i = tig * NQ; i < tile_size; i += THREADS_PER_OUTPUT * NQ) {
+        let blck_idx = i / BLOCK_SIZE;
+        let block_offset = (i % BLOCK_SIZE) / WEIGHTS_PER_F16;
+        let scale_idx = (idx_base + k_outer / BLOCK_SIZE + blck_idx) * F16_PER_BLOCK;
+        // each f16 contains offsets [block_offset, block_offset + 1] and [block_offset + 16, block_offset + 17]
+        let shmem_idx = blck_idx * BLOCK_SIZE + block_offset * 2u;
+        let d = f32(src0[scale_idx]);
+        let m = f32(src0[scale_idx + 1u]);
+        for (var j = 0u; j < F16_PER_THREAD; j += 2) {
+            let q_0 = src0[scale_idx + 2u + block_offset + j];
+            let q_1 = src0[scale_idx + 2u + block_offset + j + 1];
+            let q_packed = bitcast<u32>(vec2(q_0, q_1));
+            for (var k: u32 = 0; k < 4; k++) {
+                let q_byte = get_byte(q_packed, k);
+                let q_hi = f32((q_byte >> 4) & 0xF) * d + m;
+                let q_lo = f32(q_byte & 0xF) * d + m;
+                local_sum += q_lo * shared_vector[shmem_idx + j * 2 + k];
+                local_sum += q_hi * shared_vector[shmem_idx + j * 2 + k + 16];
+            }
+        }
+    }
+    return local_sum;
+}
+#endif
+
+#ifdef MUL_ACC_Q5_0
+
+const BLOCK_SIZE = 32;
+const NQ = 16u; // number of weights per thread
+const F16_PER_BLOCK = 11u;
+const WEIGHTS_PER_F16 = 4u; // 4 weights per f16
+const F16_PER_THREAD = NQ / WEIGHTS_PER_F16;
+
+fn mul_acc(tig:u32, tile_size: u32, idx_base: u32, k_outer: u32) -> f32 {
+    var local_sum = 0.0;
+    for (var i = tig * NQ; i < tile_size; i += THREADS_PER_OUTPUT * NQ) {
+        let blck_idx = i / BLOCK_SIZE;
+        let block_offset = (i % BLOCK_SIZE) / WEIGHTS_PER_F16;
+        let scale_idx = (idx_base + k_outer / BLOCK_SIZE + blck_idx) * F16_PER_BLOCK;
+        // each f16 contains offsets [block_offset, block_offset + 1] and [block_offset + 16, block_offset + 17]
+        let shmem_idx = blck_idx * BLOCK_SIZE + block_offset * 2u;
+        let d = f32(src0[scale_idx]);
+        let qh0 = src0[scale_idx + 1u];
+        let qh1 = src0[scale_idx + 2u];
+        let qh_packed = bitcast<u32>(vec2(qh0, qh1));
+
+        for (var j = 0u; j < 2; j++) {
+            let q_0 = src0[scale_idx + 3u + block_offset + (j*2)];
+            let q_1 = src0[scale_idx + 3u + block_offset + (j*2) + 1u];
+            let q_packed = bitcast<u32>(vec2(q_0, q_1));
+
+            let j_adjusted = j + (block_offset / 2u);
+
+            for (var k: u32 = 0; k < 4; k++) {
+                let q_byte = get_byte(q_packed, k);
+
+                let qh_hi = (qh_packed >> (j_adjusted * 4 + k + 12)) & 0x10;
+                let q_hi = (f32(((q_byte >> 4) & 0xF) | qh_hi) - 16.0) * d;
+                let qh_lo = ((qh_packed >> (j_adjusted * 4 + k)) << 4) & 0x10;
+                let q_lo = (f32((q_byte & 0xF) | qh_lo) - 16.0) * d;
+
+                local_sum += q_lo * shared_vector[shmem_idx + j * 4 + k];
+                local_sum += q_hi * shared_vector[shmem_idx + j * 4 + k + 16];
+            }
+
+        }
+    }
+    return local_sum;
+}
+#endif
+
+
+#ifdef MUL_ACC_Q5_1
+
+const BLOCK_SIZE = 32;
+const NQ = 16u; // number of weights per thread
+const F16_PER_BLOCK = 12u;
+const WEIGHTS_PER_F16 = 4u; // 4 weights per f16
+const F16_PER_THREAD = NQ / WEIGHTS_PER_F16;
+
+fn mul_acc(tig:u32, tile_size: u32, idx_base: u32, k_outer: u32) -> f32 {
+    var local_sum = 0.0;
+    for (var i = tig * NQ; i < tile_size; i += THREADS_PER_OUTPUT * NQ) {
+        let blck_idx = i / BLOCK_SIZE;
+        let block_offset = (i % BLOCK_SIZE) / WEIGHTS_PER_F16;
+        let scale_idx = (idx_base + k_outer / BLOCK_SIZE + blck_idx) * F16_PER_BLOCK;
+        // each f16 contains offsets [block_offset, block_offset + 1] and [block_offset + 16, block_offset + 17]
+        let shmem_idx = blck_idx * BLOCK_SIZE + block_offset * 2u;
+        let d = f32(src0[scale_idx]);
+        let m = src0[scale_idx + 1u];
+        let qh0 = src0[scale_idx + 2u];
+        let qh1 = src0[scale_idx + 3u];
+        let qh_packed = bitcast<u32>(vec2(qh0, qh1));
+
+        for (var j = 0u; j < 2; j++) {
+            let q_0 = src0[scale_idx + 4u + block_offset + (j*2)];
+            let q_1 = src0[scale_idx + 4u + block_offset + (j*2) + 1u];
+            let q_packed = bitcast<u32>(vec2(q_0, q_1));
+
+            let j_adjusted = j + (block_offset / 2u);
+
+            for (var k: u32 = 0; k < 4; k++) {
+                let q_byte = get_byte(q_packed, k);
+
+                let qh_hi = (qh_packed >> (j_adjusted * 4 + k + 12)) & 0x10;
+                let q_hi = f32(((q_byte >> 4) & 0xF) | qh_hi) * d + f32(m);
+                let qh_lo = ((qh_packed >> (j_adjusted * 4 + k)) << 4) & 0x10;
+                let q_lo = f32((q_byte & 0xF) | qh_lo) * d + f32(m);
+
+                local_sum += q_lo * shared_vector[shmem_idx + j * 4 + k];
+                local_sum += q_hi * shared_vector[shmem_idx + j * 4 + k + 16];
+            }
+
+        }
+    }
+    return local_sum;
+}
+#endif
+
+
+#ifdef MUL_ACC_Q8_0
+
+const BLOCK_SIZE = 32;
+const NQ = 16u; // number of weights per thread
+const F16_PER_BLOCK = 17u;
+const WEIGHTS_PER_F16 = 2u;
+const F16_PER_THREAD = NQ / WEIGHTS_PER_F16;
+
+fn mul_acc(tig:u32, tile_size: u32, idx_base: u32, k_outer: u32) -> f32 {
+    var local_sum = 0.0;
+    for (var i = tig * NQ; i < tile_size; i += THREADS_PER_OUTPUT * NQ) {
+        let blck_idx = i / BLOCK_SIZE;
+        let block_offset = (i % BLOCK_SIZE) / WEIGHTS_PER_F16;
+        let scale_idx = (idx_base + k_outer / BLOCK_SIZE + blck_idx) * F16_PER_BLOCK;
+        // each f16 contains offsets [block_offset, block_offset + 1] and [block_offset + 16, block_offset + 17]
+        let shmem_idx = blck_idx * BLOCK_SIZE + block_offset * 2u;
+        let d = f32(src0[scale_idx]);
+
+        for (var j = 0u; j < F16_PER_THREAD; j += 2) {
+            let q_0 = src0[scale_idx + 1 + block_offset + j];
+            let q_1 = src0[scale_idx + 1 + block_offset + j + 1];
+            let q_packed = bitcast<u32>(vec2(q_0, q_1));
+            for (var k: u32 = 0; k < 4; k++) {
+                let q_byte = get_byte_i32(q_packed, k);
+                let q_val = f32(q_byte) * d;
+                local_sum += q_val * shared_vector[shmem_idx + j * 2 + k];
+            }
+        }
+    }
+    return local_sum;
+}
+#endif
+
+
+#ifdef MUL_ACC_Q8_1
+
+const BLOCK_SIZE = 32;
+const NQ = 16u; // number of weights per thread
+const F16_PER_BLOCK = 18u;
+const WEIGHTS_PER_F16 = 2u;
+const F16_PER_THREAD = NQ / WEIGHTS_PER_F16;
+
+fn mul_acc(tig:u32, tile_size: u32, idx_base: u32, k_outer: u32) -> f32 {
+    var local_sum = 0.0;
+    for (var i = tig * NQ; i < tile_size; i += THREADS_PER_OUTPUT * NQ) {
+        let blck_idx = i / BLOCK_SIZE;
+        let block_offset = (i % BLOCK_SIZE) / WEIGHTS_PER_F16;
+        let scale_idx = (idx_base + k_outer / BLOCK_SIZE + blck_idx) * F16_PER_BLOCK;
+        // each f16 contains offsets [block_offset, block_offset + 1] and [block_offset + 16, block_offset + 17]
+        let shmem_idx = blck_idx * BLOCK_SIZE + block_offset * 2u;
+        let d = f32(src0[scale_idx]);
+        let m = src0[scale_idx + 1u];
+
+        for (var j = 0u; j < F16_PER_THREAD; j += 2) {
+            let q_0 = src0[scale_idx + 2u + block_offset + j];
+            let q_1 = src0[scale_idx + 2u + block_offset + j + 1];
+            let q_packed = bitcast<u32>(vec2(q_0, q_1));
+            for (var k: u32 = 0; k < 4; k++) {
+                let q_byte = get_byte_i32(q_packed, k);
+                let q_val = f32(q_byte) * d + f32(m);
+                local_sum += q_val * shared_vector[shmem_idx + j * 2 + k];
+            }
+        }
+    }
+    return local_sum;
+}
+#endif
+
+#ifdef MUL_ACC_Q6_K
+
+const BLOCK_SIZE = 256u;
+const F16_PER_BLOCK = 105u;
+
+fn load_u32_at(bbase: u32, byte_offset: u32) -> u32 {
+    let aligned = byte_offset & ~3u;
+    let idx = bbase + aligned / 2u;
+    return bitcast<u32>(vec2(src0[idx], src0[idx + 1u]));
+}
+
+fn byte_of(v: u32, b: u32) -> u32 {
+    return (v >> (b * 8u)) & 0xFFu;
+}
+
+fn sbyte_of(v: u32, b: u32) -> i32 {
+    let raw = i32((v >> (b * 8u)) & 0xFFu);
+    return select(raw, raw - 256, raw >= 128);
+}
+
+fn mul_acc(tig: u32, tile_size: u32, idx_base: u32, k_outer: u32) -> f32 {
+    let tid = tig / 2u;
+    let ix  = tig % 2u;
+    let ip  = tid / 8u;
+    let il  = tid % 8u;
+    let l0  = 4u * il;
+    let is  = 8u * ip + l0 / 16u;
+
+    let y_offset   = 128u * ip + l0;
+    let q_offset_l =  64u * ip + l0;
+    let q_offset_h =  32u * ip + l0;
+
+    let nb = tile_size / BLOCK_SIZE;
+    let k_block_start = k_outer / BLOCK_SIZE;
+
+    // Aligned scale byte position (is can be odd)
+    let sc_base_byte = 192u + (is & ~3u);
+    let sc_byte_pos  = is & 3u;
+
+    var local_sum = 0.0;
+
+    for (var i = ix; i < nb; i += 2u) {
+        let bbase = (idx_base + k_block_start + i) * F16_PER_BLOCK;
+
+        let d_raw = load_u32_at(bbase, 208u);
+        let d = f32(bitcast<vec2<f16>>(d_raw)[0]);
+
+        let ql1_u32  = load_u32_at(bbase, q_offset_l);
+        let ql2_u32  = load_u32_at(bbase, q_offset_l + 32u);
+        let qh_u32   = load_u32_at(bbase, 128u + q_offset_h);
+        let sc_u32_0 = load_u32_at(bbase, sc_base_byte);
+        let sc_u32_1 = load_u32_at(bbase, sc_base_byte + 4u);
+
+        let sc0 = sbyte_of(sc_u32_0, sc_byte_pos);
+        let sc2 = sbyte_of(sc_u32_0, sc_byte_pos + 2u);
+        let sc4 = sbyte_of(sc_u32_1, sc_byte_pos);
+        let sc6 = sbyte_of(sc_u32_1, sc_byte_pos + 2u);
+
+        var sums = vec4<f32>(0.0, 0.0, 0.0, 0.0);
+
+        for (var l = 0u; l < 4u; l++) {
+            let y_base = i * BLOCK_SIZE + y_offset + l;
+            let yl0 = f32(shared_vector[y_base]);
+            let yl1 = f32(shared_vector[y_base + 32u]);
+            let yl2 = f32(shared_vector[y_base + 64u]);
+            let yl3 = f32(shared_vector[y_base + 96u]);
+
+            let q1b = byte_of(ql1_u32, l);
+            let q2b = byte_of(ql2_u32, l);
+            let qhb = byte_of(qh_u32,  l);
+
+            let dq0 = f32(i32((q1b & 0x0Fu) | ((qhb & 0x03u) << 4u)) - 32);
+            let dq1 = f32(i32((q2b & 0x0Fu) | ((qhb & 0x0Cu) << 2u)) - 32);
+            let dq2 = f32(i32((q1b >>   4u) | ((qhb & 0x30u)       )) - 32);
+            let dq3 = f32(i32((q2b >>   4u) | ((qhb & 0xC0u) >> 2u)) - 32);
+
+            sums[0] += yl0 * dq0;
+            sums[1] += yl1 * dq1;
+            sums[2] += yl2 * dq2;
+            sums[3] += yl3 * dq3;
+        }
+
+        local_sum += d * (sums[0] * f32(sc0) + sums[1] * f32(sc2) +
+                          sums[2] * f32(sc4) + sums[3] * f32(sc6));
+    }
+
+    return local_sum;
+}
+#endif
+
 struct MulMatParams {
     offset_src0: u32,
     offset_src1: u32,
@@ -191,4 +478,3 @@ fn main(
         dst[dst_idx / VEC_SIZE] = store_val(group_base);
     }
 }
-