]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
CUDA: refactor topk-moe to enable more models (GLM 4.7, Nemotron etc.) (#19126)
authorAman Gupta <redacted>
Thu, 29 Jan 2026 02:31:28 +0000 (10:31 +0800)
committerGitHub <redacted>
Thu, 29 Jan 2026 02:31:28 +0000 (10:31 +0800)
ggml/src/ggml-cuda/ggml-cuda.cu
ggml/src/ggml-cuda/topk-moe.cu
ggml/src/ggml-cuda/topk-moe.cuh

index e9df0ea4a7c4398b5173fd12b4a0ece838270589..76d0f12550e1e306f6a790c32ab800b585efec4d 100644 (file)
@@ -3080,63 +3080,166 @@ static bool ggml_cuda_should_fuse_rope_set_rows(const ggml_tensor * rope,
     return true;
 }
 
-static bool ggml_cuda_can_fuse(const struct ggml_cgraph * cgraph, int node_idx, std::initializer_list<enum ggml_op> ops, std::initializer_list<enum ggml_unary_op> unary_ops) {
-#ifndef NDEBUG
-    const size_t num_unary = std::count(ops.begin(), ops.end(), GGML_OP_UNARY);
-    GGML_ASSERT(unary_ops.size() == num_unary);
-#endif
-
-    //TODO: remove special case once ggml_can_fuse can handle empty nodes
-    std::initializer_list<enum ggml_op> topk_moe_ops =
-        ggml_cuda_topk_moe_ops(/*with_norm*/ false, /*delayed_softmax=*/false);
-    std::initializer_list<enum ggml_op> topk_moe_ops_with_norm =
-        ggml_cuda_topk_moe_ops(/*with_norm=*/true, /*delayed_softmax=*/false);
-    std::initializer_list<enum ggml_op> topk_moe_ops_delayed_softmax =
-        ggml_cuda_topk_moe_ops(/*with_norm=*/false, /*delayed_softmax=*/true);
+static bool ggml_cuda_topk_moe_fusion(const struct ggml_cgraph * cgraph, int node_idx, ggml_cuda_topk_moe_args & args) {
+    args.sigmoid         = false;
+    args.softmax         = false;
+    args.delayed_softmax = false;
+    args.prob_bias       = false;
+    args.norm            = false;
 
-    const auto is_equal = [](const std::initializer_list<enum ggml_op> & list1,
-                             const std::initializer_list<enum ggml_op> & list2) {
-        return std::equal(list1.begin(), list1.end(), list2.begin(), list2.end());
-    };
+    const int      n_nodes = cgraph->n_nodes;
+    ggml_tensor ** nodes   = cgraph->nodes;
 
-    if (is_equal(topk_moe_ops_with_norm, ops) &&
-        ggml_can_fuse_subgraph(cgraph, node_idx, ops, { node_idx + 3, node_idx + 9 })) {
-        ggml_tensor * softmax = cgraph->nodes[node_idx];
-        ggml_tensor * weights = cgraph->nodes[node_idx + 9];
-        ggml_tensor * get_rows = cgraph->nodes[node_idx + 4];
-        ggml_tensor * argsort = cgraph->nodes[node_idx + 2];
-        int n_expert = cgraph->nodes[node_idx]->src[0]->ne[0];
+    if (nodes[node_idx]->op == GGML_OP_SOFT_MAX) {
+        args.softmax = true;
+    }
 
-        if (ggml_cuda_should_use_topk_moe(softmax, weights, get_rows, argsort, nullptr, n_expert)) {
-            return true;
+    if (nodes[node_idx]->op == GGML_OP_UNARY) {
+        if (ggml_get_unary_op(nodes[node_idx]) != GGML_UNARY_OP_SIGMOID) {
+            return false;
         }
+        args.sigmoid = true;
     }
 
-    if (is_equal(topk_moe_ops, ops) && ggml_can_fuse_subgraph(cgraph, node_idx, ops, { node_idx + 3, node_idx + 4 })) {
-        ggml_tensor * softmax = cgraph->nodes[node_idx];
-        ggml_tensor * weights = cgraph->nodes[node_idx + 4];
-        ggml_tensor * get_rows = cgraph->nodes[node_idx + 4];
-        ggml_tensor * argsort = cgraph->nodes[node_idx + 2];
-        int n_expert = cgraph->nodes[node_idx]->src[0]->ne[0];
+    if (nodes[node_idx]->op == GGML_OP_ARGSORT) {
+        args.delayed_softmax = true;
+    }
 
-        if (ggml_cuda_should_use_topk_moe(softmax, weights, get_rows, argsort, nullptr, n_expert)) {
-            return true;
+    node_idx++;
+
+    if (args.sigmoid || args.softmax) {
+        // SOFTMAX -> RESHAPE
+        if (node_idx >= n_nodes || nodes[node_idx]->op != GGML_OP_RESHAPE ||
+                nodes[node_idx]->src[0] != nodes[node_idx - 1]) {
+            return false;
+        }
+        ggml_tensor * probs_reshaped = nodes[node_idx];
+        node_idx++;
+
+        if (node_idx >= n_nodes) {
+            return false;
+        }
+
+        // src of bias add is the unreshaped probs (-2 instead of -1)
+        if (nodes[node_idx]->op == GGML_OP_ADD && nodes[node_idx]->src[0] == nodes[node_idx - 2]) {
+            args.prob_bias = true;
+            node_idx++;
+        }
+        // RESHAPE/ADD -> ARGSORT
+        if (node_idx >= n_nodes || nodes[node_idx]->op != GGML_OP_ARGSORT) {
+            return false;
+        }
+
+        if (args.prob_bias && nodes[node_idx]->src[0] != nodes[node_idx - 1]) {
+            return false;
+        } else if (!args.prob_bias && nodes[node_idx]->src[0] != nodes[node_idx - 2]) {
+            return false;
+        }
+
+        node_idx++;
+
+        // ARGSORT-> VIEW
+        if (node_idx >= n_nodes || nodes[node_idx]->op != GGML_OP_VIEW ||
+                nodes[node_idx]->src[0] != nodes[node_idx - 1]) {
+            return false;
+        }
+        node_idx++;
+
+        if (node_idx >= n_nodes || nodes[node_idx]->op != GGML_OP_GET_ROWS) {
+            return false;
+        }
+
+        // GET_ROWS
+        if (nodes[node_idx]->src[0] != probs_reshaped || nodes[node_idx]->src[1] != nodes[node_idx - 1]) {
+            return false;
+        }
+        node_idx++;
+    } else if (args.delayed_softmax) {
+        if (node_idx - 2 < 0) {
+            return false;
+        }
+        ggml_tensor * probs_reshaped = nodes[node_idx - 2];
+
+        // VIEW->ARGSORT
+        if (node_idx >= n_nodes || nodes[node_idx]->op != GGML_OP_VIEW ||
+            nodes[node_idx]->src[0] != nodes[node_idx - 1]) {
+            return false;
+        }
+        node_idx++;
+
+        // GET_ROWS
+        if (node_idx >= n_nodes || nodes[node_idx]->src[1] != nodes[node_idx - 1] ||
+                nodes[node_idx]->src[0] != probs_reshaped) {
+            return false;
         }
+        node_idx++;
+
+        static const std::vector<ggml_op> remaining_ops = { GGML_OP_RESHAPE, GGML_OP_SOFT_MAX, GGML_OP_RESHAPE };
+
+        for (const ggml_op op : remaining_ops) {
+            if (node_idx >= n_nodes || nodes[node_idx]->op != op || nodes[node_idx]->src[0] != nodes[node_idx - 1]) {
+                return false;
+            }
+            node_idx++;
+        }
+    }
+
+    // At this point we can check for norm + scale. Everything is now at least valid till the norm
+    if (node_idx >= n_nodes) {
+        return true;
     }
 
-    if (is_equal(topk_moe_ops_delayed_softmax, ops) &&
-        ggml_can_fuse_subgraph(cgraph, node_idx, ops, { node_idx + 1, node_idx + 5 })) {
-        ggml_tensor * softmax = cgraph->nodes[node_idx + 4];
-        ggml_tensor * weights = cgraph->nodes[node_idx + 5];
-        ggml_tensor * get_rows = cgraph->nodes[node_idx + 2];
-        ggml_tensor * argsort = cgraph->nodes[node_idx + 0];
-        int n_expert = cgraph->nodes[node_idx]->src[0]->ne[0];
+    if (nodes[node_idx]->op == GGML_OP_RESHAPE) {
+        //check RESHAPE->SUM_ROWS->CLAMP->DIV->RESHAPE
+        static const std::vector<ggml_op> norm_ops = { GGML_OP_RESHAPE, GGML_OP_SUM_ROWS, GGML_OP_CLAMP };
 
-        if (ggml_cuda_should_use_topk_moe(softmax, weights, get_rows, argsort, nullptr, n_expert)) {
+        args.norm = true;
+        for (const ggml_op op : norm_ops) {
+            if (nodes[node_idx]->op == op && nodes[node_idx]->src[0] == nodes[node_idx - 1]) {
+                node_idx++;
+            } else {
+                args.norm = false;
+                return true;
+            }
+        }
+
+        // DIV <- CLAMP, RESHAPE
+        if (nodes[node_idx]->op != GGML_OP_DIV || nodes[node_idx]->src[1] != nodes[node_idx - 1] ||
+            nodes[node_idx]->src[0] != nodes[node_idx - 3]) {
+            args.norm = false;
+            return true;
+        }
+        node_idx++;
+
+        if (nodes[node_idx]->op != GGML_OP_RESHAPE || nodes[node_idx]->src[0] != nodes[node_idx - 1]) {
+            args.norm = false;
             return true;
         }
+
+        node_idx++;
     }
 
+    if (nodes[node_idx]->op == GGML_OP_SCALE && nodes[node_idx]->src[0] == nodes[node_idx - 1]) {
+        args.scale = true;
+    }
+
+    return true;
+}
+
+static bool ggml_cuda_can_fuse(const struct ggml_cgraph *                cgraph,
+                               int                                       node_idx,
+                               std::initializer_list<enum ggml_op>       ops,
+                               std::initializer_list<enum ggml_unary_op> unary_ops) {
+#ifndef NDEBUG
+    const size_t num_unary = std::count(ops.begin(), ops.end(), GGML_OP_UNARY);
+    GGML_ASSERT(unary_ops.size() == num_unary);
+#endif
+
+    const auto is_equal = [](const std::initializer_list<enum ggml_op> & list1,
+                             const std::initializer_list<enum ggml_op> & list2) {
+        return std::equal(list1.begin(), list1.end(), list2.begin(), list2.end());
+    };
+
     std::initializer_list<enum ggml_op> mul_mat_bias_glu_ops    = { GGML_OP_MUL_MAT,    GGML_OP_ADD,    GGML_OP_MUL_MAT,    GGML_OP_ADD,    GGML_OP_GLU };
     std::initializer_list<enum ggml_op> mul_mat_id_bias_glu_ops = { GGML_OP_MUL_MAT_ID, GGML_OP_ADD_ID, GGML_OP_MUL_MAT_ID, GGML_OP_ADD_ID, GGML_OP_GLU };
 
@@ -3398,35 +3501,75 @@ static void ggml_cuda_graph_evaluate_and_capture(ggml_backend_cuda_context * cud
                 // start of fusion operations
                 static bool disable_fusion = (getenv("GGML_CUDA_DISABLE_FUSION") != nullptr);
                 if (!disable_fusion) {
+                    ggml_cuda_topk_moe_args args;
+
+                    if (cgraph->nodes[i]->op == GGML_OP_UNARY || cgraph->nodes[i]->op == GGML_OP_SOFT_MAX ||
+                        cgraph->nodes[i]->op == GGML_OP_ARGSORT) {
+                        const bool can_fuse = ggml_cuda_topk_moe_fusion(cgraph, i, args);
+
+                        std::vector<ggml_op> ops;
+
+                        if (can_fuse) {
+                            const ggml_tensor * logits  = node->src[0];
+                            ggml_tensor *       weights = nullptr;
+                            ggml_tensor *       ids     = nullptr;
+                            const ggml_tensor * bias    = nullptr;
+                            const ggml_tensor * clamp   = nullptr;
+                            const ggml_tensor * scale   = nullptr;
+
+                            if (!args.delayed_softmax) {
+                                ggml_op gating_op = args.sigmoid ? GGML_OP_UNARY : GGML_OP_SOFT_MAX;
+                                int     out_nodes[2];  // nodes which can't be elided
+
+                                if (args.prob_bias) {
+                                    bias = cgraph->nodes[i + 2]->src[1];
+                                    ops.insert(ops.end(), { gating_op, GGML_OP_RESHAPE, GGML_OP_ADD, GGML_OP_ARGSORT,
+                                                            GGML_OP_VIEW, GGML_OP_GET_ROWS });
+                                    out_nodes[0] = i + 4;
+                                    ids          = cgraph->nodes[i + 4];
+                                } else {
+                                    ops.insert(ops.end(), { gating_op, GGML_OP_RESHAPE, GGML_OP_ARGSORT, GGML_OP_VIEW,
+                                                            GGML_OP_GET_ROWS });
+                                    out_nodes[0] = i + 3;
+                                    ids          = cgraph->nodes[i + 3];
+                                }
 
-                    if (ggml_cuda_can_fuse(cgraph, i, ggml_cuda_topk_moe_ops(/*with norm*/ true), {})) {
-                        ggml_tensor * weights          = cgraph->nodes[i + 9];
-                        ggml_tensor * selected_experts = cgraph->nodes[i + 3];
-                        ggml_tensor * clamp            = cgraph->nodes[i + 7];
-                        ggml_cuda_op_topk_moe(*cuda_ctx, node->src[0], weights, selected_experts, /*with norm*/ true,
-                                              /*delayed softmax*/ false, clamp);
-                        i += 9;
-                        continue;
-                    }
-
-                    if (ggml_cuda_can_fuse(cgraph, i, ggml_cuda_topk_moe_ops(/*with norm*/ false), {})) {
-                        ggml_tensor * weights          = cgraph->nodes[i + 4];
-                        ggml_tensor * selected_experts = cgraph->nodes[i + 3];
-                        ggml_cuda_op_topk_moe(*cuda_ctx, node->src[0], weights, selected_experts, /*with norm*/ false,
-                                              /*delayed softmax*/ false);
-                        i += 4;
-                        continue;
-                    }
+                                if (args.norm) {
+                                    ops.insert(ops.end(), { GGML_OP_RESHAPE, GGML_OP_SUM_ROWS, GGML_OP_CLAMP,
+                                                            GGML_OP_DIV, GGML_OP_RESHAPE });
+                                    clamp = cgraph->nodes[i + ops.size() - 3];
+                                }
+                                if (args.scale) {
+                                    ops.insert(ops.end(), { GGML_OP_SCALE });
+                                    scale = cgraph->nodes[i + ops.size() - 1];
+                                }
 
-                    if (ggml_cuda_can_fuse(cgraph, i,
-                                           ggml_cuda_topk_moe_ops(/*with norm*/ false, /*delayed softmax*/ true), {})) {
-                        ggml_tensor * weights = cgraph->nodes[i + 5];
-                        ggml_tensor * ids     = cgraph->nodes[i + 1];
+                                weights      = cgraph->nodes[i + ops.size() - 1];
+                                out_nodes[1] = i + ops.size() - 1;
 
-                        ggml_cuda_op_topk_moe(*cuda_ctx, node->src[0], weights, ids, /*with norm*/ false,
-                                              /*delayed_softmax*/ true);
-                        i += 5;
-                        continue;
+                                if (ggml_can_fuse_subgraph(cgraph, i, ops.size(), ops.data(), out_nodes, 2) &&
+                                        ggml_cuda_should_use_topk_moe(node, logits, weights, ids)) {
+                                    ggml_cuda_op_topk_moe(*cuda_ctx, logits, weights, ids, clamp, scale, bias, args);
+                                    i += ops.size() - 1;
+                                    continue;
+                                }
+                            } else if (!args.norm && !args.prob_bias) {
+                                //special case gpt-oss, no norm, no bias.
+                                ops.insert(ops.end(), { GGML_OP_ARGSORT, GGML_OP_VIEW, GGML_OP_GET_ROWS,
+                                                        GGML_OP_RESHAPE, GGML_OP_SOFT_MAX, GGML_OP_RESHAPE });
+                                weights                     = cgraph->nodes[i + 5];
+                                ids                         = cgraph->nodes[i + 1];
+                                const ggml_tensor * softmax = cgraph->nodes[i + 4];
+
+                                int out_nodes[2] = { i + 1, i + 5 };
+                                if (ggml_can_fuse_subgraph(cgraph, i, ops.size(), ops.data(), out_nodes, 2) &&
+                                        ggml_cuda_should_use_topk_moe(softmax, logits, weights, ids)) {
+                                    ggml_cuda_op_topk_moe(*cuda_ctx, logits, weights, ids, clamp, scale, bias, args);
+                                    i += ops.size() - 1;
+                                    continue;
+                                }
+                            }
+                        }
                     }
 
                     if (ggml_cuda_can_fuse(cgraph, i, { GGML_OP_ROPE, GGML_OP_VIEW, GGML_OP_SET_ROWS }, {})) {
index 48e569efa0d989fd5a2c0b6707d8b9054fa5b9fe..08a88990dde0f378a93210d77b613799c2455a32 100644 (file)
@@ -5,6 +5,13 @@
 #include <cmath>
 #include <initializer_list>
 
+// Kernel config struct - passed by value to CUDA kernel
+struct topk_moe_config {
+    bool use_sigmoid;
+    bool with_norm;
+    bool delayed_softmax;
+};
+
 // Warp-local softmax used for both the pre-top-k logits and the post-top-k delayed path.
 template <int experts_per_thread, bool use_limit>
 __device__ void softmax_warp_inplace(float (&vals)[experts_per_thread], const int limit, const int lane) {
@@ -50,6 +57,16 @@ __device__ void softmax_warp_inplace(float (&vals)[experts_per_thread], const in
     }
 }
 
+template <int experts_per_thread, bool use_limit>
+__device__ void sigmoid_warp_inplace(float (&vals)[experts_per_thread], const int limit, const int lane) {
+#pragma unroll
+    for (int i = 0; i < experts_per_thread; i++) {
+        const int  idx    = lane + i * WARP_SIZE;
+        const bool active = !use_limit || (idx < limit);
+        vals[i]           = active ? 1.f / (1.f + expf(-vals[i])) : -INFINITY;
+    }
+}
+
 /*
     This kernel does the following:
     1. optionally softmax over the logits per token [n_experts, n_tokens]
@@ -59,13 +76,16 @@ __device__ void softmax_warp_inplace(float (&vals)[experts_per_thread], const in
 
     It is intended as fusion of softmax->top-k->get_rows pipeline for MoE models
 */
-template <int n_experts, bool with_norm, bool delayed_softmax = false>
-__launch_bounds__(4 * WARP_SIZE, 1) __global__ void topk_moe_cuda(const float * logits,
-                                                                  float *       weights,
-                                                                  int32_t *     ids,
-                                                                  const int     n_rows,
-                                                                  const int     n_expert_used,
-                                                                  const float   clamp_val) {
+template <int n_experts, bool has_bias>
+__launch_bounds__(4 * WARP_SIZE, 1) __global__ void topk_moe_cuda(const float *         logits,
+                                                                  float *               weights,
+                                                                  int32_t *             ids,
+                                                                  float *               bias,
+                                                                  const int             n_rows,
+                                                                  const int             n_expert_used,
+                                                                  const float           clamp_val,
+                                                                  const float           scale_val,
+                                                                  const topk_moe_config config) {
     const int row = blockIdx.x * blockDim.y + threadIdx.y;
     if (row >= n_rows) {
         return;
@@ -79,14 +99,41 @@ __launch_bounds__(4 * WARP_SIZE, 1) __global__ void topk_moe_cuda(const float *
 
     float wt[experts_per_thread];
 
+    // Initialize all slots to -INFINITY
+#pragma unroll
+    for (int i = 0; i < experts_per_thread; i++) {
+        wt[i] = -INFINITY;
+    }
+
 #pragma unroll
     for (int i = 0; i < n_experts; i += WARP_SIZE) {
         const int expert  = i + threadIdx.x;
         wt[i / WARP_SIZE] = (n_experts % WARP_SIZE == 0 || expert < n_experts) ? logits[expert] : -INFINITY;
     }
 
-    if constexpr (!delayed_softmax) {
-        softmax_warp_inplace<experts_per_thread, false>(wt, n_experts, threadIdx.x);
+    if (!config.delayed_softmax) {
+        if (config.use_sigmoid) {
+           sigmoid_warp_inplace<experts_per_thread, false>(wt, n_experts, threadIdx.x);
+        } else {
+           softmax_warp_inplace<experts_per_thread, false>(wt, n_experts, threadIdx.x);
+        }
+    }
+
+    // selection_wt is only needed when bias is present (selection uses wt + bias)
+    // when no bias, we use wt directly for both selection and weight values
+    float selection_wt[has_bias ? experts_per_thread : 1];
+
+    if constexpr (has_bias) {
+#pragma unroll
+        for (int i = 0; i < experts_per_thread; i++) {
+            selection_wt[i] = -INFINITY;
+        }
+#pragma unroll
+        for (int i = 0; i < n_experts; i += WARP_SIZE) {
+            const int expert = i + threadIdx.x;
+            selection_wt[i / WARP_SIZE] =
+                (n_experts % WARP_SIZE == 0 || expert < n_experts) ? wt[i / WARP_SIZE] + bias[expert] : -INFINITY;
+        }
     }
 
     //at this point, each thread holds either a portion of the softmax distribution
@@ -106,22 +153,56 @@ __launch_bounds__(4 * WARP_SIZE, 1) __global__ void topk_moe_cuda(const float *
         float max_val    = wt[0];
         int   max_expert = threadIdx.x;
 
+        if constexpr (has_bias) {
+            float max_val_s = selection_wt[0];
+
 #pragma unroll
-        for (int i = 1; i < experts_per_thread; i++) {
-            const int expert = threadIdx.x + i * WARP_SIZE;
-            if ((n_experts % WARP_SIZE == 0 || expert < n_experts) && wt[i] > max_val) {
-                max_val    = wt[i];
-                max_expert = expert;
+            for (int i = 1; i < experts_per_thread; i++) {
+                const int expert = threadIdx.x + i * WARP_SIZE;
+                if ((n_experts % WARP_SIZE == 0 || expert < n_experts) && selection_wt[i] > max_val_s) {
+                    max_val    = wt[i];
+                    max_val_s  = selection_wt[i];
+                    max_expert = expert;
+                }
+            }
+
+#pragma unroll
+            for (int mask = WARP_SIZE / 2; mask > 0; mask /= 2) {
+                const float val    = __shfl_xor_sync(0xFFFFFFFF, max_val, mask, WARP_SIZE);
+                const float val_s  = __shfl_xor_sync(0xFFFFFFFF, max_val_s, mask, WARP_SIZE);
+                const int   expert = __shfl_xor_sync(0xFFFFFFFF, max_expert, mask, WARP_SIZE);
+                if (val_s > max_val_s || (val_s == max_val_s && expert < max_expert)) {
+                    max_val    = val;
+                    max_val_s  = val_s;
+                    max_expert = expert;
+                }
+            }
+
+            if ((max_expert & (WARP_SIZE - 1)) == threadIdx.x) {
+                selection_wt[max_expert / WARP_SIZE] = -INFINITY;
+            }
+        } else {
+#pragma unroll
+            for (int i = 1; i < experts_per_thread; i++) {
+                const int expert = threadIdx.x + i * WARP_SIZE;
+                if ((n_experts % WARP_SIZE == 0 || expert < n_experts) && wt[i] > max_val) {
+                    max_val    = wt[i];
+                    max_expert = expert;
+                }
             }
-        }
 
 #pragma unroll
-        for (int mask = WARP_SIZE / 2; mask > 0; mask /= 2) {
-            const float val    = __shfl_xor_sync(0xFFFFFFFF, max_val, mask, WARP_SIZE);
-            const int   expert = __shfl_xor_sync(0xFFFFFFFF, max_expert, mask, WARP_SIZE);
-            if (val > max_val || (val == max_val && expert < max_expert)) {
-                max_val    = val;
-                max_expert = expert;
+            for (int mask = WARP_SIZE / 2; mask > 0; mask /= 2) {
+                const float val    = __shfl_xor_sync(0xFFFFFFFF, max_val, mask, WARP_SIZE);
+                const int   expert = __shfl_xor_sync(0xFFFFFFFF, max_expert, mask, WARP_SIZE);
+                if (val > max_val || (val == max_val && expert < max_expert)) {
+                    max_val    = val;
+                    max_expert = expert;
+                }
+            }
+
+            if ((max_expert & (WARP_SIZE - 1)) == threadIdx.x) {
+                wt[max_expert / WARP_SIZE] = -INFINITY;
             }
         }
 
@@ -130,16 +211,14 @@ __launch_bounds__(4 * WARP_SIZE, 1) __global__ void topk_moe_cuda(const float *
         }
 
         if ((max_expert & (WARP_SIZE - 1)) == threadIdx.x) {
-            wt[max_expert / WARP_SIZE] = -INFINITY;
-
             ids[k] = max_expert;
-            if constexpr (with_norm) {
+            if (config.with_norm) {
                 wt_sum += max_val;
             }
         }
     }
 
-    if constexpr (with_norm) {
+    if (config.with_norm) {
         wt_sum              = warp_reduce_sum(wt_sum);
         wt_sum              = max(wt_sum, clamp_val);
         const float inv_sum = 1.0f / wt_sum;
@@ -149,7 +228,7 @@ __launch_bounds__(4 * WARP_SIZE, 1) __global__ void topk_moe_cuda(const float *
         }
     }
 
-    if constexpr (delayed_softmax) {
+    if (config.delayed_softmax) {
         softmax_warp_inplace<experts_per_thread, true>(output_weights, n_expert_used, threadIdx.x);
     }
 
@@ -157,25 +236,25 @@ __launch_bounds__(4 * WARP_SIZE, 1) __global__ void topk_moe_cuda(const float *
     for (int i = 0; i < experts_per_thread; i++) {
         const int idx = i * WARP_SIZE + threadIdx.x;
         if (idx < n_expert_used) {
-            weights[idx] = output_weights[i];
+            weights[idx] = output_weights[i] * scale_val;
         }
     }
-
-    if (!with_norm) {
-        GGML_UNUSED(clamp_val);
-    }
 }
 
-template <bool with_norm, bool delayed_softmax = false>
+template<bool has_bias>
 static void launch_topk_moe_cuda(ggml_backend_cuda_context & ctx,
                                  const float *               logits,
                                  float *                     weights,
                                  int32_t *                   ids,
+                                 float *                     bias,
                                  const int                   n_rows,
                                  const int                   n_expert,
                                  const int                   n_expert_used,
-                                 const float                 clamp_val) {
-    static_assert(!(with_norm && delayed_softmax), "delayed softmax is not supported with weight normalization");
+                                 const float                 clamp_val,
+                                 const float                 scale_val,
+                                 const topk_moe_config       config) {
+    GGML_ASSERT(!(config.with_norm && config.delayed_softmax) &&
+                "delayed softmax is not supported with weight normalization");
     const int    rows_per_block = 4;
     dim3         grid_dims((n_rows + rows_per_block - 1) / rows_per_block, 1, 1);
     dim3         block_dims(WARP_SIZE, rows_per_block, 1);
@@ -183,44 +262,48 @@ static void launch_topk_moe_cuda(ggml_backend_cuda_context & ctx,
 
     switch (n_expert) {
         case 1:
-            topk_moe_cuda<1, with_norm, delayed_softmax>
-                <<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, n_rows, n_expert_used, clamp_val);
+            topk_moe_cuda<1, has_bias><<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, bias, n_rows, n_expert_used,
+                                                                   clamp_val, scale_val, config);
             break;
         case 2:
-            topk_moe_cuda<2, with_norm, delayed_softmax>
-                <<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, n_rows, n_expert_used, clamp_val);
+            topk_moe_cuda<2, has_bias><<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, bias, n_rows, n_expert_used,
+                                                                   clamp_val, scale_val, config);
             break;
         case 4:
-            topk_moe_cuda<4, with_norm, delayed_softmax>
-                <<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, n_rows, n_expert_used, clamp_val);
+            topk_moe_cuda<4, has_bias><<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, bias, n_rows, n_expert_used,
+                                                                   clamp_val, scale_val, config);
             break;
         case 8:
-            topk_moe_cuda<8, with_norm, delayed_softmax>
-                <<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, n_rows, n_expert_used, clamp_val);
+            topk_moe_cuda<8, has_bias><<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, bias, n_rows, n_expert_used,
+                                                                   clamp_val, scale_val, config);
             break;
         case 16:
-            topk_moe_cuda<16, with_norm, delayed_softmax>
-                <<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, n_rows, n_expert_used, clamp_val);
+            topk_moe_cuda<16, has_bias><<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, bias, n_rows, n_expert_used,
+                                                                    clamp_val, scale_val, config);
             break;
         case 32:
-            topk_moe_cuda<32, with_norm, delayed_softmax>
-                <<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, n_rows, n_expert_used, clamp_val);
+            topk_moe_cuda<32, has_bias><<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, bias, n_rows, n_expert_used,
+                                                                    clamp_val, scale_val, config);
             break;
         case 64:
-            topk_moe_cuda<64, with_norm, delayed_softmax>
-                <<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, n_rows, n_expert_used, clamp_val);
+            topk_moe_cuda<64, has_bias><<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, bias, n_rows, n_expert_used,
+                                                                    clamp_val, scale_val, config);
             break;
         case 128:
-            topk_moe_cuda<128, with_norm, delayed_softmax>
-                <<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, n_rows, n_expert_used, clamp_val);
+            topk_moe_cuda<128, has_bias><<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, bias, n_rows, n_expert_used,
+                                                                     clamp_val, scale_val, config);
             break;
         case 256:
-            topk_moe_cuda<256, with_norm, delayed_softmax>
-                <<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, n_rows, n_expert_used, clamp_val);
+            topk_moe_cuda<256, has_bias><<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, bias, n_rows, n_expert_used,
+                                                                     clamp_val, scale_val, config);
             break;
         case 512:
-            topk_moe_cuda<512, with_norm, delayed_softmax>
-                <<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, n_rows, n_expert_used, clamp_val);
+            topk_moe_cuda<512, has_bias><<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, bias, n_rows, n_expert_used,
+                                                                     clamp_val, scale_val, config);
+            break;
+        case 576:
+            topk_moe_cuda<576, has_bias><<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, bias, n_rows, n_expert_used,
+                                                                     clamp_val, scale_val, config);
             break;
         default:
             GGML_ASSERT(false && "fatal error");
@@ -228,13 +311,14 @@ static void launch_topk_moe_cuda(ggml_backend_cuda_context & ctx,
     }
 }
 
-void ggml_cuda_op_topk_moe(ggml_backend_cuda_context & ctx,
-                           const ggml_tensor *         logits,
-                           ggml_tensor *               weights,
-                           ggml_tensor *               ids,
-                           const bool                  with_norm,
-                           const bool                  delayed_softmax,
-                           ggml_tensor *               clamp) {
+void ggml_cuda_op_topk_moe(ggml_backend_cuda_context &     ctx,
+                           const ggml_tensor *             logits,
+                           ggml_tensor *                   weights,
+                           ggml_tensor *                   ids,
+                           const ggml_tensor *             clamp,
+                           const ggml_tensor *             scale,
+                           const ggml_tensor *             bias,
+                           const ggml_cuda_topk_moe_args & args) {
     GGML_ASSERT(logits->type == GGML_TYPE_F32);
     GGML_ASSERT(weights->type == GGML_TYPE_F32);
     GGML_ASSERT(ids->type == GGML_TYPE_I32);
@@ -245,107 +329,75 @@ void ggml_cuda_op_topk_moe(ggml_backend_cuda_context & ctx,
     const float * logits_d  = (const float *) logits->data;
     float *       weights_d = (float *) weights->data;
     int32_t *     ids_d     = (int32_t *) ids->data;
+    float *       bias_d    = bias ? (float *) bias->data : nullptr;
+
+    float scale_val = scale ? ggml_get_op_params_f32(scale, 0) : 1.0f;
 
     GGML_ASSERT(ids->nb[1] / ggml_type_size(ids->type) == (size_t) n_experts);
 
     const int n_expert_used = weights->ne[1];
 
+    const bool with_norm = clamp != nullptr;
+
     float clamp_val = -INFINITY;
-    if (with_norm) {
-        if (clamp) {
-            clamp_val = ggml_get_op_params_f32(clamp, 0);
-        }
-        launch_topk_moe_cuda<true>(ctx, logits_d, weights_d, ids_d, n_rows, n_experts, n_expert_used, clamp_val);
+    if (clamp) {
+        clamp_val = ggml_get_op_params_f32(clamp, 0);
+    }
+
+    topk_moe_config config;
+    config.use_sigmoid     = args.sigmoid;
+    config.with_norm       = with_norm;
+    config.delayed_softmax = args.delayed_softmax;
+
+    if (bias) {
+        launch_topk_moe_cuda<true>(ctx, logits_d, weights_d, ids_d, bias_d, n_rows, n_experts, n_expert_used, clamp_val,
+                             scale_val, config);
     } else {
-        GGML_ASSERT(clamp == nullptr);
-        if (delayed_softmax) {
-            launch_topk_moe_cuda<false, true>(ctx, logits_d, weights_d, ids_d, n_rows, n_experts, n_expert_used,
-                                              clamp_val);
-        } else {
-            launch_topk_moe_cuda<false, false>(ctx, logits_d, weights_d, ids_d, n_rows, n_experts, n_expert_used,
-                                               clamp_val);
-        }
+        launch_topk_moe_cuda<false>(ctx, logits_d, weights_d, ids_d, bias_d, n_rows, n_experts, n_expert_used, clamp_val,
+                             scale_val, config);
     }
 }
 
-bool ggml_cuda_should_use_topk_moe(const ggml_tensor * softmax,
+bool ggml_cuda_should_use_topk_moe(const ggml_tensor * gating_op,
                                    const ggml_tensor * weights,
-                                   const ggml_tensor * get_rows,
-                                   const ggml_tensor * argsort,
-                                   const ggml_tensor * clamp,
-                                   int n_expert) {
-    ggml_tensor * probs = get_rows->src[0];
-    if (probs->op != GGML_OP_RESHAPE) {
+                                   const ggml_tensor * logits,
+                                   const ggml_tensor * ids) {
+    const int n_expert = ids->nb[1] / ids->nb[0];
+    if (((n_expert & (n_expert - 1)) != 0 || n_expert > 512) && n_expert != 576) {
         return false;
     }
-    probs = probs->src[0];
-    ggml_tensor * selection_probs = argsort->src[0];
 
-    if (probs != selection_probs) {
+    if (!ggml_is_contiguous(weights) || !ggml_is_contiguous(logits)) {
         return false;
     }
 
-    float scale    = 1.0f;
-    float max_bias = 0.0f;
+    if (gating_op->op == GGML_OP_SOFT_MAX) {
+        const ggml_tensor * softmax  = gating_op;
+        float               scale    = 1.0f;
+        float               max_bias = 0.0f;
 
-    memcpy(&scale, (const float *) softmax->op_params + 0, sizeof(float));
-    memcpy(&max_bias, (const float *) softmax->op_params + 1, sizeof(float));
+        memcpy(&scale, (const float *) softmax->op_params + 0, sizeof(float));
+        memcpy(&max_bias, (const float *) softmax->op_params + 1, sizeof(float));
 
-    if (!ggml_is_contiguous(softmax->src[0]) || !ggml_is_contiguous(weights)) {
-        return false;
-    }
-
-    if (scale != 1.0f || max_bias != 0.0f) {
-        return false;
-    }
-
-    // don't fuse when masks or sinks are present
-    if (softmax->src[1] || softmax->src[2]) {
-        return false;
-    }
+        if (!ggml_is_contiguous(softmax->src[0])) {
+            return false;
+        }
 
-    // n_expert must be a power of 2
-    if ((n_expert & (n_expert - 1)) != 0 || n_expert > 512) {
-        return false;
-    }
+        if (scale != 1.0f || max_bias != 0.0f) {
+            return false;
+        }
 
-    if (clamp) {
-        if (clamp->op != GGML_OP_CLAMP) {
+        // don't fuse when masks or sinks are present
+        if (softmax->src[1] || softmax->src[2]) {
             return false;
         }
-        float max_val = ggml_get_op_params_f32(clamp, 1);
+    } else if (gating_op->op == GGML_OP_UNARY) {
+        ggml_unary_op op = ggml_get_unary_op(gating_op);
 
-        if (max_val != INFINITY) {
+        if (op != GGML_UNARY_OP_SIGMOID) {
             return false;
         }
     }
 
-
     return true;
 }
-
-std::initializer_list<enum ggml_op> ggml_cuda_topk_moe_ops(bool norm, bool delayed_softmax) {
-    static std::initializer_list<enum ggml_op> norm_ops = { GGML_OP_SOFT_MAX, GGML_OP_RESHAPE,  GGML_OP_ARGSORT,
-                                                            GGML_OP_VIEW,     GGML_OP_GET_ROWS, GGML_OP_RESHAPE,
-                                                            GGML_OP_SUM_ROWS, GGML_OP_CLAMP,    GGML_OP_DIV,
-                                                            GGML_OP_RESHAPE };
-
-    static std::initializer_list<enum ggml_op> no_norm_ops = { GGML_OP_SOFT_MAX, GGML_OP_RESHAPE, GGML_OP_ARGSORT,
-                                                               GGML_OP_VIEW, GGML_OP_GET_ROWS };
-
-    static std::initializer_list<enum ggml_op> delayed_softmax_ops = { GGML_OP_ARGSORT,  GGML_OP_VIEW,
-                                                                       GGML_OP_GET_ROWS, GGML_OP_RESHAPE,
-                                                                       GGML_OP_SOFT_MAX, GGML_OP_RESHAPE };
-
-    GGML_ASSERT(!norm || !delayed_softmax);
-
-    if (delayed_softmax) {
-        return delayed_softmax_ops;
-    }
-
-    if (norm) {
-        return norm_ops;
-    }
-
-    return no_norm_ops;
-}
index 6b6c13c58706930409a4aa46fcc8589315eb56e1..243dc2f1c41b5ccbbd24dae44a3dd00a9b7f0747 100644 (file)
@@ -3,19 +3,25 @@
 
 #include <initializer_list>
 
-void ggml_cuda_op_topk_moe(ggml_backend_cuda_context & ctx,
-                           const ggml_tensor *         logits,
-                           ggml_tensor *               weights,
-                           ggml_tensor *               ids,
-                           const bool                  with_norm,
-                           const bool                  delayed_softmax = false,
-                           ggml_tensor *               weight_clamp    = nullptr);
+struct ggml_cuda_topk_moe_args {
+    bool sigmoid{};
+    bool softmax{};
+    bool delayed_softmax{};
+    bool prob_bias{};
+    bool norm{};
+    bool scale{};
+};
 
-bool ggml_cuda_should_use_topk_moe(const ggml_tensor * softmax,
-                                   const ggml_tensor * weights,
-                                   const ggml_tensor * get_rows,
-                                   const ggml_tensor * argsort,
-                                   const ggml_tensor * clamp,
-                                   int n_expert);
+void ggml_cuda_op_topk_moe(ggml_backend_cuda_context &     ctx,
+                           const ggml_tensor *             logits,
+                           ggml_tensor *                   weights,
+                           ggml_tensor *                   ids,
+                           const ggml_tensor *             clamp,
+                           const ggml_tensor *             scale,
+                           const ggml_tensor *             bias,
+                           const ggml_cuda_topk_moe_args & args);
 
-std::initializer_list<enum ggml_op> ggml_cuda_topk_moe_ops(bool with_norm, bool delayed_softmax = false);
+bool ggml_cuda_should_use_topk_moe(const ggml_tensor * gating_op,
+                                   const ggml_tensor * weights,
+                                   const ggml_tensor * logits,
+                                   const ggml_tensor * ids);