#define PARAMS_BINDING 4
#endif
-@group(0) @binding(DST_BINDING) var<storage, read_write> dst: array<f32>;
+@group(0) @binding(DST_BINDING) var<storage, read_write> dst: array<vec4<f32>>;
@group(0) @binding(PARAMS_BINDING) var<uniform> params: Params;
// Just a very small float value.
return v;
}
+fn load_f32x4(buf: ptr<storage, array<vec4<f32>>, read_write>, scalar_index: u32) -> vec4<f32> {
+ return (*buf)[scalar_index >> 2u];
+}
+
+fn load_kvx4(buf: ptr<storage, array<vec4<KV_TYPE>>, read_write>, scalar_index: u32) -> vec4<KV_TYPE> {
+ return (*buf)[scalar_index >> 2u];
+}
@compute @workgroup_size(WG_SIZE)
fn main(@builtin(workgroup_id) wg_id: vec3<u32>,
- @builtin(local_invocation_id) local_id: vec3<u32>,
- @builtin(subgroup_id) subgroup_id: u32,
- @builtin(subgroup_size) subgroup_size: u32,
- @builtin(num_subgroups) num_subgroups: u32,
- @builtin(subgroup_invocation_id) sg_inv_id: u32) {
+ @builtin(local_invocation_id) local_id: vec3<u32>,
+ @builtin(subgroup_id) subgroup_id: u32,
+ @builtin(subgroup_size) subgroup_size: u32,
+ @builtin(num_subgroups) num_subgroups: u32,
+ @builtin(subgroup_invocation_id) sg_inv_id: u32) {
// initialize row max for online softmax
for (var i = local_id.x; i < Q_TILE; i += WG_SIZE) {
for (var kv_tile = 0u; kv_tile < params.seq_len_kv; kv_tile += KV_TILE) {
// clear inter_shmem to ensure zero-initialized accumulators
- for (var elem_idx = local_id.x; elem_idx < Q_TILE * KV_TILE; elem_idx += WG_SIZE) {
- inter_shmem[elem_idx] = 0.0;
- }
+ for (var elem_idx = local_id.x; elem_idx < Q_TILE * KV_TILE; elem_idx += WG_SIZE) {
+ inter_shmem[elem_idx] = 0.0;
+ }
// load k tile into shared memory
#if defined(KV_Q4_0)
// accumulate q block * k block into registers across the entire KV tile
// TODO: this loop seems to be the current largest bottleneck
- for (var kv_block = subgroup_id; kv_block < KV_BLOCKS; kv_block += num_subgroups) {
- let inter_offset = kv_block * SG_MAT_N;
- var acc: subgroup_matrix_result<f16, SG_MAT_M, SG_MAT_N> = subgroupMatrixLoad<
- subgroup_matrix_result<f16, SG_MAT_M, SG_MAT_N>>(&inter_shmem, inter_offset, false, KV_TILE);
+ // this bracket exists to scope the lifetime of variables, reducing register pressure
+ {
#ifdef KV_DIRECT
- let k_block_row = kv_tile + kv_block * SG_MAT_N;
- let k_global_offset = k_head_offset + k_block_row * params.stride_k1;
+ let k_block_row = kv_tile + subgroup_id * SG_MAT_N;
+ var k_global_offset = k_head_offset + k_block_row * params.stride_k1;
#else
- let k_block_offset = kv_block * SG_MAT_N * HEAD_DIM_QK;
+ var k_block_offset = subgroup_id * SG_MAT_N * HEAD_DIM_QK;
#endif
- for (var head_dim_block = 0u; head_dim_block < HEAD_DIM_QK; head_dim_block += SG_MAT_K) {
- // load q submatrix from shared memory
- var q_sg_mat: subgroup_matrix_left<f16, SG_MAT_M, SG_MAT_K> = subgroupMatrixLoad<subgroup_matrix_left<f16, SG_MAT_M, SG_MAT_K>>(
- &q_shmem,
- head_dim_block,
- false,
- HEAD_DIM_QK
- );
+ for (var kv_block = subgroup_id; kv_block < KV_BLOCKS; kv_block += num_subgroups) {
+ let inter_offset = kv_block * SG_MAT_N;
+ var acc: subgroup_matrix_result<f16, SG_MAT_M, SG_MAT_N> = subgroupMatrixLoad<subgroup_matrix_result<f16, SG_MAT_M, SG_MAT_N>>(&inter_shmem, inter_offset, false, KV_TILE);
+
+ var q_cur = subgroupMatrixLoad<subgroup_matrix_left<f16, SG_MAT_M, SG_MAT_K>>(&q_shmem, 0u, false, HEAD_DIM_QK);
- // load k submatrix from device or shared memory
#ifdef KV_DIRECT
- var k_sg_mat: subgroup_matrix_right<f16, SG_MAT_K, SG_MAT_N> = subgroupMatrixLoad<subgroup_matrix_right<f16, SG_MAT_K, SG_MAT_N>>(
- &K,
- k_global_offset + head_dim_block,
- true,
- params.stride_k1
- );
+ var k_cur = subgroupMatrixLoad<subgroup_matrix_right<f16, SG_MAT_K, SG_MAT_N>>(&K, k_global_offset + 0u, true, params.stride_k1);
#else
- var k_sg_mat: subgroup_matrix_right<f16, SG_MAT_K, SG_MAT_N> = subgroupMatrixLoad<subgroup_matrix_right<f16, SG_MAT_K, SG_MAT_N>>(
- &kv_shmem,
- k_block_offset + head_dim_block,
- true,
- HEAD_DIM_QK
- );
+ var k_cur = subgroupMatrixLoad<subgroup_matrix_right<f16, SG_MAT_K, SG_MAT_N>>(&kv_shmem, k_block_offset + 0u, true, HEAD_DIM_QK);
#endif
- acc = subgroupMatrixMultiplyAccumulate(q_sg_mat, k_sg_mat, acc);
- }
- // store acc to shared memory for softmax (S matrix from paper)
- subgroupMatrixStore(&inter_shmem, inter_offset, acc, false, KV_TILE);
+ var t: u32 = 1u;
+ for (; t + 1u < HEAD_DIM_QK / SG_MAT_K; t += 2u) {
+ let h0 = t * SG_MAT_K;
+ var q0 = subgroupMatrixLoad<subgroup_matrix_left<f16, SG_MAT_M, SG_MAT_K>>(&q_shmem, h0, false, HEAD_DIM_QK);
+#ifdef KV_DIRECT
+ var k0 = subgroupMatrixLoad<subgroup_matrix_right<f16, SG_MAT_K, SG_MAT_N>>(&K, k_global_offset + h0, true, params.stride_k1);
+#else
+ var k0 = subgroupMatrixLoad<subgroup_matrix_right<f16, SG_MAT_K, SG_MAT_N>>(&kv_shmem, k_block_offset + h0, true, HEAD_DIM_QK);
+#endif
+ acc = subgroupMatrixMultiplyAccumulate(q_cur, k_cur, acc);
+ q_cur = q0;
+ k_cur = k0;
+
+ let h1 = (t + 1u) * SG_MAT_K;
+ var q1g = subgroupMatrixLoad<subgroup_matrix_left<f16, SG_MAT_M, SG_MAT_K>>(&q_shmem, h1, false, HEAD_DIM_QK);
+#ifdef KV_DIRECT
+ var k1g = subgroupMatrixLoad<subgroup_matrix_right<f16, SG_MAT_K, SG_MAT_N>>(&K, k_global_offset + h1, true, params.stride_k1);
+#else
+ var k1g = subgroupMatrixLoad<subgroup_matrix_right<f16, SG_MAT_K, SG_MAT_N>>(&kv_shmem, k_block_offset + h1, true, HEAD_DIM_QK);
+#endif
+ acc = subgroupMatrixMultiplyAccumulate(q_cur, k_cur, acc);
+ q_cur = q1g;
+ k_cur = k1g;
+ }
+
+ // handle odd tail
+ if (t < HEAD_DIM_QK / SG_MAT_K) {
+ let h = t * SG_MAT_K;
+ var qn = subgroupMatrixLoad<subgroup_matrix_left<f16, SG_MAT_M, SG_MAT_K>>(&q_shmem, h, false, HEAD_DIM_QK);
+#ifdef KV_DIRECT
+ var kn = subgroupMatrixLoad<subgroup_matrix_right<f16, SG_MAT_K, SG_MAT_N>>(&K, k_global_offset + h, true, params.stride_k1);
+#else
+ var kn = subgroupMatrixLoad<subgroup_matrix_right<f16, SG_MAT_K, SG_MAT_N>>(&kv_shmem, k_block_offset + h, true, HEAD_DIM_QK);
+#endif
+ acc = subgroupMatrixMultiplyAccumulate(q_cur, k_cur, acc);
+ q_cur = qn;
+ k_cur = kn;
+ }
+
+ acc = subgroupMatrixMultiplyAccumulate(q_cur, k_cur, acc);
+
+#ifdef KV_DIRECT
+ k_global_offset += num_subgroups * SG_MAT_N * params.stride_k1;
+#else
+ k_block_offset += num_subgroups * SG_MAT_N * HEAD_DIM_QK;
+#endif
+ subgroupMatrixStore(&inter_shmem, inter_offset, acc, false, KV_TILE);
+ }
}
+
#ifdef MASK
// load mask tile into shared memory for this KV block
// TODO: optimize and skip if mask is -INF for the entire tile
false,
HEAD_DIM_V
);
-
for (var kv_block = 0u; kv_block < KV_BLOCKS; kv_block++) {
let p_offset = kv_block * SG_MAT_N;
var p_sg_mat: subgroup_matrix_left<f16, SG_MAT_M, SG_MAT_K> = subgroupMatrixLoad<subgroup_matrix_left<f16, SG_MAT_M, SG_MAT_K>>(
// O += P * V
o_sg_mat = subgroupMatrixMultiplyAccumulate(p_sg_mat, v_sg_mat, o_sg_mat);
}
-
// store O back to shared memory
subgroupMatrixStore(&o_shmem, head_dim_block, o_sg_mat, false, HEAD_DIM_V);
}
-
workgroupBarrier();
}
o_shmem[idx] = f16(val);
}
}
-
workgroupBarrier();
#endif
-
- // write output back to global memory
for (var q_tile_row = subgroup_id;
- q_tile_row < Q_TILE;
- q_tile_row += num_subgroups) {
- let global_q_row = q_row_start + q_tile_row;
- if (global_q_row >= params.seq_len_q) {
- break;
- }
+ q_tile_row < Q_TILE;
+ q_tile_row += num_subgroups) {
- let exp_sum = exp_sum_shmem[q_tile_row];
- let scale = select(0.0, 1.0 / exp_sum, exp_sum != 0);
+ let global_q_row = q_row_start + q_tile_row;
+ if (global_q_row >= params.seq_len_q) { break; }
- for (var elem_idx = sg_inv_id; elem_idx < HEAD_DIM_V; elem_idx += subgroup_size) {
- let o_val = o_shmem[q_tile_row * HEAD_DIM_V + elem_idx];
- let scaled = f32(o_val) * scale;
- dst[dst_global_offset + q_tile_row * dst2_stride + elem_idx] = scaled;
- }
+ let exp_sum = exp_sum_shmem[q_tile_row];
+ let scale = select(0.0, 1.0 / exp_sum, exp_sum != 0.0);
+
+ let row_base: u32 = dst_global_offset + q_tile_row * dst2_stride;
+
+ for (var elem_base = sg_inv_id * 4u;
+ elem_base < HEAD_DIM_V;
+ elem_base += subgroup_size * 4u) {
+
+ let i0 = q_tile_row * HEAD_DIM_V + (elem_base + 0u);
+ let i1 = q_tile_row * HEAD_DIM_V + (elem_base + 1u);
+ let i2 = q_tile_row * HEAD_DIM_V + (elem_base + 2u);
+ let i3 = q_tile_row * HEAD_DIM_V + (elem_base + 3u);
+
+ let v = vec4<f32>(
+ f32(o_shmem[i0]) * scale,
+ f32(o_shmem[i1]) * scale,
+ f32(o_shmem[i2]) * scale,
+ f32(o_shmem[i3]) * scale
+ );
+
+ let dst_vec_index: u32 = (row_base + elem_base) >> 2u;
+ dst[dst_vec_index] = v;
+ }
}
}