]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
llama : use n_embd_head_v when reshaping kqv (#7327)
authorfairydreaming <redacted>
Fri, 17 May 2024 11:24:38 +0000 (13:24 +0200)
committerGitHub <redacted>
Fri, 17 May 2024 11:24:38 +0000 (14:24 +0300)
* llama : use n_embd_head_v instead of n_embd_head_k when reshaping kqv

* llama : use n_embd_v_gqa and n_embd_head_v instead of n_embd_k_gqa and n_embd_head_k when making a view of cached value vectors.

---------

Co-authored-by: Stanisław Szymczyk <redacted>
llama.cpp

index c5a1fa0f60a1f8645f8bdeafbef0b86f8d4369c6..e11f0ac4b72dc80062a4a9ed2f5b5f546943c1b3 100644 (file)
--- a/llama.cpp
+++ b/llama.cpp
@@ -6622,6 +6622,7 @@ static struct ggml_tensor * llm_build_kqv(
     const int64_t n_embd_head_k = hparams.n_embd_head_k;
     const int64_t n_embd_k_gqa  = hparams.n_embd_k_gqa();
     const int64_t n_embd_head_v = hparams.n_embd_head_v;
+    const int64_t n_embd_v_gqa  = hparams.n_embd_v_gqa();
 
     struct ggml_tensor * q = ggml_permute(ctx, q_cur, 0, 2, 1, 3);
     cb(q, "q", il);
@@ -6644,8 +6645,8 @@ static struct ggml_tensor * llm_build_kqv(
         struct ggml_tensor * v =
             ggml_view_3d(ctx, kv.v_l[il],
                     n_embd_head_v, n_kv, n_head_kv,
-                    ggml_row_size(kv.v_l[il]->type, n_embd_k_gqa),
-                    ggml_row_size(kv.v_l[il]->type, n_embd_head_k),
+                    ggml_row_size(kv.v_l[il]->type, n_embd_v_gqa),
+                    ggml_row_size(kv.v_l[il]->type, n_embd_head_v),
                     0);
         cb(v, "v", il);
 
@@ -6655,7 +6656,7 @@ static struct ggml_tensor * llm_build_kqv(
             ggml_flash_attn_ext_set_prec(cur, GGML_PREC_F32);
         }
 
-        cur = ggml_reshape_2d(ctx, cur, n_embd_head_k*n_head, n_tokens);
+        cur = ggml_reshape_2d(ctx, cur, n_embd_head_v*n_head, n_tokens);
     } else {
         struct ggml_tensor * kq = ggml_mul_mat(ctx, k, q);
         cb(kq, "kq", il);
@@ -6700,7 +6701,7 @@ static struct ggml_tensor * llm_build_kqv(
         struct ggml_tensor * kqv_merged = ggml_permute(ctx, kqv, 0, 2, 1, 3);
         cb(kqv_merged, "kqv_merged", il);
 
-        cur = ggml_cont_2d(ctx, kqv_merged, n_embd_head_k*n_head, n_tokens);
+        cur = ggml_cont_2d(ctx, kqv_merged, n_embd_head_v*n_head, n_tokens);
         cb(cur, "kqv_merged_cont", il);
     }