]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
Adding Gemma 2 2B configs (#8784)
authorpculliton <redacted>
Wed, 31 Jul 2024 15:12:10 +0000 (11:12 -0400)
committerGitHub <redacted>
Wed, 31 Jul 2024 15:12:10 +0000 (17:12 +0200)
* Adding Gemma 2 2B configs

Updates to Q scaling and Gemma 2 model sizes to match v2 2B model.

* Update src/llama.cpp

Co-authored-by: slaren <redacted>
---------

Co-authored-by: slaren <redacted>
src/llama.cpp

index a207451f585071016d2571b429989984eec0ea4c..e6f303d31b3bff99fa28c78c1818e0d0d1f23aeb 100644 (file)
@@ -4969,6 +4969,7 @@ static void llm_load_hparams(
                 hparams.attn_soft_cap = true;
 
                 switch (hparams.n_layer) {
+                    case 26: model.type = e_model::MODEL_2B; break;
                     case 42: model.type = e_model::MODEL_9B; break;
                     case 46: model.type = e_model::MODEL_27B; break;
                     default: model.type = e_model::MODEL_UNKNOWN;
@@ -11736,6 +11737,7 @@ struct llm_build_context {
 
                 // ref: https://github.com/google/gemma_pytorch/commit/03e657582d17cb5a8617ebf333c1c16f3694670e
                 switch (model.type) {
+                    case e_model::MODEL_2B:
                     case e_model::MODEL_9B:  Qcur = ggml_scale(ctx0, Qcur, 1.0f / sqrtf(float(n_embd_head_k)));   break;
                     case e_model::MODEL_27B: Qcur = ggml_scale(ctx0, Qcur, 1.0f / sqrtf(float(n_embd / n_head))); break;
                     default: GGML_ABORT("fatal error");