From: Johannes Gäßler Date: Sat, 24 May 2025 09:46:19 +0000 (+0200) Subject: CUDA: fix race condition in FA vector kernels (llama/13742) X-Git-Url: https://git.djapps.eu/?a=commitdiff_plain;h=f1576b26598c6cf051fb983b1eccfca762b628e7;p=pkg%2Fggml%2Fsources%2Fwhisper.cpp CUDA: fix race condition in FA vector kernels (llama/13742) --- diff --git a/ggml/src/ggml-cuda/fattn-vec-f16.cuh b/ggml/src/ggml-cuda/fattn-vec-f16.cuh index 798a59b2..35e649cb 100644 --- a/ggml/src/ggml-cuda/fattn-vec-f16.cuh +++ b/ggml/src/ggml-cuda/fattn-vec-f16.cuh @@ -212,6 +212,7 @@ static __global__ void flash_attn_vec_ext_f16( } } if (__all_sync(0xFFFFFFFF, skip)) { + __syncthreads(); continue; } #endif // GGML_USE_HIP diff --git a/ggml/src/ggml-cuda/fattn-vec-f32.cuh b/ggml/src/ggml-cuda/fattn-vec-f32.cuh index 49c592ea..95396791 100644 --- a/ggml/src/ggml-cuda/fattn-vec-f32.cuh +++ b/ggml/src/ggml-cuda/fattn-vec-f32.cuh @@ -217,6 +217,7 @@ static __global__ void flash_attn_vec_ext_f32( } } if (__all_sync(0xFFFFFFFF, skip)) { + __syncthreads(); continue; } #endif // GGML_USE_HIP