]> git.djapps.eu Git - pkg/ggml/sources/whisper.cpp/commitdiff
CUDA: remove unnecessary warp reduce in FA (ggml/1032)
authormahorozte <redacted>
Tue, 3 Dec 2024 13:11:43 +0000 (21:11 +0800)
committerGeorgi Gerganov <redacted>
Sun, 8 Dec 2024 18:14:35 +0000 (20:14 +0200)
* kqmax_new_j in every thread within warp is same after operate at line 199,this reduce can be omit

* same problem in vec32

---------

Co-authored-by: ZhaoXiaoYu <redacted>
ggml/src/ggml-cuda/fattn-vec-f16.cuh
ggml/src/ggml-cuda/fattn-vec-f32.cuh

index 5ec3b91ae2b39f88997f0874cbdc5f96453db154..34a2992c769b9d4d338664bd338db0526c4b5c00 100644 (file)
@@ -220,7 +220,6 @@ static __global__ void flash_attn_vec_ext_f16(
         for (int j = 0; j < ncols; ++j) {
             half kqmax_new_j = ncols == 1 ? kqmax_new : kqmax_new_arr[j];
 
-            kqmax_new_j = warp_reduce_max(kqmax_new_j);
             if (threadIdx.x == 0) {
                 kqmax_shared[j][threadIdx.y] = kqmax_new_j;
             }
index 3d93f4a8acdf27bec1b44d75056632907099de35..a28fc8b7fc893ea5ed8a4fb9a746e84fcdf5a2b3 100644 (file)
@@ -206,7 +206,6 @@ static __global__ void flash_attn_vec_ext_f32(
         for (int j = 0; j < ncols; ++j) {
             float kqmax_new_j = kqmax_new_arr[j];
 
-            kqmax_new_j = warp_reduce_max(kqmax_new_j);
             if (threadIdx.x == 0) {
                 kqmax_shared[j][threadIdx.y] = kqmax_new_j;
             }