]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
CUDA: add expert reduce kernel (#16857)
authorAman Gupta <redacted>
Fri, 31 Oct 2025 12:05:07 +0000 (20:05 +0800)
committerGitHub <redacted>
Fri, 31 Oct 2025 12:05:07 +0000 (20:05 +0800)
* CUDA: add expert reduce kernel

* contigous checks, better formatting, use std::vector instead of array

* use vector empty instead of size

Co-authored-by: Johannes Gäßler <redacted>
---------

Co-authored-by: Johannes Gäßler <redacted>
ggml/src/ggml-cuda/ggml-cuda.cu
ggml/src/ggml-cuda/moe-expert-reduce.cu [new file with mode: 0644]
ggml/src/ggml-cuda/moe-expert-reduce.cuh [new file with mode: 0644]
tests/test-backend-ops.cpp

index fcff5d7cdc1f5cef9540b111a32afcddffbdd76f..61a8f1df87de1872c4eb825c6b07c96f5cc51ae4 100644 (file)
@@ -27,6 +27,7 @@
 #include "ggml-cuda/mmq.cuh"
 #include "ggml-cuda/mmvf.cuh"
 #include "ggml-cuda/mmvq.cuh"
+#include "ggml-cuda/moe-expert-reduce.cuh"
 #include "ggml-cuda/norm.cuh"
 #include "ggml-cuda/opt-step-adamw.cuh"
 #include "ggml-cuda/opt-step-sgd.cuh"
@@ -3169,6 +3170,31 @@ static void evaluate_and_capture_cuda_graph(ggml_backend_cuda_context * cuda_ctx
                         continue;
                     }
 
+                    if (node->op == GGML_OP_MUL) {
+                        int current_node = i + 1;
+                        int num_views    = 0;
+                        int num_adds     = 0;
+                        while (current_node < cgraph->n_nodes && cgraph->nodes[current_node]->op == GGML_OP_VIEW) {
+                            num_views++;
+                            current_node++;
+                        }
+
+                        while (current_node < cgraph->n_nodes && cgraph->nodes[current_node]->op == GGML_OP_ADD &&
+                                num_adds < num_views - 1) {
+                            num_adds++;
+                            current_node++;
+                        }
+
+                        if (num_adds == num_views - 1 && num_views > 0) {
+                            ggml_tensor * dst_node = cgraph->nodes[current_node - 1];
+                            if (ggml_cuda_should_use_moe_expert_reduce(cgraph, i, current_node)) {
+                                ggml_cuda_op_moe_expert_reduce(*cuda_ctx, node->src[0], node->src[1], dst_node);
+                                i += num_views + num_adds;
+                                continue;
+                            }
+                        }
+                    }
+
                     if (node->op == GGML_OP_ADD) {
                         int n_fuse = 0;
                         ggml_op ops[8];
diff --git a/ggml/src/ggml-cuda/moe-expert-reduce.cu b/ggml/src/ggml-cuda/moe-expert-reduce.cu
new file mode 100644 (file)
index 0000000..a97c5d5
--- /dev/null
@@ -0,0 +1,168 @@
+#include "moe-expert-reduce.cuh"
+
+// This kernel is a fusion of the expert weight reduce, common in MoE models
+
+template <int n_expert_used_template>
+__global__ void moe_expert_reduce_cuda(const float * __restrict__ experts,
+                                       const float * __restrict__ weights,
+                                       float * __restrict__ dst,
+                                       const int n_expert_used,
+                                       const int n_cols) {
+    const int row = blockIdx.x;
+    const int col = blockIdx.y * blockDim.x + threadIdx.x;
+    if (col >= n_cols) {
+        return;
+    }
+
+    experts += row * n_cols * n_expert_used;
+    weights += row * n_expert_used;
+    dst += row * n_cols;
+
+    float acc = 0.f;
+    if constexpr (n_expert_used_template == 0) {
+        for (int expert = 0; expert < n_expert_used; ++expert) {
+            ggml_cuda_mad(acc, experts[col], weights[expert]);
+            experts += n_cols;
+        }
+        dst[col] = acc;
+    } else {
+#pragma unroll
+        for (int i = 0; i < n_expert_used_template; ++i) {
+            ggml_cuda_mad(acc, experts[col], weights[i]);
+            experts += n_cols;
+        }
+        dst[col] = acc;
+    }
+}
+
+static void launch_moe_expert_reduce(ggml_backend_cuda_context & ctx,
+                                     const float *               experts,
+                                     const float *               weights,
+                                     float *                     dst,
+                                     const int                   n_expert_used,
+                                     const int                   n_cols,
+                                     const int                   n_rows) {
+    const int block_size = 32;
+
+    const int n_blocks_x = n_rows;
+    const int n_blocks_y = (n_cols + block_size - 1) / block_size;
+
+    dim3 block_dims(block_size);
+    dim3 grid_dims(n_blocks_x, n_blocks_y);
+
+    cudaStream_t stream = ctx.stream();
+    switch (n_expert_used) {
+        case 1:
+            moe_expert_reduce_cuda<1>
+                <<<grid_dims, block_dims, 0, stream>>>(experts, weights, dst, n_expert_used, n_cols);
+            break;
+        case 2:
+            moe_expert_reduce_cuda<2>
+                <<<grid_dims, block_dims, 0, stream>>>(experts, weights, dst, n_expert_used, n_cols);
+            break;
+        case 4:
+            moe_expert_reduce_cuda<4>
+                <<<grid_dims, block_dims, 0, stream>>>(experts, weights, dst, n_expert_used, n_cols);
+            break;
+        case 6:
+            moe_expert_reduce_cuda<6>
+                <<<grid_dims, block_dims, 0, stream>>>(experts, weights, dst, n_expert_used, n_cols);
+            break;
+        case 8:
+            moe_expert_reduce_cuda<8>
+                <<<grid_dims, block_dims, 0, stream>>>(experts, weights, dst, n_expert_used, n_cols);
+            break;
+        case 16:
+            moe_expert_reduce_cuda<16>
+                <<<grid_dims, block_dims, 0, stream>>>(experts, weights, dst, n_expert_used, n_cols);
+            break;
+        case 32:
+            moe_expert_reduce_cuda<32>
+                <<<grid_dims, block_dims, 0, stream>>>(experts, weights, dst, n_expert_used, n_cols);
+            break;
+        case 64:
+            moe_expert_reduce_cuda<64>
+                <<<grid_dims, block_dims, 0, stream>>>(experts, weights, dst, n_expert_used, n_cols);
+            break;
+        case 128:
+            moe_expert_reduce_cuda<128>
+                <<<grid_dims, block_dims, 0, stream>>>(experts, weights, dst, n_expert_used, n_cols);
+            break;
+        default:
+            moe_expert_reduce_cuda<0>
+                <<<grid_dims, block_dims, 0, stream>>>(experts, weights, dst, n_expert_used, n_cols);
+            break;
+    }
+}
+
+bool ggml_cuda_should_use_moe_expert_reduce(const ggml_cgraph * cgraph, int start_index, int end_index) {
+    const ggml_tensor * mul = cgraph->nodes[start_index];
+
+    if (mul->op != GGML_OP_MUL || !ggml_is_contiguous(mul->src[0]) || !ggml_is_contiguous(mul->src[1])) {
+        return false;
+    }
+
+    int    current_node   = start_index + 1;
+    size_t current_offset = 0;
+
+    std::vector<const ggml_tensor *> view_nodes;
+    //check if all are views of the expert in increasing order
+    while (current_node < end_index && cgraph->nodes[current_node]->op == GGML_OP_VIEW) {
+        const ggml_tensor * node = cgraph->nodes[current_node];
+        if (node->view_src != mul) {
+            return false;
+        }
+        if (node->view_offs < current_offset) {
+            return false;
+        }
+        current_offset = node->view_offs;
+        current_node++;
+        view_nodes.push_back(node);
+    }
+
+    //check if all the adds are in increasing order
+    const ggml_tensor * prev_add_src = view_nodes.empty() ? nullptr : view_nodes[0];
+    int                 num_adds     = 0;
+    int                 num_views    = view_nodes.size();
+    while (current_node < end_index && cgraph->nodes[current_node]->op == GGML_OP_ADD) {
+        const ggml_tensor * add_node = cgraph->nodes[current_node];
+
+        bool is_first_op_ok  = num_views > num_adds ? add_node->src[0] == prev_add_src : false;
+        bool is_second_op_ok = num_views > num_adds ? add_node->src[1] == view_nodes[num_adds + 1] : false;
+
+        if (!is_first_op_ok || !is_second_op_ok) {
+            return false;
+        }
+        prev_add_src = add_node;
+
+        num_adds++;
+        current_node++;
+    }
+
+    if (num_views != num_adds + 1) {
+        return false;
+    }
+
+    return true;
+}
+
+void ggml_cuda_op_moe_expert_reduce(ggml_backend_cuda_context & ctx,
+                                    const ggml_tensor *         experts,
+                                    const ggml_tensor *         weights,
+                                    ggml_tensor *               dst) {
+    const int n_rows        = experts->ne[2];
+    const int n_expert_used = experts->ne[1];
+    const int n_cols        = experts->ne[0];
+
+    GGML_ASSERT(experts->type == GGML_TYPE_F32);
+    GGML_ASSERT(weights->type == GGML_TYPE_F32);
+    GGML_ASSERT(ggml_is_contiguous(experts));
+    GGML_ASSERT(ggml_is_contiguous(weights));
+    GGML_ASSERT(dst->type == GGML_TYPE_F32);
+
+    const float * experts_d = (const float *) experts->data;
+    const float * weights_d = (const float *) weights->data;
+    float *       dst_d     = (float *) dst->data;
+
+    launch_moe_expert_reduce(ctx, experts_d, weights_d, dst_d, n_expert_used, n_cols, n_rows);
+}
diff --git a/ggml/src/ggml-cuda/moe-expert-reduce.cuh b/ggml/src/ggml-cuda/moe-expert-reduce.cuh
new file mode 100644 (file)
index 0000000..cafc50e
--- /dev/null
@@ -0,0 +1,11 @@
+#include "common.cuh"
+#include "ggml.h"
+
+#include <initializer_list>
+
+void ggml_cuda_op_moe_expert_reduce(ggml_backend_cuda_context & ctx,
+                                    const ggml_tensor *         experts,
+                                    const ggml_tensor *         weights,
+                                    ggml_tensor *               dst);
+
+bool ggml_cuda_should_use_moe_expert_reduce(const ggml_cgraph * cgraph, int start_index, int end_index);
index fa12c06ccddf8adee1e615ba1df635a5c611140f..04fa1b62d3b4d19a5d0779239ffa37f74d4cf509 100644 (file)
@@ -4807,6 +4807,60 @@ struct test_topk_moe: public test_case {
     }
 };
 
+struct test_moe_expert_reduce : public test_case {
+    const int64_t n_embd;
+    const int64_t n_tokens;
+    const int64_t n_expert_used;
+
+    test_moe_expert_reduce(int64_t n_embd = 64, int64_t n_tokens = 5, int64_t n_expert_used = 4)
+        : n_embd(n_embd), n_tokens(n_tokens), n_expert_used(n_expert_used) {
+        GGML_ASSERT(n_expert_used > 1);
+    }
+
+    std::string vars() override {
+        return VARS_TO_STR3(n_embd, n_tokens, n_expert_used);
+    }
+
+    std::string op_desc(ggml_tensor * t) override {
+        GGML_UNUSED(t);
+        return "MOE_EXPERT_REDUCE";
+    }
+
+    bool run_whole_graph() override { return true; }
+
+    ggml_tensor * build_graph(ggml_context * ctx) override {
+        ggml_tensor * experts = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, n_embd, n_expert_used, n_tokens);
+        ggml_set_name(experts, "experts");
+
+        ggml_tensor * weights = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, 1, n_expert_used, n_tokens);
+        ggml_set_name(weights, "weights");
+
+        ggml_tensor * weighted = ggml_mul(ctx, experts, weights);
+        ggml_set_name(weighted, "weighted_experts");
+
+        std::vector<ggml_tensor *> expert_views(n_expert_used);
+        for (int64_t i = 0; i < n_expert_used; ++i) {
+            expert_views[i] = ggml_view_2d(ctx, weighted, n_embd, n_tokens, weighted->nb[2], i * weighted->nb[1]);
+
+            std::string name = "expert_view_" + std::to_string(i);
+            ggml_set_name(expert_views[i], name.c_str());
+            ggml_build_forward_expand(gf, expert_views[i]);
+        }
+
+        ggml_tensor * moe_out = expert_views[0];
+        for (int64_t i = 1; i < n_expert_used; ++i) {
+            moe_out = ggml_add(ctx, moe_out, expert_views[i]);
+
+            std::string name = "expert_add_" + std::to_string(i - 1);
+            ggml_set_name(moe_out, name.c_str());
+        }
+
+        ggml_set_name(moe_out, "moe_out");
+
+        return moe_out;
+    }
+};
+
 struct test_mul_mat_vec_fusion : public test_case {
     const ggml_type type;
     const ggml_glu_op glu_op;
@@ -7260,6 +7314,10 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
     test_cases.emplace_back(new test_topk_moe({ 8, 22, 1, 1 }, 4, /*with_norm*/ false, /*delayed_softmax*/ true));
     test_cases.emplace_back(new test_topk_moe({ 32, 22, 1, 1 }, 8, /*with_norm*/ false, /*delayed_softmax*/ true));
 
+    test_cases.emplace_back(new test_moe_expert_reduce(1024, 5, 4));
+    test_cases.emplace_back(new test_moe_expert_reduce(80, 3, 6));
+    test_cases.emplace_back(new test_moe_expert_reduce(80, 3, 7));
+
 #if 0
     // these tests are disabled to save execution time, sbut they can be handy for debugging
     test_cases.emplace_back(new test_llama(2, true));