From: Johannes Gäßler Date: Mon, 12 May 2025 08:51:21 +0000 (+0200) Subject: CUDA: fix misaligned synchronization in FA (#13469) X-Git-Url: https://git.djapps.eu/?a=commitdiff_plain;h=95e18884fc7ea4031f70f1a518d5d1df616e5717;p=pkg%2Fggml%2Fsources%2Fllama.cpp CUDA: fix misaligned synchronization in FA (#13469) --- diff --git a/ggml/src/ggml-cuda/fattn-mma-f16.cuh b/ggml/src/ggml-cuda/fattn-mma-f16.cuh index 9873ea75..491780ab 100644 --- a/ggml/src/ggml-cuda/fattn-mma-f16.cuh +++ b/ggml/src/ggml-cuda/fattn-mma-f16.cuh @@ -895,6 +895,11 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile( float2 * dstk_fixup_meta = dstk_fixup + (gridDim.x + blockIdx.x)*ncols; dstk_fixup_meta[(threadIdx.y/np)*cols_per_warp + threadIdx.x] = make_float2(KQ_cmn, KQ_crs); } + } else if (np > 1) { + // Warps with threadIdx.y % np == 0 execute a __syncthreads() in the if branch. + // Therefore, all other warps also need to execute a __syncthreads(). + // Otherwise the points at which warps synchronize with each other would become misaligned. + __syncthreads(); } #pragma unroll