]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
opencl: fix FA for f32 (#16584)
authorlhez <redacted>
Wed, 15 Oct 2025 17:48:28 +0000 (10:48 -0700)
committerGitHub <redacted>
Wed, 15 Oct 2025 17:48:28 +0000 (10:48 -0700)
ggml/src/ggml-opencl/kernels/flash_attn_f32.cl

index 9c0bab135a912a7dc931c05915a452f990a80e91..a6d74790375117c8a9e19d41d4d416721b298f69 100644 (file)
@@ -4,6 +4,7 @@
 #define ACC_TYPE4 float4
 #define DATA_TYPE float
 #define DATA_TYPE4 float4
+#define MASK_DATA_TYPE half
 #define CONVERT_ACC4(x) (x)
 #define CONVERT_DATA4(x) (x)
 
@@ -148,7 +149,7 @@ __kernel void flash_attn_f32(
             if (k_row1 >= n_kv) score1 = -INFINITY;
 
             if (mask_base != NULL) {
-                const global DATA_TYPE* mask_ptr = (const global DATA_TYPE*)(mask_base + my_query_row * mask_nb1);
+                const global MASK_DATA_TYPE* mask_ptr = (const global MASK_DATA_TYPE*)(mask_base + my_query_row * mask_nb1);
                 if (k_row0 < n_kv) score0 += slope * (ACC_TYPE)mask_ptr[k_row0];
                 if (k_row1 < n_kv) score1 += slope * (ACC_TYPE)mask_ptr[k_row1];
             }
@@ -281,7 +282,7 @@ __kernel void flash_attn_f32_q1(
         }
         ACC_TYPE score = (dot_acc.s0 + dot_acc.s1 + dot_acc.s2 + dot_acc.s3) * scale;
         if (mask_base != NULL) {
-            const global DATA_TYPE* mask_ptr = (const global DATA_TYPE*)(mask_base);
+            const global MASK_DATA_TYPE* mask_ptr = (const global MASK_DATA_TYPE*)(mask_base);
             score += slope * (ACC_TYPE)mask_ptr[k_idx];
         }
         if (logit_softcap > 0.0f) {
@@ -317,7 +318,7 @@ __kernel void flash_attn_f32_q1(
         }
         ACC_TYPE score = (dot_acc.s0 + dot_acc.s1 + dot_acc.s2 + dot_acc.s3) * scale;
         if (mask_base != NULL) {
-            const global DATA_TYPE* mask_ptr = (const global DATA_TYPE*)(mask_base);
+            const global MASK_DATA_TYPE* mask_ptr = (const global MASK_DATA_TYPE*)(mask_base);
             score += slope * (ACC_TYPE)mask_ptr[k_idx];
         }
         if (logit_softcap > 0.0f) {