]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
graph : normalize Q, K, V shapes + sync cross attention (#12449)
authorGeorgi Gerganov <redacted>
Tue, 18 Mar 2025 19:35:19 +0000 (21:35 +0200)
committerGitHub <redacted>
Tue, 18 Mar 2025 19:35:19 +0000 (21:35 +0200)
* graph : normalize Q, K, V shapes and add comments

ggml-ci

* context : synchronize before getting cross attention data

* model : fix command-r attention norm check

src/llama-context.cpp
src/llama-graph.cpp
src/llama-graph.h
src/llama-model.cpp

index 42332acf1e39d7e64f88cdfd017c6664a38adbf2..664703a896701a9988a23fd79405f00154622bb0 100644 (file)
@@ -1143,6 +1143,8 @@ int llama_context::encode(llama_batch & inp_batch) {
     if (model.arch == LLM_ARCH_T5 && t_embd) {
         //cross.t_embd = t_embd;
 
+        synchronize();
+
         cross.n_embd = t_embd->ne[0];
         cross.n_enc  = t_embd->ne[1];
         cross.v_embd.resize(cross.n_embd*cross.n_enc);
index 4e90873397ca4c230f3c79f82095a415c25e455f..0bd40174438cce54a077beafc8a4d4e2ce46a870 100644 (file)
@@ -1378,7 +1378,7 @@ ggml_tensor * llm_graph_context::build_attn(
         // note: storing RoPE-ed version of K in the KV cache
         ggml_build_forward_expand(gf, ggml_cpy(ctx0, k_cur, k_cache_view));
 
-        assert(v_cur->ne[0] == n_embd_v_gqa && v_cur->ne[1] == n_tokens);
+        v_cur = ggml_reshape_2d(ctx0, v_cur, n_embd_v_gqa, n_tokens);
 
         ggml_tensor * v_cache_view = nullptr;
 
index c4328e6f9e627848aa5ba07ef888750dd3731356..bdf19ed015e35a898b7f54d1bafe0a6c59968c90 100644 (file)
@@ -487,9 +487,9 @@ struct llm_graph_context {
 
     ggml_tensor * build_attn_mha(
              ggml_cgraph * gf,
-             ggml_tensor * q,
-             ggml_tensor * k,
-             ggml_tensor * v,
+             ggml_tensor * q, // [n_embd_head_q, n_tokens, n_head_q]
+             ggml_tensor * k, // [n_embd_head_k, n_tokens, n_head_k]
+             ggml_tensor * v, // [n_embd_head_v, n_tokens, n_head_v] (v_trans == false)
              ggml_tensor * kq_b,
              ggml_tensor * kq_mask,
                     bool   v_trans,
@@ -502,9 +502,9 @@ struct llm_graph_context {
             ggml_cgraph * gf,
             ggml_tensor * wo,
             ggml_tensor * wo_b,
-            ggml_tensor * q_cur,
-            ggml_tensor * k_cur,
-            ggml_tensor * v_cur,
+            ggml_tensor * q_cur, // [n_embd_head_q, n_head_q, n_tokens]
+            ggml_tensor * k_cur, // [n_embd_head_k, n_head_k, n_tokens]
+            ggml_tensor * v_cur, // [n_embd_head_v, n_head_v, n_tokens]
             ggml_tensor * kq_b,
                   float   kq_scale,
                     int   il) const;
@@ -516,9 +516,9 @@ struct llm_graph_context {
             ggml_cgraph * gf,
             ggml_tensor * wo,
             ggml_tensor * wo_b,
-            ggml_tensor * q_cur,
-            ggml_tensor * k_cur,
-            ggml_tensor * v_cur,
+            ggml_tensor * q_cur, // [n_embd_head_q, n_head_q, n_tokens]
+            ggml_tensor * k_cur, // [n_embd_head_k, n_head_k, n_tokens]
+            ggml_tensor * v_cur, // [n_embd_head_v, n_head_v, n_tokens]
             ggml_tensor * kq_b,
                   float   kq_scale,
                     int   il) const;
@@ -530,9 +530,9 @@ struct llm_graph_context {
             ggml_cgraph * gf,
             ggml_tensor * wo,
             ggml_tensor * wo_b,
-            ggml_tensor * q_cur,
-            ggml_tensor * k_cur,
-            ggml_tensor * v_cur,
+            ggml_tensor * q_cur, // [n_embd_head_q, n_head_q, n_tokens]
+            ggml_tensor * k_cur, // [n_embd_head_k, n_head_k, n_tokens]
+            ggml_tensor * v_cur, // [n_embd_head_v, n_head_v, n_tokens]
             ggml_tensor * kq_b,
                   float   kq_scale,
                     int   il) const;
index 9171585bd9d915abc2f5cc4070450afe63dc6557..d286176c1ff839f7a2b6f40ad9de603bc104b90c 100644 (file)
@@ -4093,19 +4093,25 @@ struct llm_build_llama : public llm_graph_context {
                     cb(Vcur, "Vcur", il);
                 }
 
+                Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head,    n_tokens);
+                Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens);
+                Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens);
+
                 Qcur = ggml_rope_ext(
-                        ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos, rope_factors,
+                        ctx0, Qcur, inp_pos, rope_factors,
                         n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
                         ext_factor, attn_factor, beta_fast, beta_slow
                         );
-                cb(Qcur, "Qcur", il);
 
                 Kcur = ggml_rope_ext(
-                        ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos, rope_factors,
+                        ctx0, Kcur, inp_pos, rope_factors,
                         n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
                         ext_factor, attn_factor, beta_fast, beta_slow
                         );
+
+                cb(Qcur, "Qcur", il);
                 cb(Kcur, "Kcur", il);
+                cb(Vcur, "Vcur", il);
 
                 cur = build_attn(inp_attn, gf,
                         model.layers[il].wo, model.layers[il].bo,
@@ -4267,19 +4273,25 @@ struct llm_build_deci : public llm_graph_context {
                     cb(Vcur, "Vcur", il);
                 }
 
+                Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head,    n_tokens);
+                Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens);
+                Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens);
+
                 Qcur = ggml_rope_ext(
-                        ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos, rope_factors,
+                        ctx0, Qcur, inp_pos, rope_factors,
                         n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
                         ext_factor, attn_factor, beta_fast, beta_slow
                         );
-                cb(Qcur, "Qcur", il);
 
                 Kcur = ggml_rope_ext(
-                        ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos, rope_factors,
+                        ctx0, Kcur, inp_pos, rope_factors,
                         n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
                         ext_factor, attn_factor, beta_fast, beta_slow
                         );
+
+                cb(Qcur, "Qcur", il);
                 cb(Kcur, "Kcur", il);
+                cb(Vcur, "Vcur", il);
 
                 cur = build_attn(inp_attn, gf,
                         model.layers[il].wo, model.layers[il].bo,
@@ -4396,28 +4408,32 @@ struct llm_build_baichuan : public llm_graph_context {
                 ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur);
                 cb(Vcur, "Vcur", il);
 
+                Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head,    n_tokens);
+                Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens);
+                Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens);
+
                 switch (model.type) {
                     case LLM_TYPE_7B:
                         Qcur = ggml_rope_ext(
-                                ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos, nullptr,
+                                ctx0, Qcur, inp_pos, nullptr,
                                 n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
                                 ext_factor, attn_factor, beta_fast, beta_slow
                                 );
                         Kcur = ggml_rope_ext(
-                                ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos, nullptr,
+                                ctx0, Kcur, inp_pos, nullptr,
                                 n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
                                 ext_factor, attn_factor, beta_fast, beta_slow
                                 );
                         break;
                     case LLM_TYPE_13B:
-                        Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd/n_head, n_head, n_tokens);
-                        Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd/n_head, n_head, n_tokens);
                         break;
                     default:
                         GGML_ABORT("fatal error");
                 }
+
                 cb(Qcur, "Qcur", il);
                 cb(Kcur, "Kcur", il);
+                cb(Vcur, "Vcur", il);
 
                 cur = build_attn(inp_attn, gf,
                         model.layers[il].wo, NULL,
@@ -4514,19 +4530,25 @@ struct llm_build_xverse : public llm_graph_context {
                 ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur);
                 cb(Vcur, "Vcur", il);
 
+                Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head,    n_tokens);
+                Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens);
+                Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens);
+
                 Qcur = ggml_rope_ext(
-                        ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos, nullptr,
+                        ctx0, Qcur, inp_pos, nullptr,
                         n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
                         ext_factor, attn_factor, beta_fast, beta_slow
                         );
-                cb(Qcur, "Qcur", il);
 
                 Kcur = ggml_rope_ext(
-                        ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos, nullptr,
+                        ctx0, Kcur, inp_pos, nullptr,
                         n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
                         ext_factor, attn_factor, beta_fast, beta_slow
                         );
+
+                cb(Qcur, "Qcur", il);
                 cb(Kcur, "Kcur", il);
+                cb(Vcur, "Vcur", il);
 
                 cur = build_attn(inp_attn, gf,
                         model.layers[il].wo, NULL,
@@ -4632,25 +4654,26 @@ struct llm_build_falcon : public llm_graph_context {
                 ggml_tensor * Kcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd)));
                 ggml_tensor * Vcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd + n_embd_gqa)));
 
-                cb(Qcur, "Qcur", il);
-                cb(Kcur, "Kcur", il);
-                cb(Vcur, "Vcur", il);
-
                 Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head,    n_tokens);
                 Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens);
+                Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens);
 
                 // using mode = 2 for neox mode
                 Qcur = ggml_rope_ext(
-                        ctx0, Qcur, inp_pos, nullptr, n_rot, rope_type, n_ctx_orig,
-                        freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow
+                        ctx0, Qcur, inp_pos, nullptr,
+                        n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
+                        ext_factor, attn_factor, beta_fast, beta_slow
                         );
-                cb(Qcur, "Qcur", il);
 
                 Kcur = ggml_rope_ext(
-                        ctx0, Kcur, inp_pos, nullptr, n_rot, rope_type, n_ctx_orig,
-                        freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow
+                        ctx0, Kcur, inp_pos, nullptr,
+                        n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
+                        ext_factor, attn_factor, beta_fast, beta_slow
                         );
+
+                cb(Qcur, "Qcur", il);
                 cb(Kcur, "Kcur", il);
+                cb(Vcur, "Vcur", il);
 
                 cur = build_attn(inp_attn, gf,
                         model.layers[il].wo, NULL,
@@ -4762,19 +4785,25 @@ struct llm_build_grok : public llm_graph_context {
                     cb(Vcur, "Vcur", il);
                 }
 
+                Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head,    n_tokens);
+                Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens);
+                Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens);
+
                 Qcur = ggml_rope_ext(
-                        ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos, nullptr,
+                        ctx0, Qcur, inp_pos, nullptr,
                         n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
                         ext_factor, attn_factor, beta_fast, beta_slow
                         );
-                cb(Qcur, "Qcur", il);
 
                 Kcur = ggml_rope_ext(
-                        ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos, nullptr,
+                        ctx0, Kcur, inp_pos, nullptr,
                         n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
                         ext_factor, attn_factor, beta_fast, beta_slow
                         );
+
+                cb(Qcur, "Qcur", il);
                 cb(Kcur, "Kcur", il);
+                cb(Vcur, "Vcur", il);
 
                 cur = build_attn(inp_attn, gf,
                         model.layers[il].wo, model.layers[il].bo,
@@ -4907,23 +4936,25 @@ struct llm_build_dbrx : public llm_graph_context {
                 Kcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd)));
                 Vcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd + n_embd_gqa)));
 
-                cb(Qcur, "Qcur", il);
-                cb(Kcur, "Kcur", il);
-                cb(Vcur, "Vcur", il);
+                Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head,    n_tokens);
+                Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens);
+                Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens);
 
                 Qcur = ggml_rope_ext(
-                        ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos, nullptr,
+                        ctx0, Qcur, inp_pos, nullptr,
                         n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
                         ext_factor, attn_factor, beta_fast, beta_slow
                         );
-                cb(Qcur, "Qcur", il);
 
                 Kcur = ggml_rope_ext(
-                        ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos, nullptr,
+                        ctx0, Kcur, inp_pos, nullptr,
                         n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
                         ext_factor, attn_factor, beta_fast, beta_slow
                         );
+
+                cb(Qcur, "Qcur", il);
                 cb(Kcur, "Kcur", il);
+                cb(Vcur, "Vcur", il);
 
                 cur = build_attn(inp_attn, gf,
                         model.layers[il].wo, NULL,
@@ -5031,12 +5062,14 @@ struct llm_build_starcoder : public llm_graph_context {
                 ggml_tensor * Kcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd)));
                 ggml_tensor * Vcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd + n_embd_gqa)));
 
+                Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head,    n_tokens);
+                Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens);
+                Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens);
+
                 cb(Qcur, "Qcur", il);
                 cb(Kcur, "Kcur", il);
                 cb(Vcur, "Vcur", il);
 
-                Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens);
-
                 cur = build_attn(inp_attn, gf,
                         model.layers[il].wo, model.layers[il].bo,
                         Qcur, Kcur, Vcur, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
@@ -5128,11 +5161,13 @@ struct llm_build_refact : public llm_graph_context {
                 ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur);
                 cb(Vcur, "Vcur", il);
 
+                Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head,    n_tokens);
                 Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens);
-                cb(Kcur, "Kcur", il);
+                Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens);
 
-                Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head,    n_tokens);
                 cb(Qcur, "Qcur", il);
+                cb(Kcur, "Kcur", il);
+                cb(Vcur, "Vcur", il);
 
                 cur = build_attn(inp_attn, gf,
                         model.layers[il].wo, NULL,
@@ -5267,21 +5302,21 @@ struct llm_build_bert : public llm_graph_context {
                 Kcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd)));
                 Vcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd + n_embd_gqa)));
 
+                Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head,    n_tokens);
+                Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens);
+                Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens);
+
                 Qcur = ggml_rope_ext(
-                        ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head,    n_tokens), inp_pos, nullptr,
+                        ctx0, Qcur, inp_pos, nullptr,
                         n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
                         ext_factor, attn_factor, beta_fast, beta_slow
                         );
-                cb(Qcur, "Qcur", il);
 
                 Kcur = ggml_rope_ext(
-                        ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos, nullptr,
+                        ctx0, Kcur, inp_pos, nullptr,
                         n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
                         ext_factor, attn_factor, beta_fast, beta_slow
                         );
-                cb(Kcur, "Kcur", il);
-
-                Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens);
             }
 
             cb(Qcur, "Qcur", il);
@@ -5397,12 +5432,14 @@ struct llm_build_bloom : public llm_graph_context {
                 ggml_tensor * Kcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd)));
                 ggml_tensor * Vcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd + n_embd_gqa)));
 
+                Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head,    n_tokens);
+                Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens);
+                Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens);
+
                 cb(Qcur, "Qcur", il);
                 cb(Kcur, "Kcur", il);
                 cb(Vcur, "Vcur", il);
 
-                Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens);
-
                 cur = build_attn(inp_attn, gf,
                         model.layers[il].wo, model.layers[il].bo,
                         Qcur, Kcur, Vcur, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
@@ -5534,20 +5571,19 @@ struct llm_build_mpt : public llm_graph_context {
                             model.layers[il].attn_k_norm_b,
                             LLM_NORM, il);
                     cb(Kcur, "Kcur", il);
+                }
 
-                    Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head,    n_tokens);
-                    Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens);
+                Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head,    n_tokens);
+                Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens);
+                Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens);
 
-                    cur = build_attn(inp_attn, gf,
-                            model.layers[il].wo, model.layers[il].bo,
-                            Qcur, Kcur, Vcur, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
-                } else {
-                    Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens);
+                cb(Qcur, "Qcur", il);
+                cb(Kcur, "Kcur", il);
+                cb(Vcur, "Vcur", il);
 
-                    cur = build_attn(inp_attn, gf,
-                            model.layers[il].wo, model.layers[il].bo,
-                            Qcur, Kcur, Vcur, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
-                }
+                cur = build_attn(inp_attn, gf,
+                        model.layers[il].wo, model.layers[il].bo,
+                        Qcur, Kcur, Vcur, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
             }
 
             if (il == n_layer - 1) {
@@ -5656,9 +5692,8 @@ struct llm_build_stablelm : public llm_graph_context {
                 }
 
                 Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head,    n_tokens);
-                cb(Qcur, "Qcur", il);
                 Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens);
-                cb(Kcur, "Kcur", il);
+                Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens);
 
                 if (model.layers[il].attn_q_norm) {
                     Qcur = build_norm(Qcur,
@@ -5667,6 +5702,7 @@ struct llm_build_stablelm : public llm_graph_context {
                             LLM_NORM, il);
                     cb(Qcur, "Qcur", il);
                 }
+
                 if (model.layers[il].attn_k_norm) {
                     Kcur = build_norm(Kcur,
                             model.layers[il].attn_k_norm,
@@ -5675,20 +5711,21 @@ struct llm_build_stablelm : public llm_graph_context {
                     cb(Kcur, "Kcur", il);
                 }
 
-
                 Qcur = ggml_rope_ext(
                         ctx0, Qcur, inp_pos, nullptr,
                         n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
                         ext_factor, attn_factor, beta_fast, beta_slow
                         );
-                cb(Qcur, "Qcur", il);
 
                 Kcur = ggml_rope_ext(
                         ctx0, Kcur, inp_pos, nullptr,
                         n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
                         ext_factor, attn_factor, beta_fast, beta_slow
                         );
+
+                cb(Qcur, "Qcur", il);
                 cb(Kcur, "Kcur", il);
+                cb(Vcur, "Vcur", il);
 
                 cur = build_attn(inp_attn, gf,
                         model.layers[il].wo, NULL,
@@ -5792,25 +5829,26 @@ struct llm_build_qwen : public llm_graph_context {
                 ggml_tensor * Kcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd)));
                 ggml_tensor * Vcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd, n_tokens, cur->nb[1], 2*sizeof(float)*(n_embd)));
 
-                cb(Qcur, "Qcur", il);
-                cb(Kcur, "Kcur", il);
-                cb(Vcur, "Vcur", il);
-
                 Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head,    n_tokens);
                 Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens);
+                Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens);
 
                 // using mode = 2 for neox mode
                 Qcur = ggml_rope_ext(
-                        ctx0, Qcur, inp_pos, nullptr, n_rot, rope_type, n_ctx_orig,
-                        freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow
+                        ctx0, Qcur, inp_pos, nullptr,
+                        n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
+                        ext_factor, attn_factor, beta_fast, beta_slow
                         );
-                cb(Qcur, "Qcur", il);
 
                 Kcur = ggml_rope_ext(
-                        ctx0, Kcur, inp_pos, nullptr, n_rot, rope_type, n_ctx_orig,
-                        freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow
+                        ctx0, Kcur, inp_pos, nullptr,
+                        n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
+                        ext_factor, attn_factor, beta_fast, beta_slow
                         );
+
+                cb(Qcur, "Qcur", il);
                 cb(Kcur, "Kcur", il);
+                cb(Vcur, "Vcur", il);
 
                 cur = build_attn(inp_attn, gf,
                         model.layers[il].wo, NULL,
@@ -5901,33 +5939,36 @@ struct llm_build_qwen2 : public llm_graph_context {
             {
                 // compute Q and K and RoPE them
                 ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur);
-                cb(Qcur, "Qcur", il);
                 Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq);
                 cb(Qcur, "Qcur", il);
 
                 ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur);
-                cb(Kcur, "Kcur", il);
                 Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk);
                 cb(Kcur, "Kcur", il);
 
                 ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur);
-                cb(Vcur, "Vcur", il);
                 Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv);
                 cb(Vcur, "Vcur", il);
 
+                Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head,    n_tokens);
+                Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens);
+                Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens);
+
                 Qcur = ggml_rope_ext(
-                        ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head,    n_tokens), inp_pos, nullptr,
+                        ctx0, Qcur, inp_pos, nullptr,
                         n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
                         ext_factor, attn_factor, beta_fast, beta_slow
                         );
-                cb(Qcur, "Qcur", il);
 
                 Kcur = ggml_rope_ext(
-                        ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos, nullptr,
+                        ctx0, Kcur, inp_pos, nullptr,
                         n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
                         ext_factor, attn_factor, beta_fast, beta_slow
                         );
+
+                cb(Qcur, "Qcur", il);
                 cb(Kcur, "Kcur", il);
+                cb(Vcur, "Vcur", il);
 
                 cur = build_attn(inp_attn, gf,
                         model.layers[il].wo, model.layers[il].bo,
@@ -6019,35 +6060,36 @@ struct llm_build_qwen2vl : public llm_graph_context {
             {
                 // compute Q and K and RoPE them
                 ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur);
-                cb(Qcur, "Qcur", il);
                 Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq);
                 cb(Qcur, "Qcur", il);
 
                 ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur);
-                cb(Kcur, "Kcur", il);
                 Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk);
                 cb(Kcur, "Kcur", il);
 
                 ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur);
-                cb(Vcur, "Vcur", il);
                 Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv);
                 cb(Vcur, "Vcur", il);
 
+                Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head,    n_tokens);
+                Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens);
+                Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens);
+
                 Qcur = ggml_rope_multi(
-                        ctx0,
-                        ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos, nullptr,
+                        ctx0, Qcur, inp_pos, nullptr,
                         n_rot, sections, rope_type, n_ctx_orig, freq_base, freq_scale,
                         ext_factor, attn_factor, beta_fast, beta_slow
                         );
-                cb(Qcur, "Qcur", il);
 
                 Kcur = ggml_rope_multi(
-                        ctx0,
-                        ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos, nullptr,
+                        ctx0, Kcur, inp_pos, nullptr,
                         n_rot, sections, rope_type, n_ctx_orig, freq_base, freq_scale,
                         ext_factor, attn_factor, beta_fast, beta_slow
                         );
+
+                cb(Qcur, "Qcur", il);
                 cb(Kcur, "Kcur", il);
+                cb(Vcur, "Vcur", il);
 
                 cur = build_attn(inp_attn, gf,
                         model.layers[il].wo, model.layers[il].bo,
@@ -6136,33 +6178,36 @@ struct llm_build_qwen2moe : public llm_graph_context {
             {
                 // compute Q and K and RoPE them
                 ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur);
-                cb(Qcur, "Qcur", il);
                 Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq);
                 cb(Qcur, "Qcur", il);
 
                 ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur);
-                cb(Kcur, "Kcur", il);
                 Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk);
                 cb(Kcur, "Kcur", il);
 
                 ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur);
-                cb(Vcur, "Vcur", il);
                 Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv);
                 cb(Vcur, "Vcur", il);
 
+                Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head,    n_tokens);
+                Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens);
+                Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens);
+
                 Qcur = ggml_rope_ext(
-                        ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos, nullptr,
+                        ctx0, Qcur, inp_pos, nullptr,
                         n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
                         ext_factor, attn_factor, beta_fast, beta_slow
                         );
-                cb(Qcur, "Qcur", il);
 
                 Kcur = ggml_rope_ext(
-                        ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos, nullptr,
+                        ctx0, Kcur, inp_pos, nullptr,
                         n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
                         ext_factor, attn_factor, beta_fast, beta_slow
                         );
+
+                cb(Qcur, "Qcur", il);
                 cb(Kcur, "Kcur", il);
+                cb(Vcur, "Vcur", il);
 
                 cur = build_attn(inp_attn, gf,
                         model.layers[il].wo, model.layers[il].bo,
@@ -6307,23 +6352,27 @@ struct llm_build_phi2 : public llm_graph_context {
 
                 Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head,    n_tokens);
                 Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens);
+                Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens);
 
                 Qcur = ggml_rope_ext(
-                        ctx0, Qcur, inp_pos, nullptr, n_rot, rope_type, n_ctx_orig,
-                        freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow
+                        ctx0, Qcur, inp_pos, nullptr,
+                        n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
+                        ext_factor, attn_factor, beta_fast, beta_slow
+                        );
+
+                Kcur = ggml_rope_ext(
+                        ctx0, Kcur, inp_pos, nullptr,
+                        n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
+                        ext_factor, attn_factor, beta_fast, beta_slow
                         );
+
                 cb(Qcur, "Qcur", il);
+                cb(Kcur, "Kcur", il);
+                cb(Vcur, "Vcur", il);
 
                 // with phi2, we scale the Q to avoid precision issues
                 // ref: https://github.com/ml-explore/mlx-examples/blob/08e862336ade809bc37d1035f94b359e7d1a5152/phi2/phi2.py#L64-L66
                 Qcur = ggml_scale(ctx0, Qcur, 1.0f/sqrtf(float(n_embd_head)));
-                cb(Qcur, "Qcur", il);
-
-                Kcur = ggml_rope_ext(
-                        ctx0, Kcur, inp_pos, nullptr, n_rot, rope_type, n_ctx_orig,
-                        freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow
-                        );
-                cb(Kcur, "Kcur", il);
 
                 cur = build_attn(inp_attn, gf,
                         model.layers[il].wo, model.layers[il].bo,
@@ -6433,21 +6482,26 @@ struct llm_build_phi3 : public llm_graph_context {
 
                 Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head,    n_tokens);
                 Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens);
+                Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens);
 
                 Qcur = ggml_rope_ext(
-                        ctx0, Qcur, inp_pos, rope_factors, n_rot, rope_type, n_ctx_orig,
-                        freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow
+                        ctx0, Qcur, inp_pos, rope_factors,
+                        n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
+                        ext_factor, attn_factor, beta_fast, beta_slow
                         );
-                cb(Qcur, "Qcur", il);
-
-                Qcur = ggml_scale(ctx0, Qcur, 1.0f / sqrtf(float(n_embd_head)));
-                cb(Qcur, "Qcur", il);
 
                 Kcur = ggml_rope_ext(
-                        ctx0, Kcur, inp_pos, rope_factors, n_rot, rope_type, n_ctx_orig,
-                        freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow
+                        ctx0, Kcur, inp_pos, rope_factors,
+                        n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
+                        ext_factor, attn_factor, beta_fast, beta_slow
                         );
+
+                cb(Qcur, "Qcur", il);
                 cb(Kcur, "Kcur", il);
+                cb(Vcur, "Vcur", il);
+
+                Qcur = ggml_scale(ctx0, Qcur, 1.0f / sqrtf(float(n_embd_head)));
+                cb(Qcur, "Qcur", il);
 
                 cur = build_attn(inp_attn, gf,
                         model.layers[il].wo, model.layers[il].bo,
@@ -6564,17 +6618,25 @@ struct llm_build_plamo : public llm_graph_context {
                 ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur);
                 cb(Vcur, "Vcur", il);
 
+                Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head,    n_tokens);
+                Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens);
+                Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens);
+
                 Qcur = ggml_rope_ext(
-                        ctx0, ggml_reshape_3d(ctx0, Qcur, n_rot, n_head,    n_tokens), inp_pos, nullptr,
+                        ctx0, Qcur, inp_pos, nullptr,
                         n_embd_head, rope_type, n_ctx_orig, freq_base, freq_scale,
-                        ext_factor, attn_factor, beta_fast, beta_slow);
-                cb(Qcur, "Qcur", il);
+                        ext_factor, attn_factor, beta_fast, beta_slow
+                        );
 
                 Kcur = ggml_rope_ext(
-                        ctx0, ggml_reshape_3d(ctx0, Kcur, n_rot, n_head_kv, n_tokens), inp_pos, nullptr,
+                        ctx0, Kcur, inp_pos, nullptr,
                         n_embd_head, rope_type, n_ctx_orig, freq_base, freq_scale,
-                        ext_factor, attn_factor, beta_fast, beta_slow);
+                        ext_factor, attn_factor, beta_fast, beta_slow
+                        );
+
+                cb(Qcur, "Qcur", il);
                 cb(Kcur, "Kcur", il);
+                cb(Vcur, "Vcur", il);
 
                 cur = build_attn(inp_attn, gf,
                         model.layers[il].wo, NULL,
@@ -6679,7 +6741,9 @@ struct llm_build_gpt2 : public llm_graph_context {
                 cb(Kcur, "Kcur", il);
                 cb(Vcur, "Vcur", il);
 
-                Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens);
+                Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head,    n_tokens);
+                Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens);
+                Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens);
 
                 cur = build_attn(inp_attn, gf,
                         model.layers[il].wo, model.layers[il].bo,
@@ -6773,27 +6837,29 @@ struct llm_build_codeshell : public llm_graph_context {
                 cur = ggml_add(ctx0, cur, model.layers[il].bqkv);
                 cb(cur, "bqkv", il);
 
-                ggml_tensor * tmpq = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd,     n_tokens, cur->nb[1], 0*sizeof(float)*(n_embd)));
-                ggml_tensor * tmpk = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd)));
+                ggml_tensor * Qcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd,     n_tokens, cur->nb[1], 0*sizeof(float)*(n_embd)));
+                ggml_tensor * Kcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd)));
                 ggml_tensor * Vcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd + n_embd_gqa)));
 
-                cb(tmpq, "tmpq", il);
-                cb(tmpk, "tmpk", il);
-                cb(Vcur, "Vcur", il);
+                Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head,    n_tokens);
+                Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens);
+                Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens);
 
-                ggml_tensor * Qcur = ggml_rope_ext(
-                        ctx0, ggml_reshape_3d(ctx0, tmpq, n_embd_head, n_head,    n_tokens), inp_pos, nullptr,
+                Qcur = ggml_rope_ext(
+                        ctx0, Qcur, inp_pos, nullptr,
                         n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
                         ext_factor, attn_factor, beta_fast, beta_slow
                         );
-                cb(Qcur, "Qcur", il);
 
-                ggml_tensor * Kcur = ggml_rope_ext(
-                        ctx0, ggml_reshape_3d(ctx0, tmpk, n_embd_head, n_head_kv, n_tokens), inp_pos, nullptr,
+                Kcur = ggml_rope_ext(
+                        ctx0, Kcur, inp_pos, nullptr,
                         n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
                         ext_factor, attn_factor, beta_fast, beta_slow
                         );
+
+                cb(Qcur, "Qcur", il);
                 cb(Kcur, "Kcur", il);
+                cb(Vcur, "Vcur", il);
 
                 cur = build_attn(inp_attn, gf,
                         model.layers[il].wo, model.layers[il].bo,
@@ -6904,19 +6970,25 @@ struct llm_build_orion : public llm_graph_context {
             //     cb(Vcur, "Vcur", il);
             // }
 
+            Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head,    n_tokens);
+            Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens);
+            Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens);
+
             Qcur = ggml_rope_ext(
-                ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head,    n_tokens), inp_pos, nullptr,
+                ctx0, Qcur, inp_pos, nullptr,
                 n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
                 ext_factor, attn_factor, beta_fast, beta_slow
             );
-            cb(Qcur, "Qcur", il);
 
             Kcur = ggml_rope_ext(
-                ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos, nullptr,
+                ctx0, Kcur, inp_pos, nullptr,
                 n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
                 ext_factor, attn_factor, beta_fast, beta_slow
             );
+
+            cb(Qcur, "Qcur", il);
             cb(Kcur, "Kcur", il);
+            cb(Vcur, "Vcur", il);
 
             cur = build_attn(inp_attn, gf,
                     model.layers[il].wo, NULL,
@@ -7025,19 +7097,25 @@ struct llm_build_internlm2 : public llm_graph_context {
                     cb(Vcur, "Vcur", il);
                 }
 
+                Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head,    n_tokens);
+                Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens);
+                Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens);
+
                 Qcur = ggml_rope_ext(
-                        ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head,    n_tokens), inp_pos, nullptr,
+                        ctx0, Qcur, inp_pos, nullptr,
                         n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
                         ext_factor, attn_factor, beta_fast, beta_slow
                         );
-                cb(Qcur, "Qcur", il);
 
                 Kcur = ggml_rope_ext(
-                        ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos, nullptr,
+                        ctx0, Kcur, inp_pos, nullptr,
                         n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
                         ext_factor, attn_factor, beta_fast, beta_slow
                         );
+
+                cb(Qcur, "Qcur", il);
                 cb(Kcur, "Kcur", il);
+                cb(Vcur, "Vcur", il);
 
                 cur = build_attn(inp_attn, gf,
                         model.layers[il].wo, model.layers[il].bo,
@@ -7311,7 +7389,7 @@ struct llm_build_minicpm3 : public llm_graph_context {
 
 struct llm_build_gemma : public llm_graph_context {
     llm_build_gemma(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) {
-        const int64_t n_embd_head_k = hparams.n_embd_head_k;
+        const int64_t n_embd_head = hparams.n_embd_head_v;
 
         ggml_tensor * cur;
         ggml_tensor * inpL;
@@ -7345,20 +7423,26 @@ struct llm_build_gemma : public llm_graph_context {
                 ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur);
                 cb(Vcur, "Vcur", il);
 
+                Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head,    n_tokens);
+                Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens);
+                Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens);
+
                 Qcur = ggml_rope_ext(
-                        ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head_k, n_head,    n_tokens), inp_pos, nullptr,
+                        ctx0, Qcur, inp_pos, nullptr,
                         n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
                         ext_factor, attn_factor, beta_fast, beta_slow);
-                cb(Qcur, "Qcur", il);
-
-                Qcur = ggml_scale(ctx0, Qcur, 1.0f / sqrtf(float(n_embd_head_k)));
-                cb(Qcur, "Qcur_scaled", il);
 
                 Kcur = ggml_rope_ext(
-                        ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head_k, n_head_kv, n_tokens), inp_pos, nullptr,
+                        ctx0, Kcur, inp_pos, nullptr,
                         n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
                         ext_factor, attn_factor, beta_fast, beta_slow);
+
+                cb(Qcur, "Qcur", il);
                 cb(Kcur, "Kcur", il);
+                cb(Vcur, "Vcur", il);
+
+                Qcur = ggml_scale(ctx0, Qcur, 1.0f / sqrtf(float(n_embd_head)));
+                cb(Qcur, "Qcur_scaled", il);
 
                 cur = build_attn(inp_attn, gf,
                         model.layers[il].wo, NULL,
@@ -7421,7 +7505,7 @@ struct llm_build_gemma : public llm_graph_context {
 
 struct llm_build_gemma2 : public llm_graph_context {
     llm_build_gemma2(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) {
-        const int64_t n_embd_head_k = hparams.n_embd_head_k;
+        const int64_t n_embd_head = hparams.n_embd_head_k;
 
         ggml_tensor * cur;
         ggml_tensor * inpL;
@@ -7455,27 +7539,33 @@ struct llm_build_gemma2 : public llm_graph_context {
                 ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur);
                 cb(Vcur, "Vcur", il);
 
+                Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head,    n_tokens);
+                Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens);
+                Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens);
+
                 Qcur = ggml_rope_ext(
-                        ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head_k, n_head,    n_tokens), inp_pos, nullptr,
+                        ctx0, Qcur, inp_pos, nullptr,
+                        n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
+                        ext_factor, attn_factor, beta_fast, beta_slow);
+
+                Kcur = ggml_rope_ext(
+                        ctx0, Kcur, inp_pos, nullptr,
                         n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
                         ext_factor, attn_factor, beta_fast, beta_slow);
+
                 cb(Qcur, "Qcur", il);
+                cb(Kcur, "Kcur", il);
+                cb(Vcur, "Vcur", il);
 
                 // ref: https://github.com/google/gemma_pytorch/commit/03e657582d17cb5a8617ebf333c1c16f3694670e
                 switch (model.type) {
                     case LLM_TYPE_2B:
-                    case LLM_TYPE_9B:  Qcur = ggml_scale(ctx0, Qcur, 1.0f / sqrtf(float(n_embd_head_k)));   break;
-                    case LLM_TYPE_27B: Qcur = ggml_scale(ctx0, Qcur, 1.0f / sqrtf(float(n_embd / n_head))); break;
+                    case LLM_TYPE_9B:
+                    case LLM_TYPE_27B: Qcur = ggml_scale(ctx0, Qcur, 1.0f / sqrtf(float(n_embd_head))); break;
                     default: GGML_ABORT("fatal error");
                 };
                 cb(Qcur, "Qcur_scaled", il);
 
-                Kcur = ggml_rope_ext(
-                        ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head_k, n_head_kv, n_tokens), inp_pos, nullptr,
-                        n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
-                        ext_factor, attn_factor, beta_fast, beta_slow);
-                cb(Kcur, "Kcur", il);
-
                 cur = build_attn(inp_attn, gf,
                         model.layers[il].wo, NULL,
                         Qcur, Kcur, Vcur, nullptr, 1.0f, il);
@@ -7552,7 +7642,7 @@ struct llm_build_gemma2 : public llm_graph_context {
 
 struct llm_build_gemma3 : public llm_graph_context {
     llm_build_gemma3(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) {
-        const int64_t n_embd_head_k = hparams.n_embd_head_k;
+        const int64_t n_embd_head = hparams.n_embd_head_k;
 
         ggml_tensor * cur;
         ggml_tensor * inpL;
@@ -7593,7 +7683,10 @@ struct llm_build_gemma3 : public llm_graph_context {
                 ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur);
                 cb(Vcur, "Vcur", il);
 
-                Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head_k, n_head,    n_tokens);
+                Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head,    n_tokens);
+                Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens);
+                Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens);
+
                 Qcur = build_norm(Qcur, model.layers[il].attn_q_norm, NULL, LLM_NORM_RMS, il);
                 cb(Qcur, "Qcur_normed", il);
 
@@ -7601,9 +7694,7 @@ struct llm_build_gemma3 : public llm_graph_context {
                         ctx0, Qcur, inp_pos, nullptr,
                         n_rot, rope_type, n_ctx_orig, freq_base_l, freq_scale_l,
                         ext_factor, attn_factor, beta_fast, beta_slow);
-                cb(Qcur, "Qcur", il);
 
-                Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head_k, n_head_kv, n_tokens);
                 Kcur = build_norm(Kcur, model.layers[il].attn_k_norm, NULL, LLM_NORM_RMS, il);
                 cb(Kcur, "Kcur_normed", il);
 
@@ -7611,7 +7702,10 @@ struct llm_build_gemma3 : public llm_graph_context {
                         ctx0, Kcur, inp_pos, nullptr,
                         n_rot, rope_type, n_ctx_orig, freq_base_l, freq_scale_l,
                         ext_factor, attn_factor, beta_fast, beta_slow);
+
+                cb(Qcur, "Qcur", il);
                 cb(Kcur, "Kcur", il);
+                cb(Vcur, "Vcur", il);
 
                 cur = build_attn(inp_attn, gf,
                         model.layers[il].wo, NULL,
@@ -7733,19 +7827,25 @@ struct llm_build_starcoder2 : public llm_graph_context {
                     cb(Vcur, "Vcur", il);
                 }
 
+                Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head,    n_tokens);
+                Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens);
+                Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens);
+
                 Qcur = ggml_rope_ext(
-                        ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos, nullptr,
+                        ctx0, Qcur, inp_pos, nullptr,
                         n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
                         ext_factor, attn_factor, beta_fast, beta_slow
                         );
-                cb(Qcur, "Qcur", il);
 
                 Kcur = ggml_rope_ext(
-                        ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos, nullptr,
+                        ctx0, Kcur, inp_pos, nullptr,
                         n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
                         ext_factor, attn_factor, beta_fast, beta_slow
                         );
+
+                cb(Qcur, "Qcur", il);
                 cb(Kcur, "Kcur", il);
+                cb(Vcur, "Vcur", il);
 
                 cur = build_attn(inp_attn, gf,
                         model.layers[il].wo, model.layers[il].bo,
@@ -8046,24 +8146,25 @@ struct llm_build_command_r : public llm_graph_context {
                     cb(Vcur, "Vcur", il);
                 }
 
-                if (model.layers[il].attn_q_norm) {
-                    Qcur = ggml_view_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens,
-                            ggml_element_size(Qcur) * n_embd_head,
-                            ggml_element_size(Qcur) * n_embd_head * n_head,
-                            0);
-                    cb(Qcur, "Qcur", il);
-                    Kcur = ggml_view_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens,
-                            ggml_element_size(Kcur) * n_embd_head,
-                            ggml_element_size(Kcur) * n_embd_head * n_head_kv,
-                            0);
-                    cb(Kcur, "Kcur", il);
+                Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head,    n_tokens);
+                Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens);
+                Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens);
 
+                if (model.layers[il].attn_q_norm) {
                     Qcur = build_norm(Qcur,
                             model.layers[il].attn_q_norm,
                             NULL,
                             LLM_NORM, il);
                     cb(Qcur, "Qcur", il);
+                }
 
+                Qcur = ggml_rope_ext(
+                        ctx0, Qcur, inp_pos, nullptr,
+                        n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
+                        ext_factor, attn_factor, beta_fast, beta_slow
+                        );
+
+                if (model.layers[il].attn_k_norm) {
                     Kcur = build_norm(Kcur,
                             model.layers[il].attn_k_norm,
                             NULL,
@@ -8071,19 +8172,15 @@ struct llm_build_command_r : public llm_graph_context {
                     cb(Kcur, "Kcur", il);
                 }
 
-                Qcur = ggml_rope_ext(
-                        ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos, nullptr,
-                        n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
-                        ext_factor, attn_factor, beta_fast, beta_slow
-                        );
-                cb(Qcur, "Qcur", il);
-
                 Kcur = ggml_rope_ext(
-                        ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos, nullptr,
+                        ctx0, Kcur, inp_pos, nullptr,
                         n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
                         ext_factor, attn_factor, beta_fast, beta_slow
                         );
+
+                cb(Qcur, "Qcur", il);
                 cb(Kcur, "Kcur", il);
+                cb(Vcur, "Vcur", il);
 
                 cur = build_attn(inp_attn, gf,
                         model.layers[il].wo, model.layers[il].bo,
@@ -8198,25 +8295,28 @@ struct llm_build_cohere2 : public llm_graph_context {
                     cb(Vcur, "Vcur", il);
                 }
 
-                if (is_swa) {
-                    Qcur = ggml_rope_ext(ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos, rope_factors,
-                            n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, ext_factor, attn_factor,
-                            beta_fast, beta_slow);
-                    cb(Qcur, "Qcur", il);
+                Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head,    n_tokens);
+                Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens);
+                Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens);
 
-                    Kcur = ggml_rope_ext(ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos,
-                            rope_factors, n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, ext_factor,
-                            attn_factor, beta_fast, beta_slow);
-                    cb(Kcur, "Kcur", il);
-                } else {
-                    // For non-sliding layers, just reshape without applying RoPE
-                    Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens);
-                    cb(Qcur, "Qcur", il);
+                if (is_swa) {
+                    Qcur = ggml_rope_ext(
+                            ctx0, Qcur, inp_pos, rope_factors,
+                            n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
+                            ext_factor, attn_factor, beta_fast, beta_slow
+                            );
 
-                    Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens);
-                    cb(Kcur, "Kcur", il);
+                    Kcur = ggml_rope_ext(
+                            ctx0, Kcur, inp_pos, rope_factors,
+                            n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
+                            ext_factor, attn_factor, beta_fast, beta_slow
+                            );
                 }
 
+                cb(Qcur, "Qcur", il);
+                cb(Kcur, "Kcur", il);
+                cb(Vcur, "Vcur", il);
+
                 cur = build_attn(inp_attn, gf,
                         model.layers[il].wo, model.layers[il].bo,
                         Qcur, Kcur, Vcur, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
@@ -8328,19 +8428,25 @@ struct llm_build_olmo : public llm_graph_context {
                     cb(Vcur, "Vcur", il);
                 }
 
+                Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head,    n_tokens);
+                Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens);
+                Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens);
+
                 Qcur = ggml_rope_ext(
-                        ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos, nullptr,
+                        ctx0, Qcur, inp_pos, nullptr,
                         n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
                         ext_factor, attn_factor, beta_fast, beta_slow
                         );
-                cb(Qcur, "Qcur", il);
 
                 Kcur = ggml_rope_ext(
-                        ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos, nullptr,
+                        ctx0, Kcur, inp_pos, nullptr,
                         n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
                         ext_factor, attn_factor, beta_fast, beta_slow
                         );
+
+                cb(Qcur, "Qcur", il);
                 cb(Kcur, "Kcur", il);
+                cb(Vcur, "Vcur", il);
 
                 cur = build_attn(inp_attn, gf,
                         model.layers[il].wo, nullptr,
@@ -8442,22 +8548,25 @@ struct llm_build_olmo2 : public llm_graph_context {
                         LLM_NORM_RMS, il);
                 cb(Kcur, "Kcur_normed", il);
 
-                Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens);
+                Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head,    n_tokens);
                 Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens);
+                Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens);
 
                 Qcur = ggml_rope_ext(
                         ctx0, Qcur, inp_pos, nullptr,
                         n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
                         ext_factor, attn_factor, beta_fast, beta_slow
                         );
-                cb(Qcur, "Qcur_rope", il);
 
                 Kcur = ggml_rope_ext(
                         ctx0, Kcur, inp_pos, nullptr,
                         n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
                         ext_factor, attn_factor, beta_fast, beta_slow
                         );
-                cb(Kcur, "Kcur_rope", il);
+
+                cb(Qcur, "Qcur", il);
+                cb(Kcur, "Kcur", il);
+                cb(Vcur, "Vcur", il);
 
                 cur = build_attn(inp_attn, gf,
                         model.layers[il].wo, NULL,
@@ -8572,22 +8681,25 @@ struct llm_build_olmoe : public llm_graph_context {
                         LLM_NORM_RMS, il);
                 cb(Kcur, "Kcur_normed", il);
 
-                Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens);
+                Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head,    n_tokens);
                 Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens);
+                Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens);
 
                 Qcur = ggml_rope_ext(
                         ctx0, Qcur, inp_pos, nullptr,
                         n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
                         ext_factor, attn_factor, beta_fast, beta_slow
                         );
-                cb(Qcur, "Qcur_rope", il);
 
                 Kcur = ggml_rope_ext(
                         ctx0, Kcur, inp_pos, nullptr,
                         n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
                         ext_factor, attn_factor, beta_fast, beta_slow
                         );
-                cb(Kcur, "Kcur_rope", il);
+
+                cb(Qcur, "Qcur", il);
+                cb(Kcur, "Kcur", il);
+                cb(Vcur, "Vcur", il);
 
                 cur = build_attn(inp_attn, gf,
                         model.layers[il].wo, NULL,
@@ -8687,7 +8799,7 @@ struct llm_build_openelm : public llm_graph_context {
 
                 cur = ggml_reshape_3d(ctx0, cur, n_embd_head_k, n_head_qkv, n_tokens);
 
-                ggml_tensor * Qcur = ggml_cont(ctx0, ggml_view_3d(ctx0, cur, n_embd_head, n_head, n_tokens, cur->nb[1], cur->nb[2], 0));
+                ggml_tensor * Qcur = ggml_cont(ctx0, ggml_view_3d(ctx0, cur, n_embd_head, n_head,    n_tokens, cur->nb[1], cur->nb[2], 0));
                 cb(Qcur, "Qcur", il);
 
                 ggml_tensor * Kcur = ggml_cont(ctx0, ggml_view_3d(ctx0, cur, n_embd_head, n_head_kv, n_tokens, cur->nb[1], cur->nb[2], cur->nb[1]*n_head));
@@ -8707,18 +8819,19 @@ struct llm_build_openelm : public llm_graph_context {
                 cb(Kcur, "Kcur", il);
 
                 Qcur = ggml_rope_ext(
-                        ctx0, Qcur, inp_pos, NULL, n_rot, rope_type, n_ctx_orig,
-                        freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow
+                        ctx0, Qcur, inp_pos, NULL,
+                        n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
+                        ext_factor, attn_factor, beta_fast, beta_slow
                         );
-                cb(Qcur, "Qcur", il);
 
                 Kcur = ggml_rope_ext(
-                        ctx0, Kcur, inp_pos, NULL, n_rot, rope_type, n_ctx_orig,
-                        freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow
+                        ctx0, Kcur, inp_pos, NULL,
+                        n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
+                        ext_factor, attn_factor, beta_fast, beta_slow
                         );
-                cb(Kcur, "Kcur", il);
 
-                Vcur = ggml_reshape_2d(ctx0, Vcur, n_embd_head * n_head_kv, n_tokens);
+                cb(Qcur, "Qcur", il);
+                cb(Kcur, "Kcur", il);
                 cb(Qcur, "Vcur", il);
 
                 cur = build_attn(inp_attn, gf,
@@ -8815,23 +8928,25 @@ struct llm_build_gptneox : public llm_graph_context {
                 ggml_tensor * Kcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd)));
                 ggml_tensor * Vcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd + n_embd_gqa)));
 
-                cb(Qcur, "Qcur", il);
-                cb(Kcur, "Kcur", il);
-                cb(Vcur, "Vcur", il);
+                Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head,    n_tokens);
+                Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens);
+                Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens);
 
                 Qcur = ggml_rope_ext(
-                        ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos, nullptr,
+                        ctx0, Qcur, inp_pos, nullptr,
                         n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
                         ext_factor, attn_factor, beta_fast, beta_slow
                         );
-                cb(Qcur, "Qcur", il);
 
                 Kcur = ggml_rope_ext(
-                        ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos, nullptr,
+                        ctx0, Kcur, inp_pos, nullptr,
                         n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
                         ext_factor, attn_factor, beta_fast, beta_slow
                         );
+
+                cb(Qcur, "Qcur", il);
                 cb(Kcur, "Kcur", il);
+                cb(Vcur, "Vcur", il);
 
                 cur = build_attn(inp_attn, gf,
                         model.layers[il].wo, model.layers[il].bo,
@@ -8963,19 +9078,25 @@ struct llm_build_arctic : public llm_graph_context {
                 ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur);
                 cb(Vcur, "Vcur", il);
 
+                Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head,    n_tokens);
+                Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens);
+                Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens);
+
                 Qcur = ggml_rope_ext(
-                        ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos, nullptr,
+                        ctx0, Qcur, inp_pos, nullptr,
                         n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
                         ext_factor, attn_factor, beta_fast, beta_slow
                         );
-                cb(Qcur, "Qcur", il);
 
                 Kcur = ggml_rope_ext(
-                        ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos, nullptr,
+                        ctx0, Kcur, inp_pos, nullptr,
                         n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
                         ext_factor, attn_factor, beta_fast, beta_slow
                         );
+
+                cb(Qcur, "Qcur", il);
                 cb(Kcur, "Kcur", il);
+                cb(Vcur, "Vcur", il);
 
                 cur = build_attn(inp_attn, gf,
                         model.layers[il].wo, NULL,
@@ -9112,19 +9233,25 @@ struct llm_build_deepseek : public llm_graph_context {
                     cb(Vcur, "Vcur", il);
                 }
 
+                Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head,    n_tokens);
+                Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens);
+                Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens);
+
                 Qcur = ggml_rope_ext(
-                        ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos, rope_factors,
+                        ctx0, Qcur, inp_pos, rope_factors,
                         n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
                         ext_factor, attn_factor, beta_fast, beta_slow
                         );
-                cb(Qcur, "Qcur", il);
 
                 Kcur = ggml_rope_ext(
-                        ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos, rope_factors,
+                        ctx0, Kcur, inp_pos, rope_factors,
                         n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
                         ext_factor, attn_factor, beta_fast, beta_slow
                         );
+
+                cb(Qcur, "Qcur", il);
                 cb(Kcur, "Kcur", il);
+                cb(Vcur, "Vcur", il);
 
                 cur = build_attn(inp_attn, gf,
                         model.layers[il].wo, model.layers[il].bo,
@@ -9502,19 +9629,25 @@ struct llm_build_bitnet : public llm_graph_context {
                     cb(Vcur, "Vcur", il);
                 }
 
+                Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head,    n_tokens);
+                Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens);
+                Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens);
+
                 Qcur = ggml_rope_ext(
-                        ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos, nullptr,
+                        ctx0, Qcur, inp_pos, nullptr,
                         n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
                         ext_factor, attn_factor, beta_fast, beta_slow
                         );
-                cb(Qcur, "Qcur", il);
 
                 Kcur = ggml_rope_ext(
-                        ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos, nullptr,
+                        ctx0, Kcur, inp_pos, nullptr,
                         n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
                         ext_factor, attn_factor, beta_fast, beta_slow
                         );
+
+                cb(Qcur, "Qcur", il);
                 cb(Kcur, "Kcur", il);
+                cb(Vcur, "Vcur", il);
 
                 cur = build_attn(inp_attn, gf,
                         NULL, NULL,
@@ -9906,7 +10039,9 @@ struct llm_build_jais : public llm_graph_context {
                 cb(Kcur, "Kcur", il);
                 cb(Vcur, "Vcur", il);
 
-                Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens);
+                Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head,    n_tokens);
+                Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens);
+                Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens);
 
                 cur = build_attn(inp_attn, gf,
                         model.layers[il].wo, model.layers[il].bo,
@@ -10019,29 +10154,30 @@ struct llm_build_chatglm : public llm_graph_context {
                     Vcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd + n_embd_gqa)));
                 }
 
-                cb(Qcur, "Qcur", il);
-                cb(Kcur, "Kcur", il);
-                cb(Vcur, "Vcur", il);
+                Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head,    n_tokens);
+                Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens);
+                Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens);
 
                 //printf("freq_base: %f freq_scale: %f ext_factor: %f attn_factor: %f\n", freq_base, freq_scale, ext_factor, attn_factor);
                 Qcur = ggml_rope_ext(
-                        ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos, nullptr,
+                        ctx0, Qcur, inp_pos, nullptr,
                         n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
                         ext_factor, attn_factor, beta_fast, beta_slow
                         );
-                cb(Qcur, "Qcur_rope", il);
 
                 Kcur = ggml_rope_ext(
-                        ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos, nullptr,
+                        ctx0, Kcur, inp_pos, nullptr,
                         n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
                         ext_factor, attn_factor, beta_fast, beta_slow
                         );
-                cb(Kcur, "Kcur_rope", il);
+
+                cb(Qcur, "Qcur", il);
+                cb(Kcur, "Kcur", il);
+                cb(Vcur, "Vcur", il);
 
                 cur = build_attn(inp_attn, gf,
                         model.layers[il].wo, NULL,
                         Qcur, Kcur, Vcur, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
-
             }
 
             if (il == n_layer - 1) {
@@ -10145,19 +10281,25 @@ struct llm_build_nemotron : public llm_graph_context {
                     cb(Vcur, "Vcur", il);
                 }
 
+                Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head,    n_tokens);
+                Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens);
+                Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens);
+
                 Qcur = ggml_rope_ext(
-                        ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos, nullptr,
+                        ctx0, Qcur, inp_pos, nullptr,
                         n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
                         ext_factor, attn_factor, beta_fast, beta_slow
                         );
-                cb(Qcur, "Qcur", il);
 
                 Kcur = ggml_rope_ext(
-                        ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos, nullptr,
+                        ctx0, Kcur, inp_pos, nullptr,
                         n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
                         ext_factor, attn_factor, beta_fast, beta_slow
                         );
+
+                cb(Qcur, "Qcur", il);
                 cb(Kcur, "Kcur", il);
+                cb(Vcur, "Vcur", il);
 
                 cur = build_attn(inp_attn, gf,
                         model.layers[il].wo, model.layers[il].bo,
@@ -10270,19 +10412,25 @@ struct llm_build_exaone : public llm_graph_context {
                     cb(Vcur, "Vcur", il);
                 }
 
+                Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head,    n_tokens);
+                Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens);
+                Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens);
+
                 Qcur = ggml_rope_ext(
-                        ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos, rope_factors,
+                        ctx0, Qcur, inp_pos, rope_factors,
                         n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
                         ext_factor, attn_factor, beta_fast, beta_slow
                         );
-                cb(Qcur, "Qcur", il);
 
                 Kcur = ggml_rope_ext(
-                        ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos, rope_factors,
+                        ctx0, Kcur, inp_pos, rope_factors,
                         n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
                         ext_factor, attn_factor, beta_fast, beta_slow
                         );
+
+                cb(Qcur, "Qcur", il);
                 cb(Kcur, "Kcur", il);
+                cb(Vcur, "Vcur", il);
 
                 cur = build_attn(inp_attn, gf,
                         model.layers[il].wo, model.layers[il].bo,
@@ -11166,19 +11314,25 @@ struct llm_build_chameleon : public llm_graph_context {
                     cb(Kcur, "Kcur", il);
                 }
 
+                Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head,    n_tokens);
+                Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens);
+                Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens);
+
                 Qcur = ggml_rope_ext(
-                        ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos, nullptr,
+                        ctx0, Qcur, inp_pos, nullptr,
                         n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
                         ext_factor, attn_factor, beta_fast, beta_slow
                         );
-                cb(Qcur, "Qcur", il);
 
                 Kcur = ggml_rope_ext(
-                        ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos, nullptr,
+                        ctx0, Kcur, inp_pos, nullptr,
                         n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
                         ext_factor, attn_factor, beta_fast, beta_slow
                         );
+
+                cb(Qcur, "Qcur", il);
                 cb(Kcur, "Kcur", il);
+                cb(Vcur, "Vcur", il);
 
                 cur = build_attn(inp_attn, gf,
                         model.layers[il].wo, nullptr,