#pragma clang diagnostic ignored "-Wunused-function"
#pragma clang diagnostic ignored "-Wunused-but-set-variable"
+#include <assert.h>
#include <HAP_farf.h>
#include <HAP_perf.h>
-
#include <math.h>
#include <string.h>
hvx_vec_store_u(r, 4, rsum);
}
-// MAD: y (F32) += x (F16) * v (float)
+// MAD: y (F32) += x (F16) * s (float)
static inline void hvx_mad_f32_f16_aa(float * restrict y, const void * restrict x, int n, float s) {
const HVX_Vector * restrict ptr_x = (const HVX_Vector *) x;
HVX_Vector * restrict ptr_y = (HVX_Vector *) y;
uint32_t ic = 0;
// Process in blocks of 32 (VLEN_FP32)
- for (; ic + VLEN_FP32 <= current_block_size; ic += VLEN_FP32) {
+ static_assert(FLASH_ATTN_BLOCK_SIZE / VLEN_FP32 == 4, "FLASH_ATTN_BLOCK_SIZE changed, fix HVX_Vector_x4 usage");
+ HVX_Vector_x4 scores_x4;
+ HVX_Vector v_max = hvx_vec_splat_f32(-INFINITY);
+ for (uint32_t iv = 0; ic + VLEN_FP32 <= current_block_size; ic += VLEN_FP32, ++iv) {
// 1. Compute scores
- float __attribute__((aligned(VLEN))) scores_arr[VLEN_FP32];
+ float __attribute__((aligned(VLEN))) scores_arr[FLASH_ATTN_BLOCK_SIZE];
for (int j = 0; j < VLEN_FP32; ++j) {
const uint32_t cur_ic = ic + j;
const uint8_t * k_ptr = k_base + cur_ic * size_k_row_padded;
scores = Q6_Vsf_equals_Vqf32(scores);
}
+ scores_x4.v[iv] = scores;
+ v_max = Q6_Vsf_vmax_VsfVsf(scores, v_max);
+ }
+
+ {
// 4. Online Softmax Update
- HVX_Vector v_max = hvx_vec_reduce_max_f32(scores);
+ v_max = hvx_vec_reduce_max_f32(v_max);
float m_block = hvx_vec_get_f32(v_max);
-
float M_old = M;
float M_new = (m_block > M) ? m_block : M;
M = M_new;
- float ms = expf(M_old - M_new);
-
+ const float ms = expf(M_old - M_new);
hvx_scale_f32_aa((uint8_t *) VKQ32, (const uint8_t *) VKQ32, DV, ms);
- S = S * ms;
HVX_Vector M_new_vec = hvx_vec_splat_f32(M_new);
- HVX_Vector scores_shifted = Q6_Vqf32_vsub_VsfVsf(scores, M_new_vec);
- HVX_Vector P = hvx_vec_exp_f32(Q6_Vsf_equals_Vqf32(scores_shifted));
-
- HVX_Vector p_sum_vec = hvx_vec_reduce_sum_f32(P);
- float p_sum = hvx_vec_get_f32(p_sum_vec);
- S += p_sum;
-
- // 5. Accumulate V
- float __attribute__((aligned(VLEN))) p_arr[VLEN_FP32];
- *(HVX_Vector*)p_arr = P;
-
- for (int j = 0; j < VLEN_FP32; ++j) {
- const uint32_t cur_ic = ic + j;
- const uint8_t * v_ptr = v_base + cur_ic * size_v_row_padded;
- hvx_mad_f32_f16_aa(VKQ32, v_ptr, DV, p_arr[j]);
+ HVX_Vector p_sum_vec = hvx_vec_splat_f32(0.0f);
+ for (uint32_t ic2 = 0, iv = 0; ic2 + VLEN_FP32 <= current_block_size; ic2 += VLEN_FP32, ++iv) {
+ HVX_Vector scores = scores_x4.v[iv];
+ HVX_Vector scores_shifted = Q6_Vqf32_vsub_VsfVsf(scores, M_new_vec);
+ HVX_Vector P = hvx_vec_exp_f32(Q6_Vsf_equals_Vqf32(scores_shifted));
+
+ p_sum_vec = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_VsfVsf(p_sum_vec, P));
+
+ // 5. Accumulate V
+ float __attribute__((aligned(VLEN))) p_arr[VLEN_FP32];
+ *(HVX_Vector*)p_arr = P;
+
+ for (int j = 0; j < VLEN_FP32; ++j) {
+ const uint32_t cur_ic = ic2 + j;
+ const uint8_t * v_ptr = v_base + cur_ic * size_v_row_padded;
+ hvx_mad_f32_f16_aa(VKQ32, v_ptr, DV, p_arr[j]);
+ }
}
+
+ p_sum_vec = hvx_vec_reduce_sum_f32(p_sum_vec);
+ S = S * ms + hvx_vec_get_f32(p_sum_vec);
}
// Leftover