file(TO_CMAKE_PATH "${HEXAGON_SDK_ROOT}" HEXAGON_SDK_ROOT)
file(TO_CMAKE_PATH "${HEXAGON_TOOLS_ROOT}" HEXAGON_TOOLS_ROOT)
-if (NOT IS_DIRECTORY "${HEXAGON_SDK_ROOT}" OR NOT IS_DIRECTORY "${HEXAGON_TOOLS_ROOT}")
- message(FATAL_ERROR "Make sure HEXAGON_SDK_ROOT and HEXAGON_TOOLS_ROOT point to the correct Hexagon SDK installation.")
+if (NOT IS_DIRECTORY "${HEXAGON_SDK_ROOT}")
+ message(FATAL_ERROR "Make sure HEXAGON_SDK_ROOT point to the correct Hexagon SDK installation.")
+endif()
+
+if (NOT IS_DIRECTORY "${HEXAGON_TOOLS_ROOT}")
+ message("Try to read HEXAGON_TOOLS_ROOT from hexagon_sdk.json")
+ file(READ "${HEXAGON_SDK_ROOT}/hexagon_sdk.json" HEXAGON_SDK_CONFIG_PATH)
+ string(JSON HEXAGON_TOOLS_PATH GET ${HEXAGON_SDK_CONFIG_PATH} "root" "tools" "info" 0 "path")
+ message("Found HEXAGON_TOOLS_PATH: ${HEXAGON_TOOLS_PATH}")
+ set(HEXAGON_TOOLS_ROOT "${HEXAGON_SDK_ROOT}/${HEXAGON_TOOLS_PATH}")
+ file(TO_CMAKE_PATH "${HEXAGON_TOOLS_ROOT}" HEXAGON_TOOLS_ROOT)
+ if (NOT IS_DIRECTORY "${HEXAGON_TOOLS_ROOT}")
+ message(FATAL_ERROR "Make sure HEXAGON_TOOLS_ROOT point to the correct Hexagon SDK installation.")
+ endif()
endif()
message(STATUS "hexagon: using ${HEXAGON_SDK_ROOT} and ${HEXAGON_TOOLS_ROOT} for building libggml-htp skels")
#include "htp-msg.h"
#include "htp-ops.h"
+static inline HVX_Vector hvx_load_f32_to_f16(const HVX_Vector * restrict src, const HVX_Vector zero) {
+ HVX_Vector y0_qf = Q6_Vqf32_vsub_VsfVsf(src[0], zero); // 32 elements
+ HVX_Vector y1_qf = Q6_Vqf32_vsub_VsfVsf(src[1], zero); // 32 elements
+ return Q6_Vh_vdeal_Vh(Q6_Vhf_equals_Wqf32(Q6_W_vcombine_VV(y1_qf, y0_qf)));
+}
+
// Dot product of FP32 and FP16 vectors, accumulating to float
static inline void hvx_dot_f32_f16_aa(float * restrict r, const void * restrict y, const void * restrict x, unsigned int n, float s) {
const HVX_Vector * restrict vy = (const HVX_Vector * restrict) y; // fp32
#pragma unroll(4)
for (i = 0; i < nvec; i++) {
// Load y (fp32) and convert into fp16
- HVX_Vector y0_qf = Q6_Vqf32_vsub_VsfVsf(vy[i*2+0], zero); // 32 elements
- HVX_Vector y1_qf = Q6_Vqf32_vsub_VsfVsf(vy[i*2+1], zero); // 32 elements
- HVX_Vector y_hf = Q6_Vh_vdeal_Vh(Q6_Vhf_equals_Wqf32(Q6_W_vcombine_VV(y1_qf, y0_qf)));
+ HVX_Vector y_hf = hvx_load_f32_to_f16(&vy[i*2], zero);
// Load x (fp16)
HVX_Vector x_hf = vx[i];
HVX_VectorPair xy_qf = Q6_Wqf32_vmpy_VhfVhf(x_hf, y_hf);
- rsum = Q6_Vqf32_vadd_Vqf32Vqf32(rsum, Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(xy_qf), Q6_V_hi_W(xy_qf)));
+ rsum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(xy_qf), Q6_V_hi_W(xy_qf)), rsum));
}
if (nloe) {
// Load y (fp32) and convert into fp16
- HVX_Vector y0_qf = Q6_Vqf32_vsub_VsfVsf(vy[i*2+0], zero); // 32 elements
- HVX_Vector y1_qf = Q6_Vqf32_vsub_VsfVsf(vy[i*2+1], zero); // 32 elements
- HVX_Vector y_hf = Q6_Vh_vdeal_Vh(Q6_Vhf_equals_Wqf32(Q6_W_vcombine_VV(y1_qf, y0_qf)));
+ HVX_Vector y_hf = hvx_load_f32_to_f16(&vy[i*2], zero);
// Load x (fp16)
HVX_Vector x_hf = vx[i];
HVX_VectorPair xy_qf = Q6_Wqf32_vmpy_VhfVhf(x_hf, y_hf);
- rsum = Q6_Vqf32_vadd_Vqf32Vqf32(rsum, Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(xy_qf), Q6_V_hi_W(xy_qf)));
+ rsum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(xy_qf), Q6_V_hi_W(xy_qf)), rsum));
}
- rsum = Q6_Vqf32_vmpy_VsfVsf(Q6_Vsf_equals_Vqf32(rsum), hvx_vec_splat_f32(s));
- rsum = Q6_Vsf_equals_Vqf32(hvx_vec_reduce_sum_qf32(rsum));
+ rsum = Q6_Vqf32_vmpy_VsfVsf(hvx_vec_splat_f32(s), hvx_vec_reduce_sum_f32(rsum));
+ hvx_vec_store_u(r, 4, Q6_Vsf_equals_Vqf32(rsum));
+}
+
+// Dot product of FP32 and FP16 vectors, accumulating to float
+static inline void hvx_dot_f32_f16_aa_rx2(float * restrict r,
+ const void * restrict y,
+ const void * restrict x0,
+ const void * restrict x1,
+ unsigned int n,
+ float s) {
+ const HVX_Vector * restrict vy = (const HVX_Vector * restrict) y; // fp32
+ const HVX_Vector * restrict vx0 = (const HVX_Vector * restrict) x0; // fp16
+ const HVX_Vector * restrict vx1 = (const HVX_Vector * restrict) x1; // fp16
+
+ uint32_t nvec = n / VLEN_FP16; // num full fp16 hvx vectors
+ uint32_t nloe = n % VLEN_FP16; // leftover elements
+
+ const HVX_Vector zero = Q6_V_vsplat_R(0);
+ HVX_Vector rsum0 = Q6_V_vsplat_R(0);
+ HVX_Vector rsum1 = Q6_V_vsplat_R(0);
+
+ uint32_t i = 0;
- hvx_vec_store_u(r, 4, rsum);
+ #pragma unroll(2)
+ for (i = 0; i < nvec; i++) {
+ // Load y (fp32) and convert into fp16
+ HVX_Vector y_hf = hvx_load_f32_to_f16(&vy[i*2], zero);
+ // Load x (fp16)
+ HVX_Vector x0_hf = vx0[i];
+ HVX_Vector x1_hf = vx1[i];
+
+ HVX_VectorPair xy0_qf = Q6_Wqf32_vmpy_VhfVhf(x0_hf, y_hf);
+ HVX_VectorPair xy1_qf = Q6_Wqf32_vmpy_VhfVhf(x1_hf, y_hf);
+
+ rsum0 = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(xy0_qf), Q6_V_hi_W(xy0_qf)), rsum0));
+ rsum1 = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(xy1_qf), Q6_V_hi_W(xy1_qf)), rsum1));
+ }
+
+ if (nloe) {
+ // Load y (fp32) and convert into fp16
+ HVX_Vector y_hf = hvx_load_f32_to_f16(&vy[i*2], zero);
+
+ // Load x (fp16)
+ HVX_Vector x0_hf = vx0[i];
+ HVX_Vector x1_hf = vx1[i];
+
+ // Zero-out unused elements
+ // Note that we need to clear both x and y because they may contain NANs
+ HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe * 2);
+ x0_hf = Q6_V_vand_QV(bmask, x0_hf);
+ x1_hf = Q6_V_vand_QV(bmask, x1_hf);
+ y_hf = Q6_V_vand_QV(bmask, y_hf);
+
+ HVX_VectorPair xy0_qf = Q6_Wqf32_vmpy_VhfVhf(x0_hf, y_hf);
+ HVX_VectorPair xy1_qf = Q6_Wqf32_vmpy_VhfVhf(x1_hf, y_hf);
+
+ rsum0 = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(xy0_qf), Q6_V_hi_W(xy0_qf)), rsum0));
+ rsum1 = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(xy1_qf), Q6_V_hi_W(xy1_qf)), rsum1));
+ }
+
+ HVX_Vector rsum = Q6_Vqf32_vmpy_VsfVsf(hvx_vec_splat_f32(s), hvx_vec_reduce_sum_f32x2(rsum0, rsum1));
+ hvx_vec_store_u(r, 8, Q6_Vsf_equals_Vqf32(rsum));
}
// Dot product of two F16 vectors, accumulating to float
HVX_VectorPair xy_qf = Q6_Wqf32_vmpy_VhfVhf(x_hf, y_hf);
- rsum = Q6_Vqf32_vadd_Vqf32Vqf32(rsum, Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(xy_qf), Q6_V_hi_W(xy_qf)));
+ rsum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(xy_qf), Q6_V_hi_W(xy_qf)), rsum));
}
if (nloe) {
HVX_VectorPair xy_qf = Q6_Wqf32_vmpy_VhfVhf(x_hf, y_hf);
- rsum = Q6_Vqf32_vadd_Vqf32Vqf32(rsum, Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(xy_qf), Q6_V_hi_W(xy_qf)));
+ rsum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(xy_qf), Q6_V_hi_W(xy_qf)), rsum));
+ }
+
+ rsum = Q6_Vqf32_vmpy_VsfVsf(hvx_vec_splat_f32(s), hvx_vec_reduce_sum_f32(rsum));
+ hvx_vec_store_u(r, 4, Q6_Vsf_equals_Vqf32(rsum));
+}
+
+static inline void hvx_dot_f16_f16_aa_rx2(float * restrict r,
+ const void * restrict y,
+ const void * restrict x0,
+ const void * restrict x1,
+ unsigned int n,
+ float s) {
+ const HVX_Vector * restrict vx0 = (const HVX_Vector * restrict) x0; // fp16
+ const HVX_Vector * restrict vx1 = (const HVX_Vector * restrict) x1; // fp16
+ const HVX_Vector * restrict vy = (const HVX_Vector * restrict) y; // fp16
+
+ uint32_t nvec = n / VLEN_FP16; // num full fp16 hvx vectors
+ uint32_t nloe = n % VLEN_FP16; // leftover elements
+
+ const HVX_Vector zero = Q6_V_vsplat_R(0);
+ HVX_Vector rsum0 = Q6_V_vsplat_R(0);
+ HVX_Vector rsum1 = Q6_V_vsplat_R(0);
+
+ uint32_t i = 0;
+
+ #pragma unroll(4)
+ for (i = 0; i < nvec; i++) {
+ HVX_Vector y_hf = vy[i];
+ HVX_Vector x0_hf = vx0[i];
+ HVX_Vector x1_hf = vx1[i];
+
+ HVX_VectorPair xy0_qf = Q6_Wqf32_vmpy_VhfVhf(x0_hf, y_hf);
+ HVX_VectorPair xy1_qf = Q6_Wqf32_vmpy_VhfVhf(x1_hf, y_hf);
+
+ rsum0 = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(xy0_qf), Q6_V_hi_W(xy0_qf)), rsum0));
+ rsum1 = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(xy1_qf), Q6_V_hi_W(xy1_qf)), rsum1));
+ }
+
+ if (nloe) {
+ HVX_Vector y_hf = vy[i];
+
+ // Load x (fp16) and zero-out unused elements
+ HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe * 2);
+ HVX_Vector x0_hf = Q6_V_vand_QV(bmask, vx0[i]);
+ HVX_Vector x1_hf = Q6_V_vand_QV(bmask, vx1[i]);
+
+ HVX_VectorPair xy0_qf = Q6_Wqf32_vmpy_VhfVhf(x0_hf, y_hf);
+ HVX_VectorPair xy1_qf = Q6_Wqf32_vmpy_VhfVhf(x1_hf, y_hf);
+
+ rsum0 = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(xy0_qf), Q6_V_hi_W(xy0_qf)), rsum0));
+ rsum1 = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(xy1_qf), Q6_V_hi_W(xy1_qf)), rsum1));
}
- rsum = Q6_Vqf32_vmpy_VsfVsf(Q6_Vsf_equals_Vqf32(rsum), hvx_vec_splat_f32(s));
- rsum = Q6_Vsf_equals_Vqf32(hvx_vec_reduce_sum_qf32(rsum));
- hvx_vec_store_u(r, 4, rsum);
+ HVX_Vector rsum = Q6_Vqf32_vmpy_VsfVsf(hvx_vec_splat_f32(s), hvx_vec_reduce_sum_f32x2(rsum0, rsum1));
+ hvx_vec_store_u(r, 8, Q6_Vsf_equals_Vqf32(rsum));
}
// MAD: y (F32) += x (F16) * s (float)
// Inner loop processing the block from VTCM
uint32_t ic = 0;
+ const bool is_q_fp32 = (q->type == HTP_TYPE_F32);
+
// Process in blocks of 32 (VLEN_FP32)
- static_assert(FLASH_ATTN_BLOCK_SIZE / VLEN_FP32 == 4, "FLASH_ATTN_BLOCK_SIZE changed, fix HVX_Vector_x4 usage");
+ static_assert(FLASH_ATTN_BLOCK_SIZE / VLEN_FP32 <= 4, "FLASH_ATTN_BLOCK_SIZE changed, fix HVX_Vector_x4 usage");
HVX_Vector_x4 scores_x4;
HVX_Vector v_max = hvx_vec_splat_f32(-INFINITY);
for (uint32_t iv = 0; ic + VLEN_FP32 <= current_block_size; ic += VLEN_FP32, ++iv) {
// 1. Compute scores
- float __attribute__((aligned(VLEN))) scores_arr[FLASH_ATTN_BLOCK_SIZE];
- for (int j = 0; j < VLEN_FP32; ++j) {
+ float __attribute__((aligned(VLEN))) scores_arr[VLEN_FP32];
+ for (int j = 0; j < VLEN_FP32; j += 2) {
const uint32_t cur_ic = ic + j;
const uint8_t * k_ptr = k_base + cur_ic * size_k_row_padded;
- if (q->type == HTP_TYPE_F32) {
- hvx_dot_f32_f16_aa(&scores_arr[j], q_ptr_vtcm, k_ptr, DK, scale);
+ if (is_q_fp32) {
+ hvx_dot_f32_f16_aa_rx2(&scores_arr[j], q_ptr_vtcm, k_ptr, k_ptr + size_k_row_padded, DK, scale);
} else {
- hvx_dot_f16_f16_aa(&scores_arr[j], q_ptr_vtcm, k_ptr, DK, scale);
+ hvx_dot_f16_f16_aa_rx2(&scores_arr[j], q_ptr_vtcm, k_ptr, k_ptr + size_k_row_padded, DK, scale);
}
}
float s_val;
const uint8_t * k_ptr = k_base + ic * size_k_row_padded;
- if (q->type == HTP_TYPE_F32) {
+ if (is_q_fp32) {
hvx_dot_f32_f16_aa(&s_val, q_ptr_vtcm, k_ptr, DK, scale);
} else {
hvx_dot_f16_f16_aa(&s_val, q_ptr_vtcm, k_ptr, DK, scale);
}
static void hvx_vec_dump_f32_n(char * pref, HVX_Vector v, uint32_t n) {
- union {
- HVX_Vector v;
- float d[32];
- } u = { .v = v };
+ HVX_VectorAlias u = { .v = v };
const uint32_t n0 = n / 16;
const uint32_t n1 = n % 16;
int i = 0;
for (; i < n0; i++) {
- hex_dump_f32_line(pref, u.d + (16 * i), 16);
+ hex_dump_f32_line(pref, u.fp32 + (16 * i), 16);
}
if (n1) {
- hex_dump_f32_line(pref, u.d + (16 * i), n1);
+ hex_dump_f32_line(pref, u.fp32 + (16 * i), n1);
}
}
return hvx_vec_reduce_sum_n_qf32(in, 32);
}
+#if __HVX_ARCH__ > 75
+
+static inline HVX_Vector hvx_vec_reduce_sum_f32x2(HVX_Vector in0, HVX_Vector in1) {
+ HVX_VectorPair sump = Q6_W_vshuff_VVR(in1, in0, 4);
+ HVX_Vector sum_sf = Q6_Vsf_vadd_VsfVsf(Q6_V_lo_W(sump), Q6_V_hi_W(sump));
+
+ sum_sf = Q6_Vsf_vadd_VsfVsf(sum_sf, Q6_V_vror_VR(sum_sf, VLEN / 2));
+ sum_sf = Q6_Vsf_vadd_VsfVsf(sum_sf, Q6_V_vror_VR(sum_sf, VLEN / 4));
+ sum_sf = Q6_Vsf_vadd_VsfVsf(sum_sf, Q6_V_vror_VR(sum_sf, VLEN / 8));
+ sum_sf = Q6_Vsf_vadd_VsfVsf(sum_sf, Q6_V_vror_VR(sum_sf, VLEN / 16));
+ return sum_sf;
+}
+
+static inline HVX_Vector hvx_vec_reduce_sum_n_f32(HVX_Vector in, unsigned int n) {
+ unsigned int total = n * 4; // total vec nbytes
+ unsigned int width = 4; // fp32 nbytes
+
+ HVX_Vector sum = in, sum_t;
+ while (width < total) {
+ sum_t = Q6_V_vror_VR(sum, width); // rotate right
+ sum = Q6_Vsf_vadd_VsfVsf(sum, sum_t); // elementwise sum
+ width = width << 1;
+ }
+ return sum;
+}
+
+#else
+
+static inline HVX_Vector hvx_vec_reduce_sum_f32x2(HVX_Vector in0, HVX_Vector in1) {
+ HVX_VectorPair sump = Q6_W_vshuff_VVR(in1, in0, 4);
+ HVX_Vector sum_qf = Q6_Vqf32_vadd_VsfVsf(Q6_V_lo_W(sump), Q6_V_hi_W(sump));
+
+ sum_qf = Q6_Vqf32_vadd_Vqf32Vsf(sum_qf, Q6_V_vror_VR(Q6_Vsf_equals_Vqf32(sum_qf), VLEN / 2));
+ sum_qf = Q6_Vqf32_vadd_Vqf32Vsf(sum_qf, Q6_V_vror_VR(Q6_Vsf_equals_Vqf32(sum_qf), VLEN / 4));
+ sum_qf = Q6_Vqf32_vadd_Vqf32Vsf(sum_qf, Q6_V_vror_VR(Q6_Vsf_equals_Vqf32(sum_qf), VLEN / 8));
+ sum_qf = Q6_Vqf32_vadd_Vqf32Vsf(sum_qf, Q6_V_vror_VR(Q6_Vsf_equals_Vqf32(sum_qf), VLEN / 16));
+ return Q6_Vsf_equals_Vqf32(sum_qf);
+}
+
static inline HVX_Vector hvx_vec_reduce_sum_n_f32(HVX_Vector in, unsigned int n) {
unsigned int total = n * 4; // total vec nbytes
unsigned int width = 4; // fp32 nbytes
return sum;
}
+#endif
+
static inline HVX_Vector hvx_vec_reduce_sum_f32(HVX_Vector in) {
return hvx_vec_reduce_sum_n_f32(in, 32);
}
#include "hex-dma.h"
#include "hvx-utils.h"
+#include "hvx-dump.h"
#define GGML_COMMON_DECL_C
#include "ggml-common.h"
const uint8_t * restrict y_q = ((const uint8_t *) vy + 0); // quants first
const uint8_t * restrict y_d = ((const uint8_t *) vy + y_qrow_size); // then scales
- // Row sum (qf32)
+ // Row sum (sf)
HVX_Vector r0_sum = Q6_V_vsplat_R(0);
// Multiply and accumulate into int32.
HVX_Vector r0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_ia, r0_dd);
- r0_sum = Q6_Vqf32_vadd_Vqf32Vqf32(r0_sum, r0_fa);
+ r0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_fa, r0_sum));
}
// Process leftovers, we still load full 4x4x2 block but zero out unused scales/blocks
// Zero out unused scales
HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe / 8);
r0_dd = Q6_V_vand_QV(bmask, r0_dd);
+ r0_ia = Q6_V_vand_QV(bmask, r0_ia);
HVX_Vector r0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_ia, r0_dd);
- r0_sum = Q6_Vqf32_vadd_Vqf32Vqf32(r0_sum, r0_fa);
+ r0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_fa, r0_sum));
}
- // Reduce and convert into fp32
- r0_sum = hvx_vec_reduce_sum_f32(Q6_Vsf_equals_Vqf32(r0_sum));
+ r0_sum = hvx_vec_reduce_sum_f32(r0_sum);
hvx_vec_store_u(&s[0], 4, r0_sum);
}
const uint8_t * restrict y_q = ((const uint8_t *) vy + 0); // quants first
const uint8_t * restrict y_d = ((const uint8_t *) vy + y_qrow_size); // then scales
- // Row sum (qf32)
+ // Row sum (sf)
HVX_Vector r0_sum = Q6_V_vsplat_R(0);
HVX_Vector r1_sum = Q6_V_vsplat_R(0);
HVX_Vector r0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_ia, r0_dd);
HVX_Vector r1_fa = Q6_Vqf32_vmpy_VsfVsf(r1_ia, r1_dd);
- r0_sum = Q6_Vqf32_vadd_Vqf32Vqf32(r0_sum, r0_fa);
- r1_sum = Q6_Vqf32_vadd_Vqf32Vqf32(r1_sum, r1_fa);
+ r0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_fa, r0_sum));
+ r1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_fa, r1_sum));
}
// Process leftovers, we still load full 4x4x2 block but zero out unused scales/blocks
HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe / 8);
r0_dd = Q6_V_vand_QV(bmask, r0_dd);
r1_dd = Q6_V_vand_QV(bmask, r1_dd);
+ r0_ia = Q6_V_vand_QV(bmask, r0_ia);
+ r1_ia = Q6_V_vand_QV(bmask, r1_ia);
HVX_Vector r0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_ia, r0_dd);
HVX_Vector r1_fa = Q6_Vqf32_vmpy_VsfVsf(r1_ia, r1_dd);
- r0_sum = Q6_Vqf32_vadd_Vqf32Vqf32(r0_sum, r0_fa);
- r1_sum = Q6_Vqf32_vadd_Vqf32Vqf32(r1_sum, r1_fa);
+ r0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_fa, r0_sum));
+ r1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_fa, r1_sum));
}
- // Convert into fp32 and reduce
- r0_sum = hvx_vec_reduce_sum_f32(Q6_Vsf_equals_Vqf32(r0_sum));
- r1_sum = hvx_vec_reduce_sum_f32(Q6_Vsf_equals_Vqf32(r1_sum));
- HVX_VectorPair p0 = Q6_W_vshuff_VVR(r1_sum, r0_sum, 4);
-
- hvx_vec_store_u(&s[0], 8, Q6_V_lo_W(p0));
+ HVX_Vector rsum = hvx_vec_reduce_sum_f32x2(r0_sum, r1_sum);
+ hvx_vec_store_u(&s[0], 8, rsum);
}
static void vec_dot_q8x4x2_q8x4x2(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) {
const uint8_t * restrict y_q = ((const uint8_t *) vy + 0); // quants first
const uint8_t * restrict y_d = ((const uint8_t *) vy + y_qrow_size); // then scales
- // Row sum (qf32)
+ // Row sum (sf)
HVX_Vector r0_sum = Q6_V_vsplat_R(0);
// Multiply and accumulate into int32.
HVX_Vector r0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_ia, r0_dd);
- r0_sum = Q6_Vqf32_vadd_Vqf32Vqf32(r0_sum, r0_fa);
+ r0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_fa, r0_sum));
}
// Process leftovers, we still load full 4x4x2 block but zero out unused scales/blocks
// Zero out unused scales
HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe / 8);
r0_dd = Q6_V_vand_QV(bmask, r0_dd);
+ r0_ia = Q6_V_vand_QV(bmask, r0_ia);
HVX_Vector r0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_ia, r0_dd);
- r0_sum = Q6_Vqf32_vadd_Vqf32Vqf32(r0_sum, r0_fa);
+ r0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_fa, r0_sum));
}
- // Reduce and convert into fp32
- r0_sum = hvx_vec_reduce_sum_f32(Q6_Vsf_equals_Vqf32(r0_sum));
+ r0_sum = hvx_vec_reduce_sum_f32(r0_sum);
hvx_vec_store_u(&s[0], 4, r0_sum);
}
HVX_Vector r0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_ia, r0_dd);
HVX_Vector r1_fa = Q6_Vqf32_vmpy_VsfVsf(r1_ia, r1_dd);
- r0_sum = Q6_Vqf32_vadd_Vqf32Vqf32(r0_sum, r0_fa);
- r1_sum = Q6_Vqf32_vadd_Vqf32Vqf32(r1_sum, r1_fa);
+ r0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_fa, r0_sum));
+ r1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_fa, r1_sum));
}
// Process leftovers, we still load full 4x4x2 block but zero out unused scales/blocks
HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe / 8);
r0_dd = Q6_V_vand_QV(bmask, r0_dd);
r1_dd = Q6_V_vand_QV(bmask, r1_dd);
+ r0_ia = Q6_V_vand_QV(bmask, r0_ia);
+ r1_ia = Q6_V_vand_QV(bmask, r1_ia);
HVX_Vector r0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_ia, r0_dd);
HVX_Vector r1_fa = Q6_Vqf32_vmpy_VsfVsf(r1_ia, r1_dd);
- r0_sum = Q6_Vqf32_vadd_Vqf32Vqf32(r0_sum, r0_fa);
- r1_sum = Q6_Vqf32_vadd_Vqf32Vqf32(r1_sum, r1_fa);
+ r0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_fa, r0_sum));
+ r1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_fa, r1_sum));
}
- // Convert into fp32 and reduce
- r0_sum = hvx_vec_reduce_sum_f32(Q6_Vsf_equals_Vqf32(r0_sum));
- r1_sum = hvx_vec_reduce_sum_f32(Q6_Vsf_equals_Vqf32(r1_sum));
- HVX_VectorPair p0 = Q6_W_vshuff_VVR(r1_sum, r0_sum, 4);
-
- hvx_vec_store_u(&s[0], 8, Q6_V_lo_W(p0));
+ HVX_Vector rsum = hvx_vec_reduce_sum_f32x2(r0_sum, r1_sum);
+ hvx_vec_store_u(&s[0], 8, rsum);
}
static void vec_dot_mxfp4x4x2_q8x4x2(const int n,
const uint8_t * restrict y_q = ((const uint8_t *) vy + 0); // quants first
const uint8_t * restrict y_d = ((const uint8_t *) vy + y_qrow_size); // then scales
- // Row sum (qf32)
+ // Row sum (sf)
HVX_Vector r0_sum = Q6_V_vsplat_R(0);
// Multiply and accumulate into int32.
HVX_Vector r0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_ia, r0_dd);
- r0_sum = Q6_Vqf32_vadd_Vqf32Vqf32(r0_sum, r0_fa);
+ r0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_fa, r0_sum));
}
// Process leftovers
// Zero-out unused scales
HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe / 8);
r0_dd = Q6_V_vand_QV(bmask, r0_dd);
+ r0_ia = Q6_V_vand_QV(bmask, r0_ia);
HVX_Vector r0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_ia, r0_dd);
- r0_sum = Q6_Vqf32_vadd_Vqf32Vqf32(r0_sum, r0_fa);
+ r0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_fa, r0_sum));
}
- // Reduce and convert into fp32
- r0_sum = hvx_vec_reduce_sum_f32(Q6_Vsf_equals_Vqf32(r0_sum));
+ r0_sum = hvx_vec_reduce_sum_f32(r0_sum);
hvx_vec_store_u(&s[0], 4, r0_sum);
}
const uint8_t * restrict y_q = ((const uint8_t *) vy + 0); // quants first
const uint8_t * restrict y_d = ((const uint8_t *) vy + y_qrow_size); // then scales
- // Row sum (qf32)
+ // Row sum (sf)
HVX_Vector r0_sum = Q6_V_vsplat_R(0);
HVX_Vector r1_sum = Q6_V_vsplat_R(0);
// Multiply and accumulate into int32.
// Compute combined scale (fp32).
- // Apply scale to acc and accumulate into the row sum (qf32).
+ // Apply scale to acc and accumulate into the row sum (f32).
const uint32_t nb = n / qk; // num full blocks
int32_t nloe = n % qk; // num leftover elemements (must be signed)
HVX_Vector r0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_ia, r0_dd);
HVX_Vector r1_fa = Q6_Vqf32_vmpy_VsfVsf(r1_ia, r1_dd);
- r0_sum = Q6_Vqf32_vadd_Vqf32Vqf32(r0_sum, r0_fa);
- r1_sum = Q6_Vqf32_vadd_Vqf32Vqf32(r1_sum, r1_fa);
+ r0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_fa, r0_sum));
+ r1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_fa, r1_sum));
}
// Process leftovers
HVX_Vector r0_dd = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(r0_d, vy_d));
HVX_Vector r1_dd = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(r1_d, vy_d));
- // Zero-out unused scales
+ // Zero-out unused values
HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe / 8);
r0_dd = Q6_V_vand_QV(bmask, r0_dd);
r1_dd = Q6_V_vand_QV(bmask, r1_dd);
+ r0_ia = Q6_V_vand_QV(bmask, r0_ia);
+ r1_ia = Q6_V_vand_QV(bmask, r1_ia);
HVX_Vector r0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_ia, r0_dd);
HVX_Vector r1_fa = Q6_Vqf32_vmpy_VsfVsf(r1_ia, r1_dd);
- r0_sum = Q6_Vqf32_vadd_Vqf32Vqf32(r0_sum, r0_fa);
- r1_sum = Q6_Vqf32_vadd_Vqf32Vqf32(r1_sum, r1_fa);
+ r0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_fa, r0_sum));
+ r1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_fa, r1_sum));
}
- // Convert into fp32 and reduce
- r0_sum = hvx_vec_reduce_sum_f32(Q6_Vsf_equals_Vqf32(r0_sum));
- r1_sum = hvx_vec_reduce_sum_f32(Q6_Vsf_equals_Vqf32(r1_sum));
- HVX_VectorPair p0 = Q6_W_vshuff_VVR(r1_sum, r0_sum, 4);
-
- hvx_vec_store_u(&s[0], 8, Q6_V_lo_W(p0));
+ HVX_Vector rsum = hvx_vec_reduce_sum_f32x2(r0_sum, r1_sum);
+ hvx_vec_store_u(&s[0], 8, rsum);
}
static void vec_dot_f16_f16_aa(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) {
rsum = Q6_Vqf32_vadd_Vqf32Vqf32(rsum, Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(xy_qf), Q6_V_hi_W(xy_qf)));
}
- rsum = Q6_Vsf_equals_Vqf32(hvx_vec_reduce_sum_qf32(rsum));
+ rsum = hvx_vec_reduce_sum_f32(Q6_Vsf_equals_Vqf32(rsum));
hvx_vec_store_u(&s[0], 4, rsum);
}
rsum1 = Q6_Vqf32_vadd_Vqf32Vqf32(rsum1, Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(xy1_qf), Q6_V_hi_W(xy1_qf)));
}
- rsum0 = Q6_Vsf_equals_Vqf32(hvx_vec_reduce_sum_qf32(rsum0));
- rsum1 = Q6_Vsf_equals_Vqf32(hvx_vec_reduce_sum_qf32(rsum1));
- HVX_VectorPair p0 = Q6_W_vshuff_VVR(rsum1, rsum0, 4);
-
- hvx_vec_store_u(&s[0], 8, Q6_V_lo_W(p0));
+ HVX_Vector rsum = hvx_vec_reduce_sum_f32x2(Q6_Vsf_equals_Vqf32(rsum0), Q6_Vsf_equals_Vqf32(rsum1));
+ hvx_vec_store_u(&s[0], 8, rsum);
}
static void vec_dot_f16_f16_uu(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) {
rsum = Q6_Vqf32_vadd_Vqf32Vqf32(rsum, Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(xy_qf), Q6_V_hi_W(xy_qf)));
}
- rsum = Q6_Vsf_equals_Vqf32(hvx_vec_reduce_sum_qf32(rsum));
+ rsum = hvx_vec_reduce_sum_f32(Q6_Vsf_equals_Vqf32(rsum));
hvx_vec_store_u(&s[0], 4, rsum);
}
rsum = Q6_Vqf32_vadd_Vqf32Vqf32(rsum, Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(xy_qf), Q6_V_hi_W(xy_qf)));
}
- rsum = Q6_Vsf_equals_Vqf32(hvx_vec_reduce_sum_qf32(rsum));
+ // Convert into fp32 and reduce
+ rsum = hvx_vec_reduce_sum_f32(Q6_Vsf_equals_Vqf32(rsum));
hvx_vec_store_u(&s[0], 4, rsum);
}
v_pad[i] = v3;
}
- v = hvx_vec_reduce_sum_qf32(sum_vec);
- sum_vec = hvx_vec_repl4(Q6_Vsf_equals_Vqf32(v));
+ v = hvx_vec_reduce_sum_f32(Q6_Vsf_equals_Vqf32(sum_vec));
+ sum_vec = hvx_vec_repl4(v);
HVX_VectorPred pos_sum = Q6_Q_vcmp_gt_VwVw(sum_vec, zero_v);
HVX_Vector v4 = hvx_vec_inverse_f32(sum_vec);
sum_v = Q6_Vqf32_vadd_Vqf32Vqf32(sum_v, v2);
}
- HVX_Vector reduced_sum = hvx_vec_reduce_sum_qf32(sum_v);
- sum_v = hvx_vec_repl4(Q6_Vsf_equals_Vqf32(reduced_sum));
+ HVX_Vector reduced_sum = hvx_vec_reduce_sum_f32(Q6_Vsf_equals_Vqf32(sum_v));
+ sum_v = hvx_vec_repl4(reduced_sum);
HVX_Vector t_v = hvx_vec_splat_f32((float) num_elems);
HVX_Vector denom_v = hvx_vec_inverse_f32(t_v);