]> git.djapps.eu Git - pkg/ggml/sources/whisper.cpp/commitdiff
vulkan: Adjust coopmat2 tile sizes and selection heuristic (llama/12258)
authorJeff Bolz <redacted>
Mon, 17 Mar 2025 09:35:00 +0000 (04:35 -0500)
committerGeorgi Gerganov <redacted>
Thu, 27 Mar 2025 09:06:03 +0000 (11:06 +0200)
ggml/src/ggml-vulkan/ggml-vulkan.cpp

index ff53bdfbe171c42259036e1c2152ead4fab4fa87..e46007a52f56e74c4963b2421b42dab671c028d9 100644 (file)
@@ -1476,26 +1476,26 @@ static void ggml_vk_load_shaders(vk_device& device) {
         // spec constants and tile sizes for quant matmul (non-Qi_K)
         l_warptile_mmq = { 256, 128, 256, 64 };
         m_warptile_mmq = { 256, 128, 128, 64 };
-        s_warptile_mmq = { 256, 128, 128, 64 };
+        s_warptile_mmq = { 256, 32,  64, 128 };
         l_mmq_wg_denoms = { 128, 256, 1 };
         m_mmq_wg_denoms = { 128, 128, 1 };
-        s_mmq_wg_denoms = { 128, 128, 1 };
+        s_mmq_wg_denoms = { 32,  64,  1 };
 
         // spec constants and tile sizes for quant matmul (Qi_K)
-        l_warptile_mmq_k = { 256, 128, 512, 16 };
-        m_warptile_mmq_k = { 256, 128, 256, 16 };
-        s_warptile_mmq_k = { 256, 32, 128, 64 };
-        l_mmq_wg_denoms_k = { 128, 512, 1 };
-        m_mmq_wg_denoms_k = { 128, 256, 1 };
-        s_mmq_wg_denoms_k = { 32, 128, 1 };
+        l_warptile_mmq_k = { 256, 64, 128, 64 };
+        m_warptile_mmq_k = { 256, 32,  64, 64 };
+        s_warptile_mmq_k = { 256, 32,  32, 128 };
+        l_mmq_wg_denoms_k = { 64, 128, 1 };
+        m_mmq_wg_denoms_k = { 32,  64, 1 };
+        s_mmq_wg_denoms_k = { 32,  32, 1 };
 
         // spec constants and tile sizes for quant matmul_id
-        l_warptile_mmqid = { 256, 128, 128, 16 };
+        l_warptile_mmqid = { 256, 128, 64, 16 };
         m_warptile_mmqid = { 256, 128, 64, 16 };
-        s_warptile_mmqid = { 256, 64, 64, 16 };
-        l_mmqid_wg_denoms = { 128, 128, 1 };
+        s_warptile_mmqid = { 256, 128, 64, 16 };
+        l_mmqid_wg_denoms = { 128, 64, 1 };
         m_mmqid_wg_denoms = { 128, 64, 1 };
-        s_mmqid_wg_denoms = { 64, 64, 1 };
+        s_mmqid_wg_denoms = { 128, 64, 1 };
 
         l_align = 128;
         m_align =  64;
@@ -3850,10 +3850,14 @@ static vk_pipeline ggml_vk_guess_matmul_pipeline(ggml_backend_vk_context * ctx,
     VK_LOG_DEBUG("ggml_vk_guess_matmul_pipeline(" << m << ", " << n << ", " << aligned << ", " << ggml_type_name(src0_type) << ")");
 
     if (ctx->device->coopmat2) {
-        if ((ctx->device->mul_mat_l[src0_type] && (m % mmp->l->wg_denoms[0]) == 0 && (n % mmp->l->wg_denoms[1]) == 0) || (!ctx->device->mul_mat_m[src0_type] && !ctx->device->mul_mat_s[src0_type])) {
+        // Use large shader when the N dimension is greater than the medium shader's tile size
+        uint32_t crossover_large = mmp->m->wg_denoms[1];
+        if ((ctx->device->mul_mat_l[src0_type] && (n > crossover_large)) || (!ctx->device->mul_mat_m[src0_type] && !ctx->device->mul_mat_s[src0_type])) {
             return aligned ? mmp->a_l : mmp->l;
         }
-        if ((ctx->device->mul_mat_m[src0_type] && (m % mmp->m->wg_denoms[0]) == 0 && (n % mmp->m->wg_denoms[1]) == 0) || !ctx->device->mul_mat_s[src0_type]) {
+        // Use medium shader when the N dimension is greater than the small shader's tile size
+        uint32_t crossover_medium = mmp->s->wg_denoms[1];
+        if ((ctx->device->mul_mat_m[src0_type] && (n > crossover_medium)) || !ctx->device->mul_mat_s[src0_type]) {
             return aligned ? mmp->a_m : mmp->m;
         }
         return aligned ? mmp->a_s : mmp->s;
@@ -3898,13 +3902,17 @@ static void ggml_vk_matmul(
 }
 
 static vk_pipeline ggml_vk_guess_matmul_id_pipeline(ggml_backend_vk_context * ctx, vk_matmul_pipeline& mmp, int m, int n, bool aligned, ggml_type src0_type) {
-    VK_LOG_DEBUG("ggml_vk_guess_matmul_pipeline(" << m << ", " << n << ", " << aligned << ", " << ggml_type_name(src0_type) << ")");
+    VK_LOG_DEBUG("ggml_vk_guess_matmul_id_pipeline(" << m << ", " << n << ", " << aligned << ", " << ggml_type_name(src0_type) << ")");
 
     if (ctx->device->coopmat2) {
-        if ((ctx->device->mul_mat_id_l[src0_type] && (m % mmp->l->wg_denoms[0]) == 0 && (n % mmp->l->wg_denoms[1]) == 0) || (!ctx->device->mul_mat_id_m[src0_type] && !ctx->device->mul_mat_id_s[src0_type])) {
+        // Use large shader when the N dimension is greater than the medium shader's tile size
+        uint32_t crossover_large = mmp->m->wg_denoms[1];
+        if ((ctx->device->mul_mat_id_l[src0_type] && (n > crossover_large)) || (!ctx->device->mul_mat_id_m[src0_type] && !ctx->device->mul_mat_id_s[src0_type])) {
             return aligned ? mmp->a_l : mmp->l;
         }
-        if ((ctx->device->mul_mat_id_m[src0_type] && (m % mmp->m->wg_denoms[0]) == 0 && (n % mmp->m->wg_denoms[1]) == 0) || !ctx->device->mul_mat_id_s[src0_type]) {
+        // Use medium shader when the N dimension is greater than the small shader's tile size
+        uint32_t crossover_medium = mmp->s->wg_denoms[1];
+        if ((ctx->device->mul_mat_id_m[src0_type] && (n > crossover_medium)) || !ctx->device->mul_mat_id_s[src0_type]) {
             return aligned ? mmp->a_m : mmp->m;
         }
         return aligned ? mmp->a_s : mmp->s;