]> git.djapps.eu Git - pkg/ggml/sources/ggml/commitdiff
vulkan: Handle updated FA dim2/3 definition (llama/14518)
authorJeff Bolz <redacted>
Sat, 5 Jul 2025 07:26:04 +0000 (02:26 -0500)
committerGeorgi Gerganov <redacted>
Sat, 12 Jul 2025 13:05:00 +0000 (16:05 +0300)
* vulkan: Handle updated FA dim2/3 definition

Pack mask boolean and n_head_log2 into a single dword to keep the push
constant block under the 128B limit.

* handle null mask for gqa

* allow gqa with dim3>1

src/ggml-vulkan/ggml-vulkan.cpp
src/ggml-vulkan/vulkan-shaders/flash_attn.comp
src/ggml-vulkan/vulkan-shaders/flash_attn_base.comp
src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp
src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp

index 22a34a433568f6a2e6a69612a5e553f224f8f28e..e8df00d4183acfa132dcf8f5135247050c5cff5d 100644 (file)
@@ -636,6 +636,7 @@ struct vk_flash_attn_push_constants {
     uint32_t nev3;
     uint32_t nem1;
     uint32_t nem2;
+    uint32_t nem3;
 
     uint32_t nb01;
     uint32_t nb02;
@@ -651,8 +652,7 @@ struct vk_flash_attn_push_constants {
     float max_bias;
     float logit_softcap;
 
-    uint32_t mask;
-    uint32_t n_head_log2;
+    uint32_t mask_n_head_log2;
     float m0;
     float m1;
 
@@ -6111,6 +6111,7 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
 
     const uint32_t nem1 = mask ? mask->ne[1] : 0;
     const uint32_t nem2 = mask ? mask->ne[2] : 0;
+    const uint32_t nem3 = mask ? mask->ne[3] : 0;
 
     const uint32_t HSK = nek0;
     const uint32_t HSV = nev0;
@@ -6178,7 +6179,7 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
     }
 
     if (N == 1 && qk_ratio > 1 && qk_ratio <= max_gqa &&
-        qk_ratio * nek2 == neq2 && nek2 == nev2 && neq3 == 1 && nek3 == 1 && nev3 == 1) {
+        qk_ratio * nek2 == neq2 && nek2 == nev2 && nem2 <= 1) {
         // grouped query attention - make the N dimension equal to gqa_ratio, reduce
         // workgroups proportionally in y dimension. The shader will detect gqa_ratio > 1
         // and change addressing calculations to index Q's dimension 2.
@@ -6348,17 +6349,19 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
         }
     }
 
+    uint32_t mask_n_head_log2 = ((mask != nullptr) << 16) | n_head_log2;
+
     const vk_flash_attn_push_constants pc = { N, KV,
                                               (uint32_t)ne1, (uint32_t)ne2, (uint32_t)ne3,
                                               (uint32_t)neq2, (uint32_t)neq3,
                                               (uint32_t)nek2, (uint32_t)nek3,
                                               (uint32_t)nev2, (uint32_t)nev3,
-                                              nem1, nem2,
+                                              nem1, nem2, nem3,
                                               q_stride, (uint32_t)nbq2, (uint32_t)nbq3,
                                               k_stride, (uint32_t)nbk2, (uint32_t)nbk3,
                                               v_stride, (uint32_t)nbv2, (uint32_t)nbv3,
                                               scale, max_bias, logit_softcap,
-                                              mask != nullptr, n_head_log2, m0, m1,
+                                              mask_n_head_log2, m0, m1,
                                               gqa_ratio, split_kv, split_k };
 
     ggml_vk_sync_buffers(subctx);
@@ -10303,12 +10306,6 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
                 if (op->src[3] && op->src[3]->type != GGML_TYPE_F16) {
                     return false;
                 }
-                // TODO: support broadcast
-                // note: this was initially implemented in https://github.com/ggml-org/llama.cpp/pull/14449, but
-                //       the interface of ggml_flash_attn_ext() changed in https://github.com/ggml-org/llama.cpp/pull/14505
-                if (op->src[0]->ne[3] != 1 || (op->src[3] && op->src[3]->ne[2] != 1)) {
-                    return false;
-                }
                 // It's straightforward to support different K/V dequant, but would
                 // significantly increase the number of pipelines
                 if (op->src[1]->type != op->src[2]->type) {
index 788a5e065d1dba662cc2d7e2326803ffd720e487..45c6e7736ace687e549d1fbd115e61bc54a4bb1b 100644 (file)
@@ -101,8 +101,8 @@ void main() {
     uint32_t v_offset = (iv2*p.nb22 + iv3*p.nb23) / 2;
 #endif
     uint32_t m_offset = 0;
-    if (p.nem2 != 1) {
-        m_offset = (iq3 % p.nem2) * p.nem1 * KV;
+    if (p.nem2 != 1 || p.nem3 != 1) {
+        m_offset = ((iq3 % p.nem3) * p.nem2 + (iq2 % p.nem2)) * p.nem1 * KV;
     }
 
     [[dont_unroll]]
@@ -149,7 +149,7 @@ void main() {
             }
         }
 
-        if (p.mask != 0) {
+        if ((p.mask_n_head_log2 & MASK_ENABLE_BIT) != 0) {
 
             [[unroll]] for (uint32_t idx = 0; idx < Bc * Br; idx += gl_WorkGroupSize.x) {
                 uint32_t c = (idx + tid) % Bc;
index 6609f0bad3d85099c8ec9e3ae14167ea72ecfa62..7defe72b403b5eb2dad19fa6dd127fb7c3de963f 100644 (file)
@@ -25,6 +25,7 @@ layout (push_constant) uniform parameter {
     uint32_t nev3;
     uint32_t nem1;
     uint32_t nem2;
+    uint32_t nem3;
 
     uint32_t nb01;
     uint32_t nb02;
@@ -40,8 +41,7 @@ layout (push_constant) uniform parameter {
     float max_bias;
     float logit_softcap;
 
-    uint32_t mask;
-    uint32_t n_head_log2;
+    uint32_t mask_n_head_log2;
     float m0;
     float m1;
 
@@ -50,6 +50,9 @@ layout (push_constant) uniform parameter {
     uint32_t k_num;
 } p;
 
+#define MASK_ENABLE_BIT (1<<16)
+#define N_LOG2_MASK 0xFFFF
+
 layout (binding = 4) writeonly buffer O {D_TYPE data_o[];};
 
 #if defined(A_TYPE_PACKED16)
@@ -100,8 +103,10 @@ ACC_TYPE perElemOpComputeSlope(const in uint32_t r, const in uint32_t c, const i
 {
     const uint32_t h = iq2 + (r % p.gqa_ratio);
 
-    const ACC_TYPE base = ACC_TYPE(h < p.n_head_log2 ? p.m0 : p.m1);
-    const int      exph = int(h < p.n_head_log2 ? h + 1 : 2*(h - p.n_head_log2) + 1);
+    uint32_t n_head_log2 = p.mask_n_head_log2 & N_LOG2_MASK;
+
+    const ACC_TYPE base = ACC_TYPE(h < n_head_log2 ? p.m0 : p.m1);
+    const int      exph = int(h < n_head_log2 ? h + 1 : 2*(h - n_head_log2) + 1);
 
     return ACC_TYPE(pow(base, ACC_TYPE(exph)));
 }
index e74e2fa9346749660eed71e6412d2948b49b0345..486735fe8b0c97fbe9b26882a0b2ef8fa6c32c29 100644 (file)
@@ -126,8 +126,8 @@ void main() {
     uint32_t v_offset = (iv2*p.nb22 + iv3*p.nb23) / 2;
 #endif
     uint32_t m_offset = 0;
-    if (p.nem2 != 1) {
-        m_offset = (iq3 % p.nem2) * p.nem1 * KV;
+    if (p.nem2 != 1 || p.nem3 != 1) {
+        m_offset = ((iq3 % p.nem3) * p.nem2 + (iq2 % p.nem2)) * p.nem1 * KV;
     }
 
     [[dont_unroll]]
@@ -182,7 +182,7 @@ void main() {
             barrier();
         }
 
-        if (p.mask != 0) {
+        if ((p.mask_n_head_log2 & MASK_ENABLE_BIT) != 0) {
             [[unroll]] for (uint32_t idx = 0; idx < Bc * Br; idx += gl_WorkGroupSize.x) {
                 uint32_t c = (idx + tid) % Bc;
                 uint32_t r = (idx + tid) / Bc;
index 8792d5195e45731ff2629d5ddd173066b49051ff..274f48fcabdd081abc14fa4c243cd60bcb7c44c9 100644 (file)
@@ -131,8 +131,8 @@ void main() {
     }
 
     uint32_t m_offset = 0;
-    if (p.nem2 != 1) {
-        m_offset = (iq3 % p.nem2) * p.nem1 * KV * 2 /*sizeof(float16_t)*/;
+    if (p.nem2 != 1 || p.nem3 != 1) {
+        m_offset = ((iq3 % p.nem3) * p.nem2 + (iq2 % p.nem2)) * p.nem1 * KV * 2 /*sizeof(float16_t)*/;
     }
 
     [[dont_unroll]]
@@ -153,7 +153,7 @@ void main() {
             }
         }
 
-        if (p.mask != 0) {
+        if ((p.mask_n_head_log2 & MASK_ENABLE_BIT) != 0) {
             tensorLayoutNV<2, Clamp> tensorLayoutM = createTensorLayoutNV(2, Clamp);
             tensorLayoutM = setTensorLayoutDimensionNV(tensorLayoutM, p.nem1, KV);
             tensorLayoutM = setTensorLayoutStrideNV(tensorLayoutM, m_stride, 1);