]> git.djapps.eu Git - pkg/ggml/sources/whisper.cpp/commitdiff
vulkan: fix MMQ shader push constants and multi-dispatch (llama/19732)
authorRuben Ortlam <redacted>
Thu, 19 Feb 2026 13:59:16 +0000 (14:59 +0100)
committerGeorgi Gerganov <redacted>
Fri, 27 Feb 2026 18:57:58 +0000 (20:57 +0200)
ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq.comp

index 335d7f6a68273095416e1ac9a6aa235ee57046ed..aae1c2e8ae9fc51543b9cd54c5fc89f011712fad 100644 (file)
@@ -57,6 +57,8 @@ layout (push_constant) uniform parameter
     uint nbi1;
     uint ne11;
 #else
+    uint base_work_group_z;
+    uint num_batches;
     uint k_split;
     uint ne02;
     uint ne12;
@@ -108,7 +110,7 @@ void main() {
     const uint ic = gl_WorkGroupID.y;
 
 #ifdef MUL_MAT_ID
-    const uint expert_idx = gl_GlobalInvocationID.z;
+    const uint expert_idx = gl_WorkGroupID.z;
     if (ic * BN >= data_expert_count[expert_idx]) {
         return;
     }
@@ -118,7 +120,7 @@ void main() {
 #endif
 
 #ifndef MUL_MAT_ID
-    const uint batch_idx = gl_GlobalInvocationID.z;
+    const uint batch_idx = gl_WorkGroupID.z + p.base_work_group_z;
 
     const uint i13 = batch_idx / p.ne12;
     const uint i12 = batch_idx % p.ne12;
@@ -276,7 +278,7 @@ void main() {
     const uint dc = ic * BN + warp_c * WN;
 
 #ifndef MUL_MAT_ID
-    const uint offsets = batch_idx * p.batch_stride_d + ik * p.batch_stride_d * gl_NumWorkGroups.z;
+    const uint offsets = batch_idx * p.batch_stride_d + ik * p.batch_stride_d * p.num_batches;
 #endif
 
     [[unroll]] for (uint wsic = 0; wsic < WNITER; wsic++) {