]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
metal : fix floating-point range of attention scores in FA kernels (#13090)
authorGeorgi Gerganov <redacted>
Thu, 24 Apr 2025 07:38:30 +0000 (10:38 +0300)
committerGitHub <redacted>
Thu, 24 Apr 2025 07:38:30 +0000 (10:38 +0300)
ggml-ci

ggml/src/ggml-metal/ggml-metal.metal

index 8d6e99e621e9e3933b149a15e6eb1007f593ab90..9f4147e93974d21c33b39be468145cd7bf8d8cd5 100644 (file)
@@ -3192,7 +3192,7 @@ kernel void kernel_flash_attn_ext(
 
     {
         float S[Q] = { [0 ... Q-1] = 0.0f };
-        float M[Q] = { [0 ... Q-1] = -__FLT16_MAX__/2 };
+        float M[Q] = { [0 ... Q-1] = -__FLT_MAX__/2 };
 
         // thread indices inside the simdgroup
         // TODO: see if we can utilize quad-group functions for better performance
@@ -3452,7 +3452,7 @@ kernel void kernel_flash_attn_ext(
     // reduce the warps sequentially
     for (ushort sg = 1; sg < nsg; ++sg) {
         float S = { 0.0f };
-        float M = { -__FLT16_MAX__/2 };
+        float M = { -__FLT_MAX__/2 };
 
         threadgroup_barrier(mem_flags::mem_threadgroup);
 
@@ -3699,7 +3699,7 @@ kernel void kernel_flash_attn_ext_vec(
 
     {
         float S = 0.0f;
-        float M = -__FLT16_MAX__/2;
+        float M = -__FLT_MAX__/2;
 
         // thread indices inside the simdgroup
         const short tx = tiisg%NL;