]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
llama : fix Gemma-2 Query scaling factors (#8473)
authorGeorgi Gerganov <redacted>
Sun, 14 Jul 2024 11:05:09 +0000 (14:05 +0300)
committerGitHub <redacted>
Sun, 14 Jul 2024 11:05:09 +0000 (14:05 +0300)
* 9B - query_pre_attn_scalar = 256 not 224

See https://github.com/google/gemma_pytorch/commit/03e657582d17cb5a8617ebf333c1c16f3694670e

Gemma 9b should use 256 and not 224 (self.config.hidden_size // self.config.num_attention_heads)

* llama : fix Gemma-2 Query scaling factor

ggml-ci

---------

Co-authored-by: Daniel Han <redacted>
convert_hf_to_gguf.py
src/llama.cpp

index af82cd6cd05c996c18ebc195df71a75340c3e082..42dace219f20fcb87248ac4e989e6eed2d102bd9 100755 (executable)
@@ -2504,11 +2504,6 @@ class Gemma2Model(Model):
         )
         self.gguf_writer.add_sliding_window(self.hparams["sliding_window"])
 
-        # sanity check
-        attn_scalar = self.hparams["query_pre_attn_scalar"]
-        if attn_scalar != hparams["hidden_size"] / hparams["num_attention_heads"]:
-            raise ValueError("query_pre_attn_scalar must be equal to n_embd / n_head")
-
     def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
         del bid  # unused
 
index 77d34dca280fc5fb56ffa1b1c17bbe372d98458f..400a4232beeb09dff60a71fe8bd8ccf11d0a7a48 100644 (file)
@@ -11680,7 +11680,12 @@ struct llm_build_context {
                         ext_factor, attn_factor, beta_fast, beta_slow);
                 cb(Qcur, "Qcur", il);
 
-                Qcur = ggml_scale(ctx0, Qcur, 1.0f / sqrtf(float(n_embd / n_head)));
+                // ref: https://github.com/google/gemma_pytorch/commit/03e657582d17cb5a8617ebf333c1c16f3694670e
+                switch (model.type) {
+                    case e_model::MODEL_9B:  Qcur = ggml_scale(ctx0, Qcur, 1.0f / sqrtf(float(n_embd_head_k)));   break;
+                    case e_model::MODEL_27B: Qcur = ggml_scale(ctx0, Qcur, 1.0f / sqrtf(float(n_embd / n_head))); break;
+                    default: GGML_ASSERT(false);
+                };
                 cb(Qcur, "Qcur_scaled", il);
 
                 Kcur = ggml_rope_ext(