]> git.djapps.eu Git - pkg/ggml/sources/whisper.cpp/commitdiff
Revert "CUDA: add expert reduce kernel (ggml/16857)" (llama/17100)
authorAman Gupta <redacted>
Sat, 8 Nov 2025 13:05:19 +0000 (21:05 +0800)
committerGeorgi Gerganov <redacted>
Sun, 9 Nov 2025 21:38:03 +0000 (23:38 +0200)
ggml/src/ggml-cuda/ggml-cuda.cu
ggml/src/ggml-cuda/moe-expert-reduce.cu [deleted file]
ggml/src/ggml-cuda/moe-expert-reduce.cuh [deleted file]

index 2d4314fba4fdc97a020c0c1d169d47f43acffd6a..68dc57843e44b63b1f16bbb15939fa582882823a 100644 (file)
@@ -27,7 +27,6 @@
 #include "ggml-cuda/mmq.cuh"
 #include "ggml-cuda/mmvf.cuh"
 #include "ggml-cuda/mmvq.cuh"
-#include "ggml-cuda/moe-expert-reduce.cuh"
 #include "ggml-cuda/norm.cuh"
 #include "ggml-cuda/opt-step-adamw.cuh"
 #include "ggml-cuda/opt-step-sgd.cuh"
@@ -3197,31 +3196,6 @@ static void evaluate_and_capture_cuda_graph(ggml_backend_cuda_context * cuda_ctx
                         continue;
                     }
 
-                    if (node->op == GGML_OP_MUL) {
-                        int current_node = i + 1;
-                        int num_views    = 0;
-                        int num_adds     = 0;
-                        while (current_node < cgraph->n_nodes && cgraph->nodes[current_node]->op == GGML_OP_VIEW) {
-                            num_views++;
-                            current_node++;
-                        }
-
-                        while (current_node < cgraph->n_nodes && cgraph->nodes[current_node]->op == GGML_OP_ADD &&
-                                num_adds < num_views - 1) {
-                            num_adds++;
-                            current_node++;
-                        }
-
-                        if (num_adds == num_views - 1 && num_views > 0) {
-                            ggml_tensor * dst_node = cgraph->nodes[current_node - 1];
-                            if (ggml_cuda_should_use_moe_expert_reduce(cgraph, i, current_node)) {
-                                ggml_cuda_op_moe_expert_reduce(*cuda_ctx, node->src[0], node->src[1], dst_node);
-                                i += num_views + num_adds;
-                                continue;
-                            }
-                        }
-                    }
-
                     if (node->op == GGML_OP_ADD) {
                         int n_fuse = 0;
                         ggml_op ops[8];
diff --git a/ggml/src/ggml-cuda/moe-expert-reduce.cu b/ggml/src/ggml-cuda/moe-expert-reduce.cu
deleted file mode 100644 (file)
index a97c5d5..0000000
+++ /dev/null
@@ -1,168 +0,0 @@
-#include "moe-expert-reduce.cuh"
-
-// This kernel is a fusion of the expert weight reduce, common in MoE models
-
-template <int n_expert_used_template>
-__global__ void moe_expert_reduce_cuda(const float * __restrict__ experts,
-                                       const float * __restrict__ weights,
-                                       float * __restrict__ dst,
-                                       const int n_expert_used,
-                                       const int n_cols) {
-    const int row = blockIdx.x;
-    const int col = blockIdx.y * blockDim.x + threadIdx.x;
-    if (col >= n_cols) {
-        return;
-    }
-
-    experts += row * n_cols * n_expert_used;
-    weights += row * n_expert_used;
-    dst += row * n_cols;
-
-    float acc = 0.f;
-    if constexpr (n_expert_used_template == 0) {
-        for (int expert = 0; expert < n_expert_used; ++expert) {
-            ggml_cuda_mad(acc, experts[col], weights[expert]);
-            experts += n_cols;
-        }
-        dst[col] = acc;
-    } else {
-#pragma unroll
-        for (int i = 0; i < n_expert_used_template; ++i) {
-            ggml_cuda_mad(acc, experts[col], weights[i]);
-            experts += n_cols;
-        }
-        dst[col] = acc;
-    }
-}
-
-static void launch_moe_expert_reduce(ggml_backend_cuda_context & ctx,
-                                     const float *               experts,
-                                     const float *               weights,
-                                     float *                     dst,
-                                     const int                   n_expert_used,
-                                     const int                   n_cols,
-                                     const int                   n_rows) {
-    const int block_size = 32;
-
-    const int n_blocks_x = n_rows;
-    const int n_blocks_y = (n_cols + block_size - 1) / block_size;
-
-    dim3 block_dims(block_size);
-    dim3 grid_dims(n_blocks_x, n_blocks_y);
-
-    cudaStream_t stream = ctx.stream();
-    switch (n_expert_used) {
-        case 1:
-            moe_expert_reduce_cuda<1>
-                <<<grid_dims, block_dims, 0, stream>>>(experts, weights, dst, n_expert_used, n_cols);
-            break;
-        case 2:
-            moe_expert_reduce_cuda<2>
-                <<<grid_dims, block_dims, 0, stream>>>(experts, weights, dst, n_expert_used, n_cols);
-            break;
-        case 4:
-            moe_expert_reduce_cuda<4>
-                <<<grid_dims, block_dims, 0, stream>>>(experts, weights, dst, n_expert_used, n_cols);
-            break;
-        case 6:
-            moe_expert_reduce_cuda<6>
-                <<<grid_dims, block_dims, 0, stream>>>(experts, weights, dst, n_expert_used, n_cols);
-            break;
-        case 8:
-            moe_expert_reduce_cuda<8>
-                <<<grid_dims, block_dims, 0, stream>>>(experts, weights, dst, n_expert_used, n_cols);
-            break;
-        case 16:
-            moe_expert_reduce_cuda<16>
-                <<<grid_dims, block_dims, 0, stream>>>(experts, weights, dst, n_expert_used, n_cols);
-            break;
-        case 32:
-            moe_expert_reduce_cuda<32>
-                <<<grid_dims, block_dims, 0, stream>>>(experts, weights, dst, n_expert_used, n_cols);
-            break;
-        case 64:
-            moe_expert_reduce_cuda<64>
-                <<<grid_dims, block_dims, 0, stream>>>(experts, weights, dst, n_expert_used, n_cols);
-            break;
-        case 128:
-            moe_expert_reduce_cuda<128>
-                <<<grid_dims, block_dims, 0, stream>>>(experts, weights, dst, n_expert_used, n_cols);
-            break;
-        default:
-            moe_expert_reduce_cuda<0>
-                <<<grid_dims, block_dims, 0, stream>>>(experts, weights, dst, n_expert_used, n_cols);
-            break;
-    }
-}
-
-bool ggml_cuda_should_use_moe_expert_reduce(const ggml_cgraph * cgraph, int start_index, int end_index) {
-    const ggml_tensor * mul = cgraph->nodes[start_index];
-
-    if (mul->op != GGML_OP_MUL || !ggml_is_contiguous(mul->src[0]) || !ggml_is_contiguous(mul->src[1])) {
-        return false;
-    }
-
-    int    current_node   = start_index + 1;
-    size_t current_offset = 0;
-
-    std::vector<const ggml_tensor *> view_nodes;
-    //check if all are views of the expert in increasing order
-    while (current_node < end_index && cgraph->nodes[current_node]->op == GGML_OP_VIEW) {
-        const ggml_tensor * node = cgraph->nodes[current_node];
-        if (node->view_src != mul) {
-            return false;
-        }
-        if (node->view_offs < current_offset) {
-            return false;
-        }
-        current_offset = node->view_offs;
-        current_node++;
-        view_nodes.push_back(node);
-    }
-
-    //check if all the adds are in increasing order
-    const ggml_tensor * prev_add_src = view_nodes.empty() ? nullptr : view_nodes[0];
-    int                 num_adds     = 0;
-    int                 num_views    = view_nodes.size();
-    while (current_node < end_index && cgraph->nodes[current_node]->op == GGML_OP_ADD) {
-        const ggml_tensor * add_node = cgraph->nodes[current_node];
-
-        bool is_first_op_ok  = num_views > num_adds ? add_node->src[0] == prev_add_src : false;
-        bool is_second_op_ok = num_views > num_adds ? add_node->src[1] == view_nodes[num_adds + 1] : false;
-
-        if (!is_first_op_ok || !is_second_op_ok) {
-            return false;
-        }
-        prev_add_src = add_node;
-
-        num_adds++;
-        current_node++;
-    }
-
-    if (num_views != num_adds + 1) {
-        return false;
-    }
-
-    return true;
-}
-
-void ggml_cuda_op_moe_expert_reduce(ggml_backend_cuda_context & ctx,
-                                    const ggml_tensor *         experts,
-                                    const ggml_tensor *         weights,
-                                    ggml_tensor *               dst) {
-    const int n_rows        = experts->ne[2];
-    const int n_expert_used = experts->ne[1];
-    const int n_cols        = experts->ne[0];
-
-    GGML_ASSERT(experts->type == GGML_TYPE_F32);
-    GGML_ASSERT(weights->type == GGML_TYPE_F32);
-    GGML_ASSERT(ggml_is_contiguous(experts));
-    GGML_ASSERT(ggml_is_contiguous(weights));
-    GGML_ASSERT(dst->type == GGML_TYPE_F32);
-
-    const float * experts_d = (const float *) experts->data;
-    const float * weights_d = (const float *) weights->data;
-    float *       dst_d     = (float *) dst->data;
-
-    launch_moe_expert_reduce(ctx, experts_d, weights_d, dst_d, n_expert_used, n_cols, n_rows);
-}
diff --git a/ggml/src/ggml-cuda/moe-expert-reduce.cuh b/ggml/src/ggml-cuda/moe-expert-reduce.cuh
deleted file mode 100644 (file)
index cafc50e..0000000
+++ /dev/null
@@ -1,11 +0,0 @@
-#include "common.cuh"
-#include "ggml.h"
-
-#include <initializer_list>
-
-void ggml_cuda_op_moe_expert_reduce(ggml_backend_cuda_context & ctx,
-                                    const ggml_tensor *         experts,
-                                    const ggml_tensor *         weights,
-                                    ggml_tensor *               dst);
-
-bool ggml_cuda_should_use_moe_expert_reduce(const ggml_cgraph * cgraph, int start_index, int end_index);