]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
CANN: add operator fusion support for ADD + RMS_NORM (#17512)
authorChenguang Li <redacted>
Mon, 5 Jan 2026 07:38:18 +0000 (15:38 +0800)
committerGitHub <redacted>
Mon, 5 Jan 2026 07:38:18 +0000 (15:38 +0800)
This commit implements operator fusion for ADD + RMS_NORM operations
in the CANN backend to reduce memory access overhead and improve
performance. The fusion is controlled by the GGML_CANN_OPERATOR_FUSION
environment variable (default: false).

Changes:
- Implement ggml_cann_op_add_rms_norm_fused() using ACLNN AddRmsNorm
- Add ggml_cann_can_fuse() to check fusion eligibility
- Integrate fusion logic into computation graph evaluation
- Add test cases for ADD + RMS_NORM fusion
- Update documentation with new environment variable

The fusion combines ADD and RMS_NORM into a single kernel call,
which is more efficient than executing them separately.

docs/backend/CANN.md
ggml/src/ggml-cann/aclnn_ops.cpp
ggml/src/ggml-cann/aclnn_ops.h
ggml/src/ggml-cann/ggml-cann.cpp
tests/test-backend-ops.cpp

index 37dcfaef9a84d22fe40336cece90afadd0bb3c68..b03c2a122cb3eed31d386143ede8735ec60e86eb 100755 (executable)
@@ -327,3 +327,7 @@ Maximum number of compiled CANN graphs kept in the LRU cache, default is 12. Whe
 ### GGML_CANN_PREFILL_USE_GRAPH
 
 Enable ACL graph execution during the prefill stage, default is false. This option is only effective when FA is enabled.
+
+### GGML_CANN_OPERATOR_FUSION
+
+Enable operator fusion during computation, default is false. This option fuses compatible operators (e.g., ADD + RMS_NORM) to reduce overhead and improve performance.
index 2180a06fd00123fcafe82ea5515cc428fa6ec7e9..50b6bd00e4cfe6ea02701de8b539d6034f7d8b6d 100644 (file)
@@ -26,6 +26,7 @@
 #include "ggml.h"
 
 #include <aclnnop/aclnn_add.h>
+#include <aclnnop/aclnn_add_rms_norm.h>
 #include <aclnnop/aclnn_addcdiv.h>
 #include <aclnnop/aclnn_argmax.h>
 #include <aclnnop/aclnn_avgpool2d.h>
@@ -3805,3 +3806,57 @@ void ggml_cann_ssm_conv(ggml_backend_cann_context & ctx, ggml_tensor * dst) {
                             cubeMathType);
 }
 
+
+void ggml_cann_op_add_rms_norm_fused(ggml_backend_cann_context & ctx,
+                                     ggml_tensor *               add_node,
+                                     ggml_tensor *               rms_norm_node) {
+    // Get the two input tensors for ADD operation
+    ggml_tensor * x1 = add_node->src[0];
+    ggml_tensor * x2 = add_node->src[1];
+
+    // Create ACL tensors for the two ADD inputs
+    acl_tensor_ptr acl_x1 = ggml_cann_create_tensor(x1);
+    acl_tensor_ptr acl_x2 = ggml_cann_create_tensor(x2);
+
+    // Get epsilon parameter from rms_norm_tensor
+    float eps;
+    memcpy(&eps, rms_norm_node->op_params, sizeof(float));
+
+    // Build gamma tensor (RMS normalization scaling factor)
+    // Gamma should match the normalized dimensions (last dimension of x1)
+    size_t acl_gamma_nb[GGML_MAX_DIMS];
+    acl_gamma_nb[0] = ggml_type_size(rms_norm_node->type);
+    for (int i = 1; i < GGML_MAX_DIMS; i++) {
+        acl_gamma_nb[i] = acl_gamma_nb[i - 1] * x1->ne[i - 1];
+    }
+    acl_tensor_ptr acl_gamma =
+        get_cache_acl_tensor(ctx, &ctx.rms_norm_one_tensor_cache.cache, ctx.rms_norm_one_tensor_cache.size, x1->ne,
+                             acl_gamma_nb, rms_norm_node->type,
+                             1,    // dims - only the last dimension
+                             1.0f  // value
+        );
+
+    // Build rstdOut tensor (output for normalized standard deviation)
+    // Shape should be the dimensions that are NOT normalized
+    int64_t acl_rstd_ne[] = { 1, x1->ne[1], x1->ne[2], x1->ne[3] };
+    size_t  acl_rstd_nb[GGML_MAX_DIMS - 1];
+    acl_rstd_nb[0] = sizeof(float);
+    for (int i = 1; i < GGML_MAX_DIMS - 1; i++) {
+        acl_rstd_nb[i] = acl_rstd_nb[i - 1] * acl_rstd_ne[i - 1];
+    }
+    acl_tensor_ptr acl_rstd =
+        get_cache_acl_tensor(ctx, &ctx.rms_norm_zero_tensor_cache.cache, ctx.rms_norm_zero_tensor_cache.size,
+                             acl_rstd_ne, acl_rstd_nb, GGML_TYPE_F32, GGML_MAX_DIMS,
+                             0.0f  // value
+        );
+
+    acl_tensor_ptr acl_xout = ggml_cann_create_tensor(add_node);
+
+    // Create yOut tensor (final output after RMS normalization)
+    acl_tensor_ptr acl_yout = ggml_cann_create_tensor(rms_norm_node);
+
+    // Call fused ADD + RMS_NORM operator
+    GGML_CANN_CALL_ACLNN_OP(ctx, AddRmsNorm, acl_x1.get(), acl_x2.get(), acl_gamma.get(),
+                            eps,  // double type
+                            acl_yout.get(), acl_rstd.get(), acl_xout.get());
+}
index a6ea016c542780c80f65e039e5dfb3e82f28b99e..08ee7b1fbdf8bb149d7d7a764b37378c67474bed 100644 (file)
@@ -935,6 +935,20 @@ template <typename... Args> void register_acl_resources(std::vector<any_acl_reso
  */
 void ggml_cann_mul_mat_id(ggml_backend_cann_context & ctx, ggml_tensor * dst);
 
+/**
+ * @brief Performs fused ADD + RMS_NORM operation using the CANN backend.
+ *
+ * This function fuses the ADD and RMS_NORM operations into a single kernel call
+ * for better performance. It first adds two input tensors (x1 + x2), then applies
+ * RMS normalization to the result.
+ *
+ * @param ctx The context for the CANN backend operations.
+ * @param dst The ADD operation node, contains the two input tensors to be added.
+ * @param rms_norm_tensor The RMS_NORM operation node, contains the gamma weights
+ *                        and epsilon parameter.
+ */
+void ggml_cann_op_add_rms_norm_fused(ggml_backend_cann_context & ctx, ggml_tensor * add_node, ggml_tensor * rms_norm_node);
+
 /**
  * @brief   Check whether a tensor is a weight tensor for matrix multiplication.
  *
index ef23ec78da69340e9baef1cc9ece73626bf3ed7e..7f6214e4fb40a23f745830f3a96db4ab46423bb5 100644 (file)
@@ -1888,6 +1888,7 @@ static bool ggml_cann_compute_forward(ggml_backend_cann_context & ctx, struct gg
             break;
         case GGML_OP_OUT_PROD:
             ggml_cann_out_prod(ctx, dst);
+            break;
         case GGML_OP_SSM_CONV:
             ggml_cann_ssm_conv(ctx, dst);
             break;
@@ -2077,6 +2078,40 @@ static void ggml_backend_cann_synchronize(ggml_backend_t backend) {
     ACL_CHECK(aclrtSynchronizeStream(cann_ctx->stream()));
 }
 
+/**
+ * @brief Check if CANN backend can fuse the specified operation sequence
+ *
+ * This function determines whether an operation sequence starting from the specified node
+ * can be fused into an optimized operation in the CANN backend. Operation fusion can reduce
+ * memory access overhead and improve computational efficiency.
+ *
+ * @param cgraph Pointer to the computation graph
+ * @param node_idx Index of the starting node in the computation graph
+ * @param ops Sequence of operation types to check for fusion
+ * @return true if the operations can be fused
+ * @return false if the operations cannot be fused
+ */
+static bool ggml_cann_can_fuse(const struct ggml_cgraph *          cgraph,
+                               int                                 node_idx,
+                               std::initializer_list<enum ggml_op> ops) {
+    if (!ggml_can_fuse(cgraph, node_idx, ops)) {
+        return false;
+    }
+
+    // CANN backend supports fusing ADD + RMS_NORM operations
+    if ((ops.size() == 2) && ops.begin()[0] == GGML_OP_ADD && ops.begin()[1] == GGML_OP_RMS_NORM) {
+        ggml_tensor * add_node = cgraph->nodes[node_idx];
+        // TODO: support broadcast for ADD + RMS_NORM
+        if (add_node->src[0]->ne[0] != add_node->src[1]->ne[0] || add_node->src[0]->ne[1] != add_node->src[1]->ne[1] ||
+            add_node->src[0]->ne[2] != add_node->src[1]->ne[2] || add_node->src[0]->ne[3] != add_node->src[1]->ne[3]) {
+            return false;
+        }
+        return true;
+    }
+
+    return false;
+}
+
 /**
  * @brief Evaluate the computation graph and optionally capture or execute it using CANN graph API.
  *
@@ -2101,9 +2136,18 @@ static void evaluate_and_capture_cann_graph(ggml_backend_cann_context * cann_ctx
 #endif  // USE_ACL_GRAPH
     // Only perform the graph execution if CANN graphs are not enabled, or we are capturing the graph.
     // With the use of CANN graphs, the execution will be performed by the graph launch.
+    static bool opt_fusion = parse_bool(get_env("GGML_CANN_OPERATOR_FUSION").value_or(""));
+
     if (!use_cann_graph || cann_graph_capture_required) {
         for (int i = 0; i < cgraph->n_nodes; i++) {
             ggml_tensor * node = cgraph->nodes[i];
+            if (opt_fusion) {
+                if (ggml_cann_can_fuse(cgraph, i, { GGML_OP_ADD, GGML_OP_RMS_NORM })) {
+                    ggml_cann_op_add_rms_norm_fused(*cann_ctx, node, cgraph->nodes[i + 1]);
+                    i++;
+                    continue;
+                }
+            }
 
             if (ggml_is_empty(node) || node->op == GGML_OP_RESHAPE || node->op == GGML_OP_TRANSPOSE ||
                 node->op == GGML_OP_VIEW || node->op == GGML_OP_PERMUTE || node->op == GGML_OP_NONE) {
index 76abfdaf0afbaff2978c7c38726dd6b907280ddb..fa6e80e3fcd78d38e577aa62591a19140a5b81e9 100644 (file)
@@ -3431,6 +3431,65 @@ struct test_rms_norm_mul_add : public test_case {
     }
 };
 
+// GGML_OP_ADD + GGML_OP_RMS_NORM (fused operation)
+struct test_add_rms_norm : public test_case {
+    const ggml_type type;
+    const std::array<int64_t, 4> ne;
+    const float eps;
+    const bool broadcast;
+
+    std::string op_desc(ggml_tensor * t) override {
+        GGML_UNUSED(t);
+        return "ADD_RMS_NORM";
+    }
+
+    bool run_whole_graph() override { return true; }
+
+    std::string vars() override {
+        return VARS_TO_STR4(type, ne, eps, broadcast);
+    }
+
+    test_add_rms_norm(ggml_type type = GGML_TYPE_F32,
+            std::array<int64_t, 4> ne = {64, 5, 4, 3},
+            float eps = 1e-6f, bool broadcast = false)
+        : type(type), ne(ne), eps(eps), broadcast(broadcast) {}
+
+    ggml_tensor * build_graph(ggml_context * ctx) override {
+        std::array<int64_t, 4> broadcast_dims = {ne[0]*2, ne[1]*3, ne[2]*3, ne[3]*4};
+
+        ggml_tensor * a = ggml_new_tensor(ctx, type, 4, broadcast ? broadcast_dims.data() : ne.data());
+        ggml_tensor * b = ggml_new_tensor(ctx, type, 4, ne.data());
+
+        ggml_set_param(a);
+        ggml_set_name(a, "a");
+        ggml_set_param(b);
+        ggml_set_name(b, "b");
+
+        // ADD operation followed by RMS_NORM
+        ggml_tensor * add_result = ggml_add(ctx, a, b);
+        ggml_set_name(add_result, "add_result");
+
+        ggml_tensor * out = ggml_rms_norm(ctx, add_result, eps);
+        ggml_set_name(out, "out");
+
+        return out;
+    }
+
+    void initialize_tensors(ggml_context * ctx) override {
+        for (ggml_tensor * t = ggml_get_first_tensor(ctx); t != NULL; t = ggml_get_next_tensor(ctx, t)) {
+            init_tensor_uniform(t, -10.f, 10.f);
+        }
+    }
+
+    float grad_eps() override {
+        return 1.0f;
+    }
+
+    bool grad_precise() override {
+        return true;
+    }
+};
+
 // GGML_OP_SSM_CONV
 struct test_ssm_conv : public test_case {
     const ggml_type type;
@@ -7393,11 +7452,14 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
         test_cases.emplace_back(new test_rms_norm_mul_add(GGML_TYPE_F32, {64, 5, 4, 3}, eps, true));
         test_cases.emplace_back(new test_norm_mul_add(GGML_TYPE_F32, {64, 5, 4, 3}, eps, false));
         test_cases.emplace_back(new test_norm_mul_add(GGML_TYPE_F32, {64, 5, 4, 3}, eps, true));
+        test_cases.emplace_back(new test_add_rms_norm(GGML_TYPE_F32, {64, 5, 4, 3}, eps, false));
+        test_cases.emplace_back(new test_add_rms_norm(GGML_TYPE_F32, {64, 5, 4, 3}, eps, true));
     }
     for (uint32_t n : {1, 511, 1025, 8192, 33*512}) {
         for (bool multi_add : {false, true}) {
             test_cases.emplace_back(new test_rms_norm_mul_add(GGML_TYPE_F32, {n, 1, 1, 1}, 1e-6f, false, multi_add));
         }
+        test_cases.emplace_back(new test_add_rms_norm(GGML_TYPE_F32, {n, 1, 1, 1}, 1e-6f, false));
     }
 
     for (auto multi_add : {false, true}) {