#include "hex-dma.h"
#include "hvx-utils.h"
+#include "hvx-dump.h"
#define GGML_COMMON_DECL_C
#include "ggml-common.h"
#include "htp-msg.h"
#include "htp-ops.h"
+// Must be multiple of 32
+#define FLASH_ATTN_BLOCK_SIZE (32 * 2)
+
+// This is a bit of a hack because the compiler is strugling to properly inline
+// the default hvx_vec_f32_to_f16 with output into the local array.
+static void __attribute__((noinline)) hvx_vec_f32_to_f16_a(void *ptr, HVX_Vector v0, HVX_Vector v1)
+{
+ *(HVX_Vector *) ptr = hvx_vec_f32_to_f16(v0, v1);
+}
+
// Dot product of two F16 vectors, accumulating to float
static inline void hvx_dot_f16_f16_aa(float * restrict r, const void * restrict x, const void * restrict y, unsigned int n, float s) {
const HVX_Vector * restrict vx = (const HVX_Vector * restrict) x; // fp16
uint32_t nvec = n / VLEN_FP16; // num full fp16 hvx vectors
uint32_t nloe = n % VLEN_FP16; // leftover elements
- HVX_Vector rsum = Q6_V_vsplat_R(0);
+ HVX_VectorPair rsum_p = Q6_W_vcombine_VV(Q6_V_vsplat_R(0), Q6_V_vsplat_R(0));
uint32_t i = 0;
#pragma unroll(4)
for (i = 0; i < nvec; i++) {
- HVX_Vector y_hf = vy[i];
- HVX_Vector x_hf = vx[i];
-
- HVX_VectorPair xy_qf = Q6_Wqf32_vmpy_VhfVhf(x_hf, y_hf);
-
- rsum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(xy_qf), Q6_V_hi_W(xy_qf)), rsum));
+ rsum_p = hvx_vec_mpyacc_f32_f16(rsum_p, vx[i], vy[i]);
}
if (nloe) {
- // Load x (fp16) and zero-out unused elements
HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe * 2);
HVX_Vector y_hf = Q6_V_vand_QV(bmask, vy[i]);
HVX_Vector x_hf = Q6_V_vand_QV(bmask, vx[i]);
- HVX_VectorPair xy_qf = Q6_Wqf32_vmpy_VhfVhf(x_hf, y_hf);
-
- rsum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(xy_qf), Q6_V_hi_W(xy_qf)), rsum));
+ rsum_p = hvx_vec_mpyacc_f32_f16(rsum_p, x_hf, y_hf);
}
- rsum = Q6_Vqf32_vmpy_VsfVsf(hvx_vec_splat_f32(s), hvx_vec_reduce_sum_f32(rsum));
- hvx_vec_store_u(r, 4, Q6_Vsf_equals_Vqf32(rsum));
+ HVX_Vector rsum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_VsfVsf(Q6_V_lo_W(rsum_p), Q6_V_hi_W(rsum_p)));
+ rsum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(hvx_vec_splat_f32(s), hvx_vec_reduce_sum_f32(rsum)));
+ hvx_vec_store_u(r, 4, rsum);
}
-static inline void hvx_dot_f16_f16_aa_rx2(float * restrict r,
- const void * restrict y,
- const void * restrict x0,
- const void * restrict x1,
- unsigned int n,
- float s) {
- const HVX_Vector * restrict vx0 = (const HVX_Vector * restrict) x0; // fp16
- const HVX_Vector * restrict vx1 = (const HVX_Vector * restrict) x1; // fp16
- const HVX_Vector * restrict vy = (const HVX_Vector * restrict) y; // fp16
-
- uint32_t nvec = n / VLEN_FP16; // num full fp16 hvx vectors
- uint32_t nloe = n % VLEN_FP16; // leftover elements
-
- HVX_Vector rsum0 = Q6_V_vsplat_R(0);
- HVX_Vector rsum1 = Q6_V_vsplat_R(0);
+static inline HVX_Vector hvx_dot_f16_f16_aa_rx4(const void * restrict y,
+ const uint8_t * restrict x,
+ const size_t stride_x,
+ const size_t nvec,
+ const size_t nloe) {
+ const HVX_Vector * restrict vx0 = (const HVX_Vector * restrict) x; // fp16
+ const HVX_Vector * restrict vx1 = (const HVX_Vector * restrict) (x + stride_x); // fp16
+ const HVX_Vector * restrict vx2 = (const HVX_Vector * restrict) (x + stride_x * 2); // fp16
+ const HVX_Vector * restrict vx3 = (const HVX_Vector * restrict) (x + stride_x * 3); // fp16
+ const HVX_Vector * restrict vy = (const HVX_Vector * restrict) y; // fp16
+
+ HVX_VectorPair rsum0_p = Q6_W_vcombine_VV(Q6_V_vsplat_R(0), Q6_V_vsplat_R(0));
+ HVX_VectorPair rsum1_p = Q6_W_vcombine_VV(Q6_V_vsplat_R(0), Q6_V_vsplat_R(0));
+ HVX_VectorPair rsum2_p = Q6_W_vcombine_VV(Q6_V_vsplat_R(0), Q6_V_vsplat_R(0));
+ HVX_VectorPair rsum3_p = Q6_W_vcombine_VV(Q6_V_vsplat_R(0), Q6_V_vsplat_R(0));
uint32_t i = 0;
- #pragma unroll(4)
for (i = 0; i < nvec; i++) {
HVX_Vector y_hf = vy[i];
HVX_Vector x0_hf = vx0[i];
HVX_Vector x1_hf = vx1[i];
+ HVX_Vector x2_hf = vx2[i];
+ HVX_Vector x3_hf = vx3[i];
- HVX_VectorPair xy0_qf = Q6_Wqf32_vmpy_VhfVhf(x0_hf, y_hf);
- HVX_VectorPair xy1_qf = Q6_Wqf32_vmpy_VhfVhf(x1_hf, y_hf);
-
- rsum0 = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(xy0_qf), Q6_V_hi_W(xy0_qf)), rsum0));
- rsum1 = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(xy1_qf), Q6_V_hi_W(xy1_qf)), rsum1));
+ rsum0_p = hvx_vec_mpyacc_f32_f16(rsum0_p, x0_hf, y_hf);
+ rsum1_p = hvx_vec_mpyacc_f32_f16(rsum1_p, x1_hf, y_hf);
+ rsum2_p = hvx_vec_mpyacc_f32_f16(rsum2_p, x2_hf, y_hf);
+ rsum3_p = hvx_vec_mpyacc_f32_f16(rsum3_p, x3_hf, y_hf);
}
if (nloe) {
// Load x (fp16) and zero-out unused elements
HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe * 2);
- HVX_Vector x0_hf = Q6_V_vand_QV(bmask, vx0[i]);
- HVX_Vector x1_hf = Q6_V_vand_QV(bmask, vx1[i]);
- HVX_Vector y_hf = Q6_V_vand_QV(bmask, vy[i]);
+ HVX_Vector y_hf = Q6_V_vand_QV(bmask, vy[i]);
+ HVX_Vector x0_hf = Q6_V_vand_QV(bmask, vx0[i]);
+ HVX_Vector x1_hf = Q6_V_vand_QV(bmask, vx1[i]);
+ HVX_Vector x2_hf = Q6_V_vand_QV(bmask, vx2[i]);
+ HVX_Vector x3_hf = Q6_V_vand_QV(bmask, vx3[i]);
+
+ rsum0_p = hvx_vec_mpyacc_f32_f16(rsum0_p, x0_hf, y_hf);
+ rsum1_p = hvx_vec_mpyacc_f32_f16(rsum1_p, x1_hf, y_hf);
+ rsum2_p = hvx_vec_mpyacc_f32_f16(rsum2_p, x2_hf, y_hf);
+ rsum3_p = hvx_vec_mpyacc_f32_f16(rsum3_p, x3_hf, y_hf);
+ }
+
+ HVX_Vector rsum0 = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_VsfVsf(Q6_V_lo_W(rsum0_p), Q6_V_hi_W(rsum0_p)));
+ HVX_Vector rsum1 = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_VsfVsf(Q6_V_lo_W(rsum1_p), Q6_V_hi_W(rsum1_p)));
+ HVX_Vector rsum2 = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_VsfVsf(Q6_V_lo_W(rsum2_p), Q6_V_hi_W(rsum2_p)));
+ HVX_Vector rsum3 = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_VsfVsf(Q6_V_lo_W(rsum3_p), Q6_V_hi_W(rsum3_p)));
- HVX_VectorPair xy0_qf = Q6_Wqf32_vmpy_VhfVhf(x0_hf, y_hf);
- HVX_VectorPair xy1_qf = Q6_Wqf32_vmpy_VhfVhf(x1_hf, y_hf);
+ HVX_Vector_x4 rsum0123 = { .v = { rsum0, rsum1, rsum2, rsum3 } };
+ return hvx_vec_reduce_sum_f32x4(rsum0123);
+}
- rsum0 = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(xy0_qf), Q6_V_hi_W(xy0_qf)), rsum0));
- rsum1 = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(xy1_qf), Q6_V_hi_W(xy1_qf)), rsum1));
+static inline HVX_Vector hvx_dot_f16_f16_aa_rx32(const void * restrict y,
+ const uint8_t * restrict x,
+ const size_t stride_x,
+ const size_t n,
+ float s) {
+
+ const size_t nvec = n / VLEN_FP16; // num full fp16 hvx vectors
+ const size_t nloe = n % VLEN_FP16; // leftover elements
+
+ HVX_Vector sums; // initialize at j = 0
+ const size_t stride_x_4 = stride_x * 4;
+ for (uint32_t j = 0; j < VLEN_FP32; j += 4) {
+ HVX_Vector sums_x4 = hvx_dot_f16_f16_aa_rx4(y, x, stride_x, nvec, nloe);
+ HVX_VectorPred pred = Q6_Q_vsetq_R(j * SIZEOF_FP32);
+ sums = Q6_V_vmux_QVV(pred, sums, sums_x4);
+ x += stride_x_4;
}
- HVX_Vector rsum = Q6_Vqf32_vmpy_VsfVsf(hvx_vec_splat_f32(s), hvx_vec_reduce_sum_f32x2(rsum0, rsum1));
- hvx_vec_store_u(r, 8, Q6_Vsf_equals_Vqf32(rsum));
+ sums = Q6_Vqf32_vmpy_VsfVsf(hvx_vec_splat_f32(s), sums);
+ return Q6_Vsf_equals_Vqf32(sums);
}
-// MAD: y (F32) += x (F16) * s (F32)
-static inline void hvx_mad_f32_f16_aa(float * restrict y, const void * restrict x, int n, float s) {
- const HVX_Vector * restrict ptr_x = (const HVX_Vector *) x;
- HVX_Vector * restrict ptr_y = (HVX_Vector *) y;
+// MAD: y (F32) += x (F16) * s (F16)
+static inline void hvx_mad_f32_f16_aa(float * restrict y, const void * restrict x, const __fp16 * restrict s, int n) {
+ const HVX_Vector * restrict vx0 = (const HVX_Vector *) x;
+
+ HVX_VectorPair * restrict vy_p = (HVX_VectorPair *) y;
+ HVX_Vector * restrict vy = (HVX_Vector *) y;
uint32_t nvec = n / VLEN_FP16; // num full fp16 hvx vectors
uint32_t nloe = n % VLEN_FP16; // leftover elements
- HVX_Vector S = hvx_vec_splat_f16(s);
+ HVX_Vector S0 = hvx_vec_splat_f16(*s);
uint32_t i = 0;
- #pragma unroll(4)
+
+ #pragma unroll(2)
for (i = 0; i < nvec; ++i) {
- // Multiply x * s -> pair of F32 vectors
- HVX_VectorPair xs_p = Q6_Wqf32_vmpy_VhfVhf(Q6_Vh_vshuff_Vh(ptr_x[i]), S);
- ptr_y[i*2] = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(Q6_V_lo_W(xs_p), ptr_y[i*2]));
- ptr_y[i*2+1] = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(Q6_V_hi_W(xs_p), ptr_y[i*2+1]));
+ vy_p[i] = hvx_vec_mpyacc_f32_f16(vy_p[i], Q6_Vh_vshuff_Vh(vx0[i]), S0);
}
if (nloe) {
- HVX_VectorPair xs_p = Q6_Wqf32_vmpy_VhfVhf(Q6_Vh_vshuff_Vh(ptr_x[i]), S);
+ HVX_VectorPair xy_p = vy_p[i];
+ xy_p = hvx_vec_mpyacc_f32_f16(xy_p, Q6_Vh_vshuff_Vh(vx0[i]), S0);
- HVX_Vector xs = Q6_V_lo_W(xs_p);
- i = 2 * i; // index for ptr_y
+ HVX_Vector xy = Q6_V_lo_W(xy_p);
+ i = 2 * i; // index for vy
- if (nloe >= 32) {
- ptr_y[i] = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(xs, ptr_y[i]));
- nloe -= 32; ++i; xs = Q6_V_hi_W(xs_p);
+ if (nloe >= VLEN_FP32) {
+ vy[i] = xy;
+ nloe -= VLEN_FP32; ++i; xy = Q6_V_hi_W(xy_p);
}
if (nloe) {
- HVX_Vector xy = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(xs, ptr_y[i]));
- hvx_vec_store_a(&ptr_y[i], nloe * 4, xy);
+ hvx_vec_store_a(&vy[i], nloe * 4, xy);
}
}
}
-// MAD: y (F32) += x0 (F16) * s0 (F32) + x1 (F16) * s1 (F32)
-static inline void hvx_mad_f32_f16_aa_rx2(float * restrict y,
- const void * restrict x0,
- const void * restrict x1,
- float s0,
- float s1,
- int n) {
- const HVX_Vector * restrict ptr_x0 = (const HVX_Vector *) x0;
- const HVX_Vector * restrict ptr_x1 = (const HVX_Vector *) x1;
- HVX_Vector * restrict ptr_y = (HVX_Vector *) y;
+// MAD: y (F32) += x0 (F16) * s0 (F16) + x1 (F16) * s1 (F16)
+static inline void hvx_mad_f32_f16_aa_rx2(float * restrict y, const void * restrict x0, const void * restrict x1,
+ const __fp16 * restrict s0, const __fp16 * restrict s1, int n) {
+ const HVX_Vector * restrict vx0 = (const HVX_Vector *) x0;
+ const HVX_Vector * restrict vx1 = (const HVX_Vector *) x1;
+
+ HVX_VectorPair * restrict vy_p = (HVX_VectorPair *) y;
+ HVX_Vector * restrict vy = (HVX_Vector *) y;
uint32_t nvec = n / VLEN_FP16; // num full fp16 hvx vectors
uint32_t nloe = n % VLEN_FP16; // leftover elements
- HVX_Vector S0 = hvx_vec_splat_f16(s0);
- HVX_Vector S1 = hvx_vec_splat_f16(s1);
+ HVX_Vector S0 = hvx_vec_splat_f16(*s0);
+ HVX_Vector S1 = hvx_vec_splat_f16(*s1);
uint32_t i = 0;
+
#pragma unroll(2)
for (i = 0; i < nvec; ++i) {
- // Multiply x * s -> pair of F32 vectors
- HVX_VectorPair xs0_p = Q6_Wqf32_vmpy_VhfVhf(Q6_Vh_vshuff_Vh(ptr_x0[i]), S0);
- HVX_VectorPair xs1_p = Q6_Wqf32_vmpy_VhfVhf(Q6_Vh_vshuff_Vh(ptr_x1[i]), S1);
-
- HVX_Vector xs_p_lo = Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(xs0_p), Q6_V_lo_W(xs1_p));
- HVX_Vector xs_p_hi = Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_hi_W(xs0_p), Q6_V_hi_W(xs1_p));
-
- ptr_y[i * 2] = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(xs_p_lo, ptr_y[i * 2]));
- ptr_y[i * 2 + 1] = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(xs_p_hi, ptr_y[i * 2 + 1]));
+ vy_p[i] = hvx_vec_mpyacc_f32_f16(vy_p[i], Q6_Vh_vshuff_Vh(vx0[i]), S0);
+ vy_p[i] = hvx_vec_mpyacc_f32_f16(vy_p[i], Q6_Vh_vshuff_Vh(vx1[i]), S1);
}
if (nloe) {
- HVX_VectorPair xs0_p = Q6_Wqf32_vmpy_VhfVhf(Q6_Vh_vshuff_Vh(ptr_x0[i]), S0);
- HVX_VectorPair xs1_p = Q6_Wqf32_vmpy_VhfVhf(Q6_Vh_vshuff_Vh(ptr_x1[i]), S1);
+ HVX_VectorPair xy_p = vy_p[i];
+ xy_p = hvx_vec_mpyacc_f32_f16(xy_p, Q6_Vh_vshuff_Vh(vx0[i]), S0);
+ xy_p = hvx_vec_mpyacc_f32_f16(xy_p, Q6_Vh_vshuff_Vh(vx1[i]), S1);
- HVX_Vector xs_p_lo = Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(xs0_p), Q6_V_lo_W(xs1_p));
- HVX_Vector xs = xs_p_lo;
- i = 2 * i; // index for ptr_y
+ HVX_Vector xy = Q6_V_lo_W(xy_p);
+ i = 2 * i; // index for vy
- if (nloe >= 32) {
- ptr_y[i] = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(xs, ptr_y[i]));
- nloe -= 32; ++i;
- xs = Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_hi_W(xs0_p), Q6_V_hi_W(xs1_p));
+ if (nloe >= VLEN_FP32) {
+ vy[i] = xy;
+ nloe -= VLEN_FP32; ++i; xy = Q6_V_hi_W(xy_p);
}
if (nloe) {
- HVX_Vector xy = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(xs, ptr_y[i]));
- hvx_vec_store_a(&ptr_y[i], nloe * 4, xy);
+ hvx_vec_store_a(&vy[i], nloe * 4, xy);
}
}
}
-#define FLASH_ATTN_BLOCK_SIZE 128
-
struct htp_fa_context {
const struct htp_ops_context * octx;
size_t size_v_block;
size_t size_m_block;
+ uint32_t qrows;
+ uint32_t qrows_per_thread;
+
bool is_q_fp32;
+
+ uint64_t t_start;
};
static inline void hvx_scale_vec_f32_aa(uint8_t * restrict dst, const uint8_t * restrict src, const int n, HVX_Vector vs) {
const uint32_t nb3 = dst->nb[3];
// total rows in q
- const uint32_t nr = neq1*neq2*neq3;
-
- const uint32_t dr = (nr + nth - 1) / nth;
+ const uint32_t nr = factx->qrows;
+ const uint32_t dr = factx->qrows_per_thread;
const uint32_t ir0 = dr * ith;
const uint32_t ir1 = MIN(ir0 + dr, nr);
const uint8_t * q_row_ptr = (const uint8_t *) q->data + (iq1*nbq1 + iq2*nbq2 + iq3*nbq3);
dma_queue_push(dma, dma_make_ptr(spad_q, q_row_ptr), factx->size_q_row_padded, nbq1, size_q_row, 1);
- const uint32_t h = iq2; // head index
- const float slope = (factx->max_bias > 0.0f) ? (h < factx->n_head_log2 ? powf(factx->m0, h + 1) : powf(factx->m1, 2*(h - factx->n_head_log2) + 1)) : 1.0f;
-
- HVX_Vector S_vec = hvx_vec_splat_f32(0.0f);
- HVX_Vector M_vec = hvx_vec_splat_f32(-INFINITY);
-
- // Clear accumulator
- hvx_splat_f32_a(spad_a, 0, DV);
- float * VKQ32 = (float *) spad_a;
+ // FARF(HIGH, "fa %u: prefetch Q: ir %u iq1 %u iq2 %u iq3 %u q_row_ptr %p size %u : usec %u", ith, ir, iq1, iq2, iq3, q_row_ptr, size_q_row,
+ // (unsigned)HAP_perf_qtimer_count_to_us(HAP_perf_get_qtimer_count() - factx->t_start));
const __fp16 * mp_base = NULL;
if (mask) {
// Mask is 1D contiguous for this row
dma_queue_push(dma, dma_make_ptr(m_dst, m_src), current_block_size * 2, current_block_size * 2, current_block_size * 2, 1);
}
+
+ // FARF(HIGH, "fa %u: prefetch KVM: ir %u ib %u iq1 %u iq2 %u iq3 %u : size_k_row %u size_v_row %u bs %u: usec %u",
+ // ith, ir, ib, iq1, iq2, iq3,
+ // size_k_row, size_v_row, current_block_size,
+ // (unsigned)HAP_perf_qtimer_count_to_us(HAP_perf_get_qtimer_count() - factx->t_start));
}
+ const uint32_t h = iq2; // head index
+ const float slope = (factx->max_bias > 0.0f) ? (h < factx->n_head_log2 ? powf(factx->m0, h + 1) : powf(factx->m1, 2*(h - factx->n_head_log2) + 1)) : 1.0f;
+
+ HVX_Vector S_vec = hvx_vec_splat_f32(0.0f);
+ HVX_Vector M_vec = hvx_vec_splat_f32(-INFINITY);
+
+ // Clear accumulator
+ hvx_splat_f32_a(spad_a, 0, DV);
+ float * VKQ32 = (float *) (spad_a + 0);
+
uint8_t * q_ptr_vtcm = dma_queue_pop(dma).dst;
if (factx->is_q_fp32) {
hvx_copy_f16_f32_aa(q_ptr_vtcm, q_ptr_vtcm, DK); // inplace convert f32 to f16
uint8_t * v_base = dma_queue_pop(dma).dst; // V
__fp16 * m_base = mask ? dma_queue_pop(dma).dst : NULL; // M
+ // FARF(HIGH, "fa %u: process: ir %u ib %u : iq1 %u iq2 %u iq3 %u q_ptr_vtcm %p : usec %u",
+ // ith, ir, ib, iq1, iq2, iq3, q_ptr_vtcm,
+ // (unsigned)HAP_perf_qtimer_count_to_us(HAP_perf_get_qtimer_count() - factx->t_start));
+
// Inner loop processing the block from VTCM
uint32_t ic = 0;
- // Process in blocks of 32 (VLEN_FP32)
- static_assert(FLASH_ATTN_BLOCK_SIZE / VLEN_FP32 <= 4, "FLASH_ATTN_BLOCK_SIZE changed, fix HVX_Vector_x4 usage");
- HVX_Vector_x4 scores_x4;
+ // Process in sub-blocks of 32 (VLEN_FP32)
+ HVX_Vector sb_scores[FLASH_ATTN_BLOCK_SIZE / VLEN_FP32];
HVX_Vector v_max = hvx_vec_splat_f32(-INFINITY);
for (uint32_t iv = 0; ic + VLEN_FP32 <= current_block_size; ic += VLEN_FP32, ++iv) {
// 1. Compute scores
- float __attribute__((aligned(VLEN))) scores_arr[VLEN_FP32];
- for (uint32_t j = 0; j < VLEN_FP32; j += 2) {
- const uint32_t cur_ic = ic + j;
- const uint8_t * k_ptr = k_base + cur_ic * factx->size_k_row_padded;
- hvx_dot_f16_f16_aa_rx2(&scores_arr[j], q_ptr_vtcm, k_ptr, k_ptr + factx->size_k_row_padded, DK, factx->scale);
- }
-
- HVX_Vector scores = *(HVX_Vector *) scores_arr;
+ HVX_Vector scores = hvx_dot_f16_f16_aa_rx32(q_ptr_vtcm, k_base + ic * factx->size_k_row_padded, factx->size_k_row_padded, DK, factx->scale);
// 2. Softcap
if (factx->logit_softcap != 0.0f) {
scores = Q6_Vsf_equals_Vqf32(scores);
}
- scores_x4.v[iv] = scores;
+ sb_scores[iv] = scores;
v_max = hvx_vec_reduce_max2_f32(scores, v_max); // All lanes have block max
}
{
// 4. Online Softmax Update
HVX_Vector M_new_vec = Q6_Vsf_vmax_VsfVsf(v_max, M_vec);
- HVX_Vector diff_vec = Q6_Vqf32_vsub_VsfVsf(M_vec, M_new_vec);
- HVX_Vector ms_vec = hvx_vec_exp_f32(Q6_Vsf_equals_Vqf32(diff_vec));
+ HVX_Vector diff_vec = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vsub_VsfVsf(M_vec, M_new_vec));
+ HVX_Vector ms_vec = hvx_vec_exp_f32(diff_vec);
M_vec = M_new_vec;
hvx_scale_vec_f32_aa((uint8_t *) VKQ32, (const uint8_t *) VKQ32, DV, ms_vec);
HVX_Vector p_sum_vec = hvx_vec_splat_f32(0.0f);
for (uint32_t ic2 = 0, iv = 0; ic2 + VLEN_FP32 <= current_block_size; ic2 += VLEN_FP32, ++iv) {
- HVX_Vector scores = scores_x4.v[iv];
+ HVX_Vector scores = sb_scores[iv];
HVX_Vector scores_shifted = Q6_Vqf32_vsub_VsfVsf(scores, M_vec);
HVX_Vector P = hvx_vec_exp_f32(Q6_Vsf_equals_Vqf32(scores_shifted));
p_sum_vec = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_VsfVsf(p_sum_vec, P));
// 5. Accumulate V
- float __attribute__((aligned(VLEN))) p_arr[VLEN_FP32];
- *(HVX_Vector *) p_arr = P;
+ __fp16 __attribute__((aligned(VLEN))) p_arr[VLEN_FP16];
+ hvx_vec_f32_to_f16_a(p_arr, P, hvx_vec_splat_f32(0));
for (uint32_t j = 0; j < VLEN_FP32; j += 2) {
const uint32_t cur_ic = ic2 + j;
const uint8_t * v_ptr = v_base + cur_ic * factx->size_v_row_padded;
- hvx_mad_f32_f16_aa_rx2(VKQ32, v_ptr, v_ptr + factx->size_v_row_padded, p_arr[j], p_arr[j + 1], DV);
+ hvx_mad_f32_f16_aa_rx2(VKQ32, v_ptr, v_ptr + factx->size_v_row_padded, (p_arr + j), (p_arr + j + 1), DV);
}
}
S_vec = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_VsfVsf(Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(S_vec, ms_vec)), p_sum_vec));
}
- // Sync scalars for leftover/next block if needed
- float M = hvx_vec_get_f32(M_vec);
- float S = hvx_vec_get_f32(S_vec);
+ if (ic < current_block_size) {
+ // Sync scalars for leftover/next block if needed
+ float M = hvx_vec_get_f32(M_vec);
+ float S = hvx_vec_get_f32(S_vec);
+
+ // Leftover
+ for (; ic < current_block_size; ++ic) {
+ float s_val;
+ const uint8_t * k_ptr = k_base + ic * factx->size_k_row_padded;
+ hvx_dot_f16_f16_aa(&s_val, q_ptr_vtcm, k_ptr, DK, factx->scale);
+ if (factx->logit_softcap != 0.0f) {
+ s_val = factx->logit_softcap * tanhf(s_val);
+ }
- // Leftover
- for (; ic < current_block_size; ++ic) {
- float s_val;
- const uint8_t * k_ptr = k_base + ic * factx->size_k_row_padded;
- hvx_dot_f16_f16_aa(&s_val, q_ptr_vtcm, k_ptr, DK, factx->scale);
- if (factx->logit_softcap != 0.0f) {
- s_val = factx->logit_softcap * tanhf(s_val);
- }
+ if (mask) {
+ const float m_val = m_base[ic];
+ s_val += slope * m_val;
+ }
- if (mask) {
- const float m_val = m_base[ic];
- s_val += slope * m_val;
- }
+ const float Mold = M;
+ __fp16 vs = 1.0f;
+
+ if (s_val > M) {
+ M = s_val;
+ HVX_Vector diff_vec = hvx_vec_splat_f32(Mold - M);
+ HVX_Vector ms_vec = hvx_vec_exp_f32(diff_vec);
+ hvx_scale_vec_f32_aa((uint8_t *) VKQ32, (const uint8_t *) VKQ32, DV, ms_vec);
+
+ float ms = hvx_vec_get_f32(ms_vec);
+ S = S * ms + vs;
+ } else {
+ HVX_Vector diff_vec = hvx_vec_splat_f32(s_val - M);
+ vs = hvx_vec_get_f32(hvx_vec_exp_f32(diff_vec));
+ S += vs;
+ }
- const float Mold = M;
- float vs = 1.0f;
-
- if (s_val > M) {
- M = s_val;
- HVX_Vector diff_vec = hvx_vec_splat_f32(Mold - M);
- HVX_Vector ms_vec = hvx_vec_exp_f32(diff_vec);
- hvx_scale_vec_f32_aa((uint8_t *) VKQ32, (const uint8_t *) VKQ32, DV, ms_vec);
-
- float ms = hvx_vec_get_f32(ms_vec);
- S = S * ms + vs;
- } else {
- HVX_Vector diff_vec = hvx_vec_splat_f32(s_val - M);
- vs = hvx_vec_get_f32(hvx_vec_exp_f32(diff_vec));
- S += vs;
- }
+ const uint8_t * v_ptr = v_base + ic * factx->size_v_row_padded;
- const uint8_t * v_ptr = v_base + ic * factx->size_v_row_padded;
+ hvx_mad_f32_f16_aa(VKQ32, v_ptr, &vs, DV);
+ }
- hvx_mad_f32_f16_aa(VKQ32, v_ptr, DV, vs);
+ M_vec = hvx_vec_splat_f32(M);
+ S_vec = hvx_vec_splat_f32(S);
}
- M_vec = hvx_vec_splat_f32(M);
- S_vec = hvx_vec_splat_f32(S);
// Issue DMA for next+1 block (if exists)
if (ib + 2 < factx->n_blocks) {
const uint8_t * m_src = (const uint8_t *) (mp_base + next_ic_start);
dma_queue_push(dma, dma_make_ptr(m_base, m_src), next_block_size * 2, next_block_size * 2, next_block_size * 2, 1);
}
+
+ // FARF(HIGH, "fa %u: prefetch KVM: ir %u ib %u : iq1 %u iq2 %u iq3 %u : size_k_row %u size_v_row %u bs %u: usec %u",
+ // ith, ir, next_ib, iq1, iq2, iq3,
+ // size_k_row, size_v_row, next_block_size,
+ // (unsigned)HAP_perf_qtimer_count_to_us(HAP_perf_get_qtimer_count() - factx->t_start));
}
}
struct htp_fa_context factx;
factx.octx = octx;
+ factx.t_start = HAP_perf_get_qtimer_count();
+
factx.src0_div21 = init_fastdiv_values(q->ne[2] * q->ne[1]);
factx.src0_div1 = init_fastdiv_values(q->ne[1]);
factx.m0 = powf(2.0f, -(max_bias ) / factx.n_head_log2);
factx.m1 = powf(2.0f, -(max_bias / 2.0f) / factx.n_head_log2);
+ // total rows in q
+ const uint32_t neq0 = q->ne[0];
+ const uint32_t neq1 = q->ne[1];
+ const uint32_t neq2 = q->ne[2];
+ const uint32_t neq3 = q->ne[3];
+
+ factx.qrows = neq1*neq2*neq3;
+ factx.qrows_per_thread = (factx.qrows + octx->n_threads - 1) / octx->n_threads;
+
size_t size_vkq_acc = hex_round_up(v->ne[0] * sizeof(float), 128); // VKQ32
octx->src0_spad.size_per_thread = size_q_block * 1;
uint32_t nvec = n / VLEN_FP16; // num full fp16 hvx vectors
uint32_t nloe = n % VLEN_FP16; // leftover elements
- HVX_Vector rsum = Q6_V_vsplat_R(0);
+ HVX_VectorPair rsum_p = Q6_W_vcombine_VV(Q6_V_vsplat_R(0), Q6_V_vsplat_R(0));
uint32_t i = 0;
#pragma unroll(4)
for (i = 0; i < nvec; i++) {
- HVX_VectorPair xy_qf = Q6_Wqf32_vmpy_VhfVhf(x[i], y[i]);
- rsum = Q6_Vqf32_vadd_Vqf32Vqf32(rsum, Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(xy_qf), Q6_V_hi_W(xy_qf)));
+ rsum_p = hvx_vec_mpyacc_f32_f16(rsum_p, x[i], y[i]);
}
if (nloe) {
HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe * 2);
HVX_Vector x_hf = Q6_V_vand_QV(bmask, x[i]);
HVX_Vector y_hf = Q6_V_vand_QV(bmask, y[i]);
-
- HVX_VectorPair xy_qf = Q6_Wqf32_vmpy_VhfVhf(x_hf, y_hf);
- rsum = Q6_Vqf32_vadd_Vqf32Vqf32(rsum, Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(xy_qf), Q6_V_hi_W(xy_qf)));
+ rsum_p = hvx_vec_mpyacc_f32_f16(rsum_p, x_hf, y_hf);
}
- rsum = hvx_vec_reduce_sum_f32(Q6_Vsf_equals_Vqf32(rsum));
- hvx_vec_store_u(&s[0], 4, rsum);
+ HVX_Vector rsum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_VsfVsf(Q6_V_lo_W(rsum_p), Q6_V_hi_W(rsum_p)));
+ hvx_vec_store_u(s, 4, hvx_vec_reduce_sum_f32(rsum));
}
static void vec_dot_f16_f16_aa_2x1(const int n, float * restrict s0,
uint32_t nvec = n / VLEN_FP16;
uint32_t nloe = n % VLEN_FP16;
- HVX_Vector rsum0 = Q6_V_vsplat_R(0);
- HVX_Vector rsum1 = Q6_V_vsplat_R(0);
+ HVX_VectorPair rsum0_p = Q6_W_vcombine_VV(Q6_V_vsplat_R(0), Q6_V_vsplat_R(0));
+ HVX_VectorPair rsum1_p = Q6_W_vcombine_VV(Q6_V_vsplat_R(0), Q6_V_vsplat_R(0));
uint32_t i = 0;
#pragma unroll(2)
for (i = 0; i < nvec; i++) {
HVX_Vector y_hf = y[i];
- HVX_VectorPair xy0_qf = Q6_Wqf32_vmpy_VhfVhf(x0[i], y_hf);
- HVX_VectorPair xy1_qf = Q6_Wqf32_vmpy_VhfVhf(x1[i], y_hf);
-
- rsum0 = Q6_Vqf32_vadd_Vqf32Vqf32(rsum0, Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(xy0_qf), Q6_V_hi_W(xy0_qf)));
- rsum1 = Q6_Vqf32_vadd_Vqf32Vqf32(rsum1, Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(xy1_qf), Q6_V_hi_W(xy1_qf)));
+ rsum0_p = hvx_vec_mpyacc_f32_f16(rsum0_p, x0[i], y_hf);
+ rsum1_p = hvx_vec_mpyacc_f32_f16(rsum1_p, x1[i], y_hf);
}
if (nloe) {
HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe * 2);
+ HVX_Vector y_hf = Q6_V_vand_QV(bmask, y[i]);
HVX_Vector x0_hf = Q6_V_vand_QV(bmask, x0[i]);
HVX_Vector x1_hf = Q6_V_vand_QV(bmask, x1[i]);
- HVX_Vector y_hf = Q6_V_vand_QV(bmask, y[i]);
-
- HVX_VectorPair xy0_qf = Q6_Wqf32_vmpy_VhfVhf(x0_hf, y_hf);
- HVX_VectorPair xy1_qf = Q6_Wqf32_vmpy_VhfVhf(x1_hf, y_hf);
-
- rsum0 = Q6_Vqf32_vadd_Vqf32Vqf32(rsum0, Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(xy0_qf), Q6_V_hi_W(xy0_qf)));
- rsum1 = Q6_Vqf32_vadd_Vqf32Vqf32(rsum1, Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(xy1_qf), Q6_V_hi_W(xy1_qf)));
+ rsum0_p = hvx_vec_mpyacc_f32_f16(rsum0_p, x0_hf, y_hf);
+ rsum1_p = hvx_vec_mpyacc_f32_f16(rsum1_p, x1_hf, y_hf);
}
- HVX_Vector rsum = hvx_vec_reduce_sum_f32x2(Q6_Vsf_equals_Vqf32(rsum0), Q6_Vsf_equals_Vqf32(rsum1));
+ HVX_Vector rsum0 = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_VsfVsf(Q6_V_lo_W(rsum0_p), Q6_V_hi_W(rsum0_p)));
+ HVX_Vector rsum1 = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_VsfVsf(Q6_V_lo_W(rsum1_p), Q6_V_hi_W(rsum1_p)));
+ HVX_Vector rsum = hvx_vec_reduce_sum_f32x2(rsum0, rsum1);
hvx_vec_store_u(s0, 8, rsum);
}
uint32_t nloe = n % VLEN_FP16;
// Row sums (sf) - 4 accumulators for 2×2 tile
- HVX_Vector r0_c0_sum = Q6_V_vsplat_R(0);
- HVX_Vector r0_c1_sum = Q6_V_vsplat_R(0);
- HVX_Vector r1_c0_sum = Q6_V_vsplat_R(0);
- HVX_Vector r1_c1_sum = Q6_V_vsplat_R(0);
+ HVX_VectorPair r0_c0_sum_p = Q6_W_vcombine_VV(Q6_V_vsplat_R(0), Q6_V_vsplat_R(0));
+ HVX_VectorPair r0_c1_sum_p = Q6_W_vcombine_VV(Q6_V_vsplat_R(0), Q6_V_vsplat_R(0));
+ HVX_VectorPair r1_c0_sum_p = Q6_W_vcombine_VV(Q6_V_vsplat_R(0), Q6_V_vsplat_R(0));
+ HVX_VectorPair r1_c1_sum_p = Q6_W_vcombine_VV(Q6_V_vsplat_R(0), Q6_V_vsplat_R(0));
uint32_t i = 0;
HVX_Vector c1_hf = y1[i];
// Compute 4 dot products: r0×c0, r0×c1, r1×c0, r1×c1
- HVX_VectorPair r0_c0_qf_p = Q6_Wqf32_vmpy_VhfVhf(r0_hf, c0_hf);
- HVX_VectorPair r0_c1_qf_p = Q6_Wqf32_vmpy_VhfVhf(r0_hf, c1_hf);
- HVX_VectorPair r1_c0_qf_p = Q6_Wqf32_vmpy_VhfVhf(r1_hf, c0_hf);
- HVX_VectorPair r1_c1_qf_p = Q6_Wqf32_vmpy_VhfVhf(r1_hf, c1_hf);
-
- HVX_Vector r0_c0_qf = Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(r0_c0_qf_p), Q6_V_hi_W(r0_c0_qf_p));
- HVX_Vector r0_c1_qf = Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(r0_c1_qf_p), Q6_V_hi_W(r0_c1_qf_p));
- HVX_Vector r1_c0_qf = Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(r1_c0_qf_p), Q6_V_hi_W(r1_c0_qf_p));
- HVX_Vector r1_c1_qf = Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(r1_c1_qf_p), Q6_V_hi_W(r1_c1_qf_p));
-
- r0_c0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_c0_qf, r0_c0_sum));
- r0_c1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_c1_qf, r0_c1_sum));
- r1_c0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_c0_qf, r1_c0_sum));
- r1_c1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_c1_qf, r1_c1_sum));
+ r0_c0_sum_p = hvx_vec_mpyacc_f32_f16(r0_c0_sum_p, r0_hf, c0_hf);
+ r0_c1_sum_p = hvx_vec_mpyacc_f32_f16(r0_c1_sum_p, r0_hf, c1_hf);
+ r1_c0_sum_p = hvx_vec_mpyacc_f32_f16(r1_c0_sum_p, r1_hf, c0_hf);
+ r1_c1_sum_p = hvx_vec_mpyacc_f32_f16(r1_c1_sum_p, r1_hf, c1_hf);
}
if (nloe) {
HVX_Vector c0_hf = Q6_V_vand_QV(bmask, y0[i]);
HVX_Vector c1_hf = Q6_V_vand_QV(bmask, y1[i]);
- HVX_VectorPair r0_c0_qf_p = Q6_Wqf32_vmpy_VhfVhf(r0_hf, c0_hf);
- HVX_VectorPair r0_c1_qf_p = Q6_Wqf32_vmpy_VhfVhf(r0_hf, c1_hf);
- HVX_VectorPair r1_c0_qf_p = Q6_Wqf32_vmpy_VhfVhf(r1_hf, c0_hf);
- HVX_VectorPair r1_c1_qf_p = Q6_Wqf32_vmpy_VhfVhf(r1_hf, c1_hf);
-
- HVX_Vector r0_c0_qf = Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(r0_c0_qf_p), Q6_V_hi_W(r0_c0_qf_p));
- HVX_Vector r0_c1_qf = Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(r0_c1_qf_p), Q6_V_hi_W(r0_c1_qf_p));
- HVX_Vector r1_c0_qf = Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(r1_c0_qf_p), Q6_V_hi_W(r1_c0_qf_p));
- HVX_Vector r1_c1_qf = Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(r1_c1_qf_p), Q6_V_hi_W(r1_c1_qf_p));
-
- r0_c0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_c0_qf, r0_c0_sum));
- r0_c1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_c1_qf, r0_c1_sum));
- r1_c0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_c0_qf, r1_c0_sum));
- r1_c1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_c1_qf, r1_c1_sum));
-
+ r0_c0_sum_p = hvx_vec_mpyacc_f32_f16(r0_c0_sum_p, r0_hf, c0_hf);
+ r0_c1_sum_p = hvx_vec_mpyacc_f32_f16(r0_c1_sum_p, r0_hf, c1_hf);
+ r1_c0_sum_p = hvx_vec_mpyacc_f32_f16(r1_c0_sum_p, r1_hf, c0_hf);
+ r1_c1_sum_p = hvx_vec_mpyacc_f32_f16(r1_c1_sum_p, r1_hf, c1_hf);
}
+ HVX_Vector r0_c0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_VsfVsf(Q6_V_lo_W(r0_c0_sum_p), Q6_V_hi_W(r0_c0_sum_p)));
+ HVX_Vector r0_c1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_VsfVsf(Q6_V_lo_W(r0_c1_sum_p), Q6_V_hi_W(r0_c1_sum_p)));
+ HVX_Vector r1_c0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_VsfVsf(Q6_V_lo_W(r1_c0_sum_p), Q6_V_hi_W(r1_c0_sum_p)));
+ HVX_Vector r1_c1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_VsfVsf(Q6_V_lo_W(r1_c1_sum_p), Q6_V_hi_W(r1_c1_sum_p)));
+
// Reduce and store results
HVX_Vector r0_r1_c0_sum = hvx_vec_reduce_sum_f32x2(r0_c0_sum, r1_c0_sum);
HVX_Vector r0_r1_c1_sum = hvx_vec_reduce_sum_f32x2(r0_c1_sum, r1_c1_sum);