]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
vulkan: enable coopmat2 FA gqa and split_k optimizations more often (#12931)
authorJeff Bolz <redacted>
Wed, 16 Apr 2025 18:37:25 +0000 (13:37 -0500)
committerGitHub <redacted>
Wed, 16 Apr 2025 18:37:25 +0000 (20:37 +0200)
The grouped query attention optmization doesn't require a power of two ratio,
the only thing relying on it was the modulo operation written as bitwise &.

split_k need not depend on gqa_ratio - enable it any time there's only one
workgroup in the X dimension. The shader gets the split index from the x coord,
and multiple workgroups in the X dimension (pre-split) indicates a larger
FA operation that wouldn't need splitting.

ggml/src/ggml-vulkan/ggml-vulkan.cpp
ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp
tests/test-backend-ops.cpp

index 783a0ff86c1c1ad4f5f2ecea53c32e7c547f4f7a..0e9b2e8135a7a07b08e0cd1821d4a45223f5d89a 100644 (file)
@@ -5531,7 +5531,7 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
     uint32_t workgroups_y = (uint32_t)neq2;
     uint32_t workgroups_z = (uint32_t)neq3;
 
-    if (N == 1 && qk_ratio > 1 && is_pow2(qk_ratio) && gqa_ratio <= flash_attention_num_small_rows &&
+    if (N == 1 && qk_ratio > 1 && gqa_ratio <= flash_attention_num_small_rows &&
         qk_ratio * nek2 == neq2 && nek2 == nev2 && neq3 == 1 && nek3 == 1 && nev3 == 1) {
         // grouped query attention - make the N dimension equal to gqa_ratio, reduce
         // workgroups proportionally in y dimension. The shader will detect gqa_ratio > 1
@@ -5544,8 +5544,8 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
     uint32_t split_kv = KV;
     uint32_t split_k = 1;
 
-    if (gqa_ratio > 1 && ctx->device->shader_core_count > 0) {
-        GGML_ASSERT(workgroups_x == 1);
+    // Try to use split_k when KV is large enough to be worth the overhead
+    if (workgroups_x == 1 && ctx->device->shader_core_count > 0 && KV >= 512) {
         // Try to run two workgroups per SM.
         split_k = ctx->device->shader_core_count * 2 / workgroups_y;
         if (split_k > 1) {
index e1baa85f9e33050b86e7282178cf5b8e878d7ff7..b926a578aded6df20c6b40442cdd06c2648b7d71 100644 (file)
@@ -131,7 +131,7 @@ ACC_TYPE perElemOpStoreCol0(const in uint32_t r, const in uint32_t c, const in A
 // Load the slope matrix, indexed by Q's dimension 2.
 ACC_TYPE perElemOpComputeSlope(const in uint32_t r, const in uint32_t c, const in ACC_TYPE elem, const in uint32_t iq2)
 {
-    const uint32_t h = iq2 + (r & (p.gqa_ratio - 1));
+    const uint32_t h = iq2 + (r % p.gqa_ratio);
 
     const ACC_TYPE base = ACC_TYPE(h < p.n_head_log2 ? p.m0 : p.m1);
     const int      exph = int(h < p.n_head_log2 ? h + 1 : 2*(h - p.n_head_log2) + 1);
index 3a5741c8d959d2d5f01dce0999e8cc418c5a8bd1..1ee742894695bdbf70e6b26107bab170dd9a935b 100644 (file)
@@ -4532,7 +4532,9 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_perf() {
 
     for (int kv : { 4096, 8192, 16384, }) {
         for (int hs : { 64, 128, }) {
-            test_cases.emplace_back(new test_flash_attn_ext(hs, hs, 8, 4, kv, 1, true, 0, 0, GGML_PREC_F32, GGML_TYPE_F16));
+            for (int nr : { 1, 4, }) {
+                test_cases.emplace_back(new test_flash_attn_ext(hs, hs, 8, nr, kv, 1, true, 0, 0, GGML_PREC_F32, GGML_TYPE_F16));
+            }
         }
     }