]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
ggml : fix FA mask dim 2 and 3 (#14505)
authorGeorgi Gerganov <redacted>
Thu, 3 Jul 2025 07:46:57 +0000 (10:46 +0300)
committerGitHub <redacted>
Thu, 3 Jul 2025 07:46:57 +0000 (10:46 +0300)
* ggml : fix FA mask dim 2 and 3

ggml-ci

* backends : unsupport batched FA in CUDA and Vulkan

ggml-ci

* vulkan : disable FA for mask->ne[2] != 1

ggml/include/ggml.h
ggml/src/ggml-cpu/ops.cpp
ggml/src/ggml-cuda/ggml-cuda.cu
ggml/src/ggml-metal/ggml-metal-impl.h
ggml/src/ggml-metal/ggml-metal.m
ggml/src/ggml-metal/ggml-metal.metal
ggml/src/ggml-vulkan/ggml-vulkan.cpp
ggml/src/ggml.c
tests/test-backend-ops.cpp

index 78983bcc56ec21532c500f623b818773fcf07491..b2ad13f194d26f9bbd208869036c3ac340efc608 100644 (file)
@@ -1983,15 +1983,16 @@ extern "C" {
 
 #define GGML_KQ_MASK_PAD 64
 
-    // q:    [n_embd_k, n_batch,     n_head,    ne3]
-    // k:    [n_embd_k, n_kv,        n_head_kv, ne3]
-    // v:    [n_embd_v, n_kv,        n_head_kv, ne3] !! not transposed !!
-    // mask: [n_kv,     n_batch_pad, ne32,      1] !! n_batch_pad = GGML_PAD(n_batch, GGML_KQ_MASK_PAD) !!
-    // res:  [n_embd_v, n_head,      n_batch,   ne3] !! permuted !!
+    // q:    [n_embd_k, n_batch,     n_head,    ne3 ]
+    // k:    [n_embd_k, n_kv,        n_head_kv, ne3 ]
+    // v:    [n_embd_v, n_kv,        n_head_kv, ne3 ] !! not transposed !!
+    // mask: [n_kv,     n_batch_pad, ne32,      ne33] !! n_batch_pad = GGML_PAD(n_batch, GGML_KQ_MASK_PAD) !!
+    // res:  [n_embd_v, n_head,      n_batch,   ne3 ] !! permuted !!
     //
     // broadcast:
     //   n_head % n_head_kv == 0
-    //   ne3    % ne32      == 0
+    //   n_head % ne32      == 0
+    //   ne3    % ne33      == 0
     //
     GGML_API struct ggml_tensor * ggml_flash_attn_ext(
             struct ggml_context * ctx,
index 2f15341954214d321a4035339960d995446998a6..0fb2c08b5bcb5ffc15676a42ad6c95a354bdaca1 100644 (file)
@@ -7799,7 +7799,7 @@ static void ggml_compute_forward_flash_attn_ext_f16(
             memset(VKQ32, 0, DV*sizeof(float));
         }
 
-        const ggml_fp16_t * mp = mask ? (ggml_fp16_t *)((char *) mask->data + iq1*mask->nb[1] + (iq3%mask->ne[2])*mask->nb[2]) : NULL;
+        const ggml_fp16_t * mp = mask ? (ggml_fp16_t *)((char *) mask->data + iq1*mask->nb[1] + (iq2%mask->ne[2])*mask->nb[2] + (iq3%mask->ne[3])*mask->nb[3]) : NULL;
 
         // k indices
         const int ik3 = iq3 / rk3;
index 121e83641b23df227a731a070a45e2e72c90c005..1c04bba52e88bf2419a3d10854086147c88cf06c 100644 (file)
@@ -3390,7 +3390,8 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
                 return false;
             }
             // TODO: support broadcast
-            // ref: https://github.com/ggml-org/llama.cpp/pull/14435
+            // note: this was initially implemented in https://github.com/ggml-org/llama.cpp/pull/14500, but
+            //       the interface of ggml_flash_attn_ext() changed in https://github.com/ggml-org/llama.cpp/pull/14505
             if (op->src[0]->ne[3] != 1) {
                 return false;
             }
index adb65c0e7b7defea78fcd19de0a544ae86c79788..752d55c216604a9788275a5d9e20def12be9e15b 100644 (file)
@@ -230,8 +230,10 @@ typedef struct {
     uint64_t nb22;
     uint64_t nb23;
     int32_t  ne32;
+    int32_t  ne33;
     uint64_t nb31;
     uint64_t nb32;
+    uint64_t nb33;
     int32_t  ne1;
     int32_t  ne2;
     float    scale;
index de40430ef3565f4781e8661c2960b24fe3fa584e..5e5467e88a1ab20122b7bc0395e4a0f6a6b850e3 100644 (file)
@@ -5018,8 +5018,10 @@ static bool ggml_metal_encode_node(
                     /*.nb22          =*/ nb22,
                     /*.nb23          =*/ nb23,
                     /*.ne32          =*/ ne32,
+                    /*.ne33          =*/ ne33,
                     /*.nb31          =*/ nb31,
                     /*.nb32          =*/ nb32,
+                    /*.nb33          =*/ nb33,
                     /*.ne1           =*/ ne1,
                     /*.ne2           =*/ ne2,
                     /*.scale         =*/ scale,
index a7657e0fc3c36b976d9bff0c3209d026e426a426..ebde005f33e028013fa737abfdb99cfbe9d4256a 100644 (file)
@@ -3857,7 +3857,7 @@ kernel void kernel_flash_attn_ext(
                 // load the mask in shared memory
                 #pragma unroll(Q)
                 for (short j = 0; j < Q; ++j) {
-                    device const half * pm = (device const half *) ((device const char *) mask + (iq1 + j)*args.nb31 + (iq3%args.ne32)*args.nb32);
+                    device const half * pm = (device const half *) ((device const char *) mask + (iq1 + j)*args.nb31 + (iq2%args.ne32)*args.nb32 + (iq3%args.ne33)*args.nb33);
 
                     const float m = pm[ic + tiisg];
 
@@ -4343,7 +4343,7 @@ kernel void kernel_flash_attn_ext_vec(
         const bool has_mask = mask != q;
 
         // pointer to the mask
-        device const half * pm = (device const half *) (mask + iq1*args.nb31 + (iq3%args.ne32)*args.nb32);
+        device const half * pm = (device const half *) (mask + iq1*args.nb31 + (iq2%args.ne32)*args.nb32 + (iq3%args.ne33)*args.nb33);
 
         float slope = 1.0f;
 
index 25f70127a62c562d74eaa4687fc74e69456c3d4c..ccb0d47b6759b696a4bb902d194afc7a76d56dc7 100644 (file)
@@ -10265,6 +10265,12 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
                 if (op->src[3] && op->src[3]->type != GGML_TYPE_F16) {
                     return false;
                 }
+                // TODO: support broadcast
+                // note: this was initially implemented in https://github.com/ggml-org/llama.cpp/pull/14449, but
+                //       the interface of ggml_flash_attn_ext() changed in https://github.com/ggml-org/llama.cpp/pull/14505
+                if (op->src[0]->ne[3] != 1 || (op->src[3] && op->src[3]->ne[2] != 1)) {
+                    return false;
+                }
                 // It's straightforward to support different K/V dequant, but would
                 // significantly increase the number of pipelines
                 if (op->src[1]->type != op->src[2]->type) {
index 97da26b37849d4e3b59959858b7fb338e7ff565b..fdb57e178588498ffe20e3b6c9d062d03cc02aa3 100644 (file)
@@ -3674,7 +3674,6 @@ static struct ggml_tensor * ggml_soft_max_impl(
     if (mask) {
         GGML_ASSERT(mask->type == GGML_TYPE_F16 || mask->type == GGML_TYPE_F32);
         GGML_ASSERT(ggml_is_contiguous(mask));
-        GGML_ASSERT(ggml_is_3d(mask));
         GGML_ASSERT(mask->ne[0] == a->ne[0]);
         GGML_ASSERT(mask->ne[1] >= a->ne[1]);
         GGML_ASSERT(a->ne[2]%mask->ne[2] == 0);
@@ -4704,12 +4703,12 @@ struct ggml_tensor * ggml_flash_attn_ext(
 
     if (mask) {
         GGML_ASSERT(ggml_is_contiguous(mask));
-        GGML_ASSERT(mask->ne[2] == q->ne[3]);
         GGML_ASSERT(mask->ne[1] >= GGML_PAD(q->ne[1], GGML_KQ_MASK_PAD) &&
                 "the Flash-Attention kernel requires the mask to be padded to GGML_KQ_MASK_PAD and at least n_queries big");
         //GGML_ASSERT(ggml_can_repeat_rows(mask, qk));
 
-        GGML_ASSERT(q->ne[3] % mask->ne[2] == 0);
+        GGML_ASSERT(q->ne[2] % mask->ne[2] == 0);
+        GGML_ASSERT(q->ne[3] % mask->ne[3] == 0);
     }
 
     if (max_bias > 0.0f) {
index 5c3db854cb498bc843260aebe3e9fcd8a1c65f6f..2ab6dd06fc88e4c862d98af3422590a0ccf4e121 100644 (file)
@@ -3637,7 +3637,7 @@ struct test_flash_attn_ext : public test_case {
 
         ggml_tensor * m = nullptr;
         if (mask) {
-            m = ggml_new_tensor_4d(ctx, GGML_TYPE_F16, kv, GGML_PAD(nb, GGML_KQ_MASK_PAD), nr23[1], 1);
+            m = ggml_new_tensor_4d(ctx, GGML_TYPE_F16, kv, GGML_PAD(nb, GGML_KQ_MASK_PAD), nr23[0], nr23[1]);
             ggml_set_name(m, "m");
         }
 
@@ -4751,7 +4751,7 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
                                 test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {ne0-1, ne1-1, 1, 1}, mask, m_prec, {1, 1}, scale, max_bias));
 
                                 if (ne0 <= 32 && ne1 <= 32) {
-                                    test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {ne0,   ne1,   1, 1}, mask, m_prec, {3, 1}, scale, max_bias));
+                                    test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {ne0,   ne1,   1, 3}, mask, m_prec, {3, 1}, scale, max_bias));
                                     test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {ne0-1, ne1-1, 1, 1}, mask, m_prec, {2, 3}, scale, max_bias));
                                 }
                             }