]> git.djapps.eu Git - pkg/ggml/sources/ggml/commitdiff
CUDA: fix FA tg at long context for CC >= 8.9 (llama/13852)
authorJohannes Gäßler <redacted>
Wed, 28 May 2025 11:33:37 +0000 (13:33 +0200)
committerGeorgi Gerganov <redacted>
Sun, 1 Jun 2025 11:01:05 +0000 (14:01 +0300)
src/ggml-cuda/fattn-common.cuh

index a4fbd823638fab0db558c214cf48c1b971cac68b..cfab2b5ebaccc6000ddc0e10a55e1ac18bbedd51 100644 (file)
@@ -623,8 +623,8 @@ static __global__ void flash_attn_combine_results(
     __builtin_assume(tid < D);
 
     extern __shared__ float2 meta[];
-    if (tid < 2*parallel_blocks) {
-        ((float *) meta)[threadIdx.x] = ((const float *)VKQ_meta) [blockIdx.z*(2*parallel_blocks) + tid];
+    for (int i = tid; i < 2*parallel_blocks; i += D) {
+        ((float *) meta)[i] = ((const float *)VKQ_meta) [blockIdx.z*(2*parallel_blocks) + i];
     }
 
     __syncthreads();