]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
CUDA: fix race condition in FA vector kernels (#13742)
authorJohannes Gäßler <redacted>
Sat, 24 May 2025 09:46:19 +0000 (11:46 +0200)
committerGitHub <redacted>
Sat, 24 May 2025 09:46:19 +0000 (11:46 +0200)
ggml/src/ggml-cuda/fattn-vec-f16.cuh
ggml/src/ggml-cuda/fattn-vec-f32.cuh

index 798a59b2778612a77f15b4ae4d64047edd45412d..35e649cb3c81bdbbd81914bc97345c654db3e65f 100644 (file)
@@ -212,6 +212,7 @@ static __global__ void flash_attn_vec_ext_f16(
                 }
             }
             if (__all_sync(0xFFFFFFFF, skip)) {
+                __syncthreads();
                 continue;
             }
 #endif // GGML_USE_HIP
index 49c592ea59224ccae38b24c7e72e7697c46070d3..95396791779698ce15e692ef8b8e32bbc99ac396 100644 (file)
@@ -217,6 +217,7 @@ static __global__ void flash_attn_vec_ext_f32(
                 }
             }
             if (__all_sync(0xFFFFFFFF, skip)) {
+                __syncthreads();
                 continue;
             }
 #endif // GGML_USE_HIP