]> git.djapps.eu Git - pkg/ggml/sources/ggml/commitdiff
CUDA: fix FA out-of-bounds writes (llama/7465)
authorJohannes Gäßler <redacted>
Wed, 22 May 2024 15:58:25 +0000 (17:58 +0200)
committerGeorgi Gerganov <redacted>
Tue, 28 May 2024 11:41:08 +0000 (14:41 +0300)
src/ggml-cuda/fattn-tile-f16.cu
src/ggml-cuda/fattn-tile-f32.cu
src/ggml-cuda/fattn-vec-f16.cu
src/ggml-cuda/fattn-vec-f32.cu

index 4a07ac6adad717870711400b20871a6066f76f73..586d469c049d10b4c7e5bc4956dedaf9794b4551 100644 (file)
@@ -238,6 +238,10 @@ static __global__ void flash_attn_tile_ext_f16(
     for (int j_VKQ_0 = 0; j_VKQ_0 < ncols; j_VKQ_0 += nwarps) {
         const int j_VKQ = j_VKQ_0 + threadIdx.y;
 
+        if (ic0 + j_VKQ >= ne01) {
+            return;
+        }
+
         half kqsum_j = __low2half(kqsum[j_VKQ_0/nwarps]) + __high2half(kqsum[j_VKQ_0/nwarps]);
         kqsum_j = warp_reduce_sum(kqsum_j);
 
index b8b2f69e19edb86c237be2cc99823cc1df8a2bc0..b6ef8eb48d992bc0e82cf30d6fcf294b11cc1810 100644 (file)
@@ -237,6 +237,10 @@ static __global__ void flash_attn_tile_ext_f32(
     for (int j_VKQ_0 = 0; j_VKQ_0 < ncols; j_VKQ_0 += nwarps) {
         const int j_VKQ = j_VKQ_0 + threadIdx.y;
 
+        if (ic0 + j_VKQ >= ne01) {
+            return;
+        }
+
         float kqsum_j = kqsum[j_VKQ_0/nwarps];
         kqsum_j = warp_reduce_sum(kqsum_j);
 
index 54e1ac5d1605010a03190dc02abb394b753ecfc0..7352dcabf62919c05d4e59f1f48964a0948086fa 100644 (file)
@@ -212,6 +212,10 @@ static __global__ void flash_attn_vec_ext_f16(
 
 #pragma unroll
     for (int j_VKQ = 0; j_VKQ < ncols; ++j_VKQ) {
+        if (ic0 + j_VKQ >= ne01) {
+            break;
+        }
+
         kqsum[j_VKQ] = kqsum_shared[j_VKQ][threadIdx.x];
         kqsum[j_VKQ] = warp_reduce_sum(kqsum[j_VKQ]);
 
@@ -223,7 +227,7 @@ static __global__ void flash_attn_vec_ext_f16(
         dst[j_dst*D*gridDim.y + D*blockIdx.y + tid] = dst_val;
     }
 
-    if (parallel_blocks != 1 && tid < ncols) {
+    if (parallel_blocks != 1 && tid < ncols && ic0 + tid < ne01) {
         dst_meta[(ic0 + tid)*gridDim.y*parallel_blocks + blockIdx.y*parallel_blocks + ip] = make_float2(kqmax[tid], kqsum[tid]);
     }
 #else
index 5bcabd0928451e8ebbaee831e5b7a7cfd131afa5..11476a6c0fbbca7d458c909426f33eb559beb048 100644 (file)
@@ -200,6 +200,10 @@ static __global__ void flash_attn_vec_ext_f32(
 
 #pragma unroll
     for (int j_VKQ = 0; j_VKQ < ncols; ++j_VKQ) {
+        if (ic0 + j_VKQ >= ne01) {
+            break;
+        }
+
         kqsum[j_VKQ] = kqsum_shared[j_VKQ][threadIdx.x];
         kqsum[j_VKQ] = warp_reduce_sum(kqsum[j_VKQ]);
 
@@ -211,7 +215,7 @@ static __global__ void flash_attn_vec_ext_f32(
         dst[j_dst*D*gridDim.y + D*blockIdx.y + tid] = dst_val;
     }
 
-    if (parallel_blocks != 1 && tid < ncols) {
+    if (parallel_blocks != 1 && tid < ncols && ic0 + tid < ne01) {
         dst_meta[(ic0 + tid)*gridDim.y*parallel_blocks + blockIdx.y*parallel_blocks + ip] = make_float2(kqmax[tid], kqsum[tid]);
     }
 }