]> git.djapps.eu Git - pkg/ggml/sources/whisper.cpp/commitdiff
vulkan: Use fewer rows for scalar FA when HS is not a multiple of 16 (llama/17455)
authorJeff Bolz <redacted>
Tue, 25 Nov 2025 06:11:27 +0000 (00:11 -0600)
committerGeorgi Gerganov <redacted>
Fri, 12 Dec 2025 15:53:07 +0000 (17:53 +0200)
ggml/src/ggml-vulkan/ggml-vulkan.cpp

index d78c727e53bd2b989d5d40ee44b06e7a0b387e53..6cf15b43bb33c98b3932dcc34b0ed9aa3b4891ba 100644 (file)
@@ -2501,9 +2501,11 @@ static void ggml_vk_wait_events(vk_context& ctx, std::vector<vk::Event>&& events
 static constexpr uint32_t flash_attention_num_small_rows = 32;
 static constexpr uint32_t scalar_flash_attention_num_small_rows = 1;
 
-static uint32_t get_fa_scalar_num_large_rows(uint32_t hsv) {
+static uint32_t get_fa_scalar_num_large_rows(uint32_t hsk, uint32_t hsv) {
     if (hsv >= 192) {
         return 2;
+    } else if ((hsv | hsk) & 8) {
+        return 4;
     } else {
         return 8;
     }
@@ -2535,9 +2537,9 @@ static std::array<uint32_t, 2> fa_rows_cols(FaCodePath path, uint32_t hsk, uint3
             if ((hsv | hsk) & 8) {
                 // HSV/HSK not being a multiple of 16 makes D_split smaller, which makes cols_per_iter
                 // larger, and Bc needs to be >= cols_per_thread. 64 is large enough, 32 is not.
-                return {get_fa_scalar_num_large_rows(hsv), 64};
+                return {get_fa_scalar_num_large_rows(hsk, hsv), 64};
             } else {
-                return {get_fa_scalar_num_large_rows(hsv), 32};
+                return {get_fa_scalar_num_large_rows(hsk, hsv), 32};
             }
         }
     }
@@ -7740,7 +7742,7 @@ static bool ggml_vk_flash_attn_scalar_shmem_support(const vk_device& device, con
     // Needs to be kept up to date on shader changes
     GGML_UNUSED(hsv);
     const uint32_t wg_size = scalar_flash_attention_workgroup_size;
-    const uint32_t Br = get_fa_scalar_num_large_rows(hsv);
+    const uint32_t Br = get_fa_scalar_num_large_rows(hsk, hsv);
     const uint32_t Bc = scalar_flash_attention_Bc;
 
     const uint32_t tmpsh = wg_size * sizeof(float);
@@ -7871,7 +7873,7 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
     case FA_SCALAR:
     case FA_COOPMAT1:
         // We may switch from coopmat1 to scalar, so use the scalar limit for both
-        max_gqa = get_fa_scalar_num_large_rows(HSV);
+        max_gqa = get_fa_scalar_num_large_rows(HSK, HSV);
         break;
     case FA_COOPMAT2:
         max_gqa = get_fa_num_small_rows(FA_COOPMAT2);