From: Ruben Ortlam Date: Thu, 19 Feb 2026 13:59:16 +0000 (+0100) Subject: vulkan: fix MMQ shader push constants and multi-dispatch (llama/19732) X-Git-Tag: v0.9.8~118 X-Git-Url: https://git.djapps.eu/?a=commitdiff_plain;h=defde0e7c706e2e2607150ccb7be738e2f24d0bb;p=pkg%2Fggml%2Fsources%2Fggml vulkan: fix MMQ shader push constants and multi-dispatch (llama/19732) --- diff --git a/src/ggml-vulkan/vulkan-shaders/mul_mmq.comp b/src/ggml-vulkan/vulkan-shaders/mul_mmq.comp index 335d7f6a..aae1c2e8 100644 --- a/src/ggml-vulkan/vulkan-shaders/mul_mmq.comp +++ b/src/ggml-vulkan/vulkan-shaders/mul_mmq.comp @@ -57,6 +57,8 @@ layout (push_constant) uniform parameter uint nbi1; uint ne11; #else + uint base_work_group_z; + uint num_batches; uint k_split; uint ne02; uint ne12; @@ -108,7 +110,7 @@ void main() { const uint ic = gl_WorkGroupID.y; #ifdef MUL_MAT_ID - const uint expert_idx = gl_GlobalInvocationID.z; + const uint expert_idx = gl_WorkGroupID.z; if (ic * BN >= data_expert_count[expert_idx]) { return; } @@ -118,7 +120,7 @@ void main() { #endif #ifndef MUL_MAT_ID - const uint batch_idx = gl_GlobalInvocationID.z; + const uint batch_idx = gl_WorkGroupID.z + p.base_work_group_z; const uint i13 = batch_idx / p.ne12; const uint i12 = batch_idx % p.ne12; @@ -276,7 +278,7 @@ void main() { const uint dc = ic * BN + warp_c * WN; #ifndef MUL_MAT_ID - const uint offsets = batch_idx * p.batch_stride_d + ik * p.batch_stride_d * gl_NumWorkGroups.z; + const uint offsets = batch_idx * p.batch_stride_d + ik * p.batch_stride_d * p.num_batches; #endif [[unroll]] for (uint wsic = 0; wsic < WNITER; wsic++) {