]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
CUDA: fix misaligned synchronization in FA (#13469)
authorJohannes Gäßler <redacted>
Mon, 12 May 2025 08:51:21 +0000 (10:51 +0200)
committerGitHub <redacted>
Mon, 12 May 2025 08:51:21 +0000 (10:51 +0200)
ggml/src/ggml-cuda/fattn-mma-f16.cuh

index 9873ea755a599767ae4e6904415daf881535f911..491780abd40626d4518a7c7d514654704c3c957d 100644 (file)
@@ -895,6 +895,11 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
             float2 * dstk_fixup_meta = dstk_fixup + (gridDim.x + blockIdx.x)*ncols;
             dstk_fixup_meta[(threadIdx.y/np)*cols_per_warp + threadIdx.x] = make_float2(KQ_cmn, KQ_crs);
         }
+    } else if (np > 1) {
+        // Warps with threadIdx.y % np == 0 execute a __syncthreads() in the if branch.
+        // Therefore, all other warps also need to execute a __syncthreads().
+        // Otherwise the points at which warps synchronize with each other would become misaligned.
+        __syncthreads();
     }
 
 #pragma unroll