]> git.djapps.eu Git - pkg/ggml/sources/ggml/commitdiff
CUDA: fix broken oob check for FA vec f32 kernel (llama/7904)
authorJohannes Gäßler <redacted>
Wed, 12 Jun 2024 15:41:51 +0000 (17:41 +0200)
committerGeorgi Gerganov <redacted>
Sat, 15 Jun 2024 19:05:47 +0000 (22:05 +0300)
src/ggml-cuda/fattn-vec-f32.cuh

index ddf0c83740f5ae0bfcbd747b02ba412c641fc5ab..11a5e355fd40b8d59c9bd9176112505392dfdba5 100644 (file)
@@ -149,7 +149,7 @@ static __global__ void flash_attn_vec_ext_f32(
             for (int i0 = 0; i0 < D/2; i0 += WARP_SIZE) {
                 const int i = i0 + threadIdx.x;
 
-                Q_f2[j][i0/WARP_SIZE]    = ncols <= 2 || ic0 + j ? Q_f2_j[i] : make_float2(0.0f, 0.0f);
+                Q_f2[j][i0/WARP_SIZE]    = ncols <= 2 || ic0 + j < ne01 ? Q_f2_j[i] : make_float2(0.0f, 0.0f);
                 Q_f2[j][i0/WARP_SIZE].x *= scale;
                 Q_f2[j][i0/WARP_SIZE].y *= scale;
             }