]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
CUDA: fix FA VKQ accumulator overflow (#17746)
authorJohannes Gäßler <redacted>
Fri, 5 Dec 2025 08:18:10 +0000 (09:18 +0100)
committerGitHub <redacted>
Fri, 5 Dec 2025 08:18:10 +0000 (09:18 +0100)
ggml/src/ggml-cuda/fattn-common.cuh
ggml/src/ggml-cuda/fattn-mma-f16.cuh
ggml/src/ggml-cuda/fattn-tile.cuh
ggml/src/ggml-cuda/fattn-vec.cuh
ggml/src/ggml-cuda/fattn-wmma-f16.cu

index 02443b8c638294ee143e8c6bd8cc354be78a30df..2750117aa973d2f2ef72b113e04244f38455c146 100644 (file)
 #define HALF_MAX_HALF         __float2half(65504.0f/2) // Use neg. of this instead of -INFINITY to initialize KQ max vals to avoid NaN upon subtraction.
 #define SOFTMAX_FTZ_THRESHOLD -20.0f                   // Softmax exp. of values smaller than this are flushed to zero to avoid NaNs.
 
+// log(2) = 0.6931, by adding this to the KQ maximum used for the softmax the numerical range representable
+//     by the VKQ accumulators is effectively being shifted up by a factor of 8.
+// This reduces issues with numerical overflow but also causes larger values to be flushed to zero.
+// However, as the output from FlashAttention will usually be used as an input for a matrix multiplication this should be negligible.
+#define FATTN_KQ_MAX_OFFSET 0.6931f
+
 typedef void (* fattn_kernel_t)(
         const char * __restrict__ Q,
         const char * __restrict__ K,
index b6250cf7949d160e1b5bc0de6a915dbb722ba7e2..ade0773dad84c68ff6d127ca601bddd76cb0ea6c 100644 (file)
@@ -532,7 +532,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
 #pragma unroll
             for (int l = 0; l < T_C_KQ::ne; ++l) {
                 if (!oob_check || k0 + T_C_KQ::get_i(l) < k_VKQ_sup) {
-                    KQ_max_new[l % 2] = fmaxf(KQ_max_new[l % 2], KQ_C[k0/(np*T_C_KQ::I)].x[l]);
+                    KQ_max_new[l % 2] = fmaxf(KQ_max_new[l % 2], KQ_C[k0/(np*T_C_KQ::I)].x[l] + FATTN_KQ_MAX_OFFSET);
                 }
             }
         }
@@ -585,7 +585,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
             for (int l = 0; l < T_C_KQ::ne; ++l) {
                 if (!oob_check || k0 + T_C_KQ::get_j(l) < k_VKQ_sup) {
                     // Turing + Volta:
-                    KQ_max_new[(l/2) % 2] = fmaxf(KQ_max_new[(l/2) % 2], KQ_C[(k0/(np*T_C_KQ::J))].x[l]);
+                    KQ_max_new[(l/2) % 2] = fmaxf(KQ_max_new[(l/2) % 2], KQ_C[(k0/(np*T_C_KQ::J))].x[l] + FATTN_KQ_MAX_OFFSET);
                 }
             }
         }
index 63b235674eb8d19c38caca2c0994fcc8183f26bd..8afc1daaeb73092b6e93a28f4802e7a967357a76 100644 (file)
@@ -572,7 +572,7 @@ static __device__ __forceinline__ void flash_attn_tile_iter(
                 KQ_acc[(i_KQ_0/(np*warp_size))*cpw + jc0] += (ncols2 > 1 || mask) ?
                     slope*__half2float(mask[j*stride_mask + k_VKQ_0 + i_KQ]) : 0.0f;
 
-                KQ_max_new[jc0] = fmaxf(KQ_max_new[jc0], KQ_acc[(i_KQ_0/(np*warp_size))*cpw + jc0]);
+                KQ_max_new[jc0] = fmaxf(KQ_max_new[jc0], KQ_acc[(i_KQ_0/(np*warp_size))*cpw + jc0] + FATTN_KQ_MAX_OFFSET);
             }
         }
 
index 0bae9849a96fb306e4388ef71bdf55ec577686e8..4d167b95a075852e8cdaa7b7656ed8cac7ec2538 100644 (file)
@@ -270,7 +270,7 @@ static __global__ void flash_attn_ext_vec(
                     sum += slope*__half2float(maskh[j*ne11 + i_KQ]);
                 }
 
-                KQ_max_new[j] = fmaxf(KQ_max_new[j], sum);
+                KQ_max_new[j] = fmaxf(KQ_max_new[j], sum + FATTN_KQ_MAX_OFFSET);
 
                 if ((nthreads_KQ == WARP_SIZE ? threadIdx.x : threadIdx.x % nthreads_KQ) == uint32_t(i_KQ_0)) {
                     KQ_reg[j] = sum;
index 0d81f0aae0a75bbd192e8fa1e029ae23a13f81df..8694fd06c7bc4f31d9702933ca0aa720389ef240 100644 (file)
@@ -220,7 +220,7 @@ static __global__ void flash_attn_ext_f16(
 
                     KQ_f_tmp[k0/warp_size] += mask && ic0 + j < int(ne01.z) ?
                         __half2float(slopeh*maskh[j*(nb31/sizeof(half)) + k_VKQ_0 + k]) : 0.0f;
-                    KQ_max_new = max(KQ_max_new, KQ_f_tmp[k0/warp_size]);
+                    KQ_max_new = max(KQ_max_new, KQ_f_tmp[k0/warp_size] + FATTN_KQ_MAX_OFFSET);
                 }
                 KQ_max_new = warp_reduce_max<warp_size>(KQ_max_new);