]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
gemma : more consistent attention scaling for v2 and v3 (#13951)
authorGeorgi Gerganov <redacted>
Mon, 2 Jun 2025 17:54:26 +0000 (20:54 +0300)
committerGitHub <redacted>
Mon, 2 Jun 2025 17:54:26 +0000 (20:54 +0300)
* gemma : fix attn scale for 27B

* cont : apply scale before attn

* cont : consistent attention scaling

src/llama-model.cpp

index 50264a69aac0ebd882c54009d893a3402bbfd069..afef8487030fbdf70c0d47981e50ae503a28fc90 100644 (file)
@@ -956,6 +956,11 @@ void llama_model::load_hparams(llama_model_loader & ml) {
                     case 46: type = LLM_TYPE_27B; break;
                     default: type = LLM_TYPE_UNKNOWN;
                }
+
+                // ref: https://github.com/google/gemma_pytorch/blob/014acb7ac4563a5f77c76d7ff98f31b568c16508/gemma/config.py#L173
+                hparams.f_attention_scale = type == LLM_TYPE_27B
+                    ? 1.0f / std::sqrt(float(hparams.n_embd / hparams.n_head(0)))
+                    : 1.0f / std::sqrt(float(hparams.n_embd_head_k));
             } break;
         case LLM_ARCH_GEMMA3:
             {
@@ -976,6 +981,7 @@ void llama_model::load_hparams(llama_model_loader & ml) {
                     default: type = LLM_TYPE_UNKNOWN;
                 }
 
+                // ref: https://github.com/google/gemma_pytorch/blob/014acb7ac4563a5f77c76d7ff98f31b568c16508/gemma/config.py#L289
                 hparams.f_attention_scale = type == LLM_TYPE_27B
                     ? 1.0f / std::sqrt(float(hparams.n_embd / hparams.n_head(0)))
                     : 1.0f / std::sqrt(float(hparams.n_embd_head_k));
@@ -8484,14 +8490,7 @@ struct llm_build_gemma2_iswa : public llm_graph_context {
                 cb(Kcur, "Kcur", il);
                 cb(Vcur, "Vcur", il);
 
-                // ref: https://github.com/google/gemma_pytorch/commit/03e657582d17cb5a8617ebf333c1c16f3694670e
-                switch (model.type) {
-                    case LLM_TYPE_2B:
-                    case LLM_TYPE_9B:
-                    case LLM_TYPE_27B: Qcur = ggml_scale(ctx0, Qcur, 1.0f / sqrtf(float(n_embd_head))); break;
-                    default: GGML_ABORT("fatal error");
-                };
-                cb(Qcur, "Qcur_scaled", il);
+                Qcur = ggml_scale(ctx0, Qcur, hparams.f_attention_scale);
 
                 cur = build_attn(inp_attn, gf,
                         model.layers[il].wo, NULL,
@@ -8632,9 +8631,12 @@ struct llm_build_gemma3_iswa : public llm_graph_context {
                 cb(Kcur, "Kcur", il);
                 cb(Vcur, "Vcur", il);
 
+                // ref: https://github.com/google/gemma_pytorch/blob/014acb7ac4563a5f77c76d7ff98f31b568c16508/gemma/model.py#L315
+                Qcur = ggml_scale(ctx0, Qcur, hparams.f_attention_scale);
+
                 cur = build_attn(inp_attn, gf,
                         model.layers[il].wo, NULL,
-                        Qcur, Kcur, Vcur, nullptr, nullptr, hparams.f_attention_scale, il);
+                        Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f, il);
             }
 
             cur = build_norm(cur,