]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
CUDA: fix numerical issues in tile FA kernel (#16540)
authorJohannes Gäßler <redacted>
Mon, 13 Oct 2025 14:29:45 +0000 (16:29 +0200)
committerGitHub <redacted>
Mon, 13 Oct 2025 14:29:45 +0000 (17:29 +0300)
ggml/src/ggml-cuda/fattn-tile.cuh

index 2efc9cc880cf8fa3608870be37dd35dd1fc52dda..2b60b3bb13563d33da26cc63e1e6fceca683da89 100644 (file)
@@ -540,10 +540,12 @@ static __device__ __forceinline__ void flash_attn_tile_iter(
                 KQ_acc[(i_KQ_0/(np*warp_size))*cpw + jc0] = logit_softcap * tanhf(KQ_acc[(i_KQ_0/(np*warp_size))*cpw + jc0]);
             }
 
-            KQ_acc[(i_KQ_0/(np*warp_size))*cpw + jc0] += (ncols2 > 1 || mask) && (!oob_check || i_KQ < k_VKQ_sup) ?
-                slope*__half2float(mask[j*stride_mask + k_VKQ_0 + i_KQ]) : 0.0f;
+            if (!oob_check || i_KQ < k_VKQ_sup) {
+                KQ_acc[(i_KQ_0/(np*warp_size))*cpw + jc0] += (ncols2 > 1 || mask) ?
+                    slope*__half2float(mask[j*stride_mask + k_VKQ_0 + i_KQ]) : 0.0f;
 
-            KQ_max_new[jc0] = fmaxf(KQ_max_new[jc0], KQ_acc[(i_KQ_0/(np*warp_size))*cpw + jc0]);
+                KQ_max_new[jc0] = fmaxf(KQ_max_new[jc0], KQ_acc[(i_KQ_0/(np*warp_size))*cpw + jc0]);
+            }
         }
 
         KQ_max_new[jc0] = warp_reduce_max<warp_size>(KQ_max_new[jc0]);
@@ -581,10 +583,9 @@ static __device__ __forceinline__ void flash_attn_tile_iter(
             float KQ_sum_add = 0.0f;
 #pragma unroll
             for (int i0 = 0; i0 < nbatch_fa; i0 += np*warp_size) {
-                const float val = expf(KQ_acc[(i0/(np*warp_size))*cpw + jc] - KQ_max[jc]);
-                if (!oob_check || i0 + (threadIdx.y % np)*warp_size + threadIdx.x < k_VKQ_sup) {
-                    KQ_sum_add += val;
-                }
+                const float val = !oob_check || i0 + (threadIdx.y % np)*warp_size + threadIdx.x < k_VKQ_sup ?
+                    expf(KQ_acc[(i0/(np*warp_size))*cpw + jc] - KQ_max[jc]) : 0.0f;
+                KQ_sum_add += val;
                 tmp[i0/(np*warp_size)][jc1] = val;
             }
             KQ_sum[jc] = KQ_sum[jc]*KQ_max_scale + KQ_sum_add;
@@ -975,26 +976,6 @@ static __global__ void flash_attn_tile(
         }
     }
 
-    if (gridDim.y == 1) {
-#pragma unroll
-        for (int jc0 = 0; jc0 < cpw; ++jc0) {
-#ifdef FAST_FP16_AVAILABLE
-            const half2 KQ_sum_jc_inv = make_half2(1.0f/KQ_sum[jc0], 1.0f/KQ_sum[jc0]);
-#pragma unroll
-            for (int i = 0; i < (DVp/2)/warp_size; ++i) {
-                VKQ[jc0*((DVp/2)/warp_size) + i] *= KQ_sum_jc_inv;
-            }
-#else
-            const float KQ_sum_jc_inv = 1.0f/KQ_sum[jc0];
-#pragma unroll
-            for (int i = 0; i < (DVp/2)/warp_size; ++i) {
-                VKQ[jc0*((DVp/2)/warp_size) + i].x *= KQ_sum_jc_inv;
-                VKQ[jc0*((DVp/2)/warp_size) + i].y *= KQ_sum_jc_inv;
-            }
-#endif // FAST_FP16_AVAILABLE
-        }
-    }
-
     // Write back results:
 #pragma unroll
     for (int jc0 = 0; jc0 < cpw; ++jc0) {
@@ -1007,6 +988,8 @@ static __global__ void flash_attn_tile(
             return;
         }
 
+        const float scale = gridDim.y == 1 ? 1.0f/KQ_sum[jc0] : 1.0f;
+
         const int j_dst_unrolled = ((sequence*ne01 + col_Q_0 + j)*ne02 + head0 + c)*gridDim.y + blockIdx.y;
 
 #ifdef FAST_FP16_AVAILABLE
@@ -1017,6 +1000,8 @@ static __global__ void flash_attn_tile(
 #pragma unroll
             for (int i1 = 0; i1 < cpy_ne_D; ++i1) {
                 tmp[i1] = __half22float2(VKQ[jc0*((DVp/2)/warp_size) + i0/warp_size + i1]);
+                tmp[i1].x *= scale;
+                tmp[i1].y *= scale;
             }
             if (i0 + warp_size*cpy_ne_D <= DV/2 || i0 + threadIdx.x*cpy_ne_D < DV/2) {
                 ggml_cuda_memcpy_1<sizeof(tmp)>(&dst[j_dst_unrolled*DV + 2*i0 + threadIdx.x*(2*cpy_ne_D)], tmp);
@@ -1027,6 +1012,11 @@ static __global__ void flash_attn_tile(
 #pragma unroll
         for (int i0 = 0; i0 < DVp; i0 += warp_size*cpy_ne_D) {
             if (i0 + warp_size*cpy_ne_D <= DV || i0 + threadIdx.x*cpy_ne_D < DV) {
+#pragma unroll
+                for (int i1 = 0; i1 < cpy_ne_D/2; ++i1) {
+                    VKQ[jc0*((DVp/2)/warp_size) + i0/(2*warp_size) + i1].x *= scale;
+                    VKQ[jc0*((DVp/2)/warp_size) + i0/(2*warp_size) + i1].y *= scale;
+                }
                 ggml_cuda_memcpy_1<cpy_ne_D*4>(
                     &dst[j_dst_unrolled*DV + i0 + threadIdx.x*cpy_ne_D],
                     &VKQ[jc0*((DVp/2)/warp_size) + i0/(2*warp_size)]);