]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
llama : fix KV shift for qwen2vl (#13870)
authorXuan-Son Nguyen <redacted>
Wed, 28 May 2025 20:35:31 +0000 (22:35 +0200)
committerGitHub <redacted>
Wed, 28 May 2025 20:35:31 +0000 (22:35 +0200)
* llama : fix KV shift for qwen2vl

* add ref to the PR

src/llama-graph.cpp
src/llama-kv-cache.cpp

index 7a91ff3df8885ccc393c809b0b39cb2d9d624789..7c383e2eb3f27e53b5f0678f33ec724b94456714 100644 (file)
@@ -455,7 +455,7 @@ llm_graph_context::llm_graph_context(const llm_graph_params & params) :
     }
 
 int64_t llm_graph_context::n_pos_per_embd() const {
-    return arch == LLM_ARCH_QWEN2VL ? 4 : 1;
+    return hparams.rope_type == LLAMA_ROPE_TYPE_MROPE ? 4 : 1;
 }
 
 void llm_graph_context::cb(ggml_tensor * cur, const char * name, int il) const {
index 4a42d6ecdc4556f7528810afce958d2854d50f6f..766f8d079afb26e5242cb1f1f519b597069028a9 100644 (file)
@@ -757,11 +757,19 @@ ggml_tensor * llama_kv_cache_unified::build_rope_shift(
     const auto & yarn_beta_slow  = cparams.yarn_beta_slow;
 
     const auto & n_rot     = hparams.n_rot;
-    const auto & rope_type = hparams.rope_type;
+    const auto & rope_type = hparams.rope_type == LLAMA_ROPE_TYPE_MROPE
+                                // @ngxson : this is a workaround
+                                // for M-RoPE, we want to rotate the whole vector when doing KV shift
+                                // a normal RoPE should work, we just need to use the correct ordering
+                                // ref: https://github.com/ggml-org/llama.cpp/pull/13870
+                                ? LLAMA_ROPE_TYPE_NEOX
+                                : hparams.rope_type;
 
     // See llm_build_deepseek2() for why attn_factor has to be scaled for YaRN RoPE to work correctly.
     // See https://github.com/ggerganov/llama.cpp/discussions/7416 for detailed explanation.
-    const float yarn_attn_factor = model.arch == LLM_ARCH_DEEPSEEK2 ? 1.0f / (1.0f + 0.1f * logf(1.0f / freq_scale)) : cparams.yarn_attn_factor;
+    const float yarn_attn_factor = model.arch == LLM_ARCH_DEEPSEEK2
+                                    ? 1.0f / (1.0f + 0.1f * logf(1.0f / freq_scale))
+                                    : cparams.yarn_attn_factor;
 
     ggml_tensor * tmp;