]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
CUDA: topk-moe: add optional parameter for gpt-oss (#16649)
authorAman Gupta <redacted>
Tue, 21 Oct 2025 14:40:38 +0000 (22:40 +0800)
committerGitHub <redacted>
Tue, 21 Oct 2025 14:40:38 +0000 (22:40 +0800)
ggml/src/ggml-cuda/ggml-cuda.cu
ggml/src/ggml-cuda/topk-moe.cu
ggml/src/ggml-cuda/topk-moe.cuh
tests/test-backend-ops.cpp

index 015b37be0708e028a93f6d0a0a0937e0b2c219ec..6e7c5aedbc55a55937575701f49a91d38105300c 100644 (file)
@@ -2818,8 +2818,12 @@ static bool ggml_cuda_can_fuse(const struct ggml_cgraph * cgraph, int node_idx,
 #endif
 
     //TODO: remove special case once ggml_can_fuse can handle empty nodes
-    std::initializer_list<enum ggml_op> topk_moe_ops           = ggml_cuda_topk_moe_ops(false);
-    std::initializer_list<enum ggml_op> topk_moe_ops_with_norm = ggml_cuda_topk_moe_ops(true);
+    std::initializer_list<enum ggml_op> topk_moe_ops =
+        ggml_cuda_topk_moe_ops(/*with_norm*/ false, /*delayed_softmax=*/false);
+    std::initializer_list<enum ggml_op> topk_moe_ops_with_norm =
+        ggml_cuda_topk_moe_ops(/*with_norm=*/true, /*delayed_softmax=*/false);
+    std::initializer_list<enum ggml_op> topk_moe_ops_delayed_softmax =
+        ggml_cuda_topk_moe_ops(/*with_norm=*/false, /*delayed_softmax=*/true);
 
     if (ops.size() == topk_moe_ops_with_norm.size() &&
         ggml_can_fuse_subgraph(cgraph, node_idx, topk_moe_ops_with_norm, { node_idx + 3, node_idx + 8 })) {
@@ -2840,6 +2844,16 @@ static bool ggml_cuda_can_fuse(const struct ggml_cgraph * cgraph, int node_idx,
         }
     }
 
+    if (ops.size() == topk_moe_ops_delayed_softmax.size() &&
+        ggml_can_fuse_subgraph(cgraph, node_idx, topk_moe_ops_delayed_softmax, { node_idx + 2, node_idx + 5 })) {
+        ggml_tensor * softmax = cgraph->nodes[node_idx + 4];
+        ggml_tensor * weights = cgraph->nodes[node_idx + 5];
+
+        if (ggml_cuda_should_use_topk_moe(softmax, weights)) {
+            return true;
+        }
+    }
+
     if (!ggml_can_fuse(cgraph, node_idx, ops)) {
         return false;
     }
@@ -2933,7 +2947,8 @@ static void evaluate_and_capture_cuda_graph(ggml_backend_cuda_context * cuda_ctx
                     if (ggml_cuda_can_fuse(cgraph, i, ggml_cuda_topk_moe_ops(/*with norm*/ true), {})) {
                         ggml_tensor * weights = cgraph->nodes[i+8];
                         ggml_tensor * selected_experts = cgraph->nodes[i+3];
-                        ggml_cuda_op_topk_moe(*cuda_ctx, node, weights, selected_experts, /*with norm*/ true);
+                        ggml_cuda_op_topk_moe(*cuda_ctx, node->src[0], weights, selected_experts, /*with norm*/ true,
+                                              /*delayed softmax*/ false);
                         i += 8;
                         continue;
                     }
@@ -2941,11 +2956,23 @@ static void evaluate_and_capture_cuda_graph(ggml_backend_cuda_context * cuda_ctx
                     if (ggml_cuda_can_fuse(cgraph, i, ggml_cuda_topk_moe_ops(/*with norm*/ false), {})) {
                         ggml_tensor * weights = cgraph->nodes[i+4];
                         ggml_tensor * selected_experts = cgraph->nodes[i+3];
-                        ggml_cuda_op_topk_moe(*cuda_ctx, node, weights, selected_experts, /*with norm*/ false);
+                        ggml_cuda_op_topk_moe(*cuda_ctx, node->src[0], weights, selected_experts, /*with norm*/ false,
+                                              /*delayed softmax*/ false);
                         i += 4;
                         continue;
                     }
 
+                    if (ggml_cuda_can_fuse(cgraph, i,
+                                           ggml_cuda_topk_moe_ops(/*with norm*/ false, /*delayed softmax*/ true), {})) {
+                        ggml_tensor * weights = cgraph->nodes[i + 5];
+                        ggml_tensor * ids     = cgraph->nodes[i + 1];
+
+                        ggml_cuda_op_topk_moe(*cuda_ctx, node->src[0], weights, ids, /*with norm*/ false,
+                                              /*delayed_softmax*/ true);
+                        i += 5;
+                        continue;
+                    }
+
                     if (node->op == GGML_OP_ADD) {
                         int n_fuse = 0;
                         ggml_op ops[8];
index c588da2bb9e9325b1472359bc8ec9afafe96311d..d782ad948d2547a734a28a737a0ad5275d37eea7 100644 (file)
@@ -4,16 +4,61 @@
 
 #include <initializer_list>
 
+// Warp-local softmax used for both the pre-top-k logits and the post-top-k delayed path.
+template <int experts_per_thread, bool use_limit>
+__device__ void softmax_warp_inplace(float (&vals)[experts_per_thread], const int limit, const int lane) {
+    float max_val = -INFINITY;
+
+#pragma unroll
+    for (int i = 0; i < experts_per_thread; i++) {
+        const int  idx    = lane + i * WARP_SIZE;
+        const bool active = !use_limit || (idx < limit);
+        if (active) {
+            max_val = max(max_val, vals[i]);
+        }
+    }
+
+    max_val = warp_reduce_max(max_val);
+
+    float sum = 0.f;
+
+#pragma unroll
+    for (int i = 0; i < experts_per_thread; i++) {
+        const int  idx    = lane + i * WARP_SIZE;
+        const bool active = !use_limit || (idx < limit);
+        if (active) {
+            const float val = expf(vals[i] - max_val);
+            vals[i]         = val;
+            sum += val;
+        } else {
+            vals[i] = 0.f;
+        }
+    }
+
+    sum = warp_reduce_sum(sum);
+
+    const float inv_sum = 1.0f / sum;
+
+#pragma unroll
+    for (int i = 0; i < experts_per_thread; i++) {
+        const int  idx    = lane + i * WARP_SIZE;
+        const bool active = !use_limit || (idx < limit);
+        if (active) {
+            vals[i] *= inv_sum;
+        }
+    }
+}
+
 /*
     This kernel does the following:
-    1. softmax over the logits per token [n_experts, n_tokens]
+    1. optionally softmax over the logits per token [n_experts, n_tokens]
     2. argmax reduce over the top-k (n_experts_used) logits
     3. write weights + ids to global memory
-    4. optionally normalize the weights
+    4. optionally normalize the weights or apply softmax over the selected logits
 
     It is intended as fusion of softmax->top-k->get_rows pipeline for MoE models
 */
-template <int n_experts, bool with_norm>
+template <int n_experts, bool with_norm, bool delayed_softmax = false>
 __launch_bounds__(4 * WARP_SIZE, 1) __global__ void topk_moe_cuda(const float * logits,
                                                                   float *       weights,
                                                                   int32_t *     ids,
@@ -30,51 +75,31 @@ __launch_bounds__(4 * WARP_SIZE, 1) __global__ void topk_moe_cuda(const float *
 
     constexpr int experts_per_thread = (n_experts > WARP_SIZE) ? n_experts / WARP_SIZE : 1;
 
-    float logits_r[experts_per_thread];
+    float wt[experts_per_thread];
 
 #pragma unroll
     for (int i = 0; i < n_experts; i += WARP_SIZE) {
-        const int expert        = i + threadIdx.x;
-        logits_r[i / WARP_SIZE] = n_experts % WARP_SIZE == 0 || expert < n_experts ? logits[expert] : -INFINITY;
+        const int expert  = i + threadIdx.x;
+        wt[i / WARP_SIZE] = (n_experts % WARP_SIZE == 0 || expert < n_experts) ? logits[expert] : -INFINITY;
     }
 
-    float max_val = logits_r[0];
-
-#pragma unroll
-    for (int i = 1; i < experts_per_thread; i++) {
-        const float val = logits_r[i];
-        max_val         = max(val, max_val);
+    if constexpr (!delayed_softmax) {
+        softmax_warp_inplace<experts_per_thread, false>(wt, n_experts, threadIdx.x);
     }
 
-    max_val = warp_reduce_max(max_val);
-
-    float wt[experts_per_thread];
-    float tmp = 0.f;
-
-#pragma unroll
-    for (int i = 0; i < experts_per_thread; i++) {
-        const float val = logits_r[i];
-        wt[i]           = expf(val - max_val);
-        tmp += wt[i];
-    }
+    //at this point, each thread holds either a portion of the softmax distribution
+    //or the raw logits. We do the argmax reduce over n_expert_used, each time marking
+    //the expert weight as -inf to exclude from the next iteration
 
-    tmp = warp_reduce_sum(tmp);
+    float wt_sum = 0.f;
 
-    const float inv_sum = 1.0f / tmp;
+    float output_weights[experts_per_thread];
 
 #pragma unroll
     for (int i = 0; i < experts_per_thread; i++) {
-        wt[i] = wt[i] * inv_sum;
+        output_weights[i] = 0.f;
     }
 
-    //at this point, each thread holds a portion of softmax,
-    //we do the argmax reduce over n_expert_used, each time marking
-    //the expert weight as -inf to exclude from the next iteration
-
-    float wt_sum = 0.f;
-
-    float output_weights[experts_per_thread];
-
     for (int k = 0; k < n_expert_used; k++) {
         float max_val    = wt[0];
         int   max_expert = threadIdx.x;
@@ -121,6 +146,10 @@ __launch_bounds__(4 * WARP_SIZE, 1) __global__ void topk_moe_cuda(const float *
         }
     }
 
+    if constexpr (delayed_softmax) {
+        softmax_warp_inplace<experts_per_thread, true>(output_weights, n_expert_used, threadIdx.x);
+    }
+
 #pragma unroll
     for (int i = 0; i < experts_per_thread; i++) {
         const int idx = i * WARP_SIZE + threadIdx.x;
@@ -130,7 +159,7 @@ __launch_bounds__(4 * WARP_SIZE, 1) __global__ void topk_moe_cuda(const float *
     }
 }
 
-template <bool with_norm>
+template <bool with_norm, bool delayed_softmax = false>
 static void launch_topk_moe_cuda(ggml_backend_cuda_context & ctx,
                                  const float *               logits,
                                  float *                     weights,
@@ -138,6 +167,8 @@ static void launch_topk_moe_cuda(ggml_backend_cuda_context & ctx,
                                  const int                   n_rows,
                                  const int                   n_expert,
                                  const int                   n_expert_used) {
+    static_assert(!(with_norm && delayed_softmax), "delayed softmax is not supported with weight normalization");
+
     const int    rows_per_block = 4;
     dim3         grid_dims((n_rows + rows_per_block - 1) / rows_per_block, 1, 1);
     dim3         block_dims(WARP_SIZE, rows_per_block, 1);
@@ -145,43 +176,43 @@ static void launch_topk_moe_cuda(ggml_backend_cuda_context & ctx,
 
     switch (n_expert) {
         case 1:
-            topk_moe_cuda<1, with_norm>
+            topk_moe_cuda<1, with_norm, delayed_softmax>
                 <<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, n_rows, n_expert_used);
             break;
         case 2:
-            topk_moe_cuda<2, with_norm>
+            topk_moe_cuda<2, with_norm, delayed_softmax>
                 <<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, n_rows, n_expert_used);
             break;
         case 4:
-            topk_moe_cuda<4, with_norm>
+            topk_moe_cuda<4, with_norm, delayed_softmax>
                 <<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, n_rows, n_expert_used);
             break;
         case 8:
-            topk_moe_cuda<8, with_norm>
+            topk_moe_cuda<8, with_norm, delayed_softmax>
                 <<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, n_rows, n_expert_used);
             break;
         case 16:
-            topk_moe_cuda<16, with_norm>
+            topk_moe_cuda<16, with_norm, delayed_softmax>
                 <<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, n_rows, n_expert_used);
             break;
         case 32:
-            topk_moe_cuda<32, with_norm>
+            topk_moe_cuda<32, with_norm, delayed_softmax>
                 <<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, n_rows, n_expert_used);
             break;
         case 64:
-            topk_moe_cuda<64, with_norm>
+            topk_moe_cuda<64, with_norm, delayed_softmax>
                 <<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, n_rows, n_expert_used);
             break;
         case 128:
-            topk_moe_cuda<128, with_norm>
+            topk_moe_cuda<128, with_norm, delayed_softmax>
                 <<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, n_rows, n_expert_used);
             break;
         case 256:
-            topk_moe_cuda<256, with_norm>
+            topk_moe_cuda<256, with_norm, delayed_softmax>
                 <<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, n_rows, n_expert_used);
             break;
         case 512:
-            topk_moe_cuda<512, with_norm>
+            topk_moe_cuda<512, with_norm, delayed_softmax>
                 <<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, n_rows, n_expert_used);
             break;
         default:
@@ -194,7 +225,8 @@ void ggml_cuda_op_topk_moe(ggml_backend_cuda_context & ctx,
                            const ggml_tensor *         logits,
                            ggml_tensor *               weights,
                            ggml_tensor *               ids,
-                           const bool                  with_norm) {
+                           const bool                  with_norm,
+                           const bool                  delayed_softmax) {
     GGML_ASSERT(logits->type == GGML_TYPE_F32);
     GGML_ASSERT(weights->type == GGML_TYPE_F32);
     GGML_ASSERT(ids->type == GGML_TYPE_I32);
@@ -202,7 +234,7 @@ void ggml_cuda_op_topk_moe(ggml_backend_cuda_context & ctx,
     const int n_experts = logits->ne[0];
     const int n_rows    = logits->ne[1];
 
-    const float * logits_d  = (const float *) logits->src[0]->data;
+    const float * logits_d  = (const float *) logits->data;
     float *       weights_d = (float *) weights->data;
     int32_t *     ids_d     = (int32_t *) ids->data;
 
@@ -213,7 +245,11 @@ void ggml_cuda_op_topk_moe(ggml_backend_cuda_context & ctx,
     if (with_norm) {
         launch_topk_moe_cuda<true>(ctx, logits_d, weights_d, ids_d, n_rows, n_experts, n_expert_used);
     } else {
-        launch_topk_moe_cuda<false>(ctx, logits_d, weights_d, ids_d, n_rows, n_experts, n_expert_used);
+        if (delayed_softmax) {
+            launch_topk_moe_cuda<false, true>(ctx, logits_d, weights_d, ids_d, n_rows, n_experts, n_expert_used);
+        } else {
+            launch_topk_moe_cuda<false, false>(ctx, logits_d, weights_d, ids_d, n_rows, n_experts, n_expert_used);
+        }
     }
 }
 
@@ -246,7 +282,7 @@ bool ggml_cuda_should_use_topk_moe(const ggml_tensor * softmax, const ggml_tenso
     return true;
 }
 
-std::initializer_list<enum ggml_op> ggml_cuda_topk_moe_ops(bool norm) {
+std::initializer_list<enum ggml_op> ggml_cuda_topk_moe_ops(bool norm, bool delayed_softmax) {
     static std::initializer_list<enum ggml_op> norm_ops = { GGML_OP_SOFT_MAX, GGML_OP_RESHAPE,  GGML_OP_ARGSORT,
                                                             GGML_OP_VIEW,     GGML_OP_GET_ROWS, GGML_OP_RESHAPE,
                                                             GGML_OP_SUM_ROWS, GGML_OP_DIV,      GGML_OP_RESHAPE };
@@ -254,8 +290,19 @@ std::initializer_list<enum ggml_op> ggml_cuda_topk_moe_ops(bool norm) {
     static std::initializer_list<enum ggml_op> no_norm_ops = { GGML_OP_SOFT_MAX, GGML_OP_RESHAPE, GGML_OP_ARGSORT,
                                                                GGML_OP_VIEW, GGML_OP_GET_ROWS };
 
+    static std::initializer_list<enum ggml_op> delayed_softmax_ops = { GGML_OP_ARGSORT,  GGML_OP_VIEW,
+                                                                       GGML_OP_GET_ROWS, GGML_OP_RESHAPE,
+                                                                       GGML_OP_SOFT_MAX, GGML_OP_RESHAPE };
+
+    GGML_ASSERT(!norm || !delayed_softmax);
+
+    if (delayed_softmax) {
+        return delayed_softmax_ops;
+    }
+
     if (norm) {
         return norm_ops;
     }
+
     return no_norm_ops;
 }
index 6613fb56507eaedd8a4c8b6eaf96016006e6a9be..cc2fbfe9e6649e73b946a7964c3779d84ea3dace 100644 (file)
@@ -6,9 +6,10 @@
 void ggml_cuda_op_topk_moe(ggml_backend_cuda_context & ctx,
                            const ggml_tensor *         logits,
                            ggml_tensor *               weights,
-                           ggml_tensor *               top_k,
-                           const bool                  with_norm);
+                           ggml_tensor *               ids,
+                           const bool                  with_norm,
+                           const bool                  delayed_softmax = false);
 
 bool ggml_cuda_should_use_topk_moe(const ggml_tensor * softmax, const ggml_tensor * weights);
 
-std::initializer_list<enum ggml_op> ggml_cuda_topk_moe_ops(bool with_norm);
+std::initializer_list<enum ggml_op> ggml_cuda_topk_moe_ops(bool with_norm, bool delayed_softmax = false);
index fa98db2982ce77112c7930142afc1f804af3a87d..991c62597962d4c520cd0b1914c21badc0374e13 100644 (file)
@@ -4669,14 +4669,21 @@ struct test_topk_moe: public test_case {
     const std::array<int64_t, 4> ne;
     const int n_expert_used;
     const bool with_norm;
-    test_topk_moe(std::array<int64_t, 4> ne = {10, 5, 1, 1}, int n_expert_used = 1, bool with_norm = false)
-    : ne(ne), n_expert_used(n_expert_used), with_norm(with_norm) {
+    const bool                   delayed_softmax;
+
+    test_topk_moe(std::array<int64_t, 4> ne              = { 10, 5, 1, 1 },
+                  int                    n_expert_used   = 1,
+                  bool                   with_norm       = false,
+                  bool                   delayed_softmax = false) :
+        ne(ne),
+        n_expert_used(n_expert_used),
+        with_norm(with_norm),
+        delayed_softmax(delayed_softmax) {
         GGML_ASSERT(n_expert_used <= ne[0]);
+        GGML_ASSERT(!(with_norm && delayed_softmax));
     }
 
-    std::string vars() override {
-        return VARS_TO_STR3(ne, n_expert_used, with_norm);
-    }
+    std::string vars() override { return VARS_TO_STR4(ne, n_expert_used, with_norm, delayed_softmax); }
 
     std::string op_desc(ggml_tensor * t) override {
         GGML_UNUSED(t);
@@ -4690,11 +4697,17 @@ struct test_topk_moe: public test_case {
         const int n_tokens = ne[1];
 
         ggml_tensor * logits = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne.data());
-        ggml_tensor * probs  = ggml_soft_max(ctx, logits);
+        ggml_tensor * probs            = delayed_softmax ? logits : ggml_soft_max(ctx, logits);
         ggml_tensor * selected_experts = ggml_top_k(ctx, probs, n_expert_used); // [n_expert_used, n_tokens]
 
         ggml_tensor * out = ggml_get_rows(ctx, ggml_reshape_3d(ctx, probs, 1, n_expert, n_tokens), selected_experts); // [1, n_expert_used, n_tokens]
 
+        if (delayed_softmax) {
+            out = ggml_reshape_2d(ctx, out, n_expert_used, n_tokens);
+            out = ggml_soft_max(ctx, out);  // [n_expert_used, n_tokens]
+            out = ggml_reshape_3d(ctx, out, 1, n_expert_used, n_tokens);
+        }
+
         if (with_norm) {
             out = ggml_reshape_2d(ctx, out, n_expert_used, n_tokens);
             ggml_tensor * weights_sum = ggml_sum_rows(ctx, out); // [1, n_tokens]
@@ -6975,6 +6988,9 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
         test_cases.emplace_back(new test_topk_moe({128, 1, 1, 1}, 128, with_norm));
     }
 
+    test_cases.emplace_back(new test_topk_moe({ 8, 22, 1, 1 }, 4, /*with_norm*/ false, /*delayed_softmax*/ true));
+    test_cases.emplace_back(new test_topk_moe({ 32, 22, 1, 1 }, 8, /*with_norm*/ false, /*delayed_softmax*/ true));
+
 #if 0
     // these tests are disabled to save execution time, sbut they can be handy for debugging
     test_cases.emplace_back(new test_llama(2, true));