]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
model : avoid ggml_cont_3d for fused QKV weights (#15662)
authorGeorgi Gerganov <redacted>
Mon, 8 Sep 2025 07:25:33 +0000 (10:25 +0300)
committerGitHub <redacted>
Mon, 8 Sep 2025 07:25:33 +0000 (10:25 +0300)
* model : avoid ggml_cont_3d for fused QKV weights

ggml-ci

* kv-cache : make cpy_k and cpy_v implementation more readable

ggml-ci

* cont : add comments

ggml-ci

* cont : minor fix [no ci]

* cont : one more fix

* cont : clarity

ggml-ci

* kv-cache : require contiguous heads of k_cur and v_cur

ggml-ci

src/llama-kv-cache.cpp
src/llama-kv-cache.h
src/llama-model.cpp

index ae35f74201e9cccd51a04102acf5cfac43929d6c..885be072a75c8536c02e45521ecb840ebb1e9b6b 100644 (file)
@@ -1018,16 +1018,33 @@ ggml_tensor * llama_kv_cache::cpy_k(ggml_context * ctx, ggml_tensor * k_cur, ggm
 
     const int32_t ikv = map_layer_ids.at(il);
 
-    auto * k = layers[ikv].k;
+    ggml_tensor * k = layers[ikv].k;
+
+    const int64_t n_embd_head = k_cur->ne[0];
+    const int64_t n_head      = k_cur->ne[1];
+    const int64_t n_tokens    = k_cur->ne[2];
+
+    const int64_t n_embd_gqa = n_embd_head*n_head;
 
-    const int64_t n_tokens = k_cur->ne[2];
+    // we can merge dims 0 and 1
+    // TODO: add ggml helper function for this?
+    GGML_ASSERT(ggml_row_size(k_cur->type, n_embd_head) == k_cur->nb[1]);
 
-    k_cur = ggml_reshape_2d(ctx, k_cur, k->ne[0], n_tokens);
+    k_cur = ggml_view_2d(ctx, k_cur, n_embd_gqa, n_tokens, k_cur->nb[2], 0);
 
-    if (k->ne[2] > 1) {
-        k = ggml_reshape_2d(ctx, k, k->ne[0], k->ne[1]*k->ne[2]);
+    const int64_t n_stream = k->ne[2];
+
+    if (n_stream > 1) {
+        const int64_t kv_size = get_size();
+
+        assert(n_embd_gqa == k->ne[0]);
+        assert(kv_size    == k->ne[1]);
+
+        // merge the buffer across all streams because the idxs are global
+        k = ggml_reshape_2d(ctx, k, n_embd_gqa, kv_size*n_stream);
     }
 
+    // store the current K values into the cache
     return ggml_set_rows(ctx, k, k_cur, k_idxs);
 }
 
@@ -1038,28 +1055,51 @@ ggml_tensor * llama_kv_cache::cpy_v(ggml_context * ctx, ggml_tensor * v_cur, ggm
 
     auto * v = layers[ikv].v;
 
-    const int64_t n_embd_v_gqa = v_cur->ne[0]*v_cur->ne[1];
-    const int64_t n_tokens     = v_cur->ne[2];
+    const int64_t n_embd_head = v_cur->ne[0];
+    const int64_t n_head      = v_cur->ne[1];
+    const int64_t n_tokens    = v_cur->ne[2];
+
+    const int64_t n_embd_gqa = n_embd_head*n_head;
 
-    v_cur = ggml_reshape_2d(ctx, v_cur, n_embd_v_gqa, n_tokens);
+    // we can merge dims 0 and 1
+    GGML_ASSERT(ggml_row_size(v_cur->type, n_embd_head) == v_cur->nb[1]);
 
+    const int64_t n_stream = v->ne[2];
+
+    // take this branch when FA is enabled (the V cache is not transposed)
     if (!v_trans) {
-        if (v->ne[2] > 1) {
-            v = ggml_reshape_2d(ctx, v, v->ne[0], v->ne[1]*v->ne[2]);
+        v_cur = ggml_view_2d(ctx, v_cur, n_embd_gqa, n_tokens, v_cur->nb[2], 0);
+
+        if (n_stream > 1) {
+            const int64_t kv_size = get_size();
+
+            assert(n_embd_gqa == v->ne[0]);
+            assert(kv_size    == v->ne[1]);
+
+            // merge the buffer across all streams because the idxs are global
+            v = ggml_reshape_2d(ctx, v, n_embd_gqa, kv_size*n_stream);
         }
 
         return ggml_set_rows(ctx, v, v_cur, v_idxs);
     }
 
+    if (ggml_row_size(v_cur->type, n_embd_gqa) == v_cur->nb[2]) {
+        // we can merge dims 0, 1 and 2
+        v_cur = ggml_reshape_2d(ctx, v_cur, n_embd_gqa, n_tokens);
+    } else {
+        // otherwise -> make a copy to get contiguous data
+        v_cur = ggml_cont_2d   (ctx, v_cur, n_embd_gqa, n_tokens);
+    }
+
     // [TAG_V_CACHE_VARIABLE]
-    if (n_embd_v_gqa < v->ne[0]) {
-        v_cur = ggml_pad(ctx, v_cur, v->ne[0] - n_embd_v_gqa, 0, 0, 0);
+    if (n_embd_gqa < v->ne[0]) {
+        v_cur = ggml_pad(ctx, v_cur, v->ne[0] - n_embd_gqa, 0, 0, 0);
     }
 
-    // the row becomes a single element
-    ggml_tensor * v_view = ggml_reshape_2d(ctx, v, 1, v->ne[0]*v->ne[1]*v->ne[2]);
+    // in this branch the v_idxs are constructed in such a way that each row is a single head element
+    ggml_tensor * v_view = ggml_reshape_2d(ctx, v, 1, ggml_nelements(v));
 
-    v_cur = ggml_reshape_2d(ctx, v_cur, 1, v_cur->ne[0]*v_cur->ne[1]);
+    v_cur = ggml_reshape_2d(ctx, v_cur, 1, ggml_nelements(v_cur));
 
     return ggml_set_rows(ctx, v_view, v_cur, v_idxs);
 }
index b545bf5b9cf7113867688c0ff797d94ccf4f321c..30de013f5f7f36b0901716ae63cc10e7bbd907eb 100644 (file)
@@ -317,9 +317,17 @@ public:
     ggml_tensor * get_v(ggml_context * ctx, int32_t il) const;
 
     // store k_cur and v_cur in the cache based on the provided head location
+    // note: the heads in k_cur and v_cur should be layed out contiguously in memory
+    //   - k_cur  [n_embd_head_k, n_head_k, n_tokens]
+    //   - k_idxs [n_tokens]
+    //   - v_cur  [n_embd_head_v, n_head_v, n_tokens]
+    //   - v_idxs [n_tokens] or [n_tokens*n_embd_v_gqa] depending if V cache is transposed
     ggml_tensor * cpy_k(ggml_context * ctx, ggml_tensor * k_cur, ggml_tensor * k_idxs, int32_t il) const;
     ggml_tensor * cpy_v(ggml_context * ctx, ggml_tensor * v_cur, ggml_tensor * v_idxs, int32_t il) const;
 
+    // create destination indices for each head of the current batch for where it would be written in the KV cache
+    // the indices address the global KV cache (not per stream) - this is not relevant for the user of this API, but
+    //   helps understand the implementation logic of cpy_k and cpy_v
     ggml_tensor * build_input_k_idxs(ggml_context * ctx, const llama_ubatch & ubatch) const;
     ggml_tensor * build_input_v_idxs(ggml_context * ctx, const llama_ubatch & ubatch) const;
 
index 1813f06d7b308b7014d89ca8cd2c72851c3a3d57..b9e4634a7061c81f1eaf0b7c7a2db6e6a495e0d8 100644 (file)
@@ -6927,9 +6927,7 @@ struct llm_build_falcon : public llm_graph_context {
 
                 ggml_tensor * Qcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head,    n_tokens, n_embd_head*sizeof(float), cur->nb[1], 0*sizeof(float)*(n_embd));
                 ggml_tensor * Kcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head_kv, n_tokens, n_embd_head*sizeof(float), cur->nb[1], 1*sizeof(float)*(n_embd));
-                ggml_tensor * Vcur = ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd + n_embd_gqa));
-
-                Vcur = ggml_cont_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens);
+                ggml_tensor * Vcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head_kv, n_tokens, n_embd_head*sizeof(float), cur->nb[1], 1*sizeof(float)*(n_embd + n_embd_gqa));
 
                 // using mode = 2 for neox mode
                 Qcur = ggml_rope_ext(
@@ -7207,9 +7205,7 @@ struct llm_build_dbrx : public llm_graph_context {
 
                 Qcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head,    n_tokens, n_embd_head*sizeof(float), cur->nb[1], 0*sizeof(float)*(n_embd));
                 Kcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head_kv, n_tokens, n_embd_head*sizeof(float), cur->nb[1], 1*sizeof(float)*(n_embd));
-                Vcur = ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd + n_embd_gqa));
-
-                Vcur = ggml_cont_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens);
+                Vcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head_kv, n_tokens, n_embd_head*sizeof(float), cur->nb[1], 1*sizeof(float)*(n_embd + n_embd_gqa));
 
                 Qcur = ggml_rope_ext(
                         ctx0, Qcur, inp_pos, nullptr,
@@ -7329,13 +7325,9 @@ struct llm_build_starcoder : public llm_graph_context {
                 cur = ggml_add(ctx0, cur, model.layers[il].bqkv);
                 cb(cur, "bqkv", il);
 
-                ggml_tensor * Qcur = ggml_view_2d(ctx0, cur, n_embd,     n_tokens, cur->nb[1], 0*sizeof(float)*(n_embd));
-                ggml_tensor * Kcur = ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd));
-                ggml_tensor * Vcur = ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd + n_embd_gqa));
-
-                Qcur = ggml_cont_3d(ctx0, Qcur, n_embd_head, n_head,    n_tokens);
-                Kcur = ggml_cont_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens);
-                Vcur = ggml_cont_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens);
+                ggml_tensor * Qcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head,    n_tokens, n_embd_head*sizeof(float), cur->nb[1], 0*sizeof(float)*(n_embd));
+                ggml_tensor * Kcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head_kv, n_tokens, n_embd_head*sizeof(float), cur->nb[1], 1*sizeof(float)*(n_embd));
+                ggml_tensor * Vcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head_kv, n_tokens, n_embd_head*sizeof(float), cur->nb[1], 1*sizeof(float)*(n_embd + n_embd_gqa));
 
                 cb(Qcur, "Qcur", il);
                 cb(Kcur, "Kcur", il);
@@ -7551,14 +7543,16 @@ struct llm_build_bert : public llm_graph_context {
                         cb(cur, "bqkv", il);
                     }
 
-                    Qcur = ggml_view_2d(ctx0, cur, n_embd,     n_tokens, cur->nb[1], 0*sizeof(float)*(n_embd));
-                    Kcur = ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd));
-                    Vcur = ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd + n_embd_gqa));
-                    Vcur = ggml_cont_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens);
+                    Qcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head,    n_tokens, n_embd_head*sizeof(float), cur->nb[1], 0*sizeof(float)*(n_embd));
+                    Kcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head_kv, n_tokens, n_embd_head*sizeof(float), cur->nb[1], 1*sizeof(float)*(n_embd));
+                    Vcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head_kv, n_tokens, n_embd_head*sizeof(float), cur->nb[1], 1*sizeof(float)*(n_embd + n_embd_gqa));
                 } else {
                     Qcur = ggml_add(ctx0, build_lora_mm(model.layers[il].wq, cur), model.layers[il].bq);
                     Kcur = ggml_add(ctx0, build_lora_mm(model.layers[il].wk, cur), model.layers[il].bk);
                     Vcur = ggml_add(ctx0, build_lora_mm(model.layers[il].wv, cur), model.layers[il].bv);
+
+                    Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head,    n_tokens);
+                    Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens);
                     Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens);
                 }
 
@@ -7569,8 +7563,6 @@ struct llm_build_bert : public llm_graph_context {
                             LLM_NORM, il);
 
                     Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head,    n_tokens);
-                } else {
-                    Qcur = ggml_cont_3d(ctx0, Qcur, n_embd_head, n_head,    n_tokens);
                 }
 
                 if (model.layers[il].attn_k_norm) {
@@ -7580,8 +7572,6 @@ struct llm_build_bert : public llm_graph_context {
                             LLM_NORM, il);
 
                     Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens);
-                } else {
-                    Kcur = ggml_cont_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens);
                 }
 
                 // RoPE
@@ -7727,9 +7717,7 @@ struct llm_build_neo_bert : public llm_graph_context {
 
                 Qcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head,    n_tokens, n_embd_head*sizeof(float), cur->nb[1], 0*sizeof(float)*(n_embd));
                 Kcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head_kv, n_tokens, n_embd_head*sizeof(float), cur->nb[1], 1*sizeof(float)*(n_embd));
-                Vcur = ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd + n_embd_gqa));
-
-                Vcur = ggml_cont_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens);
+                Vcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head_kv, n_tokens, n_embd_head*sizeof(float), cur->nb[1], 1*sizeof(float)*(n_embd + n_embd_gqa));
 
                 // RoPE
                 Qcur = ggml_rope_ext(
@@ -7836,13 +7824,9 @@ struct llm_build_bloom : public llm_graph_context {
                 cur = ggml_add(ctx0, cur, model.layers[il].bqkv);
                 cb(cur, "bqkv", il);
 
-                ggml_tensor * Qcur = ggml_view_2d(ctx0, cur, n_embd,     n_tokens, cur->nb[1], 0*sizeof(float)*(n_embd));
-                ggml_tensor * Kcur = ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd));
-                ggml_tensor * Vcur = ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd + n_embd_gqa));
-
-                Qcur = ggml_cont_3d(ctx0, Qcur, n_embd_head, n_head,    n_tokens);
-                Kcur = ggml_cont_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens);
-                Vcur = ggml_cont_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens);
+                ggml_tensor * Qcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head,    n_tokens, n_embd_head*sizeof(float), cur->nb[1], 0*sizeof(float)*(n_embd));
+                ggml_tensor * Kcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head_kv, n_tokens, n_embd_head*sizeof(float), cur->nb[1], 1*sizeof(float)*(n_embd));
+                ggml_tensor * Vcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head_kv, n_tokens, n_embd_head*sizeof(float), cur->nb[1], 1*sizeof(float)*(n_embd + n_embd_gqa));
 
                 cb(Qcur, "Qcur", il);
                 cb(Kcur, "Kcur", il);
@@ -7958,13 +7942,9 @@ struct llm_build_mpt : public llm_graph_context {
                     cb(cur, "wqkv_clamped", il);
                 }
 
-                ggml_tensor * Qcur = ggml_view_2d(ctx0, cur, n_embd,     n_tokens, cur->nb[1], 0*sizeof(float)*(n_embd));
-                ggml_tensor * Kcur = ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd));
-                ggml_tensor * Vcur = ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd + n_embd_gqa));
-
-                cb(Qcur, "Qcur", il);
-                cb(Kcur, "Kcur", il);
-                cb(Vcur, "Vcur", il);
+                ggml_tensor * Qcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head,    n_tokens, n_embd_head*sizeof(float), cur->nb[1], 0*sizeof(float)*(n_embd));
+                ggml_tensor * Kcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head_kv, n_tokens, n_embd_head*sizeof(float), cur->nb[1], 1*sizeof(float)*(n_embd));
+                ggml_tensor * Vcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head_kv, n_tokens, n_embd_head*sizeof(float), cur->nb[1], 1*sizeof(float)*(n_embd + n_embd_gqa));
 
                 // Q/K Layernorm
                 if (model.layers[il].attn_q_norm) {
@@ -7972,26 +7952,16 @@ struct llm_build_mpt : public llm_graph_context {
                             model.layers[il].attn_q_norm,
                             model.layers[il].attn_q_norm_b,
                             LLM_NORM, il);
-                    cb(Qcur, "Qcur", il);
 
                     Kcur = build_norm(Kcur,
                             model.layers[il].attn_k_norm,
                             model.layers[il].attn_k_norm_b,
                             LLM_NORM, il);
-                    cb(Kcur, "Kcur", il);
 
                     Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head,    n_tokens);
                     Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens);
-                } else {
-                    Qcur = ggml_cont_3d(ctx0, Qcur, n_embd_head, n_head,    n_tokens);
-                    cb(Qcur, "Qcur", il);
-
-                    Kcur = ggml_cont_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens);
-                    cb(Kcur, "Kcur", il);
                 }
 
-                Vcur = ggml_cont_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens);
-
                 cb(Qcur, "Qcur", il);
                 cb(Kcur, "Kcur", il);
                 cb(Vcur, "Vcur", il);
@@ -8240,11 +8210,9 @@ struct llm_build_qwen : public llm_graph_context {
                 cur = ggml_add(ctx0, cur, model.layers[il].bqkv);
                 cb(cur, "bqkv", il);
 
-                ggml_tensor * Qcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head,   n_tokens, n_embd_head*sizeof(float), cur->nb[1], 0*sizeof(float)*(n_embd));
+                ggml_tensor * Qcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head,    n_tokens, n_embd_head*sizeof(float), cur->nb[1], 0*sizeof(float)*(n_embd));
                 ggml_tensor * Kcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head_kv, n_tokens, n_embd_head*sizeof(float), cur->nb[1], 1*sizeof(float)*(n_embd));
-                ggml_tensor * Vcur = ggml_view_2d(ctx0, cur, n_embd, n_tokens, cur->nb[1], 2*sizeof(float)*(n_embd));
-
-                Vcur = ggml_cont_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens);
+                ggml_tensor * Vcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head_kv, n_tokens, n_embd_head*sizeof(float), cur->nb[1], 2*sizeof(float)*(n_embd));
 
                 // using mode = 2 for neox mode
                 Qcur = ggml_rope_ext(
@@ -9219,21 +9187,17 @@ struct llm_build_phi2 : public llm_graph_context {
 
                     Qcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head,    n_tokens, n_embd_head*sizeof(float), cur->nb[1], 0*sizeof(float)*(n_embd));
                     Kcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head_kv, n_tokens, n_embd_head*sizeof(float), cur->nb[1], 1*sizeof(float)*(n_embd));
-                    Vcur = ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd + n_embd_gqa));
-                    Vcur = ggml_cont_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens);
+                    Vcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head_kv, n_tokens, n_embd_head*sizeof(float), cur->nb[1], 1*sizeof(float)*(n_embd + n_embd_gqa));
                 } else {
                     Qcur = ggml_add(ctx0, build_lora_mm(model.layers[il].wq, attn_norm_output), model.layers[il].bq);
                     Kcur = ggml_add(ctx0, build_lora_mm(model.layers[il].wk, attn_norm_output), model.layers[il].bk);
                     Vcur = ggml_add(ctx0, build_lora_mm(model.layers[il].wv, attn_norm_output), model.layers[il].bv);
+
                     Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head,    n_tokens);
                     Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens);
                     Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens);
                 }
 
-                cb(Qcur, "Qcur", il);
-                cb(Kcur, "Kcur", il);
-                cb(Vcur, "Vcur", il);
-
                 Qcur = ggml_rope_ext(
                         ctx0, Qcur, inp_pos, nullptr,
                         n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
@@ -9357,21 +9321,17 @@ struct llm_build_phi3 : public llm_graph_context {
 
                     Qcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head,    n_tokens, n_embd_head * sizeof(float), cur->nb[1], 0 * sizeof(float) * (n_embd));
                     Kcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head_kv, n_tokens, n_embd_head * sizeof(float), cur->nb[1], 1 * sizeof(float) * (n_embd));
-                    Vcur = ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1 * sizeof(float) * (n_embd + n_embd_gqa));
-                    Vcur = ggml_cont_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens);
+                    Vcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head_kv, n_tokens, n_embd_head * sizeof(float), cur->nb[1], 1 * sizeof(float) * (n_embd + n_embd_gqa));
                 } else {
                     Qcur = ggml_add(ctx0, build_lora_mm(model.layers[il].wq, attn_norm_output), model.layers[il].bq);
                     Kcur = ggml_add(ctx0, build_lora_mm(model.layers[il].wk, attn_norm_output), model.layers[il].bk);
                     Vcur = ggml_add(ctx0, build_lora_mm(model.layers[il].wv, attn_norm_output), model.layers[il].bv);
+
                     Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head,    n_tokens);
                     Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens);
                     Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens);
                 }
 
-                cb(Qcur, "Qcur", il);
-                cb(Kcur, "Kcur", il);
-                cb(Vcur, "Vcur", il);
-
                 Qcur = ggml_rope_ext(
                         ctx0, Qcur, inp_pos, rope_factors,
                         n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
@@ -9621,18 +9581,14 @@ struct llm_build_gpt2 : public llm_graph_context {
                 cur = ggml_add(ctx0, cur, model.layers[il].bqkv);
                 cb(cur, "bqkv", il);
 
-                ggml_tensor * Qcur = ggml_view_2d(ctx0, cur, n_embd,     n_tokens, cur->nb[1], 0*sizeof(float)*(n_embd));
-                ggml_tensor * Kcur = ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd));
-                ggml_tensor * Vcur = ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd + n_embd_gqa));
+                ggml_tensor * Qcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head,    n_tokens, n_embd_head*sizeof(float), cur->nb[1], 0*sizeof(float)*(n_embd));
+                ggml_tensor * Kcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head_kv, n_tokens, n_embd_head*sizeof(float), cur->nb[1], 1*sizeof(float)*(n_embd));
+                ggml_tensor * Vcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head_kv, n_tokens, n_embd_head*sizeof(float), cur->nb[1], 1*sizeof(float)*(n_embd + n_embd_gqa));
 
                 cb(Qcur, "Qcur", il);
                 cb(Kcur, "Kcur", il);
                 cb(Vcur, "Vcur", il);
 
-                Qcur = ggml_cont_3d(ctx0, Qcur, n_embd_head, n_head,    n_tokens);
-                Kcur = ggml_cont_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens);
-                Vcur = ggml_cont_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens);
-
                 cur = build_attn(inp_attn,
                         model.layers[il].wo, model.layers[il].bo,
                         Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
@@ -9727,9 +9683,7 @@ struct llm_build_codeshell : public llm_graph_context {
 
                 ggml_tensor * Qcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head,    n_tokens, n_embd_head*sizeof(float), cur->nb[1], 0*sizeof(float)*(n_embd));
                 ggml_tensor * Kcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head_kv, n_tokens, n_embd_head*sizeof(float), cur->nb[1], 1*sizeof(float)*(n_embd));
-                ggml_tensor * Vcur = ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd + n_embd_gqa));
-
-                Vcur = ggml_cont_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens);
+                ggml_tensor * Vcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head_kv, n_tokens, n_embd_head*sizeof(float), cur->nb[1], 1*sizeof(float)*(n_embd + n_embd_gqa));
 
                 Qcur = ggml_rope_ext(
                         ctx0, Qcur, inp_pos, nullptr,
@@ -12601,9 +12555,7 @@ struct llm_build_gptneox : public llm_graph_context {
 
                 ggml_tensor * Qcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head,    n_tokens, n_embd_head*sizeof(float), cur->nb[1], 0*sizeof(float)*(n_embd));
                 ggml_tensor * Kcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head_kv, n_tokens, n_embd_head*sizeof(float), cur->nb[1], 1*sizeof(float)*(n_embd));
-                ggml_tensor * Vcur = ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd + n_embd_gqa));
-
-                Vcur = ggml_cont_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens);
+                ggml_tensor * Vcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head_kv, n_tokens, n_embd_head*sizeof(float), cur->nb[1], 1*sizeof(float)*(n_embd + n_embd_gqa));
 
                 Qcur = ggml_rope_ext(
                         ctx0, Qcur, inp_pos, nullptr,
@@ -13736,18 +13688,14 @@ struct llm_build_jais : public llm_graph_context {
                 cur = ggml_add(ctx0, cur, model.layers[il].bqkv);
                 cb(cur, "bqkv", il);
 
-                ggml_tensor * Qcur = ggml_view_2d(ctx0, cur, n_embd,     n_tokens, cur->nb[1], 0*cur->nb[0]*(n_embd));
-                ggml_tensor * Kcur = ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*cur->nb[0]*(n_embd));
-                ggml_tensor * Vcur = ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*cur->nb[0]*(n_embd + n_embd_gqa));
+                ggml_tensor * Qcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head,    n_tokens, n_embd_head*sizeof(float), cur->nb[1], 0*cur->nb[0]*(n_embd));
+                ggml_tensor * Kcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head_kv, n_tokens, n_embd_head*sizeof(float), cur->nb[1], 1*cur->nb[0]*(n_embd));
+                ggml_tensor * Vcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head_kv, n_tokens, n_embd_head*sizeof(float), cur->nb[1], 1*cur->nb[0]*(n_embd + n_embd_gqa));
 
                 cb(Qcur, "Qcur", il);
                 cb(Kcur, "Kcur", il);
                 cb(Vcur, "Vcur", il);
 
-                Qcur = ggml_cont_3d(ctx0, Qcur, n_embd_head, n_head,    n_tokens);
-                Kcur = ggml_cont_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens);
-                Vcur = ggml_cont_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens);
-
                 cur = build_attn(inp_attn,
                         model.layers[il].wo, model.layers[il].bo,
                         Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f/float(n_embd_head), il);
@@ -13859,8 +13807,7 @@ struct llm_build_chatglm : public llm_graph_context {
                     }
                     Qcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head,    n_tokens, n_embd_head*sizeof(float), cur->nb[1], 0*sizeof(float)*(n_embd));
                     Kcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head_kv, n_tokens, n_embd_head*sizeof(float), cur->nb[1], 1*sizeof(float)*(n_embd));
-                    Vcur = ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd + n_embd_gqa));
-                    Vcur = ggml_cont_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens);
+                    Vcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head_kv, n_tokens, n_embd_head*sizeof(float), cur->nb[1], 1*sizeof(float)*(n_embd + n_embd_gqa));
                 }
 
                 //printf("freq_base: %f freq_scale: %f ext_factor: %f attn_factor: %f\n", freq_base, freq_scale, ext_factor, attn_factor);
@@ -13993,8 +13940,7 @@ struct llm_build_glm4 : public llm_graph_context {
                     }
                     Qcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head,    n_tokens, n_embd_head*sizeof(float), cur->nb[1], 0*sizeof(float)*(n_embd));
                     Kcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head_kv, n_tokens, n_embd_head*sizeof(float), cur->nb[1], 1*sizeof(float)*(n_embd));
-                    Vcur = ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd + n_embd_gqa));
-                    Vcur = ggml_cont_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens);
+                    Vcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head_kv, n_tokens, n_embd_head*sizeof(float), cur->nb[1], 1*sizeof(float)*(n_embd + n_embd_gqa));
                 }
 
                 Qcur = ggml_rope_ext(
@@ -17293,16 +17239,14 @@ private:
             const int64_t k_offset = n_embd_head_q * n_head;
             const int64_t v_offset = k_offset + n_embd_head_k * n_head_kv;
 
-            ggml_tensor * Qcur = ggml_view_3d(ctx0, qkv, n_embd_head_q, n_head, n_tokens, n_embd_head_q * sizeof(float), qkv->nb[1], q_offset * ggml_element_size(qkv));
+            ggml_tensor * Qcur = ggml_view_3d(ctx0, qkv, n_embd_head_q, n_head,    n_tokens, n_embd_head_q * sizeof(float), qkv->nb[1], q_offset * ggml_element_size(qkv));
             ggml_tensor * Kcur = ggml_view_3d(ctx0, qkv, n_embd_head_k, n_head_kv, n_tokens, n_embd_head_k * sizeof(float), qkv->nb[1], k_offset * ggml_element_size(qkv));
-            ggml_tensor * Vcur = ggml_view_2d(ctx0, qkv, n_embd_head_v * n_head_kv, n_tokens, qkv->nb[1], v_offset * ggml_element_size(qkv));
+            ggml_tensor * Vcur = ggml_view_3d(ctx0, qkv, n_embd_head_v, n_head_kv, n_tokens, n_embd_head_v * sizeof(float), qkv->nb[1], v_offset * ggml_element_size(qkv));
 
             cb(Qcur, "Qcur", il);
             cb(Kcur, "Kcur", il);
             cb(Vcur, "Vcur", il);
 
-            Vcur = ggml_cont_3d(ctx0, Vcur, n_embd_head_v, n_head_kv, n_tokens);
-
             Qcur = build_norm(Qcur, model.layers[il].attn_q_norm, NULL, LLM_NORM_RMS, il);
             cb(Qcur, "Qcur_normed", il);