]> git.djapps.eu Git - pkg/ggml/sources/ggml/commitdiff
ggml webgpu: quantized buffers to u32 + wider browser/device support (llama/21046)
authorReese Levine <redacted>
Wed, 1 Apr 2026 05:38:24 +0000 (22:38 -0700)
committerGeorgi Gerganov <redacted>
Wed, 1 Apr 2026 13:00:26 +0000 (16:00 +0300)
* Work towards removing bitcast

* Move rest of existing types over

* Add timeout back to wait and remove synchronous set_tensor/memset_tensor

* move to unpackf16 for wider compatibility

* cleanup

* Remove deadlock condition in free_bufs

src/ggml-webgpu/ggml-webgpu-shader-lib.hpp
src/ggml-webgpu/ggml-webgpu.cpp
src/ggml-webgpu/wgsl-shaders/common_decls.tmpl
src/ggml-webgpu/wgsl-shaders/flash_attn.wgsl
src/ggml-webgpu/wgsl-shaders/mul_mat_decls.tmpl
src/ggml-webgpu/wgsl-shaders/mul_mat_vec.wgsl

index 97863f40412d65b914235257e0dfb8e23ac8fa0a..a194ce84e2556f217db5cd1aa483c79ee3e8fca0 100644 (file)
@@ -1219,9 +1219,8 @@ class ggml_webgpu_shader_lib {
 
                     defines.push_back("BYTE_HELPERS");
                     defines.push_back("MUL_ACC_" + type_upper);
-
-                    // For fast path we always dequantize from f16 inside the shader
-                    defines.push_back("SRC0_INNER_TYPE=f16");
+                    defines.push_back("U32_DEQUANT_HELPERS");
+                    defines.push_back("SRC0_INNER_TYPE=u32");
                     break;
                 }
         }
@@ -1334,9 +1333,8 @@ class ggml_webgpu_shader_lib {
                     defines.push_back("MUL_ACC_" + type_upper);
                     defines.push_back("INIT_SRC0_SHMEM_" + type_upper);
                     defines.push_back("INIT_SRC1_SHMEM_FLOAT");
-
-                    // Use f16 inside the shader for quantized types
-                    defines.push_back("SRC0_INNER_TYPE=f16");
+                    defines.push_back("U32_DEQUANT_HELPERS");
+                    defines.push_back("SRC0_INNER_TYPE=u32");
 
                     variant += std::string("_") + src0_name;
                     break;
index fa3c492a7a5537c7376a85831bf355ffe5b20697..1aa15b0507cc41b2beb722c5360dbf22a23512f8 100644 (file)
@@ -83,7 +83,7 @@ static inline void compute_2d_workgroups(uint32_t total_wg, uint32_t max_per_dim
 
 #define WEBGPU_NUM_PARAM_BUFS                96u
 #define WEBGPU_COMMAND_SUBMIT_BATCH_SIZE     32u
-#define WEBGPU_WAIT_ANY_TIMEOUT_MS           0
+#define WEBGPU_WAIT_ANY_TIMEOUT_MS           100
 // Maximum number of in-flight submissions per-thread, to avoid exhausting the
 // parameter buffer pool
 #define WEBGPU_MAX_INFLIGHT_SUBS_PER_THREAD  (WEBGPU_NUM_PARAM_BUFS / WEBGPU_COMMAND_SUBMIT_BATCH_SIZE)
@@ -171,6 +171,7 @@ struct webgpu_buf_pool {
         // Try growing the pool if no free buffers
         if (free.empty() && cur_pool_size < max_pool_size && should_grow) {
             cur_pool_size++;
+            lock.unlock();  // avoid deadlock between this lock and Dawn's internal locks when buffers are freed in callbacks
             wgpu::Buffer dev_buf;
             ggml_webgpu_create_buffer(device, dev_buf, buf_size, dev_buf_usage, "ggml_webgpu_dev_pool_buf");
 
@@ -507,7 +508,7 @@ static void ggml_backend_webgpu_wait(webgpu_global_context &          ctx,
 
     bool blocking_wait = block || subs.size() >= WEBGPU_MAX_INFLIGHT_SUBS_PER_THREAD;
     while (blocking_wait) {
-        auto waitStatus = ctx->instance.WaitAny(1, &subs[0].submit_done, 0);
+        auto waitStatus = ctx->instance.WaitAny(1, &subs[0].submit_done, WEBGPU_WAIT_ANY_TIMEOUT_MS * 1e6);
         if (ggml_backend_webgpu_handle_wait_status(waitStatus, true)) {
 #ifdef GGML_WEBGPU_GPU_PROFILE
             ggml_backend_webgpu_wait_profile_futures(ctx, subs[0].profile_futures, true);
@@ -728,7 +729,6 @@ static void ggml_backend_webgpu_buffer_memset(webgpu_global_context & ctx,
         ggml_backend_webgpu_build(ctx, ctx->memset_buf_pool, ctx->memset_pipelines[0], params, entries, wg_x);
     std::vector<webgpu_command>    commands = { command };
     std::vector<webgpu_submission> sub      = { ggml_backend_webgpu_submit(ctx, commands, ctx->memset_buf_pool) };
-    ggml_backend_webgpu_wait(ctx, sub);
 }
 
 /** End WebGPU Actions */
@@ -2694,17 +2694,6 @@ static void ggml_backend_webgpu_buffer_set_tensor(ggml_backend_buffer_t buffer,
         // memset the remaining bytes
         ggml_backend_webgpu_buffer_memset(buf_ctx->global_ctx, buf_ctx->buffer, val32,
                                           total_offset + (size - remaining_size), remaining_size);
-    } else {
-        // wait for WriteBuffer to complete
-        buf_ctx->global_ctx->instance.WaitAny(buf_ctx->global_ctx->queue.OnSubmittedWorkDone(
-                                                  wgpu::CallbackMode::AllowSpontaneous,
-                                                  [](wgpu::QueueWorkDoneStatus status, wgpu::StringView message) {
-                                                      if (status != wgpu::QueueWorkDoneStatus::Success) {
-                                                          GGML_LOG_ERROR("ggml_webgpu: Failed to submit commands: %s\n",
-                                                                         std::string(message).c_str());
-                                                      }
-                                                  }),
-                                              UINT64_MAX);
     }
     WEBGPU_CPU_PROFILE_TOTAL_END(set_tensor, buf_ctx->global_ctx);
 }
index 9a5b18ebc07b3263d35ec3a5138b91ebf46929ef..feb0bca3f847760efed0be4a8bfcd843783d9bcd 100644 (file)
@@ -8,6 +8,30 @@ fn get_byte_i32(value: u32, index: u32) -> i32 {
 }
 #endif
 
+#ifdef U32_DEQUANT_HELPERS
+fn load_src0_u16_at(byte_offset: u32) -> u32 {
+    let word = src0[byte_offset / 4u];
+    let shift = (byte_offset & 2u) * 8u;
+    return (word >> shift) & 0xFFFFu;
+}
+
+fn load_src0_u32_at(byte_offset: u32) -> u32 {
+    let word_idx = byte_offset / 4u;
+    let shift = (byte_offset & 3u) * 8u;
+    let lo = src0[word_idx];
+    if (shift == 0u) {
+        return lo;
+    }
+    let hi = src0[word_idx + 1u];
+    return (lo >> shift) | (hi << (32u - shift));
+}
+
+fn load_src0_f16_at(byte_offset: u32) -> f16 {
+    let packed = unpack2x16float(load_src0_u16_at(byte_offset));
+    return f16(packed[0]);
+}
+#endif
+
 #ifdef Q4_0_T
 struct q4_0 {
     d: f16,
index b6822161464c6a7d8d7e494288de48bb268d3252..8b76cecba917a6d83c92e1f67c2e4c416caf76ef 100644 (file)
@@ -6,6 +6,8 @@ enable chromium_experimental_subgroup_matrix;
 
 #ifdef KV_F32
 #define KV_TYPE f32
+#elif defined(KV_Q4_0) || defined(KV_Q8_0)
+#define KV_TYPE u32
 #else
 #define KV_TYPE f16
 #endif
@@ -37,11 +39,13 @@ enable chromium_experimental_subgroup_matrix;
 #define NQ 16
 // Q4_0 has 32 elements, 1 f16 for scale, 8 f16 for 4-bit weights
 #define F16_PER_BLOCK 9
+#define BLOCK_SIZE_BYTES 18u
 #define WEIGHTS_PER_F16 4
 #elif defined(KV_Q8_0)
 #define NQ 8
 // Q8_0 has 32 elements, 1 f16 for scale, 16 f16 for 8-bit weights
 #define F16_PER_BLOCK 17
+#define BLOCK_SIZE_BYTES 34u
 #define WEIGHTS_PER_F16 2
 #endif
 #define F16_PER_THREAD (NQ / WEIGHTS_PER_F16)
@@ -55,6 +59,47 @@ fn get_byte_i32(value: u32, index: u32) -> i32 {
     return bitcast<i32>(((value >> (index * 8)) & 0xFF) << 24) >> 24;
 }
 
+#if defined(KV_Q4_0) || defined(KV_Q8_0)
+fn load_k_u16_at(byte_offset: u32) -> u32 {
+    let word = K[byte_offset / 4u];
+    let shift = (byte_offset & 2u) * 8u;
+    return (word >> shift) & 0xFFFFu;
+}
+
+fn load_k_u32_at(byte_offset: u32) -> u32 {
+    let word_idx = byte_offset / 4u;
+    let shift = (byte_offset & 3u) * 8u;
+    let lo = K[word_idx];
+    if (shift == 0u) {
+        return lo;
+    }
+    let hi = K[word_idx + 1u];
+    return (lo >> shift) | (hi << (32u - shift));
+}
+
+fn load_v_u16_at(byte_offset: u32) -> u32 {
+    let word = V[byte_offset / 4u];
+    let shift = (byte_offset & 2u) * 8u;
+    return (word >> shift) & 0xFFFFu;
+}
+
+fn load_v_u32_at(byte_offset: u32) -> u32 {
+    let word_idx = byte_offset / 4u;
+    let shift = (byte_offset & 3u) * 8u;
+    let lo = V[word_idx];
+    if (shift == 0u) {
+        return lo;
+    }
+    let hi = V[word_idx + 1u];
+    return (lo >> shift) | (hi << (32u - shift));
+}
+
+fn f16_from_u16(bits: u32) -> f16 {
+    let packed = unpack2x16float(bits);
+    return f16(packed[0]);
+}
+#endif
+
 struct Params {
     offset_q: u32,
     offset_k: u32,
@@ -254,12 +299,11 @@ fn main(@builtin(workgroup_id) wg_id: vec3<u32>,
 
           if (global_k_row < params.seq_len_kv) {
               let global_block_idx = k_head_offset + global_k_row * params.stride_k1 + block_k;
-              let base_idx = global_block_idx * F16_PER_BLOCK;
-              let d = K[base_idx]; // scale
+              let block_byte_base = global_block_idx * BLOCK_SIZE_BYTES;
+              let d = f16_from_u16(load_k_u16_at(block_byte_base));
               for (var j = 0u; j < F16_PER_THREAD; j += 2) {
-                  let q_0 = K[base_idx + 1u + block_offset + j];
-                  let q_1 = K[base_idx + 1u + block_offset + j + 1];
-                  let q_packed = bitcast<u32>(vec2(q_0, q_1));
+                  let q_byte_offset = block_byte_base + 2u + 2u * (block_offset + j);
+                  let q_packed = load_k_u32_at(q_byte_offset);
                   for (var k = 0u; k < 4u; k++) {
                       let q_byte = get_byte(q_packed, k);
                       let q_hi = (f16((q_byte >> 4) & 0xF) - 8.0) * d;
@@ -282,12 +326,11 @@ fn main(@builtin(workgroup_id) wg_id: vec3<u32>,
 
           if (global_k_row < params.seq_len_kv) {
               let global_block_idx = k_head_offset + global_k_row * params.stride_k1 + block_k;
-              let base_idx = global_block_idx * F16_PER_BLOCK;
-              let d = K[base_idx]; // scale
+              let block_byte_base = global_block_idx * BLOCK_SIZE_BYTES;
+              let d = f16_from_u16(load_k_u16_at(block_byte_base));
               for (var j = 0u; j < F16_PER_THREAD; j += 2) {
-                  let q_0 = K[base_idx + 1u + block_offset + j];
-                  let q_1 = K[base_idx + 1u + block_offset + j + 1];
-                  let q_packed = bitcast<u32>(vec2(q_0, q_1));
+                  let q_byte_offset = block_byte_base + 2u + 2u * (block_offset + j);
+                  let q_packed = load_k_u32_at(q_byte_offset);
                   for (var k = 0u; k < 4u; k++) {
                       let q_byte = get_byte_i32(q_packed, k);
                       let q_val = f16(q_byte) * d;
@@ -459,12 +502,11 @@ fn main(@builtin(workgroup_id) wg_id: vec3<u32>,
 
           if (global_v_row < params.seq_len_kv) {
               let global_block_idx = v_head_offset + global_v_row * params.stride_v1 + block_k;
-              let base_idx = global_block_idx * F16_PER_BLOCK;
-              let d = V[base_idx]; // scale
+              let block_byte_base = global_block_idx * BLOCK_SIZE_BYTES;
+              let d = f16_from_u16(load_v_u16_at(block_byte_base));
               for (var j = 0u; j < F16_PER_THREAD; j += 2) {
-                  let q_0 = V[base_idx + 1u + block_offset + j];
-                  let q_1 = V[base_idx + 1u + block_offset + j + 1];
-                  let q_packed = bitcast<u32>(vec2(q_0, q_1));
+                  let q_byte_offset = block_byte_base + 2u + 2u * (block_offset + j);
+                  let q_packed = load_v_u32_at(q_byte_offset);
                   for (var k = 0u; k < 4u; k++) {
                       let q_byte = get_byte(q_packed, k);
                       let q_hi = (f16((q_byte >> 4) & 0xF) - 8.0) * d;
@@ -487,12 +529,11 @@ fn main(@builtin(workgroup_id) wg_id: vec3<u32>,
 
           if (global_v_row < params.seq_len_kv) {
               let global_block_idx = v_head_offset + global_v_row * params.stride_v1 + block_k;
-              let base_idx = global_block_idx * F16_PER_BLOCK;
-              let d = V[base_idx]; // scale
+              let block_byte_base = global_block_idx * BLOCK_SIZE_BYTES;
+              let d = f16_from_u16(load_v_u16_at(block_byte_base));
               for (var j = 0u; j < F16_PER_THREAD; j += 2) {
-                  let q_0 = V[base_idx + 1u + block_offset + j];
-                  let q_1 = V[base_idx + 1u + block_offset + j + 1];
-                  let q_packed = bitcast<u32>(vec2(q_0, q_1));
+                  let q_byte_offset = block_byte_base + 2u + 2u * (block_offset + j);
+                  let q_packed = load_v_u32_at(q_byte_offset);
                   for (var k = 0u; k < 4u; k++) {
                       let q_byte = get_byte_i32(q_packed, k);
                       let q_val = f16(q_byte) * d;
index de60ebbcf2b5712330df77532b1af2767844b6a0..eb228537bad20bdd68f517651026e8faaf7e98a8 100644 (file)
@@ -61,10 +61,10 @@ fn init_shmem_src1(thread_id: u32, batch_offset: u32, offset_n: u32, k_outer: u3
 
 #ifdef INIT_SRC0_SHMEM_Q4_0
 const BLOCK_SIZE = 32u;
+const BLOCK_SIZE_BYTES = 18u;
 // the number of blocks per k-tile. Note that this currently only works if TILE_K is a multiple of BLOCK_SIZE, which may need to be rethought for larger quantized types.
 override BLOCKS_K = TILE_K/BLOCK_SIZE;
 const NQ = 16u;
-const F16_PER_BLOCK = 9u; // 1 scale + 8x4 packed weights
 const WEIGHTS_PER_F16 = 4u; // 4 weights per f16
 const F16_PER_THREAD = NQ / WEIGHTS_PER_F16;
 
@@ -81,14 +81,12 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3
 
         if (global_m < params.m && global_k < params.k / BLOCK_SIZE) {
             let src0_idx = batch_offset + global_m * params.stride_01 + global_k;
-            let scale_idx = src0_idx * F16_PER_BLOCK;
-            let d = src0[scale_idx];
+            let block_byte_base = src0_idx * BLOCK_SIZE_BYTES;
+            let d = load_src0_f16_at(block_byte_base);
 
             for (var j = 0u; j < F16_PER_THREAD; j += 2) {
-                let q_0 = src0[scale_idx + 1u + block_offset + j];
-                let q_1 = src0[scale_idx + 1u + block_offset + j + 1];
-
-                let q_packed = bitcast<u32>(vec2(q_0, q_1));
+                let q_byte_offset = block_byte_base + 2u + 2u * (block_offset + j);
+                let q_packed = load_src0_u32_at(q_byte_offset);
                 for (var k = 0u; k < 4u; k++) {
                     let q_byte = get_byte(q_packed, k);
                     let q_hi = (f16((q_byte >> 4) & 0xF) - 8.0) * d;
@@ -104,10 +102,10 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3
 
 #ifdef INIT_SRC0_SHMEM_Q4_1
 const BLOCK_SIZE = 32u;
+const BLOCK_SIZE_BYTES = 20u;
 // the number of blocks per k-tile. Note that this currently only works if TILE_K is a multiple of BLOCK_SIZE, which may need to be rethought for larger quantized types.
 override BLOCKS_K = TILE_K/BLOCK_SIZE;
 const NQ = 16u;
-const F16_PER_BLOCK = 10u; // 1 scale + 8 packed weights + 1 mean
 const WEIGHTS_PER_F16 = 4u; // 4 weights per f16
 const F16_PER_THREAD = NQ / WEIGHTS_PER_F16;
 
@@ -124,15 +122,13 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3
 
         if (global_m < params.m && global_k < params.k / BLOCK_SIZE) {
             let src0_idx = batch_offset + global_m * params.stride_01 + global_k;
-            let scale_idx = src0_idx * F16_PER_BLOCK;
-            let d = src0[scale_idx];
-            let m = src0[scale_idx + 1u];
+            let block_byte_base = src0_idx * BLOCK_SIZE_BYTES;
+            let d = load_src0_f16_at(block_byte_base);
+            let m = load_src0_f16_at(block_byte_base + 2u);
 
             for (var j = 0u; j < F16_PER_THREAD; j += 2) {
-                let q_0 = src0[scale_idx + 2u + block_offset + j];
-                let q_1 = src0[scale_idx + 2u + block_offset + j + 1];
-
-                let q_packed = bitcast<u32>(vec2(q_0, q_1));
+                let q_byte_offset = block_byte_base + 4u + 2u * (block_offset + j);
+                let q_packed = load_src0_u32_at(q_byte_offset);
                 for (var k = 0u; k < 4u; k++) {
                     let q_byte = get_byte(q_packed, k);
                     let q_lo = f16(q_byte & 0xF) * d + m;
@@ -149,11 +145,11 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3
 #ifdef INIT_SRC0_SHMEM_Q5_0
 // 32 weights per block, each at 4 bits each = 32 * 4 = 128 bits / 16 = 8 f16s per block
 const BLOCK_SIZE = 32u;
+const BLOCK_SIZE_BYTES = 22u;
 // the number of blocks per k-tile. Note that this currently only works if TILE_K is a multiple of BLOCK_SIZE, which may need to be rethought for larger quantized types.
 // tile_k is defined as 32u, so blocks_k ends up being 1 always
 override BLOCKS_K = TILE_K / BLOCK_SIZE;
 const NQ = 16u;
-const F16_PER_BLOCK = 11u; // 1 scale + 2 qh + 8 packed weights
 const WEIGHTS_PER_F16 = 4u; // 4 weights per f16
 const F16_PER_THREAD = NQ / WEIGHTS_PER_F16; // 16 / 4 = 4 f16s per thread, each thread should handle 4 f16s * 4 weights per = 16 weights
 
@@ -171,18 +167,14 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3
 
         if (global_m < params.m && global_k < params.k / BLOCK_SIZE) {
             let src0_idx  = batch_offset + global_m * params.stride_01 + global_k;
-            let scale_idx = src0_idx * F16_PER_BLOCK;
+            let block_byte_base = src0_idx * BLOCK_SIZE_BYTES;
 
-            let d  = src0[scale_idx];
-            let qh0 = src0[scale_idx + 1u];
-            let qh1 = src0[scale_idx + 2u];
-            let qh_packed = bitcast<u32>(vec2(qh0, qh1));
+            let d  = load_src0_f16_at(block_byte_base);
+            let qh_packed = load_src0_u32_at(block_byte_base + 2u);
 
             for (var j = 0u; j < 2; j++) {
-                let q_0 = src0[scale_idx + 3u + block_offset + (j*2)];
-                let q_1 = src0[scale_idx + 3u + block_offset + (j*2) + 1u];
-
-                let q_packed = bitcast<u32>(vec2(q_0, q_1));
+                let q_byte_offset = block_byte_base + 6u + 2u * (block_offset + j * 2u);
+                let q_packed = load_src0_u32_at(q_byte_offset);
 
                 let j_adjusted = j + (block_offset / 2u);
 
@@ -207,11 +199,11 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3
 #ifdef INIT_SRC0_SHMEM_Q5_1
 // 32 weights per block, each at 4 bits each = 32 * 4 = 128 bits / 16 = 8 f16s per block
 const BLOCK_SIZE = 32u;
+const BLOCK_SIZE_BYTES = 24u;
 // the number of blocks per k-tile. Note that this currently only works if TILE_K is a multiple of BLOCK_SIZE, which may need to be rethought for larger quantized types.
 // tile_k is defined as 32u, so blocks_k ends up being 1 always
 override BLOCKS_K = TILE_K / BLOCK_SIZE;
 const NQ = 16u;
-const F16_PER_BLOCK = 12u; // 1 scale + 2 qh + 8 packed weights + 1 mean
 const WEIGHTS_PER_F16 = 4u; // 4 weights per f16
 const F16_PER_THREAD = NQ / WEIGHTS_PER_F16; // 16 / 4 = 4 f16s per thread, each thread should handle 4 f16s * 4 weights per = 16 weights
 
@@ -229,20 +221,16 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3
 
         if (global_m < params.m && global_k < params.k / BLOCK_SIZE) {
             let src0_idx  = batch_offset + global_m * params.stride_01 + global_k;
-            let scale_idx = src0_idx * F16_PER_BLOCK;
+            let block_byte_base = src0_idx * BLOCK_SIZE_BYTES;
 
-            let d  = src0[scale_idx];
-            let m = src0[scale_idx + 1u];
-            let qh0 = src0[scale_idx + 2u];
-            let qh1 = src0[scale_idx + 3u];
-            let qh_packed = bitcast<u32>(vec2(qh0, qh1));
+            let d  = load_src0_f16_at(block_byte_base);
+            let m = load_src0_f16_at(block_byte_base + 2u);
+            let qh_packed = load_src0_u32_at(block_byte_base + 4u);
 
             for (var j = 0u; j < 2; j++) {
 
-                let q_0 = src0[scale_idx + 4u + block_offset + (j*2)];
-                let q_1 = src0[scale_idx + 4u + block_offset + (j*2) + 1u];
-
-                let q_packed = bitcast<u32>(vec2(q_0, q_1));
+                let q_byte_offset = block_byte_base + 8u + 2u * (block_offset + j * 2u);
+                let q_packed = load_src0_u32_at(q_byte_offset);
 
                 let j_adjusted = j + (block_offset / 2u);
 
@@ -266,10 +254,10 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3
 
 #ifdef INIT_SRC0_SHMEM_Q8_0
 const BLOCK_SIZE = 32u;
+const BLOCK_SIZE_BYTES = 34u;
 // the number of blocks per k-tile. Note that this currently only works if TILE_K is a multiple of BLOCK_SIZE, which may need to be rethought for larger quantized types.
 override BLOCKS_K = TILE_K/BLOCK_SIZE;
 const NQ = 16u;
-const F16_PER_BLOCK = 17u; // 1 scale + 16 in array of weights
 const WEIGHTS_PER_F16 = 2u; // 2 8-bit weights per f16
 const F16_PER_THREAD = NQ / WEIGHTS_PER_F16; // 8 f16s per thread
 
@@ -286,14 +274,12 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3
 
         if (global_m < params.m && global_k < params.k / BLOCK_SIZE) {
             let src0_idx = batch_offset + global_m * params.stride_01 + global_k;
-            let scale_idx = src0_idx * F16_PER_BLOCK;
-            let d = src0[scale_idx];
+            let block_byte_base = src0_idx * BLOCK_SIZE_BYTES;
+            let d = load_src0_f16_at(block_byte_base);
 
             for (var j = 0u; j < F16_PER_THREAD; j+=2) {
-                let q_0 = src0[scale_idx + 1u + block_offset + j];
-                let q_1 = src0[scale_idx + 1u + block_offset + j + 1];
-
-                let q_packed = bitcast<u32>(vec2(q_0, q_1));
+                let q_byte_offset = block_byte_base + 2u + 2u * (block_offset + j);
+                let q_packed = load_src0_u32_at(q_byte_offset);
                 for (var k = 0u; k < 4u; k++) {
                     let q_byte = get_byte_i32(q_packed, k);
 
@@ -308,10 +294,10 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3
 
 #ifdef INIT_SRC0_SHMEM_Q8_1
 const BLOCK_SIZE = 32u;
+const BLOCK_SIZE_BYTES = 36u;
 // the number of blocks per k-tile. Note that this currently only works if TILE_K is a multiple of BLOCK_SIZE, which may need to be rethought for larger quantized types.
 override BLOCKS_K = TILE_K/BLOCK_SIZE;
 const NQ = 16u;
-const F16_PER_BLOCK = 18u; // 1 scale + 1 mean + 8 32-bit values in array of weights
 const WEIGHTS_PER_F16 = 2u; // 2 8-bit weights per f16
 const F16_PER_THREAD = NQ / WEIGHTS_PER_F16; // 8 f16s per thread, 2 threads per block
 
@@ -328,15 +314,13 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3
 
         if (global_m < params.m && global_k < params.k / BLOCK_SIZE) {
             let src0_idx = batch_offset + global_m * params.stride_01 + global_k;
-            let scale_idx = src0_idx * F16_PER_BLOCK;
-            let d = src0[scale_idx];
-            let m = src0[scale_idx + 1u];
+            let block_byte_base = src0_idx * BLOCK_SIZE_BYTES;
+            let d = load_src0_f16_at(block_byte_base);
+            let m = load_src0_f16_at(block_byte_base + 2u);
 
             for (var j = 0u; j < F16_PER_THREAD; j+=2) {
-                let q_0 = src0[scale_idx + 2u + block_offset + j];
-                let q_1 = src0[scale_idx + 2u + block_offset + j + 1];
-
-                let q_packed = bitcast<u32>(vec2(q_0, q_1));
+                let q_byte_offset = block_byte_base + 4u + 2u * (block_offset + j);
+                let q_packed = load_src0_u32_at(q_byte_offset);
                 for (var k = 0u; k < 4u; k++) {
                     let q_byte = get_byte_i32(q_packed, k);
 
@@ -351,7 +335,7 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3
 
 #ifdef INIT_SRC0_SHMEM_Q2_K
 const BLOCK_SIZE = 256u;
-const F16_PER_BLOCK = 42u;
+const BLOCK_SIZE_BYTES = 84u;
 
 fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u32) {
     // Use standard thread layout instead of lane/row_group
@@ -371,10 +355,10 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3
         let k_in_block = global_k % BLOCK_SIZE;
 
         let src0_idx = batch_offset + global_m * params.stride_01 + block_k;
-        let scale_idx = src0_idx * F16_PER_BLOCK;
+        let block_byte_base = src0_idx * BLOCK_SIZE_BYTES;
 
-        let d = src0[scale_idx + 40u];
-        let dmin = src0[scale_idx + 41u];
+        let d = load_src0_f16_at(block_byte_base + 80u);
+        let dmin = load_src0_f16_at(block_byte_base + 82u);
 
         // Decode the element at position k_in_block
         let block_of_32 = k_in_block / 32u;
@@ -387,18 +371,14 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3
 
         let is = k_in_block / 16u;
 
-        let sc_0 = src0[scale_idx + 2u * (is / 4u)];
-        let sc_1 = src0[scale_idx + 2u * (is / 4u) + 1u];
-        let sc_packed = bitcast<u32>(vec2(sc_0, sc_1));
+        let sc_packed = load_src0_u32_at(block_byte_base + 4u * (is / 4u));
         let sc = get_byte(sc_packed, is % 4u);
 
         let dl = d * f16(sc & 0xFu);
         let ml = dmin * f16(sc >> 4u);
 
         let q_idx = q_b_idx + k + l;
-        let q_0 = src0[scale_idx + 8u + 2u * (q_idx / 4u)];
-        let q_1 = src0[scale_idx + 8u + 2u * (q_idx / 4u) + 1u];
-        let q_packed = bitcast<u32>(vec2(q_0, q_1));
+        let q_packed = load_src0_u32_at(block_byte_base + 16u + 4u * (q_idx / 4u));
         let q_byte = get_byte(q_packed, q_idx % 4u);
         let qs_val = (q_byte >> shift) & 3u;
 
@@ -410,7 +390,7 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3
 
 #ifdef INIT_SRC0_SHMEM_Q3_K
 const BLOCK_SIZE = 256u;
-const F16_PER_BLOCK = 55u;
+const BLOCK_SIZE_BYTES = 110u;
 
 fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u32) {
     for (var elem_idx = thread_id; elem_idx < TILE_SRC0_SHMEM; elem_idx += TOTAL_WORKGROUP_SIZE) {
@@ -429,9 +409,9 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3
         let k_in_block = global_k % BLOCK_SIZE;
 
         let src0_idx = batch_offset + global_m * params.stride_01 + block_k;
-        let scale_idx = src0_idx * F16_PER_BLOCK;
+        let block_byte_base = src0_idx * BLOCK_SIZE_BYTES;
 
-        let d = src0[scale_idx + 54u];
+        let d = load_src0_f16_at(block_byte_base + 108u);
 
         // Load and unpack scales
         let kmask1: u32 = 0x03030303u;
@@ -439,9 +419,7 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3
 
         var scale_vals: array<u32, 4>;
         for (var i: u32 = 0u; i < 4u; i++) {
-            let scale_0 = src0[scale_idx + 48u + (2u*i)];
-            let scale_1 = src0[scale_idx + 48u + (2u*i) + 1u];
-            scale_vals[i] = bitcast<u32>(vec2(scale_0, scale_1));
+            scale_vals[i] = load_src0_u32_at(block_byte_base + 96u + 4u * i);
         }
 
         var tmp: u32 = scale_vals[2];
@@ -453,16 +431,12 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3
         // Load hmask and qs arrays
         var hmask_vals: array<u32, 8>;
         for (var i: u32 = 0u; i < 8u; i++) {
-            let hmask_0 = src0[scale_idx + (2u*i)];
-            let hmask_1 = src0[scale_idx + (2u*i) + 1u];
-            hmask_vals[i] = bitcast<u32>(vec2(hmask_0, hmask_1));
+            hmask_vals[i] = load_src0_u32_at(block_byte_base + 4u * i);
         }
 
         var qs_vals: array<u32, 16>;
         for (var i: u32 = 0u; i < 16u; i++) {
-            let qs_0 = src0[scale_idx + 16u + (2u*i)];
-            let qs_1 = src0[scale_idx + 16u + (2u*i) + 1u];
-            qs_vals[i] = bitcast<u32>(vec2(qs_0, qs_1));
+            qs_vals[i] = load_src0_u32_at(block_byte_base + 32u + 4u * i);
         }
 
         let half = k_in_block / 128u;           // 0 or 1
@@ -502,7 +476,7 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3
 
 #ifdef INIT_SRC0_SHMEM_Q4_K
 const BLOCK_SIZE = 256u;
-const F16_PER_BLOCK = 72u;
+const BLOCK_SIZE_BYTES = 144u;
 
 fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u32) {
     for (var elem_idx = thread_id; elem_idx < TILE_SRC0_SHMEM; elem_idx += TOTAL_WORKGROUP_SIZE) {
@@ -521,17 +495,15 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3
         let k_in_block = global_k % BLOCK_SIZE;
 
         let src0_idx = batch_offset + global_m * params.stride_01 + block_k;
-        let scale_idx = src0_idx * F16_PER_BLOCK;
+        let block_byte_base = src0_idx * BLOCK_SIZE_BYTES;
 
-        let d = src0[scale_idx];
-        let dmin = src0[scale_idx + 1u];
+        let d = load_src0_f16_at(block_byte_base);
+        let dmin = load_src0_f16_at(block_byte_base + 2u);
 
         // Load packed scales
         var scale_vals: array<u32, 3>;
         for (var i: u32 = 0u; i < 3u; i++) {
-            let scale_0 = src0[scale_idx + 2u + (2u*i)];
-            let scale_1 = src0[scale_idx + 2u + (2u*i) + 1u];
-            scale_vals[i] = bitcast<u32>(vec2(scale_0, scale_1));
+            scale_vals[i] = load_src0_u32_at(block_byte_base + 4u + 4u * i);
         }
 
         // Map k_in_block to loop structure:
@@ -567,9 +539,7 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3
         let ml = dmin * f16(mn);
 
         let q_idx = q_b_idx + l;
-        let q_0 = src0[scale_idx + 8u + 2u * (q_idx / 4u)];
-        let q_1 = src0[scale_idx + 8u + 2u * (q_idx / 4u) + 1u];
-        let q_packed = bitcast<u32>(vec2(q_0, q_1));
+        let q_packed = load_src0_u32_at(block_byte_base + 16u + 4u * (q_idx / 4u));
 
         let q_byte = get_byte(q_packed, q_idx % 4u);
         let qs_val = (q_byte >> shift) & 0xFu;
@@ -582,7 +552,7 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3
 
 #ifdef INIT_SRC0_SHMEM_Q5_K
 const BLOCK_SIZE = 256u;
-const F16_PER_BLOCK = 88u;
+const BLOCK_SIZE_BYTES = 176u;
 
 fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u32) {
     for (var elem_idx = thread_id; elem_idx < TILE_SRC0_SHMEM; elem_idx += TOTAL_WORKGROUP_SIZE) {
@@ -601,17 +571,15 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3
         let k_in_block = global_k % BLOCK_SIZE;
 
         let src0_idx = batch_offset + global_m * params.stride_01 + block_k;
-        let scale_idx = src0_idx * F16_PER_BLOCK;
+        let block_byte_base = src0_idx * BLOCK_SIZE_BYTES;
 
-        let d = src0[scale_idx];
-        let dmin = src0[scale_idx + 1u];
+        let d = load_src0_f16_at(block_byte_base);
+        let dmin = load_src0_f16_at(block_byte_base + 2u);
 
         // Load packed scales
         var scale_vals: array<u32, 3>;
         for (var i: u32 = 0u; i < 3u; i++) {
-            let scale_0 = src0[scale_idx + 2u + (2u*i)];
-            let scale_1 = src0[scale_idx + 2u + (2u*i) + 1u];
-            scale_vals[i] = bitcast<u32>(vec2(scale_0, scale_1));
+            scale_vals[i] = load_src0_u32_at(block_byte_base + 4u + 4u * i);
         }
 
         // The original loop processes elements in groups of 64
@@ -651,15 +619,11 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3
         let ml = dmin * f16(mn);
 
         let q_idx = q_b_idx + l;
-        let q_0 = src0[scale_idx + 24u + 2u * (q_idx / 4u)];
-        let q_1 = src0[scale_idx + 24u + 2u * (q_idx / 4u) + 1u];
-        let q_packed = bitcast<u32>(vec2(q_0, q_1));
+        let q_packed = load_src0_u32_at(block_byte_base + 48u + 4u * (q_idx / 4u));
 
         let q_byte = get_byte(q_packed, q_idx % 4u);
 
-        let qh_0 = src0[scale_idx + 8u + 2u * (l / 4u)];
-        let qh_1 = src0[scale_idx + 8u + 2u * (l / 4u) + 1u];
-        let qh_packed = bitcast<u32>(vec2(qh_0, qh_1));
+        let qh_packed = load_src0_u32_at(block_byte_base + 16u + 4u * (l / 4u));
 
         let qh_byte = get_byte(qh_packed, l % 4u);
 
@@ -675,7 +639,7 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3
 
 #ifdef INIT_SRC0_SHMEM_Q6_K
 const BLOCK_SIZE = 256u;
-const F16_PER_BLOCK = 105u;
+const BLOCK_SIZE_BYTES = 210u;
 
 fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u32) {
     for (var elem_idx = thread_id; elem_idx < TILE_SRC0_SHMEM; elem_idx += TOTAL_WORKGROUP_SIZE) {
@@ -694,7 +658,7 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3
         let k_in_block = global_k % BLOCK_SIZE;
 
         let src0_idx = batch_offset + global_m * params.stride_01 + block_k;
-        let scale_idx = src0_idx * F16_PER_BLOCK;
+        let block_byte_base = src0_idx * BLOCK_SIZE_BYTES;
 
         let half = k_in_block / 128u;
         let pos_in_half = k_in_block % 128u;
@@ -707,30 +671,18 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3
 
         // Load only ql13 word needed
         let ql13_flat = ql_b_idx + l;
-        let ql13_word = ql13_flat / 4u;
-        let ql13 = bitcast<u32>(vec2(
-            src0[scale_idx + 2u * ql13_word],
-            src0[scale_idx + 2u * ql13_word + 1u]
-        ));
-        let ql13_b = get_byte(ql13, ql13_flat % 4u);
+        let ql13 = load_src0_u32_at(block_byte_base + ql13_flat);
+        let ql13_b = get_byte(ql13, 0u);
 
         // Load only ql24 word needed
         let ql24_flat = ql_b_idx + l + 32u;
-        let ql24_word = ql24_flat / 4u;
-        let ql24 = bitcast<u32>(vec2(
-            src0[scale_idx + 2u * ql24_word],
-            src0[scale_idx + 2u * ql24_word + 1u]
-        ));
-        let ql24_b = get_byte(ql24, ql24_flat % 4u);
+        let ql24 = load_src0_u32_at(block_byte_base + ql24_flat);
+        let ql24_b = get_byte(ql24, 0u);
 
         // Load only qh word needed
         let qh_flat = qh_b_idx + l;
-        let qh_word = qh_flat / 4u;
-        let qh = bitcast<u32>(vec2(
-            src0[scale_idx + 64u + 2u * qh_word],
-            src0[scale_idx + 64u + 2u * qh_word + 1u]
-        ));
-        let qh_b = get_byte(qh, qh_flat % 4u);
+        let qh = load_src0_u32_at(block_byte_base + 128u + qh_flat);
+        let qh_b = get_byte(qh, 0u);
 
         let q1 = f16((ql13_b & 0xFu) | ((qh_b & 3u) << 4u)) - f16(32.0);
         let q2 = f16((ql24_b & 0xFu) | (((qh_b >> 2u) & 3u) << 4u)) - f16(32.0);
@@ -740,14 +692,10 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3
         // Load only the scale word needed
         let is = l / 16u;
         let sc_idx = sc_b_idx + is + quarter * 2u;
-        let sc_word = sc_idx / 4u;
-        let sc = bitcast<u32>(vec2(
-            src0[scale_idx + 96u + 2u * sc_word],
-            src0[scale_idx + 96u + 2u * sc_word + 1u]
-        ));
-        let sc_val = get_byte_i32(sc, sc_idx % 4u);
-
-        let d = src0[scale_idx + 104u];
+        let sc = load_src0_u32_at(block_byte_base + 192u + sc_idx);
+        let sc_val = get_byte_i32(sc, 0u);
+
+        let d = load_src0_f16_at(block_byte_base + 208u);
 
         var q_val: f16;
         if (quarter == 0u) {
index 94f4bae11f4a3aa8e9b3538668fa85115e5a2749..6525f23bdfc48e7d95707324e74a1652f7264780 100644 (file)
@@ -52,8 +52,8 @@ fn mul_acc(tig:u32, tile_size: u32, idx_base: u32, k_outer: u32) -> f32 {
 #ifdef MUL_ACC_Q4_0
 
 const BLOCK_SIZE = 32;
+const BLOCK_SIZE_BYTES = 18u;
 const NQ = 16u; // number of weights per thread
-const F16_PER_BLOCK = 9u; // 1 scale + 8x4 packed weights
 const WEIGHTS_PER_F16 = 4u; // 4 weights per f16
 const F16_PER_THREAD = NQ / WEIGHTS_PER_F16;
 
@@ -62,14 +62,13 @@ fn mul_acc(tig:u32, tile_size: u32, idx_base: u32, k_outer: u32) -> f32 {
     for (var i = tig * NQ; i < tile_size; i += THREADS_PER_OUTPUT * NQ) {
         let blck_idx = i / BLOCK_SIZE;
         let block_offset = (i % BLOCK_SIZE) / WEIGHTS_PER_F16;
-        let scale_idx = (idx_base + k_outer / BLOCK_SIZE + blck_idx) * F16_PER_BLOCK;
+        let block_byte_base = (idx_base + k_outer / BLOCK_SIZE + blck_idx) * BLOCK_SIZE_BYTES;
         // each f16 contains offsets [block_offset, block_offset + 1] and [block_offset + 16, block_offset + 17]
         let shmem_idx = blck_idx * BLOCK_SIZE + block_offset * 2u;
-        let d = f32(src0[scale_idx]);
+        let d = f32(load_src0_f16_at(block_byte_base));
         for (var j = 0u; j < F16_PER_THREAD; j += 2) {
-            let q_0 = src0[scale_idx + 1 + block_offset + j];
-            let q_1 = src0[scale_idx + 1 + block_offset + j + 1];
-            let q_packed = bitcast<u32>(vec2(q_0, q_1));
+            let q_byte_offset = block_byte_base + 2u + 2u * (block_offset + j);
+            let q_packed = load_src0_u32_at(q_byte_offset);
             for (var k: u32 = 0; k < 4; k++) {
                 let q_byte = get_byte(q_packed, k);
                 let q_hi = (f32((q_byte >> 4) & 0xF) - 8.0) * d;
@@ -86,8 +85,8 @@ fn mul_acc(tig:u32, tile_size: u32, idx_base: u32, k_outer: u32) -> f32 {
 #ifdef MUL_ACC_Q4_1
 
 const BLOCK_SIZE = 32;
+const BLOCK_SIZE_BYTES = 20u;
 const NQ = 16u; // number of weights per thread
-const F16_PER_BLOCK = 10u;
 const WEIGHTS_PER_F16 = 4u; // 4 weights per f16
 const F16_PER_THREAD = NQ / WEIGHTS_PER_F16;
 
@@ -96,15 +95,14 @@ fn mul_acc(tig:u32, tile_size: u32, idx_base: u32, k_outer: u32) -> f32 {
     for (var i = tig * NQ; i < tile_size; i += THREADS_PER_OUTPUT * NQ) {
         let blck_idx = i / BLOCK_SIZE;
         let block_offset = (i % BLOCK_SIZE) / WEIGHTS_PER_F16;
-        let scale_idx = (idx_base + k_outer / BLOCK_SIZE + blck_idx) * F16_PER_BLOCK;
+        let block_byte_base = (idx_base + k_outer / BLOCK_SIZE + blck_idx) * BLOCK_SIZE_BYTES;
         // each f16 contains offsets [block_offset, block_offset + 1] and [block_offset + 16, block_offset + 17]
         let shmem_idx = blck_idx * BLOCK_SIZE + block_offset * 2u;
-        let d = f32(src0[scale_idx]);
-        let m = f32(src0[scale_idx + 1u]);
+        let d = f32(load_src0_f16_at(block_byte_base));
+        let m = f32(load_src0_f16_at(block_byte_base + 2u));
         for (var j = 0u; j < F16_PER_THREAD; j += 2) {
-            let q_0 = src0[scale_idx + 2u + block_offset + j];
-            let q_1 = src0[scale_idx + 2u + block_offset + j + 1];
-            let q_packed = bitcast<u32>(vec2(q_0, q_1));
+            let q_byte_offset = block_byte_base + 4u + 2u * (block_offset + j);
+            let q_packed = load_src0_u32_at(q_byte_offset);
             for (var k: u32 = 0; k < 4; k++) {
                 let q_byte = get_byte(q_packed, k);
                 let q_hi = f32((q_byte >> 4) & 0xF) * d + m;
@@ -121,8 +119,8 @@ fn mul_acc(tig:u32, tile_size: u32, idx_base: u32, k_outer: u32) -> f32 {
 #ifdef MUL_ACC_Q5_0
 
 const BLOCK_SIZE = 32;
+const BLOCK_SIZE_BYTES = 22u;
 const NQ = 16u; // number of weights per thread
-const F16_PER_BLOCK = 11u;
 const WEIGHTS_PER_F16 = 4u; // 4 weights per f16
 const F16_PER_THREAD = NQ / WEIGHTS_PER_F16;
 
@@ -131,18 +129,15 @@ fn mul_acc(tig:u32, tile_size: u32, idx_base: u32, k_outer: u32) -> f32 {
     for (var i = tig * NQ; i < tile_size; i += THREADS_PER_OUTPUT * NQ) {
         let blck_idx = i / BLOCK_SIZE;
         let block_offset = (i % BLOCK_SIZE) / WEIGHTS_PER_F16;
-        let scale_idx = (idx_base + k_outer / BLOCK_SIZE + blck_idx) * F16_PER_BLOCK;
+        let block_byte_base = (idx_base + k_outer / BLOCK_SIZE + blck_idx) * BLOCK_SIZE_BYTES;
         // each f16 contains offsets [block_offset, block_offset + 1] and [block_offset + 16, block_offset + 17]
         let shmem_idx = blck_idx * BLOCK_SIZE + block_offset * 2u;
-        let d = f32(src0[scale_idx]);
-        let qh0 = src0[scale_idx + 1u];
-        let qh1 = src0[scale_idx + 2u];
-        let qh_packed = bitcast<u32>(vec2(qh0, qh1));
+        let d = f32(load_src0_f16_at(block_byte_base));
+        let qh_packed = load_src0_u32_at(block_byte_base + 2u);
 
         for (var j = 0u; j < 2; j++) {
-            let q_0 = src0[scale_idx + 3u + block_offset + (j*2)];
-            let q_1 = src0[scale_idx + 3u + block_offset + (j*2) + 1u];
-            let q_packed = bitcast<u32>(vec2(q_0, q_1));
+            let q_byte_offset = block_byte_base + 6u + 2u * (block_offset + j * 2u);
+            let q_packed = load_src0_u32_at(q_byte_offset);
 
             let j_adjusted = j + (block_offset / 2u);
 
@@ -168,8 +163,8 @@ fn mul_acc(tig:u32, tile_size: u32, idx_base: u32, k_outer: u32) -> f32 {
 #ifdef MUL_ACC_Q5_1
 
 const BLOCK_SIZE = 32;
+const BLOCK_SIZE_BYTES = 24u;
 const NQ = 16u; // number of weights per thread
-const F16_PER_BLOCK = 12u;
 const WEIGHTS_PER_F16 = 4u; // 4 weights per f16
 const F16_PER_THREAD = NQ / WEIGHTS_PER_F16;
 
@@ -178,19 +173,16 @@ fn mul_acc(tig:u32, tile_size: u32, idx_base: u32, k_outer: u32) -> f32 {
     for (var i = tig * NQ; i < tile_size; i += THREADS_PER_OUTPUT * NQ) {
         let blck_idx = i / BLOCK_SIZE;
         let block_offset = (i % BLOCK_SIZE) / WEIGHTS_PER_F16;
-        let scale_idx = (idx_base + k_outer / BLOCK_SIZE + blck_idx) * F16_PER_BLOCK;
+        let block_byte_base = (idx_base + k_outer / BLOCK_SIZE + blck_idx) * BLOCK_SIZE_BYTES;
         // each f16 contains offsets [block_offset, block_offset + 1] and [block_offset + 16, block_offset + 17]
         let shmem_idx = blck_idx * BLOCK_SIZE + block_offset * 2u;
-        let d = f32(src0[scale_idx]);
-        let m = src0[scale_idx + 1u];
-        let qh0 = src0[scale_idx + 2u];
-        let qh1 = src0[scale_idx + 3u];
-        let qh_packed = bitcast<u32>(vec2(qh0, qh1));
+        let d = f32(load_src0_f16_at(block_byte_base));
+        let m = load_src0_f16_at(block_byte_base + 2u);
+        let qh_packed = load_src0_u32_at(block_byte_base + 4u);
 
         for (var j = 0u; j < 2; j++) {
-            let q_0 = src0[scale_idx + 4u + block_offset + (j*2)];
-            let q_1 = src0[scale_idx + 4u + block_offset + (j*2) + 1u];
-            let q_packed = bitcast<u32>(vec2(q_0, q_1));
+            let q_byte_offset = block_byte_base + 8u + 2u * (block_offset + j * 2u);
+            let q_packed = load_src0_u32_at(q_byte_offset);
 
             let j_adjusted = j + (block_offset / 2u);
 
@@ -216,8 +208,8 @@ fn mul_acc(tig:u32, tile_size: u32, idx_base: u32, k_outer: u32) -> f32 {
 #ifdef MUL_ACC_Q8_0
 
 const BLOCK_SIZE = 32;
+const BLOCK_SIZE_BYTES = 34u;
 const NQ = 16u; // number of weights per thread
-const F16_PER_BLOCK = 17u;
 const WEIGHTS_PER_F16 = 2u;
 const F16_PER_THREAD = NQ / WEIGHTS_PER_F16;
 
@@ -226,15 +218,14 @@ fn mul_acc(tig:u32, tile_size: u32, idx_base: u32, k_outer: u32) -> f32 {
     for (var i = tig * NQ; i < tile_size; i += THREADS_PER_OUTPUT * NQ) {
         let blck_idx = i / BLOCK_SIZE;
         let block_offset = (i % BLOCK_SIZE) / WEIGHTS_PER_F16;
-        let scale_idx = (idx_base + k_outer / BLOCK_SIZE + blck_idx) * F16_PER_BLOCK;
+        let block_byte_base = (idx_base + k_outer / BLOCK_SIZE + blck_idx) * BLOCK_SIZE_BYTES;
         // each f16 contains offsets [block_offset, block_offset + 1] and [block_offset + 16, block_offset + 17]
         let shmem_idx = blck_idx * BLOCK_SIZE + block_offset * 2u;
-        let d = f32(src0[scale_idx]);
+        let d = f32(load_src0_f16_at(block_byte_base));
 
         for (var j = 0u; j < F16_PER_THREAD; j += 2) {
-            let q_0 = src0[scale_idx + 1 + block_offset + j];
-            let q_1 = src0[scale_idx + 1 + block_offset + j + 1];
-            let q_packed = bitcast<u32>(vec2(q_0, q_1));
+            let q_byte_offset = block_byte_base + 2u + 2u * (block_offset + j);
+            let q_packed = load_src0_u32_at(q_byte_offset);
             for (var k: u32 = 0; k < 4; k++) {
                 let q_byte = get_byte_i32(q_packed, k);
                 let q_val = f32(q_byte) * d;
@@ -250,8 +241,8 @@ fn mul_acc(tig:u32, tile_size: u32, idx_base: u32, k_outer: u32) -> f32 {
 #ifdef MUL_ACC_Q8_1
 
 const BLOCK_SIZE = 32;
+const BLOCK_SIZE_BYTES = 36u;
 const NQ = 16u; // number of weights per thread
-const F16_PER_BLOCK = 18u;
 const WEIGHTS_PER_F16 = 2u;
 const F16_PER_THREAD = NQ / WEIGHTS_PER_F16;
 
@@ -260,16 +251,15 @@ fn mul_acc(tig:u32, tile_size: u32, idx_base: u32, k_outer: u32) -> f32 {
     for (var i = tig * NQ; i < tile_size; i += THREADS_PER_OUTPUT * NQ) {
         let blck_idx = i / BLOCK_SIZE;
         let block_offset = (i % BLOCK_SIZE) / WEIGHTS_PER_F16;
-        let scale_idx = (idx_base + k_outer / BLOCK_SIZE + blck_idx) * F16_PER_BLOCK;
+        let block_byte_base = (idx_base + k_outer / BLOCK_SIZE + blck_idx) * BLOCK_SIZE_BYTES;
         // each f16 contains offsets [block_offset, block_offset + 1] and [block_offset + 16, block_offset + 17]
         let shmem_idx = blck_idx * BLOCK_SIZE + block_offset * 2u;
-        let d = f32(src0[scale_idx]);
-        let m = src0[scale_idx + 1u];
+        let d = f32(load_src0_f16_at(block_byte_base));
+        let m = load_src0_f16_at(block_byte_base + 2u);
 
         for (var j = 0u; j < F16_PER_THREAD; j += 2) {
-            let q_0 = src0[scale_idx + 2u + block_offset + j];
-            let q_1 = src0[scale_idx + 2u + block_offset + j + 1];
-            let q_packed = bitcast<u32>(vec2(q_0, q_1));
+            let q_byte_offset = block_byte_base + 4u + 2u * (block_offset + j);
+            let q_packed = load_src0_u32_at(q_byte_offset);
             for (var k: u32 = 0; k < 4; k++) {
                 let q_byte = get_byte_i32(q_packed, k);
                 let q_val = f32(q_byte) * d + f32(m);
@@ -284,13 +274,7 @@ fn mul_acc(tig:u32, tile_size: u32, idx_base: u32, k_outer: u32) -> f32 {
 #ifdef MUL_ACC_Q6_K
 
 const BLOCK_SIZE = 256u;
-const F16_PER_BLOCK = 105u;
-
-fn load_u32_at(bbase: u32, byte_offset: u32) -> u32 {
-    let aligned = byte_offset & ~3u;
-    let idx = bbase + aligned / 2u;
-    return bitcast<u32>(vec2(src0[idx], src0[idx + 1u]));
-}
+const BLOCK_SIZE_BYTES = 210u;
 
 fn byte_of(v: u32, b: u32) -> u32 {
     return (v >> (b * 8u)) & 0xFFu;
@@ -323,16 +307,15 @@ fn mul_acc(tig: u32, tile_size: u32, idx_base: u32, k_outer: u32) -> f32 {
     var local_sum = 0.0;
 
     for (var i = ix; i < nb; i += 2u) {
-        let bbase = (idx_base + k_block_start + i) * F16_PER_BLOCK;
+        let bbase = (idx_base + k_block_start + i) * BLOCK_SIZE_BYTES;
 
-        let d_raw = load_u32_at(bbase, 208u);
-        let d = f32(bitcast<vec2<f16>>(d_raw)[0]);
+        let d = f32(load_src0_f16_at(bbase + 208u));
 
-        let ql1_u32  = load_u32_at(bbase, q_offset_l);
-        let ql2_u32  = load_u32_at(bbase, q_offset_l + 32u);
-        let qh_u32   = load_u32_at(bbase, 128u + q_offset_h);
-        let sc_u32_0 = load_u32_at(bbase, sc_base_byte);
-        let sc_u32_1 = load_u32_at(bbase, sc_base_byte + 4u);
+        let ql1_u32  = load_src0_u32_at(bbase + q_offset_l);
+        let ql2_u32  = load_src0_u32_at(bbase + q_offset_l + 32u);
+        let qh_u32   = load_src0_u32_at(bbase + 128u + q_offset_h);
+        let sc_u32_0 = load_src0_u32_at(bbase + sc_base_byte);
+        let sc_u32_1 = load_src0_u32_at(bbase + sc_base_byte + 4u);
 
         let sc0 = sbyte_of(sc_u32_0, sc_byte_pos);
         let sc2 = sbyte_of(sc_u32_0, sc_byte_pos + 2u);