From: Johannes Gäßler Date: Sat, 10 May 2025 20:22:48 +0000 (+0200) Subject: CUDA: fix race conditions FlashAttention kernels (#13438) X-Git-Url: https://git.djapps.eu/?a=commitdiff_plain;h=0208355f42bdab88a08507ead4a6302790a08323;p=pkg%2Fggml%2Fsources%2Fllama.cpp CUDA: fix race conditions FlashAttention kernels (#13438) --- diff --git a/ggml/src/ggml-cuda/fattn-mma-f16.cuh b/ggml/src/ggml-cuda/fattn-mma-f16.cuh index b2f95fa3..9873ea75 100644 --- a/ggml/src/ggml-cuda/fattn-mma-f16.cuh +++ b/ggml/src/ggml-cuda/fattn-mma-f16.cuh @@ -874,6 +874,8 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile( } } + __syncthreads(); + // Write back combined meta data: #pragma unroll for (int imeta = 0; imeta < nmeta; ++imeta) { diff --git a/ggml/src/ggml-cuda/fattn-vec-f16.cuh b/ggml/src/ggml-cuda/fattn-vec-f16.cuh index ef0addc1..d96e3921 100644 --- a/ggml/src/ggml-cuda/fattn-vec-f16.cuh +++ b/ggml/src/ggml-cuda/fattn-vec-f16.cuh @@ -168,6 +168,7 @@ static __global__ void flash_attn_vec_ext_f16( for (int j = 0; j < ncols; ++j) { KQ[j*D + tid] = -HALF_MAX_HALF; } + __syncthreads(); half2 VKQ[ncols] = {{0.0f, 0.0f}};