defines.push_back("BYTE_HELPERS");
defines.push_back("MUL_ACC_" + type_upper);
-
- // For fast path we always dequantize from f16 inside the shader
- defines.push_back("SRC0_INNER_TYPE=f16");
+ defines.push_back("U32_DEQUANT_HELPERS");
+ defines.push_back("SRC0_INNER_TYPE=u32");
break;
}
}
defines.push_back("MUL_ACC_" + type_upper);
defines.push_back("INIT_SRC0_SHMEM_" + type_upper);
defines.push_back("INIT_SRC1_SHMEM_FLOAT");
-
- // Use f16 inside the shader for quantized types
- defines.push_back("SRC0_INNER_TYPE=f16");
+ defines.push_back("U32_DEQUANT_HELPERS");
+ defines.push_back("SRC0_INNER_TYPE=u32");
variant += std::string("_") + src0_name;
break;
#define WEBGPU_NUM_PARAM_BUFS 96u
#define WEBGPU_COMMAND_SUBMIT_BATCH_SIZE 32u
-#define WEBGPU_WAIT_ANY_TIMEOUT_MS 0
+#define WEBGPU_WAIT_ANY_TIMEOUT_MS 100
// Maximum number of in-flight submissions per-thread, to avoid exhausting the
// parameter buffer pool
#define WEBGPU_MAX_INFLIGHT_SUBS_PER_THREAD (WEBGPU_NUM_PARAM_BUFS / WEBGPU_COMMAND_SUBMIT_BATCH_SIZE)
// Try growing the pool if no free buffers
if (free.empty() && cur_pool_size < max_pool_size && should_grow) {
cur_pool_size++;
+ lock.unlock(); // avoid deadlock between this lock and Dawn's internal locks when buffers are freed in callbacks
wgpu::Buffer dev_buf;
ggml_webgpu_create_buffer(device, dev_buf, buf_size, dev_buf_usage, "ggml_webgpu_dev_pool_buf");
bool blocking_wait = block || subs.size() >= WEBGPU_MAX_INFLIGHT_SUBS_PER_THREAD;
while (blocking_wait) {
- auto waitStatus = ctx->instance.WaitAny(1, &subs[0].submit_done, 0);
+ auto waitStatus = ctx->instance.WaitAny(1, &subs[0].submit_done, WEBGPU_WAIT_ANY_TIMEOUT_MS * 1e6);
if (ggml_backend_webgpu_handle_wait_status(waitStatus, true)) {
#ifdef GGML_WEBGPU_GPU_PROFILE
ggml_backend_webgpu_wait_profile_futures(ctx, subs[0].profile_futures, true);
ggml_backend_webgpu_build(ctx, ctx->memset_buf_pool, ctx->memset_pipelines[0], params, entries, wg_x);
std::vector<webgpu_command> commands = { command };
std::vector<webgpu_submission> sub = { ggml_backend_webgpu_submit(ctx, commands, ctx->memset_buf_pool) };
- ggml_backend_webgpu_wait(ctx, sub);
}
/** End WebGPU Actions */
// memset the remaining bytes
ggml_backend_webgpu_buffer_memset(buf_ctx->global_ctx, buf_ctx->buffer, val32,
total_offset + (size - remaining_size), remaining_size);
- } else {
- // wait for WriteBuffer to complete
- buf_ctx->global_ctx->instance.WaitAny(buf_ctx->global_ctx->queue.OnSubmittedWorkDone(
- wgpu::CallbackMode::AllowSpontaneous,
- [](wgpu::QueueWorkDoneStatus status, wgpu::StringView message) {
- if (status != wgpu::QueueWorkDoneStatus::Success) {
- GGML_LOG_ERROR("ggml_webgpu: Failed to submit commands: %s\n",
- std::string(message).c_str());
- }
- }),
- UINT64_MAX);
}
WEBGPU_CPU_PROFILE_TOTAL_END(set_tensor, buf_ctx->global_ctx);
}
}
#endif
+#ifdef U32_DEQUANT_HELPERS
+fn load_src0_u16_at(byte_offset: u32) -> u32 {
+ let word = src0[byte_offset / 4u];
+ let shift = (byte_offset & 2u) * 8u;
+ return (word >> shift) & 0xFFFFu;
+}
+
+fn load_src0_u32_at(byte_offset: u32) -> u32 {
+ let word_idx = byte_offset / 4u;
+ let shift = (byte_offset & 3u) * 8u;
+ let lo = src0[word_idx];
+ if (shift == 0u) {
+ return lo;
+ }
+ let hi = src0[word_idx + 1u];
+ return (lo >> shift) | (hi << (32u - shift));
+}
+
+fn load_src0_f16_at(byte_offset: u32) -> f16 {
+ let packed = unpack2x16float(load_src0_u16_at(byte_offset));
+ return f16(packed[0]);
+}
+#endif
+
#ifdef Q4_0_T
struct q4_0 {
d: f16,
#ifdef KV_F32
#define KV_TYPE f32
+#elif defined(KV_Q4_0) || defined(KV_Q8_0)
+#define KV_TYPE u32
#else
#define KV_TYPE f16
#endif
#define NQ 16
// Q4_0 has 32 elements, 1 f16 for scale, 8 f16 for 4-bit weights
#define F16_PER_BLOCK 9
+#define BLOCK_SIZE_BYTES 18u
#define WEIGHTS_PER_F16 4
#elif defined(KV_Q8_0)
#define NQ 8
// Q8_0 has 32 elements, 1 f16 for scale, 16 f16 for 8-bit weights
#define F16_PER_BLOCK 17
+#define BLOCK_SIZE_BYTES 34u
#define WEIGHTS_PER_F16 2
#endif
#define F16_PER_THREAD (NQ / WEIGHTS_PER_F16)
return bitcast<i32>(((value >> (index * 8)) & 0xFF) << 24) >> 24;
}
+#if defined(KV_Q4_0) || defined(KV_Q8_0)
+fn load_k_u16_at(byte_offset: u32) -> u32 {
+ let word = K[byte_offset / 4u];
+ let shift = (byte_offset & 2u) * 8u;
+ return (word >> shift) & 0xFFFFu;
+}
+
+fn load_k_u32_at(byte_offset: u32) -> u32 {
+ let word_idx = byte_offset / 4u;
+ let shift = (byte_offset & 3u) * 8u;
+ let lo = K[word_idx];
+ if (shift == 0u) {
+ return lo;
+ }
+ let hi = K[word_idx + 1u];
+ return (lo >> shift) | (hi << (32u - shift));
+}
+
+fn load_v_u16_at(byte_offset: u32) -> u32 {
+ let word = V[byte_offset / 4u];
+ let shift = (byte_offset & 2u) * 8u;
+ return (word >> shift) & 0xFFFFu;
+}
+
+fn load_v_u32_at(byte_offset: u32) -> u32 {
+ let word_idx = byte_offset / 4u;
+ let shift = (byte_offset & 3u) * 8u;
+ let lo = V[word_idx];
+ if (shift == 0u) {
+ return lo;
+ }
+ let hi = V[word_idx + 1u];
+ return (lo >> shift) | (hi << (32u - shift));
+}
+
+fn f16_from_u16(bits: u32) -> f16 {
+ let packed = unpack2x16float(bits);
+ return f16(packed[0]);
+}
+#endif
+
struct Params {
offset_q: u32,
offset_k: u32,
if (global_k_row < params.seq_len_kv) {
let global_block_idx = k_head_offset + global_k_row * params.stride_k1 + block_k;
- let base_idx = global_block_idx * F16_PER_BLOCK;
- let d = K[base_idx]; // scale
+ let block_byte_base = global_block_idx * BLOCK_SIZE_BYTES;
+ let d = f16_from_u16(load_k_u16_at(block_byte_base));
for (var j = 0u; j < F16_PER_THREAD; j += 2) {
- let q_0 = K[base_idx + 1u + block_offset + j];
- let q_1 = K[base_idx + 1u + block_offset + j + 1];
- let q_packed = bitcast<u32>(vec2(q_0, q_1));
+ let q_byte_offset = block_byte_base + 2u + 2u * (block_offset + j);
+ let q_packed = load_k_u32_at(q_byte_offset);
for (var k = 0u; k < 4u; k++) {
let q_byte = get_byte(q_packed, k);
let q_hi = (f16((q_byte >> 4) & 0xF) - 8.0) * d;
if (global_k_row < params.seq_len_kv) {
let global_block_idx = k_head_offset + global_k_row * params.stride_k1 + block_k;
- let base_idx = global_block_idx * F16_PER_BLOCK;
- let d = K[base_idx]; // scale
+ let block_byte_base = global_block_idx * BLOCK_SIZE_BYTES;
+ let d = f16_from_u16(load_k_u16_at(block_byte_base));
for (var j = 0u; j < F16_PER_THREAD; j += 2) {
- let q_0 = K[base_idx + 1u + block_offset + j];
- let q_1 = K[base_idx + 1u + block_offset + j + 1];
- let q_packed = bitcast<u32>(vec2(q_0, q_1));
+ let q_byte_offset = block_byte_base + 2u + 2u * (block_offset + j);
+ let q_packed = load_k_u32_at(q_byte_offset);
for (var k = 0u; k < 4u; k++) {
let q_byte = get_byte_i32(q_packed, k);
let q_val = f16(q_byte) * d;
if (global_v_row < params.seq_len_kv) {
let global_block_idx = v_head_offset + global_v_row * params.stride_v1 + block_k;
- let base_idx = global_block_idx * F16_PER_BLOCK;
- let d = V[base_idx]; // scale
+ let block_byte_base = global_block_idx * BLOCK_SIZE_BYTES;
+ let d = f16_from_u16(load_v_u16_at(block_byte_base));
for (var j = 0u; j < F16_PER_THREAD; j += 2) {
- let q_0 = V[base_idx + 1u + block_offset + j];
- let q_1 = V[base_idx + 1u + block_offset + j + 1];
- let q_packed = bitcast<u32>(vec2(q_0, q_1));
+ let q_byte_offset = block_byte_base + 2u + 2u * (block_offset + j);
+ let q_packed = load_v_u32_at(q_byte_offset);
for (var k = 0u; k < 4u; k++) {
let q_byte = get_byte(q_packed, k);
let q_hi = (f16((q_byte >> 4) & 0xF) - 8.0) * d;
if (global_v_row < params.seq_len_kv) {
let global_block_idx = v_head_offset + global_v_row * params.stride_v1 + block_k;
- let base_idx = global_block_idx * F16_PER_BLOCK;
- let d = V[base_idx]; // scale
+ let block_byte_base = global_block_idx * BLOCK_SIZE_BYTES;
+ let d = f16_from_u16(load_v_u16_at(block_byte_base));
for (var j = 0u; j < F16_PER_THREAD; j += 2) {
- let q_0 = V[base_idx + 1u + block_offset + j];
- let q_1 = V[base_idx + 1u + block_offset + j + 1];
- let q_packed = bitcast<u32>(vec2(q_0, q_1));
+ let q_byte_offset = block_byte_base + 2u + 2u * (block_offset + j);
+ let q_packed = load_v_u32_at(q_byte_offset);
for (var k = 0u; k < 4u; k++) {
let q_byte = get_byte_i32(q_packed, k);
let q_val = f16(q_byte) * d;
#ifdef INIT_SRC0_SHMEM_Q4_0
const BLOCK_SIZE = 32u;
+const BLOCK_SIZE_BYTES = 18u;
// the number of blocks per k-tile. Note that this currently only works if TILE_K is a multiple of BLOCK_SIZE, which may need to be rethought for larger quantized types.
override BLOCKS_K = TILE_K/BLOCK_SIZE;
const NQ = 16u;
-const F16_PER_BLOCK = 9u; // 1 scale + 8x4 packed weights
const WEIGHTS_PER_F16 = 4u; // 4 weights per f16
const F16_PER_THREAD = NQ / WEIGHTS_PER_F16;
if (global_m < params.m && global_k < params.k / BLOCK_SIZE) {
let src0_idx = batch_offset + global_m * params.stride_01 + global_k;
- let scale_idx = src0_idx * F16_PER_BLOCK;
- let d = src0[scale_idx];
+ let block_byte_base = src0_idx * BLOCK_SIZE_BYTES;
+ let d = load_src0_f16_at(block_byte_base);
for (var j = 0u; j < F16_PER_THREAD; j += 2) {
- let q_0 = src0[scale_idx + 1u + block_offset + j];
- let q_1 = src0[scale_idx + 1u + block_offset + j + 1];
-
- let q_packed = bitcast<u32>(vec2(q_0, q_1));
+ let q_byte_offset = block_byte_base + 2u + 2u * (block_offset + j);
+ let q_packed = load_src0_u32_at(q_byte_offset);
for (var k = 0u; k < 4u; k++) {
let q_byte = get_byte(q_packed, k);
let q_hi = (f16((q_byte >> 4) & 0xF) - 8.0) * d;
#ifdef INIT_SRC0_SHMEM_Q4_1
const BLOCK_SIZE = 32u;
+const BLOCK_SIZE_BYTES = 20u;
// the number of blocks per k-tile. Note that this currently only works if TILE_K is a multiple of BLOCK_SIZE, which may need to be rethought for larger quantized types.
override BLOCKS_K = TILE_K/BLOCK_SIZE;
const NQ = 16u;
-const F16_PER_BLOCK = 10u; // 1 scale + 8 packed weights + 1 mean
const WEIGHTS_PER_F16 = 4u; // 4 weights per f16
const F16_PER_THREAD = NQ / WEIGHTS_PER_F16;
if (global_m < params.m && global_k < params.k / BLOCK_SIZE) {
let src0_idx = batch_offset + global_m * params.stride_01 + global_k;
- let scale_idx = src0_idx * F16_PER_BLOCK;
- let d = src0[scale_idx];
- let m = src0[scale_idx + 1u];
+ let block_byte_base = src0_idx * BLOCK_SIZE_BYTES;
+ let d = load_src0_f16_at(block_byte_base);
+ let m = load_src0_f16_at(block_byte_base + 2u);
for (var j = 0u; j < F16_PER_THREAD; j += 2) {
- let q_0 = src0[scale_idx + 2u + block_offset + j];
- let q_1 = src0[scale_idx + 2u + block_offset + j + 1];
-
- let q_packed = bitcast<u32>(vec2(q_0, q_1));
+ let q_byte_offset = block_byte_base + 4u + 2u * (block_offset + j);
+ let q_packed = load_src0_u32_at(q_byte_offset);
for (var k = 0u; k < 4u; k++) {
let q_byte = get_byte(q_packed, k);
let q_lo = f16(q_byte & 0xF) * d + m;
#ifdef INIT_SRC0_SHMEM_Q5_0
// 32 weights per block, each at 4 bits each = 32 * 4 = 128 bits / 16 = 8 f16s per block
const BLOCK_SIZE = 32u;
+const BLOCK_SIZE_BYTES = 22u;
// the number of blocks per k-tile. Note that this currently only works if TILE_K is a multiple of BLOCK_SIZE, which may need to be rethought for larger quantized types.
// tile_k is defined as 32u, so blocks_k ends up being 1 always
override BLOCKS_K = TILE_K / BLOCK_SIZE;
const NQ = 16u;
-const F16_PER_BLOCK = 11u; // 1 scale + 2 qh + 8 packed weights
const WEIGHTS_PER_F16 = 4u; // 4 weights per f16
const F16_PER_THREAD = NQ / WEIGHTS_PER_F16; // 16 / 4 = 4 f16s per thread, each thread should handle 4 f16s * 4 weights per = 16 weights
if (global_m < params.m && global_k < params.k / BLOCK_SIZE) {
let src0_idx = batch_offset + global_m * params.stride_01 + global_k;
- let scale_idx = src0_idx * F16_PER_BLOCK;
+ let block_byte_base = src0_idx * BLOCK_SIZE_BYTES;
- let d = src0[scale_idx];
- let qh0 = src0[scale_idx + 1u];
- let qh1 = src0[scale_idx + 2u];
- let qh_packed = bitcast<u32>(vec2(qh0, qh1));
+ let d = load_src0_f16_at(block_byte_base);
+ let qh_packed = load_src0_u32_at(block_byte_base + 2u);
for (var j = 0u; j < 2; j++) {
- let q_0 = src0[scale_idx + 3u + block_offset + (j*2)];
- let q_1 = src0[scale_idx + 3u + block_offset + (j*2) + 1u];
-
- let q_packed = bitcast<u32>(vec2(q_0, q_1));
+ let q_byte_offset = block_byte_base + 6u + 2u * (block_offset + j * 2u);
+ let q_packed = load_src0_u32_at(q_byte_offset);
let j_adjusted = j + (block_offset / 2u);
#ifdef INIT_SRC0_SHMEM_Q5_1
// 32 weights per block, each at 4 bits each = 32 * 4 = 128 bits / 16 = 8 f16s per block
const BLOCK_SIZE = 32u;
+const BLOCK_SIZE_BYTES = 24u;
// the number of blocks per k-tile. Note that this currently only works if TILE_K is a multiple of BLOCK_SIZE, which may need to be rethought for larger quantized types.
// tile_k is defined as 32u, so blocks_k ends up being 1 always
override BLOCKS_K = TILE_K / BLOCK_SIZE;
const NQ = 16u;
-const F16_PER_BLOCK = 12u; // 1 scale + 2 qh + 8 packed weights + 1 mean
const WEIGHTS_PER_F16 = 4u; // 4 weights per f16
const F16_PER_THREAD = NQ / WEIGHTS_PER_F16; // 16 / 4 = 4 f16s per thread, each thread should handle 4 f16s * 4 weights per = 16 weights
if (global_m < params.m && global_k < params.k / BLOCK_SIZE) {
let src0_idx = batch_offset + global_m * params.stride_01 + global_k;
- let scale_idx = src0_idx * F16_PER_BLOCK;
+ let block_byte_base = src0_idx * BLOCK_SIZE_BYTES;
- let d = src0[scale_idx];
- let m = src0[scale_idx + 1u];
- let qh0 = src0[scale_idx + 2u];
- let qh1 = src0[scale_idx + 3u];
- let qh_packed = bitcast<u32>(vec2(qh0, qh1));
+ let d = load_src0_f16_at(block_byte_base);
+ let m = load_src0_f16_at(block_byte_base + 2u);
+ let qh_packed = load_src0_u32_at(block_byte_base + 4u);
for (var j = 0u; j < 2; j++) {
- let q_0 = src0[scale_idx + 4u + block_offset + (j*2)];
- let q_1 = src0[scale_idx + 4u + block_offset + (j*2) + 1u];
-
- let q_packed = bitcast<u32>(vec2(q_0, q_1));
+ let q_byte_offset = block_byte_base + 8u + 2u * (block_offset + j * 2u);
+ let q_packed = load_src0_u32_at(q_byte_offset);
let j_adjusted = j + (block_offset / 2u);
#ifdef INIT_SRC0_SHMEM_Q8_0
const BLOCK_SIZE = 32u;
+const BLOCK_SIZE_BYTES = 34u;
// the number of blocks per k-tile. Note that this currently only works if TILE_K is a multiple of BLOCK_SIZE, which may need to be rethought for larger quantized types.
override BLOCKS_K = TILE_K/BLOCK_SIZE;
const NQ = 16u;
-const F16_PER_BLOCK = 17u; // 1 scale + 16 in array of weights
const WEIGHTS_PER_F16 = 2u; // 2 8-bit weights per f16
const F16_PER_THREAD = NQ / WEIGHTS_PER_F16; // 8 f16s per thread
if (global_m < params.m && global_k < params.k / BLOCK_SIZE) {
let src0_idx = batch_offset + global_m * params.stride_01 + global_k;
- let scale_idx = src0_idx * F16_PER_BLOCK;
- let d = src0[scale_idx];
+ let block_byte_base = src0_idx * BLOCK_SIZE_BYTES;
+ let d = load_src0_f16_at(block_byte_base);
for (var j = 0u; j < F16_PER_THREAD; j+=2) {
- let q_0 = src0[scale_idx + 1u + block_offset + j];
- let q_1 = src0[scale_idx + 1u + block_offset + j + 1];
-
- let q_packed = bitcast<u32>(vec2(q_0, q_1));
+ let q_byte_offset = block_byte_base + 2u + 2u * (block_offset + j);
+ let q_packed = load_src0_u32_at(q_byte_offset);
for (var k = 0u; k < 4u; k++) {
let q_byte = get_byte_i32(q_packed, k);
#ifdef INIT_SRC0_SHMEM_Q8_1
const BLOCK_SIZE = 32u;
+const BLOCK_SIZE_BYTES = 36u;
// the number of blocks per k-tile. Note that this currently only works if TILE_K is a multiple of BLOCK_SIZE, which may need to be rethought for larger quantized types.
override BLOCKS_K = TILE_K/BLOCK_SIZE;
const NQ = 16u;
-const F16_PER_BLOCK = 18u; // 1 scale + 1 mean + 8 32-bit values in array of weights
const WEIGHTS_PER_F16 = 2u; // 2 8-bit weights per f16
const F16_PER_THREAD = NQ / WEIGHTS_PER_F16; // 8 f16s per thread, 2 threads per block
if (global_m < params.m && global_k < params.k / BLOCK_SIZE) {
let src0_idx = batch_offset + global_m * params.stride_01 + global_k;
- let scale_idx = src0_idx * F16_PER_BLOCK;
- let d = src0[scale_idx];
- let m = src0[scale_idx + 1u];
+ let block_byte_base = src0_idx * BLOCK_SIZE_BYTES;
+ let d = load_src0_f16_at(block_byte_base);
+ let m = load_src0_f16_at(block_byte_base + 2u);
for (var j = 0u; j < F16_PER_THREAD; j+=2) {
- let q_0 = src0[scale_idx + 2u + block_offset + j];
- let q_1 = src0[scale_idx + 2u + block_offset + j + 1];
-
- let q_packed = bitcast<u32>(vec2(q_0, q_1));
+ let q_byte_offset = block_byte_base + 4u + 2u * (block_offset + j);
+ let q_packed = load_src0_u32_at(q_byte_offset);
for (var k = 0u; k < 4u; k++) {
let q_byte = get_byte_i32(q_packed, k);
#ifdef INIT_SRC0_SHMEM_Q2_K
const BLOCK_SIZE = 256u;
-const F16_PER_BLOCK = 42u;
+const BLOCK_SIZE_BYTES = 84u;
fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u32) {
// Use standard thread layout instead of lane/row_group
let k_in_block = global_k % BLOCK_SIZE;
let src0_idx = batch_offset + global_m * params.stride_01 + block_k;
- let scale_idx = src0_idx * F16_PER_BLOCK;
+ let block_byte_base = src0_idx * BLOCK_SIZE_BYTES;
- let d = src0[scale_idx + 40u];
- let dmin = src0[scale_idx + 41u];
+ let d = load_src0_f16_at(block_byte_base + 80u);
+ let dmin = load_src0_f16_at(block_byte_base + 82u);
// Decode the element at position k_in_block
let block_of_32 = k_in_block / 32u;
let is = k_in_block / 16u;
- let sc_0 = src0[scale_idx + 2u * (is / 4u)];
- let sc_1 = src0[scale_idx + 2u * (is / 4u) + 1u];
- let sc_packed = bitcast<u32>(vec2(sc_0, sc_1));
+ let sc_packed = load_src0_u32_at(block_byte_base + 4u * (is / 4u));
let sc = get_byte(sc_packed, is % 4u);
let dl = d * f16(sc & 0xFu);
let ml = dmin * f16(sc >> 4u);
let q_idx = q_b_idx + k + l;
- let q_0 = src0[scale_idx + 8u + 2u * (q_idx / 4u)];
- let q_1 = src0[scale_idx + 8u + 2u * (q_idx / 4u) + 1u];
- let q_packed = bitcast<u32>(vec2(q_0, q_1));
+ let q_packed = load_src0_u32_at(block_byte_base + 16u + 4u * (q_idx / 4u));
let q_byte = get_byte(q_packed, q_idx % 4u);
let qs_val = (q_byte >> shift) & 3u;
#ifdef INIT_SRC0_SHMEM_Q3_K
const BLOCK_SIZE = 256u;
-const F16_PER_BLOCK = 55u;
+const BLOCK_SIZE_BYTES = 110u;
fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u32) {
for (var elem_idx = thread_id; elem_idx < TILE_SRC0_SHMEM; elem_idx += TOTAL_WORKGROUP_SIZE) {
let k_in_block = global_k % BLOCK_SIZE;
let src0_idx = batch_offset + global_m * params.stride_01 + block_k;
- let scale_idx = src0_idx * F16_PER_BLOCK;
+ let block_byte_base = src0_idx * BLOCK_SIZE_BYTES;
- let d = src0[scale_idx + 54u];
+ let d = load_src0_f16_at(block_byte_base + 108u);
// Load and unpack scales
let kmask1: u32 = 0x03030303u;
var scale_vals: array<u32, 4>;
for (var i: u32 = 0u; i < 4u; i++) {
- let scale_0 = src0[scale_idx + 48u + (2u*i)];
- let scale_1 = src0[scale_idx + 48u + (2u*i) + 1u];
- scale_vals[i] = bitcast<u32>(vec2(scale_0, scale_1));
+ scale_vals[i] = load_src0_u32_at(block_byte_base + 96u + 4u * i);
}
var tmp: u32 = scale_vals[2];
// Load hmask and qs arrays
var hmask_vals: array<u32, 8>;
for (var i: u32 = 0u; i < 8u; i++) {
- let hmask_0 = src0[scale_idx + (2u*i)];
- let hmask_1 = src0[scale_idx + (2u*i) + 1u];
- hmask_vals[i] = bitcast<u32>(vec2(hmask_0, hmask_1));
+ hmask_vals[i] = load_src0_u32_at(block_byte_base + 4u * i);
}
var qs_vals: array<u32, 16>;
for (var i: u32 = 0u; i < 16u; i++) {
- let qs_0 = src0[scale_idx + 16u + (2u*i)];
- let qs_1 = src0[scale_idx + 16u + (2u*i) + 1u];
- qs_vals[i] = bitcast<u32>(vec2(qs_0, qs_1));
+ qs_vals[i] = load_src0_u32_at(block_byte_base + 32u + 4u * i);
}
let half = k_in_block / 128u; // 0 or 1
#ifdef INIT_SRC0_SHMEM_Q4_K
const BLOCK_SIZE = 256u;
-const F16_PER_BLOCK = 72u;
+const BLOCK_SIZE_BYTES = 144u;
fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u32) {
for (var elem_idx = thread_id; elem_idx < TILE_SRC0_SHMEM; elem_idx += TOTAL_WORKGROUP_SIZE) {
let k_in_block = global_k % BLOCK_SIZE;
let src0_idx = batch_offset + global_m * params.stride_01 + block_k;
- let scale_idx = src0_idx * F16_PER_BLOCK;
+ let block_byte_base = src0_idx * BLOCK_SIZE_BYTES;
- let d = src0[scale_idx];
- let dmin = src0[scale_idx + 1u];
+ let d = load_src0_f16_at(block_byte_base);
+ let dmin = load_src0_f16_at(block_byte_base + 2u);
// Load packed scales
var scale_vals: array<u32, 3>;
for (var i: u32 = 0u; i < 3u; i++) {
- let scale_0 = src0[scale_idx + 2u + (2u*i)];
- let scale_1 = src0[scale_idx + 2u + (2u*i) + 1u];
- scale_vals[i] = bitcast<u32>(vec2(scale_0, scale_1));
+ scale_vals[i] = load_src0_u32_at(block_byte_base + 4u + 4u * i);
}
// Map k_in_block to loop structure:
let ml = dmin * f16(mn);
let q_idx = q_b_idx + l;
- let q_0 = src0[scale_idx + 8u + 2u * (q_idx / 4u)];
- let q_1 = src0[scale_idx + 8u + 2u * (q_idx / 4u) + 1u];
- let q_packed = bitcast<u32>(vec2(q_0, q_1));
+ let q_packed = load_src0_u32_at(block_byte_base + 16u + 4u * (q_idx / 4u));
let q_byte = get_byte(q_packed, q_idx % 4u);
let qs_val = (q_byte >> shift) & 0xFu;
#ifdef INIT_SRC0_SHMEM_Q5_K
const BLOCK_SIZE = 256u;
-const F16_PER_BLOCK = 88u;
+const BLOCK_SIZE_BYTES = 176u;
fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u32) {
for (var elem_idx = thread_id; elem_idx < TILE_SRC0_SHMEM; elem_idx += TOTAL_WORKGROUP_SIZE) {
let k_in_block = global_k % BLOCK_SIZE;
let src0_idx = batch_offset + global_m * params.stride_01 + block_k;
- let scale_idx = src0_idx * F16_PER_BLOCK;
+ let block_byte_base = src0_idx * BLOCK_SIZE_BYTES;
- let d = src0[scale_idx];
- let dmin = src0[scale_idx + 1u];
+ let d = load_src0_f16_at(block_byte_base);
+ let dmin = load_src0_f16_at(block_byte_base + 2u);
// Load packed scales
var scale_vals: array<u32, 3>;
for (var i: u32 = 0u; i < 3u; i++) {
- let scale_0 = src0[scale_idx + 2u + (2u*i)];
- let scale_1 = src0[scale_idx + 2u + (2u*i) + 1u];
- scale_vals[i] = bitcast<u32>(vec2(scale_0, scale_1));
+ scale_vals[i] = load_src0_u32_at(block_byte_base + 4u + 4u * i);
}
// The original loop processes elements in groups of 64
let ml = dmin * f16(mn);
let q_idx = q_b_idx + l;
- let q_0 = src0[scale_idx + 24u + 2u * (q_idx / 4u)];
- let q_1 = src0[scale_idx + 24u + 2u * (q_idx / 4u) + 1u];
- let q_packed = bitcast<u32>(vec2(q_0, q_1));
+ let q_packed = load_src0_u32_at(block_byte_base + 48u + 4u * (q_idx / 4u));
let q_byte = get_byte(q_packed, q_idx % 4u);
- let qh_0 = src0[scale_idx + 8u + 2u * (l / 4u)];
- let qh_1 = src0[scale_idx + 8u + 2u * (l / 4u) + 1u];
- let qh_packed = bitcast<u32>(vec2(qh_0, qh_1));
+ let qh_packed = load_src0_u32_at(block_byte_base + 16u + 4u * (l / 4u));
let qh_byte = get_byte(qh_packed, l % 4u);
#ifdef INIT_SRC0_SHMEM_Q6_K
const BLOCK_SIZE = 256u;
-const F16_PER_BLOCK = 105u;
+const BLOCK_SIZE_BYTES = 210u;
fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u32) {
for (var elem_idx = thread_id; elem_idx < TILE_SRC0_SHMEM; elem_idx += TOTAL_WORKGROUP_SIZE) {
let k_in_block = global_k % BLOCK_SIZE;
let src0_idx = batch_offset + global_m * params.stride_01 + block_k;
- let scale_idx = src0_idx * F16_PER_BLOCK;
+ let block_byte_base = src0_idx * BLOCK_SIZE_BYTES;
let half = k_in_block / 128u;
let pos_in_half = k_in_block % 128u;
// Load only ql13 word needed
let ql13_flat = ql_b_idx + l;
- let ql13_word = ql13_flat / 4u;
- let ql13 = bitcast<u32>(vec2(
- src0[scale_idx + 2u * ql13_word],
- src0[scale_idx + 2u * ql13_word + 1u]
- ));
- let ql13_b = get_byte(ql13, ql13_flat % 4u);
+ let ql13 = load_src0_u32_at(block_byte_base + ql13_flat);
+ let ql13_b = get_byte(ql13, 0u);
// Load only ql24 word needed
let ql24_flat = ql_b_idx + l + 32u;
- let ql24_word = ql24_flat / 4u;
- let ql24 = bitcast<u32>(vec2(
- src0[scale_idx + 2u * ql24_word],
- src0[scale_idx + 2u * ql24_word + 1u]
- ));
- let ql24_b = get_byte(ql24, ql24_flat % 4u);
+ let ql24 = load_src0_u32_at(block_byte_base + ql24_flat);
+ let ql24_b = get_byte(ql24, 0u);
// Load only qh word needed
let qh_flat = qh_b_idx + l;
- let qh_word = qh_flat / 4u;
- let qh = bitcast<u32>(vec2(
- src0[scale_idx + 64u + 2u * qh_word],
- src0[scale_idx + 64u + 2u * qh_word + 1u]
- ));
- let qh_b = get_byte(qh, qh_flat % 4u);
+ let qh = load_src0_u32_at(block_byte_base + 128u + qh_flat);
+ let qh_b = get_byte(qh, 0u);
let q1 = f16((ql13_b & 0xFu) | ((qh_b & 3u) << 4u)) - f16(32.0);
let q2 = f16((ql24_b & 0xFu) | (((qh_b >> 2u) & 3u) << 4u)) - f16(32.0);
// Load only the scale word needed
let is = l / 16u;
let sc_idx = sc_b_idx + is + quarter * 2u;
- let sc_word = sc_idx / 4u;
- let sc = bitcast<u32>(vec2(
- src0[scale_idx + 96u + 2u * sc_word],
- src0[scale_idx + 96u + 2u * sc_word + 1u]
- ));
- let sc_val = get_byte_i32(sc, sc_idx % 4u);
-
- let d = src0[scale_idx + 104u];
+ let sc = load_src0_u32_at(block_byte_base + 192u + sc_idx);
+ let sc_val = get_byte_i32(sc, 0u);
+
+ let d = load_src0_f16_at(block_byte_base + 208u);
var q_val: f16;
if (quarter == 0u) {
#ifdef MUL_ACC_Q4_0
const BLOCK_SIZE = 32;
+const BLOCK_SIZE_BYTES = 18u;
const NQ = 16u; // number of weights per thread
-const F16_PER_BLOCK = 9u; // 1 scale + 8x4 packed weights
const WEIGHTS_PER_F16 = 4u; // 4 weights per f16
const F16_PER_THREAD = NQ / WEIGHTS_PER_F16;
for (var i = tig * NQ; i < tile_size; i += THREADS_PER_OUTPUT * NQ) {
let blck_idx = i / BLOCK_SIZE;
let block_offset = (i % BLOCK_SIZE) / WEIGHTS_PER_F16;
- let scale_idx = (idx_base + k_outer / BLOCK_SIZE + blck_idx) * F16_PER_BLOCK;
+ let block_byte_base = (idx_base + k_outer / BLOCK_SIZE + blck_idx) * BLOCK_SIZE_BYTES;
// each f16 contains offsets [block_offset, block_offset + 1] and [block_offset + 16, block_offset + 17]
let shmem_idx = blck_idx * BLOCK_SIZE + block_offset * 2u;
- let d = f32(src0[scale_idx]);
+ let d = f32(load_src0_f16_at(block_byte_base));
for (var j = 0u; j < F16_PER_THREAD; j += 2) {
- let q_0 = src0[scale_idx + 1 + block_offset + j];
- let q_1 = src0[scale_idx + 1 + block_offset + j + 1];
- let q_packed = bitcast<u32>(vec2(q_0, q_1));
+ let q_byte_offset = block_byte_base + 2u + 2u * (block_offset + j);
+ let q_packed = load_src0_u32_at(q_byte_offset);
for (var k: u32 = 0; k < 4; k++) {
let q_byte = get_byte(q_packed, k);
let q_hi = (f32((q_byte >> 4) & 0xF) - 8.0) * d;
#ifdef MUL_ACC_Q4_1
const BLOCK_SIZE = 32;
+const BLOCK_SIZE_BYTES = 20u;
const NQ = 16u; // number of weights per thread
-const F16_PER_BLOCK = 10u;
const WEIGHTS_PER_F16 = 4u; // 4 weights per f16
const F16_PER_THREAD = NQ / WEIGHTS_PER_F16;
for (var i = tig * NQ; i < tile_size; i += THREADS_PER_OUTPUT * NQ) {
let blck_idx = i / BLOCK_SIZE;
let block_offset = (i % BLOCK_SIZE) / WEIGHTS_PER_F16;
- let scale_idx = (idx_base + k_outer / BLOCK_SIZE + blck_idx) * F16_PER_BLOCK;
+ let block_byte_base = (idx_base + k_outer / BLOCK_SIZE + blck_idx) * BLOCK_SIZE_BYTES;
// each f16 contains offsets [block_offset, block_offset + 1] and [block_offset + 16, block_offset + 17]
let shmem_idx = blck_idx * BLOCK_SIZE + block_offset * 2u;
- let d = f32(src0[scale_idx]);
- let m = f32(src0[scale_idx + 1u]);
+ let d = f32(load_src0_f16_at(block_byte_base));
+ let m = f32(load_src0_f16_at(block_byte_base + 2u));
for (var j = 0u; j < F16_PER_THREAD; j += 2) {
- let q_0 = src0[scale_idx + 2u + block_offset + j];
- let q_1 = src0[scale_idx + 2u + block_offset + j + 1];
- let q_packed = bitcast<u32>(vec2(q_0, q_1));
+ let q_byte_offset = block_byte_base + 4u + 2u * (block_offset + j);
+ let q_packed = load_src0_u32_at(q_byte_offset);
for (var k: u32 = 0; k < 4; k++) {
let q_byte = get_byte(q_packed, k);
let q_hi = f32((q_byte >> 4) & 0xF) * d + m;
#ifdef MUL_ACC_Q5_0
const BLOCK_SIZE = 32;
+const BLOCK_SIZE_BYTES = 22u;
const NQ = 16u; // number of weights per thread
-const F16_PER_BLOCK = 11u;
const WEIGHTS_PER_F16 = 4u; // 4 weights per f16
const F16_PER_THREAD = NQ / WEIGHTS_PER_F16;
for (var i = tig * NQ; i < tile_size; i += THREADS_PER_OUTPUT * NQ) {
let blck_idx = i / BLOCK_SIZE;
let block_offset = (i % BLOCK_SIZE) / WEIGHTS_PER_F16;
- let scale_idx = (idx_base + k_outer / BLOCK_SIZE + blck_idx) * F16_PER_BLOCK;
+ let block_byte_base = (idx_base + k_outer / BLOCK_SIZE + blck_idx) * BLOCK_SIZE_BYTES;
// each f16 contains offsets [block_offset, block_offset + 1] and [block_offset + 16, block_offset + 17]
let shmem_idx = blck_idx * BLOCK_SIZE + block_offset * 2u;
- let d = f32(src0[scale_idx]);
- let qh0 = src0[scale_idx + 1u];
- let qh1 = src0[scale_idx + 2u];
- let qh_packed = bitcast<u32>(vec2(qh0, qh1));
+ let d = f32(load_src0_f16_at(block_byte_base));
+ let qh_packed = load_src0_u32_at(block_byte_base + 2u);
for (var j = 0u; j < 2; j++) {
- let q_0 = src0[scale_idx + 3u + block_offset + (j*2)];
- let q_1 = src0[scale_idx + 3u + block_offset + (j*2) + 1u];
- let q_packed = bitcast<u32>(vec2(q_0, q_1));
+ let q_byte_offset = block_byte_base + 6u + 2u * (block_offset + j * 2u);
+ let q_packed = load_src0_u32_at(q_byte_offset);
let j_adjusted = j + (block_offset / 2u);
#ifdef MUL_ACC_Q5_1
const BLOCK_SIZE = 32;
+const BLOCK_SIZE_BYTES = 24u;
const NQ = 16u; // number of weights per thread
-const F16_PER_BLOCK = 12u;
const WEIGHTS_PER_F16 = 4u; // 4 weights per f16
const F16_PER_THREAD = NQ / WEIGHTS_PER_F16;
for (var i = tig * NQ; i < tile_size; i += THREADS_PER_OUTPUT * NQ) {
let blck_idx = i / BLOCK_SIZE;
let block_offset = (i % BLOCK_SIZE) / WEIGHTS_PER_F16;
- let scale_idx = (idx_base + k_outer / BLOCK_SIZE + blck_idx) * F16_PER_BLOCK;
+ let block_byte_base = (idx_base + k_outer / BLOCK_SIZE + blck_idx) * BLOCK_SIZE_BYTES;
// each f16 contains offsets [block_offset, block_offset + 1] and [block_offset + 16, block_offset + 17]
let shmem_idx = blck_idx * BLOCK_SIZE + block_offset * 2u;
- let d = f32(src0[scale_idx]);
- let m = src0[scale_idx + 1u];
- let qh0 = src0[scale_idx + 2u];
- let qh1 = src0[scale_idx + 3u];
- let qh_packed = bitcast<u32>(vec2(qh0, qh1));
+ let d = f32(load_src0_f16_at(block_byte_base));
+ let m = load_src0_f16_at(block_byte_base + 2u);
+ let qh_packed = load_src0_u32_at(block_byte_base + 4u);
for (var j = 0u; j < 2; j++) {
- let q_0 = src0[scale_idx + 4u + block_offset + (j*2)];
- let q_1 = src0[scale_idx + 4u + block_offset + (j*2) + 1u];
- let q_packed = bitcast<u32>(vec2(q_0, q_1));
+ let q_byte_offset = block_byte_base + 8u + 2u * (block_offset + j * 2u);
+ let q_packed = load_src0_u32_at(q_byte_offset);
let j_adjusted = j + (block_offset / 2u);
#ifdef MUL_ACC_Q8_0
const BLOCK_SIZE = 32;
+const BLOCK_SIZE_BYTES = 34u;
const NQ = 16u; // number of weights per thread
-const F16_PER_BLOCK = 17u;
const WEIGHTS_PER_F16 = 2u;
const F16_PER_THREAD = NQ / WEIGHTS_PER_F16;
for (var i = tig * NQ; i < tile_size; i += THREADS_PER_OUTPUT * NQ) {
let blck_idx = i / BLOCK_SIZE;
let block_offset = (i % BLOCK_SIZE) / WEIGHTS_PER_F16;
- let scale_idx = (idx_base + k_outer / BLOCK_SIZE + blck_idx) * F16_PER_BLOCK;
+ let block_byte_base = (idx_base + k_outer / BLOCK_SIZE + blck_idx) * BLOCK_SIZE_BYTES;
// each f16 contains offsets [block_offset, block_offset + 1] and [block_offset + 16, block_offset + 17]
let shmem_idx = blck_idx * BLOCK_SIZE + block_offset * 2u;
- let d = f32(src0[scale_idx]);
+ let d = f32(load_src0_f16_at(block_byte_base));
for (var j = 0u; j < F16_PER_THREAD; j += 2) {
- let q_0 = src0[scale_idx + 1 + block_offset + j];
- let q_1 = src0[scale_idx + 1 + block_offset + j + 1];
- let q_packed = bitcast<u32>(vec2(q_0, q_1));
+ let q_byte_offset = block_byte_base + 2u + 2u * (block_offset + j);
+ let q_packed = load_src0_u32_at(q_byte_offset);
for (var k: u32 = 0; k < 4; k++) {
let q_byte = get_byte_i32(q_packed, k);
let q_val = f32(q_byte) * d;
#ifdef MUL_ACC_Q8_1
const BLOCK_SIZE = 32;
+const BLOCK_SIZE_BYTES = 36u;
const NQ = 16u; // number of weights per thread
-const F16_PER_BLOCK = 18u;
const WEIGHTS_PER_F16 = 2u;
const F16_PER_THREAD = NQ / WEIGHTS_PER_F16;
for (var i = tig * NQ; i < tile_size; i += THREADS_PER_OUTPUT * NQ) {
let blck_idx = i / BLOCK_SIZE;
let block_offset = (i % BLOCK_SIZE) / WEIGHTS_PER_F16;
- let scale_idx = (idx_base + k_outer / BLOCK_SIZE + blck_idx) * F16_PER_BLOCK;
+ let block_byte_base = (idx_base + k_outer / BLOCK_SIZE + blck_idx) * BLOCK_SIZE_BYTES;
// each f16 contains offsets [block_offset, block_offset + 1] and [block_offset + 16, block_offset + 17]
let shmem_idx = blck_idx * BLOCK_SIZE + block_offset * 2u;
- let d = f32(src0[scale_idx]);
- let m = src0[scale_idx + 1u];
+ let d = f32(load_src0_f16_at(block_byte_base));
+ let m = load_src0_f16_at(block_byte_base + 2u);
for (var j = 0u; j < F16_PER_THREAD; j += 2) {
- let q_0 = src0[scale_idx + 2u + block_offset + j];
- let q_1 = src0[scale_idx + 2u + block_offset + j + 1];
- let q_packed = bitcast<u32>(vec2(q_0, q_1));
+ let q_byte_offset = block_byte_base + 4u + 2u * (block_offset + j);
+ let q_packed = load_src0_u32_at(q_byte_offset);
for (var k: u32 = 0; k < 4; k++) {
let q_byte = get_byte_i32(q_packed, k);
let q_val = f32(q_byte) * d + f32(m);
#ifdef MUL_ACC_Q6_K
const BLOCK_SIZE = 256u;
-const F16_PER_BLOCK = 105u;
-
-fn load_u32_at(bbase: u32, byte_offset: u32) -> u32 {
- let aligned = byte_offset & ~3u;
- let idx = bbase + aligned / 2u;
- return bitcast<u32>(vec2(src0[idx], src0[idx + 1u]));
-}
+const BLOCK_SIZE_BYTES = 210u;
fn byte_of(v: u32, b: u32) -> u32 {
return (v >> (b * 8u)) & 0xFFu;
var local_sum = 0.0;
for (var i = ix; i < nb; i += 2u) {
- let bbase = (idx_base + k_block_start + i) * F16_PER_BLOCK;
+ let bbase = (idx_base + k_block_start + i) * BLOCK_SIZE_BYTES;
- let d_raw = load_u32_at(bbase, 208u);
- let d = f32(bitcast<vec2<f16>>(d_raw)[0]);
+ let d = f32(load_src0_f16_at(bbase + 208u));
- let ql1_u32 = load_u32_at(bbase, q_offset_l);
- let ql2_u32 = load_u32_at(bbase, q_offset_l + 32u);
- let qh_u32 = load_u32_at(bbase, 128u + q_offset_h);
- let sc_u32_0 = load_u32_at(bbase, sc_base_byte);
- let sc_u32_1 = load_u32_at(bbase, sc_base_byte + 4u);
+ let ql1_u32 = load_src0_u32_at(bbase + q_offset_l);
+ let ql2_u32 = load_src0_u32_at(bbase + q_offset_l + 32u);
+ let qh_u32 = load_src0_u32_at(bbase + 128u + q_offset_h);
+ let sc_u32_0 = load_src0_u32_at(bbase + sc_base_byte);
+ let sc_u32_1 = load_src0_u32_at(bbase + sc_base_byte + 4u);
let sc0 = sbyte_of(sc_u32_0, sc_byte_pos);
let sc2 = sbyte_of(sc_u32_0, sc_byte_pos + 2u);