static constexpr uint32_t flash_attention_num_small_rows = 32;
static constexpr uint32_t scalar_flash_attention_num_small_rows = 1;
-static uint32_t get_fa_scalar_num_large_rows(uint32_t hsv) {
+static uint32_t get_fa_scalar_num_large_rows(uint32_t hsk, uint32_t hsv) {
if (hsv >= 192) {
return 2;
+ } else if ((hsv | hsk) & 8) {
+ return 4;
} else {
return 8;
}
if ((hsv | hsk) & 8) {
// HSV/HSK not being a multiple of 16 makes D_split smaller, which makes cols_per_iter
// larger, and Bc needs to be >= cols_per_thread. 64 is large enough, 32 is not.
- return {get_fa_scalar_num_large_rows(hsv), 64};
+ return {get_fa_scalar_num_large_rows(hsk, hsv), 64};
} else {
- return {get_fa_scalar_num_large_rows(hsv), 32};
+ return {get_fa_scalar_num_large_rows(hsk, hsv), 32};
}
}
}
// Needs to be kept up to date on shader changes
GGML_UNUSED(hsv);
const uint32_t wg_size = scalar_flash_attention_workgroup_size;
- const uint32_t Br = get_fa_scalar_num_large_rows(hsv);
+ const uint32_t Br = get_fa_scalar_num_large_rows(hsk, hsv);
const uint32_t Bc = scalar_flash_attention_Bc;
const uint32_t tmpsh = wg_size * sizeof(float);
case FA_SCALAR:
case FA_COOPMAT1:
// We may switch from coopmat1 to scalar, so use the scalar limit for both
- max_gqa = get_fa_scalar_num_large_rows(HSV);
+ max_gqa = get_fa_scalar_num_large_rows(HSK, HSV);
break;
case FA_COOPMAT2:
max_gqa = get_fa_num_small_rows(FA_COOPMAT2);