]> git.djapps.eu Git - pkg/ggml/sources/whisper.cpp/commitdiff
ggml : extend ggml_can_fuse to work with non-sequential nodes (llama/16123)
authorGeorgi Gerganov <redacted>
Mon, 22 Sep 2025 08:12:37 +0000 (11:12 +0300)
committerGeorgi Gerganov <redacted>
Mon, 29 Sep 2025 12:18:09 +0000 (15:18 +0300)
* ggml : extend ggml_can_fuse to work with non-sequential nodes in the graph

* cont : fix wrong bounds check condition

* cont : remove unnecessary overload

ggml/src/ggml-impl.h

index 6e01a42cef6756dcced0d3e486e334f699afbaf6..c2eaea22fecdb51acf4bcb77afa7aa88530e165a 100644 (file)
@@ -583,27 +583,27 @@ static inline bool ggml_node_has_n_uses(const struct ggml_cgraph * cgraph, int n
     return true;
 }
 
-// Returns true if nodes [i, i+ops.size()) are the sequence of ggml_ops in ops[]
+// Returns true if nodes with indices { node_idxs } are the sequence of ggml_ops in ops[]
 // and are fusable. Nodes are considered fusable according to this function if:
 // - all nodes except the last have only one use and are not views/outputs (see ggml_node_has_N_uses).
 // - all nodes except the last are a src of the following node.
 // - all nodes are the same shape.
 // TODO: Consider allowing GGML_OP_NONE nodes in between
-static inline bool ggml_can_fuse(const struct ggml_cgraph * cgraph, int node_idx, const enum ggml_op * ops, int num_ops) {
-    if (node_idx + num_ops > cgraph->n_nodes) {
-        return false;
-    }
-
+static inline bool ggml_can_fuse_ext(const struct ggml_cgraph * cgraph, const int * node_idxs, const enum ggml_op * ops, int num_ops) {
     for (int i = 0; i < num_ops; ++i) {
-        struct ggml_tensor * node = cgraph->nodes[node_idx + i];
+        if (node_idxs[i] >= cgraph->n_nodes) {
+            return false;
+        }
+
+        struct ggml_tensor * node = cgraph->nodes[node_idxs[i]];
         if (node->op != ops[i]) {
             return false;
         }
-        if (i < num_ops - 1 && !ggml_node_has_n_uses(cgraph, node_idx + i, 1)) {
+        if (i < num_ops - 1 && !ggml_node_has_n_uses(cgraph, node_idxs[i], 1)) {
             return false;
         }
         if (i > 0) {
-            struct ggml_tensor * prev = cgraph->nodes[node_idx + i - 1];
+            struct ggml_tensor * prev = cgraph->nodes[node_idxs[i - 1]];
             if (node->src[0] != prev && node->src[1] != prev) {
                 return false;
             }
@@ -615,6 +615,22 @@ static inline bool ggml_can_fuse(const struct ggml_cgraph * cgraph, int node_idx
     return true;
 }
 
+// same as above, for sequential indices starting at node_idx
+static inline bool ggml_can_fuse(const struct ggml_cgraph * cgraph, int node_idx, const enum ggml_op * ops, int num_ops) {
+    assert(num_ops < 32);
+
+    if (node_idx + num_ops > cgraph->n_nodes) {
+        return false;
+    }
+
+    int idxs[32];
+    for (int i = 0; i < num_ops; ++i) {
+        idxs[i] = node_idx + i;
+    }
+
+    return ggml_can_fuse_ext(cgraph, idxs, ops, num_ops);
+}
+
 #ifdef __cplusplus
 }
 #endif