]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
CUDA: support for weight clamp in top-k norm (#16702)
authorAman Gupta <redacted>
Mon, 27 Oct 2025 01:06:16 +0000 (09:06 +0800)
committerGitHub <redacted>
Mon, 27 Oct 2025 01:06:16 +0000 (09:06 +0800)
ggml/src/ggml-cuda/ggml-cuda.cu
ggml/src/ggml-cuda/topk-moe.cu
ggml/src/ggml-cuda/topk-moe.cuh
tests/test-backend-ops.cpp

index 6b688bfecdedd1bdde8680ce4ecb3adf0db82374..94ab1ec0f5a908bd5d1ae7469eb2e09f8f600345 100644 (file)
@@ -2976,7 +2976,7 @@ static bool ggml_cuda_can_fuse(const struct ggml_cgraph * cgraph, int node_idx,
     if (ops.size() == topk_moe_ops_with_norm.size() &&
         ggml_can_fuse_subgraph(cgraph, node_idx, ops, { node_idx + 3, node_idx + 8 })) {
         ggml_tensor * softmax = cgraph->nodes[node_idx];
-        ggml_tensor * weights = cgraph->nodes[node_idx+8];
+        ggml_tensor * weights = cgraph->nodes[node_idx + 9];
 
         if (ggml_cuda_should_use_topk_moe(softmax, weights)) {
             return true;
@@ -2986,7 +2986,7 @@ static bool ggml_cuda_can_fuse(const struct ggml_cgraph * cgraph, int node_idx,
     if (ops.size() == topk_moe_ops.size() &&
         ggml_can_fuse_subgraph(cgraph, node_idx, ops, { node_idx + 3, node_idx + 4 })) {
         ggml_tensor * softmax = cgraph->nodes[node_idx];
-        ggml_tensor * weights = cgraph->nodes[node_idx+4];
+        ggml_tensor * weights = cgraph->nodes[node_idx + 4];
         if (ggml_cuda_should_use_topk_moe(softmax, weights)) {
             return true;
         }
@@ -3125,17 +3125,18 @@ static void evaluate_and_capture_cuda_graph(ggml_backend_cuda_context * cuda_ctx
                 if (!disable_fusion) {
 
                     if (ggml_cuda_can_fuse(cgraph, i, ggml_cuda_topk_moe_ops(/*with norm*/ true), {})) {
-                        ggml_tensor * weights = cgraph->nodes[i+8];
-                        ggml_tensor * selected_experts = cgraph->nodes[i+3];
+                        ggml_tensor * weights          = cgraph->nodes[i + 9];
+                        ggml_tensor * selected_experts = cgraph->nodes[i + 3];
+                        ggml_tensor * clamp            = cgraph->nodes[i + 7];
                         ggml_cuda_op_topk_moe(*cuda_ctx, node->src[0], weights, selected_experts, /*with norm*/ true,
-                                              /*delayed softmax*/ false);
-                        i += 8;
+                                              /*delayed softmax*/ false, clamp);
+                        i += 9;
                         continue;
                     }
 
                     if (ggml_cuda_can_fuse(cgraph, i, ggml_cuda_topk_moe_ops(/*with norm*/ false), {})) {
-                        ggml_tensor * weights = cgraph->nodes[i+4];
-                        ggml_tensor * selected_experts = cgraph->nodes[i+3];
+                        ggml_tensor * weights          = cgraph->nodes[i + 4];
+                        ggml_tensor * selected_experts = cgraph->nodes[i + 3];
                         ggml_cuda_op_topk_moe(*cuda_ctx, node->src[0], weights, selected_experts, /*with norm*/ false,
                                               /*delayed softmax*/ false);
                         i += 4;
index e28c810ac5df7b7fb4c7e89b10e7de4014795cf5..572379fcbf0e81ba031a8c6fe92dba6555ba850b 100644 (file)
@@ -2,6 +2,7 @@
 #include "ggml.h"
 #include "topk-moe.cuh"
 
+#include <cmath>
 #include <initializer_list>
 
 // Warp-local softmax used for both the pre-top-k logits and the post-top-k delayed path.
@@ -63,7 +64,8 @@ __launch_bounds__(4 * WARP_SIZE, 1) __global__ void topk_moe_cuda(const float *
                                                                   float *       weights,
                                                                   int32_t *     ids,
                                                                   const int     n_rows,
-                                                                  const int     n_expert_used) {
+                                                                  const int     n_expert_used,
+                                                                  const float   clamp_val) {
     const int row = blockIdx.x * blockDim.y + threadIdx.y;
     if (row >= n_rows) {
         return;
@@ -139,6 +141,7 @@ __launch_bounds__(4 * WARP_SIZE, 1) __global__ void topk_moe_cuda(const float *
 
     if constexpr (with_norm) {
         wt_sum              = warp_reduce_sum(wt_sum);
+        wt_sum              = max(wt_sum, clamp_val);
         const float inv_sum = 1.0f / wt_sum;
 
         for (int i = 0; i < experts_per_thread; i++) {
@@ -157,6 +160,10 @@ __launch_bounds__(4 * WARP_SIZE, 1) __global__ void topk_moe_cuda(const float *
             weights[idx] = output_weights[i];
         }
     }
+
+    if (!with_norm) {
+        GGML_UNUSED(clamp_val);
+    }
 }
 
 template <bool with_norm, bool delayed_softmax = false>
@@ -166,9 +173,9 @@ static void launch_topk_moe_cuda(ggml_backend_cuda_context & ctx,
                                  int32_t *                   ids,
                                  const int                   n_rows,
                                  const int                   n_expert,
-                                 const int                   n_expert_used) {
+                                 const int                   n_expert_used,
+                                 const float                 clamp_val) {
     static_assert(!(with_norm && delayed_softmax), "delayed softmax is not supported with weight normalization");
-
     const int    rows_per_block = 4;
     dim3         grid_dims((n_rows + rows_per_block - 1) / rows_per_block, 1, 1);
     dim3         block_dims(WARP_SIZE, rows_per_block, 1);
@@ -177,43 +184,43 @@ static void launch_topk_moe_cuda(ggml_backend_cuda_context & ctx,
     switch (n_expert) {
         case 1:
             topk_moe_cuda<1, with_norm, delayed_softmax>
-                <<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, n_rows, n_expert_used);
+                <<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, n_rows, n_expert_used, clamp_val);
             break;
         case 2:
             topk_moe_cuda<2, with_norm, delayed_softmax>
-                <<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, n_rows, n_expert_used);
+                <<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, n_rows, n_expert_used, clamp_val);
             break;
         case 4:
             topk_moe_cuda<4, with_norm, delayed_softmax>
-                <<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, n_rows, n_expert_used);
+                <<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, n_rows, n_expert_used, clamp_val);
             break;
         case 8:
             topk_moe_cuda<8, with_norm, delayed_softmax>
-                <<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, n_rows, n_expert_used);
+                <<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, n_rows, n_expert_used, clamp_val);
             break;
         case 16:
             topk_moe_cuda<16, with_norm, delayed_softmax>
-                <<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, n_rows, n_expert_used);
+                <<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, n_rows, n_expert_used, clamp_val);
             break;
         case 32:
             topk_moe_cuda<32, with_norm, delayed_softmax>
-                <<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, n_rows, n_expert_used);
+                <<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, n_rows, n_expert_used, clamp_val);
             break;
         case 64:
             topk_moe_cuda<64, with_norm, delayed_softmax>
-                <<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, n_rows, n_expert_used);
+                <<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, n_rows, n_expert_used, clamp_val);
             break;
         case 128:
             topk_moe_cuda<128, with_norm, delayed_softmax>
-                <<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, n_rows, n_expert_used);
+                <<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, n_rows, n_expert_used, clamp_val);
             break;
         case 256:
             topk_moe_cuda<256, with_norm, delayed_softmax>
-                <<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, n_rows, n_expert_used);
+                <<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, n_rows, n_expert_used, clamp_val);
             break;
         case 512:
             topk_moe_cuda<512, with_norm, delayed_softmax>
-                <<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, n_rows, n_expert_used);
+                <<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, n_rows, n_expert_used, clamp_val);
             break;
         default:
             GGML_ASSERT(false && "fatal error");
@@ -226,7 +233,8 @@ void ggml_cuda_op_topk_moe(ggml_backend_cuda_context & ctx,
                            ggml_tensor *               weights,
                            ggml_tensor *               ids,
                            const bool                  with_norm,
-                           const bool                  delayed_softmax) {
+                           const bool                  delayed_softmax,
+                           ggml_tensor *               clamp) {
     GGML_ASSERT(logits->type == GGML_TYPE_F32);
     GGML_ASSERT(weights->type == GGML_TYPE_F32);
     GGML_ASSERT(ids->type == GGML_TYPE_I32);
@@ -242,18 +250,25 @@ void ggml_cuda_op_topk_moe(ggml_backend_cuda_context & ctx,
 
     const int n_expert_used = weights->ne[1];
 
+    float clamp_val = -INFINITY;
     if (with_norm) {
-        launch_topk_moe_cuda<true>(ctx, logits_d, weights_d, ids_d, n_rows, n_experts, n_expert_used);
+        if (clamp) {
+            clamp_val = ggml_get_op_params_f32(clamp, 0);
+        }
+        launch_topk_moe_cuda<true>(ctx, logits_d, weights_d, ids_d, n_rows, n_experts, n_expert_used, clamp_val);
     } else {
+        GGML_ASSERT(clamp == nullptr);
         if (delayed_softmax) {
-            launch_topk_moe_cuda<false, true>(ctx, logits_d, weights_d, ids_d, n_rows, n_experts, n_expert_used);
+            launch_topk_moe_cuda<false, true>(ctx, logits_d, weights_d, ids_d, n_rows, n_experts, n_expert_used,
+                                              clamp_val);
         } else {
-            launch_topk_moe_cuda<false, false>(ctx, logits_d, weights_d, ids_d, n_rows, n_experts, n_expert_used);
+            launch_topk_moe_cuda<false, false>(ctx, logits_d, weights_d, ids_d, n_rows, n_experts, n_expert_used,
+                                               clamp_val);
         }
     }
 }
 
-bool ggml_cuda_should_use_topk_moe(const ggml_tensor * softmax, const ggml_tensor * weights) {
+bool ggml_cuda_should_use_topk_moe(const ggml_tensor * softmax, const ggml_tensor * weights, const ggml_tensor * clamp) {
     float scale    = 1.0f;
     float max_bias = 0.0f;
 
@@ -279,13 +294,26 @@ bool ggml_cuda_should_use_topk_moe(const ggml_tensor * softmax, const ggml_tenso
         return false;
     }
 
+    if (clamp) {
+        if (clamp->op != GGML_OP_CLAMP) {
+            return false;
+        }
+        float max_val = ggml_get_op_params_f32(clamp, 1);
+
+        if (max_val != INFINITY) {
+            return false;
+        }
+    }
+
+
     return true;
 }
 
 std::initializer_list<enum ggml_op> ggml_cuda_topk_moe_ops(bool norm, bool delayed_softmax) {
     static std::initializer_list<enum ggml_op> norm_ops = { GGML_OP_SOFT_MAX, GGML_OP_RESHAPE,  GGML_OP_ARGSORT,
                                                             GGML_OP_VIEW,     GGML_OP_GET_ROWS, GGML_OP_RESHAPE,
-                                                            GGML_OP_SUM_ROWS, GGML_OP_DIV,      GGML_OP_RESHAPE };
+                                                            GGML_OP_SUM_ROWS, GGML_OP_CLAMP,    GGML_OP_DIV,
+                                                            GGML_OP_RESHAPE };
 
     static std::initializer_list<enum ggml_op> no_norm_ops = { GGML_OP_SOFT_MAX, GGML_OP_RESHAPE, GGML_OP_ARGSORT,
                                                                GGML_OP_VIEW, GGML_OP_GET_ROWS };
index cc2fbfe9e6649e73b946a7964c3779d84ea3dace..2eff408b03058b82a2b27f8a95de98aad6168426 100644 (file)
@@ -8,8 +8,9 @@ void ggml_cuda_op_topk_moe(ggml_backend_cuda_context & ctx,
                            ggml_tensor *               weights,
                            ggml_tensor *               ids,
                            const bool                  with_norm,
-                           const bool                  delayed_softmax = false);
+                           const bool                  delayed_softmax = false,
+                           ggml_tensor *               weight_clamp    = nullptr);
 
-bool ggml_cuda_should_use_topk_moe(const ggml_tensor * softmax, const ggml_tensor * weights);
+bool ggml_cuda_should_use_topk_moe(const ggml_tensor * softmax, const ggml_tensor * weights, const ggml_tensor * clamp = nullptr);
 
 std::initializer_list<enum ggml_op> ggml_cuda_topk_moe_ops(bool with_norm, bool delayed_softmax = false);
index 0ad73e944edaccde534a5299d77fcdd27f7784cb..60157037b29e4121f679940c7005bb0f5d7f3083 100644 (file)
@@ -4712,6 +4712,7 @@ struct test_topk_moe: public test_case {
             out = ggml_reshape_2d(ctx, out, n_expert_used, n_tokens);
             ggml_tensor * weights_sum = ggml_sum_rows(ctx, out); // [1, n_tokens]
 
+            weights_sum = ggml_clamp(ctx, weights_sum, 6.103515625e-5, INFINITY);
             out = ggml_div(ctx, out, weights_sum); // [n_expert_used, n_tokens]
             out = ggml_reshape_3d(ctx, out, 1, n_expert_used, n_tokens);
         }