]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
vulkan/cuda: fix topk_moe with exp_probs_b (#18071)
authorJeff Bolz <redacted>
Sun, 21 Dec 2025 09:27:34 +0000 (03:27 -0600)
committerGitHub <redacted>
Sun, 21 Dec 2025 09:27:34 +0000 (10:27 +0100)
I updated test_topk_moe to more closely match llm_graph_context::build_moe_ffn
and added coverage for exp_probs_b and some other missing combinations. This
exposed a bug in both CUDA and Vulkan backends where they were assuming the
input to argsort and the input to get_rows are the same. I'd like to optimize
this graph in another change, but for now just get it functional.

CUDA also had a bug where it got n_experts from the wrong place, leading to
GGML_ASSERT failures in some of the new tests.

ggml/src/ggml-cuda/ggml-cuda.cu
ggml/src/ggml-cuda/topk-moe.cu
ggml/src/ggml-cuda/topk-moe.cuh
ggml/src/ggml-vulkan/ggml-vulkan.cpp
tests/test-backend-ops.cpp

index ab0f6fe9ce92a3d59cf8142e29e2c0f57668d240..55fa2e6a7cc3473f2172ef2bae13ca9536c95af0 100644 (file)
@@ -3076,8 +3076,11 @@ static bool ggml_cuda_can_fuse(const struct ggml_cgraph * cgraph, int node_idx,
         ggml_can_fuse_subgraph(cgraph, node_idx, ops, { node_idx + 3, node_idx + 9 })) {
         ggml_tensor * softmax = cgraph->nodes[node_idx];
         ggml_tensor * weights = cgraph->nodes[node_idx + 9];
+        ggml_tensor * get_rows = cgraph->nodes[node_idx + 4];
+        ggml_tensor * argsort = cgraph->nodes[node_idx + 2];
+        int n_expert = cgraph->nodes[node_idx]->src[0]->ne[0];
 
-        if (ggml_cuda_should_use_topk_moe(softmax, weights)) {
+        if (ggml_cuda_should_use_topk_moe(softmax, weights, get_rows, argsort, nullptr, n_expert)) {
             return true;
         }
     }
@@ -3085,7 +3088,11 @@ static bool ggml_cuda_can_fuse(const struct ggml_cgraph * cgraph, int node_idx,
     if (is_equal(topk_moe_ops, ops) && ggml_can_fuse_subgraph(cgraph, node_idx, ops, { node_idx + 3, node_idx + 4 })) {
         ggml_tensor * softmax = cgraph->nodes[node_idx];
         ggml_tensor * weights = cgraph->nodes[node_idx + 4];
-        if (ggml_cuda_should_use_topk_moe(softmax, weights)) {
+        ggml_tensor * get_rows = cgraph->nodes[node_idx + 4];
+        ggml_tensor * argsort = cgraph->nodes[node_idx + 2];
+        int n_expert = cgraph->nodes[node_idx]->src[0]->ne[0];
+
+        if (ggml_cuda_should_use_topk_moe(softmax, weights, get_rows, argsort, nullptr, n_expert)) {
             return true;
         }
     }
@@ -3094,8 +3101,11 @@ static bool ggml_cuda_can_fuse(const struct ggml_cgraph * cgraph, int node_idx,
         ggml_can_fuse_subgraph(cgraph, node_idx, ops, { node_idx + 1, node_idx + 5 })) {
         ggml_tensor * softmax = cgraph->nodes[node_idx + 4];
         ggml_tensor * weights = cgraph->nodes[node_idx + 5];
+        ggml_tensor * get_rows = cgraph->nodes[node_idx + 2];
+        ggml_tensor * argsort = cgraph->nodes[node_idx + 0];
+        int n_expert = cgraph->nodes[node_idx]->src[0]->ne[0];
 
-        if (ggml_cuda_should_use_topk_moe(softmax, weights)) {
+        if (ggml_cuda_should_use_topk_moe(softmax, weights, get_rows, argsort, nullptr, n_expert)) {
             return true;
         }
     }
index 572379fcbf0e81ba031a8c6fe92dba6555ba850b..48e569efa0d989fd5a2c0b6707d8b9054fa5b9fe 100644 (file)
@@ -268,7 +268,23 @@ void ggml_cuda_op_topk_moe(ggml_backend_cuda_context & ctx,
     }
 }
 
-bool ggml_cuda_should_use_topk_moe(const ggml_tensor * softmax, const ggml_tensor * weights, const ggml_tensor * clamp) {
+bool ggml_cuda_should_use_topk_moe(const ggml_tensor * softmax,
+                                   const ggml_tensor * weights,
+                                   const ggml_tensor * get_rows,
+                                   const ggml_tensor * argsort,
+                                   const ggml_tensor * clamp,
+                                   int n_expert) {
+    ggml_tensor * probs = get_rows->src[0];
+    if (probs->op != GGML_OP_RESHAPE) {
+        return false;
+    }
+    probs = probs->src[0];
+    ggml_tensor * selection_probs = argsort->src[0];
+
+    if (probs != selection_probs) {
+        return false;
+    }
+
     float scale    = 1.0f;
     float max_bias = 0.0f;
 
@@ -288,7 +304,6 @@ bool ggml_cuda_should_use_topk_moe(const ggml_tensor * softmax, const ggml_tenso
         return false;
     }
 
-    const int n_expert = softmax->ne[0];
     // n_expert must be a power of 2
     if ((n_expert & (n_expert - 1)) != 0 || n_expert > 512) {
         return false;
index 2eff408b03058b82a2b27f8a95de98aad6168426..6b6c13c58706930409a4aa46fcc8589315eb56e1 100644 (file)
@@ -11,6 +11,11 @@ void ggml_cuda_op_topk_moe(ggml_backend_cuda_context & ctx,
                            const bool                  delayed_softmax = false,
                            ggml_tensor *               weight_clamp    = nullptr);
 
-bool ggml_cuda_should_use_topk_moe(const ggml_tensor * softmax, const ggml_tensor * weights, const ggml_tensor * clamp = nullptr);
+bool ggml_cuda_should_use_topk_moe(const ggml_tensor * softmax,
+                                   const ggml_tensor * weights,
+                                   const ggml_tensor * get_rows,
+                                   const ggml_tensor * argsort,
+                                   const ggml_tensor * clamp,
+                                   int n_expert);
 
 std::initializer_list<enum ggml_op> ggml_cuda_topk_moe_ops(bool with_norm, bool delayed_softmax = false);
index 1c614536362d29dc9c24ed66e6f4911ea67c5c87..a5308fa54426641624d1935da015ca78aab6a189 100644 (file)
@@ -12941,24 +12941,43 @@ static bool ggml_vk_can_fuse_topk_moe(ggml_backend_vk_context * ctx, const struc
 
     const ggml_tensor * softmax;
     const ggml_tensor * weights;
+    const ggml_tensor * get_rows;
+    const ggml_tensor * argsort;
 
     switch (mode) {
     case TOPK_MOE_EARLY_SOFTMAX_NORM:
         softmax = cgraph->nodes[node_idx + 0];
         weights = cgraph->nodes[node_idx + 9];
+        get_rows = cgraph->nodes[node_idx + 4];
+        argsort = cgraph->nodes[node_idx + 2];
         break;
     case TOPK_MOE_EARLY_SOFTMAX:
         softmax = cgraph->nodes[node_idx + 0];
         weights = cgraph->nodes[node_idx + 4];
+        get_rows = cgraph->nodes[node_idx + 4];
+        argsort = cgraph->nodes[node_idx + 2];
         break;
     case TOPK_MOE_LATE_SOFTMAX:
         softmax = cgraph->nodes[node_idx + 4];
         weights = cgraph->nodes[node_idx + 5];
+        get_rows = cgraph->nodes[node_idx + 2];
+        argsort = cgraph->nodes[node_idx + 0];
         break;
     default:
         return false;
     }
 
+    ggml_tensor * probs = get_rows->src[0];
+    if (probs->op != GGML_OP_RESHAPE) {
+        return false;
+    }
+    probs = probs->src[0];
+    ggml_tensor * selection_probs = argsort->src[0];
+
+    if (probs != selection_probs) {
+        return false;
+    }
+
     const float * op_params = (const float *)softmax->op_params;
 
     float scale = op_params[0];
index b9a73bea620cdd737e885120f892807fa2bc5d0a..5116918e7188a2024a1aaba91206498977257558 100644 (file)
@@ -5118,25 +5118,36 @@ struct test_top_k : public test_case {
     }
 };
 
+enum MoeGatingFunc {
+    GATING_FUNC_SOFTMAX,
+    GATING_FUNC_SIGMOID,
+    GATING_FUNC_SOFTMAX_WEIGHT,
+};
+
 struct test_topk_moe : public test_case {
     const std::array<int64_t, 4> ne;
     const int n_expert_used;
     const bool with_norm;
-    const bool                   delayed_softmax;
+    const bool bias_probs;
+    const MoeGatingFunc gating_func;
+    const float scale_w;
 
     test_topk_moe(std::array<int64_t, 4> ne              = { 10, 5, 1, 1 },
                   int                    n_expert_used   = 1,
                   bool                   with_norm       = false,
-                  bool                   delayed_softmax = false) :
+                  bool                   bias_probs      = false,
+                  MoeGatingFunc          gating_func     = GATING_FUNC_SOFTMAX,
+                  float                  scale_w         = 0.0f) :
         ne(ne),
         n_expert_used(n_expert_used),
         with_norm(with_norm),
-        delayed_softmax(delayed_softmax) {
+        bias_probs(bias_probs),
+        gating_func(gating_func),
+        scale_w(scale_w) {
         GGML_ASSERT(n_expert_used <= ne[0]);
-        GGML_ASSERT(!(with_norm && delayed_softmax));
     }
 
-    std::string vars() override { return VARS_TO_STR4(ne, n_expert_used, with_norm, delayed_softmax); }
+    std::string vars() override { return VARS_TO_STR6(ne, n_expert_used, with_norm, bias_probs, gating_func, scale_w); }
 
     std::string op_desc(ggml_tensor * t) override {
         GGML_UNUSED(t);
@@ -5150,28 +5161,47 @@ struct test_topk_moe : public test_case {
         const int n_tokens = ne[1];
 
         ggml_tensor * logits = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne.data());
-        ggml_tensor * probs            = delayed_softmax ? logits : ggml_soft_max(ctx, logits);
-        ggml_tensor * selected_experts = ggml_argsort_top_k(ctx, probs, n_expert_used); // [n_expert_used, n_tokens]
+        ggml_tensor * probs            =
+            (gating_func == GATING_FUNC_SOFTMAX) ? ggml_soft_max(ctx, logits) :
+            (gating_func == GATING_FUNC_SIGMOID) ? ggml_sigmoid(ctx, logits) : logits;
+        ggml_set_name(probs, "probs");
+
+        ggml_tensor * selection_probs = probs;
+        if (bias_probs) {
+            ggml_tensor * exp_probs_b = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne.data());
+            ggml_set_name(exp_probs_b, "exp_probs_b");
+            selection_probs = ggml_add(ctx, probs, exp_probs_b);
+            ggml_set_name(selection_probs, "selection_probs");
+        }
 
-        ggml_tensor * out = ggml_get_rows(ctx, ggml_reshape_3d(ctx, probs, 1, n_expert, n_tokens), selected_experts); // [1, n_expert_used, n_tokens]
+        ggml_tensor * selected_experts = ggml_argsort_top_k(ctx, selection_probs, n_expert_used); // [n_expert_used, n_tokens]
+        ggml_set_name(selected_experts, "selected_experts");
 
-        if (delayed_softmax) {
-            out = ggml_reshape_2d(ctx, out, n_expert_used, n_tokens);
-            out = ggml_soft_max(ctx, out);  // [n_expert_used, n_tokens]
-            out = ggml_reshape_3d(ctx, out, 1, n_expert_used, n_tokens);
+        ggml_tensor * weights = ggml_get_rows(ctx, ggml_reshape_3d(ctx, probs, 1, n_expert, n_tokens), selected_experts); // [1, n_expert_used, n_tokens]
+        ggml_set_name(weights, "weights");
+
+        if (gating_func == GATING_FUNC_SOFTMAX_WEIGHT) {
+            weights = ggml_reshape_2d(ctx, weights, n_expert_used, n_tokens);
+            weights = ggml_soft_max(ctx, weights);  // [n_expert_used, n_tokens]
+            weights = ggml_reshape_3d(ctx, weights, 1, n_expert_used, n_tokens);
         }
 
         if (with_norm) {
-            out = ggml_reshape_2d(ctx, out, n_expert_used, n_tokens);
-            ggml_tensor * weights_sum = ggml_sum_rows(ctx, out); // [1, n_tokens]
+            weights = ggml_reshape_2d(ctx, weights, n_expert_used, n_tokens);
+            ggml_tensor * weights_sum = ggml_sum_rows(ctx, weights); // [1, n_tokens]
+            ggml_set_name(weights_sum, "weights_sum");
 
             weights_sum = ggml_clamp(ctx, weights_sum, 6.103515625e-5, INFINITY);
-            out = ggml_div(ctx, out, weights_sum); // [n_expert_used, n_tokens]
-            out = ggml_reshape_3d(ctx, out, 1, n_expert_used, n_tokens);
+            weights = ggml_div(ctx, weights, weights_sum); // [n_expert_used, n_tokens]
+            weights = ggml_reshape_3d(ctx, weights, 1, n_expert_used, n_tokens);
         }
 
-        ggml_set_name(out, "out");
-        return out;
+        if (scale_w) {
+            weights = ggml_scale(ctx, weights, scale_w);
+        }
+
+        ggml_set_name(weights, "weights");
+        return weights;
     }
 };
 
@@ -7991,19 +8021,22 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
         }
     }
 
-    for (bool with_norm : {false, true}) {
-        test_cases.emplace_back(new test_topk_moe({8, 22, 1, 1}, 4, with_norm));
-        test_cases.emplace_back(new test_topk_moe({31, 22, 1, 1}, 8, with_norm));
-        test_cases.emplace_back(new test_topk_moe({32, 22, 1, 1}, 8, with_norm));
-        test_cases.emplace_back(new test_topk_moe({40, 22, 1, 1}, 8, with_norm));
-        test_cases.emplace_back(new test_topk_moe({71, 22, 1, 1}, 8, with_norm));
-        test_cases.emplace_back(new test_topk_moe({128, 1, 1, 1}, 128, with_norm));
-        test_cases.emplace_back(new test_topk_moe({129, 1, 1, 1}, 128, with_norm));
+    for (auto gate : {GATING_FUNC_SOFTMAX, GATING_FUNC_SIGMOID, GATING_FUNC_SOFTMAX_WEIGHT}) {
+        for (bool with_norm : {false, true}) {
+            for (bool bias_probs : {false, true}) {
+                for (float scale_w : {0.0f, 2.0f}) {
+                    test_cases.emplace_back(new test_topk_moe({8, 22, 1, 1}, 4, with_norm, bias_probs, gate, scale_w));
+                    test_cases.emplace_back(new test_topk_moe({31, 22, 1, 1}, 8, with_norm, bias_probs, gate, scale_w));
+                    test_cases.emplace_back(new test_topk_moe({32, 22, 1, 1}, 8, with_norm, bias_probs, gate, scale_w));
+                    test_cases.emplace_back(new test_topk_moe({40, 22, 1, 1}, 8, with_norm, bias_probs, gate, scale_w));
+                    test_cases.emplace_back(new test_topk_moe({71, 22, 1, 1}, 8, with_norm, bias_probs, gate, scale_w));
+                    test_cases.emplace_back(new test_topk_moe({128, 1, 1, 1}, 128, with_norm, bias_probs, gate, scale_w));
+                    test_cases.emplace_back(new test_topk_moe({129, 1, 1, 1}, 128, with_norm, bias_probs, gate, scale_w));
+                }
+            }
+        }
     }
 
-    test_cases.emplace_back(new test_topk_moe({ 8, 22, 1, 1 }, 4, /*with_norm*/ false, /*delayed_softmax*/ true));
-    test_cases.emplace_back(new test_topk_moe({ 32, 22, 1, 1 }, 8, /*with_norm*/ false, /*delayed_softmax*/ true));
-
 #if 0
     // these tests are disabled to save execution time, sbut they can be handy for debugging
     test_cases.emplace_back(new test_llama(2, true));