]> git.djapps.eu Git - pkg/ggml/sources/ggml/commitdiff
llama : separate compute buffer reserve from fattn check (llama/15696)
authorDiego Devesa <redacted>
Sun, 31 Aug 2025 13:49:03 +0000 (06:49 -0700)
committerGeorgi Gerganov <redacted>
Fri, 5 Sep 2025 09:54:09 +0000 (12:54 +0300)
Exposes ggml_backend_sched_split_graph() to allow splitting the graph without allocating compute buffers and uses it to split the graph for the automatic Flash Attention check.

include/ggml-backend.h
src/ggml-backend.cpp

index a2977ea2e56d935100b03d5f4e183647a6f1e4d2..4f246f6ccd62922283d0545a851d648c058dce47 100644 (file)
@@ -307,6 +307,9 @@ extern "C" {
     GGML_API void                 ggml_backend_sched_set_tensor_backend(ggml_backend_sched_t sched, struct ggml_tensor * node, ggml_backend_t backend);
     GGML_API ggml_backend_t       ggml_backend_sched_get_tensor_backend(ggml_backend_sched_t sched, struct ggml_tensor * node);
 
+    // Split graph without allocating it
+    GGML_API void                 ggml_backend_sched_split_graph(ggml_backend_sched_t sched, struct ggml_cgraph * graph);
+
     // Allocate and compute graph on the backend scheduler
     GGML_API bool                 ggml_backend_sched_alloc_graph(ggml_backend_sched_t sched, struct ggml_cgraph * graph); // returns success
     GGML_API enum ggml_status     ggml_backend_sched_graph_compute(ggml_backend_sched_t sched, struct ggml_cgraph * graph);
index 02375337c4dd68076b68d79941a0dd6d90a446b5..0cdbf180172566cc087b6e6f5a911001a58c5be1 100644 (file)
@@ -902,7 +902,7 @@ static void ggml_backend_sched_set_if_supported(ggml_backend_sched_t sched, stru
 }
 
 // assigns backends to ops and splits the graph into subgraphs that can be computed on the same backend
-static void ggml_backend_sched_split_graph(ggml_backend_sched_t sched, struct ggml_cgraph * graph) {
+void ggml_backend_sched_split_graph(ggml_backend_sched_t sched, struct ggml_cgraph * graph) {
     // reset splits
     sched->n_splits = 0;
     sched->n_graph_inputs = 0;
@@ -1687,6 +1687,8 @@ bool ggml_backend_sched_reserve(ggml_backend_sched_t sched, struct ggml_cgraph *
     GGML_ASSERT(sched);
     GGML_ASSERT((int)sched->hash_set.size >= measure_graph->n_nodes + measure_graph->n_leafs);
 
+    ggml_backend_sched_reset(sched);
+
     ggml_backend_sched_synchronize(sched);
 
     ggml_backend_sched_split_graph(sched, measure_graph);