bool use_vec_kernel = false;
+ // TODO: add vec kernels for (ne00%64 == 0) and maybe also for (ne00%32 == 0)
+ // for now avoiding mainly to keep the number of templates/kernels a bit lower
if (ne01 >= 4 || (ne00%128 != 0)) {
switch (src1->type) {
case GGML_TYPE_F16:
const short D4 = D/4;
const short D16 = D/16;
const short NW = N_SIMDWIDTH;
- const short NL = NW/4;
- const short SH = 2*C; // shared memory per simdgroup
+ const short NL = NW/4; // note: this can be adjusted to support D%64 == 0 and D%32 == 0
+ const short SH = 2*C; // shared memory per simdgroup
const short T = D + nsg*SH; // shared memory size per query in (half)
// Q*K^T
{
- // each simdgroup processes 1 query and 4 keys
+ // each simdgroup processes 1 query and 4 (NW/NL) keys
for (short cc = 0; cc < C/4; ++cc) {
qk_t mqka[4] = { 0.0, 0.0, 0.0, 0.0 };
half, half4, half4x4, \
half4x4
-typedef decltype(kernel_flash_attn_ext_vec<FA_TYPES, half4x4, 1, dequantize_f16, half4x4, 1, dequantize_f16, 64>) flash_attn_ext_vec_t;
+typedef decltype(kernel_flash_attn_ext_vec<FA_TYPES, half4x4, 1, dequantize_f16, half4x4, 1, dequantize_f16, 128>) flash_attn_ext_vec_t;
template [[host_name("kernel_flash_attn_ext_vec_f16_h128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, half4x4, 1, dequantize_f16, half4x4, 1, dequantize_f16, 128>;
#if defined(GGML_METAL_USE_BF16)