From: nullname Date: Sat, 24 Jan 2026 06:02:07 +0000 (+0800) Subject: ggml-hexagon: flash-attn opt (llama/19025) X-Git-Tag: upstream/1.8.3+155~117 X-Git-Url: https://git.djapps.eu/?a=commitdiff_plain;h=79f1bb3d355c198e390f622555cb22225683a2bf;p=pkg%2Fggml%2Fsources%2Fwhisper.cpp ggml-hexagon: flash-attn opt (llama/19025) * optimize flash attention kernel by improving score computation and online softmax update * wip * Refactor online softmax update in flash attention kernel for improved performance * Optimize flash attention kernel by replacing float array with HVX_Vector for score computation * wip --- diff --git a/ggml/src/ggml-hexagon/htp/flash-attn-ops.c b/ggml/src/ggml-hexagon/htp/flash-attn-ops.c index 1de47d0f..c7cb2a4e 100644 --- a/ggml/src/ggml-hexagon/htp/flash-attn-ops.c +++ b/ggml/src/ggml-hexagon/htp/flash-attn-ops.c @@ -2,9 +2,9 @@ #pragma clang diagnostic ignored "-Wunused-function" #pragma clang diagnostic ignored "-Wunused-but-set-variable" +#include #include #include - #include #include @@ -111,7 +111,7 @@ static inline void hvx_dot_f16_f16_aa(float * restrict r, const void * restrict hvx_vec_store_u(r, 4, rsum); } -// MAD: y (F32) += x (F16) * v (float) +// MAD: y (F32) += x (F16) * s (float) static inline void hvx_mad_f32_f16_aa(float * restrict y, const void * restrict x, int n, float s) { const HVX_Vector * restrict ptr_x = (const HVX_Vector *) x; HVX_Vector * restrict ptr_y = (HVX_Vector *) y; @@ -318,9 +318,12 @@ static void flash_attn_ext_f16_thread(struct htp_ops_context * octx, int ith, in uint32_t ic = 0; // Process in blocks of 32 (VLEN_FP32) - for (; ic + VLEN_FP32 <= current_block_size; ic += VLEN_FP32) { + static_assert(FLASH_ATTN_BLOCK_SIZE / VLEN_FP32 == 4, "FLASH_ATTN_BLOCK_SIZE changed, fix HVX_Vector_x4 usage"); + HVX_Vector_x4 scores_x4; + HVX_Vector v_max = hvx_vec_splat_f32(-INFINITY); + for (uint32_t iv = 0; ic + VLEN_FP32 <= current_block_size; ic += VLEN_FP32, ++iv) { // 1. Compute scores - float __attribute__((aligned(VLEN))) scores_arr[VLEN_FP32]; + float __attribute__((aligned(VLEN))) scores_arr[FLASH_ATTN_BLOCK_SIZE]; for (int j = 0; j < VLEN_FP32; ++j) { const uint32_t cur_ic = ic + j; const uint8_t * k_ptr = k_base + cur_ic * size_k_row_padded; @@ -356,36 +359,43 @@ static void flash_attn_ext_f16_thread(struct htp_ops_context * octx, int ith, in scores = Q6_Vsf_equals_Vqf32(scores); } + scores_x4.v[iv] = scores; + v_max = Q6_Vsf_vmax_VsfVsf(scores, v_max); + } + + { // 4. Online Softmax Update - HVX_Vector v_max = hvx_vec_reduce_max_f32(scores); + v_max = hvx_vec_reduce_max_f32(v_max); float m_block = hvx_vec_get_f32(v_max); - float M_old = M; float M_new = (m_block > M) ? m_block : M; M = M_new; - float ms = expf(M_old - M_new); - + const float ms = expf(M_old - M_new); hvx_scale_f32_aa((uint8_t *) VKQ32, (const uint8_t *) VKQ32, DV, ms); - S = S * ms; HVX_Vector M_new_vec = hvx_vec_splat_f32(M_new); - HVX_Vector scores_shifted = Q6_Vqf32_vsub_VsfVsf(scores, M_new_vec); - HVX_Vector P = hvx_vec_exp_f32(Q6_Vsf_equals_Vqf32(scores_shifted)); - - HVX_Vector p_sum_vec = hvx_vec_reduce_sum_f32(P); - float p_sum = hvx_vec_get_f32(p_sum_vec); - S += p_sum; - - // 5. Accumulate V - float __attribute__((aligned(VLEN))) p_arr[VLEN_FP32]; - *(HVX_Vector*)p_arr = P; - - for (int j = 0; j < VLEN_FP32; ++j) { - const uint32_t cur_ic = ic + j; - const uint8_t * v_ptr = v_base + cur_ic * size_v_row_padded; - hvx_mad_f32_f16_aa(VKQ32, v_ptr, DV, p_arr[j]); + HVX_Vector p_sum_vec = hvx_vec_splat_f32(0.0f); + for (uint32_t ic2 = 0, iv = 0; ic2 + VLEN_FP32 <= current_block_size; ic2 += VLEN_FP32, ++iv) { + HVX_Vector scores = scores_x4.v[iv]; + HVX_Vector scores_shifted = Q6_Vqf32_vsub_VsfVsf(scores, M_new_vec); + HVX_Vector P = hvx_vec_exp_f32(Q6_Vsf_equals_Vqf32(scores_shifted)); + + p_sum_vec = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_VsfVsf(p_sum_vec, P)); + + // 5. Accumulate V + float __attribute__((aligned(VLEN))) p_arr[VLEN_FP32]; + *(HVX_Vector*)p_arr = P; + + for (int j = 0; j < VLEN_FP32; ++j) { + const uint32_t cur_ic = ic2 + j; + const uint8_t * v_ptr = v_base + cur_ic * size_v_row_padded; + hvx_mad_f32_f16_aa(VKQ32, v_ptr, DV, p_arr[j]); + } } + + p_sum_vec = hvx_vec_reduce_sum_f32(p_sum_vec); + S = S * ms + hvx_vec_get_f32(p_sum_vec); } // Leftover