]> git.djapps.eu Git - pkg/ggml/sources/ggml/commitdiff
CANN : refactor ACL graph cache (llama/17752)
authorWang Weixuan <redacted>
Wed, 24 Dec 2025 09:50:24 +0000 (17:50 +0800)
committerGeorgi Gerganov <redacted>
Wed, 31 Dec 2025 10:39:43 +0000 (12:39 +0200)
Move the graph property checking code into methods of LRU cache.

Signed-off-by: Wang Weixuan <redacted>
src/ggml-cann/common.h
src/ggml-cann/ggml-cann.cpp

index 3a461ef1a72821e92944f938f61e812a85ba0458..e9a21e1b055cc7bbcd1979354aaf3ff4e501451e 100644 (file)
@@ -229,6 +229,60 @@ struct ggml_graph_node_properties {
     // op
     ggml_op node_op;
     int32_t op_params[GGML_MAX_OP_PARAMS / sizeof(int32_t)];
+
+    /**
+     * @brief Check if a ggml tensor node matches this property set.
+     *
+     * This function compares all relevant fields (address, op type, shape, source inputs, op params)
+     * to determine whether the current node matches these previously recorded properties.
+     *
+     * @param node The current ggml tensor node.
+     * @return true if all fields match (excluding GGML_OP_VIEW); false otherwise.
+     */
+    bool has_matching_properties(ggml_tensor * node) {
+        if (node->data != this->node_address && node->op != GGML_OP_VIEW) {
+            return false;
+        }
+
+        if (node->op != this->node_op) {
+            return false;
+        }
+
+        for (int i = 0; i < GGML_MAX_DIMS; i++) {
+            if (node->ne[i] != this->ne[i]) {
+                return false;
+            }
+            if (node->nb[i] != this->nb[i]) {
+                return false;
+            }
+        }
+
+        for (int i = 0; i < GGML_MAX_SRC; i++) {
+            if (node->src[i]) {
+                if (node->src[i]->data != this->src_address[i] && node->op != GGML_OP_VIEW) {
+                    return false;
+                }
+
+                for (int d = 0; d < GGML_MAX_DIMS; d++) {
+                    if (node->src[i]->ne[d] != this->src_ne[i][d]) {
+                        return false;
+                    }
+                    if (node->src[i]->nb[d] != this->src_nb[i][d]) {
+                        return false;
+                    }
+                }
+            } else {
+                if (this->src_address[i] != nullptr) {
+                    return false;
+                }
+            }
+        }
+
+        if (node->op == GGML_OP_SCALE || node->op == GGML_OP_UNARY || node->op == GGML_OP_GLU) {
+            return memcmp(this->op_params, node->op_params, GGML_MAX_OP_PARAMS) == 0;
+        }
+        return true;
+    }
 };
 
 struct ggml_cann_graph {
@@ -241,6 +295,79 @@ struct ggml_cann_graph {
     aclmdlRI graph = nullptr;
 
     std::vector<ggml_graph_node_properties> ggml_graph_properties;
+
+    /**
+     * @brief Create a new CANN graph from a ggml computation graph.
+     *
+     * This function creates a new ggml_cann_graph object and fills its node properties
+     * (operation type, dimensions, strides, input sources, and operation parameters)
+     * based on the current ggml computation graph.
+     *
+     * Each node in the ggml graph is mapped to a property entry in the new CANN graph:
+     * - node address
+     * - operation type
+     * - shape (ne) and strides (nb)
+     * - source tensor addresses
+     * - operation parameters
+     *
+     * @param cgraph The current ggml computation graph.
+     * @return Pointer to the newly created ggml_cann_graph object.
+     */
+    static ggml_cann_graph * create_from_cgraph(ggml_cgraph * cgraph) {
+        ggml_cann_graph * new_graph = new ggml_cann_graph();
+        new_graph->ggml_graph_properties.resize(cgraph->n_nodes);
+
+        for (int node_idx = 0; node_idx < cgraph->n_nodes; ++node_idx) {
+            ggml_tensor * node = cgraph->nodes[node_idx];
+            auto &        prop = new_graph->ggml_graph_properties[node_idx];
+
+            prop.node_address = node->data;
+            prop.node_op      = node->op;
+
+            std::copy_n(node->ne, GGML_MAX_DIMS, prop.ne);
+            std::copy_n(node->nb, GGML_MAX_DIMS, prop.nb);
+
+            for (int src = 0; src < GGML_MAX_SRC; ++src) {
+                if (node->src[src]) {
+                    prop.src_address[src] = node->src[src]->data;
+                    std::copy_n(node->src[src]->ne, GGML_MAX_DIMS, prop.src_ne[src]);
+                    std::copy_n(node->src[src]->nb, GGML_MAX_DIMS, prop.src_nb[src]);
+                } else {
+                    prop.src_address[src] = nullptr;
+                    std::fill_n(prop.src_ne[src], GGML_MAX_DIMS, 0);
+                    std::fill_n(prop.src_nb[src], GGML_MAX_DIMS, 0);
+                }
+            }
+
+            memcpy(prop.op_params, node->op_params, GGML_MAX_OP_PARAMS);
+        }
+
+        return new_graph;
+    }
+
+    /**
+     * @brief Check whether this CANN graph matches the given ggml computation graph.
+     *
+     * This function compares the number of nodes and each node's properties
+     * (operation type, dimensions, strides, inputs, and operation parameters)
+     * to determine whether this CANN graph matches the given ggml graph.
+     *
+     * @param cgraph The current ggml computation graph.
+     * @return true if this CANN graph matches the ggml graph; false otherwise.
+     */
+    bool matches_cgraph(ggml_cgraph * cgraph) {
+        if (this->ggml_graph_properties.size() != static_cast<size_t>(cgraph->n_nodes)) {
+            return false;
+        }
+
+        for (int i = 0; i < cgraph->n_nodes; ++i) {
+            if (!this->ggml_graph_properties[i].has_matching_properties(cgraph->nodes[i])) {
+                return false;
+            }
+        }
+
+        return true;
+    }
 };
 
 /**
@@ -272,15 +399,6 @@ struct ggml_cann_graph_lru_cache {
         cache_list.push_front(new_node);
     }
 
-    /**
-     * @brief Move an existing graph to the front of the cache.
-     * @param node Pointer to the ggml_cann_graph to move.
-     */
-    void move_to_front(ggml_cann_graph * node) {
-        cache_list.remove(node);
-        cache_list.push_front(node);
-    }
-
     /**
      * @brief Clear all graphs from the cache (also frees memory).
      */
@@ -295,6 +413,28 @@ struct ggml_cann_graph_lru_cache {
      * @brief Destructor that clears the cache and frees all cached graphs.
      */
     ~ggml_cann_graph_lru_cache() { clear(); }
+
+    /**
+     * @brief Find a cached CANN graph that matches the given ggml graph and move it to front.
+     *
+     * This function iterates through the cached CANN graphs stored in the LRU cache and
+     * compares them against the given ggml computation graph. If a matching graph is found,
+     * it is promoted to the front of the LRU cache and returned. Otherwise, the function
+     * returns nullptr.
+     *
+     * @param cgraph The current ggml computation graph.
+     * @return true if found; false otherwise.
+     */
+    bool find_and_move_to_front(ggml_cgraph * cgraph) {
+        for (auto & graph_ptr : this->cache_list) {
+            if (graph_ptr->matches_cgraph(cgraph)) {
+                cache_list.remove(graph_ptr);
+                cache_list.push_front(graph_ptr);
+                return true;
+            }
+        }
+        return false;
+    }
 };
 #endif  // USE_ACL_GRAPH
 
index da624c587c26fc7cf4b8d8758cea55218a776b8b..402e86d705d15f2ff5e05d5ae5f1ce8c32012ea8 100644 (file)
@@ -2075,162 +2075,6 @@ static void ggml_backend_cann_synchronize(ggml_backend_t backend) {
     ACL_CHECK(aclrtSynchronizeStream(cann_ctx->stream()));
 }
 
-#ifdef USE_ACL_GRAPH
-/**
- * @brief Add a new CANN graph to the LRU cache by populating node properties from the ggml graph.
- *
- * This function creates a new ggml_cann_graph object and fills its node properties
- * (operation type, dimensions, strides, input sources, and operation parameters)
- * based on the current ggml computation graph.
- *
- * Each node in the ggml graph is mapped to a property entry in the new CANN graph:
- * - node address
- * - operation type
- * - shape (ne) and strides (nb)
- * - source tensor addresses
- * - operation parameters
- *
- * After initialization, the new graph is pushed into the LRU cache owned by the
- * CANN backend context. The cache takes ownership of the graph and manages its
- * lifetime (including deletion upon eviction).
- *
- * @param cann_ctx  The CANN backend context containing the graph cache.
- * @param cgraph    The current ggml computation graph.
- */
-static void add_lru_matched_graph_node_properties(ggml_backend_cann_context * cann_ctx, ggml_cgraph * cgraph) {
-    // Create a new ggml_cann_graph object on the heap (its lifetime is managed by the cache).
-    ggml_cann_graph * new_graph = new ggml_cann_graph();
-    new_graph->ggml_graph_properties.resize(cgraph->n_nodes);
-
-    for (int node_idx = 0; node_idx < cgraph->n_nodes; ++node_idx) {
-        ggml_tensor * node = cgraph->nodes[node_idx];
-        auto &        prop = new_graph->ggml_graph_properties[node_idx];
-
-        prop.node_address = node->data;
-        prop.node_op      = node->op;
-
-        std::copy_n(node->ne, GGML_MAX_DIMS, prop.ne);
-        std::copy_n(node->nb, GGML_MAX_DIMS, prop.nb);
-
-        for (int src = 0; src < GGML_MAX_SRC; ++src) {
-            if (node->src[src]) {
-                prop.src_address[src] = node->src[src]->data;
-                std::copy_n(node->src[src]->ne, GGML_MAX_DIMS, prop.src_ne[src]);
-                std::copy_n(node->src[src]->nb, GGML_MAX_DIMS, prop.src_nb[src]);
-            } else {
-                prop.src_address[src] = nullptr;
-                std::fill_n(prop.src_ne[src], GGML_MAX_DIMS, 0);
-                std::fill_n(prop.src_nb[src], GGML_MAX_DIMS, 0);
-            }
-        }
-
-        memcpy(prop.op_params, node->op_params, GGML_MAX_OP_PARAMS);
-    }
-
-    // Insert into the LRU cache (cache takes ownership and will delete it when evicted).
-    cann_ctx->graph_lru_cache.push(new_graph);
-}
-
-/**
- * @brief Check if a ggml tensor node matches a previously captured CANN graph node.
- *
- * This function compares all relevant fields (address, op type, shape, source inputs, op params)
- * to determine whether the current node matches a previously recorded version.
- *
- * @param node                  The current ggml tensor node.
- * @param graph_node_properties The stored properties of a CANN graph node.
- * @return true if all fields match (excluding GGML_OP_VIEW); false otherwise.
- */
-static bool ggml_graph_node_has_matching_properties(ggml_tensor *                node,
-                                                    ggml_graph_node_properties * graph_node_properties) {
-    if (node->data != graph_node_properties->node_address && node->op != GGML_OP_VIEW) {
-        return false;
-    }
-
-    if (node->op != graph_node_properties->node_op) {
-        return false;
-    }
-
-    for (int i = 0; i < GGML_MAX_DIMS; i++) {
-        if (node->ne[i] != graph_node_properties->ne[i]) {
-            return false;
-        }
-        if (node->nb[i] != graph_node_properties->nb[i]) {
-            return false;
-        }
-    }
-
-    for (int i = 0; i < GGML_MAX_SRC; i++) {
-        if (node->src[i]) {
-            if (node->src[i]->data != graph_node_properties->src_address[i] && node->op != GGML_OP_VIEW) {
-                return false;
-            }
-
-            for (int d = 0; d < GGML_MAX_DIMS; d++) {
-                if (node->src[i]->ne[d] != graph_node_properties->src_ne[i][d]) {
-                    return false;
-                }
-                if (node->src[i]->nb[d] != graph_node_properties->src_nb[i][d]) {
-                    return false;
-                }
-            }
-        } else {
-            if (graph_node_properties->src_address[i] != nullptr) {
-                return false;
-            }
-        }
-    }
-
-    if (node->op == GGML_OP_SCALE || node->op == GGML_OP_UNARY || node->op == GGML_OP_GLU) {
-        return memcmp(graph_node_properties->op_params, node->op_params, GGML_MAX_OP_PARAMS) == 0;
-    }
-    return true;
-}
-
-/**
- * @brief Check whether there is a cached CANN graph that matches the current ggml graph.
- *
- * This function iterates through the cached CANN graphs stored in the LRU cache and
- * compares them against the given ggml computation graph. A match requires that the
- * number of nodes is the same and that each node’s properties (operation type,
- * dimensions, strides, inputs, and operation parameters) are identical.
- *
- * If a matching graph is found, it is promoted to the front of the LRU cache and the
- * function returns true. Otherwise, the function returns false, indicating that a new
- * CANN graph needs to be captured.
- *
- * @param cann_ctx  The CANN backend context containing the graph cache.
- * @param cgraph    The current ggml computation graph.
- * @return true if a matching cached graph exists; false otherwise.
- */
-static bool is_matched_graph(ggml_backend_cann_context * cann_ctx, ggml_cgraph * cgraph) {
-    ggml_cann_graph_lru_cache & lru_cache = cann_ctx->graph_lru_cache;
-    for (auto & graph_ptr : lru_cache.cache_list) {
-        // Skip graphs with a different number of nodes.
-        if (graph_ptr->ggml_graph_properties.size() != static_cast<size_t>(cgraph->n_nodes)) {
-            continue;
-        }
-
-        // Check if all nodes match.
-        bool all_match = true;
-        for (int i = 0; i < cgraph->n_nodes; ++i) {
-            if (!ggml_graph_node_has_matching_properties(cgraph->nodes[i], &graph_ptr->ggml_graph_properties[i])) {
-                all_match = false;
-                break;
-            }
-        }
-
-        if (all_match) {
-            // update cache_list && renturn graph_ptr
-            lru_cache.move_to_front(graph_ptr);
-            return true;
-        }
-    }
-
-    return false;
-}
-#endif  // USE_ACL_GRAPH
-
 /**
  * @brief Evaluate the computation graph and optionally capture or execute it using CANN graph API.
  *
@@ -2239,23 +2083,23 @@ static bool is_matched_graph(ggml_backend_cann_context * cann_ctx, ggml_cgraph *
  *
  * Otherwise, it falls back to op-by-op execution using the CANN compute kernel dispatcher.
  *
- * @param cann_ctx                 The CANN backend context.
- * @param cgraph                   The ggml computation graph.
- * @param use_cann_graph           Whether to use CANN graph execution.
- * @param cann_graph_update_required Whether graph capture is needed due to graph changes.
+ * @param cann_ctx                     The CANN backend context.
+ * @param cgraph                       The ggml computation graph.
+ * @param use_cann_graph               Whether to use CANN graph execution.
+ * @param cann_graph_capture_required  Whether graph capture is needed due to graph changes.
  */
 static void evaluate_and_capture_cann_graph(ggml_backend_cann_context * cann_ctx,
                                             ggml_cgraph *               cgraph,
-                                            bool &                      use_cann_graph,
-                                            bool &                      cann_graph_update_required) {
+                                            bool                        use_cann_graph,
+                                            bool                        cann_graph_capture_required) {
 #ifdef USE_ACL_GRAPH
-    if (use_cann_graph && cann_graph_update_required) {  // Begin CANN graph capture
+    if (use_cann_graph && cann_graph_capture_required) {  // Begin CANN graph capture
         ACL_CHECK(aclmdlRICaptureBegin(cann_ctx->stream(), ACL_MODEL_RI_CAPTURE_MODE_GLOBAL));
     }
 #endif  // USE_ACL_GRAPH
     // Only perform the graph execution if CANN graphs are not enabled, or we are capturing the graph.
     // With the use of CANN graphs, the execution will be performed by the graph launch.
-    if (!use_cann_graph || cann_graph_update_required) {
+    if (!use_cann_graph || cann_graph_capture_required) {
         for (int i = 0; i < cgraph->n_nodes; i++) {
             ggml_tensor * node = cgraph->nodes[i];
 
@@ -2274,9 +2118,10 @@ static void evaluate_and_capture_cann_graph(ggml_backend_cann_context * cann_ctx
 
 #ifdef USE_ACL_GRAPH
     if (use_cann_graph) {
+        GGML_ASSERT(!cann_ctx->graph_lru_cache.cache_list.empty());
         ggml_cann_graph * matched_graph = cann_ctx->graph_lru_cache.cache_list.front();
 
-        if (cann_graph_update_required) {  // End CANN graph capture
+        if (cann_graph_capture_required) {  // End CANN graph capture
             ACL_CHECK(aclmdlRICaptureEnd(cann_ctx->stream(), &matched_graph->graph));
         }
 
@@ -2306,7 +2151,7 @@ static enum ggml_status ggml_backend_cann_graph_compute(ggml_backend_t backend,
     // calculate rope cache for fist layer in current device.
     cann_ctx->rope_cache.cached = false;
 
-    bool cann_graph_update_required = false;
+    bool graph_capture_required = false;
 #ifdef USE_ACL_GRAPH
     bool use_cann_graph = true;
 
@@ -2331,16 +2176,17 @@ static enum ggml_status ggml_backend_cann_graph_compute(ggml_backend_t backend,
 
     if (use_cann_graph) {
         // If no matching graph is found, the graph needs to be recaptured.
-        cann_graph_update_required = !is_matched_graph(cann_ctx, cgraph);
-        if (cann_graph_update_required) {
+        graph_capture_required = !cann_ctx->graph_lru_cache.find_and_move_to_front(cgraph);
+        if (graph_capture_required) {
             // If no matching graph is found, add a new ACL graph.
-            add_lru_matched_graph_node_properties(cann_ctx, cgraph);
+            ggml_cann_graph * new_graph = ggml_cann_graph::create_from_cgraph(cgraph);
+            cann_ctx->graph_lru_cache.push(new_graph);
         }
     }
 #else
     bool use_cann_graph = false;
 #endif  // USE_ACL_GRAPH
-    evaluate_and_capture_cann_graph(cann_ctx, cgraph, use_cann_graph, cann_graph_update_required);
+    evaluate_and_capture_cann_graph(cann_ctx, cgraph, use_cann_graph, graph_capture_required);
 
     return GGML_STATUS_SUCCESS;
 }