]> git.djapps.eu Git - pkg/ggml/sources/whisper.cpp/commitdiff
CUDA: fix FTZ in FA for Gemma 3 (llama/13991)
authorJohannes Gäßler <redacted>
Wed, 4 Jun 2025 06:57:05 +0000 (08:57 +0200)
committerGeorgi Gerganov <redacted>
Tue, 10 Jun 2025 09:40:33 +0000 (12:40 +0300)
ggml/src/ggml-cuda/fattn-mma-f16.cuh

index 925f39e890db927eee9c27cef3f50a575838e7af..e230f6d494d77fdceec51d94ad79bcff46d03601 100644 (file)
@@ -652,9 +652,12 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
         float KQ_max_scale[cols_per_thread];
 #pragma unroll
         for (int col = 0; col < cols_per_thread; ++col) {
-            KQ_max_scale[col] = expf(KQ_max[col] - KQ_max_new[col]);
+            const float KQ_max_diff = KQ_max[col] - KQ_max_new[col];
+            KQ_max_scale[col] = expf(KQ_max_diff);
             KQ_max[col] = KQ_max_new[col];
 
+            *((uint32_t *) &KQ_max_scale[col]) *= KQ_max_diff >= SOFTMAX_FTZ_THRESHOLD;
+
             // Scale previous KQ_rowsum to account for a potential increase in KQ_max:
             KQ_rowsum[col] = KQ_max_scale[col]*KQ_rowsum[col] + KQ_rowsum_add[col];
         }