]> git.djapps.eu Git - pkg/ggml/sources/ggml/commitdiff
vulkan: optimize flash attention split_k_reduce (llama/14554)
authorJeff Bolz <redacted>
Tue, 8 Jul 2025 18:11:42 +0000 (13:11 -0500)
committerGeorgi Gerganov <redacted>
Sat, 12 Jul 2025 13:05:00 +0000 (16:05 +0300)
* vulkan: allow FA split_k with smaller KV values

* vulkan: spread split_k_reduce work across more threads

k_num can get rather large. Use the whole workgroup to reduce the M/L values.

Launch a thread for each element in the HSV dimension of the output. Helps a
lot for large HSV (like deepseek).

src/ggml-vulkan/ggml-vulkan.cpp
src/ggml-vulkan/vulkan-shaders/flash_attn_split_k_reduce.comp

index 7b8faf17879d62a0f1a9a0f315e3a10020f5f216..2245a655498c5f147d6a2835f0d59ab4d72a0278 100644 (file)
@@ -2706,7 +2706,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
     ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_IQ4_NL],  "get_rows_iq4_nl_f32",  get_rows_iq4_nl_f32_len,  get_rows_iq4_nl_f32_data,  "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
 
     ggml_vk_create_pipeline(device, device->pipeline_matmul_split_k_reduce, "split_k_reduce", split_k_reduce_len, split_k_reduce_data, "main", 2, 2 * sizeof(uint32_t), {256 * 4, 1, 1}, {}, 1);
-    ggml_vk_create_pipeline(device, device->pipeline_flash_attn_split_k_reduce, "fa_split_k_reduce", fa_split_k_reduce_len, fa_split_k_reduce_data, "main", 2, 3 * sizeof(uint32_t), {1, 1, 1}, {}, 1, true);
+    ggml_vk_create_pipeline(device, device->pipeline_flash_attn_split_k_reduce, "fa_split_k_reduce", fa_split_k_reduce_len, fa_split_k_reduce_data, "main", 2, 4 * sizeof(uint32_t), {1, device->subgroup_size, 1}, {device->subgroup_size}, 1, true);
     ggml_vk_create_pipeline(device, device->pipeline_quantize_q8_1, "quantize_q8_1", quantize_q8_1_len, quantize_q8_1_data, "main", 2, 1 * sizeof(uint32_t), {32 * device->subgroup_size / 8, 1, 1}, { device->subgroup_size }, 1);
 
     for (uint32_t i = 0; i < p021_max_gqa_ratio; ++i) {
@@ -6252,13 +6252,13 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
     const uint32_t shader_core_count = ctx->device->shader_core_count ? ctx->device->shader_core_count : 16;
 
     // Try to use split_k when KV is large enough to be worth the overhead
-    if (workgroups_x == 1 && shader_core_count > 0 && KV >= 512) {
+    if (workgroups_x == 1 && shader_core_count > 0) {
         // Try to run two workgroups per SM.
         split_k = shader_core_count * 2 / (workgroups_y * workgroups_z);
         if (split_k > 1) {
             // Try to evenly split KV into split_k chunks, but it needs to be a multiple
             // of "align", so recompute split_k based on that.
-            split_kv = ROUNDUP_POW2(KV / split_k, pipelines[1]->align);
+            split_kv = ROUNDUP_POW2(std::max(1u, KV / split_k), pipelines[1]->align);
             split_k = CEIL_DIV(KV, split_kv);
             workgroups_x = split_k;
         }
@@ -6392,7 +6392,7 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
                                         vk_subbuffer{ctx->prealloc_split_k, 0, VK_WHOLE_SIZE},
                                         vk_subbuffer{d_D, d_buf_offset, VK_WHOLE_SIZE},
                                     },
-                                    pc2, { (uint32_t)ne1, 1, (uint32_t)ne3 });
+                                    pc2, { (uint32_t)ne1, HSV, (uint32_t)ne3 });
     } else {
         ggml_vk_dispatch_pipeline(ctx, subctx, pipeline,
                                     {
index 599cef072e931e3a503a3512901f484357ebad1b..0a17a9df23f9fbd09baee3f85c1f836062ed81ef 100644 (file)
@@ -2,9 +2,9 @@
 
 #extension GL_EXT_control_flow_attributes : enable
 
-#define BLOCK_SIZE 32
+layout(constant_id = 0) const uint BLOCK_SIZE = 32;
 
-layout(local_size_x = BLOCK_SIZE, local_size_y = 1, local_size_z = 1) in;
+layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
 
 layout (binding = 0) readonly buffer A {float data_a[];};
 layout (binding = 1) writeonly buffer D {float data_d[];};
@@ -16,6 +16,8 @@ layout (push_constant) uniform parameter {
     uint k_num;
 } p;
 
+shared float tmpsh[BLOCK_SIZE];
+
 void main() {
     // Each workgroup handles a row
     const uint n = gl_WorkGroupID.x;
@@ -32,23 +34,51 @@ void main() {
 
     // Compute the max m value for the row
     float m_max = -1.0/0.0;
-    [[unroll]] for (uint k = 0; k < k_num; ++k) {
-        float m = data_a[m_offset + k * lm_stride];
+    for (uint k = 0; k + tid < k_num; k += BLOCK_SIZE) {
+        float m = data_a[m_offset + (k + tid) * lm_stride];
         m_max = max(m_max, m);
     }
 
+    // reduce across the workgroup
+    tmpsh[tid] = m_max;
+    barrier();
+    [[unroll]] for (uint s = BLOCK_SIZE/2; s > 0; s >>= 1) {
+        if (tid < s) {
+            m_max = max(m_max, tmpsh[tid + s]);
+            tmpsh[tid] = m_max;
+        }
+        barrier();
+    }
+    m_max = tmpsh[0];
+
+    barrier();
+
     // Compute L based on m_max
     float L = 0;
-    [[unroll]] for (uint k = 0; k < k_num; ++k) {
-        float l = data_a[l_offset + k * lm_stride];
-        float m = data_a[m_offset + k * lm_stride];
+    for (uint k = 0; k + tid < k_num; k += BLOCK_SIZE) {
+        float l = data_a[l_offset + (k + tid) * lm_stride];
+        float m = data_a[m_offset + (k + tid) * lm_stride];
         L += exp(m - m_max) * l;
     }
 
+    // reduce across the workgroup
+    tmpsh[tid] = L;
+    barrier();
+    [[unroll]] for (uint s = BLOCK_SIZE/2; s > 0; s >>= 1) {
+        if (tid < s) {
+            L += tmpsh[tid + s];
+            tmpsh[tid] = L;
+        }
+        barrier();
+    }
+    L = tmpsh[0];
+
     L = 1.0 / L;
 
+    // D dimension is split across workgroups in the y dimension
+    uint d = tid + gl_WorkGroupID.y * BLOCK_SIZE;
     // Scale and sum the O contributions based on m_max and store the result to memory
-    for (uint d = tid; d < D; d += BLOCK_SIZE) {
+    if (d < D) {
         float O = 0.0;
         [[unroll]] for (uint k = 0; k < k_num; ++k) {
             uint o_offset = D * N * (k + iq3 * k_num) + D * n + d;