]> git.djapps.eu Git - pkg/ggml/sources/ggml/commitdiff
CUDA: fix FA out-of-bounds reads (llama/7479)
authorJohannes Gäßler <redacted>
Wed, 22 May 2024 22:31:20 +0000 (00:31 +0200)
committerGeorgi Gerganov <redacted>
Tue, 28 May 2024 11:41:08 +0000 (14:41 +0300)
src/ggml-cuda/fattn-tile-f16.cu
src/ggml-cuda/fattn-tile-f32.cu
src/ggml-cuda/fattn-vec-f16.cu
src/ggml-cuda/fattn-vec-f32.cu

index 586d469c049d10b4c7e5bc4956dedaf9794b4551..cdb5eaff7953595c94a624467103fda11f658d8f 100644 (file)
@@ -83,7 +83,7 @@ static __global__ void flash_attn_tile_ext_f16(
         for (int i0 = 0; i0 < D/2; i0 += WARP_SIZE) {
             const int i = i0 + threadIdx.x;
 
-            const float2 tmp = Q_f2[j*(nb01/sizeof(float2)) + i];
+            const float2 tmp = ic0 + j < ne01 ? Q_f2[j*(nb01/sizeof(float2)) + i] : make_float2(0.0f, 0.0f);
             Q_h2[j][i] = make_half2(scale, scale) * make_half2(tmp.x, tmp.y);
         }
     }
index b6ef8eb48d992bc0e82cf30d6fcf294b11cc1810..5a3de2918c7a3dd43725e518f22c60c157c89528 100644 (file)
@@ -79,7 +79,7 @@ static __global__ void flash_attn_tile_ext_f32(
 
 #pragma unroll
         for (int i0 = 0; i0 < D; i0 += 2*WARP_SIZE) {
-            float2 tmp = Q_f2[j*(nb01/sizeof(float2)) + i0/2 + threadIdx.x];
+            float2 tmp = ic0 + j < ne01 ? Q_f2[j*(nb01/sizeof(float2)) + i0/2 + threadIdx.x] : make_float2(0.0f, 0.0f);
             Q_f[j][i0 + 0*WARP_SIZE + threadIdx.x] = tmp.x * scale;
             Q_f[j][i0 + 1*WARP_SIZE + threadIdx.x] = tmp.y * scale;
         }
index 7352dcabf62919c05d4e59f1f48964a0948086fa..808e8f36246a721005384ba6ffeb9210f4ba863f 100644 (file)
@@ -94,7 +94,7 @@ static __global__ void flash_attn_vec_ext_f16(
         for (int i0 = 0; i0 < D/2; i0 += WARP_SIZE) {
             const int i = i0 + threadIdx.x;
 
-            const float2 tmp = Q_f2[j*(nb01/sizeof(float2)) + i];
+            const float2 tmp = ncols <= 2 || ic0 + j < ne01 ? Q_f2[j*(nb01/sizeof(float2)) + i] : make_float2(0.0f, 0.0f);
             Q_h2[j][i0/WARP_SIZE] = make_half2(scale, scale) * make_half2(tmp.x, tmp.y);
         }
     }
@@ -212,7 +212,7 @@ static __global__ void flash_attn_vec_ext_f16(
 
 #pragma unroll
     for (int j_VKQ = 0; j_VKQ < ncols; ++j_VKQ) {
-        if (ic0 + j_VKQ >= ne01) {
+        if (ncols > 2 && ic0 + j_VKQ >= ne01) {
             break;
         }
 
@@ -227,7 +227,7 @@ static __global__ void flash_attn_vec_ext_f16(
         dst[j_dst*D*gridDim.y + D*blockIdx.y + tid] = dst_val;
     }
 
-    if (parallel_blocks != 1 && tid < ncols && ic0 + tid < ne01) {
+    if (parallel_blocks != 1 && tid < ncols && (ncols <= 2 || ic0 + tid < ne01)) {
         dst_meta[(ic0 + tid)*gridDim.y*parallel_blocks + blockIdx.y*parallel_blocks + ip] = make_float2(kqmax[tid], kqsum[tid]);
     }
 #else
index 11476a6c0fbbca7d458c909426f33eb559beb048..b4652301b87e03c0622a9adcb05b562d9ca15474 100644 (file)
@@ -91,7 +91,7 @@ static __global__ void flash_attn_vec_ext_f32(
         for (int i0 = 0; i0 < D/2; i0 += WARP_SIZE) {
             const int i = i0 + threadIdx.x;
 
-            Q_h2[j][i0/WARP_SIZE]    = Q_f2[j*(nb01/sizeof(float2)) + i];
+            Q_h2[j][i0/WARP_SIZE]    = ncols <= 2 || ic0 + j ? Q_f2[j*(nb01/sizeof(float2)) + i] : make_float2(0.0f, 0.0f);
             Q_h2[j][i0/WARP_SIZE].x *= scale;
             Q_h2[j][i0/WARP_SIZE].y *= scale;
         }
@@ -200,7 +200,7 @@ static __global__ void flash_attn_vec_ext_f32(
 
 #pragma unroll
     for (int j_VKQ = 0; j_VKQ < ncols; ++j_VKQ) {
-        if (ic0 + j_VKQ >= ne01) {
+        if (ncols > 2 && ic0 + j_VKQ >= ne01) {
             break;
         }
 
@@ -215,7 +215,7 @@ static __global__ void flash_attn_vec_ext_f32(
         dst[j_dst*D*gridDim.y + D*blockIdx.y + tid] = dst_val;
     }
 
-    if (parallel_blocks != 1 && tid < ncols && ic0 + tid < ne01) {
+    if (parallel_blocks != 1 && tid < ncols && (ncols <= 2 || ic0 + tid < ne01)) {
         dst_meta[(ic0 + tid)*gridDim.y*parallel_blocks + blockIdx.y*parallel_blocks + ip] = make_float2(kqmax[tid], kqsum[tid]);
     }
 }