From: Chenguang Li Date: Mon, 5 Jan 2026 07:38:18 +0000 (+0800) Subject: CANN: add operator fusion support for ADD + RMS_NORM (llama/17512) X-Git-Tag: upstream/1.8.3~42 X-Git-Url: https://git.djapps.eu/?a=commitdiff_plain;h=1d657effe392fc5b302a5cbada9f5e10332bc6af;p=pkg%2Fggml%2Fsources%2Fwhisper.cpp CANN: add operator fusion support for ADD + RMS_NORM (llama/17512) This commit implements operator fusion for ADD + RMS_NORM operations in the CANN backend to reduce memory access overhead and improve performance. The fusion is controlled by the GGML_CANN_OPERATOR_FUSION environment variable (default: false). Changes: - Implement ggml_cann_op_add_rms_norm_fused() using ACLNN AddRmsNorm - Add ggml_cann_can_fuse() to check fusion eligibility - Integrate fusion logic into computation graph evaluation - Add test cases for ADD + RMS_NORM fusion - Update documentation with new environment variable The fusion combines ADD and RMS_NORM into a single kernel call, which is more efficient than executing them separately. --- diff --git a/ggml/src/ggml-cann/aclnn_ops.cpp b/ggml/src/ggml-cann/aclnn_ops.cpp index 2180a06f..50b6bd00 100644 --- a/ggml/src/ggml-cann/aclnn_ops.cpp +++ b/ggml/src/ggml-cann/aclnn_ops.cpp @@ -26,6 +26,7 @@ #include "ggml.h" #include +#include #include #include #include @@ -3805,3 +3806,57 @@ void ggml_cann_ssm_conv(ggml_backend_cann_context & ctx, ggml_tensor * dst) { cubeMathType); } + +void ggml_cann_op_add_rms_norm_fused(ggml_backend_cann_context & ctx, + ggml_tensor * add_node, + ggml_tensor * rms_norm_node) { + // Get the two input tensors for ADD operation + ggml_tensor * x1 = add_node->src[0]; + ggml_tensor * x2 = add_node->src[1]; + + // Create ACL tensors for the two ADD inputs + acl_tensor_ptr acl_x1 = ggml_cann_create_tensor(x1); + acl_tensor_ptr acl_x2 = ggml_cann_create_tensor(x2); + + // Get epsilon parameter from rms_norm_tensor + float eps; + memcpy(&eps, rms_norm_node->op_params, sizeof(float)); + + // Build gamma tensor (RMS normalization scaling factor) + // Gamma should match the normalized dimensions (last dimension of x1) + size_t acl_gamma_nb[GGML_MAX_DIMS]; + acl_gamma_nb[0] = ggml_type_size(rms_norm_node->type); + for (int i = 1; i < GGML_MAX_DIMS; i++) { + acl_gamma_nb[i] = acl_gamma_nb[i - 1] * x1->ne[i - 1]; + } + acl_tensor_ptr acl_gamma = + get_cache_acl_tensor(ctx, &ctx.rms_norm_one_tensor_cache.cache, ctx.rms_norm_one_tensor_cache.size, x1->ne, + acl_gamma_nb, rms_norm_node->type, + 1, // dims - only the last dimension + 1.0f // value + ); + + // Build rstdOut tensor (output for normalized standard deviation) + // Shape should be the dimensions that are NOT normalized + int64_t acl_rstd_ne[] = { 1, x1->ne[1], x1->ne[2], x1->ne[3] }; + size_t acl_rstd_nb[GGML_MAX_DIMS - 1]; + acl_rstd_nb[0] = sizeof(float); + for (int i = 1; i < GGML_MAX_DIMS - 1; i++) { + acl_rstd_nb[i] = acl_rstd_nb[i - 1] * acl_rstd_ne[i - 1]; + } + acl_tensor_ptr acl_rstd = + get_cache_acl_tensor(ctx, &ctx.rms_norm_zero_tensor_cache.cache, ctx.rms_norm_zero_tensor_cache.size, + acl_rstd_ne, acl_rstd_nb, GGML_TYPE_F32, GGML_MAX_DIMS, + 0.0f // value + ); + + acl_tensor_ptr acl_xout = ggml_cann_create_tensor(add_node); + + // Create yOut tensor (final output after RMS normalization) + acl_tensor_ptr acl_yout = ggml_cann_create_tensor(rms_norm_node); + + // Call fused ADD + RMS_NORM operator + GGML_CANN_CALL_ACLNN_OP(ctx, AddRmsNorm, acl_x1.get(), acl_x2.get(), acl_gamma.get(), + eps, // double type + acl_yout.get(), acl_rstd.get(), acl_xout.get()); +} diff --git a/ggml/src/ggml-cann/aclnn_ops.h b/ggml/src/ggml-cann/aclnn_ops.h index a6ea016c..08ee7b1f 100644 --- a/ggml/src/ggml-cann/aclnn_ops.h +++ b/ggml/src/ggml-cann/aclnn_ops.h @@ -935,6 +935,20 @@ template void register_acl_resources(std::vectorstream())); } +/** + * @brief Check if CANN backend can fuse the specified operation sequence + * + * This function determines whether an operation sequence starting from the specified node + * can be fused into an optimized operation in the CANN backend. Operation fusion can reduce + * memory access overhead and improve computational efficiency. + * + * @param cgraph Pointer to the computation graph + * @param node_idx Index of the starting node in the computation graph + * @param ops Sequence of operation types to check for fusion + * @return true if the operations can be fused + * @return false if the operations cannot be fused + */ +static bool ggml_cann_can_fuse(const struct ggml_cgraph * cgraph, + int node_idx, + std::initializer_list ops) { + if (!ggml_can_fuse(cgraph, node_idx, ops)) { + return false; + } + + // CANN backend supports fusing ADD + RMS_NORM operations + if ((ops.size() == 2) && ops.begin()[0] == GGML_OP_ADD && ops.begin()[1] == GGML_OP_RMS_NORM) { + ggml_tensor * add_node = cgraph->nodes[node_idx]; + // TODO: support broadcast for ADD + RMS_NORM + if (add_node->src[0]->ne[0] != add_node->src[1]->ne[0] || add_node->src[0]->ne[1] != add_node->src[1]->ne[1] || + add_node->src[0]->ne[2] != add_node->src[1]->ne[2] || add_node->src[0]->ne[3] != add_node->src[1]->ne[3]) { + return false; + } + return true; + } + + return false; +} + /** * @brief Evaluate the computation graph and optionally capture or execute it using CANN graph API. * @@ -2101,9 +2136,18 @@ static void evaluate_and_capture_cann_graph(ggml_backend_cann_context * cann_ctx #endif // USE_ACL_GRAPH // Only perform the graph execution if CANN graphs are not enabled, or we are capturing the graph. // With the use of CANN graphs, the execution will be performed by the graph launch. + static bool opt_fusion = parse_bool(get_env("GGML_CANN_OPERATOR_FUSION").value_or("")); + if (!use_cann_graph || cann_graph_capture_required) { for (int i = 0; i < cgraph->n_nodes; i++) { ggml_tensor * node = cgraph->nodes[i]; + if (opt_fusion) { + if (ggml_cann_can_fuse(cgraph, i, { GGML_OP_ADD, GGML_OP_RMS_NORM })) { + ggml_cann_op_add_rms_norm_fused(*cann_ctx, node, cgraph->nodes[i + 1]); + i++; + continue; + } + } if (ggml_is_empty(node) || node->op == GGML_OP_RESHAPE || node->op == GGML_OP_TRANSPOSE || node->op == GGML_OP_VIEW || node->op == GGML_OP_PERMUTE || node->op == GGML_OP_NONE) {