]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
CUDA: fix Gemma 2 numerical issues for FA (#9166)
authorJohannes Gäßler <redacted>
Sun, 25 Aug 2024 20:11:48 +0000 (22:11 +0200)
committerGitHub <redacted>
Sun, 25 Aug 2024 20:11:48 +0000 (22:11 +0200)
src/llama.cpp

index aeea54cffe02032bee08d3ed3d52b63e6f28550e..fc8fb3e0ddef2a87ca274fb676279678f7b67ecc 100644 (file)
@@ -8877,7 +8877,7 @@ static struct ggml_tensor * llm_build_kqv(
         cur = ggml_flash_attn_ext(ctx, q, k, v, kq_mask, kq_scale, hparams.f_max_alibi_bias,
                                   hparams.attn_soft_cap ? hparams.f_attn_logit_softcapping : 0.0f);
 
-        if (model.arch == LLM_ARCH_PHI2 || model.arch == LLM_ARCH_PHI3 || model.arch == LLM_ARCH_GPTNEOX) {
+        if (model.arch == LLM_ARCH_PHI2 || model.arch == LLM_ARCH_PHI3 || model.arch == LLM_ARCH_GPTNEOX || model.arch == LLM_ARCH_GEMMA2) {
             ggml_flash_attn_ext_set_prec(cur, GGML_PREC_F32);
         }