]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
llama : separate compute buffer reserve from fattn check (#15696)
authorDiego Devesa <redacted>
Sun, 31 Aug 2025 13:49:03 +0000 (06:49 -0700)
committerGitHub <redacted>
Sun, 31 Aug 2025 13:49:03 +0000 (15:49 +0200)
Exposes ggml_backend_sched_split_graph() to allow splitting the graph without allocating compute buffers and uses it to split the graph for the automatic Flash Attention check.

ggml/include/ggml-backend.h
ggml/src/ggml-backend.cpp
src/llama-context.cpp
src/llama-context.h

index a2977ea2e56d935100b03d5f4e183647a6f1e4d2..4f246f6ccd62922283d0545a851d648c058dce47 100644 (file)
@@ -307,6 +307,9 @@ extern "C" {
     GGML_API void                 ggml_backend_sched_set_tensor_backend(ggml_backend_sched_t sched, struct ggml_tensor * node, ggml_backend_t backend);
     GGML_API ggml_backend_t       ggml_backend_sched_get_tensor_backend(ggml_backend_sched_t sched, struct ggml_tensor * node);
 
+    // Split graph without allocating it
+    GGML_API void                 ggml_backend_sched_split_graph(ggml_backend_sched_t sched, struct ggml_cgraph * graph);
+
     // Allocate and compute graph on the backend scheduler
     GGML_API bool                 ggml_backend_sched_alloc_graph(ggml_backend_sched_t sched, struct ggml_cgraph * graph); // returns success
     GGML_API enum ggml_status     ggml_backend_sched_graph_compute(ggml_backend_sched_t sched, struct ggml_cgraph * graph);
index 02375337c4dd68076b68d79941a0dd6d90a446b5..0cdbf180172566cc087b6e6f5a911001a58c5be1 100644 (file)
@@ -902,7 +902,7 @@ static void ggml_backend_sched_set_if_supported(ggml_backend_sched_t sched, stru
 }
 
 // assigns backends to ops and splits the graph into subgraphs that can be computed on the same backend
-static void ggml_backend_sched_split_graph(ggml_backend_sched_t sched, struct ggml_cgraph * graph) {
+void ggml_backend_sched_split_graph(ggml_backend_sched_t sched, struct ggml_cgraph * graph) {
     // reset splits
     sched->n_splits = 0;
     sched->n_graph_inputs = 0;
@@ -1687,6 +1687,8 @@ bool ggml_backend_sched_reserve(ggml_backend_sched_t sched, struct ggml_cgraph *
     GGML_ASSERT(sched);
     GGML_ASSERT((int)sched->hash_set.size >= measure_graph->n_nodes + measure_graph->n_leafs);
 
+    ggml_backend_sched_reset(sched);
+
     ggml_backend_sched_synchronize(sched);
 
     ggml_backend_sched_split_graph(sched, measure_graph);
index ac8453ab741d4a29272007c96e83658b1fdefc06..7e20ee9f8b383c0d06067572240529f04d9b43b2 100644 (file)
@@ -270,19 +270,7 @@ llama_context::llama_context(
         }
     }
 
-    // resolve automatic Flash Attention use and reserve worst-case graph
     if (!hparams.vocab_only) {
-        const uint32_t n_seqs = cparams.kv_unified ? 1 : cparams.n_seq_max;
-        const uint32_t n_tokens = std::min(cparams.n_ctx, cparams.n_ubatch);
-
-        LLAMA_LOG_DEBUG("%s: worst-case: n_tokens = %d, n_seqs = %d, n_outputs = %d\n", __func__, n_tokens, n_seqs, n_outputs);
-
-        int n_splits_pp = -1;
-        int n_nodes_pp  = -1;
-
-        int n_splits_tg = -1;
-        int n_nodes_tg  = -1;
-
         llama_memory_context_ptr mctx;
         if (memory) {
             LLAMA_LOG_DEBUG("%s: reserving full memory module\n", __func__);
@@ -293,6 +281,59 @@ llama_context::llama_context(
         }
 
         cross.v_embd.clear();
+        // resolve automatic Flash Attention use
+        if (params.flash_attn_type == LLAMA_FLASH_ATTN_TYPE_AUTO) {
+            auto * gf = graph_reserve(1, 1, 0, mctx.get(), true);
+            if (!gf) {
+                throw std::runtime_error("failed to split graph for Flash Attention check");
+            }
+
+            const size_t prefix_len = strlen(LLAMA_TENSOR_NAME_FATTN) + 1;
+            bool fa_device_mismatch = false;
+            for (int i = 0; i < ggml_graph_n_nodes(gf); i++) {
+                ggml_tensor * n = ggml_graph_node(gf, i);
+                if (n->op != GGML_OP_FLASH_ATTN_EXT) {
+                    continue;
+                }
+                ggml_backend_dev_t device_fa = ggml_backend_get_device(
+                    ggml_backend_sched_get_tensor_backend(sched.get(), n));
+
+                // TODO: instead of the tensor names, use a map to keep track of which (FA) tensors belong to which layer
+                GGML_ASSERT(strncmp(n->name, LLAMA_TENSOR_NAME_FATTN "-", prefix_len) == 0);
+                const int il = std::stoi(n->name + prefix_len);
+                ggml_backend_dev_t device_kv = model.dev_layer(il);
+                if (device_fa != device_kv) {
+                    LLAMA_LOG_WARN("%s: layer %d is assigned to device %s but the Flash Attention tensor "
+                        "is assigned to device %s (usually due to missing support)\n",
+                        __func__, il, ggml_backend_dev_name(device_kv), ggml_backend_dev_name(device_fa));
+                    // FIXME: fa_device_mismatch logic is wrong for --no-kv-offload, but this is broken anyways
+                    fa_device_mismatch = true;
+                    break;
+                }
+            }
+            if (fa_device_mismatch) {
+                cparams.flash_attn = false;
+                LLAMA_LOG_WARN("%s: Flash Attention was auto, set to disabled\n", __func__);
+                if (ggml_is_quantized(params.type_v)) {
+                    throw std::runtime_error("quantized V cache was requested, but this requires Flash Attention");
+                }
+            } else {
+                cparams.flash_attn = true;
+                LLAMA_LOG_INFO("%s: Flash Attention was auto, set to enabled\n", __func__);
+            }
+        }
+
+        // reserve worst-case graph
+        const uint32_t n_seqs = cparams.kv_unified ? 1 : cparams.n_seq_max;
+        const uint32_t n_tokens = std::min(cparams.n_ctx, cparams.n_ubatch);
+
+        LLAMA_LOG_DEBUG("%s: worst-case: n_tokens = %d, n_seqs = %d, n_outputs = %d\n", __func__, n_tokens, n_seqs, n_outputs);
+
+        int n_splits_pp = -1;
+        int n_nodes_pp  = -1;
+
+        int n_splits_tg = -1;
+        int n_nodes_tg  = -1;
 
         // reserve pp (prompt processing) graph first so that buffers are only allocated once
         {
@@ -301,48 +342,6 @@ llama_context::llama_context(
                 throw std::runtime_error("failed to allocate compute pp buffers");
             }
 
-            if (params.flash_attn_type == LLAMA_FLASH_ATTN_TYPE_AUTO) {
-                ggml_backend_sched_alloc_graph(sched.get(), gf);
-
-                const size_t prefix_len = strlen(LLAMA_TENSOR_NAME_FATTN) + 1;
-                bool fa_device_mismatch = false;
-                for (int i = 0; i < ggml_graph_n_nodes(gf); i++) {
-                    ggml_tensor * n = ggml_graph_node(gf, i);
-                    if (n->op != GGML_OP_FLASH_ATTN_EXT) {
-                        continue;
-                    }
-                    ggml_backend_dev_t device_fa = ggml_backend_get_device(
-                        ggml_backend_sched_get_tensor_backend(sched.get(), n));
-
-                    // TODO: instead of the tensor names, use a map to keep track of which (FA) tensors belong to which layer
-                    GGML_ASSERT(strncmp(n->name, LLAMA_TENSOR_NAME_FATTN "-", prefix_len) == 0);
-                    const int il = std::stoi(n->name + prefix_len);
-                    ggml_backend_dev_t device_kv = model.dev_layer(il);
-                    if (device_fa != device_kv) {
-                        LLAMA_LOG_WARN("%s: layer %d is assigned to device %s but the Flash Attention tensor "
-                            "is assigned to device %s (usually due to missing support)\n",
-                            __func__, il, ggml_backend_dev_name(device_kv), ggml_backend_dev_name(device_fa));
-                        // FIXME: fa_device_mismatch logic is wrong for --no-kv-offload, but this is broken anyways
-                        fa_device_mismatch = true;
-                        break;
-                    }
-                }
-                if (fa_device_mismatch) {
-                    cparams.flash_attn = false;
-                    LLAMA_LOG_WARN("%s: Flash Attention was auto, set to disabled\n", __func__);
-                    if (ggml_is_quantized(params.type_v)) {
-                        throw std::runtime_error("quantized V cache was requested, but this requires Flash Attention");
-                    }
-                    auto * gf = graph_reserve(n_tokens, n_seqs, n_tokens, mctx.get());
-                    if (!gf) {
-                        throw std::runtime_error("failed to allocate compute pp buffers");
-                    }
-                } else {
-                    cparams.flash_attn = true;
-                    LLAMA_LOG_INFO("%s: Flash Attention was auto, set to enabled\n", __func__);
-                }
-            }
-
             n_splits_pp = ggml_backend_sched_get_n_splits(sched.get());
             n_nodes_pp  = ggml_graph_n_nodes(gf);
         }
@@ -1366,7 +1365,7 @@ llm_graph_result * llama_context::get_gf_res_reserve() const {
     return static_cast<llm_graph_result *>(gf_res_reserve.get());
 }
 
-ggml_cgraph * llama_context::graph_reserve(uint32_t n_tokens, uint32_t n_seqs, uint32_t n_outputs, const llama_memory_context_i * mctx) {
+ggml_cgraph * llama_context::graph_reserve(uint32_t n_tokens, uint32_t n_seqs, uint32_t n_outputs, const llama_memory_context_i * mctx, bool split_only) {
     LLAMA_LOG_DEBUG("%s: reserving a graph for ubatch with n_tokens = %4u, n_seqs = %2u, n_outputs = %4u\n", __func__, n_tokens, n_seqs, n_outputs);
 
     if (n_tokens % n_seqs != 0) {
@@ -1401,7 +1400,9 @@ ggml_cgraph * llama_context::graph_reserve(uint32_t n_tokens, uint32_t n_seqs, u
     this->n_outputs = save_n_outputs;
 
     // initialize scheduler with the specified graph
-    if (!ggml_backend_sched_reserve(sched.get(), gf)) {
+    if (split_only) {
+        ggml_backend_sched_split_graph(sched.get(), gf);
+    } else if (!ggml_backend_sched_reserve(sched.get(), gf)) {
         LLAMA_LOG_ERROR("%s: failed to allocate compute buffers\n", __func__);
         return nullptr;
     }
index a372bcfbe41aa6883f98b685ed83f729a108c22d..f23aa8ee1368dae4b40218537385b4a4e56851ec 100644 (file)
@@ -196,7 +196,7 @@ public:
     ggml_status graph_compute(ggml_cgraph * gf, bool batched);
 
     // reserve a graph with a dummy ubatch of the specified size
-    ggml_cgraph * graph_reserve(uint32_t n_tokens, uint32_t n_seqs, uint32_t n_outputs, const llama_memory_context_i * mctx);
+    ggml_cgraph * graph_reserve(uint32_t n_tokens, uint32_t n_seqs, uint32_t n_outputs, const llama_memory_context_i * mctx, bool split_only = false);
 
 private:
     llm_graph_params graph_params(