]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
graph : simplify attn input build for unified KV cache (#12381)
authorGeorgi Gerganov <redacted>
Fri, 14 Mar 2025 08:47:44 +0000 (10:47 +0200)
committerGitHub <redacted>
Fri, 14 Mar 2025 08:47:44 +0000 (10:47 +0200)
ggml-ci

src/llama-graph.cpp
src/llama-graph.h
src/llama-model.cpp

index 1041ba29fbb578000155e7a75b1774e949bad79b..e4af507780aa1a4447e020378779d86221f91e6a 100644 (file)
@@ -1311,29 +1311,23 @@ ggml_tensor * llm_graph_context::build_attn(
     return cur;
 }
 
-llm_graph_input_attn_kv_unified * llm_graph_context::build_attn_inp_kv_unified(
-                bool   causal,
-                bool   swa) const {
+llm_graph_input_attn_kv_unified * llm_graph_context::build_attn_inp_kv_unified() const {
     const llama_kv_cache_unified * kv_self = static_cast<const llama_kv_cache_unified *>(memory);
 
     auto inp = std::make_unique<llm_graph_input_attn_kv_unified>(hparams, cparams, kv_self);
 
     const auto n_kv = kv_self->n;
 
-    inp->self_kq_mask = causal
-        ? ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_kv,     GGML_PAD(n_tokens, GGML_KQ_MASK_PAD))
-        : ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_tokens, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD));
+    inp->self_kq_mask = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD));
     //cb(inp->self_kq_mask, "KQ_mask", -1);
     ggml_set_input(inp->self_kq_mask);
 
     inp->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask, GGML_TYPE_F16) : inp->self_kq_mask;
 
-    if (swa) {
+    if (hparams.n_swa_pattern > 1) {
         GGML_ASSERT(hparams.n_swa > 0);
 
-        inp->self_kq_mask_swa = causal
-            ? ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_kv,     GGML_PAD(n_tokens, GGML_KQ_MASK_PAD))
-            : ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_tokens, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD));
+        inp->self_kq_mask_swa = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD));
         //cb(inp->self_kq_mask_swa, "KQ_mask_swa", -1);
         ggml_set_input(inp->self_kq_mask_swa);
 
index b7a66d1898736da5a9f7097a3a980ec5d89db49f..c4328e6f9e627848aa5ba07ef888750dd3731356 100644 (file)
@@ -509,9 +509,7 @@ struct llm_graph_context {
                   float   kq_scale,
                     int   il) const;
 
-    llm_graph_input_attn_kv_unified * build_attn_inp_kv_unified(
-            bool causal,
-            bool swa) const;
+    llm_graph_input_attn_kv_unified * build_attn_inp_kv_unified() const;
 
     ggml_tensor * build_attn(
             llm_graph_input_attn_kv_unified * inp,
index cce943df08a83c84820f050d4371f97e71dd1588..750a702ff77a4a825b3b3895e2352ad2f20dbd81 100644 (file)
@@ -784,9 +784,11 @@ void llama_model::load_hparams(llama_model_loader & ml) {
                     hparams.n_swa = 2047;
                 } else if (hparams.n_layer == 32 && hparams.n_head_kv(0) == 32 && hparams.n_ctx_train == 131072) {
                     // default value for Phi-3-mini-128k-instruct
+                    // note: this seems incorrect because the window is bigger than the train context?
                     hparams.n_swa = 262144;
                 } else if (hparams.n_layer == 40 && hparams.n_ctx_train == 131072) {
                     // default value for Phi-3-medium-128k-instruct
+                    // note: this seems incorrect because the window is equal to the train context?
                     hparams.n_swa = 131072;
                 }
                 bool found_swa = ml.get_key(LLM_KV_ATTENTION_SLIDING_WINDOW, hparams.n_swa, false);
@@ -3710,6 +3712,7 @@ void llama_model::print_info() const {
         LLAMA_LOG_INFO("%s: n_head_kv        = %s\n",     __func__, print_f([&](uint32_t il) { return hparams.n_head_kv(il); }, hparams.n_layer).c_str());
         LLAMA_LOG_INFO("%s: n_rot            = %u\n",     __func__, hparams.n_rot);
         LLAMA_LOG_INFO("%s: n_swa            = %u\n",     __func__, hparams.n_swa);
+        LLAMA_LOG_INFO("%s: n_swa_pattern    = %u\n",     __func__, hparams.n_swa_pattern);
         LLAMA_LOG_INFO("%s: n_embd_head_k    = %u\n",     __func__, hparams.n_embd_head_k);
         LLAMA_LOG_INFO("%s: n_embd_head_v    = %u\n",     __func__, hparams.n_embd_head_v);
         LLAMA_LOG_INFO("%s: n_gqa            = %s\n",     __func__, print_f([&](uint32_t il) { return hparams.n_gqa(il);        }, hparams.n_layer).c_str());
@@ -3871,7 +3874,7 @@ struct llm_build_llama : public llm_graph_context {
         // inp_pos - contains the positions
         ggml_tensor * inp_pos = build_inp_pos();
 
-        auto * inp_attn = build_attn_inp_kv_unified(true, false);
+        auto * inp_attn = build_attn_inp_kv_unified();
 
         const float kq_scale = hparams.f_attention_scale == 0.0f ? 1.0f/sqrtf(float(n_embd_head)) : hparams.f_attention_scale;
         for (int il = 0; il < n_layer; ++il) {
@@ -4034,7 +4037,7 @@ struct llm_build_deci : public llm_graph_context {
         // inp_pos - contains the positions
         ggml_tensor * inp_pos = build_inp_pos();
 
-        auto * inp_attn = build_attn_inp_kv_unified(true, false);
+        auto * inp_attn = build_attn_inp_kv_unified();
 
         const float kq_scale = hparams.f_attention_scale == 0.0f ? 1.0f/sqrtf(float(n_embd_head)) : hparams.f_attention_scale;
         for (int il = 0; il < n_layer; ++il) {
@@ -4192,7 +4195,7 @@ struct llm_build_baichuan : public llm_graph_context {
         // inp_pos - contains the positions
         ggml_tensor * inp_pos = model.type == LLM_TYPE_7B ? build_inp_pos() : nullptr;
 
-        auto * inp_attn = build_attn_inp_kv_unified(true, false);
+        auto * inp_attn = build_attn_inp_kv_unified();
 
         for (int il = 0; il < n_layer; ++il) {
             ggml_tensor * inpSA = inpL;
@@ -4310,7 +4313,7 @@ struct llm_build_xverse : public llm_graph_context {
         // inp_pos - contains the positions
         ggml_tensor * inp_pos = build_inp_pos();
 
-        auto * inp_attn = build_attn_inp_kv_unified(true, false);
+        auto * inp_attn = build_attn_inp_kv_unified();
 
         for (int il = 0; il < n_layer; ++il) {
             ggml_tensor * inpSA = inpL;
@@ -4418,7 +4421,7 @@ struct llm_build_falcon : public llm_graph_context {
         // inp_pos - contains the positions
         ggml_tensor * inp_pos = build_inp_pos();
 
-        auto * inp_attn = build_attn_inp_kv_unified(true, false);
+        auto * inp_attn = build_attn_inp_kv_unified();
 
         for (int il = 0; il < n_layer; ++il) {
             ggml_tensor * attn_norm;
@@ -4543,7 +4546,7 @@ struct llm_build_grok : public llm_graph_context {
         // inp_pos - contains the positions
         ggml_tensor * inp_pos = build_inp_pos();
 
-        auto * inp_attn = build_attn_inp_kv_unified(true, false);
+        auto * inp_attn = build_attn_inp_kv_unified();
 
         for (int il = 0; il < n_layer; ++il) {
             ggml_tensor * inpSA = inpL;
@@ -4697,7 +4700,7 @@ struct llm_build_dbrx : public llm_graph_context {
         // inp_pos - contains the positions
         ggml_tensor * inp_pos = build_inp_pos();
 
-        auto * inp_attn = build_attn_inp_kv_unified(true, false);
+        auto * inp_attn = build_attn_inp_kv_unified();
 
         for (int il = 0; il < n_layer; ++il) {
             ggml_tensor * inpSA = inpL;
@@ -4821,7 +4824,7 @@ struct llm_build_starcoder : public llm_graph_context {
         // inp_pos - contains the positions
         ggml_tensor * inp_pos = build_inp_pos();
 
-        auto * inp_attn = build_attn_inp_kv_unified(true, false);
+        auto * inp_attn = build_attn_inp_kv_unified();
 
         ggml_tensor * pos = ggml_get_rows(ctx0, model.pos_embd, inp_pos);
         cb(pos, "pos_embd", -1);
@@ -4924,7 +4927,7 @@ struct llm_build_refact : public llm_graph_context {
 
         inpL = build_inp_embd(model.tok_embd);
 
-        auto * inp_attn = build_attn_inp_kv_unified(true, false);
+        auto * inp_attn = build_attn_inp_kv_unified();
 
         for (int il = 0; il < n_layer; ++il) {
             ggml_tensor * inpSA = inpL;
@@ -5187,7 +5190,7 @@ struct llm_build_bloom : public llm_graph_context {
 
         inpL = build_inp_embd(model.tok_embd);
 
-        auto * inp_attn = build_attn_inp_kv_unified(true, false);
+        auto * inp_attn = build_attn_inp_kv_unified();
 
         inpL = build_norm(inpL,
                 model.tok_norm,
@@ -5292,7 +5295,7 @@ struct llm_build_mpt : public llm_graph_context {
 
         inpL = build_inp_embd(model.tok_embd);
 
-        auto * inp_attn = build_attn_inp_kv_unified(true, false);
+        auto * inp_attn = build_attn_inp_kv_unified();
 
         if (model.pos_embd) {
             // inp_pos - contains the positions
@@ -5436,7 +5439,7 @@ struct llm_build_stablelm : public llm_graph_context {
         // inp_pos - contains the positions
         ggml_tensor * inp_pos = build_inp_pos();
 
-        auto * inp_attn = build_attn_inp_kv_unified(true, false);
+        auto * inp_attn = build_attn_inp_kv_unified();
 
         for (int il = 0; il < n_layer; ++il) {
             // norm
@@ -5587,7 +5590,7 @@ struct llm_build_qwen : public llm_graph_context {
         // inp_pos - contains the positions
         ggml_tensor * inp_pos = build_inp_pos();
 
-        auto * inp_attn = build_attn_inp_kv_unified(true, false);
+        auto * inp_attn = build_attn_inp_kv_unified();
 
         for (int il = 0; il < n_layer; ++il) {
             ggml_tensor * inpSA = inpL;
@@ -5703,7 +5706,7 @@ struct llm_build_qwen2 : public llm_graph_context {
         // inp_pos - contains the positions
         ggml_tensor * inp_pos = build_inp_pos();
 
-        auto * inp_attn = build_attn_inp_kv_unified(true, false);
+        auto * inp_attn = build_attn_inp_kv_unified();
 
         for (int il = 0; il < n_layer; ++il) {
             ggml_tensor * inpSA = inpL;
@@ -5818,7 +5821,7 @@ struct llm_build_qwen2vl : public llm_graph_context {
         // inp_pos - contains the positions
         ggml_tensor * inp_pos = build_inp_pos();
 
-        auto * inp_attn = build_attn_inp_kv_unified(true, false);
+        auto * inp_attn = build_attn_inp_kv_unified();
 
         int sections[4];
         std::copy(std::begin(hparams.rope_sections), std::begin(hparams.rope_sections) + 4, sections);
@@ -5938,7 +5941,7 @@ struct llm_build_qwen2moe : public llm_graph_context {
         // inp_pos - contains the positions
         ggml_tensor * inp_pos = build_inp_pos();
 
-        auto * inp_attn = build_attn_inp_kv_unified(true, false);
+        auto * inp_attn = build_attn_inp_kv_unified();
 
         for (int il = 0; il < n_layer; ++il) {
             ggml_tensor * inpSA = inpL;
@@ -6087,7 +6090,7 @@ struct llm_build_phi2 : public llm_graph_context {
         // inp_pos - contains the positions
         ggml_tensor * inp_pos = build_inp_pos();
 
-        auto * inp_attn = build_attn_inp_kv_unified(true, false);
+        auto * inp_attn = build_attn_inp_kv_unified();
 
         for (int il = 0; il < n_layer; ++il) {
             attn_norm_output = build_norm(inpL,
@@ -6211,7 +6214,7 @@ struct llm_build_phi3 : public llm_graph_context {
         // inp_pos - contains the positions
         ggml_tensor * inp_pos = build_inp_pos();
 
-        auto * inp_attn = build_attn_inp_kv_unified(true, true);
+        auto * inp_attn = build_attn_inp_kv_unified();
 
         for (int il = 0; il < n_layer; ++il) {
             auto * residual = inpL;
@@ -6357,7 +6360,7 @@ struct llm_build_plamo : public llm_graph_context {
         // inp_pos - contains the positions
         ggml_tensor * inp_pos = build_inp_pos();
 
-        auto * inp_attn = build_attn_inp_kv_unified(true, false);
+        auto * inp_attn = build_attn_inp_kv_unified();
 
         for (int il = 0; il < n_layer; ++il) {
 
@@ -6465,7 +6468,7 @@ struct llm_build_gpt2 : public llm_graph_context {
         // inp_pos - contains the positions
         ggml_tensor * inp_pos = build_inp_pos();
 
-        auto * inp_attn = build_attn_inp_kv_unified(true, false);
+        auto * inp_attn = build_attn_inp_kv_unified();
 
         pos = ggml_get_rows(ctx0, model.pos_embd, inp_pos);
         cb(pos, "pos_embd", -1);
@@ -6573,7 +6576,7 @@ struct llm_build_codeshell : public llm_graph_context {
         // inp_pos - contains the positions
         ggml_tensor * inp_pos = build_inp_pos();
 
-        auto * inp_attn = build_attn_inp_kv_unified(true, false);
+        auto * inp_attn = build_attn_inp_kv_unified();
 
         for (int il = 0; il < n_layer; ++il) {
             cur = build_norm(inpL,
@@ -6686,7 +6689,7 @@ struct llm_build_orion : public llm_graph_context {
     // inp_pos - contains the positions
     ggml_tensor * inp_pos = build_inp_pos();
 
-    auto * inp_attn = build_attn_inp_kv_unified(true, false);
+    auto * inp_attn = build_attn_inp_kv_unified();
 
     for (int il = 0; il < n_layer; ++il) {
         ggml_tensor * inpSA = inpL;
@@ -6807,7 +6810,7 @@ struct llm_build_internlm2 : public llm_graph_context {
         // inp_pos - contains the positions
         ggml_tensor * inp_pos = build_inp_pos();
 
-        auto * inp_attn = build_attn_inp_kv_unified(true, false);
+        auto * inp_attn = build_attn_inp_kv_unified();
 
         for (int il = 0; il < n_layer; ++il) {
             ggml_tensor * inpSA = inpL;
@@ -6937,7 +6940,7 @@ struct llm_build_minicpm3 : public llm_graph_context {
         // inp_pos - contains the positions
         ggml_tensor * inp_pos = build_inp_pos();
 
-        auto * inp_attn = build_attn_inp_kv_unified(true, false);
+        auto * inp_attn = build_attn_inp_kv_unified();
 
         for (int il = 0; il < n_layer; ++il) {
             ggml_tensor * inpSA = inpL;
@@ -7141,7 +7144,7 @@ struct llm_build_gemma : public llm_graph_context {
         // inp_pos - contains the positions
         ggml_tensor * inp_pos = build_inp_pos();
 
-        auto * inp_attn = build_attn_inp_kv_unified(true, false);
+        auto * inp_attn = build_attn_inp_kv_unified();
 
         for (int il = 0; il < n_layer; ++il) {
             // norm
@@ -7251,7 +7254,7 @@ struct llm_build_gemma2 : public llm_graph_context {
         // inp_pos - contains the positions
         ggml_tensor * inp_pos = build_inp_pos();
 
-        auto * inp_attn = build_attn_inp_kv_unified(true, true);
+        auto * inp_attn = build_attn_inp_kv_unified();
 
         for (int il = 0; il < n_layer; ++il) {
             // norm
@@ -7386,7 +7389,7 @@ struct llm_build_gemma3 : public llm_graph_context {
         ggml_tensor * inp_pos = build_inp_pos();
 
         // TODO: is causal == true correct? might need some changes
-        auto * inp_attn = build_attn_inp_kv_unified(true, true);
+        auto * inp_attn = build_attn_inp_kv_unified();
 
         for (int il = 0; il < n_layer; ++il) {
             const bool is_swa = hparams.is_swa(il);
@@ -7515,7 +7518,7 @@ struct llm_build_starcoder2 : public llm_graph_context {
         // inp_pos - contains the positions
         ggml_tensor * inp_pos = build_inp_pos();
 
-        auto * inp_attn = build_attn_inp_kv_unified(true, false);
+        auto * inp_attn = build_attn_inp_kv_unified();
 
         for (int il = 0; il < n_layer; ++il) {
             ggml_tensor * inpSA = inpL;
@@ -7828,7 +7831,7 @@ struct llm_build_command_r : public llm_graph_context {
         // inp_pos - contains the positions
         ggml_tensor * inp_pos = build_inp_pos();
 
-        auto * inp_attn = build_attn_inp_kv_unified(true, false);
+        auto * inp_attn = build_attn_inp_kv_unified();
 
         for (int il = 0; il < n_layer; ++il) {
 
@@ -7978,7 +7981,7 @@ struct llm_build_cohere2 : public llm_graph_context {
         // inp_pos - contains the positions
         ggml_tensor * inp_pos = build_inp_pos();
 
-        auto * inp_attn = build_attn_inp_kv_unified(true, true);
+        auto * inp_attn = build_attn_inp_kv_unified();
 
         for (int il = 0; il < n_layer; ++il) {
             const bool is_swa = hparams.is_swa(il);
@@ -8110,7 +8113,7 @@ struct llm_build_olmo : public llm_graph_context {
         // inp_pos - contains the positions
         ggml_tensor * inp_pos = build_inp_pos();
 
-        auto * inp_attn = build_attn_inp_kv_unified(true, false);
+        auto * inp_attn = build_attn_inp_kv_unified();
 
         for (int il = 0; il < n_layer; ++il) {
             ggml_tensor * inpSA = inpL;
@@ -8232,7 +8235,7 @@ struct llm_build_olmo2 : public llm_graph_context {
         // inp_pos - contains the positions
         ggml_tensor * inp_pos = build_inp_pos();
 
-        auto * inp_attn = build_attn_inp_kv_unified(true, false);
+        auto * inp_attn = build_attn_inp_kv_unified();
 
         for (int il = 0; il < n_layer; ++il) {
             ggml_tensor * inpSA = inpL;
@@ -8358,7 +8361,7 @@ struct llm_build_olmoe : public llm_graph_context {
         // inp_pos - contains the positions
         ggml_tensor * inp_pos = build_inp_pos();
 
-        auto * inp_attn = build_attn_inp_kv_unified(true, false);
+        auto * inp_attn = build_attn_inp_kv_unified();
 
         for (int il = 0; il < n_layer; ++il) {
             ggml_tensor * inpSA = inpL;
@@ -8481,7 +8484,7 @@ struct llm_build_openelm : public llm_graph_context {
         // inp_pos - contains the positions
         ggml_tensor * inp_pos = build_inp_pos();
 
-        auto * inp_attn = build_attn_inp_kv_unified(true, false);
+        auto * inp_attn = build_attn_inp_kv_unified();
 
         for (int il = 0; il < n_layer; ++il) {
             const int64_t n_head    = hparams.n_head(il);
@@ -8611,7 +8614,7 @@ struct llm_build_gptneox : public llm_graph_context {
         // inp_pos - contains the positions
         ggml_tensor * inp_pos = build_inp_pos();
 
-        auto * inp_attn = build_attn_inp_kv_unified(true, false);
+        auto * inp_attn = build_attn_inp_kv_unified();
 
         for (int il = 0; il < n_layer; ++il) {
             cur = build_norm(inpL,
@@ -8757,7 +8760,7 @@ struct llm_build_arctic : public llm_graph_context {
         // inp_pos - contains the positions
         ggml_tensor * inp_pos = build_inp_pos();
 
-        auto * inp_attn = build_attn_inp_kv_unified(true, false);
+        auto * inp_attn = build_attn_inp_kv_unified();
 
         for (int il = 0; il < n_layer; ++il) {
             ggml_tensor * inpSA = inpL;
@@ -8889,7 +8892,7 @@ struct llm_build_deepseek : public llm_graph_context {
         // inp_pos - contains the positions
         ggml_tensor * inp_pos = build_inp_pos();
 
-        auto * inp_attn = build_attn_inp_kv_unified(true, false);
+        auto * inp_attn = build_attn_inp_kv_unified();
 
         const float kq_scale = hparams.f_attention_scale == 0.0f ? 1.0f/sqrtf(float(n_embd_head)) : hparams.f_attention_scale;
 
@@ -9054,7 +9057,7 @@ struct llm_build_deepseek2 : public llm_graph_context {
         // inp_pos - contains the positions
         ggml_tensor * inp_pos = build_inp_pos();
 
-        auto * inp_attn = build_attn_inp_kv_unified(true, false);
+        auto * inp_attn = build_attn_inp_kv_unified();
 
         for (int il = 0; il < n_layer; ++il) {
             ggml_tensor * inpSA = inpL;
@@ -9274,7 +9277,7 @@ struct llm_build_bitnet : public llm_graph_context {
         // inp_pos - contains the positions
         ggml_tensor * inp_pos = build_inp_pos();
 
-        auto * inp_attn = build_attn_inp_kv_unified(true, false);
+        auto * inp_attn = build_attn_inp_kv_unified();
 
         for (int il = 0; il < n_layer; ++il) {
             ggml_tensor * inpSA = inpL;
@@ -9532,7 +9535,7 @@ struct llm_build_t5_dec : public llm_graph_context {
 
         const int64_t n_outputs_enc = embd_enc->ne[1];
 
-        auto * inp_attn_self  = build_attn_inp_kv_unified(true, false);
+        auto * inp_attn_self  = build_attn_inp_kv_unified();
         auto * inp_attn_cross = build_attn_inp_cross();
 
         for (int il = 0; il < n_layer; ++il) {
@@ -9698,7 +9701,7 @@ struct llm_build_jais : public llm_graph_context {
 
         inpL = build_inp_embd(model.tok_embd);
 
-        auto * inp_attn = build_attn_inp_kv_unified(true, false);
+        auto * inp_attn = build_attn_inp_kv_unified();
 
         for (int il = 0; il < n_layer; ++il) {
             cur = build_norm(inpL,
@@ -9794,7 +9797,7 @@ struct llm_build_chatglm : public llm_graph_context {
         // inp_pos - contains the positions
         ggml_tensor * inp_pos = build_inp_pos();
 
-        auto * inp_attn = build_attn_inp_kv_unified(true, false);
+        auto * inp_attn = build_attn_inp_kv_unified();
 
         for (int il = 0; il < n_layer; ++il) {
             ggml_tensor * inpSA = inpL;
@@ -9926,7 +9929,7 @@ struct llm_build_nemotron : public llm_graph_context {
         // inp_pos - contains the positions
         ggml_tensor * inp_pos = build_inp_pos();
 
-        auto * inp_attn = build_attn_inp_kv_unified(true, false);
+        auto * inp_attn = build_attn_inp_kv_unified();
 
         for (int il = 0; il < n_layer; ++il) {
             ggml_tensor * inpSA = inpL;
@@ -10049,7 +10052,7 @@ struct llm_build_exaone : public llm_graph_context {
         // inp_pos - contains the positions
         ggml_tensor * inp_pos = build_inp_pos();
 
-        auto * inp_attn = build_attn_inp_kv_unified(true, false);
+        auto * inp_attn = build_attn_inp_kv_unified();
 
         for (int il = 0; il < n_layer; ++il) {
             ggml_tensor * inpSA = inpL;
@@ -10565,7 +10568,7 @@ struct llm_build_chameleon : public llm_graph_context {
         // inp_pos - contains the positions
         ggml_tensor * inp_pos = build_inp_pos();
 
-        auto * inp_attn = build_attn_inp_kv_unified(true, false);
+        auto * inp_attn = build_attn_inp_kv_unified();
 
         for (int il = 0; il < n_layer; ++il) {
             ggml_tensor * inpSA = inpL;