]> git.djapps.eu Git - pkg/ggml/sources/whisper.cpp/commitdiff
metal : add support for non-padded FA KV (llama/16148)
authorGeorgi Gerganov <redacted>
Tue, 7 Oct 2025 05:23:30 +0000 (08:23 +0300)
committerGeorgi Gerganov <redacted>
Sun, 12 Oct 2025 08:16:23 +0000 (11:16 +0300)
* metal : pad K, V and Mask when needed

* cont : simplify

* cuda : add TODO about KV padding requirement

* metal : add comments

* metal : remove mask padding requirement

ggml/src/ggml-cuda/fattn.cu
ggml/src/ggml-metal/ggml-metal-device.cpp
ggml/src/ggml-metal/ggml-metal-device.h
ggml/src/ggml-metal/ggml-metal-impl.h
ggml/src/ggml-metal/ggml-metal-ops.cpp
ggml/src/ggml-metal/ggml-metal-ops.h
ggml/src/ggml-metal/ggml-metal.cpp
ggml/src/ggml-metal/ggml-metal.metal

index d7736d36108a7b5e9a340bb7fbcd9f0a6630cd10..0c8e7b3e41904e3e89ee6fbf66b9f1088265d86d 100644 (file)
@@ -208,6 +208,12 @@ static best_fattn_kernel ggml_cuda_get_best_fattn_kernel(const int device, const
 
     const int cc = ggml_cuda_info().devices[device].cc;
 
+    // TODO: temporary until support is extended
+    //       https://github.com/ggml-org/llama.cpp/pull/16148#issuecomment-3343525206
+    if (K->ne[1] % FATTN_KQ_STRIDE != 0) {
+        return BEST_FATTN_KERNEL_NONE;
+    }
+
     switch (K->ne[0]) {
         case  64:
         case 128:
index d9e92044275ae07ec84371f4f685783973b38f20..46cc51345969bf4180943939d1eee7af51b8a55b 100644 (file)
@@ -924,6 +924,50 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_argsort(ggml_metal_library
     return res;
 }
 
+ggml_metal_pipeline_t ggml_metal_library_get_pipeline_flash_attn_ext_pad(
+        ggml_metal_library_t lib,
+        const struct ggml_tensor * op,
+        bool    has_mask,
+        int32_t ncpsg) {
+    assert(op->op == GGML_OP_FLASH_ATTN_EXT);
+    GGML_UNUSED(op);
+
+    char base[256];
+    char name[256];
+
+    snprintf(base, 256, "kernel_%s",
+            "flash_attn_ext_pad");
+
+    snprintf(name, 256, "%s_mask=%d_ncpsg=%d",
+            base,
+            has_mask,
+            ncpsg);
+
+    ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name);
+    if (res) {
+        return res;
+    }
+
+    ggml_metal_cv_t cv = ggml_metal_cv_init();
+
+    ggml_metal_cv_set_bool(cv, has_mask,  FC_FLASH_ATTN_EXT_PAD + 0);
+  //ggml_metal_cv_set_bool(cv, has_sinks, FC_FLASH_ATTN_EXT_PAD + 1);
+  //ggml_metal_cv_set_bool(cv, has_bias,  FC_FLASH_ATTN_EXT_PAD + 2);
+  //ggml_metal_cv_set_bool(cv, has_scap,  FC_FLASH_ATTN_EXT_PAD + 3);
+
+  //ggml_metal_cv_set_int32(cv, ns10, FC_FLASH_ATTN_EXT_PAD + 20);
+  //ggml_metal_cv_set_int32(cv, ns20, FC_FLASH_ATTN_EXT_PAD + 21);
+  //ggml_metal_cv_set_int32(cv, nsg,  FC_FLASH_ATTN_EXT_PAD + 22);
+  //ggml_metal_cv_set_int32(cv, nwg,  FC_FLASH_ATTN_EXT_PAD + 23);
+    ggml_metal_cv_set_int32(cv, ncpsg, FC_FLASH_ATTN_EXT_PAD + 24);
+
+    res = ggml_metal_library_compile_pipeline(lib, base, name, cv);
+
+    ggml_metal_cv_free(cv);
+
+    return res;
+}
+
 ggml_metal_pipeline_t ggml_metal_library_get_pipeline_flash_attn_ext(
         ggml_metal_library_t lib,
         const ggml_tensor * op,
@@ -931,6 +975,7 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_flash_attn_ext(
         bool    has_sinks,
         bool    has_bias,
         bool    has_scap,
+        bool    has_kvpad,
         int32_t nsg) {
     assert(op->op == GGML_OP_FLASH_ATTN_EXT);
 
@@ -943,18 +988,23 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_flash_attn_ext(
     const int32_t ns10 = op->src[1]->nb[1]/op->src[1]->nb[0];
     const int32_t ns20 = op->src[2]->nb[1]/op->src[2]->nb[0];
 
+    // do bounds checks for the mask?
+    const bool bc_mask = op->src[3] && (op->src[3]->ne[1] % 8 != 0);
+
     snprintf(base, 256, "kernel_%s_%s_dk%d_dv%d",
             "flash_attn_ext",
             ggml_type_name(op->src[1]->type),
             dk,
             dv);
 
-    snprintf(name, 256, "%s_mask=%d_sinks=%d_bias=%d_scap=%d_ns10=%d_ns20=%d_nsg=%d",
+    snprintf(name, 256, "%s_mask=%d_sinks=%d_bias=%d_scap=%d_kvpad=%d_bcm=%d_ns10=%d_ns20=%d_nsg=%d",
             base,
             has_mask,
             has_sinks,
             has_bias,
             has_scap,
+            has_kvpad,
+            bc_mask,
             ns10,
             ns20,
             nsg);
@@ -970,6 +1020,9 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_flash_attn_ext(
     ggml_metal_cv_set_bool(cv, has_sinks, FC_FLASH_ATTN_EXT + 1);
     ggml_metal_cv_set_bool(cv, has_bias,  FC_FLASH_ATTN_EXT + 2);
     ggml_metal_cv_set_bool(cv, has_scap,  FC_FLASH_ATTN_EXT + 3);
+    ggml_metal_cv_set_bool(cv, has_kvpad, FC_FLASH_ATTN_EXT + 4);
+
+    ggml_metal_cv_set_bool(cv, bc_mask, FC_FLASH_ATTN_EXT + 10);
 
     ggml_metal_cv_set_int32(cv, ns10, FC_FLASH_ATTN_EXT + 20);
     ggml_metal_cv_set_int32(cv, ns20, FC_FLASH_ATTN_EXT + 21);
@@ -989,6 +1042,7 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_flash_attn_ext_vec(
         bool    has_sinks,
         bool    has_bias,
         bool    has_scap,
+        bool    has_kvpad,
         int32_t nsg,
         int32_t nwg) {
     assert(op->op == GGML_OP_FLASH_ATTN_EXT);
@@ -1008,12 +1062,13 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_flash_attn_ext_vec(
             dk,
             dv);
 
-    snprintf(name, 256, "%s_mask=%d_sink=%d_bias=%d_softcap=%d_ns10=%d_ns20=%d_nsg=%d_nwg=%d",
+    snprintf(name, 256, "%s_mask=%d_sink=%d_bias=%d_scap=%d_kvpad=%d_ns10=%d_ns20=%d_nsg=%d_nwg=%d",
             base,
             has_mask,
             has_sinks,
             has_bias,
             has_scap,
+            has_kvpad,
             ns10,
             ns20,
             nsg, nwg);
@@ -1029,6 +1084,7 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_flash_attn_ext_vec(
     ggml_metal_cv_set_bool(cv, has_sinks, FC_FLASH_ATTN_EXT_VEC + 1);
     ggml_metal_cv_set_bool(cv, has_bias,  FC_FLASH_ATTN_EXT_VEC + 2);
     ggml_metal_cv_set_bool(cv, has_scap,  FC_FLASH_ATTN_EXT_VEC + 3);
+    ggml_metal_cv_set_bool(cv, has_kvpad, FC_FLASH_ATTN_EXT_VEC + 4);
 
     ggml_metal_cv_set_int32(cv, ns10, FC_FLASH_ATTN_EXT_VEC + 20);
     ggml_metal_cv_set_int32(cv, ns20, FC_FLASH_ATTN_EXT_VEC + 21);
index f6ebf90a00ee8a8147c0cb4d881e1a1c6cec4a42..ef049507384d85cf3778cb7a78024a57762081e2 100644 (file)
@@ -135,6 +135,12 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_pad_reflect_1d    (ggml_me
 ggml_metal_pipeline_t ggml_metal_library_get_pipeline_arange            (ggml_metal_library_t lib, const struct ggml_tensor * op);
 ggml_metal_pipeline_t ggml_metal_library_get_pipeline_timestep_embedding(ggml_metal_library_t lib, const struct ggml_tensor * op);
 
+ggml_metal_pipeline_t ggml_metal_library_get_pipeline_flash_attn_ext_pad(
+        ggml_metal_library_t lib,
+        const struct ggml_tensor * op,
+        bool    has_mask,
+        int32_t ncpsg);
+
 ggml_metal_pipeline_t ggml_metal_library_get_pipeline_flash_attn_ext(
         ggml_metal_library_t lib,
         const struct ggml_tensor * op,
@@ -142,6 +148,7 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_flash_attn_ext(
         bool    has_sinks,
         bool    has_bias,
         bool    has_scap,
+        bool    has_kvpad,
         int32_t nsg);
 
 ggml_metal_pipeline_t ggml_metal_library_get_pipeline_flash_attn_ext_vec(
@@ -151,6 +158,7 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_flash_attn_ext_vec(
         bool    has_sinks,
         bool    has_bias,
         bool    has_scap,
+        bool    has_kvpad,
         int32_t nsg,
         int32_t nwg);
 
index 908e2e1cf5ccb48d6d4306a8ae265acce4bf72cb..1524b3ab518cbf2704acefd745eea1996e210eea 100644 (file)
 #define N_SG_IQ4_XS 2
 
 // function constants offsets
-#define FC_FLASH_ATTN_EXT              100
-#define FC_FLASH_ATTN_EXT_VEC          200
-#define FC_FLASH_ATTN_EXT_VEC_REDUCE   300
-#define FC_MUL_MV                      400
-#define FC_MUL_MM                      500
+#define FC_FLASH_ATTN_EXT_PAD          100
+#define FC_FLASH_ATTN_EXT              200
+#define FC_FLASH_ATTN_EXT_VEC          300
+#define FC_FLASH_ATTN_EXT_VEC_REDUCE   400
+#define FC_MUL_MV                      500
+#define FC_MUL_MM                      600
 
 // kernel argument structs
 //
@@ -244,6 +245,24 @@ typedef struct {
     int32_t  sect_3;
 } ggml_metal_kargs_rope;
 
+typedef struct {
+    int32_t  ne11;
+    int32_t  ne_12_2; // assume K and V are same shape
+    int32_t  ne_12_3;
+    uint64_t nb11;
+    uint64_t nb12;
+    uint64_t nb13;
+    uint64_t nb21;
+    uint64_t nb22;
+    uint64_t nb23;
+    int32_t  ne31;
+    int32_t  ne32;
+    int32_t  ne33;
+    uint64_t nb31;
+    uint64_t nb32;
+    uint64_t nb33;
+} ggml_metal_kargs_flash_attn_ext_pad;
+
 typedef struct {
     int32_t  ne01;
     int32_t  ne02;
@@ -262,6 +281,7 @@ typedef struct {
     uint64_t nb21;
     uint64_t nb22;
     uint64_t nb23;
+    int32_t  ne31;
     int32_t  ne32;
     int32_t  ne33;
     uint64_t nb31;
@@ -296,6 +316,7 @@ typedef struct {
     uint64_t nb21;
     uint64_t nb22;
     uint64_t nb23;
+    int32_t  ne31;
     int32_t  ne32;
     int32_t  ne33;
     uint64_t nb31;
index 7497d7c1da09839bbc3724a6cf22555ad4b1a4cf..125cc64dc5295f2bfb6708543bf3514b3f8550c7 100644 (file)
@@ -226,6 +226,10 @@ static int ggml_metal_op_encode_impl(ggml_metal_op_t ctx, int idx) {
             GGML_TENSOR_LOCALS(uint64_t, nb0, node->src[0], nb);
             GGML_TENSOR_LOCALS( int64_t, ne1, node->src[1], ne);
             GGML_TENSOR_LOCALS(uint64_t, nb1, node->src[1], nb);
+            GGML_TENSOR_LOCALS( int64_t, ne2, node->src[2], ne);
+            GGML_TENSOR_LOCALS(uint64_t, nb2, node->src[2], nb);
+            GGML_TENSOR_LOCALS( int64_t, ne3, node->src[3], ne);
+            GGML_TENSOR_LOCALS(uint64_t, nb3, node->src[3], nb);
             GGML_TENSOR_LOCALS( int64_t, ne,  node,         ne);
             GGML_TENSOR_LOCALS(uint64_t, nb,  node,         nb);
 
@@ -237,6 +241,14 @@ static int ggml_metal_op_encode_impl(ggml_metal_op_t ctx, int idx) {
                 GGML_LOG_DEBUG("%s: src1 - %4s [%5lld, %5lld, %5lld, %5lld] [%5lld, %5lld, %5lld, %5lld], %d, %s\n", __func__, ggml_type_name(node->src[1]->type), ne10, ne11, ne12, ne13, nb10, nb11, nb12, nb13,
                         ggml_is_contiguous(node->src[1]), node->src[1]->name);
             }
+            if (node->src[2]) {
+                GGML_LOG_DEBUG("%s: src2 - %4s [%5lld, %5lld, %5lld, %5lld] [%5lld, %5lld, %5lld, %5lld], %d, %s\n", __func__, ggml_type_name(node->src[2]->type), ne20, ne21, ne22, ne23, nb20, nb21, nb22, nb23,
+                        ggml_is_contiguous(node->src[2]), node->src[2]->name);
+            }
+            if (node->src[3]) {
+                GGML_LOG_DEBUG("%s: src3 - %4s [%5lld, %5lld, %5lld, %5lld] [%5lld, %5lld, %5lld, %5lld], %d, %s\n", __func__, ggml_type_name(node->src[3]->type), ne30, ne31, ne32, ne33, nb30, nb31, nb32, nb33,
+                        ggml_is_contiguous(node->src[3]), node->src[3]->name);
+            }
             if (node) {
                 GGML_LOG_DEBUG("%s: node  - %4s [%5lld, %5lld, %5lld, %5lld] [%5lld, %5lld, %5lld, %5lld], 1, %s\n", __func__, ggml_type_name(node->type), ne0, ne1, ne2, ne3, nb0, nb1, nb2, nb3,
                         node->name);
@@ -1889,20 +1901,69 @@ bool ggml_metal_op_flash_attn_ext_use_vec(const ggml_tensor * op) {
     return (ne01 < 20) && (ne00 % 32 == 0);
 }
 
+size_t ggml_metal_op_flash_attn_ext_extra_pad(const ggml_tensor * op) {
+    assert(op->op == GGML_OP_FLASH_ATTN_EXT);
+
+    GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
+    GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
+    GGML_TENSOR_LOCALS( int32_t, ne1, op->src[1], ne);
+    GGML_TENSOR_LOCALS(uint64_t, nb1, op->src[1], nb);
+    GGML_TENSOR_LOCALS( int32_t, ne2, op->src[2], ne);
+    GGML_TENSOR_LOCALS(uint64_t, nb2, op->src[2], nb);
+    GGML_TENSOR_LOCALS( int32_t, ne3, op->src[3], ne);
+    GGML_TENSOR_LOCALS(uint64_t, nb3, op->src[3], nb);
+
+    size_t res = 0;
+
+    const bool has_mask = op->src[3] != nullptr;
+
+    if (ggml_metal_op_flash_attn_ext_use_vec(op)) {
+        const bool has_kvpad = ne11 % 32 != 0;
+
+        if (has_kvpad) {
+            res += 32*(
+                nb11*ne12*ne13 +
+                nb21*ne22*ne23 +
+                (has_mask ? ggml_type_size(GGML_TYPE_F16)*ne31*ne32*ne33 : 0));
+        }
+    } else {
+        const bool has_kvpad = ne11 % 64 != 0;
+
+        if (has_kvpad) {
+            res += 64*(
+                nb11*ne12*ne13 +
+                nb21*ne22*ne23 +
+                (has_mask ? ggml_type_size(GGML_TYPE_F16)*ne31*ne32*ne33 : 0));
+        }
+    }
+
+    return res;
+}
+
 size_t ggml_metal_op_flash_attn_ext_extra_tmp(const ggml_tensor * op) {
     assert(op->op == GGML_OP_FLASH_ATTN_EXT);
 
-    const int64_t nwg = 32;
+    GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
+    GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
+  //GGML_TENSOR_LOCALS( int32_t, ne1, op->src[1], ne);
+  //GGML_TENSOR_LOCALS(uint64_t, nb1, op->src[1], nb);
+    GGML_TENSOR_LOCALS( int32_t, ne2, op->src[2], ne);
+    GGML_TENSOR_LOCALS(uint64_t, nb2, op->src[2], nb);
+  //GGML_TENSOR_LOCALS( int32_t, ne3, op->src[3], ne);
+  //GGML_TENSOR_LOCALS(uint64_t, nb3, op->src[3], nb);
+
+    size_t res = 0;
 
-    const int64_t ne01 = op->src[0]->ne[1];
-    const int64_t ne02 = op->src[0]->ne[2];
-    const int64_t ne03 = op->src[0]->ne[3];
-    const int64_t ne20 = op->src[2]->ne[0];
+    if (ggml_metal_op_flash_attn_ext_use_vec(op)) {
+        const int64_t nwg = 32;
 
-    // temp buffer for writing the results from each workgroup
-    // - ne20: the size of the Value head
-    // -  + 2: the S and M values for each intermediate result
-    return ggml_type_size(GGML_TYPE_F32)*(ne01*ne02*ne03*nwg*(ne20 + 2));
+        // temp buffer for writing the results from each workgroup
+        // - ne20: the size of the Value head
+        // -  + 2: the S and M values for each intermediate result
+        res += ggml_type_size(GGML_TYPE_F32)*(ne01*ne02*ne03*nwg*(ne20 + 2));
+    }
+
+    return res;
 }
 
 int ggml_metal_op_flash_attn_ext(ggml_metal_op_t ctx, int idx) {
@@ -1924,8 +1985,7 @@ int ggml_metal_op_flash_attn_ext(ggml_metal_op_t ctx, int idx) {
     GGML_TENSOR_LOCALS( int32_t, ne,  op,         ne);
     GGML_TENSOR_LOCALS( int32_t, nb,  op,         nb);
 
-    GGML_ASSERT(ne00 % 4  == 0);
-    GGML_ASSERT(ne11 % 32 == 0);
+    GGML_ASSERT(ne00 % 4 == 0);
 
     GGML_ASSERT(op->src[0]->type == GGML_TYPE_F32);
     GGML_ASSERT(op->src[1]->type == op->src[2]->type);
@@ -1935,8 +1995,8 @@ int ggml_metal_op_flash_attn_ext(ggml_metal_op_t ctx, int idx) {
     GGML_ASSERT(ne12 == ne22);
 
     GGML_ASSERT(!op->src[3] || op->src[3]->type == GGML_TYPE_F16);
-    GGML_ASSERT(!op->src[3] || op->src[3]->ne[1] >= GGML_PAD(op->src[0]->ne[1], 8) &&
-            "the Flash-Attention Metal kernel requires the mask to be padded to 8 and at least n_queries big");
+    GGML_ASSERT(!op->src[3] || op->src[3]->ne[1] >= op->src[0]->ne[1] &&
+            "the Flash-Attention Metal kernel requires the mask to be at least n_queries big");
 
     float scale;
     float max_bias;
@@ -1963,6 +2023,20 @@ int ggml_metal_op_flash_attn_ext(ggml_metal_op_t ctx, int idx) {
 
     GGML_ASSERT(ne01 < 65536);
 
+    ggml_metal_buffer_id bid_src0 = ggml_metal_get_buffer_id(op->src[0]);
+    ggml_metal_buffer_id bid_src1 = ggml_metal_get_buffer_id(op->src[1]);
+    ggml_metal_buffer_id bid_src2 = ggml_metal_get_buffer_id(op->src[2]);
+    ggml_metal_buffer_id bid_src3 = has_mask  ? ggml_metal_get_buffer_id(op->src[3]) : bid_src0;
+    ggml_metal_buffer_id bid_src4 = has_sinks ? ggml_metal_get_buffer_id(op->src[4]) : bid_src0;
+
+    ggml_metal_buffer_id bid_dst = ggml_metal_get_buffer_id(op);
+
+    ggml_metal_buffer_id bid_pad = bid_dst;
+    bid_pad.offs += ggml_nbytes(op);
+
+    ggml_metal_buffer_id bid_tmp = bid_pad;
+    bid_tmp.offs += ggml_metal_op_flash_attn_ext_extra_pad(op);
+
     if (!ggml_metal_op_flash_attn_ext_use_vec(op)) {
         // half8x8 kernel
         const int64_t nqptg = 8;  // queries per threadgroup    !! sync with kernel template arguments !!
@@ -1972,6 +2046,48 @@ int ggml_metal_op_flash_attn_ext(ggml_metal_op_t ctx, int idx) {
         GGML_ASSERT(nqptg  % 8  == 0);
         GGML_ASSERT(ncpsg  % 32 == 0);
 
+        const bool has_kvpad = ne11 % ncpsg != 0;
+
+        if (has_kvpad) {
+            assert(ggml_metal_op_flash_attn_ext_extra_pad(op) != 0);
+
+            ggml_metal_kargs_flash_attn_ext_pad args0 = {
+                /*.ne11    =*/ne11,
+                /*.ne_12_2 =*/ne12,
+                /*.ne_12_3 =*/ne13,
+                /*.nb11    =*/nb11,
+                /*.nb12    =*/nb12,
+                /*.nb13    =*/nb13,
+                /*.nb21    =*/nb21,
+                /*.nb22    =*/nb22,
+                /*.nb23    =*/nb23,
+                /*.ne31    =*/ne31,
+                /*.ne32    =*/ne32,
+                /*.ne33    =*/ne33,
+                /*.nb31    =*/nb31,
+                /*.nb32    =*/nb32,
+                /*.nb33    =*/nb33,
+            };
+
+            ggml_metal_pipeline_t pipeline0 = ggml_metal_library_get_pipeline_flash_attn_ext_pad(lib, op, has_mask, ncpsg);
+
+            ggml_metal_encoder_set_pipeline(enc, pipeline0);
+            ggml_metal_encoder_set_bytes   (enc, &args0, sizeof(args0), 0);
+            ggml_metal_encoder_set_buffer  (enc, bid_src1, 1);
+            ggml_metal_encoder_set_buffer  (enc, bid_src2, 2);
+            ggml_metal_encoder_set_buffer  (enc, bid_src3, 3);
+            ggml_metal_encoder_set_buffer  (enc, bid_pad,  4);
+
+            assert(ne12 == ne22);
+            assert(ne13 == ne23);
+
+            ggml_metal_encoder_dispatch_threadgroups(enc, ncpsg, std::max(ne12, ne32), std::max(ne13, ne33), 32, 1, 1);
+
+            ggml_metal_op_concurrency_reset(ctx);
+        } else {
+            assert(ggml_metal_op_flash_attn_ext_extra_pad(op) == 0);
+        }
+
         const int is_q = ggml_is_quantized(op->src[1]->type) ? 1 : 0;
 
         // 2*(2*ncpsg)
@@ -2021,6 +2137,7 @@ int ggml_metal_op_flash_attn_ext(ggml_metal_op_t ctx, int idx) {
             /*.nb21          =*/ nb21,
             /*.nb22          =*/ nb22,
             /*.nb23          =*/ nb23,
+            /*.ne31          =*/ ne31,
             /*.ne32          =*/ ne32,
             /*.ne33          =*/ ne33,
             /*.nb31          =*/ nb31,
@@ -2037,24 +2154,17 @@ int ggml_metal_op_flash_attn_ext(ggml_metal_op_t ctx, int idx) {
             /*.logit_softcap =*/ logit_softcap,
         };
 
-        ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_flash_attn_ext(lib, op, has_mask, has_sinks, has_bias, has_scap, nsg);
+        ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_flash_attn_ext(lib, op, has_mask, has_sinks, has_bias, has_scap, has_kvpad, nsg);
 
         ggml_metal_encoder_set_pipeline(enc, pipeline);
         ggml_metal_encoder_set_bytes   (enc, &args, sizeof(args), 0);
-        ggml_metal_encoder_set_buffer  (enc, ggml_metal_get_buffer_id(op->src[0]), 1);
-        ggml_metal_encoder_set_buffer  (enc, ggml_metal_get_buffer_id(op->src[1]), 2);
-        ggml_metal_encoder_set_buffer  (enc, ggml_metal_get_buffer_id(op->src[2]), 3);
-        if (op->src[3]) {
-            ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op->src[3]), 4);
-        } else {
-            ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op->src[0]), 4);
-        }
-        if (op->src[4]) {
-            ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op->src[4]), 5);
-        } else {
-            ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op->src[0]), 5);
-        }
-        ggml_metal_encoder_set_buffer  (enc, ggml_metal_get_buffer_id(op),         6);
+        ggml_metal_encoder_set_buffer  (enc, bid_src0, 1);
+        ggml_metal_encoder_set_buffer  (enc, bid_src1, 2);
+        ggml_metal_encoder_set_buffer  (enc, bid_src2, 3);
+        ggml_metal_encoder_set_buffer  (enc, bid_src3, 4);
+        ggml_metal_encoder_set_buffer  (enc, bid_src4, 5);
+        ggml_metal_encoder_set_buffer  (enc, bid_pad,  6);
+        ggml_metal_encoder_set_buffer  (enc, bid_dst,  7);
 
         ggml_metal_encoder_set_threadgroup_memory_size(enc, smem, 0);
 
@@ -2070,6 +2180,48 @@ int ggml_metal_op_flash_attn_ext(ggml_metal_op_t ctx, int idx) {
         GGML_ASSERT(nqptg  % 1  == 0);
         GGML_ASSERT(ncpsg  % 32 == 0);
 
+        const bool has_kvpad = ne11 % ncpsg != 0;
+
+        if (has_kvpad) {
+            assert(ggml_metal_op_flash_attn_ext_extra_pad(op) != 0);
+
+            ggml_metal_kargs_flash_attn_ext_pad args0 = {
+                /*.ne11    =*/ne11,
+                /*.ne_12_2 =*/ne12,
+                /*.ne_12_3 =*/ne13,
+                /*.nb11    =*/nb11,
+                /*.nb12    =*/nb12,
+                /*.nb13    =*/nb13,
+                /*.nb21    =*/nb21,
+                /*.nb22    =*/nb22,
+                /*.nb23    =*/nb23,
+                /*.ne31    =*/ne31,
+                /*.ne32    =*/ne32,
+                /*.ne33    =*/ne33,
+                /*.nb31    =*/nb31,
+                /*.nb32    =*/nb32,
+                /*.nb33    =*/nb33,
+            };
+
+            ggml_metal_pipeline_t pipeline0 = ggml_metal_library_get_pipeline_flash_attn_ext_pad(lib, op, has_mask, ncpsg);
+
+            ggml_metal_encoder_set_pipeline(enc, pipeline0);
+            ggml_metal_encoder_set_bytes   (enc, &args0, sizeof(args0), 0);
+            ggml_metal_encoder_set_buffer  (enc, bid_src1, 1);
+            ggml_metal_encoder_set_buffer  (enc, bid_src2, 2);
+            ggml_metal_encoder_set_buffer  (enc, bid_src3, 3);
+            ggml_metal_encoder_set_buffer  (enc, bid_pad,  4);
+
+            assert(ne12 == ne22);
+            assert(ne13 == ne23);
+
+            ggml_metal_encoder_dispatch_threadgroups(enc, ncpsg, std::max(ne12, ne32), std::max(ne13, ne33), 32, 1, 1);
+
+            ggml_metal_op_concurrency_reset(ctx);
+        } else {
+            assert(ggml_metal_op_flash_attn_ext_extra_pad(op) == 0);
+        }
+
         // ne00 + 2*ncpsg*(nsg)
         // for each query, we load it as f16 in shared memory (ne00)
         // and store the soft_max values and the mask
@@ -2134,6 +2286,7 @@ int ggml_metal_op_flash_attn_ext(ggml_metal_op_t ctx, int idx) {
             /*.nb21          =*/ nb21,
             /*.nb22          =*/ nb22,
             /*.nb23          =*/ nb23,
+            /*.ne31          =*/ ne31,
             /*.ne32          =*/ ne32,
             /*.ne33          =*/ ne33,
             /*.nb31          =*/ nb31,
@@ -2150,25 +2303,17 @@ int ggml_metal_op_flash_attn_ext(ggml_metal_op_t ctx, int idx) {
             /*.logit_softcap =*/ logit_softcap,
         };
 
-        ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_flash_attn_ext_vec(lib, op, has_mask, has_sinks, has_bias, has_scap, nsg, nwg);
+        ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_flash_attn_ext_vec(lib, op, has_mask, has_sinks, has_bias, has_scap, has_kvpad, nsg, nwg);
 
         GGML_ASSERT(nsg*32 <= ggml_metal_pipeline_max_theads_per_threadgroup(pipeline));
 
         ggml_metal_encoder_set_pipeline(enc, pipeline);
         ggml_metal_encoder_set_bytes   (enc, &args, sizeof(args), 0);
-        ggml_metal_encoder_set_buffer  (enc, ggml_metal_get_buffer_id(op->src[0]), 1);
-        ggml_metal_encoder_set_buffer  (enc, ggml_metal_get_buffer_id(op->src[1]), 2);
-        ggml_metal_encoder_set_buffer  (enc, ggml_metal_get_buffer_id(op->src[2]), 3);
-        if (op->src[3]) {
-            ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op->src[3]), 4);
-        } else {
-            ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op->src[0]), 4);
-        }
-        if (op->src[4]) {
-            ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op->src[4]), 5);
-        } else {
-            ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op->src[0]), 5);
-        }
+        ggml_metal_encoder_set_buffer  (enc, bid_src0, 1);
+        ggml_metal_encoder_set_buffer  (enc, bid_src1, 2);
+        ggml_metal_encoder_set_buffer  (enc, bid_src2, 3);
+        ggml_metal_encoder_set_buffer  (enc, bid_src3, 4);
+        ggml_metal_encoder_set_buffer  (enc, bid_src4, 5);
 
         const size_t smem = FATTN_SMEM(nsg);
 
@@ -2176,23 +2321,25 @@ int ggml_metal_op_flash_attn_ext(ggml_metal_op_t ctx, int idx) {
         GGML_ASSERT(smem <= props_dev->max_theadgroup_memory_size);
 
         if (nwg == 1) {
+            assert(ggml_metal_op_flash_attn_ext_extra_tmp(op) == 0);
+
             // using 1 workgroup -> write the result directly into dst
-            ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op), 6);
+            ggml_metal_encoder_set_buffer(enc, bid_pad, 6);
+            ggml_metal_encoder_set_buffer(enc, bid_dst, 7);
 
             ggml_metal_encoder_set_threadgroup_memory_size(enc, smem, 0);
 
             ggml_metal_encoder_dispatch_threadgroups(enc, (ne01 + nqptg - 1)/nqptg, ne02, ne03*nwg, 32, nsg, 1);
         } else {
             // sanity checks
+            assert(ggml_metal_op_flash_attn_ext_extra_tmp(op) != 0);
+
             GGML_ASSERT(ne01*ne02*ne03 == ne1*ne2*ne3);
             GGML_ASSERT((uint64_t)ne1*ne2*ne3 <= (1u << 31));
 
-            ggml_metal_buffer_id bid_dst = ggml_metal_get_buffer_id(op);
-
             // write the results from each workgroup into a temp buffer
-            ggml_metal_buffer_id bid_tmp = bid_dst;
-            bid_tmp.offs += ggml_nbytes(op);
-            ggml_metal_encoder_set_buffer(enc, bid_tmp, 6);
+            ggml_metal_encoder_set_buffer(enc, bid_pad, 6);
+            ggml_metal_encoder_set_buffer(enc, bid_tmp, 7);
 
             ggml_metal_encoder_set_threadgroup_memory_size(enc, smem, 0);
             ggml_metal_encoder_dispatch_threadgroups(enc, (ne01 + nqptg - 1)/nqptg, ne02, ne03*nwg, 32, nsg, 1);
index 8df4c72e7c8cbf2c3f166c6b4cfb374f65c1170b..6a6d8a7977a7c25b2073629b2e0f9a25ed78d14b 100644 (file)
@@ -39,6 +39,7 @@ size_t ggml_metal_op_mul_mat_id_extra_ids(const struct ggml_tensor * op);
 // return true if we should use the FA vector kernel for this op
 bool ggml_metal_op_flash_attn_ext_use_vec(const struct ggml_tensor * op);
 
+size_t ggml_metal_op_flash_attn_ext_extra_pad(const struct ggml_tensor * op);
 size_t ggml_metal_op_flash_attn_ext_extra_tmp(const struct ggml_tensor * op);
 
 int ggml_metal_op_concat            (ggml_metal_op_t ctx, int idx);
index e11555a78fc7196bc74376cc535bad3ca1f85099..e53f37b29c1a4a48d5c85b8f03e87c70048280a8 100644 (file)
@@ -193,9 +193,8 @@ static size_t ggml_backend_metal_buffer_type_get_alloc_size(ggml_backend_buffer_
             } break;
         case GGML_OP_FLASH_ATTN_EXT:
             {
-                if (ggml_metal_op_flash_attn_ext_use_vec(tensor)) {
-                    res += ggml_metal_op_flash_attn_ext_extra_tmp(tensor);
-                }
+                res += ggml_metal_op_flash_attn_ext_extra_pad(tensor);
+                res += ggml_metal_op_flash_attn_ext_extra_tmp(tensor);
             } break;
         default:
             break;
index f454ceada37b373b55085d01e5115735a4dff84c..c52c6b48ad900a015fba9769cf6d2e8585ab1d69 100644 (file)
@@ -4349,10 +4349,83 @@ kernel void kernel_leaky_relu_f32_4(
     dst[tpig] = float4(x > 0.0f)*x + float4(x <= 0.0f)*(x * args.slope);
 }
 
+constant bool FC_flash_attn_ext_pad_has_mask [[function_constant(FC_FLASH_ATTN_EXT_PAD + 0)]];
+
+constant int32_t FC_flash_attn_ext_pad_ncpsg [[function_constant(FC_FLASH_ATTN_EXT_PAD + 24)]];
+
+// pad the last chunk of C elements of k and v into a an extra pad buffer
+kernel void kernel_flash_attn_ext_pad(
+        constant ggml_metal_kargs_flash_attn_ext_pad & args,
+        device const char * k,
+        device const char * v,
+        device const char * mask,
+        device       char * dst,
+        uint3   tgpig[[threadgroup_position_in_grid]],
+        ushort  tiitg[[thread_index_in_threadgroup]],
+        ushort3   ntg[[threads_per_threadgroup]]) {
+    const int32_t C = FC_flash_attn_ext_pad_ncpsg;
+
+    device char * k_pad    = dst;
+    device char * v_pad    = k_pad + args.nb11*C*args.ne_12_2*args.ne_12_3;
+    device char * mask_pad = v_pad + args.nb21*C*args.ne_12_2*args.ne_12_3;
+
+    const int32_t icp = args.ne11 % C;
+    const int32_t ic0 = args.ne11 - icp;
+
+    const int32_t i1 = tgpig[0];
+    const int32_t i2 = tgpig[1];
+    const int32_t i3 = tgpig[2];
+
+    if (i2 < args.ne_12_2 && i3 < args.ne_12_3) {
+        device const char * k_src = k + args.nb11*(ic0 + i1) + args.nb12*i2 + args.nb13*i3;
+        device const char * v_src = v + args.nb21*(ic0 + i1) + args.nb22*i2 + args.nb23*i3;
+
+        device char * k_dst = k_pad + args.nb11*i1 + args.nb11*C*i2 + args.nb11*C*args.ne_12_2*i3;
+        device char * v_dst = v_pad + args.nb21*i1 + args.nb21*C*i2 + args.nb21*C*args.ne_12_2*i3;
+
+        if (i1 >= icp) {
+            // here it is not important the exact value that will be used as we rely on masking out the scores in the attention
+            for (uint64_t i = tiitg; i < args.nb11; i += ntg.x) {
+                k_dst[i] = 0;
+            }
+            for (uint64_t i = tiitg; i < args.nb21; i += ntg.x) {
+                v_dst[i] = 0;
+            }
+        } else {
+            for (uint64_t i = tiitg; i < args.nb11; i += ntg.x) {
+                k_dst[i] = k_src[i];
+            }
+            for (uint64_t i = tiitg; i < args.nb21; i += ntg.x) {
+                v_dst[i] = v_src[i];
+            }
+        }
+    }
+
+    if (FC_flash_attn_ext_pad_has_mask) {
+        if (i2 < args.ne32 && i3 < args.ne33) {
+            for (int ib = i1; ib < args.ne31; ib += C) {
+                device const half * mask_src = (device const half *)(mask      + args.nb31*ib + args.nb32*i2 + args.nb33*i3) + ic0;
+                device       half * mask_dst = (device       half *)(mask_pad) + C*ib + C*args.ne31*i2 + C*args.ne31*args.ne32*i3;
+
+                for (int i = tiitg; i < C; i += ntg.x) {
+                    if (i >= icp) {
+                        mask_dst[i] = -MAXHALF;
+                    } else {
+                        mask_dst[i] = mask_src[i];
+                    }
+                }
+            }
+        }
+    }
+}
+
 constant bool FC_flash_attn_ext_has_mask  [[function_constant(FC_FLASH_ATTN_EXT + 0)]];
 constant bool FC_flash_attn_ext_has_sinks [[function_constant(FC_FLASH_ATTN_EXT + 1)]];
 constant bool FC_flash_attn_ext_has_bias  [[function_constant(FC_FLASH_ATTN_EXT + 2)]];
 constant bool FC_flash_attn_ext_has_scap  [[function_constant(FC_FLASH_ATTN_EXT + 3)]];
+constant bool FC_flash_attn_ext_has_kvpad [[function_constant(FC_FLASH_ATTN_EXT + 4)]];
+
+constant bool FC_flash_attn_ext_bc_mask [[function_constant(FC_FLASH_ATTN_EXT + 10)]];
 
 //constant float FC_flash_attn_ext_scale         [[function_constant(FC_FLASH_ATTN_EXT + 10)]];
 //constant float FC_flash_attn_ext_max_bias      [[function_constant(FC_FLASH_ATTN_EXT + 11)]];
@@ -4399,6 +4472,7 @@ void kernel_flash_attn_ext_impl(
         device const char * v,
         device const char * mask,
         device const char * sinks,
+        device const char * pad,
         device       char * dst,
         threadgroup  half * shmem_f16,
         uint3   tgpig,
@@ -4523,13 +4597,58 @@ void kernel_flash_attn_ext_impl(
 
         // loop over the KV cache
         // each simdgroup handles blocks of Q rows and C columns
-        for (int ic = 0; ic < args.ne11; ic += C) {
+        for (int ic0 = 0; ic0 < args.ne11; ic0 += C) {
+            int ic = ic0;
+
+            // the last partial chunk uses the pad buffer as source
+            if (FC_flash_attn_ext_has_kvpad && ic0 + C > args.ne11) {
+                k    = pad;
+                v    = k + args.nb11*C*args.ne_12_2*args.ne_12_3;
+                mask = v + args.nb21*C*args.ne_12_2*args.ne_12_3;
+
+                const short ikv2 = iq2/(args.ne02/args.ne_12_2);
+                const short ikv3 = iq3/(args.ne03/args.ne_12_3);
+
+                k += (ikv2 + ikv3*args.ne_12_2)*args.nb11*C;
+                v += (ikv2 + ikv3*args.ne_12_2)*args.nb21*C;
+
+                if (!FC_flash_attn_ext_has_mask) {
+                    threadgroup half * sm = (threadgroup half *) (sm2);
+
+                    FOR_UNROLL (short jj = 0; jj < NQ; ++jj) {
+                        const short j = jj*NSG + sgitg;
+
+                        for (short i = tiisg; i < C; i += NW) {
+                            if (ic + i >= args.ne11) {
+                                sm[2*j*SH + i] = -MAXHALF;
+                            }
+                        }
+                    }
+                } else {
+                    FOR_UNROLL (short jj = 0; jj < NQ; ++jj) {
+                        const short j = jj*NSG + sgitg;
+
+                        pm2[jj] = (device const half2 *) ((device const half *) mask +
+                                (iq1 + j)*C +
+                                (iq2%args.ne32)*(C*args.ne31) +
+                                (iq3%args.ne33)*(C*args.ne31*args.ne32));
+                    }
+                }
+
+                ic = 0;
+            }
+
             // read the mask into shared mem
             if (FC_flash_attn_ext_has_mask) {
                 FOR_UNROLL (short jj = 0; jj < NQ; ++jj) {
                     const short j = jj*NSG + sgitg;
 
-                    sm2[j*SH + tiisg] = pm2[jj][tiisg];
+                    if (FC_flash_attn_ext_bc_mask) {
+                        sm2[j*SH + tiisg] = (iq1 + j) < args.ne31 ? pm2[jj][tiisg] : half2(-MAXHALF, -MAXHALF);
+                    } else {
+                        sm2[j*SH + tiisg] = pm2[jj][tiisg];
+                    }
+
                     pm2[jj] += NW;
                 }
 
@@ -4557,7 +4676,7 @@ void kernel_flash_attn_ext_impl(
             // this is compile-time check, so it does not have runtime overhead
             if (is_same<kd4x4_t, k4x4_t>::value) {
                 // we can read directly from global memory
-                device      const k_t * pk = (device const k_t *) ((device const char *) k + ic*args.nb11);
+                device      const k_t * pk = (device const k_t *) (k + ic*args.nb11);
                 threadgroup const q_t * pq = sq;
                 threadgroup       s_t * ps = ss;
 
@@ -4629,7 +4748,7 @@ void kernel_flash_attn_ext_impl(
                     qk8x8_t mqk = make_filled_simdgroup_matrix<qk_t, 8>((qk_t) 0.0f);
 
                     for (short ii = 0; ii < DK16; ii += 4) {
-                        device const kd4x4_t * pk4x4 = (device const kd4x4_t *) ((device const char *) k + ((ic + 8*cc + ty)*args.nb11));
+                        device const kd4x4_t * pk4x4 = (device const kd4x4_t *) (k + ((ic + 8*cc + ty)*args.nb11));
 
                         if (DK16%4 == 0) {
                             // the head is evenly divisible by 4*16 = 64, so no need for bound checks
@@ -4751,7 +4870,7 @@ void kernel_flash_attn_ext_impl(
                     {
                         auto sst = ss;
 
-                        device const v_t * pv = (device const v_t *) ((device const char *) v + ic*args.nb21);
+                        device const v_t * pv = (device const v_t *) (v + ic*args.nb21);
 
                         pv += 8*sgitg;
 
@@ -4793,7 +4912,7 @@ void kernel_flash_attn_ext_impl(
                         simdgroup_load(vs, ss + 8*cc, SH, 0, false);
 
                         for (short ii = 4*sgitg; ii < DV16; ii += 4*NSG) {
-                            device const vd4x4_t * pv4x4 = (device const vd4x4_t *) ((device const char *) v + ((ic + 8*cc + ty)*args.nb21));
+                            device const vd4x4_t * pv4x4 = (device const vd4x4_t *) (v + ((ic + 8*cc + ty)*args.nb21));
 
                             if (DV16%4 == 0) {
                                 // no need for bound checks
@@ -4937,13 +5056,14 @@ kernel void kernel_flash_attn_ext(
         device const char * v,
         device const char * mask,
         device const char * sinks,
+        device const char * pad,
         device       char * dst,
         threadgroup  half * shmem_f16 [[threadgroup(0)]],
         uint3   tgpig[[threadgroup_position_in_grid]],
         ushort  tiisg[[thread_index_in_simdgroup]],
         ushort  sgitg[[simdgroup_index_in_threadgroup]]) {
 #define FWD_TMPL q_t, q4_t, q8x8_t, k_t, k4x4_t, k8x8_t, v_t, v4x4_t, v8x8_t, qk_t, qk8x8_t, s_t, s2_t, s8x8_t, o_t, o4_t, o8x8_t, kd4x4_t, nl_k, deq_k, vd4x4_t, nl_v, deq_v, DK, DV, Q, C
-#define FWD_ARGS args, q, k, v, mask, sinks, dst, shmem_f16, tgpig, tiisg, sgitg
+#define FWD_ARGS args, q, k, v, mask, sinks, pad, dst, shmem_f16, tgpig, tiisg, sgitg
     switch (FC_flash_attn_ext_nsg) {
       // note: disabled cases to reduce library load time
       //case 1: kernel_flash_attn_ext_impl<FWD_TMPL, 1>(FWD_ARGS); break;
@@ -5063,6 +5183,7 @@ constant bool FC_flash_attn_ext_vec_has_mask  [[function_constant(FC_FLASH_ATTN_
 constant bool FC_flash_attn_ext_vec_has_sinks [[function_constant(FC_FLASH_ATTN_EXT_VEC + 1)]];
 constant bool FC_flash_attn_ext_vec_has_bias  [[function_constant(FC_FLASH_ATTN_EXT_VEC + 2)]];
 constant bool FC_flash_attn_ext_vec_has_scap  [[function_constant(FC_FLASH_ATTN_EXT_VEC + 3)]];
+constant bool FC_flash_attn_ext_vec_has_kvpad [[function_constant(FC_FLASH_ATTN_EXT_VEC + 4)]];
 
 //constant float FC_flash_attn_ext_vec_scale         [[function_constant(FC_FLASH_ATTN_EXT_VEC + 10)]];
 //constant float FC_flash_attn_ext_vec_max_bias      [[function_constant(FC_FLASH_ATTN_EXT_VEC + 11)]];
@@ -5100,6 +5221,7 @@ void kernel_flash_attn_ext_vec_impl(
         device const char * v,
         device const char * mask,
         device const char * sinks,
+        device const char * pad,
         device       char * dst,
         threadgroup  half * shmem_f16 [[threadgroup(0)]],
         uint3   tgpig[[threadgroup_position_in_grid]],
@@ -5206,11 +5328,37 @@ void kernel_flash_attn_ext_vec_impl(
         // loop over the KV cache
         // each simdgroup handles blocks of Q rows and C columns
         for (int ic0 = (int) iwg*C*NSG; ic0 < args.ne11; ic0 += (int) NWG*C*NSG) {
-            const int ic = ic0 + C*sgitg;
+            int ic = ic0 + C*sgitg;
             if (ic >= args.ne11) {
                 break;
             }
 
+            // the last partial chunk uses the pad buffer as source
+            if (FC_flash_attn_ext_vec_has_kvpad && ic + C > args.ne11) {
+                k    = pad;
+                v    = k + args.nb11*C*args.ne_12_2*args.ne_12_3;
+                mask = v + args.nb21*C*args.ne_12_2*args.ne_12_3;
+
+                const short ikv2 = iq2/(args.ne02/args.ne_12_2);
+                const short ikv3 = iq3/(args.ne03/args.ne_12_3);
+
+                k += (ikv2 + ikv3*args.ne_12_2)*args.nb11*C;
+                v += (ikv2 + ikv3*args.ne_12_2)*args.nb21*C;
+
+                if (!FC_flash_attn_ext_vec_has_mask) {
+                    if (ic + tiisg >= args.ne11) {
+                        sm[tiisg] = -MAXHALF;
+                    }
+                } else {
+                    pm = (device const half *) (mask) +
+                        iq1*C +
+                        (iq2%args.ne32)*(C*args.ne31) +
+                        (iq3%args.ne33)*(C*args.ne31*args.ne32);
+                }
+
+                ic = 0;
+            }
+
             if (FC_flash_attn_ext_vec_has_mask) {
                 sm[tiisg] = pm[ic + tiisg];
             }
@@ -5222,7 +5370,7 @@ void kernel_flash_attn_ext_vec_impl(
 
             // Q*K^T
             {
-                device      const k4_t * pk4 = (device const k4_t *) ((device const char *) k + ic*args.nb11);
+                device      const k4_t * pk4 = (device const k4_t *) (k + ic*args.nb11);
                 threadgroup const q4_t * pq4 = sq4;
 
                 pk4 += ty*NS10/4 + tx;
@@ -5237,7 +5385,7 @@ void kernel_flash_attn_ext_vec_impl(
                             mqk[cc] += dot((float4) pk4[cc*NE*NS10/4 +  ii*NL], (float4) pq4[ii*NL]);
                         }
                     } else {
-                        device const kd4_t * pk = (device const kd4_t *) ((device const char *) k + ((ic + NE*cc + ty)*args.nb11));
+                        device const kd4_t * pk = (device const kd4_t *) (k + ((ic + NE*cc + ty)*args.nb11));
 
                         k4_t mk;
 
@@ -5335,7 +5483,7 @@ void kernel_flash_attn_ext_vec_impl(
                 }
 
                 if (is_same<vd4_t, v4_t>::value) {
-                    device const v4_t * pv4 = (device const v4_t *) ((device const char *) v + ic*args.nb21);
+                    device const v4_t * pv4 = (device const v4_t *) (v + ic*args.nb21);
 
                     pv4 += ty*NS20/4 + tx;
 
@@ -5348,7 +5496,7 @@ void kernel_flash_attn_ext_vec_impl(
                     }
                 } else {
                     FOR_UNROLL (short cc = 0; cc < C/NE; ++cc) {
-                        device const vd4_t * pv4 = (device const vd4_t *) ((device const char *) v + ((ic + NE*cc + ty)*args.nb21));
+                        device const vd4_t * pv4 = (device const vd4_t *) (v + ((ic + NE*cc + ty)*args.nb21));
 
                         FOR_UNROLL (short ii = 0; ii < DV4/NL; ++ii) {
                             const short i = ii*NL + tx;
@@ -5520,13 +5668,14 @@ kernel void kernel_flash_attn_ext_vec(
         device const char * v,
         device const char * mask,
         device const char * sinks,
+        device const char * pad,
         device       char * dst,
         threadgroup  half * shmem_f16 [[threadgroup(0)]],
         uint3   tgpig[[threadgroup_position_in_grid]],
         ushort  tiisg[[thread_index_in_simdgroup]],
         ushort  sgitg[[simdgroup_index_in_threadgroup]]) {
 #define FWD_TMPL q4_t, k4_t, v4_t, qk_t, s_t, s4_t, o4_t, kd4_t, nl_k, deq_k_t4, vd4_t, nl_v, deq_v_t4, DK, DV, NE, Q, C
-#define FWD_ARGS args, q, k, v, mask, sinks, dst, shmem_f16, tgpig, tiisg, sgitg
+#define FWD_ARGS args, q, k, v, mask, sinks, pad, dst, shmem_f16, tgpig, tiisg, sgitg
     switch (FC_flash_attn_ext_vec_nsg) {
       // note: disabled cases to reduce library load time
         case 1:  kernel_flash_attn_ext_vec_impl<FWD_TMPL,  1>(FWD_ARGS); break;