int32_t device_;
};
+#ifdef USE_ACL_GRAPH
+struct ggml_graph_node_properties {
+ void * node_address;
+ ggml_op node_op;
+ int64_t ne[GGML_MAX_DIMS];
+ size_t nb[GGML_MAX_DIMS];
+ void * src_address[GGML_MAX_SRC];
+ int32_t op_params[GGML_MAX_OP_PARAMS / sizeof(int32_t)];
+};
+
+struct ggml_cann_graph {
+ ~ggml_cann_graph() {
+ if (graph != nullptr) {
+ aclmdlRIDestroy(graph);
+ }
+ }
+
+ aclmdlRI graph = nullptr;
+
+ std::vector<ggml_graph_node_properties> ggml_graph_properties;
+};
+#endif // USE_ACL_GRAPH
+
/**
* @brief Context for managing CANN backend operations.
*/
std::string name; /**< Name of the device. */
std::string description; /**< Description of the device. */
aclrtEvent copy_event = nullptr; /**< Event for managing copy operations. */
+#ifdef USE_ACL_GRAPH
+ /// Cached CANN ACL graph used for executing the current ggml computation graph.
+ std::unique_ptr<ggml_cann_graph> cann_graph;
+#endif
cann_task_queue task_queue;
bool async_mode;
+ bool support_set_rows;
aclrtStream streams[GGML_CANN_MAX_STREAMS] = {nullptr}; /**< Array of streams for the device. */
async_mode = parse_bool(get_env("GGML_CANN_ASYNC_MODE").value_or(""));
GGML_LOG_INFO("%s: device %d async operator submission is %s\n", __func__,
device, async_mode ? "ON" : "OFF");
+
+ support_set_rows = parse_bool(get_env("LLAMA_SET_ROWS").value_or(""));
+ GGML_LOG_INFO("%s: LLAMA_SET_ROWS is %s\n", __func__, support_set_rows ? "ON" : "OFF");
+
+ if (!support_set_rows) {
+ GGML_LOG_INFO("%s: CANN Graph currently only supports execution when LLAMA_SET_ROWS is ON. "
+ "Falling back to eager mode.\n", __func__);
+ }
}
/**
ACL_CHECK(aclrtSynchronizeStream(cann_ctx->stream()));
}
+#ifdef USE_ACL_GRAPH
+/**
+ * @brief Populate the internal CANN graph node properties from the ggml computation graph.
+ *
+ * This function copies all node attributes (operation type, dimensions, strides, input sources,
+ * and operation parameters) into the cached CANN graph structure for later reuse or comparison.
+ *
+ * @param cann_ctx The CANN backend context.
+ * @param cgraph The ggml computational graph.
+ */
+static void set_ggml_graph_node_properties(ggml_backend_cann_context * cann_ctx, ggml_cgraph * cgraph) {
+ for (int node_idx = 0; node_idx < cgraph->n_nodes; node_idx++) {
+ ggml_tensor * node = cgraph->nodes[node_idx];
+ cann_ctx->cann_graph->ggml_graph_properties[node_idx].node_address = node->data;
+ cann_ctx->cann_graph->ggml_graph_properties[node_idx].node_op = node->op;
+
+ for (int dim = 0; dim < GGML_MAX_DIMS; dim++) {
+ cann_ctx->cann_graph->ggml_graph_properties[node_idx].ne[dim] = node->ne[dim];
+ cann_ctx->cann_graph->ggml_graph_properties[node_idx].nb[dim] = node->nb[dim];
+ }
+ for (int src = 0; src < GGML_MAX_SRC; src++) {
+ cann_ctx->cann_graph->ggml_graph_properties[node_idx].src_address[src] =
+ node->src[src] ? node->src[src]->data : nullptr;
+ }
+ memcpy(cann_ctx->cann_graph->ggml_graph_properties[node_idx].op_params, node->op_params, GGML_MAX_OP_PARAMS);
+ }
+}
+
+/**
+ * @brief Check if a ggml tensor node matches a previously captured CANN graph node.
+ *
+ * This function compares all relevant fields (address, op type, shape, source inputs, op params)
+ * to determine whether the current node matches a previously recorded version.
+ *
+ * @param node The current ggml tensor node.
+ * @param graph_node_properties The stored properties of a CANN graph node.
+ * @return true if all fields match (excluding GGML_OP_VIEW); false otherwise.
+ */
+static bool ggml_graph_node_has_matching_properties(ggml_tensor * node, ggml_graph_node_properties * graph_node_properties) {
+ if (node->data != graph_node_properties->node_address &&
+ node->op != GGML_OP_VIEW) {
+ return false;
+ }
+ if (node->op != graph_node_properties->node_op) {
+ return false;
+ }
+ for (int i = 0; i < GGML_MAX_DIMS; i++) {
+ if (node->ne[i] != graph_node_properties->ne[i]) {
+ return false;
+ }
+ if (node->nb[i] != graph_node_properties->nb[i]) {
+ return false;
+ }
+ }
+ for (int i = 0; i < GGML_MAX_SRC; i++) {
+ if (node->src[i] &&
+ node->src[i]->data != graph_node_properties->src_address[i] &&
+ node->op != GGML_OP_VIEW
+ ) {
+ return false;
+ }
+ }
+ if (node->op == GGML_OP_SCALE &&
+ memcmp(graph_node_properties->op_params, node->op_params, GGML_MAX_OP_PARAMS) != 0) {
+ return false;
+ }
+ return true;
+}
+
+/**
+ * @brief Determine if the CANN graph needs to be rebuilt due to graph changes.
+ *
+ * This checks whether the number or properties of ggml graph nodes have changed
+ * compared to the last captured CANN graph. If so, the CANN graph must be re-captured.
+ *
+ * @param cann_ctx The CANN backend context.
+ * @param cgraph The current ggml computation graph.
+ * @return true if an update is required; false otherwise.
+ */
+static bool is_cann_graph_update_required(ggml_backend_cann_context * cann_ctx, ggml_cgraph * cgraph) {
+ // The number of nodes is different, so the graph needs to be reconstructed.
+ if (cann_ctx->cann_graph->ggml_graph_properties.size() != (size_t)cgraph->n_nodes) {
+ cann_ctx->cann_graph->ggml_graph_properties.resize(cgraph->n_nodes);
+ return true;
+ }
+
+ // The number of nodes is the same; iterate over each node to check whether they match.
+ for (int i = 0; i < cgraph->n_nodes; i++) {
+ bool has_matching_properties = ggml_graph_node_has_matching_properties(
+ cgraph->nodes[i], &cann_ctx->cann_graph->ggml_graph_properties[i]);
+ if(!has_matching_properties) {
+ return true;
+ }
+ }
+ return false;
+}
+#endif // USE_ACL_GRAPH
+
+/**
+ * @brief Evaluate the computation graph and optionally capture or execute it using CANN graph API.
+ *
+ * If CANN graph execution is enabled and graph capture is required, this function begins
+ * graph capture, runs the graph, ends capture, and stores the captured graph.
+ *
+ * Otherwise, it falls back to op-by-op execution using the CANN compute kernel dispatcher.
+ *
+ * @param cann_ctx The CANN backend context.
+ * @param cgraph The ggml computation graph.
+ * @param use_cann_graph Whether to use CANN graph execution.
+ * @param cann_graph_update_required Whether graph capture is needed due to graph changes.
+ */
+static void evaluate_and_capture_cann_graph(ggml_backend_cann_context * cann_ctx, ggml_cgraph * cgraph,
+ bool & use_cann_graph, bool & cann_graph_update_required) {
+#ifdef USE_ACL_GRAPH
+ if (use_cann_graph && cann_graph_update_required) {
+ if (cann_ctx->cann_graph->graph != nullptr) {
+ ACL_CHECK(aclmdlRIDestroy(cann_ctx->cann_graph->graph));
+ cann_ctx->cann_graph->graph = nullptr;
+ }
+ ACL_CHECK(aclmdlRICaptureBegin(cann_ctx->stream(), ACL_MODEL_RI_CAPTURE_MODE_GLOBAL));
+ }
+#endif // USE_ACL_GRAPH
+
+ // Only perform the graph execution if CANN graphs are not enabled, or we are capturing the graph.
+ // With the use of CANN graphs, the execution will be performed by the graph launch.
+ if (!use_cann_graph || cann_graph_update_required) {
+ for (int i = 0; i < cgraph->n_nodes; i++) {
+ ggml_tensor * node = cgraph->nodes[i];
+
+ if (ggml_is_empty(node) || node->op == GGML_OP_RESHAPE || node->op == GGML_OP_TRANSPOSE || node->op == GGML_OP_VIEW || node->op == GGML_OP_PERMUTE || node->op == GGML_OP_NONE) {
+ continue;
+ }
+
+ bool ok = ggml_cann_compute_forward(*cann_ctx, node);
+ if (!ok) {
+ GGML_LOG_ERROR("%s: op not supported %s (%s)\n", __func__, node->name, ggml_op_name(node->op));
+ }
+ GGML_ASSERT(ok);
+ }
+ }
+
+#ifdef USE_ACL_GRAPH
+ if (use_cann_graph && cann_graph_update_required) { // End CANN graph capture
+ ACL_CHECK(aclmdlRICaptureEnd(cann_ctx->stream(), &cann_ctx->cann_graph->graph));
+ }
+
+ if (use_cann_graph) {
+ // Execute graph
+ ACL_CHECK(aclmdlRIExecuteAsync(cann_ctx->cann_graph->graph, cann_ctx->stream()));
+ }
+#endif // USE_ACL_GRAPH
+}
+
+
/**
* @brief Computes a computational graph using a CANN backend.
*
ggml_backend_t backend, ggml_cgraph* cgraph) {
ggml_backend_cann_context* cann_ctx =
(ggml_backend_cann_context*)backend->context;
-
ggml_cann_set_device(cann_ctx->device);
- //release temp buffer create by set tensor.
release_nz_workspace();
+#ifdef USE_ACL_GRAPH
+ bool use_cann_graph = true;
+ bool cann_graph_update_required = false;
- for (int i = 0; i < cgraph->n_nodes; i++) {
- ggml_tensor* node = cgraph->nodes[i];
+ // check environment LLAMA_SET_ROWS
+ if (!cann_ctx->support_set_rows) {
+ use_cann_graph = false;
+ }
- if (ggml_is_empty(node) || node->op == GGML_OP_NONE) {
- continue;
+ if (use_cann_graph) {
+ if (cann_ctx->cann_graph == nullptr) {
+ cann_ctx->cann_graph.reset(new ggml_cann_graph());
+ cann_graph_update_required = true;
}
- bool ok = ggml_cann_compute_forward(*cann_ctx, node);
-
- if (!ok) {
- GGML_LOG_ERROR("%s: error: op not supported %s (%s)\n", __func__,
- node->name, ggml_op_name(node->op));
- }
- GGML_ASSERT(ok);
+ cann_graph_update_required = is_cann_graph_update_required(cann_ctx, cgraph);
+ set_ggml_graph_node_properties(cann_ctx, cgraph);
}
+#else
+ bool use_cann_graph = false;
+ bool cann_graph_update_required = false;
+#endif // USE_ACL_GRAPH
+
+ evaluate_and_capture_cann_graph(
+ cann_ctx,
+ cgraph,
+ use_cann_graph,
+ cann_graph_update_required
+ );
return GGML_STATUS_SUCCESS;
}
// only support F32 and F16.
return false;
}
-
- if (!ggml_are_same_shape(op, src) && !ggml_is_contiguous(op)) {
- // unsupport dst is not contiguous.
- return false;
- }
-
return true;
} break;
case GGML_OP_CONT: {