]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
CUDA: fix FP16 overflow in tile FA kernel (#17875)
authorJohannes Gäßler <redacted>
Tue, 9 Dec 2025 08:34:02 +0000 (09:34 +0100)
committerGitHub <redacted>
Tue, 9 Dec 2025 08:34:02 +0000 (09:34 +0100)
ggml/src/ggml-cuda/fattn-tile.cuh

index 8afc1daaeb73092b6e93a28f4802e7a967357a76..7c4d6fe67feae18e13a6d54b3015c85c88d6493c 100644 (file)
@@ -564,6 +564,12 @@ static __device__ __forceinline__ void flash_attn_tile_iter(
         for (int i_KQ_0 = 0; i_KQ_0 < nbatch_fa; i_KQ_0 += np*warp_size) {
             const int i_KQ = i_KQ_0 + (threadIdx.y % np)*warp_size + threadIdx.x;
 
+#if defined(FAST_FP16_AVAILABLE) && !defined(V_DOT2_F32_F16_AVAILABLE)
+            // Without the v_dot2_f32_f16 instruction there is a higher risk of numerical overflow in the KQ calculation.
+            // Therefore, scale down Q values and apply the inverse scale the FP32 KQ values afterwards again.
+            KQ_acc[i_KQ_0/(np*warp_size)*cpw + jc0] *= 4.0f;
+#endif // defined(FAST_FP16_AVAILABLE) && !defined(V_DOT2_F32_F16_AVAILABLE)
+
             if (use_logit_softcap) {
                 KQ_acc[(i_KQ_0/(np*warp_size))*cpw + jc0] = logit_softcap * tanhf(KQ_acc[(i_KQ_0/(np*warp_size))*cpw + jc0]);
             }
@@ -858,6 +864,11 @@ static __global__ void flash_attn_tile(
 #pragma unroll
                 for (int i1 = 0; i1 < cpy_ne_D; i1 += 2) {
                     tmp_h2[i1/2] = make_half2(tmp_f[i1 + 0], tmp_f[i1 + 1]);
+#if defined(FAST_FP16_AVAILABLE) && !defined(V_DOT2_F32_F16_AVAILABLE)
+                    // Without the v_dot2_f32_f16 instruction there is a higher risk of numerical overflow in the KQ calculation.
+                    // Therefore, scale down Q values and apply the inverse scale the FP32 KQ values afterwards again.
+                    tmp_h2[i1/2] *= make_half2(0.25f, 0.25f);
+#endif // defined(FAST_FP16_AVAILABLE) && !defined(V_DOT2_F32_F16_AVAILABLE)
                 }
                 ggml_cuda_memcpy_1<sizeof(tmp_h2)>(
                     &Q_tmp[jc*(DKQ/2) + i0/2 + (threadIdx.y % np)*(warp_size*cpy_ne_D/2) + threadIdx.x*(cpy_ne_D/2)],