]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
CUDA: add a fused top-K MoE kernel (#16130)
authorAman Gupta <redacted>
Thu, 25 Sep 2025 14:35:05 +0000 (22:35 +0800)
committerGitHub <redacted>
Thu, 25 Sep 2025 14:35:05 +0000 (16:35 +0200)
* CUDA: add a fused top-K MoE kernel

This kernel does the following:
1. softmax over the logits per token [n_experts, n_tokens]
2. argmax reduce over the top-k (n_experts_used) logits
3. write weights + ids to global memory

It is intended as fusion of softmax->top-k->get_rows pipeline for MoE models

* Refactor into ggml_cuda_should_use_topk_moe

* Review: Use better coalescing pattern, use WARP_SIZE, store logits into registers before

* Review: format + micro-optimizations

* Fix bug: fix tie breakers

* Add optional norm + clean-up code

* Use smem for final write

* Add bounds check

* Use better memory pattern for writeback

ggml/src/ggml-cuda/ggml-cuda.cu
ggml/src/ggml-cuda/topk-moe.cu [new file with mode: 0644]
ggml/src/ggml-cuda/topk-moe.cuh [new file with mode: 0644]
src/llama-graph.cpp
tests/test-backend-ops.cpp

index 4d85c5dc083d1d056e50594458813d8f1bfad999..8c8647b147369d27f84154e2d505c320cbfa39f5 100644 (file)
@@ -45,6 +45,7 @@
 #include "ggml-cuda/sumrows.cuh"
 #include "ggml-cuda/mean.cuh"
 #include "ggml-cuda/tsembd.cuh"
+#include "ggml-cuda/topk-moe.cuh"
 #include "ggml-cuda/unary.cuh"
 #include "ggml-cuda/upscale.cuh"
 #include "ggml-cuda/wkv.cuh"
@@ -2825,6 +2826,44 @@ static bool ggml_cuda_can_fuse(const struct ggml_cgraph * cgraph, int node_idx,
     GGML_ASSERT(unary_ops.size() == num_unary);
 #endif
 
+    //TODO: remove special case once ggml_can_fuse can handle empty nodes
+    std::initializer_list<enum ggml_op> topk_moe_ops           = ggml_cuda_topk_moe_ops(false);
+    std::initializer_list<enum ggml_op> topk_moe_ops_with_norm = ggml_cuda_topk_moe_ops(true);
+
+    if (ops.size() == topk_moe_ops_with_norm.size() && std::equal(ops.begin(), ops.end(), topk_moe_ops_with_norm.begin())) {
+
+        if (node_idx + topk_moe_ops_with_norm.size() > (size_t)cgraph->n_nodes) {
+            return false;
+        }
+
+        for (size_t i = 0; i < topk_moe_ops_with_norm.size(); i++) {
+            if (cgraph->nodes[node_idx + i]->op != topk_moe_ops_with_norm.begin()[i]) return false;
+        }
+        ggml_tensor * softmax = cgraph->nodes[node_idx];
+        ggml_tensor * weights = cgraph->nodes[node_idx+8];
+
+        if (ggml_cuda_should_use_topk_moe(softmax, weights)) {
+            return true;
+        }
+    }
+
+    if (ops.size() == topk_moe_ops.size() && std::equal(ops.begin(), ops.end(), topk_moe_ops.begin())) {
+
+        if (node_idx + topk_moe_ops.size() > (size_t)cgraph->n_nodes) {
+            return false;
+        }
+
+        for (size_t i = 0; i < topk_moe_ops.size(); i++) {
+            if (cgraph->nodes[node_idx + i]->op != topk_moe_ops.begin()[i]) return false;
+        }
+
+        ggml_tensor * softmax = cgraph->nodes[node_idx];
+        ggml_tensor * weights = cgraph->nodes[node_idx+4];
+        if (ggml_cuda_should_use_topk_moe(softmax, weights)) {
+            return true;
+        }
+    }
+
     if (!ggml_can_fuse(cgraph, node_idx, ops)) {
         return false;
     }
@@ -2915,6 +2954,22 @@ static void evaluate_and_capture_cuda_graph(ggml_backend_cuda_context * cuda_ctx
                 static bool disable_fusion = (getenv("GGML_CUDA_DISABLE_FUSION") != nullptr);
                 if (!disable_fusion) {
 
+                    if (ggml_cuda_can_fuse(cgraph, i, ggml_cuda_topk_moe_ops(/*with norm*/ true), {})) {
+                        ggml_tensor * weights = cgraph->nodes[i+8];
+                        ggml_tensor * selected_experts = cgraph->nodes[i+3];
+                        ggml_cuda_op_topk_moe(*cuda_ctx, node, weights, selected_experts, /*with norm*/ true);
+                        i += 8;
+                        continue;
+                    }
+
+                    if (ggml_cuda_can_fuse(cgraph, i, ggml_cuda_topk_moe_ops(/*with norm*/ false), {})) {
+                        ggml_tensor * weights = cgraph->nodes[i+4];
+                        ggml_tensor * selected_experts = cgraph->nodes[i+3];
+                        ggml_cuda_op_topk_moe(*cuda_ctx, node, weights, selected_experts, /*with norm*/ false);
+                        i += 4;
+                        continue;
+                    }
+
                     if (node->op == GGML_OP_ADD) {
                         int n_fuse = 0;
                         ggml_op ops[8];
diff --git a/ggml/src/ggml-cuda/topk-moe.cu b/ggml/src/ggml-cuda/topk-moe.cu
new file mode 100644 (file)
index 0000000..039f284
--- /dev/null
@@ -0,0 +1,259 @@
+#include "ggml-cuda/common.cuh"
+#include "ggml.h"
+#include "topk-moe.cuh"
+
+#include <initializer_list>
+
+/*
+    This kernel does the following:
+    1. softmax over the logits per token [n_experts, n_tokens]
+    2. argmax reduce over the top-k (n_experts_used) logits
+    3. write weights + ids to global memory
+    4. optionally normalize the weights
+
+    It is intended as fusion of softmax->top-k->get_rows pipeline for MoE models
+*/
+template <size_t n_experts, bool with_norm>
+__launch_bounds__(4 * WARP_SIZE, 1) __global__ void topk_moe_cuda(const float * logits,
+                                                                  float *       weights,
+                                                                  int32_t *     ids,
+                                                                  const int     n_rows,
+                                                                  const int     n_expert_used) {
+    const int row = blockIdx.x * blockDim.y + threadIdx.y;
+    if (row >= n_rows) {
+        return;
+    }
+
+    logits += n_experts * row;
+    weights += n_expert_used * row;
+    ids += n_experts * row;
+
+    constexpr int experts_per_thread = (n_experts > WARP_SIZE) ? n_experts / WARP_SIZE : 1;
+
+    float logits_r[experts_per_thread];
+
+#pragma unroll
+    for (int i = 0; i < n_experts; i += WARP_SIZE) {
+        const int expert        = i + threadIdx.x;
+        logits_r[i / WARP_SIZE] = n_experts % WARP_SIZE == 0 || expert < n_experts ? logits[expert] : -INFINITY;
+    }
+
+    float max_val = logits_r[0];
+
+#pragma unroll
+    for (int i = 1; i < experts_per_thread; i++) {
+        const float val = logits_r[i];
+        max_val         = max(val, max_val);
+    }
+
+    max_val = warp_reduce_max(max_val);
+
+    float wt[experts_per_thread];
+    float tmp = 0.f;
+
+#pragma unroll
+    for (int i = 0; i < experts_per_thread; i++) {
+        const float val = logits_r[i];
+        wt[i]           = expf(val - max_val);
+        tmp += wt[i];
+    }
+
+    tmp = warp_reduce_sum(tmp);
+
+    const float inv_sum = 1.0f / tmp;
+
+#pragma unroll
+    for (int i = 0; i < experts_per_thread; i++) {
+        wt[i] = wt[i] * inv_sum;
+    }
+
+    //at this point, each thread holds a portion of softmax,
+    //we do the argmax reduce over n_expert_used, each time marking
+    //the expert weight as -inf to exclude from the next iteration
+
+    float wt_sum = 0.f;
+
+    extern __shared__ float data_topk_shared[];
+    float *                 wt_shared_ptr = data_topk_shared + threadIdx.y * n_expert_used;
+
+    for (int k = 0; k < n_expert_used; k++) {
+        float max_val    = wt[0];
+        int   max_expert = threadIdx.x;
+
+#pragma unroll
+        for (int i = 1; i < experts_per_thread; i++) {
+            const int expert = threadIdx.x + i * WARP_SIZE;
+            if ((n_experts % WARP_SIZE == 0 || expert < n_experts) && wt[i] > max_val) {
+                max_val    = wt[i];
+                max_expert = expert;
+            }
+        }
+
+#pragma unroll
+        for (int mask = WARP_SIZE / 2; mask > 0; mask /= 2) {
+            const float val    = __shfl_xor_sync(0xFFFFFFFF, max_val, mask, WARP_SIZE);
+            const int   expert = __shfl_xor_sync(0xFFFFFFFF, max_expert, mask, WARP_SIZE);
+            if (val > max_val || (val == max_val && expert < max_expert)) {
+                max_val    = val;
+                max_expert = expert;
+            }
+        }
+
+        if ((max_expert & (WARP_SIZE - 1)) == threadIdx.x) {
+            wt[max_expert / WARP_SIZE] = -INFINITY;
+
+            wt_shared_ptr[k] = max_val;
+            ids[k]           = max_expert;
+            if constexpr (with_norm) {
+                wt_sum += max_val;
+            }
+        }
+    }
+
+    if constexpr (with_norm) {
+        wt_sum              = warp_reduce_sum(wt_sum);
+        const float inv_sum = 1.0f / wt_sum;
+
+        for (int i = threadIdx.x; i < n_expert_used; i += WARP_SIZE) {
+            wt_shared_ptr[i] = wt_shared_ptr[i] * inv_sum;
+        }
+    }
+
+    for (int i = threadIdx.x; i < n_expert_used; i += WARP_SIZE) {
+        weights[i] = wt_shared_ptr[i];
+    }
+}
+
+template <bool with_norm>
+static void launch_topk_moe_cuda(ggml_backend_cuda_context & ctx,
+                                 const float *               logits,
+                                 float *                     weights,
+                                 int32_t *                   ids,
+                                 const int                   n_rows,
+                                 const int                   n_expert,
+                                 const int                   n_expert_used) {
+    const int    rows_per_block = 4;
+    dim3         grid_dims((n_rows + rows_per_block - 1) / rows_per_block, 1, 1);
+    dim3         block_dims(WARP_SIZE, rows_per_block, 1);
+    cudaStream_t stream = ctx.stream();
+
+    const int nbytes_shared = n_expert_used * rows_per_block * sizeof(float);
+
+    switch (n_expert) {
+        case 1:
+            topk_moe_cuda<1, with_norm>
+                <<<grid_dims, block_dims, nbytes_shared, stream>>>(logits, weights, ids, n_rows, n_expert_used);
+            break;
+        case 2:
+            topk_moe_cuda<2, with_norm>
+                <<<grid_dims, block_dims, nbytes_shared, stream>>>(logits, weights, ids, n_rows, n_expert_used);
+            break;
+        case 4:
+            topk_moe_cuda<4, with_norm>
+                <<<grid_dims, block_dims, nbytes_shared, stream>>>(logits, weights, ids, n_rows, n_expert_used);
+            break;
+        case 8:
+            topk_moe_cuda<8, with_norm>
+                <<<grid_dims, block_dims, nbytes_shared, stream>>>(logits, weights, ids, n_rows, n_expert_used);
+            break;
+        case 16:
+            topk_moe_cuda<16, with_norm>
+                <<<grid_dims, block_dims, nbytes_shared, stream>>>(logits, weights, ids, n_rows, n_expert_used);
+            break;
+        case 32:
+            topk_moe_cuda<32, with_norm>
+                <<<grid_dims, block_dims, nbytes_shared, stream>>>(logits, weights, ids, n_rows, n_expert_used);
+            break;
+        case 64:
+            topk_moe_cuda<64, with_norm>
+                <<<grid_dims, block_dims, nbytes_shared, stream>>>(logits, weights, ids, n_rows, n_expert_used);
+            break;
+        case 128:
+            topk_moe_cuda<128, with_norm>
+                <<<grid_dims, block_dims, nbytes_shared, stream>>>(logits, weights, ids, n_rows, n_expert_used);
+            break;
+        case 256:
+            topk_moe_cuda<256, with_norm>
+                <<<grid_dims, block_dims, nbytes_shared, stream>>>(logits, weights, ids, n_rows, n_expert_used);
+            break;
+        case 512:
+            topk_moe_cuda<512, with_norm>
+                <<<grid_dims, block_dims, nbytes_shared, stream>>>(logits, weights, ids, n_rows, n_expert_used);
+            break;
+        default:
+            GGML_ASSERT(false && "fatal error");
+            break;
+    }
+}
+
+void ggml_cuda_op_topk_moe(ggml_backend_cuda_context & ctx,
+                           const ggml_tensor *         logits,
+                           ggml_tensor *               weights,
+                           ggml_tensor *               ids,
+                           const bool                  with_norm) {
+    GGML_ASSERT(logits->type == GGML_TYPE_F32);
+    GGML_ASSERT(weights->type == GGML_TYPE_F32);
+    GGML_ASSERT(ids->type == GGML_TYPE_I32);
+
+    const int n_experts = logits->ne[0];
+    const int n_rows    = logits->ne[1];
+
+    const float * logits_d  = (const float *) logits->src[0]->data;
+    float *       weights_d = (float *) weights->data;
+    int32_t *     ids_d     = (int32_t *) ids->data;
+
+    GGML_ASSERT(ids->nb[1] / ggml_type_size(ids->type) == (size_t) n_experts);
+
+    cudaStream_t stream = ctx.stream();
+
+    const int n_expert_used = weights->ne[1];
+
+    if (with_norm) {
+        launch_topk_moe_cuda<true>(ctx, logits_d, weights_d, ids_d, n_rows, n_experts, n_expert_used);
+    } else {
+        launch_topk_moe_cuda<false>(ctx, logits_d, weights_d, ids_d, n_rows, n_experts, n_expert_used);
+    }
+}
+
+bool ggml_cuda_should_use_topk_moe(const ggml_tensor * softmax, const ggml_tensor * weights) {
+    float scale    = 1.0f;
+    float max_bias = 0.0f;
+
+    memcpy(&scale, (const float *) softmax->op_params + 0, sizeof(float));
+    memcpy(&max_bias, (const float *) softmax->op_params + 1, sizeof(float));
+
+    if (!ggml_is_contiguous(softmax->src[0]) || !ggml_is_contiguous(weights)) {
+        return false;
+    }
+
+    if (scale != 1.0f || max_bias != 0.0f) {
+        return false;
+    }
+
+    // don't fuse when masks or sinks are present
+    if (softmax->src[1] || softmax->src[2]) {
+        return false;
+    }
+
+    const int n_expert = softmax->ne[0];
+    // n_expert must be a power of 2
+    if ((n_expert & (n_expert - 1)) != 0 || n_expert > 512) {
+        return false;
+    }
+
+    return true;
+}
+
+std::initializer_list<enum ggml_op> ggml_cuda_topk_moe_ops(bool norm) {
+    static std::initializer_list<enum ggml_op> norm_ops = { GGML_OP_SOFT_MAX, GGML_OP_RESHAPE,  GGML_OP_ARGSORT,
+                                                            GGML_OP_VIEW,     GGML_OP_GET_ROWS, GGML_OP_RESHAPE,
+                                                            GGML_OP_SUM_ROWS, GGML_OP_DIV,      GGML_OP_RESHAPE };
+
+    static std::initializer_list<enum ggml_op> no_norm_ops = { GGML_OP_SOFT_MAX, GGML_OP_RESHAPE, GGML_OP_ARGSORT,
+                                                               GGML_OP_VIEW, GGML_OP_GET_ROWS };
+
+    if (norm) {
+        return norm_ops;
+    }
+    return no_norm_ops;
+}
diff --git a/ggml/src/ggml-cuda/topk-moe.cuh b/ggml/src/ggml-cuda/topk-moe.cuh
new file mode 100644 (file)
index 0000000..6613fb5
--- /dev/null
@@ -0,0 +1,14 @@
+#include "common.cuh"
+#include "ggml.h"
+
+#include <initializer_list>
+
+void ggml_cuda_op_topk_moe(ggml_backend_cuda_context & ctx,
+                           const ggml_tensor *         logits,
+                           ggml_tensor *               weights,
+                           ggml_tensor *               top_k,
+                           const bool                  with_norm);
+
+bool ggml_cuda_should_use_topk_moe(const ggml_tensor * softmax, const ggml_tensor * weights);
+
+std::initializer_list<enum ggml_op> ggml_cuda_topk_moe_ops(bool with_norm);
index d4faa2a63935045f1fb7b621222df33eba8c1a03..ad55838b1a66b64ba065ad157c0acca33ae49c79 100644 (file)
@@ -932,6 +932,7 @@ ggml_tensor * llm_graph_context::build_moe_ffn(
             ggml_reshape_3d(ctx0, probs, 1, n_expert, n_tokens), selected_experts); // [1, n_expert_used, n_tokens]
     cb(weights, "ffn_moe_weights", il);
 
+
     if (gating_op == LLAMA_EXPERT_GATING_FUNC_TYPE_SOFTMAX_WEIGHT) {
         weights = ggml_reshape_2d(ctx0, weights, n_expert_used, n_tokens);
         weights = ggml_soft_max(ctx0, weights); // [n_expert_used, n_tokens]
@@ -955,6 +956,9 @@ ggml_tensor * llm_graph_context::build_moe_ffn(
         cb(weights, "ffn_moe_weights_scaled", il);
     }
 
+    //call early so that topk-moe can be used
+    ggml_build_forward_expand(gf, weights);
+
     cur = ggml_reshape_3d(ctx0, cur, n_embd, 1, n_tokens);
 
     if (weight_before_ffn) {
index fdc90ca0697dc72a70dbee447179dee53a9932ec..8918452cb68d187e4566a9ee792b5253f0e46007 100644 (file)
@@ -4418,6 +4418,49 @@ struct test_argsort : public test_case {
     }
 };
 
+struct test_topk_moe: public test_case {
+    const std::array<int64_t, 4> ne;
+    const int n_expert_used;
+    const bool with_norm;
+    test_topk_moe(std::array<int64_t, 4> ne = {10, 5, 1, 1}, int n_expert_used = 1, bool with_norm = false)
+    : ne(ne), n_expert_used(n_expert_used), with_norm(with_norm) {
+        GGML_ASSERT(n_expert_used <= ne[0]);
+    }
+
+    std::string vars() override {
+        return VARS_TO_STR3(ne, n_expert_used, with_norm);
+    }
+
+    std::string op_desc(ggml_tensor * t) override {
+        GGML_UNUSED(t);
+        return "TOPK_MOE";
+    }
+
+    bool run_whole_graph() override { return true; }
+
+    ggml_tensor * build_graph(ggml_context * ctx) override {
+        const int n_expert = ne[0];
+        const int n_tokens = ne[1];
+
+        ggml_tensor * logits = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne.data());
+        ggml_tensor * probs  = ggml_soft_max(ctx, logits);
+        ggml_tensor * selected_experts = ggml_top_k(ctx, probs, n_expert_used); // [n_expert_used, n_tokens]
+
+        ggml_tensor * out = ggml_get_rows(ctx, ggml_reshape_3d(ctx, probs, 1, n_expert, n_tokens), selected_experts); // [1, n_expert_used, n_tokens]
+
+        if (with_norm) {
+            out = ggml_reshape_2d(ctx, out, n_expert_used, n_tokens);
+            ggml_tensor * weights_sum = ggml_sum_rows(ctx, out); // [1, n_tokens]
+
+            out = ggml_div(ctx, out, weights_sum); // [n_expert_used, n_tokens]
+            out = ggml_reshape_3d(ctx, out, 1, n_expert_used, n_tokens);
+        }
+
+        ggml_set_name(out, "out");
+        return out;
+    }
+};
+
 // GGML_OP_SUM
 struct test_sum : public test_case {
     const ggml_type type;
@@ -6588,6 +6631,12 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
     test_cases.emplace_back(new test_opt_step_adamw(GGML_TYPE_F32, {10, 5, 4, 3}));
     test_cases.emplace_back(new test_opt_step_sgd(GGML_TYPE_F32, {10, 5, 4, 3}));
 
+    for (bool with_norm : {false, true}) {
+        test_cases.emplace_back(new test_topk_moe({8, 22, 1, 1}, 4, with_norm));
+        test_cases.emplace_back(new test_topk_moe({32, 22, 1, 1}, 8, with_norm));
+        test_cases.emplace_back(new test_topk_moe({128, 1, 1, 1}, 128, with_norm));
+    }
+
 #if 0
     // these tests are disabled to save execution time, sbut they can be handy for debugging
     test_cases.emplace_back(new test_llama(2, true));