]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
llama : differentiate the KV dims in the attention (#4657)
authorpostmasters <redacted>
Tue, 2 Jan 2024 11:51:28 +0000 (03:51 -0800)
committerGitHub <redacted>
Tue, 2 Jan 2024 11:51:28 +0000 (13:51 +0200)
* Add n_key_dim and n_value_dim

Some models use values that are not derived from `n_embd`.
Also remove `n_embd_head` and `n_embd_gqa` because it is not clear
which "head" is referred to (key or value).

Fix issue #4648.

* Fix `llm_build_kqv` to use `n_value_gqa`

* Rebase

* Rename variables

* Fix llm_build_kqv to be more generic wrt n_embd_head_k

* Update default values for n_embd_head_k and n_embd_head_v

Co-authored-by: Georgi Gerganov <redacted>
* Fix llm_load_tensors: the asserts were not backcompat

---------

Co-authored-by: Georgi Gerganov <redacted>
gguf-py/gguf/constants.py
gguf-py/gguf/gguf_writer.py
llama.cpp

index ae62cc575499b9bc8154edb9da2de5ccc5e8bb95..f0a1c51f8dbe863789de841dd5dd5047baf93f8b 100644 (file)
@@ -46,6 +46,8 @@ class Keys:
         HEAD_COUNT_KV     = "{arch}.attention.head_count_kv"
         MAX_ALIBI_BIAS    = "{arch}.attention.max_alibi_bias"
         CLAMP_KQV         = "{arch}.attention.clamp_kqv"
+        KEY_LENGTH        = "{arch}.attention.key_length"
+        VALUE_LENGTH      = "{arch}.attention.value_length"
         LAYERNORM_EPS     = "{arch}.attention.layer_norm_epsilon"
         LAYERNORM_RMS_EPS = "{arch}.attention.layer_norm_rms_epsilon"
 
index 73e02160750b252140d247e8b018904a31e21fe8..d93aaa877171fa6b5c8f76c16a03ecd95245e03f 100644 (file)
@@ -333,6 +333,12 @@ class GGUFWriter:
     def add_head_count_kv(self, count: int) -> None:
         self.add_uint32(Keys.Attention.HEAD_COUNT_KV.format(arch=self.arch), count)
 
+    def add_key_length(self, length: int) -> None:
+        self.add_uint32(Keys.Attention.KEY_LENGTH.format(arch=self.arch), length)
+
+    def add_value_length(self, length: int) -> None:
+        self.add_uint32(Keys.Attention.VALUE_LENGTH.format(arch=self.arch), length)
+
     def add_max_alibi_bias(self, bias: float) -> None:
         self.add_float32(Keys.Attention.MAX_ALIBI_BIAS.format(arch=self.arch), bias)
 
index a833d4c15a9d0d6b4edadaf8c69eca32655844ad..7044640396c95dd259372ad4c43901ca9e15d055 100644 (file)
--- a/llama.cpp
+++ b/llama.cpp
@@ -245,6 +245,8 @@ enum llm_kv {
     LLM_KV_ATTENTION_HEAD_COUNT_KV,
     LLM_KV_ATTENTION_MAX_ALIBI_BIAS,
     LLM_KV_ATTENTION_CLAMP_KQV,
+    LLM_KV_ATTENTION_KEY_LENGTH,
+    LLM_KV_ATTENTION_VALUE_LENGTH,
     LLM_KV_ATTENTION_LAYERNORM_EPS,
     LLM_KV_ATTENTION_LAYERNORM_RMS_EPS,
 
@@ -297,6 +299,8 @@ static std::map<llm_kv, std::string> LLM_KV_NAMES = {
     { LLM_KV_ATTENTION_HEAD_COUNT_KV,       "%s.attention.head_count_kv"          },
     { LLM_KV_ATTENTION_MAX_ALIBI_BIAS,      "%s.attention.max_alibi_bias"         },
     { LLM_KV_ATTENTION_CLAMP_KQV,           "%s.attention.clamp_kqv"              },
+    { LLM_KV_ATTENTION_KEY_LENGTH,          "%s.attention.key_length"             },
+    { LLM_KV_ATTENTION_VALUE_LENGTH,        "%s.attention.value_length"           },
     { LLM_KV_ATTENTION_LAYERNORM_EPS,       "%s.attention.layer_norm_epsilon"     },
     { LLM_KV_ATTENTION_LAYERNORM_RMS_EPS,   "%s.attention.layer_norm_rms_epsilon" },
 
@@ -1284,6 +1288,8 @@ struct llama_hparams {
     uint32_t n_head_kv;
     uint32_t n_layer;
     uint32_t n_rot;
+    uint32_t n_embd_head_k; // dimension of keys (d_k). d_q is assumed to be the same, but there are n_head q heads, and only n_head_kv k-v heads
+    uint32_t n_embd_head_v; // dimension of values (d_v) aka n_embd_head
     uint32_t n_ff;
     uint32_t n_expert = 0;
     uint32_t n_expert_used = 0;
@@ -1310,6 +1316,8 @@ struct llama_hparams {
         if (this->n_head_kv     != other.n_head_kv)     return true;
         if (this->n_layer       != other.n_layer)       return true;
         if (this->n_rot         != other.n_rot)         return true;
+        if (this->n_embd_head_k != other.n_embd_head_k) return true;
+        if (this->n_embd_head_v != other.n_embd_head_v) return true;
         if (this->n_ff          != other.n_ff)          return true;
         if (this->n_expert      != other.n_expert)      return true;
         if (this->n_expert_used != other.n_expert_used) return true;
@@ -1331,12 +1339,12 @@ struct llama_hparams {
         return n_head/n_head_kv;
     }
 
-    uint32_t n_embd_head() const {
-        return n_embd/n_head;
+    uint32_t n_embd_k_gqa() const { // dimension of key embeddings across all k-v heads
+        return n_embd_head_k * n_head_kv;
     }
 
-    uint32_t n_embd_gqa() const {
-        return n_embd/n_gqa();
+    uint32_t n_embd_v_gqa() const { // dimension of value embeddings across all k-v heads
+        return n_embd_head_v * n_head_kv;
     }
 };
 
@@ -1645,8 +1653,9 @@ static bool llama_kv_cache_init(
                           uint32_t   n_ctx,
                                int   n_gpu_layers,
                               bool   offload) {
-    const uint32_t n_embd  = hparams.n_embd_gqa();
-    const uint32_t n_layer = hparams.n_layer;
+    const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa();
+    const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa();
+    const uint32_t n_layer      = hparams.n_layer;
 
     cache.has_shift = false;
 
@@ -1677,8 +1686,8 @@ static bool llama_kv_cache_init(
     const int i_gpu_start = (int) n_layer - n_gpu_layers;
 
     for (int i = 0; i < (int) n_layer; i++) {
-        ggml_tensor * k = ggml_new_tensor_1d(cache.ctx, ktype, n_embd*n_ctx);
-        ggml_tensor * v = ggml_new_tensor_1d(cache.ctx, vtype, n_embd*n_ctx);
+        ggml_tensor * k = ggml_new_tensor_1d(cache.ctx, ktype, n_embd_k_gqa*n_ctx);
+        ggml_tensor * v = ggml_new_tensor_1d(cache.ctx, vtype, n_embd_v_gqa*n_ctx);
         ggml_format_name(k, "cache_k_l%d", i);
         ggml_format_name(v, "cache_v_l%d", i);
         cache.k_l.push_back(k);
@@ -2672,6 +2681,12 @@ static void llm_load_hparams(
         // gpt-j n_rot = rotary_dim
     }
 
+    hparams.n_embd_head_k = hparams.n_embd / hparams.n_head;
+    ml.get_key(LLM_KV_ATTENTION_KEY_LENGTH, hparams.n_embd_head_k, false);
+
+    hparams.n_embd_head_v = hparams.n_embd / hparams.n_head;
+    ml.get_key(LLM_KV_ATTENTION_VALUE_LENGTH, hparams.n_embd_head_v, false);
+
     // arch-specific KVs
     switch (model.arch) {
         case LLM_ARCH_LLAMA:
@@ -3082,8 +3097,12 @@ static void llm_load_print_meta(llama_model_loader & ml, llama_model & model) {
     LLAMA_LOG_INFO("%s: n_head           = %u\n",     __func__, hparams.n_head);
     LLAMA_LOG_INFO("%s: n_head_kv        = %u\n",     __func__, hparams.n_head_kv);
     LLAMA_LOG_INFO("%s: n_layer          = %u\n",     __func__, hparams.n_layer);
-    LLAMA_LOG_INFO("%s: n_rot            = %u\n",     __func__, hparams.n_rot); // a.k.a. n_embd_head, n_head_dim
+    LLAMA_LOG_INFO("%s: n_rot            = %u\n",     __func__, hparams.n_rot);
+    LLAMA_LOG_INFO("%s: n_embd_head_k    = %u\n",     __func__, hparams.n_embd_head_k);
+    LLAMA_LOG_INFO("%s: n_embd_head_v    = %u\n",     __func__, hparams.n_embd_head_v);
     LLAMA_LOG_INFO("%s: n_gqa            = %u\n",     __func__, hparams.n_gqa());
+    LLAMA_LOG_INFO("%s: n_embd_k_gqa     = %u\n",     __func__, hparams.n_embd_k_gqa());
+    LLAMA_LOG_INFO("%s: n_embd_v_gqa     = %u\n",     __func__, hparams.n_embd_v_gqa());
     LLAMA_LOG_INFO("%s: f_norm_eps       = %.1e\n",   __func__, hparams.f_norm_eps);
     LLAMA_LOG_INFO("%s: f_norm_rms_eps   = %.1e\n",   __func__, hparams.f_norm_rms_eps);
     LLAMA_LOG_INFO("%s: f_clamp_kqv      = %.1e\n",   __func__, hparams.f_clamp_kqv);
@@ -3173,10 +3192,11 @@ static bool llm_load_tensors(
 
     // create tensors for the weights
     {
-        const int64_t n_embd     = hparams.n_embd;
-        const int64_t n_embd_gqa = hparams.n_embd_gqa();
-        const int64_t n_layer    = hparams.n_layer;
-        const int64_t n_vocab    = hparams.n_vocab;
+        const int64_t n_embd       = hparams.n_embd;
+        const int64_t n_embd_k_gqa = hparams.n_embd_k_gqa();
+        const int64_t n_embd_v_gqa = hparams.n_embd_v_gqa();
+        const int64_t n_layer      = hparams.n_layer;
+        const int64_t n_vocab      = hparams.n_vocab;
 
         const auto tn = LLM_TN(model.arch);
         switch (model.arch) {
@@ -3202,7 +3222,10 @@ static bool llm_load_tensors(
                         model.output      = ml.create_tensor(ctx, tn(LLM_TENSOR_OUTPUT,      "weight"), {n_embd, n_vocab}, backend_output);
                     }
 
-                    const uint32_t n_ff = hparams.n_ff;
+                    const uint32_t n_ff      = hparams.n_ff;
+                    const int64_t n_embd_gqa = n_embd_v_gqa;
+                    GGML_ASSERT(n_embd_gqa == n_embd / hparams.n_gqa());
+                    GGML_ASSERT(n_embd_gqa == n_embd_k_gqa);
 
                     const int i_gpu_start = n_layer - n_gpu_layers;
 
@@ -3270,7 +3293,10 @@ static bool llm_load_tensors(
                         model.output      = ml.create_tensor(ctx, tn(LLM_TENSOR_OUTPUT,      "weight"), {n_embd, n_vocab}, backend_output);
                     }
 
-                    const uint32_t n_ff = hparams.n_ff;
+                    const uint32_t n_ff      = hparams.n_ff;
+                    const int64_t n_embd_gqa = n_embd_v_gqa;
+                    GGML_ASSERT(n_embd_gqa == n_embd / hparams.n_gqa());
+                    GGML_ASSERT(n_embd_gqa == n_embd_k_gqa);
 
                     const int i_gpu_start = n_layer - n_gpu_layers;
 
@@ -3318,7 +3344,10 @@ static bool llm_load_tensors(
                         model.output        = ml.create_tensor(ctx, tn(LLM_TENSOR_OUTPUT,      "weight"), {n_embd, n_vocab}, backend_output);
                     }
 
-                    const uint32_t n_ff = hparams.n_ff;
+                    const uint32_t n_ff      = hparams.n_ff;
+                    const int64_t n_embd_gqa = n_embd_v_gqa;
+                    GGML_ASSERT(n_embd_gqa == n_embd / hparams.n_gqa());
+                    GGML_ASSERT(n_embd_gqa == n_embd_k_gqa);
 
                     const int i_gpu_start = n_layer - n_gpu_layers;
 
@@ -3368,7 +3397,10 @@ static bool llm_load_tensors(
                         model.output        = ml.create_tensor(ctx, tn(LLM_TENSOR_OUTPUT,      "weight"), {n_embd, n_vocab}, backend_output);
                     }
 
-                    const uint32_t n_ff = hparams.n_ff;
+                    const uint32_t n_ff      = hparams.n_ff;
+                    const int64_t n_embd_gqa = n_embd_v_gqa;
+                    GGML_ASSERT(n_embd_gqa == n_embd / hparams.n_gqa());
+                    GGML_ASSERT(n_embd_gqa == n_embd_k_gqa);
 
                     const int i_gpu_start = n_layer - n_gpu_layers;
 
@@ -3420,7 +3452,11 @@ static bool llm_load_tensors(
                         model.output         = ml.create_tensor(ctx, tn(LLM_TENSOR_OUTPUT,      "weight"), {n_embd, n_vocab}, backend_output);
                     }
 
-                    const uint32_t n_ff = hparams.n_ff;
+                    const uint32_t n_ff      = hparams.n_ff;
+                    const int64_t n_embd_gqa = n_embd_v_gqa;
+                    GGML_ASSERT(n_embd_gqa == n_embd / hparams.n_gqa());
+                    GGML_ASSERT(n_embd_gqa == n_embd_k_gqa);
+
                     const int i_gpu_start = n_layer - n_gpu_layers;
                     model.layers.resize(n_layer);
                     for (uint32_t i = 0; i < n_layer; ++i) {
@@ -3469,7 +3505,10 @@ static bool llm_load_tensors(
                         model.output        = ml.create_tensor(ctx, tn(LLM_TENSOR_OUTPUT,      "weight"), {n_embd, n_vocab}, backend_output);
                     }
 
-                    const uint32_t n_ff = hparams.n_ff;
+                    const uint32_t n_ff      = hparams.n_ff;
+                    const int64_t n_embd_gqa = n_embd_v_gqa;
+                    GGML_ASSERT(n_embd_gqa == n_embd / hparams.n_gqa());
+                    GGML_ASSERT(n_embd_gqa == n_embd_k_gqa);
 
                     const int i_gpu_start = n_layer - n_gpu_layers;
 
@@ -3520,7 +3559,10 @@ static bool llm_load_tensors(
                         model.output        = ml.create_tensor(ctx, tn(LLM_TENSOR_OUTPUT,      "weight"), {n_embd, n_vocab}, backend_output);
                     }
 
-                    const uint32_t n_ff = hparams.n_ff;
+                    const uint32_t n_ff      = hparams.n_ff;
+                    const int64_t n_embd_gqa = n_embd_v_gqa;
+                    GGML_ASSERT(n_embd_gqa == n_embd / hparams.n_gqa());
+                    GGML_ASSERT(n_embd_gqa == n_embd_k_gqa);
 
                     const int i_gpu_start = n_layer - n_gpu_layers;
 
@@ -3567,7 +3609,10 @@ static bool llm_load_tensors(
                         model.output      = ml.create_tensor(ctx, tn(LLM_TENSOR_OUTPUT,      "weight"), {n_embd, n_vocab}, backend_output);
                     }
 
-                    const uint32_t n_ff = hparams.n_ff;
+                    const uint32_t n_ff      = hparams.n_ff;
+                    const int64_t n_embd_gqa = n_embd_v_gqa;
+                    GGML_ASSERT(n_embd_gqa == n_embd / hparams.n_gqa());
+                    GGML_ASSERT(n_embd_gqa == n_embd_k_gqa);
 
                     const int i_gpu_start = n_layer - n_gpu_layers;
 
@@ -3665,7 +3710,10 @@ static bool llm_load_tensors(
                         model.output_b      = ml.create_tensor(ctx, tn(LLM_TENSOR_OUTPUT,      "bias"),   {n_vocab},         backend_output);
                     }
 
-                    const uint32_t n_ff = hparams.n_ff;
+                    const uint32_t n_ff      = hparams.n_ff;
+                    const int64_t n_embd_gqa = n_embd_v_gqa;
+                    GGML_ASSERT(n_embd_gqa == n_embd / hparams.n_gqa());
+                    GGML_ASSERT(n_embd_gqa == n_embd_k_gqa);
 
                     const int i_gpu_start = n_layer - n_gpu_layers;
 
@@ -3714,7 +3762,10 @@ static bool llm_load_tensors(
                         model.output      = ml.create_tensor(ctx, tn(LLM_TENSOR_OUTPUT,      "weight"), {n_embd, n_vocab}, backend_output);
                     }
 
-                    const uint32_t n_ff = hparams.n_ff;
+                    const uint32_t n_ff      = hparams.n_ff;
+                    const int64_t n_embd_gqa = n_embd_v_gqa;
+                    GGML_ASSERT(n_embd_gqa == n_embd / hparams.n_gqa());
+                    GGML_ASSERT(n_embd_gqa == n_embd_k_gqa);
 
                     const int i_gpu_start = n_layer - n_gpu_layers;
 
@@ -3761,7 +3812,10 @@ static bool llm_load_tensors(
                         model.output        = ml.create_tensor(ctx, tn(LLM_TENSOR_OUTPUT,      "weight"), {n_embd, n_vocab}, backend_output);
                     }
 
-                    const uint32_t n_ff = hparams.n_ff;
+                    const uint32_t n_ff      = hparams.n_ff;
+                    const int64_t n_embd_gqa = n_embd_v_gqa;
+                    GGML_ASSERT(n_embd_gqa == n_embd / hparams.n_gqa());
+                    GGML_ASSERT(n_embd_gqa == n_embd_k_gqa);
 
                     const int i_gpu_start = n_layer - n_gpu_layers;
 
@@ -4000,8 +4054,8 @@ static struct ggml_tensor * llm_build_inp_embd(
     return inpL;
 }
 
-// Persimmon: n_rot = n_embd_head/2
-// Other:     n_rot = n_embd_head
+// Persimmon: n_rot = n_embd_head_k/2
+// Other:     n_rot = n_embd_head_k
 static void llm_build_k_shift(
       struct ggml_context * ctx,
       const llama_hparams & hparams,
@@ -4014,17 +4068,17 @@ static void llm_build_k_shift(
                   float     freq_base,
                   float     freq_scale,
        const llm_build_cb & cb) {
-    const int64_t n_layer     = hparams.n_layer;
-    const int64_t n_head_kv   = hparams.n_head_kv;
-    const int64_t n_embd_gqa  = hparams.n_embd_gqa();
-    const int64_t n_embd_head = hparams.n_embd_head();
-    const int32_t n_orig_ctx  = cparams.n_yarn_orig_ctx;
-    const float   ext_factor  = cparams.yarn_ext_factor;
-    const float   attn_factor = cparams.yarn_attn_factor;
-    const float   beta_fast   = cparams.yarn_beta_fast;
-    const float   beta_slow   = cparams.yarn_beta_slow;
-
-    GGML_ASSERT(n_embd_head % n_rot == 0);
+    const int64_t n_layer       = hparams.n_layer;
+    const int64_t n_head_kv     = hparams.n_head_kv;
+    const int64_t n_embd_head_k = hparams.n_embd_head_k;
+    const int64_t n_embd_k_gqa  = hparams.n_embd_k_gqa();
+    const int32_t n_orig_ctx    = cparams.n_yarn_orig_ctx;
+    const float   ext_factor    = cparams.yarn_ext_factor;
+    const float   attn_factor   = cparams.yarn_attn_factor;
+    const float   beta_fast     = cparams.yarn_beta_fast;
+    const float   beta_slow     = cparams.yarn_beta_slow;
+
+    GGML_ASSERT(n_embd_head_k % n_rot == 0);
 
     struct ggml_tensor * K_shift = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, n_ctx);
     cb(K_shift, "K_shift", -1);
@@ -4042,9 +4096,9 @@ static void llm_build_k_shift(
             // we rotate only the first n_rot dimensions
             ggml_rope_custom_inplace(ctx,
                     ggml_view_3d(ctx, kv.k_l[il],
-                        n_embd_head, n_head_kv, n_ctx,
-                        ggml_row_size(kv.k_l[il]->type, n_embd_head),
-                        ggml_row_size(kv.k_l[il]->type, n_embd_gqa),
+                        n_embd_head_k, n_head_kv, n_ctx,
+                        ggml_row_size(kv.k_l[il]->type, n_embd_head_k),
+                        ggml_row_size(kv.k_l[il]->type, n_embd_k_gqa),
                         0),
                     K_shift, n_rot, rope_type, 0, n_orig_ctx, freq_base, freq_scale,
                     ext_factor, attn_factor, beta_fast, beta_slow);
@@ -4065,18 +4119,19 @@ static void llm_build_kv_store(
                     int32_t   kv_head,
          const llm_build_cb & cb,
                     int64_t   il) {
-    const int64_t n_embd_gqa = hparams.n_embd_gqa();
+    const int64_t n_embd_k_gqa = hparams.n_embd_k_gqa();
+    const int64_t n_embd_v_gqa = hparams.n_embd_v_gqa();
 
     // compute the transposed [n_tokens, n_embd] V matrix
-    struct ggml_tensor * v_cur_t = ggml_transpose(ctx, ggml_reshape_2d(ctx, v_cur, n_embd_gqa, n_tokens));
+    struct ggml_tensor * v_cur_t = ggml_transpose(ctx, ggml_reshape_2d(ctx, v_cur, n_embd_v_gqa, n_tokens));
     //struct ggml_tensor * v_cur_t = ggml_transpose(ctx, v_cur); // TODO: reshape above is likely not needed
     cb(v_cur_t, "v_cur_t", il);
 
-    struct ggml_tensor * k_cache_view = ggml_view_1d(ctx, kv.k_l[il], n_tokens*n_embd_gqa,
-            (ggml_row_size(kv.k_l[il]->type, n_embd_gqa))*kv_head);
+    struct ggml_tensor * k_cache_view = ggml_view_1d(ctx, kv.k_l[il], n_tokens*n_embd_k_gqa,
+            (ggml_row_size(kv.k_l[il]->type, n_embd_k_gqa))*kv_head);
     cb(k_cache_view, "k_cache_view", il);
 
-    struct ggml_tensor * v_cache_view = ggml_view_2d(ctx, kv.v_l[il], n_tokens, n_embd_gqa,
+    struct ggml_tensor * v_cache_view = ggml_view_2d(ctx, kv.v_l[il], n_tokens, n_embd_v_gqa,
             (  n_ctx)*ggml_element_size(kv.v_l[il]),
             (kv_head)*ggml_element_size(kv.v_l[il]));
     cb(v_cache_view, "v_cache_view", il);
@@ -4226,20 +4281,20 @@ static struct ggml_tensor * llm_build_kqv(
                     float     kq_scale,
          const llm_build_cb & cb,
                     int       il) {
-    const int64_t n_embd      = hparams.n_embd;
-    const int64_t n_head      = hparams.n_head;
-    const int64_t n_head_kv   = hparams.n_head_kv;
-    const int64_t n_embd_head = hparams.n_embd_head();
-    const int64_t n_embd_gqa  = hparams.n_embd_gqa();
+    const int64_t n_head        = hparams.n_head;
+    const int64_t n_head_kv     = hparams.n_head_kv;
+    const int64_t n_embd_head_k = hparams.n_embd_head_k;
+    const int64_t n_embd_k_gqa  = hparams.n_embd_k_gqa();
+    const int64_t n_embd_head_v = hparams.n_embd_head_v;
 
     struct ggml_tensor * q = ggml_permute(ctx, q_cur, 0, 2, 1, 3);
     cb(q, "q", il);
 
     struct ggml_tensor * k =
         ggml_view_3d(ctx, kv.k_l[il],
-                n_embd_head, n_kv, n_head_kv,
-                ggml_row_size(kv.k_l[il]->type, n_embd_gqa),
-                ggml_row_size(kv.k_l[il]->type, n_embd_head),
+                n_embd_head_k, n_kv, n_head_kv,
+                ggml_row_size(kv.k_l[il]->type, n_embd_k_gqa),
+                ggml_row_size(kv.k_l[il]->type, n_embd_head_k),
                 0);
     cb(k, "k", il);
 
@@ -4278,9 +4333,9 @@ static struct ggml_tensor * llm_build_kqv(
     // split cached v into n_head heads
     struct ggml_tensor * v =
         ggml_view_3d(ctx, kv.v_l[il],
-                n_kv, n_embd_head, n_head_kv,
+                n_kv, n_embd_head_v, n_head_kv,
                 ggml_element_size(kv.v_l[il])*n_ctx,
-                ggml_element_size(kv.v_l[il])*n_ctx*n_embd_head,
+                ggml_element_size(kv.v_l[il])*n_ctx*n_embd_head_v,
                 0);
     cb(v, "v", il);
 
@@ -4290,7 +4345,7 @@ static struct ggml_tensor * llm_build_kqv(
     struct ggml_tensor * kqv_merged = ggml_permute(ctx, kqv, 0, 2, 1, 3);
     cb(kqv_merged, "kqv_merged", il);
 
-    struct ggml_tensor * cur = ggml_cont_2d(ctx, kqv_merged, n_embd, n_tokens);
+    struct ggml_tensor * cur = ggml_cont_2d(ctx, kqv_merged, n_embd_head_k*n_head, n_tokens);
     cb(cur, "kqv_merged_cont", il);
 
     cur = ggml_mul_mat(ctx, wo, cur);
@@ -4317,8 +4372,10 @@ struct llm_build_context {
     const int64_t n_ctx;       // user-specified context size (can be different from n_ctx_train)
     const int64_t n_head;
     const int64_t n_head_kv;
-    const int64_t n_embd_head;
-    const int64_t n_embd_gqa;
+    const int64_t n_embd_head_k;
+    const int64_t n_embd_k_gqa;
+    const int64_t n_embd_head_v;
+    const int64_t n_embd_v_gqa;
     const int64_t n_expert;
     const int64_t n_expert_used;
 
@@ -4360,8 +4417,10 @@ struct llm_build_context {
         n_ctx            (cparams.n_ctx),
         n_head           (hparams.n_head),
         n_head_kv        (hparams.n_head_kv),
-        n_embd_head      (hparams.n_embd_head()),
-        n_embd_gqa       (hparams.n_embd_gqa()),
+        n_embd_head_k    (hparams.n_embd_head_k),
+        n_embd_k_gqa     (hparams.n_embd_k_gqa()),
+        n_embd_head_v    (hparams.n_embd_head_v),
+        n_embd_v_gqa     (hparams.n_embd_v_gqa()),
         n_expert         (hparams.n_expert),
         n_expert_used    (hparams.n_expert_used),
         freq_base        (cparams.rope_freq_base),
@@ -4404,6 +4463,8 @@ struct llm_build_context {
     struct ggml_cgraph * build_llama() {
         struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, LLAMA_MAX_NODES, false);
 
+        const int64_t n_embd_head = hparams.n_embd_head_v;
+        GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
         GGML_ASSERT(n_embd_head == hparams.n_rot);
 
         struct ggml_tensor * cur;
@@ -4588,6 +4649,9 @@ struct llm_build_context {
     struct ggml_cgraph * build_baichuan() {
         struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, LLAMA_MAX_NODES, false);
 
+        const int64_t n_embd_head = hparams.n_embd_head_v;
+        GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
+
         struct ggml_tensor * cur;
         struct ggml_tensor * inpL;
 
@@ -4705,6 +4769,11 @@ struct llm_build_context {
     struct ggml_cgraph * build_falcon() {
         struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, LLAMA_MAX_NODES, false);
 
+        const int64_t n_embd_head = hparams.n_embd_head_v;
+        const int64_t n_embd_gqa  = hparams.n_embd_v_gqa();
+        GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
+        GGML_ASSERT(n_embd_gqa  == n_embd);
+
         struct ggml_tensor * cur;
         struct ggml_tensor * inpL;
 
@@ -4824,6 +4893,11 @@ struct llm_build_context {
     struct ggml_cgraph * build_starcoder() {
         struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, LLAMA_MAX_NODES, false);
 
+        const int64_t n_embd_head = hparams.n_embd_head_v;
+        const int64_t n_embd_gqa  = hparams.n_embd_v_gqa();
+        GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
+        GGML_ASSERT(n_embd_gqa  == n_embd);
+
         struct ggml_tensor * cur;
         struct ggml_tensor * pos;
         struct ggml_tensor * inpL;
@@ -4920,7 +4994,12 @@ struct llm_build_context {
     struct ggml_cgraph * build_persimmon() {
         struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, LLAMA_MAX_NODES, false);
 
-        const int64_t n_rot = n_embd_head / 2;
+        const int64_t n_embd_head = hparams.n_embd_head_v;
+        const int64_t n_embd_gqa  = hparams.n_embd_v_gqa();
+        GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
+        GGML_ASSERT(n_embd_gqa  == n_embd);
+
+        const int64_t n_rot = n_embd_head_k / 2;
 
         struct ggml_tensor * cur;
         struct ggml_tensor * inpL;
@@ -5129,6 +5208,11 @@ struct llm_build_context {
     struct ggml_cgraph * build_refact() {
         struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, LLAMA_MAX_NODES, false);
 
+        const int64_t n_embd_head = hparams.n_embd_head_v;
+        const int64_t n_embd_gqa  = hparams.n_embd_v_gqa();
+        GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
+        GGML_ASSERT(n_embd_gqa  == n_embd);
+
         struct ggml_tensor * cur;
         struct ggml_tensor * inpL;
 
@@ -5217,6 +5301,11 @@ struct llm_build_context {
     struct ggml_cgraph * build_bloom() {
         struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, LLAMA_MAX_NODES, false);
 
+        const int64_t n_embd_head = hparams.n_embd_head_v;
+        const int64_t n_embd_gqa  = hparams.n_embd_v_gqa();
+        GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
+        GGML_ASSERT(n_embd_gqa  == n_embd);
+
         struct ggml_tensor * cur;
         struct ggml_tensor * inpL;
 
@@ -5308,6 +5397,11 @@ struct llm_build_context {
     struct ggml_cgraph * build_mpt() {
         struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, LLAMA_MAX_NODES, false);
 
+        const int64_t n_embd_head = hparams.n_embd_head_v;
+        const int64_t n_embd_gqa  = hparams.n_embd_v_gqa();
+        GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
+        GGML_ASSERT(n_embd_gqa  == n_embd);
+
         struct ggml_tensor * cur;
         struct ggml_tensor * inpL;
 
@@ -5403,6 +5497,9 @@ struct llm_build_context {
     struct ggml_cgraph * build_stablelm() {
         struct ggml_cgraph * gf = ggml_new_graph(ctx0);
 
+        const int64_t n_embd_head = hparams.n_embd_head_v;
+        GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
+
         struct ggml_tensor * cur;
         struct ggml_tensor * inpL;
 
@@ -5513,6 +5610,9 @@ struct llm_build_context {
     struct ggml_cgraph * build_qwen() {
         struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, LLAMA_MAX_NODES, false);
 
+        const int64_t n_embd_head = hparams.n_embd_head_v;
+        GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
+
         struct ggml_tensor * cur;
         struct ggml_tensor * inpL;
 
@@ -5624,6 +5724,11 @@ struct llm_build_context {
     struct ggml_cgraph * build_phi2() {
         struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, LLAMA_MAX_NODES, false);
 
+        const int64_t n_embd_head = hparams.n_embd_head_v;
+        const int64_t n_embd_gqa  = hparams.n_embd_v_gqa();
+        GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
+        GGML_ASSERT(n_embd_gqa  == n_embd);
+
         struct ggml_tensor * cur;
         struct ggml_tensor * attn_norm_output;
         struct ggml_tensor * ffn_output;
@@ -5736,6 +5841,9 @@ struct llm_build_context {
     struct ggml_cgraph * build_plamo() {
         struct ggml_cgraph * gf = ggml_new_graph(ctx0);
 
+        const int64_t n_embd_head = hparams.n_embd_head_v;
+        GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
+
         struct ggml_tensor * cur;
         struct ggml_tensor * inpL;
 
@@ -5840,6 +5948,11 @@ struct llm_build_context {
     struct ggml_cgraph * build_gpt2() {
         struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, LLAMA_MAX_NODES, false);
 
+        const int64_t n_embd_head = hparams.n_embd_head_v;
+        const int64_t n_embd_gqa  = hparams.n_embd_v_gqa();
+        GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
+        GGML_ASSERT(n_embd_gqa  == n_embd);
+
         struct ggml_tensor * cur;
         struct ggml_tensor * pos;
         struct ggml_tensor * inpL;
@@ -9627,8 +9740,8 @@ struct llama_context * llama_new_context_with_model(
     const ggml_type type_k = params.type_k;
     const ggml_type type_v = params.type_v;
 
-    GGML_ASSERT(hparams.n_embd_head() % ggml_blck_size(type_k) == 0);
-    GGML_ASSERT(hparams.n_embd_head() % ggml_blck_size(type_v) == 0);
+    GGML_ASSERT(hparams.n_embd_head_k % ggml_blck_size(type_k) == 0);
+    GGML_ASSERT(hparams.n_embd_head_v % ggml_blck_size(type_v) == 0);
 
     // reserve memory for context buffers
     if (!hparams.vocab_only) {
@@ -10172,9 +10285,10 @@ static void llama_copy_state_data_internal(struct llama_context * ctx, llama_dat
         const auto & hparams = ctx->model.hparams;
         const auto & cparams = ctx->cparams;
 
-        const auto   n_layer = hparams.n_layer;
-        const auto   n_embd  = hparams.n_embd_gqa();
-        const auto   n_ctx   = cparams.n_ctx;
+        const auto   n_layer      = hparams.n_layer;
+        const auto   n_embd_k_gqa = hparams.n_embd_k_gqa();
+        const auto   n_embd_v_gqa = hparams.n_embd_v_gqa();
+        const auto   n_ctx        = cparams.n_ctx;
 
         const size_t   kv_buf_size = ggml_backend_buffer_get_size(kv_self.buf);
         const uint32_t kv_head     = kv_self.head;
@@ -10196,15 +10310,15 @@ static void llama_copy_state_data_internal(struct llama_context * ctx, llama_dat
             std::vector<struct ggml_tensor *> vout2d(n_layer);
 
             for (int il = 0; il < (int) n_layer; ++il) {
-                kout2d[il] = ggml_new_tensor_2d(cpy_ctx, kv_self.k_l[il]->type, n_embd, kv_head);
-                vout2d[il] = ggml_new_tensor_2d(cpy_ctx, kv_self.v_l[il]->type, kv_head, n_embd);
+                kout2d[il] = ggml_new_tensor_2d(cpy_ctx, kv_self.k_l[il]->type, n_embd_k_gqa, kv_head);
+                vout2d[il] = ggml_new_tensor_2d(cpy_ctx, kv_self.v_l[il]->type, kv_head, n_embd_v_gqa);
 
                 ggml_tensor * k2d = ggml_view_2d(cpy_ctx, kv_self.k_l[il],
-                        n_embd, kv_head,
-                        elt_size*n_embd, 0);
+                        n_embd_k_gqa, kv_head,
+                        elt_size*n_embd_k_gqa, 0);
 
                 ggml_tensor * v2d = ggml_view_2d(cpy_ctx, kv_self.v_l[il],
-                        kv_head, n_embd,
+                        kv_head, n_embd_v_gqa,
                         elt_size*n_ctx, 0);
 
                 ggml_build_forward_expand(gf, ggml_cpy(cpy_ctx, k2d, kout2d[il]));
@@ -10311,9 +10425,10 @@ size_t llama_set_state_data(struct llama_context * ctx, uint8_t * src) {
         const auto & hparams = ctx->model.hparams;
         const auto & cparams = ctx->cparams;
 
-        const int    n_layer = hparams.n_layer;
-        const int    n_embd  = hparams.n_embd_gqa();
-        const int    n_ctx   = cparams.n_ctx;
+        const int    n_layer      = hparams.n_layer;
+        const int    n_embd_k_gqa = hparams.n_embd_k_gqa();
+        const int    n_embd_v_gqa = hparams.n_embd_v_gqa();
+        const int    n_ctx        = cparams.n_ctx;
 
         size_t   kv_buf_size;
         uint32_t kv_head;
@@ -10337,15 +10452,15 @@ size_t llama_set_state_data(struct llama_context * ctx, uint8_t * src) {
             std::vector<struct ggml_tensor *> vin2d(n_layer);
 
             for (int il = 0; il < n_layer; ++il) {
-                kin2d[il] = ggml_new_tensor_2d(cpy_ctx, kv_self.k_l[il]->type, n_embd, kv_head);
-                vin2d[il] = ggml_new_tensor_2d(cpy_ctx, kv_self.v_l[il]->type, kv_head, n_embd);
+                kin2d[il] = ggml_new_tensor_2d(cpy_ctx, kv_self.k_l[il]->type, n_embd_k_gqa, kv_head);
+                vin2d[il] = ggml_new_tensor_2d(cpy_ctx, kv_self.v_l[il]->type, kv_head, n_embd_v_gqa);
 
                 ggml_tensor * k2d = ggml_view_2d(cpy_ctx, kv_self.k_l[il],
-                    n_embd, kv_head,
-                    elt_size*n_embd, 0);
+                    n_embd_k_gqa, kv_head,
+                    elt_size*n_embd_k_gqa, 0);
 
                 ggml_tensor * v2d = ggml_view_2d(cpy_ctx, kv_self.v_l[il],
-                    kv_head, n_embd,
+                    kv_head, n_embd_v_gqa,
                     elt_size*n_ctx, 0);
 
                 ggml_build_forward_expand(gf, ggml_cpy(cpy_ctx, kin2d[il], k2d));