]> git.djapps.eu Git - pkg/ggml/sources/whisper.cpp/commitdiff
CUDA: fix FA out-of-bounds writes (llama/7465)
authorJohannes Gäßler <redacted>
Wed, 22 May 2024 15:58:25 +0000 (17:58 +0200)
committerGeorgi Gerganov <redacted>
Sun, 16 Jun 2024 15:19:48 +0000 (18:19 +0300)
ggml-cuda/fattn-tile-f16.cu
ggml-cuda/fattn-tile-f32.cu

index 4a07ac6adad717870711400b20871a6066f76f73..586d469c049d10b4c7e5bc4956dedaf9794b4551 100644 (file)
@@ -238,6 +238,10 @@ static __global__ void flash_attn_tile_ext_f16(
     for (int j_VKQ_0 = 0; j_VKQ_0 < ncols; j_VKQ_0 += nwarps) {
         const int j_VKQ = j_VKQ_0 + threadIdx.y;
 
+        if (ic0 + j_VKQ >= ne01) {
+            return;
+        }
+
         half kqsum_j = __low2half(kqsum[j_VKQ_0/nwarps]) + __high2half(kqsum[j_VKQ_0/nwarps]);
         kqsum_j = warp_reduce_sum(kqsum_j);
 
index b8b2f69e19edb86c237be2cc99823cc1df8a2bc0..b6ef8eb48d992bc0e82cf30d6fcf294b11cc1810 100644 (file)
@@ -237,6 +237,10 @@ static __global__ void flash_attn_tile_ext_f32(
     for (int j_VKQ_0 = 0; j_VKQ_0 < ncols; j_VKQ_0 += nwarps) {
         const int j_VKQ = j_VKQ_0 + threadIdx.y;
 
+        if (ic0 + j_VKQ >= ne01) {
+            return;
+        }
+
         float kqsum_j = kqsum[j_VKQ_0/nwarps];
         kqsum_j = warp_reduce_sum(kqsum_j);