]> git.djapps.eu Git - pkg/ggml/sources/whisper.cpp/commitdiff
CUDA: fix FA out-of-bounds reads (llama/7479)
authorJohannes Gäßler <redacted>
Wed, 22 May 2024 22:31:20 +0000 (00:31 +0200)
committerGeorgi Gerganov <redacted>
Sun, 16 Jun 2024 15:19:48 +0000 (18:19 +0300)
ggml-cuda/fattn-tile-f16.cu
ggml-cuda/fattn-tile-f32.cu

index 586d469c049d10b4c7e5bc4956dedaf9794b4551..cdb5eaff7953595c94a624467103fda11f658d8f 100644 (file)
@@ -83,7 +83,7 @@ static __global__ void flash_attn_tile_ext_f16(
         for (int i0 = 0; i0 < D/2; i0 += WARP_SIZE) {
             const int i = i0 + threadIdx.x;
 
-            const float2 tmp = Q_f2[j*(nb01/sizeof(float2)) + i];
+            const float2 tmp = ic0 + j < ne01 ? Q_f2[j*(nb01/sizeof(float2)) + i] : make_float2(0.0f, 0.0f);
             Q_h2[j][i] = make_half2(scale, scale) * make_half2(tmp.x, tmp.y);
         }
     }
index b6ef8eb48d992bc0e82cf30d6fcf294b11cc1810..5a3de2918c7a3dd43725e518f22c60c157c89528 100644 (file)
@@ -79,7 +79,7 @@ static __global__ void flash_attn_tile_ext_f32(
 
 #pragma unroll
         for (int i0 = 0; i0 < D; i0 += 2*WARP_SIZE) {
-            float2 tmp = Q_f2[j*(nb01/sizeof(float2)) + i0/2 + threadIdx.x];
+            float2 tmp = ic0 + j < ne01 ? Q_f2[j*(nb01/sizeof(float2)) + i0/2 + threadIdx.x] : make_float2(0.0f, 0.0f);
             Q_f[j][i0 + 0*WARP_SIZE + threadIdx.x] = tmp.x * scale;
             Q_f[j][i0 + 1*WARP_SIZE + threadIdx.x] = tmp.y * scale;
         }