]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
llama : fix Gemma3 SWA KV cache shift (#12373)
authorGeorgi Gerganov <redacted>
Thu, 13 Mar 2025 17:08:07 +0000 (19:08 +0200)
committerGitHub <redacted>
Thu, 13 Mar 2025 17:08:07 +0000 (19:08 +0200)
* llama : fix Gemma3 SWA KV cache shift

ggml-ci

* hparams : add comment [no ci]

src/llama-context.cpp
src/llama-context.h
src/llama-graph.cpp
src/llama-hparams.cpp
src/llama-hparams.h
src/llama-model.cpp

index 0a43a3af8e0035a73bee0972a99d5d24e4b7c314..89fb33cbcdae25ad551196555ed213ad421ead46 100644 (file)
@@ -442,10 +442,10 @@ ggml_tensor * llama_context::build_rope_shift(
         ggml_tensor * cur,
         ggml_tensor * shift,
         ggml_tensor * factors,
+              float   freq_base,
+              float   freq_scale,
         ggml_backend_buffer * bbuf) const {
     const auto & n_ctx_orig = cparams.n_ctx_orig_yarn;
-    const auto & freq_base  = cparams.rope_freq_base;
-    const auto & freq_scale = cparams.rope_freq_scale;
 
     const auto & yarn_ext_factor  = cparams.yarn_ext_factor;
     const auto & yarn_attn_factor = cparams.yarn_attn_factor;
@@ -537,6 +537,17 @@ llm_graph_result_ptr llama_context::build_kv_self_shift(
         const int64_t n_head_kv    = hparams.n_head_kv(il);
         const int64_t n_embd_k_gqa = hparams.n_embd_k_gqa(il);
 
+        float freq_base_l  = cparams.rope_freq_base;
+        float freq_scale_l = cparams.rope_freq_scale;
+
+        // TODO: improve
+        if (model.arch == LLM_ARCH_GEMMA3) {
+            const bool is_sliding = hparams.is_sliding(il);
+
+            freq_base_l  = is_sliding ? 10000.0f : cparams.rope_freq_base;
+            freq_scale_l = is_sliding ? 1.0f     : cparams.rope_freq_scale;
+        }
+
         ggml_tensor * rope_factors = kv_self->cbs.get_rope_factors(n_ctx_per_seq(), il);
 
         ggml_tensor * k =
@@ -546,7 +557,7 @@ llm_graph_result_ptr llama_context::build_kv_self_shift(
                 ggml_row_size(kv_self->k_l[il]->type, n_embd_k_gqa),
                 0);
 
-        ggml_tensor * cur = build_rope_shift(ctx0, k, inp->k_shift, rope_factors, kv_self->k_l[il]->buffer);
+        ggml_tensor * cur = build_rope_shift(ctx0, k, inp->k_shift, rope_factors, freq_base_l, freq_scale_l, kv_self->k_l[il]->buffer);
 
         ggml_build_forward_expand(gf, cur);
     }
index 71d702e8baeeb0b14584a1f61c0dcd9ab4ff1a2e..88df8950e4cb0f6039ea50c87aa7f1b45d944f59 100644 (file)
@@ -168,6 +168,8 @@ private:
         ggml_tensor * cur,
         ggml_tensor * shift,
         ggml_tensor * factors,
+              float   freq_base,
+              float   freq_scale,
         ggml_backend_buffer * bbuf) const;
 
     llm_graph_result_ptr build_kv_self_shift(
index 1e3f2efc89d2c05f92dddafbffca780793ef134d..4a53e83929f4173f8d833df71961cd3db4612a4f 100644 (file)
@@ -1403,34 +1403,7 @@ ggml_tensor * llm_graph_context::build_attn(
         ggml_build_forward_expand(gf, ggml_cpy(ctx0, v_cur, v_cache_view));
     }
 
-    // TODO: improve
-    bool is_sliding = false;
-
-    switch (arch) {
-        case LLM_ARCH_COHERE2:
-            {
-                const int32_t sliding_window_pattern = 4;
-                is_sliding = il % sliding_window_pattern < (sliding_window_pattern - 1);
-            } break;
-        case LLM_ARCH_GEMMA2:
-            {
-                const int32_t sliding_window_pattern = 2;
-                is_sliding = il % sliding_window_pattern < (sliding_window_pattern - 1);
-            } break;
-        case LLM_ARCH_GEMMA3:
-            {
-                const int32_t sliding_window_pattern = 6;
-                is_sliding = il % sliding_window_pattern < (sliding_window_pattern - 1);
-            } break;
-        case LLM_ARCH_PHI3:
-            {
-                is_sliding = hparams.n_swa > 0;
-            } break;
-        default:
-            {
-                is_sliding = false;
-            }
-    };
+    const bool is_sliding = hparams.is_sliding(il);
 
     const auto & kq_mask = is_sliding ? inp->get_kq_mask_swa() : inp->get_kq_mask();
 
index ea87b2953d9ddb7af448bc84b4657d8882d1ba53..58e98bf2311dba89b8f9e6ac93e550214c04a515 100644 (file)
@@ -69,3 +69,11 @@ uint32_t llama_hparams::n_embd_v_s() const {
     // corresponds to Mamba's ssm_states size
     return ssm_d_state * ssm_d_inner;
 }
+
+bool llama_hparams::is_sliding(uint32_t il) const {
+    if (il < n_layer) {
+        return n_swa > 0 && n_swa_pattern > 0 && il % n_swa_pattern < (n_swa_pattern - 1);
+    }
+
+    GGML_ABORT("fatal error");
+}
index 1fe45410371b9a87d6425feb30df0fa14f5868f8..e3091c8127dd5d500ec085e4a5545d6333e978ea 100644 (file)
@@ -36,6 +36,7 @@ struct llama_hparams {
     uint32_t n_layer;
     uint32_t n_rot;
     uint32_t n_swa = 0; // sliding window attention (SWA)
+    uint32_t n_swa_pattern = 1; // by default, all layers use non-sliding-window attention
     uint32_t n_embd_head_k; // dimension of keys (d_k). d_q is assumed to be the same, but there are n_head q heads, and only n_head_kv k-v heads
     uint32_t n_embd_head_v; // dimension of values (d_v) aka n_embd_head
     uint32_t n_expert = 0;
@@ -133,6 +134,8 @@ struct llama_hparams {
 
     // dimension of the recurrent state embeddings
     uint32_t n_embd_v_s() const;
+
+    bool is_sliding(uint32_t il) const;
 };
 
 static_assert(std::is_trivially_copyable<llama_hparams>::value, "llama_hparams must be trivially copyable");
index 522219c0122428efeb593301e2beabbc6218aa07..5647d2ad6245b494bd17a3461fbdfac8386630f4 100644 (file)
@@ -858,11 +858,13 @@ void llama_model::load_hparams(llama_model_loader & ml) {
         case LLM_ARCH_GEMMA2:
             {
                 hparams.n_swa = 4096; // default value of gemma 2
+                hparams.n_swa_pattern = 2;
+                hparams.attn_soft_cap = true;
+
                 ml.get_key(LLM_KV_ATTENTION_SLIDING_WINDOW,    hparams.n_swa, false);
                 ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
                 ml.get_key(LLM_KV_ATTN_LOGIT_SOFTCAPPING,      hparams.f_attn_logit_softcapping, false);
                 ml.get_key(LLM_KV_FINAL_LOGIT_SOFTCAPPING,     hparams.f_final_logit_softcapping, false);
-                hparams.attn_soft_cap = true;
 
                 switch (hparams.n_layer) {
                     case 26: type = LLM_TYPE_2B; break;
@@ -873,6 +875,8 @@ void llama_model::load_hparams(llama_model_loader & ml) {
             } break;
         case LLM_ARCH_GEMMA3:
             {
+                hparams.n_swa_pattern = 6;
+
                 ml.get_key(LLM_KV_ATTENTION_SLIDING_WINDOW,    hparams.n_swa);
                 ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
 
@@ -952,6 +956,8 @@ void llama_model::load_hparams(llama_model_loader & ml) {
             } break;
         case LLM_ARCH_COHERE2:
             {
+                hparams.n_swa_pattern = 4;
+
                 ml.get_key(LLM_KV_ATTENTION_SLIDING_WINDOW, hparams.n_swa);
                 ml.get_key(LLM_KV_LOGIT_SCALE,              hparams.f_logit_scale);
                 ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS,  hparams.f_norm_eps);
@@ -7374,12 +7380,8 @@ struct llm_build_gemma3 : public llm_graph_context {
         // TODO: is causal == true correct? might need some changes
         auto * inp_attn = build_attn_inp_kv_unified(true, true);
 
-        // "5-to-1 interleaved attention"
-        // 5 layers of local attention followed by 1 layer of global attention
-        static const int sliding_window_pattern = 6;
-
         for (int il = 0; il < n_layer; ++il) {
-            const bool is_sliding = il % sliding_window_pattern < (sliding_window_pattern - 1);
+            const bool is_sliding = hparams.is_sliding(il);
 
             const float freq_base_l  = is_sliding ? 10000.0f : freq_base;
             const float freq_scale_l = is_sliding ? 1.0f     : freq_scale;
@@ -7970,13 +7972,8 @@ struct llm_build_cohere2 : public llm_graph_context {
 
         auto * inp_attn = build_attn_inp_kv_unified(true, true);
 
-        // sliding window switch pattern
-        const int32_t sliding_window_pattern = 4;
-
         for (int il = 0; il < n_layer; ++il) {
-            // three layers sliding window attention (window size 4096) and ROPE
-            // fourth layer uses global attention without positional embeddings
-            const bool is_sliding = il % sliding_window_pattern < (sliding_window_pattern - 1);
+            const bool is_sliding = hparams.is_sliding(il);
 
             // norm
             cur = build_norm(inpL, model.layers[il].attn_norm, NULL, LLM_NORM, il);