From: mahorozte Date: Tue, 3 Dec 2024 13:11:43 +0000 (+0800) Subject: CUDA: remove unnecessary warp reduce in FA (#1032) X-Git-Tag: upstream/0.0.1642~151 X-Git-Url: https://git.djapps.eu/?a=commitdiff_plain;h=b903ffe79daf18c0aaacbebe44a7b93a6b8d0982;p=pkg%2Fggml%2Fsources%2Fggml CUDA: remove unnecessary warp reduce in FA (#1032) * kqmax_new_j in every thread within warp is same after operate at line 199,this reduce can be omit * same problem in vec32 --------- Co-authored-by: ZhaoXiaoYu --- diff --git a/src/ggml-cuda/fattn-vec-f16.cuh b/src/ggml-cuda/fattn-vec-f16.cuh index 5ec3b91a..34a2992c 100644 --- a/src/ggml-cuda/fattn-vec-f16.cuh +++ b/src/ggml-cuda/fattn-vec-f16.cuh @@ -220,7 +220,6 @@ static __global__ void flash_attn_vec_ext_f16( for (int j = 0; j < ncols; ++j) { half kqmax_new_j = ncols == 1 ? kqmax_new : kqmax_new_arr[j]; - kqmax_new_j = warp_reduce_max(kqmax_new_j); if (threadIdx.x == 0) { kqmax_shared[j][threadIdx.y] = kqmax_new_j; } diff --git a/src/ggml-cuda/fattn-vec-f32.cuh b/src/ggml-cuda/fattn-vec-f32.cuh index 3d93f4a8..a28fc8b7 100644 --- a/src/ggml-cuda/fattn-vec-f32.cuh +++ b/src/ggml-cuda/fattn-vec-f32.cuh @@ -206,7 +206,6 @@ static __global__ void flash_attn_vec_ext_f32( for (int j = 0; j < ncols; ++j) { float kqmax_new_j = kqmax_new_arr[j]; - kqmax_new_j = warp_reduce_max(kqmax_new_j); if (threadIdx.x == 0) { kqmax_shared[j][threadIdx.y] = kqmax_new_j; }