#define ACC_TYPE4 float4
#define DATA_TYPE float
#define DATA_TYPE4 float4
+#define MASK_DATA_TYPE half
#define CONVERT_ACC4(x) (x)
#define CONVERT_DATA4(x) (x)
if (k_row1 >= n_kv) score1 = -INFINITY;
if (mask_base != NULL) {
- const global DATA_TYPE* mask_ptr = (const global DATA_TYPE*)(mask_base + my_query_row * mask_nb1);
+ const global MASK_DATA_TYPE* mask_ptr = (const global MASK_DATA_TYPE*)(mask_base + my_query_row * mask_nb1);
if (k_row0 < n_kv) score0 += slope * (ACC_TYPE)mask_ptr[k_row0];
if (k_row1 < n_kv) score1 += slope * (ACC_TYPE)mask_ptr[k_row1];
}
}
ACC_TYPE score = (dot_acc.s0 + dot_acc.s1 + dot_acc.s2 + dot_acc.s3) * scale;
if (mask_base != NULL) {
- const global DATA_TYPE* mask_ptr = (const global DATA_TYPE*)(mask_base);
+ const global MASK_DATA_TYPE* mask_ptr = (const global MASK_DATA_TYPE*)(mask_base);
score += slope * (ACC_TYPE)mask_ptr[k_idx];
}
if (logit_softcap > 0.0f) {
}
ACC_TYPE score = (dot_acc.s0 + dot_acc.s1 + dot_acc.s2 + dot_acc.s3) * scale;
if (mask_base != NULL) {
- const global DATA_TYPE* mask_ptr = (const global DATA_TYPE*)(mask_base);
+ const global MASK_DATA_TYPE* mask_ptr = (const global MASK_DATA_TYPE*)(mask_base);
score += slope * (ACC_TYPE)mask_ptr[k_idx];
}
if (logit_softcap > 0.0f) {