]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
llama : DeepSeek V2/V3 MLA implementation (#12801)
authorJuk Armstrong <redacted>
Tue, 15 Apr 2025 06:49:57 +0000 (07:49 +0100)
committerGitHub <redacted>
Tue, 15 Apr 2025 06:49:57 +0000 (09:49 +0300)
* Merged using squash to remove all noise commit messages

* Force flash attention off for `LLM_ARCH_DEEPSEEK2` - embedding too large

* Removed 3 conts (2x RoPE and 1x RMS-norm)

* Changed to use `<cmath>` instead of `<math.h>`

* Reverted removal of the 3 conts

* Used `reshape` in `llm_graph_context::build_attn_mha()`

* Use `k_pe = ggml_reshape`

* Removed the 3 conts again

* Removed the 3D views of `wk_b` and `wv_b`, and just save and 3D in GGUF

* Removed MQA optimisation from `build_attn_mha()` as no gains now

* Simplified `is_mla` branch in `llm_build_deepseek2()`

* Removed `build_attn_mla` and added `nullptr` to all `build_atnn` calls

* Fixed call to `build_attn` in `llm_build_t5_enc`

13 files changed:
convert_hf_to_gguf.py
gguf-py/gguf/constants.py
gguf-py/gguf/gguf_writer.py
gguf-py/gguf/tensor_mapping.py
src/llama-arch.cpp
src/llama-arch.h
src/llama-context.cpp
src/llama-graph.cpp
src/llama-graph.h
src/llama-hparams.h
src/llama-kv-cache.cpp
src/llama-model.cpp
src/llama-model.h

index 2bf97475f78dd1405108ca5ec8aab494c1b4915f..89522dee8b8add210a032ef4024cbff75d62c575 100755 (executable)
@@ -4422,6 +4422,10 @@ class DeepseekV2Model(Model):
         self._set_vocab_gpt2()
 
     def set_gguf_parameters(self):
+
+        # note: deepseek2 using MLA converts into MQA (ie: GQA with 1 group)
+        self.hparams["num_key_value_heads"] = 1
+
         super().set_gguf_parameters()
         hparams = self.hparams
 
@@ -4430,8 +4434,13 @@ class DeepseekV2Model(Model):
         if "q_lora_rank" in hparams and hparams["q_lora_rank"] is not None:
             self.gguf_writer.add_q_lora_rank(hparams["q_lora_rank"])
         self.gguf_writer.add_kv_lora_rank(hparams["kv_lora_rank"])
-        self.gguf_writer.add_key_length(hparams["qk_nope_head_dim"] + hparams["qk_rope_head_dim"])
-        self.gguf_writer.add_value_length(hparams["v_head_dim"])
+
+        # note: deepseek2 using MLA converts into MQA with larger heads, then decompresses to MHA
+        self.gguf_writer.add_key_length(hparams["kv_lora_rank"] + hparams["qk_rope_head_dim"])
+        self.gguf_writer.add_value_length(hparams["kv_lora_rank"])
+        self.gguf_writer.add_key_length_mla(hparams["qk_nope_head_dim"] + hparams["qk_rope_head_dim"])
+        self.gguf_writer.add_value_length_mla(hparams["v_head_dim"])
+
         self.gguf_writer.add_expert_feed_forward_length(hparams["moe_intermediate_size"])
         self.gguf_writer.add_expert_count(hparams["n_routed_experts"])
         self.gguf_writer.add_expert_shared_count(hparams["n_shared_experts"])
@@ -4500,6 +4509,26 @@ class DeepseekV2Model(Model):
             else:
                 return []
 
+        # note: MLA with the absorption optimization, needs these two split and k_b_proj transposed
+        if name.endswith("kv_b_proj.weight"):
+            name_kb = name.replace("kv_b_proj", "k_b_proj")
+            name_vb = name.replace("kv_b_proj", "v_b_proj")
+
+            n_head_kv = self.hparams["num_key_value_heads"]
+            v_head_dim = self.hparams["v_head_dim"]
+            qk_nope_head_dim = self.hparams["qk_nope_head_dim"]
+
+            assert data_torch.shape[0] == n_head_kv * (v_head_dim + qk_nope_head_dim)
+
+            kv_b = data_torch.view(n_head_kv, v_head_dim + qk_nope_head_dim, data_torch.shape[-1])
+            k_b, v_b = torch.split(kv_b, [qk_nope_head_dim, v_head_dim], dim=1)
+            k_b = k_b.transpose(1, 2)
+
+            return [
+                (self.map_tensor_name(name_kb), k_b),
+                (self.map_tensor_name(name_vb), v_b)
+            ]
+
         return [(self.map_tensor_name(name), data_torch)]
 
     def prepare_tensors(self):
index 162070e6e193adc8f7c20d74fe108dca7a723223..8fcde2626aa7c48d756dc339384e54c412e112ad 100644 (file)
@@ -139,6 +139,8 @@ class Keys:
         REL_BUCKETS_COUNT            = "{arch}.attention.relative_buckets_count"
         SLIDING_WINDOW               = "{arch}.attention.sliding_window"
         SCALE                        = "{arch}.attention.scale"
+        KEY_LENGTH_MLA               = "{arch}.attention.key_length_mla"
+        VALUE_LENGTH_MLA             = "{arch}.attention.value_length_mla"
 
     class Rope:
         DIMENSION_COUNT         = "{arch}.rope.dimension_count"
@@ -382,6 +384,8 @@ class MODEL_TENSOR(IntEnum):
     ATTN_Q_B             = auto()
     ATTN_KV_A_MQA        = auto()
     ATTN_KV_B            = auto()
+    ATTN_K_B             = auto()
+    ATTN_V_B             = auto()
     ATTN_Q_A_NORM        = auto()
     ATTN_KV_A_NORM       = auto()
     FFN_SUB_NORM         = auto()
@@ -590,6 +594,8 @@ TENSOR_NAMES: dict[MODEL_TENSOR, str] = {
     MODEL_TENSOR.ATTN_Q_B:                  "blk.{bid}.attn_q_b",
     MODEL_TENSOR.ATTN_KV_A_MQA:             "blk.{bid}.attn_kv_a_mqa",
     MODEL_TENSOR.ATTN_KV_B:                 "blk.{bid}.attn_kv_b",
+    MODEL_TENSOR.ATTN_K_B:                  "blk.{bid}.attn_k_b",
+    MODEL_TENSOR.ATTN_V_B:                  "blk.{bid}.attn_v_b",
     MODEL_TENSOR.ATTN_Q_A_NORM:             "blk.{bid}.attn_q_a_norm",
     MODEL_TENSOR.ATTN_KV_A_NORM:            "blk.{bid}.attn_kv_a_norm",
     MODEL_TENSOR.ATTN_SUB_NORM:             "blk.{bid}.attn_sub_norm",
@@ -1517,6 +1523,8 @@ MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = {
         MODEL_TENSOR.ATTN_Q_B,
         MODEL_TENSOR.ATTN_KV_A_MQA,
         MODEL_TENSOR.ATTN_KV_B,
+        MODEL_TENSOR.ATTN_K_B,
+        MODEL_TENSOR.ATTN_V_B,
         MODEL_TENSOR.ATTN_Q_A_NORM,
         MODEL_TENSOR.ATTN_KV_A_NORM,
         MODEL_TENSOR.ATTN_OUT,
index 485550aad6da49f0eb8671d83ac9f6bc1dba9735..aef03db1577a71de01cd0a9e81862c737d657747 100644 (file)
@@ -689,6 +689,12 @@ class GGUFWriter:
     def add_value_length(self, length: int) -> None:
         self.add_uint32(Keys.Attention.VALUE_LENGTH.format(arch=self.arch), length)
 
+    def add_key_length_mla(self, length: int) -> None:
+        self.add_uint32(Keys.Attention.KEY_LENGTH_MLA.format(arch=self.arch), length)
+
+    def add_value_length_mla(self, length: int) -> None:
+        self.add_uint32(Keys.Attention.VALUE_LENGTH_MLA.format(arch=self.arch), length)
+
     def add_max_alibi_bias(self, bias: float) -> None:
         self.add_float32(Keys.Attention.MAX_ALIBI_BIAS.format(arch=self.arch), bias)
 
index 35154e9b5da905f8253b7995cc859b1b9508ea5e..0bc75cf513a9f9b4f44b628223e256369da41d66 100644 (file)
@@ -677,6 +677,14 @@ class TensorNameMap:
             "model.layers.{bid}.self_attn.kv_b_proj", # deepseek2
         ),
 
+        MODEL_TENSOR.ATTN_K_B: (
+            "model.layers.{bid}.self_attn.k_b_proj",  # deepseek2
+        ),
+
+        MODEL_TENSOR.ATTN_V_B: (
+            "model.layers.{bid}.self_attn.v_b_proj",  # deepseek2
+        ),
+
         MODEL_TENSOR.ATTN_Q_A_NORM: (
             "model.layers.{bid}.self_attn.q_a_layernorm", # deepseek2
         ),
index a6fddc7fd2e543f30412bf6b90d42ae94f156fee..62e1480bb5881aea613182ba522c890ac3511d2d 100644 (file)
@@ -140,6 +140,8 @@ static const std::map<llm_kv, const char *> LLM_KV_NAMES = {
     { LLM_KV_ATTENTION_RELATIVE_BUCKETS_COUNT,       "%s.attention.relative_buckets_count"       },
     { LLM_KV_ATTENTION_SLIDING_WINDOW,               "%s.attention.sliding_window"               },
     { LLM_KV_ATTENTION_SCALE,                        "%s.attention.scale"                        },
+    { LLM_KV_ATTENTION_KEY_LENGTH_MLA,               "%s.attention.key_length_mla"               },
+    { LLM_KV_ATTENTION_VALUE_LENGTH_MLA,             "%s.attention.value_length_mla"             },
 
     { LLM_KV_ROPE_DIMENSION_COUNT,      "%s.rope.dimension_count"                 },
     { LLM_KV_ROPE_DIMENSION_SECTIONS,   "%s.rope.dimension_sections"              },
@@ -1103,6 +1105,8 @@ static const std::map<llm_arch, std::map<llm_tensor, const char *>> LLM_TENSOR_N
             { LLM_TENSOR_ATTN_Q_B,           "blk.%d.attn_q_b" },
             { LLM_TENSOR_ATTN_KV_A_MQA,      "blk.%d.attn_kv_a_mqa" },
             { LLM_TENSOR_ATTN_KV_B,          "blk.%d.attn_kv_b" },
+            { LLM_TENSOR_ATTN_K_B,           "blk.%d.attn_k_b" },
+            { LLM_TENSOR_ATTN_V_B,           "blk.%d.attn_v_b" },
             { LLM_TENSOR_ATTN_OUT,           "blk.%d.attn_output" },
             { LLM_TENSOR_FFN_NORM,           "blk.%d.ffn_norm" },
             { LLM_TENSOR_FFN_GATE,           "blk.%d.ffn_gate" },
@@ -1563,23 +1567,8 @@ static const std::map<llm_tensor, llm_tensor_info> LLM_TENSOR_INFOS = {
     {LLM_TENSOR_ATTN_Q_B,                   {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
     {LLM_TENSOR_ATTN_KV_A_MQA,              {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
     {LLM_TENSOR_ATTN_KV_B,                  {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
-    {LLM_TENSOR_DEC_ATTN_Q,                 {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
-    {LLM_TENSOR_DEC_ATTN_K,                 {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
-    {LLM_TENSOR_ATTN_Q,                     {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
-    {LLM_TENSOR_ATTN_K,                     {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
-    {LLM_TENSOR_ATTN_V,                     {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
-    {LLM_TENSOR_ATTN_QKV,                   {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
-    {LLM_TENSOR_ATTN_OUT,                   {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
-    {LLM_TENSOR_FFN_GATE,                   {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
-    {LLM_TENSOR_FFN_DOWN,                   {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
-    {LLM_TENSOR_FFN_UP,                     {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
-    {LLM_TENSOR_FFN_DOWN_SHEXP,             {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
-    {LLM_TENSOR_FFN_GATE_SHEXP,             {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
-    {LLM_TENSOR_FFN_UP_SHEXP,               {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
-    {LLM_TENSOR_ATTN_Q_A,                   {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
-    {LLM_TENSOR_ATTN_Q_B,                   {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
-    {LLM_TENSOR_ATTN_KV_A_MQA,              {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
-    {LLM_TENSOR_ATTN_KV_B,                  {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
+    {LLM_TENSOR_ATTN_K_B,                   {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
+    {LLM_TENSOR_ATTN_V_B,                   {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
     {LLM_TENSOR_DEC_ATTN_Q,                 {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
     {LLM_TENSOR_DEC_ATTN_K,                 {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
     {LLM_TENSOR_DEC_ATTN_V,                 {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
index 2c2099b3c38517381071f824d61f5564a03265b2..98ca00a1bd0b0aa3a3a2bca52e21341bd062b05c 100644 (file)
@@ -144,6 +144,8 @@ enum llm_kv {
     LLM_KV_ATTENTION_RELATIVE_BUCKETS_COUNT,
     LLM_KV_ATTENTION_SLIDING_WINDOW,
     LLM_KV_ATTENTION_SCALE,
+    LLM_KV_ATTENTION_KEY_LENGTH_MLA,
+    LLM_KV_ATTENTION_VALUE_LENGTH_MLA,
 
     LLM_KV_ROPE_DIMENSION_COUNT,
     LLM_KV_ROPE_DIMENSION_SECTIONS,
@@ -306,6 +308,8 @@ enum llm_tensor {
     LLM_TENSOR_ATTN_Q_B,
     LLM_TENSOR_ATTN_KV_A_MQA,
     LLM_TENSOR_ATTN_KV_B,
+    LLM_TENSOR_ATTN_K_B,
+    LLM_TENSOR_ATTN_V_B,
     LLM_TENSOR_ATTN_Q_A_NORM,
     LLM_TENSOR_ATTN_KV_A_NORM,
     LLM_TENSOR_ATTN_SUB_NORM,
index 4735e98ea040ffe3abeec306fc14625059304880..d3ef1cbdeb65e9283453916b1586ba594f05de13 100644 (file)
@@ -10,6 +10,7 @@
 #include <cstring>
 #include <stdexcept>
 #include <cinttypes>
+#include <cmath>
 
 //
 // llama_context
@@ -473,7 +474,6 @@ ggml_tensor * llama_context::build_rope_shift(
     const auto & n_ctx_orig = cparams.n_ctx_orig_yarn;
 
     const auto & yarn_ext_factor  = cparams.yarn_ext_factor;
-    const auto & yarn_attn_factor = cparams.yarn_attn_factor;
     const auto & yarn_beta_fast   = cparams.yarn_beta_fast;
     const auto & yarn_beta_slow   = cparams.yarn_beta_slow;
 
@@ -482,6 +482,10 @@ ggml_tensor * llama_context::build_rope_shift(
     const auto & n_rot     = hparams.n_rot;
     const auto & rope_type = hparams.rope_type;
 
+    // See llm_build_deepseek2() for why attn_factor has to be scaled for YaRN RoPE to work correctly.
+    // See https://github.com/ggerganov/llama.cpp/discussions/7416 for detailed explanation.
+    const float yarn_attn_factor_scaled = model.arch == LLM_ARCH_DEEPSEEK2 ? 1.0f / (1.0f + 0.1f * logf(1.0f / freq_scale)) : cparams.yarn_attn_factor;
+
     ggml_tensor * tmp;
 
     if (ggml_is_quantized(cur->type)) {
@@ -500,14 +504,14 @@ ggml_tensor * llama_context::build_rope_shift(
 
         tmp = ggml_rope_ext_inplace(ctx0, tmp,
                 shift, factors, n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
-                yarn_ext_factor, yarn_attn_factor, yarn_beta_fast, yarn_beta_slow);
+                yarn_ext_factor, yarn_attn_factor_scaled, yarn_beta_fast, yarn_beta_slow);
 
         tmp = ggml_cpy(ctx0, tmp, cur);
     } else {
         // we rotate only the first n_rot dimensions
         tmp = ggml_rope_ext_inplace(ctx0, cur,
                 shift, factors, n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
-                yarn_ext_factor, yarn_attn_factor, yarn_beta_fast, yarn_beta_slow);
+                yarn_ext_factor, yarn_attn_factor_scaled, yarn_beta_fast, yarn_beta_slow);
     }
 
     return tmp;
@@ -2274,6 +2278,11 @@ llama_context * llama_init_from_model(
         params.flash_attn = false;
     }
 
+    if (params.flash_attn && model->arch == LLM_ARCH_DEEPSEEK2) {
+        LLAMA_LOG_WARN("%s: flash_attn is not compatible with Deepseek2 - forcing off\n", __func__);
+        params.flash_attn = false;
+    }
+
     if (ggml_is_quantized(params.type_v) && !params.flash_attn) {
         LLAMA_LOG_ERROR("%s: V cache quantization requires flash_attn\n", __func__);
         return nullptr;
index cd955d63bc390729f7a9f4863264a804cd2c53e9..5d0222b9810584e6e63720d6121aef583500ef9f 100644 (file)
@@ -1188,6 +1188,7 @@ ggml_tensor * llm_graph_context::build_attn_mha(
          ggml_tensor * v,
          ggml_tensor * kq_b,
          ggml_tensor * kq_mask,
+         ggml_tensor * v_mla,
              bool      v_trans,
              float     kq_scale) const {
   //const int64_t n_embd_k_gqa = hparams.n_embd_k_gqa(il);
@@ -1199,7 +1200,8 @@ ggml_tensor * llm_graph_context::build_attn_mha(
   //const auto & n_embd_head_k = hparams.n_embd_head_k;
   //const auto & n_embd_head_v = hparams.n_embd_head_v;
 
-    const auto n_embd_head_v = v_trans ? v->ne[1] : v->ne[0];
+    // note: for MLA with the absorption optimization, the final embedding size will be changed via v_mla
+    const auto n_embd_head_v = v_mla == nullptr ? v_trans ? v->ne[1] : v->ne[0] : v_mla->ne[1];
 
     const auto n_tokens = q->ne[1];
     const auto n_head   = q->ne[2];
@@ -1267,6 +1269,11 @@ ggml_tensor * llm_graph_context::build_attn_mha(
 
         ggml_tensor * kqv = ggml_mul_mat(ctx0, v, kq);
 
+        // for MLA with the absorption optimization, we need to "decompress" from MQA back to MHA
+        if (v_mla) {
+            kqv = ggml_mul_mat(ctx0, v_mla, kqv);
+        }
+
         ggml_tensor * kqv_merged = ggml_permute(ctx0, kqv, 0, 2, 1, 3);
 
         cur = ggml_cont_2d(ctx0, kqv_merged, n_embd_head_v*n_head, n_tokens);
@@ -1304,6 +1311,7 @@ ggml_tensor * llm_graph_context::build_attn(
         ggml_tensor * k_cur,
         ggml_tensor * v_cur,
         ggml_tensor * kq_b,
+        ggml_tensor * v_mla,
             float     kq_scale,
             int       il) const {
     GGML_UNUSED(n_tokens);
@@ -1325,7 +1333,7 @@ ggml_tensor * llm_graph_context::build_attn(
     ggml_tensor * v = ggml_permute(ctx0, v_cur, 0, 2, 1, 3);
     //cb(k, "v", il);
 
-    ggml_tensor * cur = build_attn_mha(gf, q, k, v, kq_b, kq_mask, false, kq_scale);
+    ggml_tensor * cur = build_attn_mha(gf, q, k, v, kq_b, kq_mask, v_mla, false, kq_scale);
 
     cb(cur, "kqv_out", il);
 
@@ -1379,6 +1387,7 @@ ggml_tensor * llm_graph_context::build_attn(
         ggml_tensor * k_cur,
         ggml_tensor * v_cur,
         ggml_tensor * kq_b,
+        ggml_tensor * v_mla,
             float     kq_scale,
             int       il) const {
     // these nodes are added to the graph together so that they are not reordered
@@ -1464,7 +1473,7 @@ ggml_tensor * llm_graph_context::build_attn(
                 ggml_element_size(kv_self->v_l[il])*n_ctx*n_embd_head_v,
                 0);
 
-    ggml_tensor * cur = build_attn_mha(gf, q, k, v, kq_b, kq_mask, v_trans, kq_scale);
+    ggml_tensor * cur = build_attn_mha(gf, q, k, v, kq_b, kq_mask, v_mla, v_trans, kq_scale);
     cb(cur, "kqv_out", il);
 
     if (wo) {
@@ -1504,6 +1513,7 @@ ggml_tensor * llm_graph_context::build_attn(
         ggml_tensor * k_cur,
         ggml_tensor * v_cur,
         ggml_tensor * kq_b,
+        ggml_tensor * v_mla,
             float     kq_scale,
             int       il) const {
     // these nodes are added to the graph together so that they are not reordered
@@ -1523,7 +1533,7 @@ ggml_tensor * llm_graph_context::build_attn(
     ggml_tensor * v = ggml_permute(ctx0, v_cur, 0, 2, 1, 3);
     //cb(k, "v", il);
 
-    ggml_tensor * cur = build_attn_mha(gf, q, k, v, kq_b, kq_mask, false, kq_scale);
+    ggml_tensor * cur = build_attn_mha(gf, q, k, v, kq_b, kq_mask, v_mla, false, kq_scale);
 
     cb(cur, "kqv_out", il);
 
@@ -1692,4 +1702,3 @@ void llm_graph_context::build_pooling(
 
     ggml_build_forward_expand(gf, cur);
 }
-
index 5b6618f9e55f1e57d36a45ce423d7e69322062b9..d192dc14957873bea71e3469cf9c9ef48865fa8b 100644 (file)
@@ -505,11 +505,12 @@ struct llm_graph_context {
 
     ggml_tensor * build_attn_mha(
              ggml_cgraph * gf,
-             ggml_tensor * q, // [n_embd_head_q, n_tokens, n_head_q]
-             ggml_tensor * k, // [n_embd_head_k, n_tokens, n_head_k]
-             ggml_tensor * v, // [n_embd_head_v, n_tokens, n_head_v] (v_trans == false)
+             ggml_tensor * q,     // [n_embd_head_q, n_tokens, n_head_q]
+             ggml_tensor * k,     // [n_embd_head_k, n_tokens, n_head_k]
+             ggml_tensor * v,     // [n_embd_head_v, n_tokens, n_head_v] (v_trans == false)
              ggml_tensor * kq_b,
              ggml_tensor * kq_mask,
+             ggml_tensor * v_mla, // [n_embd_head_v_mla, n_embd_head_v, n_head_v]
                     bool   v_trans,
                    float   kq_scale) const;
 
@@ -524,6 +525,7 @@ struct llm_graph_context {
             ggml_tensor * k_cur, // [n_embd_head_k, n_head_k, n_tokens]
             ggml_tensor * v_cur, // [n_embd_head_v, n_head_v, n_tokens]
             ggml_tensor * kq_b,
+            ggml_tensor * v_mla, // [n_embd_head_v_mla, n_embd_head_v, n_head_v]
                   float   kq_scale,
                     int   il) const;
 
@@ -538,6 +540,7 @@ struct llm_graph_context {
             ggml_tensor * k_cur, // [n_embd_head_k, n_head_k, n_tokens]
             ggml_tensor * v_cur, // [n_embd_head_v, n_head_v, n_tokens]
             ggml_tensor * kq_b,
+            ggml_tensor * v_mla, // [n_embd_head_v_mla, n_embd_head_v, n_head_v]
                   float   kq_scale,
                     int   il) const;
 
@@ -552,6 +555,7 @@ struct llm_graph_context {
             ggml_tensor * k_cur, // [n_embd_head_k, n_head_k, n_tokens]
             ggml_tensor * v_cur, // [n_embd_head_v, n_head_v, n_tokens]
             ggml_tensor * kq_b,
+            ggml_tensor * v_mla, // [n_embd_head_v_mla, n_embd_head_v, n_head_v]
                   float   kq_scale,
                     int   il) const;
 
index 4e0b57190a3a75af9a4cc102e4fca4137af9cd1b..80fcd65df0d3c5c38405474d1dbb1f680949ca5b 100644 (file)
@@ -43,6 +43,10 @@ struct llama_hparams {
     uint32_t n_expert_used = 0;
     uint32_t n_rel_attn_bkts = 0;
 
+    // note: deepseek2 using MLA converts into MQA with larger heads, then decompresses to MHA
+    uint32_t n_embd_head_k_mla = 0;
+    uint32_t n_embd_head_v_mla = 0;
+
     // for WavTokenizer
     struct llama_hparams_posnet   posnet;
     struct llama_hparams_convnext convnext;
index dbf5f1187d9e557c0bc7640b159e4326a6035191..7c9d46d8119b39d70bc0c5acb6580d73b7c6dcac 100644 (file)
@@ -27,7 +27,7 @@ bool llama_kv_cache_unified::init(
 
     recurrent = llama_model_is_recurrent(&model);
     v_trans   = !recurrent && !cparams.flash_attn;
-    can_shift = !recurrent && model.arch != LLM_ARCH_DEEPSEEK2; // not supported due to MLA
+    can_shift = !recurrent;
 
     LLAMA_LOG_INFO("%s: kv_size = %d, offload = %d, type_k = '%s', type_v = '%s', n_layer = %d, can_shift = %d\n",
             __func__, kv_size, offload, ggml_type_name(type_k), ggml_type_name(type_v), n_layer, can_shift);
index b74dd72cfbf25e823071b10cd52d1e35acfa4480..248c61748eaa8e881e9e40b8bab8c9dd9c2f0369 100644 (file)
@@ -1156,6 +1156,8 @@ void llama_model::load_hparams(llama_model_loader & ml) {
                     ml.get_key(LLM_KV_ATTENTION_Q_LORA_RANK, hparams.n_lora_q);
                 }
                 ml.get_key(LLM_KV_ATTENTION_KV_LORA_RANK,     hparams.n_lora_kv);
+                ml.get_key(LLM_KV_ATTENTION_KEY_LENGTH_MLA,   hparams.n_embd_head_k_mla, false);
+                ml.get_key(LLM_KV_ATTENTION_VALUE_LENGTH_MLA, hparams.n_embd_head_v_mla, false);
                 ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH, hparams.n_ff_exp);
                 ml.get_key(LLM_KV_EXPERT_SHARED_COUNT,        hparams.n_expert_shared);
                 ml.get_key(LLM_KV_EXPERT_WEIGHTS_SCALE,       hparams.expert_weights_scale);
@@ -3205,8 +3207,14 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
                 {
                     const bool is_lite = (hparams.n_layer == 27);
 
+                    const bool is_mla = (hparams.n_embd_head_k_mla != 0 && hparams.n_embd_head_v_mla != 0);
+
+                    // note: these are the actual head sizes you get when treating as MHA or after "decompression" using wv_b for MLA
+                    const int64_t n_embd_head_k_mla = is_mla ? hparams.n_embd_head_k_mla : hparams.n_embd_head_k;
+                    const int64_t n_embd_head_v_mla = is_mla ? hparams.n_embd_head_v_mla : hparams.n_embd_head_v;
+
                     const int64_t n_embd_head_qk_rope = hparams.n_rot;
-                    const int64_t n_embd_head_qk_nope = hparams.n_embd_head_k - hparams.n_rot;
+                    const int64_t n_embd_head_qk_nope = n_embd_head_k_mla - n_embd_head_qk_rope;
 
                     const int64_t q_lora_rank  = hparams.n_lora_q;
                     const int64_t kv_lora_rank = hparams.n_lora_kv;
@@ -3232,14 +3240,22 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
 
                         if (!is_lite) {
                             layer.wq_a = create_tensor(tn(LLM_TENSOR_ATTN_Q_A, "weight", i), {n_embd, q_lora_rank}, 0);
-                            layer.wq_b = create_tensor(tn(LLM_TENSOR_ATTN_Q_B, "weight", i), {q_lora_rank, n_head * n_embd_head_k}, 0);
+                            layer.wq_b = create_tensor(tn(LLM_TENSOR_ATTN_Q_B, "weight", i), {q_lora_rank, n_head * n_embd_head_k_mla}, 0);
                         } else {
-                            layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd_k_gqa}, 0);
+                            layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_head * n_embd_head_k_mla}, 0);
                         }
 
-                        layer.wkv_a_mqa = create_tensor(tn(LLM_TENSOR_ATTN_KV_A_MQA, "weight", i), {n_embd, kv_lora_rank + (n_embd_head_qk_rope)}, 0);
-                        layer.wkv_b     = create_tensor(tn(LLM_TENSOR_ATTN_KV_B,     "weight", i), {kv_lora_rank, n_head * (n_embd_head_qk_nope + n_embd_head_v)}, 0);
-                        layer.wo        = create_tensor(tn(LLM_TENSOR_ATTN_OUT,      "weight", i), {              n_head * (                      n_embd_head_v), n_embd}, 0);
+                        layer.wkv_a_mqa = create_tensor(tn(LLM_TENSOR_ATTN_KV_A_MQA, "weight", i), {n_embd, kv_lora_rank + n_embd_head_qk_rope}, 0);
+
+                        // note: only old legacy GGUF files will have the unsplit wkv_b tensor in
+                        if (is_mla) {
+                            layer.wk_b = create_tensor(tn(LLM_TENSOR_ATTN_K_B, "weight", i), {n_embd_head_qk_nope, kv_lora_rank, n_head}, 0);
+                            layer.wv_b = create_tensor(tn(LLM_TENSOR_ATTN_V_B, "weight", i), {kv_lora_rank, n_embd_head_v_mla, n_head}, 0);
+                        } else {
+                            layer.wkv_b = create_tensor(tn(LLM_TENSOR_ATTN_KV_B, "weight", i), {kv_lora_rank, n_head * (n_embd_head_qk_nope + n_embd_head_v_mla)}, 0);
+                        }
+
+                        layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_head * n_embd_head_v_mla, n_embd}, 0);
 
                         layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0);
 
@@ -4290,6 +4306,8 @@ void llama_model::print_info() const {
         LLAMA_LOG_INFO("%s: n_layer_dense_lead   = %d\n",     __func__, hparams.n_layer_dense_lead);
         LLAMA_LOG_INFO("%s: n_lora_q             = %d\n",     __func__, hparams.n_lora_q);
         LLAMA_LOG_INFO("%s: n_lora_kv            = %d\n",     __func__, hparams.n_lora_kv);
+        LLAMA_LOG_INFO("%s: n_embd_head_k_mla    = %d\n",     __func__, hparams.n_embd_head_k_mla);
+        LLAMA_LOG_INFO("%s: n_embd_head_v_mla    = %d\n",     __func__, hparams.n_embd_head_v_mla);
         LLAMA_LOG_INFO("%s: n_ff_exp             = %d\n",     __func__, hparams.n_ff_exp);
         LLAMA_LOG_INFO("%s: n_expert_shared      = %d\n",     __func__, hparams.n_expert_shared);
         LLAMA_LOG_INFO("%s: expert_weights_scale = %.1f\n",   __func__, hparams.expert_weights_scale);
@@ -4496,7 +4514,7 @@ struct llm_build_llama : public llm_graph_context {
 
                 cur = build_attn(inp_attn, gf,
                         model.layers[il].wo, model.layers[il].bo,
-                        Qcur, Kcur, Vcur, nullptr, kq_scale, il);
+                        Qcur, Kcur, Vcur, nullptr, nullptr, kq_scale, il);
                 cb(cur, "attn_out", il);
             }
 
@@ -4709,7 +4727,7 @@ struct llm_build_deci : public llm_graph_context {
 
                 cur = build_attn(inp_attn, gf,
                         model.layers[il].wo, model.layers[il].bo,
-                        Qcur, Kcur, Vcur, nullptr, kq_scale, il);
+                        Qcur, Kcur, Vcur, nullptr, nullptr, kq_scale, il);
             }
 
             if (il == n_layer - 1) {
@@ -4851,7 +4869,7 @@ struct llm_build_baichuan : public llm_graph_context {
 
                 cur = build_attn(inp_attn, gf,
                         model.layers[il].wo, NULL,
-                        Qcur, Kcur, Vcur, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
+                        Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
             }
 
             if (il == n_layer - 1) {
@@ -4966,7 +4984,7 @@ struct llm_build_xverse : public llm_graph_context {
 
                 cur = build_attn(inp_attn, gf,
                         model.layers[il].wo, NULL,
-                        Qcur, Kcur, Vcur, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
+                        Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
             }
 
             if (il == n_layer - 1) {
@@ -5091,7 +5109,7 @@ struct llm_build_falcon : public llm_graph_context {
 
                 cur = build_attn(inp_attn, gf,
                         model.layers[il].wo, NULL,
-                        Qcur, Kcur, Vcur, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
+                        Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
             }
 
             if (il == n_layer - 1) {
@@ -5221,7 +5239,7 @@ struct llm_build_grok : public llm_graph_context {
 
                 cur = build_attn(inp_attn, gf,
                         model.layers[il].wo, model.layers[il].bo,
-                        Qcur, Kcur, Vcur, nullptr, 1.0f, il);
+                        Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f, il);
             }
 
             if (il == n_layer - 1) {
@@ -5372,7 +5390,7 @@ struct llm_build_dbrx : public llm_graph_context {
 
                 cur = build_attn(inp_attn, gf,
                         model.layers[il].wo, NULL,
-                        Qcur, Kcur, Vcur, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
+                        Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
             }
 
             if (il == n_layer - 1) {
@@ -5486,7 +5504,7 @@ struct llm_build_starcoder : public llm_graph_context {
 
                 cur = build_attn(inp_attn, gf,
                         model.layers[il].wo, model.layers[il].bo,
-                        Qcur, Kcur, Vcur, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
+                        Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
             }
 
             if (il == n_layer - 1) {
@@ -5585,7 +5603,7 @@ struct llm_build_refact : public llm_graph_context {
 
                 cur = build_attn(inp_attn, gf,
                         model.layers[il].wo, NULL,
-                        Qcur, Kcur, Vcur, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
+                        Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
             }
 
             if (il == n_layer - 1) {
@@ -5739,7 +5757,7 @@ struct llm_build_bert : public llm_graph_context {
 
             cur = build_attn(inp_attn, gf,
                     model.layers[il].wo, model.layers[il].bo,
-                    Qcur, Kcur, Vcur, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
+                    Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
             cb(cur, "kqv_out", il);
 
             if (il == n_layer - 1 && pooling_type == LLAMA_POOLING_TYPE_NONE) {
@@ -5856,7 +5874,7 @@ struct llm_build_bloom : public llm_graph_context {
 
                 cur = build_attn(inp_attn, gf,
                         model.layers[il].wo, model.layers[il].bo,
-                        Qcur, Kcur, Vcur, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
+                        Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
             }
 
             if (il == n_layer - 1) {
@@ -5997,7 +6015,7 @@ struct llm_build_mpt : public llm_graph_context {
 
                 cur = build_attn(inp_attn, gf,
                         model.layers[il].wo, model.layers[il].bo,
-                        Qcur, Kcur, Vcur, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
+                        Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
             }
 
             if (il == n_layer - 1) {
@@ -6143,7 +6161,7 @@ struct llm_build_stablelm : public llm_graph_context {
 
                 cur = build_attn(inp_attn, gf,
                         model.layers[il].wo, NULL,
-                        Qcur, Kcur, Vcur, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
+                        Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
             }
 
             if (il == n_layer - 1) {
@@ -6266,7 +6284,7 @@ struct llm_build_qwen : public llm_graph_context {
 
                 cur = build_attn(inp_attn, gf,
                         model.layers[il].wo, NULL,
-                        Qcur, Kcur, Vcur, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
+                        Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
             }
 
             if (il == n_layer - 1) {
@@ -6386,7 +6404,7 @@ struct llm_build_qwen2 : public llm_graph_context {
 
                 cur = build_attn(inp_attn, gf,
                         model.layers[il].wo, model.layers[il].bo,
-                        Qcur, Kcur, Vcur, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
+                        Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
             }
 
             if (il == n_layer - 1) {
@@ -6507,7 +6525,7 @@ struct llm_build_qwen2vl : public llm_graph_context {
 
                 cur = build_attn(inp_attn, gf,
                         model.layers[il].wo, model.layers[il].bo,
-                        Qcur, Kcur, Vcur, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
+                        Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
             }
 
             if (il == n_layer - 1) {
@@ -6634,7 +6652,7 @@ struct llm_build_qwen2moe : public llm_graph_context {
 
                 cur = build_attn(inp_attn, gf,
                         model.layers[il].wo, model.layers[il].bo,
-                        Qcur, Kcur, Vcur, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
+                        Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
             }
 
             if (il == n_layer - 1) {
@@ -6787,7 +6805,7 @@ struct llm_build_qwen3 : public llm_graph_context {
 
                 cur = build_attn(inp_attn, gf,
                         model.layers[il].wo, model.layers[il].bo,
-                        Qcur, Kcur, Vcur, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
+                        Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
             }
 
             if (il == n_layer - 1) {
@@ -6908,7 +6926,7 @@ struct llm_build_qwen3moe : public llm_graph_context {
 
                 cur = build_attn(inp_attn, gf,
                         model.layers[il].wo, model.layers[il].bo,
-                        Qcur, Kcur, Vcur, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
+                        Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
             }
 
             if (il == n_layer - 1) {
@@ -7048,7 +7066,7 @@ struct llm_build_phi2 : public llm_graph_context {
 
                 cur = build_attn(inp_attn, gf,
                         model.layers[il].wo, model.layers[il].bo,
-                        Qcur, Kcur, Vcur, nullptr, 1.0f, il);
+                        Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f, il);
             }
 
             if (il == n_layer - 1) {
@@ -7177,7 +7195,7 @@ struct llm_build_phi3 : public llm_graph_context {
 
                 cur = build_attn(inp_attn, gf,
                         model.layers[il].wo, model.layers[il].bo,
-                        Qcur, Kcur, Vcur, nullptr, 1.0f, il);
+                        Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f, il);
             }
 
             if (il == n_layer - 1) {
@@ -7312,7 +7330,7 @@ struct llm_build_plamo : public llm_graph_context {
 
                 cur = build_attn(inp_attn, gf,
                         model.layers[il].wo, NULL,
-                        Qcur, Kcur, Vcur, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
+                        Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
             }
             ggml_tensor * sa_out = cur;
 
@@ -7419,7 +7437,7 @@ struct llm_build_gpt2 : public llm_graph_context {
 
                 cur = build_attn(inp_attn, gf,
                         model.layers[il].wo, model.layers[il].bo,
-                        Qcur, Kcur, Vcur, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
+                        Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
             }
 
             if (il == n_layer - 1) {
@@ -7535,7 +7553,7 @@ struct llm_build_codeshell : public llm_graph_context {
 
                 cur = build_attn(inp_attn, gf,
                         model.layers[il].wo, model.layers[il].bo,
-                        Qcur, Kcur, Vcur, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
+                        Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
             }
 
             if (il == n_layer - 1) {
@@ -7664,7 +7682,7 @@ struct llm_build_orion : public llm_graph_context {
 
             cur = build_attn(inp_attn, gf,
                     model.layers[il].wo, NULL,
-                    Qcur, Kcur, Vcur, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
+                    Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
         }
 
         if (il == n_layer - 1) {
@@ -7791,7 +7809,7 @@ struct llm_build_internlm2 : public llm_graph_context {
 
                 cur = build_attn(inp_attn, gf,
                         model.layers[il].wo, model.layers[il].bo,
-                        Qcur, Kcur, Vcur, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
+                        Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
             }
 
             if (il == n_layer - 1) {
@@ -7988,7 +8006,7 @@ struct llm_build_minicpm3 : public llm_graph_context {
 
                 cur = build_attn(inp_attn, gf,
                         model.layers[il].wo, NULL,
-                        q_states, k_states, v_states, nullptr, kq_scale, il);
+                        q_states, k_states, v_states, nullptr, nullptr, kq_scale, il);
             }
 
             if (il == n_layer - 1) {
@@ -8118,7 +8136,7 @@ struct llm_build_gemma : public llm_graph_context {
 
                 cur = build_attn(inp_attn, gf,
                         model.layers[il].wo, NULL,
-                        Qcur, Kcur, Vcur, nullptr, 1.0f, il);
+                        Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f, il);
             }
 
             if (il == n_layer - 1) {
@@ -8240,7 +8258,7 @@ struct llm_build_gemma2 : public llm_graph_context {
 
                 cur = build_attn(inp_attn, gf,
                         model.layers[il].wo, NULL,
-                        Qcur, Kcur, Vcur, nullptr, 1.0f, il);
+                        Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f, il);
             }
 
             cur = build_norm(cur,
@@ -8381,7 +8399,7 @@ struct llm_build_gemma3 : public llm_graph_context {
 
                 cur = build_attn(inp_attn, gf,
                         model.layers[il].wo, NULL,
-                        Qcur, Kcur, Vcur, nullptr, hparams.f_attention_scale, il);
+                        Qcur, Kcur, Vcur, nullptr, nullptr, hparams.f_attention_scale, il);
             }
 
             cur = build_norm(cur,
@@ -8521,7 +8539,7 @@ struct llm_build_starcoder2 : public llm_graph_context {
 
                 cur = build_attn(inp_attn, gf,
                         model.layers[il].wo, model.layers[il].bo,
-                        Qcur, Kcur, Vcur, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
+                        Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
             }
 
             if (il == n_layer - 1) {
@@ -8856,7 +8874,7 @@ struct llm_build_command_r : public llm_graph_context {
 
                 cur = build_attn(inp_attn, gf,
                         model.layers[il].wo, model.layers[il].bo,
-                        Qcur, Kcur, Vcur, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
+                        Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
             }
 
             if (il == n_layer - 1) {
@@ -8991,7 +9009,7 @@ struct llm_build_cohere2 : public llm_graph_context {
 
                 cur = build_attn(inp_attn, gf,
                         model.layers[il].wo, model.layers[il].bo,
-                        Qcur, Kcur, Vcur, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
+                        Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
             }
 
             if (il == n_layer - 1) {
@@ -9122,7 +9140,7 @@ struct llm_build_olmo : public llm_graph_context {
 
                 cur = build_attn(inp_attn, gf,
                         model.layers[il].wo, nullptr,
-                        Qcur, Kcur, Vcur, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
+                        Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
             }
 
             if (il == n_layer - 1) {
@@ -9242,7 +9260,7 @@ struct llm_build_olmo2 : public llm_graph_context {
 
                 cur = build_attn(inp_attn, gf,
                         model.layers[il].wo, NULL,
-                        Qcur, Kcur, Vcur, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
+                        Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
             }
 
             cur = build_norm(cur,
@@ -9375,7 +9393,7 @@ struct llm_build_olmoe : public llm_graph_context {
 
                 cur = build_attn(inp_attn, gf,
                         model.layers[il].wo, NULL,
-                        Qcur, Kcur, Vcur, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
+                        Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
             }
 
             if (il == n_layer - 1) {
@@ -9508,7 +9526,7 @@ struct llm_build_openelm : public llm_graph_context {
 
                 cur = build_attn(inp_attn, gf,
                         model.layers[il].wo, NULL,
-                        Qcur, Kcur, Vcur, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
+                        Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
             }
 
             if (il == n_layer - 1) {
@@ -9622,7 +9640,7 @@ struct llm_build_gptneox : public llm_graph_context {
 
                 cur = build_attn(inp_attn, gf,
                         model.layers[il].wo, model.layers[il].bo,
-                        Qcur, Kcur, Vcur, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
+                        Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
             }
 
             if (il == n_layer - 1) {
@@ -9772,7 +9790,7 @@ struct llm_build_arctic : public llm_graph_context {
 
                 cur = build_attn(inp_attn, gf,
                         model.layers[il].wo, NULL,
-                        Qcur, Kcur, Vcur, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
+                        Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
             }
 
             if (il == n_layer - 1) {
@@ -9927,7 +9945,7 @@ struct llm_build_deepseek : public llm_graph_context {
 
                 cur = build_attn(inp_attn, gf,
                         model.layers[il].wo, model.layers[il].bo,
-                        Qcur, Kcur, Vcur, nullptr, kq_scale, il);
+                        Qcur, Kcur, Vcur, nullptr, nullptr, kq_scale, il);
             }
 
             if (il == n_layer - 1) {
@@ -10017,16 +10035,23 @@ struct llm_build_deepseek2 : public llm_graph_context {
     llm_build_deepseek2(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) {
         bool is_lite = (hparams.n_layer == 27);
 
+        const bool is_mla = (hparams.n_embd_head_k_mla != 0 && hparams.n_embd_head_v_mla != 0);
+
+        // note: these are the actual head sizes you get when treating as MHA or after "decompression" using wv_b for MLA
+        const int64_t n_embd_head_k = is_mla ? hparams.n_embd_head_k_mla : hparams.n_embd_head_k;
+        const int64_t n_embd_head_v = is_mla ? hparams.n_embd_head_v_mla : hparams.n_embd_head_v;
+
+        const int64_t n_embd_head_qk_rope = hparams.n_rot;
+        const int64_t n_embd_head_qk_nope = n_embd_head_k - n_embd_head_qk_rope;
+
+        const uint32_t kv_lora_rank = hparams.n_lora_kv;
+
         // We have to pre-scale kq_scale and attn_factor to make the YaRN RoPE work correctly.
         // See https://github.com/ggerganov/llama.cpp/discussions/7416 for detailed explanation.
         const float mscale = attn_factor * (1.0f + hparams.rope_yarn_log_mul * logf(1.0f / freq_scale));
-        const float kq_scale = 1.0f*mscale*mscale/sqrtf(float(hparams.n_embd_head_k));
+        const float kq_scale = 1.0f*mscale*mscale/sqrtf(float(n_embd_head_k));
         const float attn_factor_scaled = 1.0f / (1.0f + 0.1f * logf(1.0f / freq_scale));
 
-        const uint32_t n_embd_head_qk_rope = hparams.n_rot;
-        const uint32_t n_embd_head_qk_nope = hparams.n_embd_head_k - hparams.n_rot;
-        const uint32_t kv_lora_rank = hparams.n_lora_kv;
-
         ggml_tensor * cur;
         ggml_tensor * inpL;
 
@@ -10051,16 +10076,14 @@ struct llm_build_deepseek2 : public llm_graph_context {
             {
                 ggml_tensor * q = NULL;
                 if (!is_lite) {
-                    // {n_embd, q_lora_rank} * {n_embd, n_tokens} -> {q_lora_rank, n_tokens}
                     q = ggml_mul_mat(ctx0, model.layers[il].wq_a, cur);
                     cb(q, "q", il);
 
                     q = build_norm(q,
-                            model.layers[il].attn_q_a_norm, NULL,
+                            model.layers[il].attn_q_a_norm, nullptr,
                             LLM_NORM_RMS, il);
                     cb(q, "q", il);
 
-                    // {q_lora_rank, n_head * hparams.n_embd_head_k} * {q_lora_rank, n_tokens} -> {n_head * hparams.n_embd_head_k, n_tokens}
                     q = ggml_mul_mat(ctx0, model.layers[il].wq_b, q);
                     cb(q, "q", il);
                 } else {
@@ -10068,96 +10091,125 @@ struct llm_build_deepseek2 : public llm_graph_context {
                     cb(q, "q", il);
                 }
 
-                // split into {n_head * n_embd_head_qk_nope, n_tokens}
-                ggml_tensor * q_nope = ggml_view_3d(ctx0, q, n_embd_head_qk_nope, n_head, n_tokens,
-                        ggml_row_size(q->type, hparams.n_embd_head_k),
-                        ggml_row_size(q->type, hparams.n_embd_head_k * n_head),
+                // split into {n_embd_head_qk_nope, n_head, n_tokens}
+                ggml_tensor * q_nope = ggml_view_3d(ctx0, q,
+                        n_embd_head_qk_nope, n_head, n_tokens,
+                        ggml_row_size(q->type, n_embd_head_k),
+                        ggml_row_size(q->type, n_embd_head_k) * n_head,
                         0);
                 cb(q_nope, "q_nope", il);
 
-                // and {n_head * n_embd_head_qk_rope, n_tokens}
-                ggml_tensor * q_pe = ggml_view_3d(ctx0, q, n_embd_head_qk_rope, n_head, n_tokens,
-                        ggml_row_size(q->type, hparams.n_embd_head_k),
-                        ggml_row_size(q->type, hparams.n_embd_head_k * n_head),
+                // and {n_embd_head_qk_rope, n_head, n_tokens}
+                ggml_tensor * q_pe = ggml_view_3d(ctx0, q,
+                        n_embd_head_qk_rope, n_head, n_tokens,
+                        ggml_row_size(q->type, n_embd_head_k),
+                        ggml_row_size(q->type, n_embd_head_k) * n_head,
                         ggml_row_size(q->type, n_embd_head_qk_nope));
                 cb(q_pe, "q_pe", il);
 
-                // {n_embd, kv_lora_rank + n_embd_head_qk_rope} * {n_embd, n_tokens} -> {kv_lora_rank + n_embd_head_qk_rope, n_tokens}
-                ggml_tensor * kv_pe_compresseed = ggml_mul_mat(ctx0, model.layers[il].wkv_a_mqa, cur);
-                cb(kv_pe_compresseed, "kv_pe_compresseed", il);
+                ggml_tensor * kv_cmpr_pe = ggml_mul_mat(ctx0, model.layers[il].wkv_a_mqa, cur);
+                cb(kv_cmpr_pe, "kv_cmpr_pe", il);
 
                 // split into {kv_lora_rank, n_tokens}
-                ggml_tensor * kv_compressed = ggml_view_2d(ctx0, kv_pe_compresseed, kv_lora_rank, n_tokens,
-                        kv_pe_compresseed->nb[1],
+                ggml_tensor * kv_cmpr = ggml_view_2d(ctx0, kv_cmpr_pe,
+                        kv_lora_rank, n_tokens,
+                        ggml_row_size(kv_cmpr_pe->type, kv_lora_rank + n_embd_head_qk_rope),
                         0);
-                cb(kv_compressed, "kv_compressed", il);
+                cb(kv_cmpr, "kv_cmpr", il);
+
+                // and {n_embd_head_qk_rope, 1, n_tokens}
+                ggml_tensor * k_pe = ggml_view_3d(ctx0, kv_cmpr_pe,
+                        n_embd_head_qk_rope, 1, n_tokens,
+                        ggml_row_size(kv_cmpr_pe->type, kv_lora_rank + n_embd_head_qk_rope),
+                        ggml_row_size(kv_cmpr_pe->type, kv_lora_rank + n_embd_head_qk_rope),
+                        ggml_row_size(kv_cmpr_pe->type, kv_lora_rank));
+                cb(k_pe, "k_pe", il);
 
-                // and {n_embd_head_qk_rope, n_tokens}
-                ggml_tensor * k_pe = ggml_view_3d(ctx0, kv_pe_compresseed, n_embd_head_qk_rope, 1, n_tokens,
-                        kv_pe_compresseed->nb[1],
-                        kv_pe_compresseed->nb[1],
-                        ggml_row_size(kv_pe_compresseed->type, kv_lora_rank));
+                q_pe = ggml_rope_ext(ctx0, q_pe, inp_pos, nullptr,
+                        n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
+                        ext_factor, attn_factor_scaled, beta_fast, beta_slow
+                );
+                cb(q_pe, "q_pe", il);
+
+                k_pe = ggml_rope_ext(ctx0, k_pe, inp_pos, nullptr,
+                        n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
+                        ext_factor, attn_factor_scaled, beta_fast, beta_slow
+                );
                 cb(k_pe, "k_pe", il);
 
-                // TODO: the CUDA backend used to not support non-cont. (RMS) norm, investigate removing ggml_cont
-                kv_compressed = ggml_cont(ctx0, kv_compressed);
-                kv_compressed = build_norm(kv_compressed,
-                        model.layers[il].attn_kv_a_norm, NULL,
+                kv_cmpr = build_norm(kv_cmpr,
+                        model.layers[il].attn_kv_a_norm, nullptr,
                         LLM_NORM_RMS, il);
-                cb(kv_compressed, "kv_compressed", il);
+                cb(kv_cmpr, "kv_cmpr", il);
 
-                // {kv_lora_rank, n_head * (n_embd_head_qk_nope + n_embd_head_v)} * {kv_lora_rank, n_tokens} -> {n_head * (n_embd_head_qk_nope + n_embd_head_v), n_tokens}
-                ggml_tensor * kv = ggml_mul_mat(ctx0, model.layers[il].wkv_b, kv_compressed);
-                cb(kv, "kv", il);
+                if (is_mla) {
+                    // {n_embd_head_qk_nope, n_tokens, n_head}
+                    q_nope = ggml_permute(ctx0, q_nope, 0, 2, 1, 3);
+                    cb(q_nope, "q_nope_perm", il);
 
-                // split into {n_head * n_embd_head_qk_nope, n_tokens}
-                ggml_tensor * k_nope = ggml_view_3d(ctx0, kv, n_embd_head_qk_nope, n_head, n_tokens,
-                        ggml_row_size(kv->type, n_embd_head_qk_nope + hparams.n_embd_head_v),
-                        ggml_row_size(kv->type, n_head * (n_embd_head_qk_nope + hparams.n_embd_head_v)),
-                        0);
-                cb(k_nope, "k_nope", il);
+                    // {n_embd_head_qk_nope, kv_lora_rank, n_head} x {n_embd_head_qk_nope, n_tokens, n_head}
+                    ggml_tensor * q_nope_absorbed = ggml_mul_mat(ctx0, model.layers[il].wk_b, q_nope);
+                    cb(q_nope_absorbed, "q_nope_absorbed", il);
 
-                // and {n_head * n_embd_head_v, n_tokens}
-                ggml_tensor * v_states = ggml_view_3d(ctx0, kv, hparams.n_embd_head_v, n_head, n_tokens,
-                        ggml_row_size(kv->type, (n_embd_head_qk_nope + hparams.n_embd_head_v)),
-                        ggml_row_size(kv->type, (n_embd_head_qk_nope + hparams.n_embd_head_v)*n_head),
-                        ggml_row_size(kv->type, (n_embd_head_qk_nope)));
-                cb(v_states, "v_states", il);
+                    // {kv_lora_rank, n_head, n_tokens}
+                    q_nope_absorbed = ggml_permute(ctx0, q_nope_absorbed, 0, 2, 1, 3);
+                    cb(q_nope_absorbed, "q_nope_absorbed_perm", il);
 
-                v_states = ggml_cont(ctx0, v_states);
-                cb(v_states, "v_states", il);
+                    // {n_embd_head_qk_rope + kv_lora_rank, n_head, n_tokens}
+                    // note: rope must go first for in-place context shifting in build_rope_shift()
+                    ggml_tensor * Qcur = ggml_concat(ctx0, q_pe, q_nope_absorbed, 0);
+                    cb(Qcur, "Qcur", il);
 
-                v_states = ggml_view_2d(ctx0, v_states, hparams.n_embd_head_v * n_head, n_tokens,
-                        ggml_row_size(kv->type, hparams.n_embd_head_v * n_head),
-                        0);
-                cb(v_states, "v_states", il);
+                    kv_cmpr = ggml_reshape_3d(ctx0, kv_cmpr, kv_lora_rank, 1, n_tokens);
+                    cb(kv_cmpr, "kv_cmpr_reshape", il);
 
-                q_pe = ggml_cont(ctx0, q_pe); // TODO: the CUDA backend used to not support non-cont. RoPE, investigate removing this
-                q_pe = ggml_rope_ext(
-                        ctx0, q_pe, inp_pos, nullptr,
-                        n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
-                        ext_factor, attn_factor_scaled, beta_fast, beta_slow
-                        );
-                cb(q_pe, "q_pe", il);
+                    // {n_embd_head_qk_rope + kv_lora_rank, 1, n_tokens}
+                    ggml_tensor * Kcur = ggml_concat(ctx0, k_pe, kv_cmpr, 0);
+                    cb(Kcur, "Kcur", il);
 
-                // shared RoPE key
-                k_pe = ggml_cont(ctx0, k_pe); // TODO: the CUDA backend used to not support non-cont. RoPE, investigate removing this
-                k_pe = ggml_rope_ext(
-                        ctx0, k_pe, inp_pos, nullptr,
-                        n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
-                        ext_factor, attn_factor_scaled, beta_fast, beta_slow
-                        );
-                cb(k_pe, "k_pe", il);
+                    // {kv_lora_rank, 1, n_tokens}
+                    ggml_tensor * Vcur = kv_cmpr;
+                    cb(Vcur, "Vcur", il);
 
-                ggml_tensor * q_states = ggml_concat(ctx0, q_nope, q_pe, 0);
-                cb(q_states, "q_states", il);
+                    // note: MLA with the absorption optimzation converts into MQA (ie: GQA with 1 group)
+                    cur = build_attn(inp_attn, gf,
+                            model.layers[il].wo, NULL,
+                            Qcur, Kcur, Vcur, nullptr, model.layers[il].wv_b, kq_scale, il);
+                } else {
+                    ggml_tensor * kv = ggml_mul_mat(ctx0, model.layers[il].wkv_b, kv_cmpr);
+                    cb(kv, "kv", il);
+
+                    // split into {n_embd_head_qk_nope, n_head, n_tokens}
+                    ggml_tensor * k_nope = ggml_view_3d(ctx0, kv,
+                            n_embd_head_qk_nope, n_head, n_tokens,
+                            ggml_row_size(kv->type, n_embd_head_qk_nope + n_embd_head_v),
+                            ggml_row_size(kv->type, n_embd_head_qk_nope + n_embd_head_v) * n_head,
+                            0);
+                    cb(k_nope, "k_nope_view", il);
 
-                ggml_tensor * k_states = ggml_concat(ctx0, k_nope, ggml_repeat(ctx0, k_pe, q_pe), 0);
-                cb(k_states, "k_states", il);
+                    // and {n_embd_head_v, n_head, n_tokens}
+                    ggml_tensor * Vcur = ggml_view_3d(ctx0, kv,
+                            n_embd_head_v, n_head, n_tokens,
+                            ggml_row_size(kv->type, n_embd_head_qk_nope + n_embd_head_v),
+                            ggml_row_size(kv->type, n_embd_head_qk_nope + n_embd_head_v) * n_head,
+                            ggml_row_size(kv->type, n_embd_head_qk_nope));
+                    cb(Vcur, "Vcur_view", il);
 
-                cur = build_attn(inp_attn, gf,
-                        model.layers[il].wo, NULL,
-                        q_states, k_states, v_states, nullptr, kq_scale, il);
+                    Vcur = ggml_cont(ctx0, Vcur);
+                    cb(Vcur, "Vcur_cont", il);
+
+                    // note: rope must go first for in-place context shifting in build_rope_shift()
+                    ggml_tensor * Qcur = ggml_concat(ctx0, q_pe, q_nope, 0);
+                    cb(Qcur, "Qcur", il);
+
+                    ggml_tensor * Kcur = ggml_concat(ctx0, ggml_repeat(ctx0, k_pe, q_pe), k_nope, 0);
+                    cb(Kcur, "Kcur", il);
+
+                    // note: MLA without the absorption optimization converts into MHA (ie: GQA with full n_head groups)
+                    cur = build_attn(inp_attn, gf,
+                            model.layers[il].wo, NULL,
+                            Qcur, Kcur, Vcur, nullptr, nullptr, kq_scale, il);
+                }
             }
 
             if (il == n_layer - 1) {
@@ -10323,7 +10375,7 @@ struct llm_build_bitnet : public llm_graph_context {
 
                 cur = build_attn(inp_attn, gf,
                         NULL, NULL,
-                        Qcur, Kcur, Vcur, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
+                        Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
 
                 cur = build_norm(cur,
                         model.layers[il].attn_sub_norm, NULL,
@@ -10446,7 +10498,7 @@ struct llm_build_t5_enc : public llm_graph_context {
 
                 cur = build_attn(inp_attn, gf,
                         model.layers[il].wo_enc, nullptr,
-                        Qcur, Kcur, Vcur, kq_b, 1.0f, il);
+                        Qcur, Kcur, Vcur, kq_b, nullptr, 1.0f, il);
                 cb(cur, "kqv_out", il);
             }
 
@@ -10552,7 +10604,7 @@ struct llm_build_t5_dec : public llm_graph_context {
 
                 cur = build_attn(inp_attn_self, gf,
                         model.layers[il].wo, model.layers[il].bo,
-                        Qcur, Kcur, Vcur, kq_b, 1.0f, il);
+                        Qcur, Kcur, Vcur, kq_b, nullptr, 1.0f, il);
                 cb(cur, "kqv_out", il);
             }
 
@@ -10584,7 +10636,7 @@ struct llm_build_t5_dec : public llm_graph_context {
 
                 cur = build_attn(inp_attn_cross, gf,
                         model.layers[il].wo_cross, nullptr,
-                        Qcur, Kcur, Vcur, nullptr, 1.0f, il);
+                        Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f, il);
                 cb(cur, "kqv_out", il);
 
                 //ggml_tensor * q =                 ggml_permute(ctx0, Qcur, 0, 2, 1, 3);
@@ -10717,7 +10769,7 @@ struct llm_build_jais : public llm_graph_context {
 
                 cur = build_attn(inp_attn, gf,
                         model.layers[il].wo, model.layers[il].bo,
-                        Qcur, Kcur, Vcur, nullptr, 1.0f/float(n_embd_head), il);
+                        Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/float(n_embd_head), il);
             }
 
             if (il == n_layer - 1) {
@@ -10849,7 +10901,7 @@ struct llm_build_chatglm : public llm_graph_context {
 
                 cur = build_attn(inp_attn, gf,
                         model.layers[il].wo, NULL,
-                        Qcur, Kcur, Vcur, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
+                        Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
             }
 
             if (il == n_layer - 1) {
@@ -10982,7 +11034,7 @@ struct llm_build_glm4 : public llm_graph_context {
 
                 cur = build_attn(inp_attn, gf,
                         model.layers[il].wo, NULL,
-                        Qcur, Kcur, Vcur, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
+                        Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
             }
 
             if (il == n_layer - 1) {
@@ -11126,7 +11178,7 @@ struct llm_build_nemotron : public llm_graph_context {
 
                 cur = build_attn(inp_attn, gf,
                         model.layers[il].wo, model.layers[il].bo,
-                        Qcur, Kcur, Vcur, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
+                        Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
             }
 
             if (il == n_layer - 1) {
@@ -11257,7 +11309,7 @@ struct llm_build_exaone : public llm_graph_context {
 
                 cur = build_attn(inp_attn, gf,
                         model.layers[il].wo, model.layers[il].bo,
-                        Qcur, Kcur, Vcur, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
+                        Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
             }
 
             if (il == n_layer - 1) {
@@ -12159,7 +12211,7 @@ struct llm_build_chameleon : public llm_graph_context {
 
                 cur = build_attn(inp_attn, gf,
                         model.layers[il].wo, nullptr,
-                        Qcur, Kcur, Vcur, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
+                        Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
 
                 if (hparams.swin_norm) {
                     cur = build_norm(cur,
@@ -12515,7 +12567,7 @@ struct llm_build_plm : public llm_graph_context {
 
                 cur = build_attn(inp_attn, gf,
                         model.layers[il].wo, NULL,
-                        q_states, k_states, v_states, nullptr, kq_scale, il);
+                        q_states, k_states, v_states, nullptr, nullptr, kq_scale, il);
             }
 
             if (il == n_layer - 1) {
@@ -12638,7 +12690,7 @@ struct llm_build_bailingmoe : public llm_graph_context {
 
                 cur = build_attn(inp_attn, gf,
                         model.layers[il].wo, model.layers[il].bo,
-                        Qcur, Kcur, Vcur, nullptr, 1.0f/sqrtf(float(n_rot)), il);
+                        Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_rot)), il);
             }
 
             if (il == n_layer - 1) {
index 0f18dac16733b00aaa4788054a8cfd2bffdbbdef..fd82d106ccda8120ce48f5f547acfd12311a2edb 100644 (file)
@@ -171,6 +171,8 @@ struct llama_layer {
     struct ggml_tensor * wq_b      = nullptr;
     struct ggml_tensor * wkv_a_mqa = nullptr;
     struct ggml_tensor * wkv_b     = nullptr;
+    struct ggml_tensor * wk_b      = nullptr;
+    struct ggml_tensor * wv_b      = nullptr;
     struct ggml_tensor * wq_cross  = nullptr;
     struct ggml_tensor * wk_cross  = nullptr;
     struct ggml_tensor * wv_cross  = nullptr;