]> git.djapps.eu Git - pkg/ggml/sources/whisper.cpp/commitdiff
CANN: add operator fusion support for ADD + RMS_NORM (llama/17512)
authorChenguang Li <redacted>
Mon, 5 Jan 2026 07:38:18 +0000 (15:38 +0800)
committerGeorgi Gerganov <redacted>
Wed, 14 Jan 2026 07:11:59 +0000 (09:11 +0200)
This commit implements operator fusion for ADD + RMS_NORM operations
in the CANN backend to reduce memory access overhead and improve
performance. The fusion is controlled by the GGML_CANN_OPERATOR_FUSION
environment variable (default: false).

Changes:
- Implement ggml_cann_op_add_rms_norm_fused() using ACLNN AddRmsNorm
- Add ggml_cann_can_fuse() to check fusion eligibility
- Integrate fusion logic into computation graph evaluation
- Add test cases for ADD + RMS_NORM fusion
- Update documentation with new environment variable

The fusion combines ADD and RMS_NORM into a single kernel call,
which is more efficient than executing them separately.

ggml/src/ggml-cann/aclnn_ops.cpp
ggml/src/ggml-cann/aclnn_ops.h
ggml/src/ggml-cann/ggml-cann.cpp

index 2180a06fd00123fcafe82ea5515cc428fa6ec7e9..50b6bd00e4cfe6ea02701de8b539d6034f7d8b6d 100644 (file)
@@ -26,6 +26,7 @@
 #include "ggml.h"
 
 #include <aclnnop/aclnn_add.h>
+#include <aclnnop/aclnn_add_rms_norm.h>
 #include <aclnnop/aclnn_addcdiv.h>
 #include <aclnnop/aclnn_argmax.h>
 #include <aclnnop/aclnn_avgpool2d.h>
@@ -3805,3 +3806,57 @@ void ggml_cann_ssm_conv(ggml_backend_cann_context & ctx, ggml_tensor * dst) {
                             cubeMathType);
 }
 
+
+void ggml_cann_op_add_rms_norm_fused(ggml_backend_cann_context & ctx,
+                                     ggml_tensor *               add_node,
+                                     ggml_tensor *               rms_norm_node) {
+    // Get the two input tensors for ADD operation
+    ggml_tensor * x1 = add_node->src[0];
+    ggml_tensor * x2 = add_node->src[1];
+
+    // Create ACL tensors for the two ADD inputs
+    acl_tensor_ptr acl_x1 = ggml_cann_create_tensor(x1);
+    acl_tensor_ptr acl_x2 = ggml_cann_create_tensor(x2);
+
+    // Get epsilon parameter from rms_norm_tensor
+    float eps;
+    memcpy(&eps, rms_norm_node->op_params, sizeof(float));
+
+    // Build gamma tensor (RMS normalization scaling factor)
+    // Gamma should match the normalized dimensions (last dimension of x1)
+    size_t acl_gamma_nb[GGML_MAX_DIMS];
+    acl_gamma_nb[0] = ggml_type_size(rms_norm_node->type);
+    for (int i = 1; i < GGML_MAX_DIMS; i++) {
+        acl_gamma_nb[i] = acl_gamma_nb[i - 1] * x1->ne[i - 1];
+    }
+    acl_tensor_ptr acl_gamma =
+        get_cache_acl_tensor(ctx, &ctx.rms_norm_one_tensor_cache.cache, ctx.rms_norm_one_tensor_cache.size, x1->ne,
+                             acl_gamma_nb, rms_norm_node->type,
+                             1,    // dims - only the last dimension
+                             1.0f  // value
+        );
+
+    // Build rstdOut tensor (output for normalized standard deviation)
+    // Shape should be the dimensions that are NOT normalized
+    int64_t acl_rstd_ne[] = { 1, x1->ne[1], x1->ne[2], x1->ne[3] };
+    size_t  acl_rstd_nb[GGML_MAX_DIMS - 1];
+    acl_rstd_nb[0] = sizeof(float);
+    for (int i = 1; i < GGML_MAX_DIMS - 1; i++) {
+        acl_rstd_nb[i] = acl_rstd_nb[i - 1] * acl_rstd_ne[i - 1];
+    }
+    acl_tensor_ptr acl_rstd =
+        get_cache_acl_tensor(ctx, &ctx.rms_norm_zero_tensor_cache.cache, ctx.rms_norm_zero_tensor_cache.size,
+                             acl_rstd_ne, acl_rstd_nb, GGML_TYPE_F32, GGML_MAX_DIMS,
+                             0.0f  // value
+        );
+
+    acl_tensor_ptr acl_xout = ggml_cann_create_tensor(add_node);
+
+    // Create yOut tensor (final output after RMS normalization)
+    acl_tensor_ptr acl_yout = ggml_cann_create_tensor(rms_norm_node);
+
+    // Call fused ADD + RMS_NORM operator
+    GGML_CANN_CALL_ACLNN_OP(ctx, AddRmsNorm, acl_x1.get(), acl_x2.get(), acl_gamma.get(),
+                            eps,  // double type
+                            acl_yout.get(), acl_rstd.get(), acl_xout.get());
+}
index a6ea016c542780c80f65e039e5dfb3e82f28b99e..08ee7b1fbdf8bb149d7d7a764b37378c67474bed 100644 (file)
@@ -935,6 +935,20 @@ template <typename... Args> void register_acl_resources(std::vector<any_acl_reso
  */
 void ggml_cann_mul_mat_id(ggml_backend_cann_context & ctx, ggml_tensor * dst);
 
+/**
+ * @brief Performs fused ADD + RMS_NORM operation using the CANN backend.
+ *
+ * This function fuses the ADD and RMS_NORM operations into a single kernel call
+ * for better performance. It first adds two input tensors (x1 + x2), then applies
+ * RMS normalization to the result.
+ *
+ * @param ctx The context for the CANN backend operations.
+ * @param dst The ADD operation node, contains the two input tensors to be added.
+ * @param rms_norm_tensor The RMS_NORM operation node, contains the gamma weights
+ *                        and epsilon parameter.
+ */
+void ggml_cann_op_add_rms_norm_fused(ggml_backend_cann_context & ctx, ggml_tensor * add_node, ggml_tensor * rms_norm_node);
+
 /**
  * @brief   Check whether a tensor is a weight tensor for matrix multiplication.
  *
index ef23ec78da69340e9baef1cc9ece73626bf3ed7e..7f6214e4fb40a23f745830f3a96db4ab46423bb5 100644 (file)
@@ -1888,6 +1888,7 @@ static bool ggml_cann_compute_forward(ggml_backend_cann_context & ctx, struct gg
             break;
         case GGML_OP_OUT_PROD:
             ggml_cann_out_prod(ctx, dst);
+            break;
         case GGML_OP_SSM_CONV:
             ggml_cann_ssm_conv(ctx, dst);
             break;
@@ -2077,6 +2078,40 @@ static void ggml_backend_cann_synchronize(ggml_backend_t backend) {
     ACL_CHECK(aclrtSynchronizeStream(cann_ctx->stream()));
 }
 
+/**
+ * @brief Check if CANN backend can fuse the specified operation sequence
+ *
+ * This function determines whether an operation sequence starting from the specified node
+ * can be fused into an optimized operation in the CANN backend. Operation fusion can reduce
+ * memory access overhead and improve computational efficiency.
+ *
+ * @param cgraph Pointer to the computation graph
+ * @param node_idx Index of the starting node in the computation graph
+ * @param ops Sequence of operation types to check for fusion
+ * @return true if the operations can be fused
+ * @return false if the operations cannot be fused
+ */
+static bool ggml_cann_can_fuse(const struct ggml_cgraph *          cgraph,
+                               int                                 node_idx,
+                               std::initializer_list<enum ggml_op> ops) {
+    if (!ggml_can_fuse(cgraph, node_idx, ops)) {
+        return false;
+    }
+
+    // CANN backend supports fusing ADD + RMS_NORM operations
+    if ((ops.size() == 2) && ops.begin()[0] == GGML_OP_ADD && ops.begin()[1] == GGML_OP_RMS_NORM) {
+        ggml_tensor * add_node = cgraph->nodes[node_idx];
+        // TODO: support broadcast for ADD + RMS_NORM
+        if (add_node->src[0]->ne[0] != add_node->src[1]->ne[0] || add_node->src[0]->ne[1] != add_node->src[1]->ne[1] ||
+            add_node->src[0]->ne[2] != add_node->src[1]->ne[2] || add_node->src[0]->ne[3] != add_node->src[1]->ne[3]) {
+            return false;
+        }
+        return true;
+    }
+
+    return false;
+}
+
 /**
  * @brief Evaluate the computation graph and optionally capture or execute it using CANN graph API.
  *
@@ -2101,9 +2136,18 @@ static void evaluate_and_capture_cann_graph(ggml_backend_cann_context * cann_ctx
 #endif  // USE_ACL_GRAPH
     // Only perform the graph execution if CANN graphs are not enabled, or we are capturing the graph.
     // With the use of CANN graphs, the execution will be performed by the graph launch.
+    static bool opt_fusion = parse_bool(get_env("GGML_CANN_OPERATOR_FUSION").value_or(""));
+
     if (!use_cann_graph || cann_graph_capture_required) {
         for (int i = 0; i < cgraph->n_nodes; i++) {
             ggml_tensor * node = cgraph->nodes[i];
+            if (opt_fusion) {
+                if (ggml_cann_can_fuse(cgraph, i, { GGML_OP_ADD, GGML_OP_RMS_NORM })) {
+                    ggml_cann_op_add_rms_norm_fused(*cann_ctx, node, cgraph->nodes[i + 1]);
+                    i++;
+                    continue;
+                }
+            }
 
             if (ggml_is_empty(node) || node->op == GGML_OP_RESHAPE || node->op == GGML_OP_TRANSPOSE ||
                 node->op == GGML_OP_VIEW || node->op == GGML_OP_PERMUTE || node->op == GGML_OP_NONE) {