From: Johannes Gäßler Date: Wed, 22 May 2024 15:58:25 +0000 (+0200) Subject: CUDA: fix FA out-of-bounds writes (llama/7465) X-Git-Tag: upstream/0.0.1642~655 X-Git-Url: https://git.djapps.eu/?a=commitdiff_plain;h=d63866af2ba553ff6cb1261d84c0e9a88559e055;p=pkg%2Fggml%2Fsources%2Fggml CUDA: fix FA out-of-bounds writes (llama/7465) --- diff --git a/src/ggml-cuda/fattn-tile-f16.cu b/src/ggml-cuda/fattn-tile-f16.cu index 4a07ac6a..586d469c 100644 --- a/src/ggml-cuda/fattn-tile-f16.cu +++ b/src/ggml-cuda/fattn-tile-f16.cu @@ -238,6 +238,10 @@ static __global__ void flash_attn_tile_ext_f16( for (int j_VKQ_0 = 0; j_VKQ_0 < ncols; j_VKQ_0 += nwarps) { const int j_VKQ = j_VKQ_0 + threadIdx.y; + if (ic0 + j_VKQ >= ne01) { + return; + } + half kqsum_j = __low2half(kqsum[j_VKQ_0/nwarps]) + __high2half(kqsum[j_VKQ_0/nwarps]); kqsum_j = warp_reduce_sum(kqsum_j); diff --git a/src/ggml-cuda/fattn-tile-f32.cu b/src/ggml-cuda/fattn-tile-f32.cu index b8b2f69e..b6ef8eb4 100644 --- a/src/ggml-cuda/fattn-tile-f32.cu +++ b/src/ggml-cuda/fattn-tile-f32.cu @@ -237,6 +237,10 @@ static __global__ void flash_attn_tile_ext_f32( for (int j_VKQ_0 = 0; j_VKQ_0 < ncols; j_VKQ_0 += nwarps) { const int j_VKQ = j_VKQ_0 + threadIdx.y; + if (ic0 + j_VKQ >= ne01) { + return; + } + float kqsum_j = kqsum[j_VKQ_0/nwarps]; kqsum_j = warp_reduce_sum(kqsum_j); diff --git a/src/ggml-cuda/fattn-vec-f16.cu b/src/ggml-cuda/fattn-vec-f16.cu index 54e1ac5d..7352dcab 100644 --- a/src/ggml-cuda/fattn-vec-f16.cu +++ b/src/ggml-cuda/fattn-vec-f16.cu @@ -212,6 +212,10 @@ static __global__ void flash_attn_vec_ext_f16( #pragma unroll for (int j_VKQ = 0; j_VKQ < ncols; ++j_VKQ) { + if (ic0 + j_VKQ >= ne01) { + break; + } + kqsum[j_VKQ] = kqsum_shared[j_VKQ][threadIdx.x]; kqsum[j_VKQ] = warp_reduce_sum(kqsum[j_VKQ]); @@ -223,7 +227,7 @@ static __global__ void flash_attn_vec_ext_f16( dst[j_dst*D*gridDim.y + D*blockIdx.y + tid] = dst_val; } - if (parallel_blocks != 1 && tid < ncols) { + if (parallel_blocks != 1 && tid < ncols && ic0 + tid < ne01) { dst_meta[(ic0 + tid)*gridDim.y*parallel_blocks + blockIdx.y*parallel_blocks + ip] = make_float2(kqmax[tid], kqsum[tid]); } #else diff --git a/src/ggml-cuda/fattn-vec-f32.cu b/src/ggml-cuda/fattn-vec-f32.cu index 5bcabd09..11476a6c 100644 --- a/src/ggml-cuda/fattn-vec-f32.cu +++ b/src/ggml-cuda/fattn-vec-f32.cu @@ -200,6 +200,10 @@ static __global__ void flash_attn_vec_ext_f32( #pragma unroll for (int j_VKQ = 0; j_VKQ < ncols; ++j_VKQ) { + if (ic0 + j_VKQ >= ne01) { + break; + } + kqsum[j_VKQ] = kqsum_shared[j_VKQ][threadIdx.x]; kqsum[j_VKQ] = warp_reduce_sum(kqsum[j_VKQ]); @@ -211,7 +215,7 @@ static __global__ void flash_attn_vec_ext_f32( dst[j_dst*D*gridDim.y + D*blockIdx.y + tid] = dst_val; } - if (parallel_blocks != 1 && tid < ncols) { + if (parallel_blocks != 1 && tid < ncols && ic0 + tid < ne01) { dst_meta[(ic0 + tid)*gridDim.y*parallel_blocks + blockIdx.y*parallel_blocks + ip] = make_float2(kqmax[tid], kqsum[tid]); } }