From: Johannes Gäßler Date: Wed, 28 May 2025 11:33:37 +0000 (+0200) Subject: CUDA: fix FA tg at long context for CC >= 8.9 (llama/13852) X-Git-Url: https://git.djapps.eu/?a=commitdiff_plain;h=9a500394ad1ef1190dde3ac91d399d8cb05bac16;p=pkg%2Fggml%2Fsources%2Fwhisper.cpp CUDA: fix FA tg at long context for CC >= 8.9 (llama/13852) --- diff --git a/ggml/src/ggml-cuda/fattn-common.cuh b/ggml/src/ggml-cuda/fattn-common.cuh index a4fbd823..cfab2b5e 100644 --- a/ggml/src/ggml-cuda/fattn-common.cuh +++ b/ggml/src/ggml-cuda/fattn-common.cuh @@ -623,8 +623,8 @@ static __global__ void flash_attn_combine_results( __builtin_assume(tid < D); extern __shared__ float2 meta[]; - if (tid < 2*parallel_blocks) { - ((float *) meta)[threadIdx.x] = ((const float *)VKQ_meta) [blockIdx.z*(2*parallel_blocks) + tid]; + for (int i = tid; i < 2*parallel_blocks; i += D) { + ((float *) meta)[i] = ((const float *)VKQ_meta) [blockIdx.z*(2*parallel_blocks) + i]; } __syncthreads();