const bool has_mask = op->src[3] != nullptr;
if (ggml_metal_op_flash_attn_ext_use_vec(op)) {
- const bool has_kvpad = ne11 % OP_FLASH_ATTN_EXT_VEC_NCPSG != 0;
+ // note: always reserve the padding space to avoid graph reallocations
+ //const bool has_kvpad = ne11 % OP_FLASH_ATTN_EXT_VEC_NCPSG != 0;
+ const bool has_kvpad = true;
if (has_kvpad) {
res += OP_FLASH_ATTN_EXT_VEC_NCPSG*(
(has_mask ? ggml_type_size(GGML_TYPE_F16)*ne31*ne32*ne33 : 0));
}
} else {
- const bool has_kvpad = ne11 % OP_FLASH_ATTN_EXT_NCPSG != 0;
+ //const bool has_kvpad = ne11 % OP_FLASH_ATTN_EXT_NCPSG != 0;
+ const bool has_kvpad = true;
if (has_kvpad) {
res += OP_FLASH_ATTN_EXT_NCPSG*(
const bool is_vec = ggml_metal_op_flash_attn_ext_use_vec(op);
// this optimization is not useful for the vector kernels
- if (is_vec) {
- return res;
- }
+ // note: always reserve the blk buffer to avoid graph reallocations
+ //if (is_vec) {
+ // return res;
+ //}
const int nqptg = is_vec ? OP_FLASH_ATTN_EXT_VEC_NQPTG : OP_FLASH_ATTN_EXT_NQPTG;
const int ncpsg = is_vec ? OP_FLASH_ATTN_EXT_VEC_NCPSG : OP_FLASH_ATTN_EXT_NCPSG;
size_t res = 0;
- if (ggml_metal_op_flash_attn_ext_use_vec(op)) {
+ // note: always reserve the temp buffer to avoid graph reallocations
+ //if (ggml_metal_op_flash_attn_ext_use_vec(op)) {
+ if (true) {
const int64_t nwg = 32;
+ const int64_t ne01_max = std::min(ne01, 32);
// temp buffer for writing the results from each workgroup
// - ne20: the size of the Value head
// - + 2: the S and M values for each intermediate result
- res += ggml_type_size(GGML_TYPE_F32)*(ne01*ne02*ne03*nwg*(ne20 + 2));
+ res += ggml_type_size(GGML_TYPE_F32)*(ne01_max*ne02*ne03*nwg*(ne20 + 2));
}
return res;