]> git.djapps.eu Git - pkg/ggml/sources/ggml/commitdiff
ggml-webgpu: improve flastAttention performance by software pipelining (llama/19151)
authorZheyuan Chen <redacted>
Thu, 29 Jan 2026 22:05:30 +0000 (14:05 -0800)
committerGeorgi Gerganov <redacted>
Fri, 30 Jan 2026 11:49:29 +0000 (13:49 +0200)
* webgpu : pipeline flash_attn Q/K loads in WGSL

* ggml-webgpu: unroll Q*K accumlation inner loop

* ggml-webgpu: vectorization

* ggml-webgpu: unrolling

* ggml-webgpu: remove redundant unrolling

* ggml-webgpu: restore the config

* ggml-webgpu: remove redundant comments

* ggml-webgpu: formatting

* ggml-webgpu: formatting and remove vectorization

* ggml-webgpu: remove unnecessary constants

* ggml-webgpu: change QKV buffer to read_write to pass validation

* ggml-webgpu: add explanation for the additional bracket around Q K accumulate

* Indentation and for -> if for tail

* Kick off CI on wgsl only commits

---------

Co-authored-by: Reese Levine <redacted>
src/ggml-webgpu/wgsl-shaders/flash_attn.wgsl

index de7c132a624e71b10899c31743e0ed1c7374825b..b6822161464c6a7d8d7e494288de48bb268d3252 100644 (file)
@@ -114,7 +114,7 @@ struct Params {
 #define PARAMS_BINDING 4
 #endif
 
-@group(0) @binding(DST_BINDING) var<storage, read_write> dst: array<f32>;
+@group(0) @binding(DST_BINDING) var<storage, read_write> dst: array<vec4<f32>>;
 @group(0) @binding(PARAMS_BINDING) var<uniform> params: Params;
 
 // Just a very small float value.
@@ -160,14 +160,21 @@ fn calc_softmax_term(kv_idx: u32, q_tile_row: u32, slope: f32) -> f32 {
     return v;
 }
 
+fn load_f32x4(buf: ptr<storage, array<vec4<f32>>, read_write>, scalar_index: u32) -> vec4<f32> {
+    return (*buf)[scalar_index >> 2u];
+}
+
+fn load_kvx4(buf: ptr<storage, array<vec4<KV_TYPE>>, read_write>, scalar_index: u32) -> vec4<KV_TYPE> {
+    return (*buf)[scalar_index >> 2u];
+}
 
 @compute @workgroup_size(WG_SIZE)
 fn main(@builtin(workgroup_id) wg_id: vec3<u32>,
-        @builtin(local_invocation_id) local_id: vec3<u32>,
-        @builtin(subgroup_id) subgroup_id: u32,
-        @builtin(subgroup_size) subgroup_size: u32,
-        @builtin(num_subgroups) num_subgroups: u32,
-        @builtin(subgroup_invocation_id) sg_inv_id: u32) {
+    @builtin(local_invocation_id) local_id: vec3<u32>,
+    @builtin(subgroup_id) subgroup_id: u32,
+    @builtin(subgroup_size) subgroup_size: u32,
+    @builtin(num_subgroups) num_subgroups: u32,
+    @builtin(subgroup_invocation_id) sg_inv_id: u32) {
 
     // initialize row max for online softmax
     for (var i = local_id.x; i < Q_TILE; i += WG_SIZE) {
@@ -231,9 +238,9 @@ fn main(@builtin(workgroup_id) wg_id: vec3<u32>,
 
     for (var kv_tile = 0u; kv_tile < params.seq_len_kv; kv_tile += KV_TILE) {
       // clear inter_shmem to ensure zero-initialized accumulators
-      for (var elem_idx = local_id.x; elem_idx < Q_TILE * KV_TILE; elem_idx += WG_SIZE) {
-          inter_shmem[elem_idx] = 0.0;
-      }
+        for (var elem_idx = local_id.x; elem_idx < Q_TILE * KV_TILE; elem_idx += WG_SIZE) {
+            inter_shmem[elem_idx] = 0.0;
+        }
 
       // load k tile into shared memory
 #if defined(KV_Q4_0)
@@ -309,48 +316,77 @@ fn main(@builtin(workgroup_id) wg_id: vec3<u32>,
 
       // accumulate q block * k block into registers across the entire KV tile
       // TODO: this loop seems to be the current largest bottleneck
-      for (var kv_block = subgroup_id; kv_block < KV_BLOCKS; kv_block += num_subgroups) {
-          let inter_offset = kv_block * SG_MAT_N;
-          var acc: subgroup_matrix_result<f16, SG_MAT_M, SG_MAT_N> = subgroupMatrixLoad<
-              subgroup_matrix_result<f16, SG_MAT_M, SG_MAT_N>>(&inter_shmem, inter_offset, false, KV_TILE);
+      // this bracket exists to scope the lifetime of variables, reducing register pressure
+      {
 #ifdef KV_DIRECT
-          let k_block_row = kv_tile + kv_block * SG_MAT_N;
-          let k_global_offset = k_head_offset + k_block_row * params.stride_k1;
+          let k_block_row = kv_tile + subgroup_id * SG_MAT_N;
+          var k_global_offset = k_head_offset + k_block_row * params.stride_k1;
 #else
-          let k_block_offset = kv_block * SG_MAT_N * HEAD_DIM_QK;
+          var k_block_offset = subgroup_id * SG_MAT_N * HEAD_DIM_QK;
 #endif
-          for (var head_dim_block = 0u; head_dim_block < HEAD_DIM_QK; head_dim_block += SG_MAT_K) {
-              // load q submatrix from shared memory
-              var q_sg_mat: subgroup_matrix_left<f16, SG_MAT_M, SG_MAT_K> = subgroupMatrixLoad<subgroup_matrix_left<f16, SG_MAT_M, SG_MAT_K>>(
-                  &q_shmem,
-                  head_dim_block,
-                  false,
-                  HEAD_DIM_QK
-              );
+          for (var kv_block = subgroup_id; kv_block < KV_BLOCKS; kv_block += num_subgroups) {
+              let inter_offset = kv_block * SG_MAT_N;
+              var acc: subgroup_matrix_result<f16, SG_MAT_M, SG_MAT_N> = subgroupMatrixLoad<subgroup_matrix_result<f16, SG_MAT_M, SG_MAT_N>>(&inter_shmem, inter_offset, false, KV_TILE);
+
+              var q_cur = subgroupMatrixLoad<subgroup_matrix_left<f16, SG_MAT_M, SG_MAT_K>>(&q_shmem, 0u, false, HEAD_DIM_QK);
 
-              // load k submatrix from device or shared memory
 #ifdef KV_DIRECT
-              var k_sg_mat: subgroup_matrix_right<f16, SG_MAT_K, SG_MAT_N> = subgroupMatrixLoad<subgroup_matrix_right<f16, SG_MAT_K, SG_MAT_N>>(
-                  &K,
-                  k_global_offset + head_dim_block,
-                  true,
-                  params.stride_k1
-              );
+              var k_cur = subgroupMatrixLoad<subgroup_matrix_right<f16, SG_MAT_K, SG_MAT_N>>(&K, k_global_offset + 0u, true, params.stride_k1);
 #else
-              var k_sg_mat: subgroup_matrix_right<f16, SG_MAT_K, SG_MAT_N> = subgroupMatrixLoad<subgroup_matrix_right<f16, SG_MAT_K, SG_MAT_N>>(
-                  &kv_shmem,
-                  k_block_offset + head_dim_block,
-                  true,
-                  HEAD_DIM_QK
-              );
+              var k_cur = subgroupMatrixLoad<subgroup_matrix_right<f16, SG_MAT_K, SG_MAT_N>>(&kv_shmem, k_block_offset + 0u, true, HEAD_DIM_QK);
 #endif
-              acc = subgroupMatrixMultiplyAccumulate(q_sg_mat, k_sg_mat, acc);
-          }
 
-          // store acc to shared memory for softmax (S matrix from paper)
-          subgroupMatrixStore(&inter_shmem, inter_offset, acc, false, KV_TILE);
+              var t: u32 = 1u;
+              for (; t + 1u < HEAD_DIM_QK / SG_MAT_K; t += 2u) {
+                  let h0 = t * SG_MAT_K;
+                  var q0 = subgroupMatrixLoad<subgroup_matrix_left<f16, SG_MAT_M, SG_MAT_K>>(&q_shmem, h0, false, HEAD_DIM_QK);
+#ifdef KV_DIRECT
+                  var k0 = subgroupMatrixLoad<subgroup_matrix_right<f16, SG_MAT_K, SG_MAT_N>>(&K, k_global_offset + h0, true, params.stride_k1);
+#else
+                  var k0 = subgroupMatrixLoad<subgroup_matrix_right<f16, SG_MAT_K, SG_MAT_N>>(&kv_shmem, k_block_offset + h0, true, HEAD_DIM_QK);
+#endif
+                  acc = subgroupMatrixMultiplyAccumulate(q_cur, k_cur, acc);
+                  q_cur = q0;
+                  k_cur = k0;
+
+                  let h1 = (t + 1u) * SG_MAT_K;
+                  var q1g = subgroupMatrixLoad<subgroup_matrix_left<f16, SG_MAT_M, SG_MAT_K>>(&q_shmem, h1, false, HEAD_DIM_QK);
+#ifdef KV_DIRECT
+                  var k1g = subgroupMatrixLoad<subgroup_matrix_right<f16, SG_MAT_K, SG_MAT_N>>(&K, k_global_offset + h1, true, params.stride_k1);
+#else
+                  var k1g = subgroupMatrixLoad<subgroup_matrix_right<f16, SG_MAT_K, SG_MAT_N>>(&kv_shmem, k_block_offset + h1, true, HEAD_DIM_QK);
+#endif
+                  acc = subgroupMatrixMultiplyAccumulate(q_cur, k_cur, acc);
+                  q_cur = q1g;
+                  k_cur = k1g;
+              }
+
+              // handle odd tail
+              if (t < HEAD_DIM_QK / SG_MAT_K) {
+                  let h = t * SG_MAT_K;
+                  var qn = subgroupMatrixLoad<subgroup_matrix_left<f16, SG_MAT_M, SG_MAT_K>>(&q_shmem, h, false, HEAD_DIM_QK);
+#ifdef KV_DIRECT
+                  var kn = subgroupMatrixLoad<subgroup_matrix_right<f16, SG_MAT_K, SG_MAT_N>>(&K, k_global_offset + h, true, params.stride_k1);
+#else
+                  var kn = subgroupMatrixLoad<subgroup_matrix_right<f16, SG_MAT_K, SG_MAT_N>>(&kv_shmem, k_block_offset + h, true, HEAD_DIM_QK);
+#endif
+                  acc = subgroupMatrixMultiplyAccumulate(q_cur, k_cur, acc);
+                  q_cur = qn;
+                  k_cur = kn;
+              }
+
+              acc = subgroupMatrixMultiplyAccumulate(q_cur, k_cur, acc);
+
+#ifdef KV_DIRECT
+              k_global_offset += num_subgroups * SG_MAT_N * params.stride_k1;
+#else
+              k_block_offset += num_subgroups * SG_MAT_N * HEAD_DIM_QK;
+#endif
+              subgroupMatrixStore(&inter_shmem, inter_offset, acc, false, KV_TILE);
+          }
       }
 
+
 #ifdef MASK
       // load mask tile into shared memory for this KV block
       // TODO: optimize and skip if mask is -INF for the entire tile
@@ -495,7 +531,6 @@ fn main(@builtin(workgroup_id) wg_id: vec3<u32>,
                   false,
                   HEAD_DIM_V
               );
-
               for (var kv_block = 0u; kv_block < KV_BLOCKS; kv_block++) {
                   let p_offset = kv_block * SG_MAT_N;
                   var p_sg_mat: subgroup_matrix_left<f16, SG_MAT_M, SG_MAT_K> = subgroupMatrixLoad<subgroup_matrix_left<f16, SG_MAT_M, SG_MAT_K>>(
@@ -527,11 +562,9 @@ fn main(@builtin(workgroup_id) wg_id: vec3<u32>,
                   // O += P * V
                   o_sg_mat = subgroupMatrixMultiplyAccumulate(p_sg_mat, v_sg_mat, o_sg_mat);
               }
-
               // store O back to shared memory
               subgroupMatrixStore(&o_shmem, head_dim_block, o_sg_mat, false, HEAD_DIM_V);
       }
-
       workgroupBarrier();
     }
 
@@ -566,26 +599,38 @@ fn main(@builtin(workgroup_id) wg_id: vec3<u32>,
                 o_shmem[idx] = f16(val);
             }
     }
-
     workgroupBarrier();
 #endif
-
-    // write output back to global memory
     for (var q_tile_row = subgroup_id;
-         q_tile_row < Q_TILE;
-         q_tile_row += num_subgroups) {
-            let global_q_row = q_row_start + q_tile_row;
-            if (global_q_row >= params.seq_len_q) {
-                break;
-            }
+        q_tile_row < Q_TILE;
+        q_tile_row += num_subgroups) {
 
-            let exp_sum = exp_sum_shmem[q_tile_row];
-            let scale = select(0.0, 1.0 / exp_sum, exp_sum != 0);
+        let global_q_row = q_row_start + q_tile_row;
+        if (global_q_row >= params.seq_len_q) { break; }
 
-            for (var elem_idx = sg_inv_id; elem_idx < HEAD_DIM_V; elem_idx += subgroup_size) {
-                let o_val = o_shmem[q_tile_row * HEAD_DIM_V + elem_idx];
-                let scaled = f32(o_val) * scale;
-                dst[dst_global_offset + q_tile_row * dst2_stride + elem_idx] = scaled;
-            }
+        let exp_sum = exp_sum_shmem[q_tile_row];
+        let scale = select(0.0, 1.0 / exp_sum, exp_sum != 0.0);
+
+        let row_base: u32 = dst_global_offset + q_tile_row * dst2_stride;
+
+        for (var elem_base = sg_inv_id * 4u;
+            elem_base < HEAD_DIM_V;
+            elem_base += subgroup_size * 4u) {
+
+            let i0 = q_tile_row * HEAD_DIM_V + (elem_base + 0u);
+            let i1 = q_tile_row * HEAD_DIM_V + (elem_base + 1u);
+            let i2 = q_tile_row * HEAD_DIM_V + (elem_base + 2u);
+            let i3 = q_tile_row * HEAD_DIM_V + (elem_base + 3u);
+
+            let v = vec4<f32>(
+                f32(o_shmem[i0]) * scale,
+                f32(o_shmem[i1]) * scale,
+                f32(o_shmem[i2]) * scale,
+                f32(o_shmem[i3]) * scale
+            );
+
+            let dst_vec_index: u32 = (row_base + elem_base) >> 2u;
+            dst[dst_vec_index] = v;
+        }
     }
 }