]> git.djapps.eu Git - pkg/ggml/sources/ggml/commitdiff
CUDA: fix pointer incrementation in FA (llama/14916)
authorJohannes Gäßler <redacted>
Mon, 28 Jul 2025 12:30:22 +0000 (14:30 +0200)
committerGeorgi Gerganov <redacted>
Sat, 2 Aug 2025 14:51:21 +0000 (17:51 +0300)
src/ggml-cuda/fattn-vec-f16.cuh
src/ggml-cuda/fattn-vec-f32.cuh

index e9b5c306365a2e3b2a52152eabfbdb81c9e630a0..109253838f26c20af2d884e02f1e66fb6cc078f6 100644 (file)
@@ -174,7 +174,10 @@ static __global__ void flash_attn_vec_ext_f16(
     K     += blockIdx.y*D * nb11;
     V     += blockIdx.y*D * nb21;
     maskh += blockIdx.y*D;
-    for (int k_VKQ_0 = blockIdx.y*D; k_VKQ_0 < ne11; k_VKQ_0 += gridDim.y*D) {
+    for (int k_VKQ_0 = blockIdx.y*D; k_VKQ_0 < ne11; k_VKQ_0 += gridDim.y*D,
+             // Increment pointers after each loop:
+             K += gridDim.y*D*nb11, V += gridDim.y*D*nb21, maskh += gridDim.y*D) {
+
         // Calculate KQ tile and keep track of new maximum KQ values:
 
         if (mask) {
@@ -291,10 +294,6 @@ static __global__ void flash_attn_vec_ext_f16(
             }
         }
 
-        K     += gridDim.y*D * nb11;
-        V     += gridDim.y*D * nb21;
-        maskh += gridDim.y*D;
-
         __syncthreads();
     }
 
index 6a4bdc0ff9aac847f4e3b4e74aefdbe39bddaa6c..2cf2e408e92c5871c0e94b2648cfe09b0b70967a 100644 (file)
@@ -180,7 +180,10 @@ static __global__ void flash_attn_vec_ext_f32(
     K     += blockIdx.y*D * nb11;
     V     += blockIdx.y*D * nb21;
     maskh += blockIdx.y*D;
-    for (int k_VKQ_0 = blockIdx.y*D; k_VKQ_0 < ne11; k_VKQ_0 += gridDim.y*D) {
+    for (int k_VKQ_0 = blockIdx.y*D; k_VKQ_0 < ne11; k_VKQ_0 += gridDim.y*D,
+             // Increment pointers after each loop:
+             K += gridDim.y*D*nb11, V += gridDim.y*D*nb21, maskh += gridDim.y*D) {
+
         // Calculate KQ tile and keep track of new maximum KQ values:
 
         if (mask) {
@@ -286,10 +289,6 @@ static __global__ void flash_attn_vec_ext_f32(
             }
         }
 
-        K     += gridDim.y*D * nb11;
-        V     += gridDim.y*D * nb21;
-        maskh += gridDim.y*D;
-
         __syncthreads();
     }