return hex_round_up(ne + nb * 8 * sizeof(__fp16), 128);
}
-static inline HVX_Vector_x8 hvx_vec_load_q4x4x8(const uint8_t * restrict ptr) {
+static inline HVX_Vector_x8 hvx_vec_load_q4x4x8_full(const uint8_t * restrict ptr) {
const HVX_Vector * restrict vptr = (const HVX_Vector *) ptr;
HVX_Vector v0_1 = vptr[0]; // first 256 elements (128 bytes)
const HVX_Vector mask_h4 = Q6_Vb_vsplat_R(0x0F);
const HVX_Vector i8 = Q6_Vb_vsplat_R(8);
- HVX_Vector v0 = Q6_V_vand_VV(v0_1, mask_h4); // & 0x0F
- HVX_Vector v1 = Q6_Vub_vlsr_VubR(v0_1, 4); // >> 4
- HVX_Vector v2 = Q6_V_vand_VV(v2_3, mask_h4); // & 0x0F
+ HVX_Vector v0 = Q6_V_vand_VV(v0_1, mask_h4); // & 0x0F : first 128 elements
+ HVX_Vector v1 = Q6_Vub_vlsr_VubR(v0_1, 4); // >> 4 : second 128 elements
+ HVX_Vector v2 = Q6_V_vand_VV(v2_3, mask_h4); // & 0x0F ...
HVX_Vector v3 = Q6_Vub_vlsr_VubR(v2_3, 4); // >> 4
HVX_Vector v4 = Q6_V_vand_VV(v4_5, mask_h4); // & 0x0F
HVX_Vector v5 = Q6_Vub_vlsr_VubR(v4_5, 4); // >> 4
return r;
}
-static inline HVX_Vector_x8 hvx_vec_load_mxfp4x4x8(const uint8_t * restrict ptr) {
+static HVX_Vector_x8 hvx_vec_load_q4x4x8_partial(const uint8_t * restrict ptr, uint32_t n) {
+ const HVX_Vector * restrict vptr = (const HVX_Vector *) ptr;
+
+ const uint32_t qk = QK_Q4_0x4x2; // 256
+ const uint32_t nb = n / qk;
+ const uint32_t nloe = n % qk;
+
+ const HVX_Vector mask_h4 = Q6_Vb_vsplat_R(0x0F);
+ const HVX_Vector i8 = Q6_Vb_vsplat_R(8);
+
+ HVX_Vector_x8 r;
+ uint32_t i = 0;
+
+ #pragma unroll(2)
+ for (i=0; i < nb; i++) {
+ HVX_Vector v = vptr[i]; // 256 elements (128 bytes)
+ HVX_Vector v0 = Q6_V_vand_VV(v, mask_h4); // & 0x0F : first 128 elements
+ HVX_Vector v1 = Q6_Vub_vlsr_VubR(v, 4); // >> 4 : second 128 elements
+ r.v[i*2+0] = Q6_Vb_vsub_VbVb(v0, i8);
+ r.v[i*2+1] = Q6_Vb_vsub_VbVb(v1, i8);
+ }
+
+ if (nloe) {
+ HVX_Vector v = vptr[i]; // 256 elements (128 bytes)
+ HVX_Vector v0 = Q6_V_vand_VV(v, mask_h4); // & 0x0F : even 128 elements
+ HVX_Vector v1 = Q6_Vub_vlsr_VubR(v, 4); // >> 4 : odd 128 elements
+ HVX_VectorPair v0_1_p = Q6_W_vshuff_VVR(v1, v0, -1); // zip even:odd:...
+ r.v[i*2+0] = Q6_Vb_vsub_VbVb(Q6_V_lo_W(v0_1_p), i8);
+ r.v[i*2+1] = Q6_Vb_vsub_VbVb(Q6_V_hi_W(v0_1_p), i8);
+ }
+
+ return r;
+}
+
+static inline HVX_Vector_x8 hvx_vec_load_mxfp4x4x8_full(const uint8_t * restrict ptr) {
const HVX_Vector * restrict vptr = (const HVX_Vector *) ptr;
HVX_Vector v0_1 = vptr[0]; // first 256 elements (128 bytes)
return r;
}
-static inline HVX_Vector_x8 hvx_vec_load_q8x4x8(const uint8_t * restrict ptr) {
+static inline HVX_Vector_x8 hvx_vec_load_mxfp4x4x8_partial(const uint8_t * restrict ptr, uint32_t n) {
+ const HVX_Vector * restrict vptr = (const HVX_Vector *) ptr;
+
+ const uint32_t qk = QK_Q4_0x4x2; // 256
+ const uint32_t nb = n / qk;
+ const uint32_t nloe = n % qk;
+
+ const HVX_Vector mask_h4 = Q6_Vb_vsplat_R(0x0F);
+ const HVX_Vector lut = *(const HVX_Vector *) kvalues_mxfp4_lut;
+
+ HVX_Vector_x8 r;
+ uint32_t i = 0;
+
+ #pragma unroll(2)
+ for (i=0; i < nb; i++) {
+ HVX_Vector v = vptr[i]; // 256 elements (128 bytes)
+ HVX_Vector v0 = Q6_V_vand_VV(v, mask_h4); // & 0x0F : first 128 elements
+ HVX_Vector v1 = Q6_Vub_vlsr_VubR(v, 4); // >> 4 : second 128 elements
+ r.v[i*2+0] = Q6_Vb_vlut32_VbVbI(v0, lut, 0);
+ r.v[i*2+1] = Q6_Vb_vlut32_VbVbI(v1, lut, 0);
+ }
+
+ if (nloe) {
+ HVX_Vector v = vptr[i]; // 256 elements (128 bytes)
+ HVX_Vector v0 = Q6_V_vand_VV(v, mask_h4); // & 0x0F : even 128 elements
+ HVX_Vector v1 = Q6_Vub_vlsr_VubR(v, 4); // >> 4 : odd 128 elements
+ HVX_VectorPair v0_1_p = Q6_W_vshuff_VVR(v1, v0, -1); // zip even:odd:...
+ r.v[i*2+0] = Q6_Vb_vlut32_VbVbI(Q6_V_lo_W(v0_1_p), lut, 0);
+ r.v[i*2+1] = Q6_Vb_vlut32_VbVbI(Q6_V_hi_W(v0_1_p), lut, 0);
+ }
+
+ return r;
+}
+
+static inline HVX_Vector_x8 hvx_vec_load_q8x4x8_full(const uint8_t * restrict ptr) {
const HVX_Vector * restrict vptr = (const HVX_Vector *) ptr;
HVX_Vector v0 = vptr[0]; // first 128 vals
return r;
}
+static inline HVX_Vector_x8 hvx_vec_load_q8x4x8_partial(const uint8_t * restrict ptr, uint32_t nloe) {
+ return hvx_vec_load_q8x4x8_full(ptr);
+}
+
// Reduce multiply 1024 x 1024 int8 elements (32x q4/8 blocks in 8x HVX vectors).
// Accumulate each block into a single int32 value.
// Return a single HVX vector with 32x int32 accumulators.
// if() checks are optimized out at compile time -- make sure to pass N as a constexpr.
static inline HVX_Vector hvx_vec_rmpy_x8_n(HVX_Vector_x8 x, HVX_Vector_x8 y, unsigned int n) {
- HVX_Vector r0 = Q6_V_vsplat_R(0);
- HVX_Vector r1 = Q6_V_vsplat_R(0);
- HVX_Vector r2 = Q6_V_vsplat_R(0);
- HVX_Vector r3 = Q6_V_vsplat_R(0);
- HVX_Vector r4 = Q6_V_vsplat_R(0);
- HVX_Vector r5 = Q6_V_vsplat_R(0);
- HVX_Vector r6 = Q6_V_vsplat_R(0);
- HVX_Vector r7 = Q6_V_vsplat_R(0);
+ HVX_Vector r0 = Q6_V_vzero();
+ HVX_Vector r1 = Q6_V_vzero();
+ HVX_Vector r2 = Q6_V_vzero();
+ HVX_Vector r3 = Q6_V_vzero();
+ HVX_Vector r4 = Q6_V_vzero();
+ HVX_Vector r5 = Q6_V_vzero();
+ HVX_Vector r6 = Q6_V_vzero();
+ HVX_Vector r7 = Q6_V_vzero();
HVX_VectorPair p3;
HVX_VectorPair p2;
}
static inline HVX_Vector hvx_vec_rmpy_x8_full(HVX_Vector_x8 x, HVX_Vector_x8 y) {
- return hvx_vec_rmpy_x8_n(x, y, 1024);
+ HVX_Vector r0 = Q6_Vw_vrmpy_VbVb(x.v[0], y.v[0]);
+ HVX_Vector r1 = Q6_Vw_vrmpy_VbVb(x.v[1], y.v[1]);
+ HVX_Vector r2 = Q6_Vw_vrmpy_VbVb(x.v[2], y.v[2]);
+ HVX_Vector r3 = Q6_Vw_vrmpy_VbVb(x.v[3], y.v[3]);
+ HVX_Vector r4 = Q6_Vw_vrmpy_VbVb(x.v[4], y.v[4]);
+ HVX_Vector r5 = Q6_Vw_vrmpy_VbVb(x.v[5], y.v[5]);
+ HVX_Vector r6 = Q6_Vw_vrmpy_VbVb(x.v[6], y.v[6]);
+ HVX_Vector r7 = Q6_Vw_vrmpy_VbVb(x.v[7], y.v[7]);
+
+ HVX_VectorPair p0 = Q6_W_vdeal_VVR(r1, r0, -4);
+ HVX_VectorPair p1 = Q6_W_vdeal_VVR(r3, r2, -4);
+ HVX_VectorPair p2 = Q6_W_vdeal_VVR(r5, r4, -4);
+ HVX_VectorPair p3 = Q6_W_vdeal_VVR(r7, r6, -4);
+
+ r0 = Q6_Vw_vadd_VwVw(Q6_V_lo_W(p0), Q6_V_hi_W(p0));
+ r1 = Q6_Vw_vadd_VwVw(Q6_V_lo_W(p1), Q6_V_hi_W(p1));
+ r2 = Q6_Vw_vadd_VwVw(Q6_V_lo_W(p2), Q6_V_hi_W(p2));
+ r3 = Q6_Vw_vadd_VwVw(Q6_V_lo_W(p3), Q6_V_hi_W(p3));
+
+ p0 = Q6_W_vdeal_VVR(r1, r0, -4);
+ p1 = Q6_W_vdeal_VVR(r3, r2, -4);
+
+ r0 = Q6_Vw_vadd_VwVw(Q6_V_lo_W(p0), Q6_V_hi_W(p0));
+ r1 = Q6_Vw_vadd_VwVw(Q6_V_lo_W(p1), Q6_V_hi_W(p1));
+
+ p0 = Q6_W_vdeal_VVR(r1, r0, -4);
+ r0 = Q6_Vw_vadd_VwVw(Q6_V_lo_W(p0), Q6_V_hi_W(p0));
+
+ return r0;
}
-// Handle most common cases of tensors not multiple of 1024.
-static inline HVX_Vector hvx_vec_rmpy_x8_nloe(HVX_Vector_x8 x, HVX_Vector_x8 y, unsigned int n) {
- if (n <= 256) { return hvx_vec_rmpy_x8_n(x, y, 256); };
- if (n <= 512) { return hvx_vec_rmpy_x8_n(x, y, 512); };
- if (n <= 768) { return hvx_vec_rmpy_x8_n(x, y, 768); };
- return hvx_vec_rmpy_x8_n(x, y, 1024);
+static inline HVX_Vector hvx_vec_rmpy_x8_partial(HVX_Vector_x8 x, HVX_Vector_x8 y, unsigned int n) {
+ if (n >= 512)
+ return hvx_vec_rmpy_x8_full(x, y);
+
+ return hvx_vec_rmpy_x8_partial(x, y, 512);
}
static void vec_dot_q4x4x2_q8x4x2_1x1(const int n, float * restrict s0, const void * restrict vx0, const void * restrict vy0) {
const uint8_t * restrict y_d = ((const uint8_t *) vy0 + y_qrow_size); // then scales
// Row sum (sf)
- HVX_Vector r0_sum = Q6_V_vsplat_R(0);
+ HVX_Vector r0_sum = Q6_V_vzero();
// Multiply and accumulate into int32.
// Compute combined scale (fp32).
uint32_t i = 0;
for (; i < nb; i++) {
- HVX_Vector_x8 vy_q = hvx_vec_load_q8x4x8(y_q + i * y_qblk_size);
- HVX_Vector_x8 r0_q = hvx_vec_load_q4x4x8(r0_x_q + i * x_qblk_size);
+ HVX_Vector_x8 vy_q = hvx_vec_load_q8x4x8_full(y_q + i * y_qblk_size);
+ HVX_Vector_x8 r0_q = hvx_vec_load_q4x4x8_full(r0_x_q + i * x_qblk_size);
HVX_Vector r0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r0_q, vy_q));
- HVX_Vector vy_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y_d + i * y_dblk_size));
+ HVX_Vector vy_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y_d + i * y_dblk_size));
HVX_Vector r0_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r0_x_d + i * x_dblk_size));
HVX_Vector r0_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_d, vy_d)));
r0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_fa, r0_sum));
}
- // Process leftovers, we still load full 4x4x2 block but zero out unused scales/blocks
+ // Process leftovers
if (nloe) {
- HVX_Vector_x8 vy_q = hvx_vec_load_q8x4x8(y_q + i * y_qblk_size);
- HVX_Vector_x8 r0_q = hvx_vec_load_q4x4x8(r0_x_q + i * x_qblk_size);
+ HVX_Vector_x8 vy_q = hvx_vec_load_q8x4x8_partial(y_q + i * y_qblk_size, nloe);
+ HVX_Vector_x8 r0_q = hvx_vec_load_q4x4x8_partial(r0_x_q + i * x_qblk_size, nloe);
- HVX_Vector r0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_nloe(r0_q, vy_q, nloe));
+ HVX_Vector r0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_partial(r0_q, vy_q, nloe));
- HVX_Vector vy_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y_d + i * y_dblk_size));
+ HVX_Vector vy_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y_d + i * y_dblk_size));
HVX_Vector r0_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r0_x_d + i * x_dblk_size));
HVX_Vector r0_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_d, vy_d)));
- // Zero out unused scales
+ // Zero out unused elements
HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe / 8);
r0_dd = Q6_V_vand_QV(bmask, r0_dd);
r0_ia = Q6_V_vand_QV(bmask, r0_ia);
const uint8_t * restrict y_d = ((const uint8_t *) vy0 + y_qrow_size); // then scales
// Row sum (sf)
- HVX_Vector r0_sum = Q6_V_vsplat_R(0);
- HVX_Vector r1_sum = Q6_V_vsplat_R(0);
+ HVX_Vector r0_sum = Q6_V_vzero();
+ HVX_Vector r1_sum = Q6_V_vzero();
// Multiply and accumulate into int32.
// Compute combined scale (fp32).
uint32_t i = 0;
for (; i < nb; i++) {
- HVX_Vector_x8 vy_q = hvx_vec_load_q8x4x8(y_q + i * y_qblk_size);
- HVX_Vector_x8 r0_q = hvx_vec_load_q4x4x8(r0_x_q + i * x_qblk_size);
- HVX_Vector_x8 r1_q = hvx_vec_load_q4x4x8(r1_x_q + i * x_qblk_size);
+ HVX_Vector_x8 vy_q = hvx_vec_load_q8x4x8_full(y_q + i * y_qblk_size);
+ HVX_Vector_x8 r0_q = hvx_vec_load_q4x4x8_full(r0_x_q + i * x_qblk_size);
+ HVX_Vector_x8 r1_q = hvx_vec_load_q4x4x8_full(r1_x_q + i * x_qblk_size);
HVX_Vector r0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r0_q, vy_q));
HVX_Vector r1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r1_q, vy_q));
- HVX_Vector vy_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y_d + i * y_dblk_size));
+ HVX_Vector vy_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y_d + i * y_dblk_size));
HVX_Vector r0_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r0_x_d + i * x_dblk_size));
HVX_Vector r1_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r1_x_d + i * x_dblk_size));
r1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_fa, r1_sum));
}
- // Process leftovers, we still load full 4x4x2 block but zero out unused scales/blocks
+ // Process leftovers
if (nloe) {
- HVX_Vector_x8 vy_q = hvx_vec_load_q8x4x8(y_q + i * y_qblk_size);
- HVX_Vector_x8 r0_q = hvx_vec_load_q4x4x8(r0_x_q + i * x_qblk_size);
- HVX_Vector_x8 r1_q = hvx_vec_load_q4x4x8(r1_x_q + i * x_qblk_size);
+ HVX_Vector_x8 vy_q = hvx_vec_load_q8x4x8_partial(y_q + i * y_qblk_size, nloe);
+ HVX_Vector_x8 r0_q = hvx_vec_load_q4x4x8_partial(r0_x_q + i * x_qblk_size, nloe);
+ HVX_Vector_x8 r1_q = hvx_vec_load_q4x4x8_partial(r1_x_q + i * x_qblk_size, nloe);
- HVX_Vector r0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_nloe(r0_q, vy_q, nloe));
- HVX_Vector r1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_nloe(r1_q, vy_q, nloe));
+ HVX_Vector r0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_partial(r0_q, vy_q, nloe));
+ HVX_Vector r1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_partial(r1_q, vy_q, nloe));
- HVX_Vector vy_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y_d + i * y_dblk_size));
+ HVX_Vector vy_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y_d + i * y_dblk_size));
HVX_Vector r0_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r0_x_d + i * x_dblk_size));
HVX_Vector r1_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r1_x_d + i * x_dblk_size));
HVX_Vector r0_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_d, vy_d)));
HVX_Vector r1_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r1_d, vy_d)));
- // Zero out unused scales
+ // Zero out unused elements
HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe / 8);
r0_dd = Q6_V_vand_QV(bmask, r0_dd);
r1_dd = Q6_V_vand_QV(bmask, r1_dd);
const uint8_t * restrict y1_d = ((const uint8_t *) vy1) + y_qrow_size; // then scales
// Row sums (sf) - 4 accumulators for 2×2 tile
- HVX_Vector r0_c0_sum = Q6_V_vsplat_R(0);
- HVX_Vector r0_c1_sum = Q6_V_vsplat_R(0);
- HVX_Vector r1_c0_sum = Q6_V_vsplat_R(0);
- HVX_Vector r1_c1_sum = Q6_V_vsplat_R(0);
+ HVX_Vector r0_c0_sum = Q6_V_vzero();
+ HVX_Vector r0_c1_sum = Q6_V_vzero();
+ HVX_Vector r1_c0_sum = Q6_V_vzero();
+ HVX_Vector r1_c1_sum = Q6_V_vzero();
const uint32_t nb = n / qk; // num full blocks
const uint32_t nloe = n % qk; // num leftover elements
uint32_t i = 0;
for (; i < nb; i++) {
// Load src1 columns (reused across both src0 rows)
- HVX_Vector_x8 vy0_q = hvx_vec_load_q8x4x8(y0_q + i * y_qblk_size);
- HVX_Vector_x8 vy1_q = hvx_vec_load_q8x4x8(y1_q + i * y_qblk_size);
+ HVX_Vector_x8 vy0_q = hvx_vec_load_q8x4x8_full(y0_q + i * y_qblk_size);
+ HVX_Vector_x8 vy1_q = hvx_vec_load_q8x4x8_full(y1_q + i * y_qblk_size);
// Load src0 rows (reused across both src1 columns)
- HVX_Vector_x8 r0_q = hvx_vec_load_q4x4x8(r0_x_q + i * x_qblk_size);
- HVX_Vector_x8 r1_q = hvx_vec_load_q4x4x8(r1_x_q + i * x_qblk_size);
+ HVX_Vector_x8 r0_q = hvx_vec_load_q4x4x8_full(r0_x_q + i * x_qblk_size);
+ HVX_Vector_x8 r1_q = hvx_vec_load_q4x4x8_full(r1_x_q + i * x_qblk_size);
// Compute 4 dot products: r0×c0, r0×c1, r1×c0, r1×c1
HVX_Vector r0_c0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r0_q, vy0_q));
HVX_Vector r1_c1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r1_q, vy1_q));
// Load scales
- HVX_Vector vy0_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y0_d + i * y_dblk_size));
- HVX_Vector vy1_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y1_d + i * y_dblk_size));
+ HVX_Vector vy0_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y0_d + i * y_dblk_size));
+ HVX_Vector vy1_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y1_d + i * y_dblk_size));
HVX_Vector r0_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r0_x_d + i * x_dblk_size));
HVX_Vector r1_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r1_x_d + i * x_dblk_size));
// Process leftovers
if (nloe) {
- HVX_Vector_x8 vy0_q = hvx_vec_load_q8x4x8(y0_q + i * y_qblk_size);
- HVX_Vector_x8 vy1_q = hvx_vec_load_q8x4x8(y1_q + i * y_qblk_size);
- HVX_Vector_x8 r0_q = hvx_vec_load_q4x4x8(r0_x_q + i * x_qblk_size);
- HVX_Vector_x8 r1_q = hvx_vec_load_q4x4x8(r1_x_q + i * x_qblk_size);
-
- HVX_Vector r0_c0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_nloe(r0_q, vy0_q, nloe));
- HVX_Vector r0_c1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_nloe(r0_q, vy1_q, nloe));
- HVX_Vector r1_c0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_nloe(r1_q, vy0_q, nloe));
- HVX_Vector r1_c1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_nloe(r1_q, vy1_q, nloe));
-
- HVX_Vector vy0_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y0_d + i * y_dblk_size));
- HVX_Vector vy1_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y1_d + i * y_dblk_size));
+ HVX_Vector_x8 vy0_q = hvx_vec_load_q8x4x8_partial(y0_q + i * y_qblk_size, nloe);
+ HVX_Vector_x8 vy1_q = hvx_vec_load_q8x4x8_partial(y1_q + i * y_qblk_size, nloe);
+ HVX_Vector_x8 r0_q = hvx_vec_load_q4x4x8_partial(r0_x_q + i * x_qblk_size, nloe);
+ HVX_Vector_x8 r1_q = hvx_vec_load_q4x4x8_partial(r1_x_q + i * x_qblk_size, nloe);
+
+ HVX_Vector r0_c0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_partial(r0_q, vy0_q, nloe));
+ HVX_Vector r0_c1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_partial(r0_q, vy1_q, nloe));
+ HVX_Vector r1_c0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_partial(r1_q, vy0_q, nloe));
+ HVX_Vector r1_c1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_partial(r1_q, vy1_q, nloe));
+
+ HVX_Vector vy0_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y0_d + i * y_dblk_size));
+ HVX_Vector vy1_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y1_d + i * y_dblk_size));
HVX_Vector r0_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r0_x_d + i * x_dblk_size));
HVX_Vector r1_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r1_x_d + i * x_dblk_size));
const uint8_t * restrict y_d = ((const uint8_t *) vy0 + y_qrow_size); // then scales
// Row sum (sf)
- HVX_Vector r0_sum = Q6_V_vsplat_R(0);
+ HVX_Vector r0_sum = Q6_V_vzero();
// Multiply and accumulate into int32.
// Compute combined scale (fp32).
uint32_t i = 0;
for (; i < nb; i++) {
- HVX_Vector_x8 vy_q = hvx_vec_load_q8x4x8(y_q + i * y_qblk_size);
- HVX_Vector_x8 r0_q = hvx_vec_load_q8x4x8(r0_x_q + i * x_qblk_size);
+ HVX_Vector_x8 vy_q = hvx_vec_load_q8x4x8_full(y_q + i * y_qblk_size);
+ HVX_Vector_x8 r0_q = hvx_vec_load_q8x4x8_full(r0_x_q + i * x_qblk_size);
HVX_Vector r0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r0_q, vy_q));
- HVX_Vector vy_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y_d + i * y_dblk_size));
+ HVX_Vector vy_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y_d + i * y_dblk_size));
HVX_Vector r0_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r0_x_d + i * x_dblk_size));
HVX_Vector r0_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_d, vy_d)));
r0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_fa, r0_sum));
}
- // Process leftovers, we still load full 4x4x2 block but zero out unused scales/blocks
+ // Process leftovers
if (nloe) {
- HVX_Vector_x8 vy_q = hvx_vec_load_q8x4x8(y_q + i * y_qblk_size);
- HVX_Vector_x8 r0_q = hvx_vec_load_q8x4x8(r0_x_q + i * x_qblk_size);
+ HVX_Vector_x8 vy_q = hvx_vec_load_q8x4x8_partial(y_q + i * y_qblk_size, nloe);
+ HVX_Vector_x8 r0_q = hvx_vec_load_q8x4x8_partial(r0_x_q + i * x_qblk_size, nloe);
- HVX_Vector r0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_nloe(r0_q, vy_q, nloe));
+ HVX_Vector r0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_partial(r0_q, vy_q, nloe));
- HVX_Vector vy_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y_d + i * y_dblk_size));
+ HVX_Vector vy_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y_d + i * y_dblk_size));
HVX_Vector r0_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r0_x_d + i * x_dblk_size));
HVX_Vector r0_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_d, vy_d)));
- // Zero out unused scales
+ // Zero out unused elements
HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe / 8);
r0_dd = Q6_V_vand_QV(bmask, r0_dd);
r0_ia = Q6_V_vand_QV(bmask, r0_ia);
const uint8_t * restrict y_d = ((const uint8_t *) vy0 + y_qrow_size); // then scales
// Row sum (qf32)
- HVX_Vector r0_sum = Q6_V_vsplat_R(0);
- HVX_Vector r1_sum = Q6_V_vsplat_R(0);
+ HVX_Vector r0_sum = Q6_V_vzero();
+ HVX_Vector r1_sum = Q6_V_vzero();
// Multiply and accumulate into int32.
// Compute combined scale (fp32).
uint32_t i = 0;
for (; i < nb; i++) {
- HVX_Vector_x8 vy_q = hvx_vec_load_q8x4x8(y_q + i * y_qblk_size);
- HVX_Vector_x8 r0_q = hvx_vec_load_q8x4x8(r0_x_q + i * x_qblk_size);
- HVX_Vector_x8 r1_q = hvx_vec_load_q8x4x8(r1_x_q + i * x_qblk_size);
+ HVX_Vector_x8 vy_q = hvx_vec_load_q8x4x8_full(y_q + i * y_qblk_size);
+ HVX_Vector_x8 r0_q = hvx_vec_load_q8x4x8_full(r0_x_q + i * x_qblk_size);
+ HVX_Vector_x8 r1_q = hvx_vec_load_q8x4x8_full(r1_x_q + i * x_qblk_size);
HVX_Vector r0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r0_q, vy_q));
HVX_Vector r1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r1_q, vy_q));
- HVX_Vector vy_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y_d + i * y_dblk_size));
+ HVX_Vector vy_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y_d + i * y_dblk_size));
HVX_Vector r0_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r0_x_d + i * x_dblk_size));
HVX_Vector r1_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r1_x_d + i * x_dblk_size));
r1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_fa, r1_sum));
}
- // Process leftovers, we still load full 4x4x2 block but zero out unused scales/blocks
+ // Process leftovers
if (nloe) {
- HVX_Vector_x8 vy_q = hvx_vec_load_q8x4x8(y_q + i * y_qblk_size);
- HVX_Vector_x8 r0_q = hvx_vec_load_q8x4x8(r0_x_q + i * x_qblk_size);
- HVX_Vector_x8 r1_q = hvx_vec_load_q8x4x8(r1_x_q + i * x_qblk_size);
+ HVX_Vector_x8 vy_q = hvx_vec_load_q8x4x8_partial(y_q + i * y_qblk_size, nloe);
+ HVX_Vector_x8 r0_q = hvx_vec_load_q8x4x8_partial(r0_x_q + i * x_qblk_size, nloe);
+ HVX_Vector_x8 r1_q = hvx_vec_load_q8x4x8_partial(r1_x_q + i * x_qblk_size, nloe);
- HVX_Vector r0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_nloe(r0_q, vy_q, nloe));
- HVX_Vector r1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_nloe(r1_q, vy_q, nloe));
+ HVX_Vector r0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_partial(r0_q, vy_q, nloe));
+ HVX_Vector r1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_partial(r1_q, vy_q, nloe));
HVX_Vector vy_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y_d + i * y_dblk_size));
HVX_Vector r0_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r0_x_d + i * x_dblk_size));
HVX_Vector r0_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_d, vy_d)));
HVX_Vector r1_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r1_d, vy_d)));
- // Zero out unused scales
+ // Zero out unused elements
HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe / 8);
r0_dd = Q6_V_vand_QV(bmask, r0_dd);
r1_dd = Q6_V_vand_QV(bmask, r1_dd);
const uint8_t * restrict y1_d = ((const uint8_t *) vy1) + y_qrow_size; // then scales
// Row sums (sf) - 4 accumulators for 2×2 tile
- HVX_Vector r0_c0_sum = Q6_V_vsplat_R(0);
- HVX_Vector r0_c1_sum = Q6_V_vsplat_R(0);
- HVX_Vector r1_c0_sum = Q6_V_vsplat_R(0);
- HVX_Vector r1_c1_sum = Q6_V_vsplat_R(0);
+ HVX_Vector r0_c0_sum = Q6_V_vzero();
+ HVX_Vector r0_c1_sum = Q6_V_vzero();
+ HVX_Vector r1_c0_sum = Q6_V_vzero();
+ HVX_Vector r1_c1_sum = Q6_V_vzero();
const uint32_t nb = n / qk; // num full blocks
const uint32_t nloe = n % qk; // num leftover elements
uint32_t i = 0;
for (; i < nb; i++) {
// Load src1 columns (reused across both src0 rows)
- HVX_Vector_x8 vy0_q = hvx_vec_load_q8x4x8(y0_q + i * y_qblk_size);
- HVX_Vector_x8 vy1_q = hvx_vec_load_q8x4x8(y1_q + i * y_qblk_size);
+ HVX_Vector_x8 vy0_q = hvx_vec_load_q8x4x8_full(y0_q + i * y_qblk_size);
+ HVX_Vector_x8 vy1_q = hvx_vec_load_q8x4x8_full(y1_q + i * y_qblk_size);
// Load src0 rows (reused across both src1 columns)
- HVX_Vector_x8 r0_q = hvx_vec_load_q8x4x8(r0_x_q + i * x_qblk_size);
- HVX_Vector_x8 r1_q = hvx_vec_load_q8x4x8(r1_x_q + i * x_qblk_size);
+ HVX_Vector_x8 r0_q = hvx_vec_load_q8x4x8_full(r0_x_q + i * x_qblk_size);
+ HVX_Vector_x8 r1_q = hvx_vec_load_q8x4x8_full(r1_x_q + i * x_qblk_size);
// Compute 4 dot products: r0×c0, r0×c1, r1×c0, r1×c1
HVX_Vector r0_c0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r0_q, vy0_q));
HVX_Vector r1_c1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r1_q, vy1_q));
// Load scales
- HVX_Vector vy0_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y0_d + i * y_dblk_size));
- HVX_Vector vy1_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y1_d + i * y_dblk_size));
+ HVX_Vector vy0_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y0_d + i * y_dblk_size));
+ HVX_Vector vy1_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y1_d + i * y_dblk_size));
HVX_Vector r0_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r0_x_d + i * x_dblk_size));
HVX_Vector r1_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r1_x_d + i * x_dblk_size));
// Process leftovers
if (nloe) {
- HVX_Vector_x8 vy0_q = hvx_vec_load_q8x4x8(y0_q + i * y_qblk_size);
- HVX_Vector_x8 vy1_q = hvx_vec_load_q8x4x8(y1_q + i * y_qblk_size);
- HVX_Vector_x8 r0_q = hvx_vec_load_q8x4x8(r0_x_q + i * x_qblk_size);
- HVX_Vector_x8 r1_q = hvx_vec_load_q8x4x8(r1_x_q + i * x_qblk_size);
-
- HVX_Vector r0_c0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_nloe(r0_q, vy0_q, nloe));
- HVX_Vector r0_c1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_nloe(r0_q, vy1_q, nloe));
- HVX_Vector r1_c0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_nloe(r1_q, vy0_q, nloe));
- HVX_Vector r1_c1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_nloe(r1_q, vy1_q, nloe));
-
- HVX_Vector vy0_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y0_d + i * y_dblk_size));
- HVX_Vector vy1_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y1_d + i * y_dblk_size));
+ HVX_Vector_x8 vy0_q = hvx_vec_load_q8x4x8_partial(y0_q + i * y_qblk_size, nloe);
+ HVX_Vector_x8 vy1_q = hvx_vec_load_q8x4x8_partial(y1_q + i * y_qblk_size, nloe);
+ HVX_Vector_x8 r0_q = hvx_vec_load_q8x4x8_partial(r0_x_q + i * x_qblk_size, nloe);
+ HVX_Vector_x8 r1_q = hvx_vec_load_q8x4x8_partial(r1_x_q + i * x_qblk_size, nloe);
+
+ HVX_Vector r0_c0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_partial(r0_q, vy0_q, nloe));
+ HVX_Vector r0_c1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_partial(r0_q, vy1_q, nloe));
+ HVX_Vector r1_c0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_partial(r1_q, vy0_q, nloe));
+ HVX_Vector r1_c1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_partial(r1_q, vy1_q, nloe));
+
+ HVX_Vector vy0_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y0_d + i * y_dblk_size));
+ HVX_Vector vy1_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y1_d + i * y_dblk_size));
HVX_Vector r0_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r0_x_d + i * x_dblk_size));
HVX_Vector r1_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r1_x_d + i * x_dblk_size));
HVX_Vector r1_c0_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r1_d, vy0_d)));
HVX_Vector r1_c1_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r1_d, vy1_d)));
- // Zero out unused scales
+ // Zero out unused elements
HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe / 8);
r0_c0_dd = Q6_V_vand_QV(bmask, r0_c0_dd);
r0_c1_dd = Q6_V_vand_QV(bmask, r0_c1_dd);
const uint8_t * restrict y_d = ((const uint8_t *) vy0 + y_qrow_size); // then scales
// Row sum (sf)
- HVX_Vector r0_sum = Q6_V_vsplat_R(0);
+ HVX_Vector r0_sum = Q6_V_vzero();
// Multiply and accumulate into int32.
// Compute combined scale (fp32).
uint32_t i = 0;
for (; i < nb; i++) {
- HVX_Vector_x8 vy_q = hvx_vec_load_q8x4x8(y_q + i * y_qblk_size);
- HVX_Vector_x8 r0_q = hvx_vec_load_mxfp4x4x8(r0_x_q + i * x_qblk_size);
+ HVX_Vector_x8 vy_q = hvx_vec_load_q8x4x8_full( y_q + i * y_qblk_size);
+ HVX_Vector_x8 r0_q = hvx_vec_load_mxfp4x4x8_full(r0_x_q + i * x_qblk_size);
HVX_Vector r0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r0_q, vy_q));
// Process leftovers
if (nloe) {
- HVX_Vector_x8 vy_q = hvx_vec_load_q8x4x8(y_q + i * y_qblk_size);
- HVX_Vector_x8 r0_q = hvx_vec_load_mxfp4x4x8(r0_x_q + i * x_qblk_size);
+ HVX_Vector_x8 vy_q = hvx_vec_load_q8x4x8_partial( y_q + i * y_qblk_size, nloe);
+ HVX_Vector_x8 r0_q = hvx_vec_load_mxfp4x4x8_partial(r0_x_q + i * x_qblk_size, nloe);
- HVX_Vector r0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r0_q, vy_q));
+ HVX_Vector r0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_partial(r0_q, vy_q, nloe));
- HVX_Vector vy_d = *(const HVX_UVector *) (y_d + i * y_dblk_size);
+ HVX_Vector vy_d = *(const HVX_UVector *) (y_d + i * y_dblk_size);
HVX_Vector r0_d = *(const HVX_UVector *) (r0_x_d + i * x_dblk_size);
// Convert vy_d from fp16 to fp32 while applying 0.5 scaling which is used for e8m0 halving
const uint8_t * restrict y_d = ((const uint8_t *) vy0) + y_qrow_size; // then scales
// Row sum (sf)
- HVX_Vector r0_sum = Q6_V_vsplat_R(0);
- HVX_Vector r1_sum = Q6_V_vsplat_R(0);
+ HVX_Vector r0_sum = Q6_V_vzero();
+ HVX_Vector r1_sum = Q6_V_vzero();
// Multiply and accumulate into int32.
// Compute combined scale (fp32).
uint32_t i = 0;
for (; i < nb; i++) {
- HVX_Vector_x8 vy_q = hvx_vec_load_q8x4x8(y_q + i * y_qblk_size);
- HVX_Vector_x8 r0_q = hvx_vec_load_mxfp4x4x8(r0_x_q + i * x_qblk_size);
- HVX_Vector_x8 r1_q = hvx_vec_load_mxfp4x4x8(r1_x_q + i * x_qblk_size);
+ HVX_Vector_x8 vy_q = hvx_vec_load_q8x4x8_full( y_q + i * y_qblk_size);
+ HVX_Vector_x8 r0_q = hvx_vec_load_mxfp4x4x8_full(r0_x_q + i * x_qblk_size);
+ HVX_Vector_x8 r1_q = hvx_vec_load_mxfp4x4x8_full(r1_x_q + i * x_qblk_size);
HVX_Vector r0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r0_q, vy_q));
HVX_Vector r1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r1_q, vy_q));
// Process leftovers
if (nloe) {
- HVX_Vector_x8 vy_q = hvx_vec_load_q8x4x8(y_q + i * y_qblk_size);
- HVX_Vector_x8 r0_q = hvx_vec_load_mxfp4x4x8(r0_x_q + i * x_qblk_size);
- HVX_Vector_x8 r1_q = hvx_vec_load_mxfp4x4x8(r1_x_q + i * x_qblk_size);
+ HVX_Vector_x8 vy_q = hvx_vec_load_q8x4x8_partial( y_q + i * y_qblk_size, nloe);
+ HVX_Vector_x8 r0_q = hvx_vec_load_mxfp4x4x8_partial(r0_x_q + i * x_qblk_size, nloe);
+ HVX_Vector_x8 r1_q = hvx_vec_load_mxfp4x4x8_partial(r1_x_q + i * x_qblk_size, nloe);
HVX_Vector r0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r0_q, vy_q));
HVX_Vector r1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r1_q, vy_q));
- HVX_Vector vy_d = *(const HVX_UVector *) (y_d + i * y_dblk_size);
+ HVX_Vector vy_d = *(const HVX_UVector *) (y_d + i * y_dblk_size);
HVX_Vector r0_d = *(const HVX_UVector *) (r0_x_d + i * x_dblk_size);
HVX_Vector r1_d = *(const HVX_UVector *) (r1_x_d + i * x_dblk_size);
const uint8_t * restrict y1_d = ((const uint8_t *) vy1) + y_qrow_size; // then scales
// Row sums (sf) - 4 accumulators for 2×2 tile
- HVX_Vector r0_c0_sum = Q6_V_vsplat_R(0);
- HVX_Vector r0_c1_sum = Q6_V_vsplat_R(0);
- HVX_Vector r1_c0_sum = Q6_V_vsplat_R(0);
- HVX_Vector r1_c1_sum = Q6_V_vsplat_R(0);
+ HVX_Vector r0_c0_sum = Q6_V_vzero();
+ HVX_Vector r0_c1_sum = Q6_V_vzero();
+ HVX_Vector r1_c0_sum = Q6_V_vzero();
+ HVX_Vector r1_c1_sum = Q6_V_vzero();
const uint32_t nb = n / qk; // num full blocks
const uint32_t nloe = n % qk; // num leftover elements
uint32_t i = 0;
for (; i < nb; i++) {
// Load src1 columns (reused across both src0 rows)
- HVX_Vector_x8 vy0_q = hvx_vec_load_q8x4x8(y0_q + i * y_qblk_size);
- HVX_Vector_x8 vy1_q = hvx_vec_load_q8x4x8(y1_q + i * y_qblk_size);
+ HVX_Vector_x8 vy0_q = hvx_vec_load_q8x4x8_full(y0_q + i * y_qblk_size);
+ HVX_Vector_x8 vy1_q = hvx_vec_load_q8x4x8_full(y1_q + i * y_qblk_size);
// Load src0 rows (reused across both src1 columns)
- HVX_Vector_x8 r0_q = hvx_vec_load_mxfp4x4x8(r0_x_q + i * x_qblk_size);
- HVX_Vector_x8 r1_q = hvx_vec_load_mxfp4x4x8(r1_x_q + i * x_qblk_size);
+ HVX_Vector_x8 r0_q = hvx_vec_load_mxfp4x4x8_full(r0_x_q + i * x_qblk_size);
+ HVX_Vector_x8 r1_q = hvx_vec_load_mxfp4x4x8_full(r1_x_q + i * x_qblk_size);
// Compute 4 dot products: r0×c0, r0×c1, r1×c0, r1×c1
HVX_Vector r0_c0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r0_q, vy0_q));
// Process leftovers
if (nloe) {
- HVX_Vector_x8 vy0_q = hvx_vec_load_q8x4x8(y0_q + i * y_qblk_size);
- HVX_Vector_x8 vy1_q = hvx_vec_load_q8x4x8(y1_q + i * y_qblk_size);
- HVX_Vector_x8 r0_q = hvx_vec_load_mxfp4x4x8(r0_x_q + i * x_qblk_size);
- HVX_Vector_x8 r1_q = hvx_vec_load_mxfp4x4x8(r1_x_q + i * x_qblk_size);
+ HVX_Vector_x8 vy0_q = hvx_vec_load_q8x4x8_partial( y0_q + i * y_qblk_size, nloe);
+ HVX_Vector_x8 vy1_q = hvx_vec_load_q8x4x8_partial( y1_q + i * y_qblk_size, nloe);
+ HVX_Vector_x8 r0_q = hvx_vec_load_mxfp4x4x8_partial(r0_x_q + i * x_qblk_size, nloe);
+ HVX_Vector_x8 r1_q = hvx_vec_load_mxfp4x4x8_partial(r1_x_q + i * x_qblk_size, nloe);
- HVX_Vector r0_c0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_nloe(r0_q, vy0_q, nloe));
- HVX_Vector r0_c1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_nloe(r0_q, vy1_q, nloe));
- HVX_Vector r1_c0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_nloe(r1_q, vy0_q, nloe));
- HVX_Vector r1_c1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_nloe(r1_q, vy1_q, nloe));
+ HVX_Vector r0_c0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_partial(r0_q, vy0_q, nloe));
+ HVX_Vector r0_c1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_partial(r0_q, vy1_q, nloe));
+ HVX_Vector r1_c0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_partial(r1_q, vy0_q, nloe));
+ HVX_Vector r1_c1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_partial(r1_q, vy1_q, nloe));
HVX_Vector vy0_d = *(const HVX_UVector *) (y0_d + i * y_dblk_size);
HVX_Vector vy1_d = *(const HVX_UVector *) (y1_d + i * y_dblk_size);
uint32_t nvec = n / VLEN_FP16; // num full fp16 hvx vectors
uint32_t nloe = n % VLEN_FP16; // leftover elements
- HVX_VectorPair rsum_p = Q6_W_vcombine_VV(Q6_V_vsplat_R(0), Q6_V_vsplat_R(0));
+ HVX_VectorPair rsum_p = Q6_W_vzero();
uint32_t i = 0;
uint32_t nvec = n / VLEN_FP16;
uint32_t nloe = n % VLEN_FP16;
- HVX_VectorPair rsum0_p = Q6_W_vcombine_VV(Q6_V_vsplat_R(0), Q6_V_vsplat_R(0));
- HVX_VectorPair rsum1_p = Q6_W_vcombine_VV(Q6_V_vsplat_R(0), Q6_V_vsplat_R(0));
+ HVX_VectorPair rsum0_p = Q6_W_vzero();
+ HVX_VectorPair rsum1_p = Q6_W_vzero();
uint32_t i = 0;
uint32_t nloe = n % VLEN_FP16;
// Row sums (sf) - 4 accumulators for 2×2 tile
- HVX_VectorPair r0_c0_sum_p = Q6_W_vcombine_VV(Q6_V_vsplat_R(0), Q6_V_vsplat_R(0));
- HVX_VectorPair r0_c1_sum_p = Q6_W_vcombine_VV(Q6_V_vsplat_R(0), Q6_V_vsplat_R(0));
- HVX_VectorPair r1_c0_sum_p = Q6_W_vcombine_VV(Q6_V_vsplat_R(0), Q6_V_vsplat_R(0));
- HVX_VectorPair r1_c1_sum_p = Q6_W_vcombine_VV(Q6_V_vsplat_R(0), Q6_V_vsplat_R(0));
+ HVX_VectorPair r0_c0_sum_p = Q6_W_vzero();
+ HVX_VectorPair r0_c1_sum_p = Q6_W_vzero();
+ HVX_VectorPair r1_c0_sum_p = Q6_W_vzero();
+ HVX_VectorPair r1_c1_sum_p = Q6_W_vzero();
uint32_t i = 0;
uint32_t nvec = n / VLEN_FP16; // num full fp16 hvx vectors
uint32_t nloe = n % VLEN_FP16; // leftover elements
- HVX_Vector rsum = Q6_V_vsplat_R(0);
+ HVX_Vector rsum = Q6_V_vzero();
uint32_t i = 0;
uint32_t nvec = n / VLEN_FP16; // num full fp16 hvx vectors
uint32_t nloe = n % VLEN_FP16; // leftover elements
- const HVX_Vector zero = Q6_V_vsplat_R(0);
+ const HVX_Vector zero = Q6_V_vzero();
- HVX_Vector rsum = Q6_V_vsplat_R(0);
+ HVX_Vector rsum = Q6_V_vzero();
uint32_t i = 0;
assert((unsigned long) y_q % 128 == 0);
HVX_Vector * vx = (HVX_Vector *) x;
- HVX_Vector zero = Q6_V_vsplat_R(0);
+ HVX_Vector zero = Q6_V_vzero();
// Use reduce max fp32 to find max(abs(e)) first
HVX_Vector vmax0_sf = hvx_vec_reduce_max_f32(hvx_vec_abs_f32(vx[0]));
HVX_Vector * vx = (HVX_Vector *) x;
// Load and convert into QF32
- HVX_Vector zero = Q6_V_vsplat_R(0);
+ HVX_Vector zero = Q6_V_vzero();
HVX_Vector vx0_qf = Q6_Vqf32_vsub_VsfVsf(vx[0], zero); // 32 elements
HVX_Vector vx1_qf = Q6_Vqf32_vsub_VsfVsf(vx[1], zero); // 32 elements
HVX_Vector vx2_qf = Q6_Vqf32_vsub_VsfVsf(vx[2], zero); // 32 elements
HVX_Vector * vx = (HVX_Vector *) x;
// Load and convert into QF32
- HVX_Vector zero = Q6_V_vsplat_R(0);
+ HVX_Vector zero = Q6_V_vzero();
HVX_Vector vx0_qf = Q6_Vqf32_vsub_VsfVsf(vx[0], zero); // 32 elements
HVX_Vector vx1_qf = Q6_Vqf32_vsub_VsfVsf(vx[1], zero); // 32 elements
HVX_Vector vx2_qf = Q6_Vqf32_vsub_VsfVsf(vx[2], zero); // 32 elements