From: Johannes Gäßler Date: Sat, 10 May 2025 07:16:52 +0000 (+0200) Subject: CUDA: fix FlashAttention on Turing (llama/13415) X-Git-Url: https://git.djapps.eu/?a=commitdiff_plain;h=637981b2afd5c0d23eddc14799b314692c839453;p=pkg%2Fggml%2Fsources%2Fggml CUDA: fix FlashAttention on Turing (llama/13415) --- diff --git a/src/ggml-cuda/fattn-mma-f16.cuh b/src/ggml-cuda/fattn-mma-f16.cuh index 2b6bdc30..b2f95fa3 100644 --- a/src/ggml-cuda/fattn-mma-f16.cuh +++ b/src/ggml-cuda/fattn-mma-f16.cuh @@ -546,7 +546,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter( const int i0_stop = i0_start + 2*c::nbatch_V2 < DV ? i0_start + 2*c::nbatch_V2 : DV; const int i0_diff = i0_stop - i0_start; - if (nstages == 1) { + if (nstages <= 1) { constexpr bool use_cp_async = nstages == 1; flash_attn_ext_f16_load_tile (V_h2 + k_VKQ_0*stride_V + i0_start/2, tile_V, i0_diff/2, stride_V);