From: Johannes Gäßler Date: Wed, 4 Jun 2025 06:57:05 +0000 (+0200) Subject: CUDA: fix FTZ in FA for Gemma 3 (llama/13991) X-Git-Url: https://git.djapps.eu/?a=commitdiff_plain;h=7f4d110f531648edaee6d0bab2989a75e9ee1927;p=pkg%2Fggml%2Fsources%2Fwhisper.cpp CUDA: fix FTZ in FA for Gemma 3 (llama/13991) --- diff --git a/ggml/src/ggml-cuda/fattn-mma-f16.cuh b/ggml/src/ggml-cuda/fattn-mma-f16.cuh index 925f39e8..e230f6d4 100644 --- a/ggml/src/ggml-cuda/fattn-mma-f16.cuh +++ b/ggml/src/ggml-cuda/fattn-mma-f16.cuh @@ -652,9 +652,12 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter( float KQ_max_scale[cols_per_thread]; #pragma unroll for (int col = 0; col < cols_per_thread; ++col) { - KQ_max_scale[col] = expf(KQ_max[col] - KQ_max_new[col]); + const float KQ_max_diff = KQ_max[col] - KQ_max_new[col]; + KQ_max_scale[col] = expf(KQ_max_diff); KQ_max[col] = KQ_max_new[col]; + *((uint32_t *) &KQ_max_scale[col]) *= KQ_max_diff >= SOFTMAX_FTZ_THRESHOLD; + // Scale previous KQ_rowsum to account for a potential increase in KQ_max: KQ_rowsum[col] = KQ_max_scale[col]*KQ_rowsum[col] + KQ_rowsum_add[col]; }