]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
CUDA: fix FlashAttention on Turing (#13415)
authorJohannes Gäßler <redacted>
Sat, 10 May 2025 07:16:52 +0000 (09:16 +0200)
committerGitHub <redacted>
Sat, 10 May 2025 07:16:52 +0000 (09:16 +0200)
ggml/src/ggml-cuda/fattn-mma-f16.cuh

index 2b6bdc30c024ba9d66caa98447278b64227c59f1..b2f95fa3f00e69fc04472b5798af151978a8dd33 100644 (file)
@@ -546,7 +546,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
         const int i0_stop = i0_start + 2*c::nbatch_V2 < DV ? i0_start + 2*c::nbatch_V2 : DV;
         const int i0_diff = i0_stop - i0_start;
 
-        if (nstages == 1) {
+        if (nstages <= 1) {
             constexpr bool use_cp_async = nstages == 1;
             flash_attn_ext_f16_load_tile<stride_tile_V, nwarps, c::nbatch_fa, use_cp_async>
                 (V_h2 + k_VKQ_0*stride_V + i0_start/2, tile_V, i0_diff/2, stride_V);