From: lhez Date: Wed, 15 Oct 2025 17:48:28 +0000 (-0700) Subject: opencl: fix FA for f32 (llama/16584) X-Git-Tag: upstream/0.9.4.185~123 X-Git-Url: https://git.djapps.eu/?a=commitdiff_plain;h=70e35abfce31240e36921dd2b7c4a34867651fe0;p=pkg%2Fggml%2Fsources%2Fggml opencl: fix FA for f32 (llama/16584) --- diff --git a/src/ggml-opencl/kernels/flash_attn_f32.cl b/src/ggml-opencl/kernels/flash_attn_f32.cl index 9c0bab13..a6d74790 100644 --- a/src/ggml-opencl/kernels/flash_attn_f32.cl +++ b/src/ggml-opencl/kernels/flash_attn_f32.cl @@ -4,6 +4,7 @@ #define ACC_TYPE4 float4 #define DATA_TYPE float #define DATA_TYPE4 float4 +#define MASK_DATA_TYPE half #define CONVERT_ACC4(x) (x) #define CONVERT_DATA4(x) (x) @@ -148,7 +149,7 @@ __kernel void flash_attn_f32( if (k_row1 >= n_kv) score1 = -INFINITY; if (mask_base != NULL) { - const global DATA_TYPE* mask_ptr = (const global DATA_TYPE*)(mask_base + my_query_row * mask_nb1); + const global MASK_DATA_TYPE* mask_ptr = (const global MASK_DATA_TYPE*)(mask_base + my_query_row * mask_nb1); if (k_row0 < n_kv) score0 += slope * (ACC_TYPE)mask_ptr[k_row0]; if (k_row1 < n_kv) score1 += slope * (ACC_TYPE)mask_ptr[k_row1]; } @@ -281,7 +282,7 @@ __kernel void flash_attn_f32_q1( } ACC_TYPE score = (dot_acc.s0 + dot_acc.s1 + dot_acc.s2 + dot_acc.s3) * scale; if (mask_base != NULL) { - const global DATA_TYPE* mask_ptr = (const global DATA_TYPE*)(mask_base); + const global MASK_DATA_TYPE* mask_ptr = (const global MASK_DATA_TYPE*)(mask_base); score += slope * (ACC_TYPE)mask_ptr[k_idx]; } if (logit_softcap > 0.0f) { @@ -317,7 +318,7 @@ __kernel void flash_attn_f32_q1( } ACC_TYPE score = (dot_acc.s0 + dot_acc.s1 + dot_acc.s2 + dot_acc.s3) * scale; if (mask_base != NULL) { - const global DATA_TYPE* mask_ptr = (const global DATA_TYPE*)(mask_base); + const global MASK_DATA_TYPE* mask_ptr = (const global MASK_DATA_TYPE*)(mask_base); score += slope * (ACC_TYPE)mask_ptr[k_idx]; } if (logit_softcap > 0.0f) {