]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
graph : fix KQ mask, lora, cvec reuse checks (#19644)
authorGeorgi Gerganov <redacted>
Mon, 16 Feb 2026 07:21:11 +0000 (09:21 +0200)
committerGitHub <redacted>
Mon, 16 Feb 2026 07:21:11 +0000 (09:21 +0200)
* graph : fix KQ mask reuse condition

* cont : dedup KQ mask build and can_reuse

* cont : fix build

* graph : fix adapter check for reuse

src/llama-adapter.h
src/llama-context.cpp
src/llama-context.h
src/llama-graph.cpp

index d275d25425e85a75ec14527ae47017f90a7936fa..aa3ab63ad7531c737b371eca3e12e78ec9c7476b 100644 (file)
@@ -39,6 +39,8 @@ private:
     std::vector<ggml_tensor *> tensors; // per layer
 };
 
+using llama_adapter_cvec_ptr = std::shared_ptr<llama_adapter_cvec>;
+
 //
 // llama_adapter_lora
 //
@@ -84,3 +86,4 @@ struct llama_adapter_lora {
 };
 
 using llama_adapter_loras = std::unordered_map<llama_adapter_lora *, float>;
+using llama_adapter_loras_ptr = std::unique_ptr<llama_adapter_loras>;
index 99035b6caced81c35cbb9c45554219812be0bb1a..fc05989aa556b2dc443e2c817a0f6c62ef509790 100644 (file)
@@ -22,6 +22,8 @@ llama_context::llama_context(
         const llama_model & model,
               llama_context_params params) :
     model(model),
+    cvec(std::make_unique<llama_adapter_cvec>()),
+    loras(std::make_unique<llama_adapter_loras>()),
     balloc(std::make_unique<llama_batch_allocr>(model.hparams.n_pos_per_embd())) {
     // TODO warning when creating llama_context with awkward ctx size that is not a power of 2,
     //     may need to be backend-dependent
@@ -1065,11 +1067,11 @@ void llama_context::set_adapters_lora(llama_adapter_lora ** adapters, size_t n_a
         return;
     }
 
-    loras.clear();
+    loras.reset(new llama_adapter_loras());
 
     for (size_t i = 0; i < n_adapters; i ++) {
         if (scales[i] != 0.0f) {
-            loras[adapters[i]] = scales[i];
+            loras->insert({adapters[i], scales[i]});
         }
     }
 
@@ -1079,14 +1081,14 @@ void llama_context::set_adapters_lora(llama_adapter_lora ** adapters, size_t n_a
 bool llama_context::adapters_lora_are_same(llama_adapter_lora ** adapters, size_t n_adapters, float * scales) {
     LLAMA_LOG_DEBUG("%s: adapters = %p\n", __func__, (void *) adapters);
 
-    if (n_adapters != loras.size()) {
+    if (n_adapters != loras->size()) {
         return false;
     }
 
     for (size_t i = 0; i < n_adapters; i ++) {
-        auto it = loras.find(adapters[i]);
+        auto it = loras->find(adapters[i]);
 
-        if (it == loras.end() || it->second != scales[i]) {
+        if (it == loras->end() || it->second != scales[i]) {
             return false;
         }
     }
@@ -1104,7 +1106,7 @@ bool llama_context::set_adapter_cvec(
 
     // TODO: should we reserve?
 
-    return cvec.apply(model, data, len, n_embd, il_start, il_end);
+    return cvec->apply(model, data, len, n_embd, il_start, il_end);
 }
 
 llm_graph_result * llama_context::process_ubatch(const llama_ubatch & ubatch, llm_graph_type gtype, llama_memory_context_i * mctx, ggml_status & ret) {
@@ -2081,8 +2083,8 @@ llm_graph_params llama_context::graph_params(
         /*.gtype       =*/ gtype,
         /*.sched       =*/ sched.get(),
         /*.backend_cpu =*/ backend_cpu,
-        /*.cvec        =*/ &cvec,
-        /*.loras       =*/ &loras,
+        /*.cvec        =*/ cvec.get(),
+        /*.loras       =*/ loras.get(),
         /*.mctx        =*/ mctx,
         /*.cross       =*/ &cross,
         /*.samplers    =*/ sampling.samplers,
index a8e53f335cc63ded1cd6e24e69d1d0dbe53bdc1c..e0d0085c1c3fb6fd673b7ea2b6c4452446d7b122 100644 (file)
@@ -256,9 +256,10 @@ private:
 
     const llama_model & model;
 
-    llama_cparams       cparams;
-    llama_adapter_cvec  cvec;
-    llama_adapter_loras loras;
+    llama_cparams cparams;
+
+    llama_adapter_cvec_ptr  cvec;
+    llama_adapter_loras_ptr loras;
 
     llama_cross cross; // TODO: tmp for handling cross-attention - need something better probably
 
index bba747d37b59324dd5f060bc0928a3c52407ce72..70d8ff02a922fa510bd6b79e9ebb67d6210fb782 100644 (file)
 #include <sstream>
 #include <unordered_set>
 
+// dedup helpers
+
+static ggml_tensor * build_kq_mask(
+        ggml_context * ctx,
+        const llama_kv_cache_context * mctx,
+        const llama_ubatch & ubatch,
+        const llama_cparams & cparams) {
+    const auto n_kv     = mctx->get_n_kv();
+    const auto n_tokens = ubatch.n_tokens;
+    const auto n_stream = cparams.kv_unified ? 1 : ubatch.n_seqs_unq;
+
+    return ggml_new_tensor_4d(ctx, GGML_TYPE_F32, n_kv, n_tokens/n_stream, 1, n_stream);
+}
+
+static bool can_reuse_kq_mask(
+        ggml_tensor * kq_mask,
+        const llama_kv_cache_context * mctx,
+        const llama_ubatch & ubatch,
+        const llama_cparams & cparams) {
+    const auto n_kv     = mctx->get_n_kv();
+    const auto n_tokens = ubatch.n_tokens;
+    const auto n_stream = cparams.kv_unified ? 1 : ubatch.n_seqs_unq;
+
+    bool res = true;
+
+    res &= (kq_mask->ne[0] == n_kv);
+    res &= (kq_mask->ne[1] == n_tokens/n_stream);
+    res &= (kq_mask->ne[2] == 1);
+    res &= (kq_mask->ne[3] == n_stream);
+
+    return res;
+}
+
+// impl
+
 void llm_graph_input_embd::set_input(const llama_ubatch * ubatch) {
     if (ubatch->token) {
         const int64_t n_tokens = ubatch->n_tokens;
@@ -403,8 +438,7 @@ bool llm_graph_input_attn_kv::can_reuse(const llm_graph_params & params) {
     res &= self_k_idxs->ne[0] == params.ubatch.n_tokens;
   //res &= self_v_idxs->ne[0] == params.ubatch.n_tokens; // TODO: need to move this to the unified cache and check there
 
-    res &= self_kq_mask->ne[0] == mctx->get_n_kv();
-    res &= self_kq_mask->ne[1] == params.ubatch.n_tokens;
+    res &= can_reuse_kq_mask(self_kq_mask, mctx, params.ubatch, params.cparams);
 
     return res;
 }
@@ -424,8 +458,7 @@ bool llm_graph_input_attn_k::can_reuse(const llm_graph_params & params) {
 
     res &= self_k_idxs->ne[0] == params.ubatch.n_tokens;
 
-    res &= self_kq_mask->ne[0] == mctx->get_n_kv();
-    res &= self_kq_mask->ne[1] == params.ubatch.n_tokens;
+    res &= can_reuse_kq_mask(self_kq_mask, mctx, params.ubatch, params.cparams);
 
     return res;
 }
@@ -455,11 +488,8 @@ bool llm_graph_input_attn_kv_iswa::can_reuse(const llm_graph_params & params) {
     res &= self_k_idxs_swa->ne[0] == params.ubatch.n_tokens;
   //res &= self_v_idxs_swa->ne[0] == params.ubatch.n_tokens; // TODO: need to move this to the unified cache and check there
 
-    res &= self_kq_mask->ne[0] == mctx->get_base()->get_n_kv();
-    res &= self_kq_mask->ne[1] == params.ubatch.n_tokens;
-
-    res &= self_kq_mask_swa->ne[0] == mctx->get_swa()->get_n_kv();
-    res &= self_kq_mask_swa->ne[1] == params.ubatch.n_tokens;
+    res &= can_reuse_kq_mask(self_kq_mask,     mctx->get_base(), params.ubatch, params.cparams);
+    res &= can_reuse_kq_mask(self_kq_mask_swa, mctx->get_swa(),  params.ubatch, params.cparams);
 
     return res;
 }
@@ -521,8 +551,7 @@ bool llm_graph_input_mem_hybrid::can_reuse(const llm_graph_params & params) {
     res &= inp_attn->self_k_idxs->ne[0] == params.ubatch.n_tokens;
   //res &= inp_attn->self_v_idxs->ne[0] == params.ubatch.n_tokens; // TODO: need to move this to the unified cache and check there
 
-    res &= inp_attn->self_kq_mask->ne[0] == mctx->get_attn()->get_n_kv();
-    res &= inp_attn->self_kq_mask->ne[1] == params.ubatch.n_tokens;
+    res &= can_reuse_kq_mask(inp_attn->self_kq_mask, mctx->get_attn(), params.ubatch, params.cparams);
 
     res &= inp_rs->s_copy->ne[0] == mctx->get_recr()->get_n_rs();
 
@@ -565,8 +594,7 @@ bool llm_graph_input_mem_hybrid_k::can_reuse(const llm_graph_params & params) {
 
     res &= inp_attn->self_k_idxs->ne[0] == params.ubatch.n_tokens;
 
-    res &= inp_attn->self_kq_mask->ne[0] == mctx->get_attn()->get_n_kv();
-    res &= inp_attn->self_kq_mask->ne[1] == params.ubatch.n_tokens;
+    res &= can_reuse_kq_mask(inp_attn->self_kq_mask, mctx->get_attn(), params.ubatch, params.cparams);
 
     res &= inp_rs->s_copy->ne[0] == mctx->get_recr()->get_n_rs();
 
@@ -625,8 +653,7 @@ bool llm_graph_input_mem_hybrid_iswa::can_reuse(const llm_graph_params & params)
         res &= inp_attn->self_k_idxs->ne[0] == params.ubatch.n_tokens;
       //res &= inp_attn->self_v_idxs->ne[0] == params.ubatch.n_tokens; // TODO: need to move this to the unified cache and check there
 
-        res &= inp_attn->self_kq_mask->ne[0] == attn_ctx->get_base()->get_n_kv();
-        res &= inp_attn->self_kq_mask->ne[1] == params.ubatch.n_tokens;
+        res &= can_reuse_kq_mask(inp_attn->self_kq_mask, attn_ctx->get_base(), params.ubatch, params.cparams);
     }
 
     // swa tensors may not be allocated if there are no SWA attention layers
@@ -634,8 +661,7 @@ bool llm_graph_input_mem_hybrid_iswa::can_reuse(const llm_graph_params & params)
         res &= inp_attn->self_k_idxs_swa->ne[0] == params.ubatch.n_tokens;
       //res &= inp_attn->self_v_idxs_swa->ne[0] == params.ubatch.n_tokens; // TODO: need to move this to the unified cache and check there
 
-        res &= inp_attn->self_kq_mask_swa->ne[0] == attn_ctx->get_swa()->get_n_kv();
-        res &= inp_attn->self_kq_mask_swa->ne[1] == params.ubatch.n_tokens;
+        res &= can_reuse_kq_mask(inp_attn->self_kq_mask_swa, attn_ctx->get_swa(), params.ubatch, params.cparams);
     }
 
     res &= inp_rs->s_copy->ne[0] == mctx->get_recr()->get_n_rs();
@@ -1891,14 +1917,11 @@ static std::unique_ptr<llm_graph_input_attn_kv> build_attn_inp_kv_impl(
     {
         GGML_ASSERT(hparams.swa_type == LLAMA_SWA_TYPE_NONE && "Use llama_kv_cache_iswa for SWA");
 
-        const auto n_kv     = mctx_cur->get_n_kv();
-        const auto n_tokens = ubatch.n_tokens;
-        const auto n_stream = cparams.kv_unified ? 1 : ubatch.n_seqs_unq;
-
         inp->self_k_idxs = mctx_cur->build_input_k_idxs(ctx0, ubatch);
         inp->self_v_idxs = mctx_cur->build_input_v_idxs(ctx0, ubatch);
 
-        inp->self_kq_mask = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, n_kv, n_tokens/n_stream, 1, n_stream);
+        inp->self_kq_mask = build_kq_mask(ctx0, mctx_cur, ubatch, cparams);
+
         ggml_set_input(inp->self_kq_mask);
 
         inp->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask, GGML_TYPE_F16) : inp->self_kq_mask;
@@ -1983,13 +2006,9 @@ static std::unique_ptr<llm_graph_input_attn_k> build_attn_inp_k_impl(
     {
         GGML_ASSERT(hparams.swa_type == LLAMA_SWA_TYPE_NONE && "Use llama_kv_cache_iswa for SWA");
 
-        const auto n_kv     = mctx_cur->get_n_kv();
-        const auto n_tokens = ubatch.n_tokens;
-        const auto n_stream = cparams.kv_unified ? 1 : ubatch.n_seqs_unq;
-
         inp->self_k_idxs = mctx_cur->build_input_k_idxs(ctx0, ubatch);
 
-        inp->self_kq_mask = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, n_kv, n_tokens/n_stream, 1, n_stream);
+        inp->self_kq_mask = build_kq_mask(ctx0, mctx_cur, ubatch, cparams);
         ggml_set_input(inp->self_kq_mask);
 
         inp->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask, GGML_TYPE_F16) : inp->self_kq_mask;
@@ -2188,15 +2207,11 @@ llm_graph_input_attn_kv_iswa * llm_graph_context::build_attn_inp_kv_iswa() const
 
     auto inp = std::make_unique<llm_graph_input_attn_kv_iswa>(hparams, cparams, mctx_cur);
 
-    const auto n_stream = cparams.kv_unified ? 1 : ubatch.n_seqs_unq;
-
     {
-        const auto n_kv = mctx_cur->get_base()->get_n_kv();
-
         inp->self_k_idxs = mctx_cur->get_base()->build_input_k_idxs(ctx0, ubatch);
         inp->self_v_idxs = mctx_cur->get_base()->build_input_v_idxs(ctx0, ubatch);
 
-        inp->self_kq_mask = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, n_kv, n_tokens/n_stream, 1, n_stream);
+        inp->self_kq_mask = build_kq_mask(ctx0, mctx_cur->get_base(), ubatch, cparams);
         ggml_set_input(inp->self_kq_mask);
         ggml_set_name(inp->self_kq_mask, "self_kq_mask");
 
@@ -2207,12 +2222,10 @@ llm_graph_input_attn_kv_iswa * llm_graph_context::build_attn_inp_kv_iswa() const
     {
         GGML_ASSERT(hparams.swa_type != LLAMA_SWA_TYPE_NONE && "Use llama_kv_cache for non-SWA");
 
-        const auto n_kv = mctx_cur->get_swa()->get_n_kv();
-
         inp->self_k_idxs_swa = mctx_cur->get_swa()->build_input_k_idxs(ctx0, ubatch);
         inp->self_v_idxs_swa = mctx_cur->get_swa()->build_input_v_idxs(ctx0, ubatch);
 
-        inp->self_kq_mask_swa = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, n_kv, n_tokens/n_stream, 1, n_stream);
+        inp->self_kq_mask_swa = build_kq_mask(ctx0, mctx_cur->get_swa(), ubatch, cparams);
         ggml_set_input(inp->self_kq_mask_swa);
         ggml_set_name(inp->self_kq_mask_swa, "self_kq_mask_swa");
 
@@ -2374,27 +2387,21 @@ llm_graph_input_mem_hybrid_iswa * llm_graph_context::build_inp_mem_hybrid_iswa()
 
     auto inp_attn = std::make_unique<llm_graph_input_attn_kv_iswa>(hparams, cparams, attn_ctx);
 
-    const auto n_stream = cparams.kv_unified ? 1 : ubatch.n_seqs_unq;
-
     {
-        const auto n_kv = attn_ctx->get_base()->get_n_kv();
-
         inp_attn->self_k_idxs = attn_ctx->get_base()->build_input_k_idxs(ctx0, ubatch);
         inp_attn->self_v_idxs = attn_ctx->get_base()->build_input_v_idxs(ctx0, ubatch);
 
-        inp_attn->self_kq_mask = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, n_kv, n_tokens/n_stream, 1, n_stream);
+        inp_attn->self_kq_mask = build_kq_mask(ctx0, attn_ctx->get_base(), ubatch, cparams);
         ggml_set_input(inp_attn->self_kq_mask);
 
         inp_attn->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp_attn->self_kq_mask, GGML_TYPE_F16) : inp_attn->self_kq_mask;
     }
 
     {
-        const auto n_kv = attn_ctx->get_swa()->get_n_kv();
-
         inp_attn->self_k_idxs_swa = attn_ctx->get_swa()->build_input_k_idxs(ctx0, ubatch);
         inp_attn->self_v_idxs_swa = attn_ctx->get_swa()->build_input_v_idxs(ctx0, ubatch);
 
-        inp_attn->self_kq_mask_swa = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, n_kv, n_tokens/n_stream, 1, n_stream);
+        inp_attn->self_kq_mask_swa = build_kq_mask(ctx0, attn_ctx->get_swa(), ubatch, cparams);
         ggml_set_input(inp_attn->self_kq_mask_swa);
 
         inp_attn->self_kq_mask_swa_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp_attn->self_kq_mask_swa, GGML_TYPE_F16) : inp_attn->self_kq_mask_swa;