]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
vulkan: fix topk_moe_sigmoid_norm_bias failures in GLM-4.6 (#18582)
authorJeff Bolz <redacted>
Mon, 5 Jan 2026 10:51:39 +0000 (04:51 -0600)
committerGitHub <redacted>
Mon, 5 Jan 2026 10:51:39 +0000 (11:51 +0100)
ggml/src/ggml-vulkan/vulkan-shaders/topk_moe.comp
tests/test-backend-ops.cpp

index 4bf6d2bcb03eb6b86918622edc5d58f6224974c2..ef2f202ec9b6a005b7c862bbc4beba7a6f85f222 100644 (file)
@@ -101,6 +101,10 @@ void main() {
     const uint lane = gl_SubgroupInvocationID;
 
     float probs[experts_per_thread];
+    [[unroll]]
+    for (int i = 0; i < experts_per_thread; i++) {
+        probs[i] = -INFINITY;
+    }
 
     [[unroll]]
     for (uint i = 0; i < n_experts; i += WARP_SIZE) {
@@ -112,8 +116,9 @@ void main() {
         softmax_warp_inplace(probs, n_experts, lane, nexperts_use_push);
     } else if (gating_func == GATING_FUNC_SIGMOID) {
         [[unroll]]
-        for (int i = 0; i < experts_per_thread; i++) {
-            probs[i] = 1.f / (1.f + exp(-probs[i]));
+        for (uint i = 0; i < n_experts; i += WARP_SIZE) {
+            const uint expert = i + lane;
+            probs[i / WARP_SIZE] = (n_experts % WARP_SIZE == 0 || expert < n_experts) ? 1.f / (1.f + exp(-probs[i / WARP_SIZE])) : -INFINITY;
         }
     }
 
@@ -150,11 +155,11 @@ void main() {
         uint   max_expert = lane;
 
         [[unroll]]
-        for (int i = 1; i < experts_per_thread; i++) {
-            const uint expert = lane + i * WARP_SIZE;
-            if ((n_experts % WARP_SIZE == 0 || expert < n_experts) && selection_probs[i] > max_val_s) {
-                max_val    = probs[i];
-                max_val_s  = selection_probs[i];
+        for (uint i = WARP_SIZE; i < n_experts; i += WARP_SIZE) {
+            const uint expert = i + lane;
+            if ((n_experts % WARP_SIZE == 0 || expert < n_experts) && selection_probs[i / WARP_SIZE] > max_val_s) {
+                max_val    = probs[i / WARP_SIZE];
+                max_val_s  = selection_probs[i / WARP_SIZE];
                 max_expert = expert;
             }
         }
index 8df994e91c2845ed0eecb6b7db5f560cdd7fbb65..15567abedcf28ad5f91b9ab2811cf1afecdb3cd2 100644 (file)
@@ -8184,6 +8184,7 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
                     test_cases.emplace_back(new test_topk_moe({71, 22, 1, 1}, 8, with_norm, bias_probs, gate, scale_w));
                     test_cases.emplace_back(new test_topk_moe({128, 1, 1, 1}, 128, with_norm, bias_probs, gate, scale_w));
                     test_cases.emplace_back(new test_topk_moe({129, 1, 1, 1}, 128, with_norm, bias_probs, gate, scale_w));
+                    test_cases.emplace_back(new test_topk_moe({160, 4, 1, 1}, 160, with_norm, bias_probs, gate, scale_w));
                 }
             }
         }