]> git.djapps.eu Git - pkg/ggml/sources/whisper.cpp/commitdiff
CANN: Improve ACL graph matching (llama/16166)
authorChenguang Li <redacted>
Thu, 9 Oct 2025 07:50:25 +0000 (15:50 +0800)
committerGeorgi Gerganov <redacted>
Sun, 12 Oct 2025 08:16:23 +0000 (11:16 +0300)
* CANN: improve ACL graph matching

Record `ne` and `nb` information for src tensors and include them in the
graph matching check. This enhances the robustness of ACL graph matching
by preventing incorrect matches when src tensors share the same data
address but differ in shape or stride.

* CANN: add op_params match

ggml/src/ggml-cann/common.h
ggml/src/ggml-cann/ggml-cann.cpp

index b707b843593c7305eaa5be3c918ff82706687cd2..debbcadc1e4c553baaccbcb0e1acdd8d9270ebb9 100755 (executable)
@@ -341,11 +341,18 @@ private:
 
 #ifdef USE_ACL_GRAPH
 struct ggml_graph_node_properties {
+    // dst tensor
     void * node_address;
-    ggml_op node_op;
     int64_t ne[GGML_MAX_DIMS];
     size_t nb[GGML_MAX_DIMS];
+
+    // src tensor
     void * src_address[GGML_MAX_SRC];
+    int64_t src_ne[GGML_MAX_SRC][GGML_MAX_DIMS];
+    size_t  src_nb[GGML_MAX_SRC][GGML_MAX_DIMS];
+
+    // op
+    ggml_op node_op;
     int32_t op_params[GGML_MAX_OP_PARAMS / sizeof(int32_t)];
 };
 
index b51b554e752e1364c04193840e84f703cb2581e2..ad1adba6b3a8ae05a90e0f7f6f990d6a5ef00050 100755 (executable)
@@ -2186,7 +2186,15 @@ static void add_lru_matched_graph_node_properties(
         std::copy_n(node->nb, GGML_MAX_DIMS, prop.nb);
 
         for (int src = 0; src < GGML_MAX_SRC; ++src) {
-            prop.src_address[src] = node->src[src] ? node->src[src]->data : nullptr;
+            if (node->src[src]) {
+                prop.src_address[src] = node->src[src]->data;
+                std::copy_n(node->src[src]->ne, GGML_MAX_DIMS, prop.src_ne[src]);
+                std::copy_n(node->src[src]->nb, GGML_MAX_DIMS, prop.src_nb[src]);
+            } else {
+                prop.src_address[src] = nullptr;
+                std::fill_n(prop.src_ne[src], GGML_MAX_DIMS, 0);
+                std::fill_n(prop.src_nb[src], GGML_MAX_DIMS, 0);
+            }
         }
 
         memcpy(prop.op_params, node->op_params, GGML_MAX_OP_PARAMS);
@@ -2206,14 +2214,18 @@ static void add_lru_matched_graph_node_properties(
  * @param graph_node_properties The stored properties of a CANN graph node.
  * @return true if all fields match (excluding GGML_OP_VIEW); false otherwise.
  */
-static bool ggml_graph_node_has_matching_properties(ggml_tensor * node, ggml_graph_node_properties * graph_node_properties) {
+static bool ggml_graph_node_has_matching_properties(
+        ggml_tensor * node,
+        ggml_graph_node_properties * graph_node_properties) {
     if (node->data != graph_node_properties->node_address &&
-           node->op != GGML_OP_VIEW) {
+            node->op != GGML_OP_VIEW) {
         return false;
     }
+
     if (node->op != graph_node_properties->node_op) {
         return false;
     }
+
     for (int i = 0; i < GGML_MAX_DIMS; i++) {
         if (node->ne[i] != graph_node_properties->ne[i]) {
             return false;
@@ -2222,17 +2234,31 @@ static bool ggml_graph_node_has_matching_properties(ggml_tensor * node, ggml_gra
             return false;
         }
     }
+
     for (int i = 0; i < GGML_MAX_SRC; i++) {
-        if (node->src[i] &&
-            node->src[i]->data != graph_node_properties->src_address[i] &&
-            node->op != GGML_OP_VIEW
-        ) {
-            return false;
+        if (node->src[i]) {
+            if (node->src[i]->data != graph_node_properties->src_address[i] &&
+                node->op != GGML_OP_VIEW) {
+                return false;
+            }
+
+            for (int d = 0; d < GGML_MAX_DIMS; d++) {
+                if (node->src[i]->ne[d] != graph_node_properties->src_ne[i][d]) {
+                    return false;
+                }
+                if (node->src[i]->nb[d] != graph_node_properties->src_nb[i][d]) {
+                    return false;
+                }
+            }
+        } else {
+            if (graph_node_properties->src_address[i] != nullptr) {
+                return false;
+            }
         }
     }
-    if (node->op == GGML_OP_SCALE &&
-        memcmp(graph_node_properties->op_params, node->op_params, GGML_MAX_OP_PARAMS) != 0) {
-        return false;
+
+    if (node->op == GGML_OP_SCALE || node->op == GGML_OP_UNARY || node->op == GGML_OP_GLU) {
+        return memcmp(graph_node_properties->op_params, node->op_params, GGML_MAX_OP_PARAMS) == 0;
     }
     return true;
 }