]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
llama : fix fattn reserve call n_seqs parameter (#15699)
authorDiego Devesa <redacted>
Sun, 31 Aug 2025 15:47:05 +0000 (08:47 -0700)
committerGitHub <redacted>
Sun, 31 Aug 2025 15:47:05 +0000 (18:47 +0300)
ggml-ci

src/llama-context.cpp

index 7e20ee9f8b383c0d06067572240529f04d9b43b2..2de6fcf0cb20909aa5b393a564a4d68c69e578b1 100644 (file)
@@ -281,9 +281,15 @@ llama_context::llama_context(
         }
 
         cross.v_embd.clear();
+
+        const uint32_t n_seqs = cparams.kv_unified ? 1 : cparams.n_seq_max;
+        const uint32_t n_tokens = std::min(cparams.n_ctx, cparams.n_ubatch);
+
+        LLAMA_LOG_DEBUG("%s: worst-case: n_tokens = %d, n_seqs = %d, n_outputs = %d\n", __func__, n_tokens, n_seqs, n_outputs);
+
         // resolve automatic Flash Attention use
         if (params.flash_attn_type == LLAMA_FLASH_ATTN_TYPE_AUTO) {
-            auto * gf = graph_reserve(1, 1, 0, mctx.get(), true);
+            auto * gf = graph_reserve(1, n_seqs, n_outputs, mctx.get(), true);
             if (!gf) {
                 throw std::runtime_error("failed to split graph for Flash Attention check");
             }
@@ -324,11 +330,6 @@ llama_context::llama_context(
         }
 
         // reserve worst-case graph
-        const uint32_t n_seqs = cparams.kv_unified ? 1 : cparams.n_seq_max;
-        const uint32_t n_tokens = std::min(cparams.n_ctx, cparams.n_ubatch);
-
-        LLAMA_LOG_DEBUG("%s: worst-case: n_tokens = %d, n_seqs = %d, n_outputs = %d\n", __func__, n_tokens, n_seqs, n_outputs);
-
         int n_splits_pp = -1;
         int n_nodes_pp  = -1;