]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
CUDA: fix negative KV_max values in FA (#15321)
authorJohannes Gäßler <redacted>
Thu, 14 Aug 2025 21:21:24 +0000 (23:21 +0200)
committerGitHub <redacted>
Thu, 14 Aug 2025 21:21:24 +0000 (23:21 +0200)
ggml/src/ggml-cuda/fattn-common.cuh

index e46f0e2081bdfb73c5bfe6582a5fff485c8d2c70..d4ed938391b478f7c1560cb82ce48ff1f37e4eef 100644 (file)
@@ -539,11 +539,15 @@ static __global__ void flash_attn_mask_to_KV_max(
         all_inf = warp_reduce_all(all_inf);
 
         if (!all_inf) {
-            KV_max_sj += FATTN_KQ_STRIDE;
             break;
         }
     }
 
+    // If the break in the loop was not triggered, KV_max_sj is now -FATTN_KQ_STRIDE.
+    // If the break was triggered it's the lower edge of the tile with the first non-masked values.
+    // In either case, walk back the decrementation by FATTN_KQ_STRIDE.
+    KV_max_sj += FATTN_KQ_STRIDE;
+
     if (threadIdx.x != 0) {
         return;
     }