]> git.djapps.eu Git - pkg/ggml/sources/ggml/commitdiff
metal : make the FA extra sizes consistent (llama/17143)
authorGeorgi Gerganov <redacted>
Fri, 14 Nov 2025 07:13:34 +0000 (09:13 +0200)
committerGeorgi Gerganov <redacted>
Mon, 17 Nov 2025 13:34:43 +0000 (15:34 +0200)
src/ggml-metal/ggml-metal-ops.cpp

index d9811e31159b1061a9024ff2fab772e62f282246..c48f7cd29fe16746a8a7f49c0ae8ff6d392356ba 100644 (file)
@@ -1975,7 +1975,9 @@ size_t ggml_metal_op_flash_attn_ext_extra_pad(const ggml_tensor * op) {
     const bool has_mask = op->src[3] != nullptr;
 
     if (ggml_metal_op_flash_attn_ext_use_vec(op)) {
-        const bool has_kvpad = ne11 % OP_FLASH_ATTN_EXT_VEC_NCPSG != 0;
+        // note: always reserve the padding space to avoid graph reallocations
+        //const bool has_kvpad = ne11 % OP_FLASH_ATTN_EXT_VEC_NCPSG != 0;
+        const bool has_kvpad = true;
 
         if (has_kvpad) {
             res += OP_FLASH_ATTN_EXT_VEC_NCPSG*(
@@ -1984,7 +1986,8 @@ size_t ggml_metal_op_flash_attn_ext_extra_pad(const ggml_tensor * op) {
                 (has_mask ? ggml_type_size(GGML_TYPE_F16)*ne31*ne32*ne33 : 0));
         }
     } else {
-        const bool has_kvpad = ne11 % OP_FLASH_ATTN_EXT_NCPSG != 0;
+        //const bool has_kvpad = ne11 % OP_FLASH_ATTN_EXT_NCPSG != 0;
+        const bool has_kvpad = true;
 
         if (has_kvpad) {
             res += OP_FLASH_ATTN_EXT_NCPSG*(
@@ -2020,9 +2023,10 @@ size_t ggml_metal_op_flash_attn_ext_extra_blk(const ggml_tensor * op) {
     const bool is_vec = ggml_metal_op_flash_attn_ext_use_vec(op);
 
     // this optimization is not useful for the vector kernels
-    if (is_vec) {
-        return res;
-    }
+    // note: always reserve the blk buffer to avoid graph reallocations
+    //if (is_vec) {
+    //    return res;
+    //}
 
     const int nqptg = is_vec ? OP_FLASH_ATTN_EXT_VEC_NQPTG : OP_FLASH_ATTN_EXT_NQPTG;
     const int ncpsg = is_vec ? OP_FLASH_ATTN_EXT_VEC_NCPSG : OP_FLASH_ATTN_EXT_NCPSG;
@@ -2049,13 +2053,16 @@ size_t ggml_metal_op_flash_attn_ext_extra_tmp(const ggml_tensor * op) {
 
     size_t res = 0;
 
-    if (ggml_metal_op_flash_attn_ext_use_vec(op)) {
+    // note: always reserve the temp buffer to avoid graph reallocations
+    //if (ggml_metal_op_flash_attn_ext_use_vec(op)) {
+    if (true) {
         const int64_t nwg = 32;
+        const int64_t ne01_max = std::min(ne01, 32);
 
         // temp buffer for writing the results from each workgroup
         // - ne20: the size of the Value head
         // -  + 2: the S and M values for each intermediate result
-        res += ggml_type_size(GGML_TYPE_F32)*(ne01*ne02*ne03*nwg*(ne20 + 2));
+        res += ggml_type_size(GGML_TYPE_F32)*(ne01_max*ne02*ne03*nwg*(ne20 + 2));
     }
 
     return res;