]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
ggml: add ggml_can_fuse_subgraph (#16662)
authorAman Gupta <redacted>
Tue, 21 Oct 2025 08:43:14 +0000 (16:43 +0800)
committerGitHub <redacted>
Tue, 21 Oct 2025 08:43:14 +0000 (16:43 +0800)
* ggml: add ggml_can_fuse_subgraph

* ggml-cuda: use ggml_can_fuse_subgraph for topk-moe

* format

* 1. remove inputs from signature as they are transient nodes
2. add check for views: view_src should be part of the subgraph

* - combine check into one loop
- check all view_src parents
- other minor review comments

* remove redudant if test

* - rename and other minor review comments

* add assert about count < 32

ggml/src/ggml-cuda/ggml-cuda.cu
ggml/src/ggml-impl.h
ggml/src/ggml.c

index 75fd6db14c514566640472ef4627db0e036f51d7..015b37be0708e028a93f6d0a0a0937e0b2c219ec 100644 (file)
@@ -2821,15 +2821,8 @@ static bool ggml_cuda_can_fuse(const struct ggml_cgraph * cgraph, int node_idx,
     std::initializer_list<enum ggml_op> topk_moe_ops           = ggml_cuda_topk_moe_ops(false);
     std::initializer_list<enum ggml_op> topk_moe_ops_with_norm = ggml_cuda_topk_moe_ops(true);
 
-    if (ops.size() == topk_moe_ops_with_norm.size() && std::equal(ops.begin(), ops.end(), topk_moe_ops_with_norm.begin())) {
-
-        if (node_idx + topk_moe_ops_with_norm.size() > (size_t)cgraph->n_nodes) {
-            return false;
-        }
-
-        for (size_t i = 0; i < topk_moe_ops_with_norm.size(); i++) {
-            if (cgraph->nodes[node_idx + i]->op != topk_moe_ops_with_norm.begin()[i]) return false;
-        }
+    if (ops.size() == topk_moe_ops_with_norm.size() &&
+        ggml_can_fuse_subgraph(cgraph, node_idx, topk_moe_ops_with_norm, { node_idx + 3, node_idx + 8 })) {
         ggml_tensor * softmax = cgraph->nodes[node_idx];
         ggml_tensor * weights = cgraph->nodes[node_idx+8];
 
@@ -2838,16 +2831,8 @@ static bool ggml_cuda_can_fuse(const struct ggml_cgraph * cgraph, int node_idx,
         }
     }
 
-    if (ops.size() == topk_moe_ops.size() && std::equal(ops.begin(), ops.end(), topk_moe_ops.begin())) {
-
-        if (node_idx + topk_moe_ops.size() > (size_t)cgraph->n_nodes) {
-            return false;
-        }
-
-        for (size_t i = 0; i < topk_moe_ops.size(); i++) {
-            if (cgraph->nodes[node_idx + i]->op != topk_moe_ops.begin()[i]) return false;
-        }
-
+    if (ops.size() == topk_moe_ops.size() &&
+        ggml_can_fuse_subgraph(cgraph, node_idx, topk_moe_ops, { node_idx + 3, node_idx + 4 })) {
         ggml_tensor * softmax = cgraph->nodes[node_idx];
         ggml_tensor * weights = cgraph->nodes[node_idx+4];
         if (ggml_cuda_should_use_topk_moe(softmax, weights)) {
index 18f095b8960104075d8a87cc782af5a05d088430..e9201cdc685dcd5efc8aa1b0554acbe0d82a1a3d 100644 (file)
@@ -647,6 +647,36 @@ static inline bool ggml_can_fuse(const struct ggml_cgraph * cgraph, int node_idx
     return ggml_can_fuse_ext(cgraph, idxs, ops, num_ops);
 }
 
+GGML_API bool ggml_can_fuse_subgraph_ext(const struct ggml_cgraph * cgraph,
+                                         const int *                node_idxs,
+                                         int                        count,
+                                         const enum ggml_op *       ops,
+                                         const int *                outputs,
+                                         int                        num_outputs);
+
+// Returns true if the subgraph formed by {node_idxs} can be fused
+// checks whethers all nodes which are not part of outputs can be elided
+// by checking if their num_uses are confined to the subgraph
+static inline bool ggml_can_fuse_subgraph(const struct ggml_cgraph * cgraph,
+                                          int                        node_idx,
+                                          int                        count,
+                                          const enum ggml_op *       ops,
+                                          const int *                outputs,
+                                          int                        num_outputs) {
+    GGML_ASSERT(count < 32);
+    if (node_idx + count > cgraph->n_nodes) {
+        return false;
+    }
+
+    int idxs[32];
+
+    for (int i = 0; i < count; ++i) {
+        idxs[i] = node_idx + i;
+    }
+
+    return ggml_can_fuse_subgraph_ext(cgraph, idxs, count, ops, outputs, num_outputs);
+}
+
 #ifdef __cplusplus
 }
 #endif
@@ -660,6 +690,13 @@ inline bool ggml_can_fuse(const struct ggml_cgraph * cgraph, int node_idx, std::
     return ggml_can_fuse(cgraph, node_idx, ops.begin(), (int)ops.size());
 }
 
+inline bool ggml_can_fuse_subgraph(const struct ggml_cgraph *          cgraph,
+                                   int                                 start_idx,
+                                   std::initializer_list<enum ggml_op> ops,
+                                   std::initializer_list<int>          outputs = {}) {
+    return ggml_can_fuse_subgraph(cgraph, start_idx, ops.size(), ops.begin(), outputs.begin(), outputs.size());
+}
+
 // expose GGUF internals for test code
 GGML_API size_t gguf_type_size(enum gguf_type type);
 GGML_API struct gguf_context * gguf_init_from_file_impl(FILE * file, struct gguf_init_params params);
index 86f1c31afd7a6254e4b9013ca80cd6010ac0e267..9be35c1be8456bf538ddb85e91be6c13fefb78df 100644 (file)
@@ -6964,6 +6964,78 @@ void ggml_graph_print(const struct ggml_cgraph * cgraph) {
     GGML_LOG_INFO("========================================\n");
 }
 
+static int ggml_node_list_find_tensor(const struct ggml_cgraph * cgraph,
+                                      const int *                idxs,
+                                      int                        count,
+                                      const struct ggml_tensor * tensor) {
+    GGML_ASSERT(cgraph && idxs);
+    for (int i = 0; i < count; ++i) {
+        const int node_idx = idxs[i];
+
+        if (node_idx >= cgraph->n_nodes) {
+            return -1;
+        }
+        if (cgraph->nodes[node_idx] == tensor) {
+            return i;
+        }
+    }
+    return -1;
+}
+
+bool ggml_can_fuse_subgraph_ext(const struct ggml_cgraph * cgraph,
+                                const int *                node_idxs,
+                                int                        count,
+                                const enum ggml_op *       ops,
+                                const int *                outputs,
+                                int                        num_outputs) {
+    GGML_ASSERT(outputs && num_outputs > 0);
+
+    for (int i = 0; i < count; ++i) {
+        if (node_idxs[i] >= cgraph->n_nodes) {
+            return false;
+        }
+
+        const struct ggml_tensor * node = cgraph->nodes[node_idxs[i]];
+
+        if (node->op != ops[i]) {
+            return false;
+        }
+
+        if (ggml_node_list_find_tensor(cgraph, outputs, num_outputs, node) != -1) {
+            continue;
+        }
+
+        if (node->flags & GGML_TENSOR_FLAG_OUTPUT) {
+            return false;
+        }
+
+        int subgraph_uses = 0;
+        for (int j = i + 1; j < count; ++j) {
+            const struct ggml_tensor * other_node = cgraph->nodes[node_idxs[j]];
+            for (int src_idx = 0; src_idx < GGML_MAX_SRC; src_idx++) {
+                if (other_node->src[src_idx] == node) {
+                    subgraph_uses++;
+                }
+            }
+        }
+
+        if (subgraph_uses != ggml_node_get_use_count(cgraph, node_idxs[i])) {
+            return false;
+        }
+
+        // if node is a view, check if the view_src and all it's parent view_srcs are within the subgraph
+        struct ggml_tensor * view_src = node->view_src;
+        while (view_src) {
+            if (ggml_node_list_find_tensor(cgraph, node_idxs, count, view_src) == -1) {
+                return false;
+            }
+            view_src = view_src->view_src;
+        }
+    }
+
+    return true;
+}
+
 // check if node is part of the graph
 static bool ggml_graph_find(const struct ggml_cgraph * cgraph, const struct ggml_tensor * node) {
     if (cgraph == NULL) {