]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
Make graph_max_nodes vary by ubatch size (#17794)
authorPiotr Wilkin (ilintar) <redacted>
Mon, 8 Dec 2025 13:32:41 +0000 (14:32 +0100)
committerGitHub <redacted>
Mon, 8 Dec 2025 13:32:41 +0000 (14:32 +0100)
* Make graph_max_nodes vary by ubatch size for models where chunking might explode the graph

* Update src/llama-context.h

Co-authored-by: Georgi Gerganov <redacted>
* Add missing const

---------

Co-authored-by: Georgi Gerganov <redacted>
src/llama-context.cpp
src/llama-context.h

index e04f0fc4f9aeda73fa9760ee44e9a290a50d6fd2..4171400713d32573e2dbe5d99405f1914a3bf2d6 100644 (file)
@@ -248,7 +248,10 @@ llama_context::llama_context(
 
         LLAMA_LOG_DEBUG("%s: backend_ptrs.size() = %zu\n", __func__, backend_ptrs.size());
 
-        const size_t max_nodes = this->graph_max_nodes();
+        const uint32_t n_seqs = cparams.n_seq_max;
+        const uint32_t n_tokens = std::min(cparams.n_ctx, cparams.n_ubatch);
+
+        const size_t max_nodes = this->graph_max_nodes(n_tokens);
 
         LLAMA_LOG_DEBUG("%s: max_nodes = %zu\n", __func__, max_nodes);
 
@@ -300,9 +303,6 @@ llama_context::llama_context(
 
         cross.v_embd.clear();
 
-        const uint32_t n_seqs = cparams.n_seq_max;
-        const uint32_t n_tokens = std::min(cparams.n_ctx, cparams.n_ubatch);
-
         // avoid reserving graphs with zero outputs - assume one output per sequence
         n_outputs = n_seqs;
 
@@ -1386,9 +1386,9 @@ void llama_context::output_reorder() {
 // graph
 //
 
-uint32_t llama_context::graph_max_nodes() const {
+uint32_t llama_context::graph_max_nodes(uint32_t n_tokens) const {
     if (model.arch == LLM_ARCH_QWEN3NEXT) {
-        return std::max<uint32_t>(8192u, 32u*model.n_tensors());
+        return std::max<uint32_t>(n_tokens * 40, 32u * model.n_tensors());
     }
     return std::max<uint32_t>(1024u, 8u*model.n_tensors());
 }
index 20cbd78955412c3bdec9e462bd23e02d7347b99e..cd26eafe18942522b05bb31bf2ec5509b2cc312a 100644 (file)
@@ -197,7 +197,7 @@ private:
     //
 
 public:
-    uint32_t graph_max_nodes() const;
+    uint32_t graph_max_nodes(uint32_t n_tokens) const;
 
     // can reuse the llm_graph_result instance of the context (for example to update a memory module)
     llm_graph_result * get_gf_res_reserve() const;