#define MM_SPAD_SRC1_NROWS 16
#define MM_SPAD_DST_NROWS 2
-struct htp_matmul_type {
+struct htp_matmul_context {
const char * type;
- void (*vec_dot)(const int n, float * restrict s, const void * restrict vx, const void * restrict vy);
- void (*vec_dot_rx2)(const int n, float * restrict s, const void * restrict vx, uint32_t vx_row_size, const void * restrict vy);
+ struct htp_ops_context * octx;
+
+ void (*vec_dot_1x1)(const int n, float * restrict s0,
+ const void * restrict vx0,
+ const void * restrict vy0);
+
+ void (*vec_dot_2x1)(const int n, float * restrict s0,
+ const void * restrict vx0, const void * restrict vx1,
+ const void * restrict vy0);
+
+ void (*vec_dot_2x2)(const int n, float * restrict s0, float * restrict s1,
+ const void * restrict vx0, const void * restrict vx1,
+ const void * restrict vy0, const void * restrict vy1);
+
+ // Precomputed values
+ uint32_t src0_nrows_per_thread;
+ uint32_t src1_nrows_per_thread;
+
+ struct fastdiv_values mm_div_ne12_ne1;
+ struct fastdiv_values mm_div_ne1;
+ struct fastdiv_values mm_div_r2;
+ struct fastdiv_values mm_div_r3;
};
// vdelta control to replicate first 4x fp32 values across lanes
HVX_Vector v6_7 = vptr[3]; // ...
const HVX_Vector mask_h4 = Q6_Vb_vsplat_R(0x0F);
+ const HVX_Vector i8 = Q6_Vb_vsplat_R(8);
HVX_Vector v0 = Q6_V_vand_VV(v0_1, mask_h4); // & 0x0F
HVX_Vector v1 = Q6_Vub_vlsr_VubR(v0_1, 4); // >> 4
HVX_Vector v7 = Q6_Vub_vlsr_VubR(v6_7, 4); // >> 4
// Convert uint4 to int4 (i.e. x - 8)
- const HVX_Vector i8 = Q6_Vb_vsplat_R(8);
- v0 = Q6_Vb_vsub_VbVb(v0, i8);
- v1 = Q6_Vb_vsub_VbVb(v1, i8);
- v2 = Q6_Vb_vsub_VbVb(v2, i8);
- v3 = Q6_Vb_vsub_VbVb(v3, i8);
- v4 = Q6_Vb_vsub_VbVb(v4, i8);
- v5 = Q6_Vb_vsub_VbVb(v5, i8);
- v6 = Q6_Vb_vsub_VbVb(v6, i8);
- v7 = Q6_Vb_vsub_VbVb(v7, i8);
+ v0 = Q6_Vb_vsub_VbVb(v0, i8);
+ v1 = Q6_Vb_vsub_VbVb(v1, i8);
+ v2 = Q6_Vb_vsub_VbVb(v2, i8);
+ v3 = Q6_Vb_vsub_VbVb(v3, i8);
+ v4 = Q6_Vb_vsub_VbVb(v4, i8);
+ v5 = Q6_Vb_vsub_VbVb(v5, i8);
+ v6 = Q6_Vb_vsub_VbVb(v6, i8);
+ v7 = Q6_Vb_vsub_VbVb(v7, i8);
HVX_Vector_x8 r = { v0, v1, v2, v3, v4, v5, v6, v7 };
return r;
HVX_Vector v6_7 = vptr[3]; // ...
const HVX_Vector mask_h4 = Q6_Vb_vsplat_R(0x0F);
+ const HVX_Vector lut = *(const HVX_Vector *) kvalues_mxfp4_lut;
HVX_Vector v0 = Q6_V_vand_VV(v0_1, mask_h4); // & 0x0F
HVX_Vector v1 = Q6_Vub_vlsr_VubR(v0_1, 4); // >> 4
HVX_Vector v6 = Q6_V_vand_VV(v6_7, mask_h4); // & 0x0F
HVX_Vector v7 = Q6_Vub_vlsr_VubR(v6_7, 4); // >> 4
- HVX_Vector lut = *(const HVX_Vector *) kvalues_mxfp4_lut;
- v0 = Q6_Vb_vlut32_VbVbI(v0, lut, 0);
- v1 = Q6_Vb_vlut32_VbVbI(v1, lut, 0);
- v2 = Q6_Vb_vlut32_VbVbI(v2, lut, 0);
- v3 = Q6_Vb_vlut32_VbVbI(v3, lut, 0);
- v4 = Q6_Vb_vlut32_VbVbI(v4, lut, 0);
- v5 = Q6_Vb_vlut32_VbVbI(v5, lut, 0);
- v6 = Q6_Vb_vlut32_VbVbI(v6, lut, 0);
- v7 = Q6_Vb_vlut32_VbVbI(v7, lut, 0);
+ v0 = Q6_Vb_vlut32_VbVbI(v0, lut, 0);
+ v1 = Q6_Vb_vlut32_VbVbI(v1, lut, 0);
+ v2 = Q6_Vb_vlut32_VbVbI(v2, lut, 0);
+ v3 = Q6_Vb_vlut32_VbVbI(v3, lut, 0);
+ v4 = Q6_Vb_vlut32_VbVbI(v4, lut, 0);
+ v5 = Q6_Vb_vlut32_VbVbI(v5, lut, 0);
+ v6 = Q6_Vb_vlut32_VbVbI(v6, lut, 0);
+ v7 = Q6_Vb_vlut32_VbVbI(v7, lut, 0);
HVX_Vector_x8 r = { v0, v1, v2, v3, v4, v5, v6, v7 };
return r;
return r;
}
-static inline HVX_Vector_x4 hvx_vec_load_x4_f16(const uint8_t * restrict ptr) {
- const HVX_Vector * restrict vptr = (const HVX_Vector *) ptr;
-
- HVX_Vector v0 = vptr[0]; // first 64 vals
- HVX_Vector v1 = vptr[1]; // second 64 vals
- HVX_Vector v2 = vptr[2]; // third 64 vals
- HVX_Vector v3 = vptr[3]; // forth 64 vals
-
- HVX_Vector_x4 r = { v0, v1, v2, v3 };
- return r;
-}
-
-static inline HVX_Vector_x4 hvx_vec_load_x4_f32_as_f16(const uint8_t * restrict ptr) {
- const HVX_VectorPair * restrict vptr = (const HVX_VectorPair *) ptr;
-
- HVX_VectorPair v0 = vptr[0]; // first 64 vals
- HVX_VectorPair v1 = vptr[1]; // second 64 vals
- HVX_VectorPair v2 = vptr[2]; // third 64 vals
- HVX_VectorPair v3 = vptr[3]; // forth 64 vals
-
- HVX_Vector vq0_lo = Q6_Vqf32_vsub_VsfVsf(Q6_V_lo_W(v0), Q6_V_vzero());
- HVX_Vector vq0_hi = Q6_Vqf32_vsub_VsfVsf(Q6_V_hi_W(v0), Q6_V_vzero());
- HVX_Vector vq1_lo = Q6_Vqf32_vsub_VsfVsf(Q6_V_lo_W(v1), Q6_V_vzero());
- HVX_Vector vq1_hi = Q6_Vqf32_vsub_VsfVsf(Q6_V_hi_W(v1), Q6_V_vzero());
- HVX_Vector vq2_lo = Q6_Vqf32_vsub_VsfVsf(Q6_V_lo_W(v2), Q6_V_vzero());
- HVX_Vector vq2_hi = Q6_Vqf32_vsub_VsfVsf(Q6_V_hi_W(v2), Q6_V_vzero());
- HVX_Vector vq3_lo = Q6_Vqf32_vsub_VsfVsf(Q6_V_lo_W(v3), Q6_V_vzero());
- HVX_Vector vq3_hi = Q6_Vqf32_vsub_VsfVsf(Q6_V_hi_W(v3), Q6_V_vzero());
-
- HVX_Vector vh0 = Q6_Vhf_equals_Wqf32(Q6_W_vcombine_VV(vq0_hi, vq0_lo));
- HVX_Vector vh1 = Q6_Vhf_equals_Wqf32(Q6_W_vcombine_VV(vq1_hi, vq1_lo));
- HVX_Vector vh2 = Q6_Vhf_equals_Wqf32(Q6_W_vcombine_VV(vq2_hi, vq2_lo));
- HVX_Vector vh3 = Q6_Vhf_equals_Wqf32(Q6_W_vcombine_VV(vq3_hi, vq3_lo));
-
- // vcombine does a shuffle, use vdeal to undo
-
- HVX_Vector_x4 r = { Q6_Vh_vdeal_Vh(vh0), Q6_Vh_vdeal_Vh(vh1), Q6_Vh_vdeal_Vh(vh2), Q6_Vh_vdeal_Vh(vh3) };
- return r;
-}
-
// Reduce multiply 1024 x 1024 int8 elements (32x q4/8 blocks in 8x HVX vectors).
// Accumulate each block into a single int32 value.
// Return a single HVX vector with 32x int32 accumulators.
return hvx_vec_rmpy_x8_n(x, y, 1024);
}
-static void vec_dot_q4x4x2_q8x4x2(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) {
+static void vec_dot_q4x4x2_q8x4x2_1x1(const int n, float * restrict s0, const void * restrict vx0, const void * restrict vy0) {
assert(n % 32 == 0); // min sub-block size
- assert((unsigned long) vx % 128 == 0);
- assert((unsigned long) vy % 128 == 0);
+ assert((unsigned long) vx0 % 128 == 0);
+ assert((unsigned long) vy0 % 128 == 0);
const uint32_t qk = QK_Q4_0x4x2 * 4;
- const uint32_t x_dblk_size = 8 * 4 * 2; // 32x __fp16
- const uint32_t x_qblk_size = qk / 2; // int4
- const uint32_t x_qrow_size = n / 2; // int4 (not padded)
+ const uint32_t x_dblk_size = 8 * 4 * 2; // 32x __fp16
+ const uint32_t x_qblk_size = qk / 2; // int4
+ const uint32_t x_qrow_size = n / 2; // int4 (not padded)
- const uint32_t y_dblk_size = 8 * 4 * 2; // 32x __fp16
- const uint32_t y_qblk_size = qk; // int8
- const uint32_t y_qrow_size = n; // int8 (not padded)
+ const uint32_t y_dblk_size = 8 * 4 * 2; // 32x __fp16
+ const uint32_t y_qblk_size = qk; // int8
+ const uint32_t y_qrow_size = n; // int8 (not padded)
- const uint8_t * restrict r0_x_q = ((const uint8_t *) vx + 0); // quants first
- const uint8_t * restrict r0_x_d = ((const uint8_t *) vx + x_qrow_size); // then scales
+ const uint8_t * restrict r0_x_q = ((const uint8_t *) vx0 + 0); // quants first
+ const uint8_t * restrict r0_x_d = ((const uint8_t *) vx0 + x_qrow_size); // then scales
- const uint8_t * restrict y_q = ((const uint8_t *) vy + 0); // quants first
- const uint8_t * restrict y_d = ((const uint8_t *) vy + y_qrow_size); // then scales
+ const uint8_t * restrict y_q = ((const uint8_t *) vy0 + 0); // quants first
+ const uint8_t * restrict y_d = ((const uint8_t *) vy0 + y_qrow_size); // then scales
// Row sum (sf)
HVX_Vector r0_sum = Q6_V_vsplat_R(0);
r0_sum = hvx_vec_reduce_sum_f32(r0_sum);
- hvx_vec_store_u(&s[0], 4, r0_sum);
+ hvx_vec_store_u(s0, 4, r0_sum);
}
-static void vec_dot_q4x4x2_q8x4x2_rx2(const int n,
- float * restrict s,
- const void * restrict vx,
- uint32_t vx_row_size,
- const void * restrict vy) {
+static void vec_dot_q4x4x2_q8x4x2_2x1(const int n, float * restrict s0,
+ const void * restrict vx0, const void * restrict vx1,
+ const void * restrict vy0) {
assert(n % 32 == 0); // min sub-block size
- assert((unsigned long) vx % 128 == 0);
- assert((unsigned long) vy % 128 == 0);
+ assert((unsigned long) vx0 % 128 == 0);
+ assert((unsigned long) vx1 % 128 == 0);
+ assert((unsigned long) vy0 % 128 == 0);
const uint32_t qk = QK_Q4_0x4x2 * 4;
- const uint32_t x_dblk_size = 8 * 4 * 2; // 32x __fp16
- const uint32_t x_qblk_size = qk / 2; // int4
- const uint32_t x_qrow_size = n / 2; // int4 (not padded)
-
- const uint32_t y_dblk_size = 8 * 4 * 2; // 32x __fp16
- const uint32_t y_qblk_size = qk; // int8
- const uint32_t y_qrow_size = n; // int8 (not padded)
+ const uint32_t x_dblk_size = 8 * 4 * 2; // 32x __fp16
+ const uint32_t x_qblk_size = qk / 2; // int4
+ const uint32_t x_qrow_size = n / 2; // int4 (not padded)
- const uint8_t * restrict r0_x_q = ((const uint8_t *) (vx + (0 * vx_row_size)) + 0); // quants first
- const uint8_t * restrict r0_x_d = ((const uint8_t *) (vx + (0 * vx_row_size)) + x_qrow_size); // then scales
+ const uint32_t y_dblk_size = 8 * 4 * 2; // 32x __fp16
+ const uint32_t y_qblk_size = qk; // int8
+ const uint32_t y_qrow_size = n; // int8 (not padded)
- const uint8_t * restrict r1_x_q = ((const uint8_t *) (vx + (1 * vx_row_size)) + 0); // quants first
- const uint8_t * restrict r1_x_d = ((const uint8_t *) (vx + (1 * vx_row_size)) + x_qrow_size); // then scales
+ const uint8_t * restrict r0_x_q = ((const uint8_t *) vx0) + 0; // quants first
+ const uint8_t * restrict r0_x_d = ((const uint8_t *) vx0) + x_qrow_size; // then scales
+ const uint8_t * restrict r1_x_q = ((const uint8_t *) vx1) + 0; // quants first
+ const uint8_t * restrict r1_x_d = ((const uint8_t *) vx1) + x_qrow_size; // then scales
- const uint8_t * restrict y_q = ((const uint8_t *) vy + 0); // quants first
- const uint8_t * restrict y_d = ((const uint8_t *) vy + y_qrow_size); // then scales
+ const uint8_t * restrict y_q = ((const uint8_t *) vy0 + 0); // quants first
+ const uint8_t * restrict y_d = ((const uint8_t *) vy0 + y_qrow_size); // then scales
// Row sum (sf)
HVX_Vector r0_sum = Q6_V_vsplat_R(0);
}
HVX_Vector rsum = hvx_vec_reduce_sum_f32x2(r0_sum, r1_sum);
- hvx_vec_store_u(&s[0], 8, rsum);
+ hvx_vec_store_u(s0, 8, rsum);
}
-static void vec_dot_q8x4x2_q8x4x2(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) {
+static void vec_dot_q4x4x2_q8x4x2_2x2(const int n, float * restrict s0, float * restrict s1,
+ const void * restrict vx0, const void * restrict vx1,
+ const void * restrict vy0, const void * restrict vy1) {
+ assert(n % 32 == 0);
+ assert((unsigned long) vx0 % 128 == 0);
+ assert((unsigned long) vx1 % 128 == 0);
+ assert((unsigned long) vy0 % 128 == 0);
+ assert((unsigned long) vy1 % 128 == 0);
+
+ const uint32_t qk = QK_Q4_0x4x2 * 4;
+
+ const uint32_t x_dblk_size = 8 * 4 * 2; // 32x __fp16
+ const uint32_t x_qblk_size = qk / 2; // int4
+ const uint32_t x_qrow_size = n / 2; // int4 (not padded)
+
+ const uint32_t y_dblk_size = 8 * 4 * 2; // 32x __fp16
+ const uint32_t y_qblk_size = qk; // int8
+ const uint32_t y_qrow_size = n; // int8 (not padded)
+
+ const uint8_t * restrict r0_x_q = ((const uint8_t *) vx0) + 0; // quants first
+ const uint8_t * restrict r0_x_d = ((const uint8_t *) vx0) + x_qrow_size; // then scales
+ const uint8_t * restrict r1_x_q = ((const uint8_t *) vx1) + 0; // quants first
+ const uint8_t * restrict r1_x_d = ((const uint8_t *) vx1) + x_qrow_size; // then scales
+
+ const uint8_t * restrict y0_q = ((const uint8_t *) vy0) + 0; // quants first
+ const uint8_t * restrict y0_d = ((const uint8_t *) vy0) + y_qrow_size; // then scales
+ const uint8_t * restrict y1_q = ((const uint8_t *) vy1) + 0; // quants first
+ const uint8_t * restrict y1_d = ((const uint8_t *) vy1) + y_qrow_size; // then scales
+
+ // Row sums (sf) - 4 accumulators for 2×2 tile
+ HVX_Vector r0_c0_sum = Q6_V_vsplat_R(0);
+ HVX_Vector r0_c1_sum = Q6_V_vsplat_R(0);
+ HVX_Vector r1_c0_sum = Q6_V_vsplat_R(0);
+ HVX_Vector r1_c1_sum = Q6_V_vsplat_R(0);
+
+ const uint32_t nb = n / qk; // num full blocks
+ const uint32_t nloe = n % qk; // num leftover elements
+
+ uint32_t i = 0;
+ for (; i < nb; i++) {
+ // Load src1 columns (reused across both src0 rows)
+ HVX_Vector_x8 vy0_q = hvx_vec_load_q8x4x8(y0_q + i * y_qblk_size);
+ HVX_Vector_x8 vy1_q = hvx_vec_load_q8x4x8(y1_q + i * y_qblk_size);
+
+ // Load src0 rows (reused across both src1 columns)
+ HVX_Vector_x8 r0_q = hvx_vec_load_q4x4x8(r0_x_q + i * x_qblk_size);
+ HVX_Vector_x8 r1_q = hvx_vec_load_q4x4x8(r1_x_q + i * x_qblk_size);
+
+ // Compute 4 dot products: r0×c0, r0×c1, r1×c0, r1×c1
+ HVX_Vector r0_c0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r0_q, vy0_q));
+ HVX_Vector r0_c1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r0_q, vy1_q));
+ HVX_Vector r1_c0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r1_q, vy0_q));
+ HVX_Vector r1_c1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r1_q, vy1_q));
+
+ // Load scales
+ HVX_Vector vy0_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y0_d + i * y_dblk_size));
+ HVX_Vector vy1_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y1_d + i * y_dblk_size));
+ HVX_Vector r0_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r0_x_d + i * x_dblk_size));
+ HVX_Vector r1_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r1_x_d + i * x_dblk_size));
+
+ // Compute combined scales
+ HVX_Vector r0_c0_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_d, vy0_d)));
+ HVX_Vector r0_c1_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_d, vy1_d)));
+ HVX_Vector r1_c0_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r1_d, vy0_d)));
+ HVX_Vector r1_c1_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r1_d, vy1_d)));
+
+ // Apply scales and accumulate
+ HVX_Vector r0_c0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_c0_ia, r0_c0_dd);
+ HVX_Vector r0_c1_fa = Q6_Vqf32_vmpy_VsfVsf(r0_c1_ia, r0_c1_dd);
+ HVX_Vector r1_c0_fa = Q6_Vqf32_vmpy_VsfVsf(r1_c0_ia, r1_c0_dd);
+ HVX_Vector r1_c1_fa = Q6_Vqf32_vmpy_VsfVsf(r1_c1_ia, r1_c1_dd);
+
+ r0_c0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_c0_fa, r0_c0_sum));
+ r0_c1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_c1_fa, r0_c1_sum));
+ r1_c0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_c0_fa, r1_c0_sum));
+ r1_c1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_c1_fa, r1_c1_sum));
+ }
+
+ // Process leftovers
+ if (nloe) {
+ HVX_Vector_x8 vy0_q = hvx_vec_load_q8x4x8(y0_q + i * y_qblk_size);
+ HVX_Vector_x8 vy1_q = hvx_vec_load_q8x4x8(y1_q + i * y_qblk_size);
+ HVX_Vector_x8 r0_q = hvx_vec_load_q4x4x8(r0_x_q + i * x_qblk_size);
+ HVX_Vector_x8 r1_q = hvx_vec_load_q4x4x8(r1_x_q + i * x_qblk_size);
+
+ HVX_Vector r0_c0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_nloe(r0_q, vy0_q, nloe));
+ HVX_Vector r0_c1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_nloe(r0_q, vy1_q, nloe));
+ HVX_Vector r1_c0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_nloe(r1_q, vy0_q, nloe));
+ HVX_Vector r1_c1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_nloe(r1_q, vy1_q, nloe));
+
+ HVX_Vector vy0_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y0_d + i * y_dblk_size));
+ HVX_Vector vy1_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y1_d + i * y_dblk_size));
+ HVX_Vector r0_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r0_x_d + i * x_dblk_size));
+ HVX_Vector r1_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r1_x_d + i * x_dblk_size));
+
+ HVX_Vector r0_c0_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_d, vy0_d)));
+ HVX_Vector r0_c1_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_d, vy1_d)));
+ HVX_Vector r1_c0_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r1_d, vy0_d)));
+ HVX_Vector r1_c1_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r1_d, vy1_d)));
+
+ // Zero out unused scales
+ HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe / 8);
+ r0_c0_dd = Q6_V_vand_QV(bmask, r0_c0_dd);
+ r0_c1_dd = Q6_V_vand_QV(bmask, r0_c1_dd);
+ r1_c0_dd = Q6_V_vand_QV(bmask, r1_c0_dd);
+ r1_c1_dd = Q6_V_vand_QV(bmask, r1_c1_dd);
+ r0_c0_ia = Q6_V_vand_QV(bmask, r0_c0_ia);
+ r0_c1_ia = Q6_V_vand_QV(bmask, r0_c1_ia);
+ r1_c0_ia = Q6_V_vand_QV(bmask, r1_c0_ia);
+ r1_c1_ia = Q6_V_vand_QV(bmask, r1_c1_ia);
+
+ HVX_Vector r0_c0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_c0_ia, r0_c0_dd);
+ HVX_Vector r0_c1_fa = Q6_Vqf32_vmpy_VsfVsf(r0_c1_ia, r0_c1_dd);
+ HVX_Vector r1_c0_fa = Q6_Vqf32_vmpy_VsfVsf(r1_c0_ia, r1_c0_dd);
+ HVX_Vector r1_c1_fa = Q6_Vqf32_vmpy_VsfVsf(r1_c1_ia, r1_c1_dd);
+
+ r0_c0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_c0_fa, r0_c0_sum));
+ r0_c1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_c1_fa, r0_c1_sum));
+ r1_c0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_c0_fa, r1_c0_sum));
+ r1_c1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_c1_fa, r1_c1_sum));
+ }
+
+ // Reduce and store results
+ HVX_Vector r0_r1_c0_sum = hvx_vec_reduce_sum_f32x2(r0_c0_sum, r1_c0_sum);
+ HVX_Vector r0_r1_c1_sum = hvx_vec_reduce_sum_f32x2(r0_c1_sum, r1_c1_sum);
+
+ hvx_vec_store_u(s0, 8, r0_r1_c0_sum); // row0,col0 row1,col0
+ hvx_vec_store_u(s1, 8, r0_r1_c1_sum); // row0,col1 row1,col1
+}
+
+static void vec_dot_q8x4x2_q8x4x2_1x1(const int n, float * restrict s0, const void * restrict vx0, const void * restrict vy0) {
assert(n % 32 == 0); // min sub-block size
- assert((unsigned long) vx % 128 == 0);
- assert((unsigned long) vy % 128 == 0);
+ assert((unsigned long) vx0 % 128 == 0);
+ assert((unsigned long) vy0 % 128 == 0);
const uint32_t qk = QK_Q4_0x4x2 * 4;
const uint32_t y_qblk_size = qk; // int8
const uint32_t y_qrow_size = n; // int8 (not padded)
- const uint8_t * restrict r0_x_q = ((const uint8_t *) vx + 0); // quants first
- const uint8_t * restrict r0_x_d = ((const uint8_t *) vx + x_qrow_size); // then scales
+ const uint8_t * restrict r0_x_q = ((const uint8_t *) vx0 + 0); // quants first
+ const uint8_t * restrict r0_x_d = ((const uint8_t *) vx0 + x_qrow_size); // then scales
- const uint8_t * restrict y_q = ((const uint8_t *) vy + 0); // quants first
- const uint8_t * restrict y_d = ((const uint8_t *) vy + y_qrow_size); // then scales
+ const uint8_t * restrict y_q = ((const uint8_t *) vy0 + 0); // quants first
+ const uint8_t * restrict y_d = ((const uint8_t *) vy0 + y_qrow_size); // then scales
// Row sum (sf)
HVX_Vector r0_sum = Q6_V_vsplat_R(0);
r0_sum = hvx_vec_reduce_sum_f32(r0_sum);
- hvx_vec_store_u(&s[0], 4, r0_sum);
+ hvx_vec_store_u(s0, 4, r0_sum);
}
-static void vec_dot_q8x4x2_q8x4x2_rx2(const int n,
- float * restrict s,
- const void * restrict vx,
- uint32_t vx_row_size,
- const void * restrict vy) {
+static void vec_dot_q8x4x2_q8x4x2_2x1(const int n, float * restrict s0,
+ const void * restrict vx0, const void * restrict vx1,
+ const void * restrict vy0) {
assert(n % 32 == 0); // min sub-block size
- assert((unsigned long) vx % 128 == 0);
- assert((unsigned long) vy % 128 == 0);
+ assert((unsigned long) vx0 % 128 == 0);
+ assert((unsigned long) vx1 % 128 == 0);
+ assert((unsigned long) vy0 % 128 == 0);
const uint32_t qk = QK_Q4_0x4x2 * 4;
- const uint32_t x_dblk_size = 8 * 4 * 2; // 32x __fp16
- const uint32_t x_qblk_size = qk; // int8
- const uint32_t x_qrow_size = n; // int8 (not padded)
+ const uint32_t x_dblk_size = 8 * 4 * 2; // 32x __fp16
+ const uint32_t x_qblk_size = qk; // int8
+ const uint32_t x_qrow_size = n; // int8 (not padded)
- const uint32_t y_dblk_size = 8 * 4 * 2; // 32x __fp16
- const uint32_t y_qblk_size = qk; // int8
- const uint32_t y_qrow_size = n; // int8 (not padded)
+ const uint32_t y_dblk_size = 8 * 4 * 2; // 32x __fp16
+ const uint32_t y_qblk_size = qk; // int8
+ const uint32_t y_qrow_size = n; // int8 (not padded)
- const uint8_t * restrict r0_x_q = ((const uint8_t *) (vx + (0 * vx_row_size)) + 0); // quants first
- const uint8_t * restrict r0_x_d = ((const uint8_t *) (vx + (0 * vx_row_size)) + x_qrow_size); // then scales
+ const uint8_t * restrict r0_x_q = ((const uint8_t *) vx0) + 0; // quants first
+ const uint8_t * restrict r0_x_d = ((const uint8_t *) vx0) + x_qrow_size; // then scales
+ const uint8_t * restrict r1_x_q = ((const uint8_t *) vx1) + 0; // quants first
+ const uint8_t * restrict r1_x_d = ((const uint8_t *) vx1) + x_qrow_size; // then scales
- const uint8_t * restrict r1_x_q = ((const uint8_t *) (vx + (1 * vx_row_size)) + 0); // quants first
- const uint8_t * restrict r1_x_d = ((const uint8_t *) (vx + (1 * vx_row_size)) + x_qrow_size); // then scales
-
- const uint8_t * restrict y_q = ((const uint8_t *) vy + 0); // quants first
- const uint8_t * restrict y_d = ((const uint8_t *) vy + y_qrow_size); // then scales
+ const uint8_t * restrict y_q = ((const uint8_t *) vy0 + 0); // quants first
+ const uint8_t * restrict y_d = ((const uint8_t *) vy0 + y_qrow_size); // then scales
// Row sum (qf32)
HVX_Vector r0_sum = Q6_V_vsplat_R(0);
}
HVX_Vector rsum = hvx_vec_reduce_sum_f32x2(r0_sum, r1_sum);
- hvx_vec_store_u(&s[0], 8, rsum);
+ hvx_vec_store_u(s0, 8, rsum);
+}
+
+static void vec_dot_q8x4x2_q8x4x2_2x2(const int n, float * restrict s0, float * restrict s1,
+ const void * restrict vx0, const void * restrict vx1,
+ const void * restrict vy0, const void * restrict vy1) {
+ assert(n % 32 == 0);
+ assert((unsigned long) vx0 % 128 == 0);
+ assert((unsigned long) vx1 % 128 == 0);
+ assert((unsigned long) vy0 % 128 == 0);
+ assert((unsigned long) vy1 % 128 == 0);
+
+ const uint32_t qk = QK_Q8_0x4x2 * 4;
+
+ const uint32_t x_dblk_size = 8 * 4 * 2; // 32x __fp16
+ const uint32_t x_qblk_size = qk; // int8
+ const uint32_t x_qrow_size = n; // int8 (not padded)
+
+ const uint32_t y_dblk_size = 8 * 4 * 2; // 32x __fp16
+ const uint32_t y_qblk_size = qk; // int8
+ const uint32_t y_qrow_size = n; // int8 (not padded)
+
+ const uint8_t * restrict r0_x_q = ((const uint8_t *) vx0) + 0; // quants first
+ const uint8_t * restrict r0_x_d = ((const uint8_t *) vx0) + x_qrow_size; // then scales
+ const uint8_t * restrict r1_x_q = ((const uint8_t *) vx1) + 0; // quants first
+ const uint8_t * restrict r1_x_d = ((const uint8_t *) vx1) + x_qrow_size; // then scales
+
+ const uint8_t * restrict y0_q = ((const uint8_t *) vy0) + 0; // quants first
+ const uint8_t * restrict y0_d = ((const uint8_t *) vy0) + y_qrow_size; // then scales
+ const uint8_t * restrict y1_q = ((const uint8_t *) vy1) + 0; // quants first
+ const uint8_t * restrict y1_d = ((const uint8_t *) vy1) + y_qrow_size; // then scales
+
+ // Row sums (sf) - 4 accumulators for 2×2 tile
+ HVX_Vector r0_c0_sum = Q6_V_vsplat_R(0);
+ HVX_Vector r0_c1_sum = Q6_V_vsplat_R(0);
+ HVX_Vector r1_c0_sum = Q6_V_vsplat_R(0);
+ HVX_Vector r1_c1_sum = Q6_V_vsplat_R(0);
+
+ const uint32_t nb = n / qk; // num full blocks
+ const uint32_t nloe = n % qk; // num leftover elements
+
+ uint32_t i = 0;
+ for (; i < nb; i++) {
+ // Load src1 columns (reused across both src0 rows)
+ HVX_Vector_x8 vy0_q = hvx_vec_load_q8x4x8(y0_q + i * y_qblk_size);
+ HVX_Vector_x8 vy1_q = hvx_vec_load_q8x4x8(y1_q + i * y_qblk_size);
+
+ // Load src0 rows (reused across both src1 columns)
+ HVX_Vector_x8 r0_q = hvx_vec_load_q8x4x8(r0_x_q + i * x_qblk_size);
+ HVX_Vector_x8 r1_q = hvx_vec_load_q8x4x8(r1_x_q + i * x_qblk_size);
+
+ // Compute 4 dot products: r0×c0, r0×c1, r1×c0, r1×c1
+ HVX_Vector r0_c0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r0_q, vy0_q));
+ HVX_Vector r0_c1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r0_q, vy1_q));
+ HVX_Vector r1_c0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r1_q, vy0_q));
+ HVX_Vector r1_c1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r1_q, vy1_q));
+
+ // Load scales
+ HVX_Vector vy0_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y0_d + i * y_dblk_size));
+ HVX_Vector vy1_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y1_d + i * y_dblk_size));
+ HVX_Vector r0_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r0_x_d + i * x_dblk_size));
+ HVX_Vector r1_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r1_x_d + i * x_dblk_size));
+
+ // Compute combined scales
+ HVX_Vector r0_c0_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_d, vy0_d)));
+ HVX_Vector r0_c1_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_d, vy1_d)));
+ HVX_Vector r1_c0_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r1_d, vy0_d)));
+ HVX_Vector r1_c1_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r1_d, vy1_d)));
+
+ // Apply scales and accumulate
+ HVX_Vector r0_c0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_c0_ia, r0_c0_dd);
+ HVX_Vector r0_c1_fa = Q6_Vqf32_vmpy_VsfVsf(r0_c1_ia, r0_c1_dd);
+ HVX_Vector r1_c0_fa = Q6_Vqf32_vmpy_VsfVsf(r1_c0_ia, r1_c0_dd);
+ HVX_Vector r1_c1_fa = Q6_Vqf32_vmpy_VsfVsf(r1_c1_ia, r1_c1_dd);
+
+ r0_c0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_c0_fa, r0_c0_sum));
+ r0_c1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_c1_fa, r0_c1_sum));
+ r1_c0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_c0_fa, r1_c0_sum));
+ r1_c1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_c1_fa, r1_c1_sum));
+ }
+
+ // Process leftovers
+ if (nloe) {
+ HVX_Vector_x8 vy0_q = hvx_vec_load_q8x4x8(y0_q + i * y_qblk_size);
+ HVX_Vector_x8 vy1_q = hvx_vec_load_q8x4x8(y1_q + i * y_qblk_size);
+ HVX_Vector_x8 r0_q = hvx_vec_load_q8x4x8(r0_x_q + i * x_qblk_size);
+ HVX_Vector_x8 r1_q = hvx_vec_load_q8x4x8(r1_x_q + i * x_qblk_size);
+
+ HVX_Vector r0_c0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_nloe(r0_q, vy0_q, nloe));
+ HVX_Vector r0_c1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_nloe(r0_q, vy1_q, nloe));
+ HVX_Vector r1_c0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_nloe(r1_q, vy0_q, nloe));
+ HVX_Vector r1_c1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_nloe(r1_q, vy1_q, nloe));
+
+ HVX_Vector vy0_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y0_d + i * y_dblk_size));
+ HVX_Vector vy1_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y1_d + i * y_dblk_size));
+ HVX_Vector r0_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r0_x_d + i * x_dblk_size));
+ HVX_Vector r1_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r1_x_d + i * x_dblk_size));
+
+ HVX_Vector r0_c0_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_d, vy0_d)));
+ HVX_Vector r0_c1_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_d, vy1_d)));
+ HVX_Vector r1_c0_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r1_d, vy0_d)));
+ HVX_Vector r1_c1_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r1_d, vy1_d)));
+
+ // Zero out unused scales
+ HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe / 8);
+ r0_c0_dd = Q6_V_vand_QV(bmask, r0_c0_dd);
+ r0_c1_dd = Q6_V_vand_QV(bmask, r0_c1_dd);
+ r1_c0_dd = Q6_V_vand_QV(bmask, r1_c0_dd);
+ r1_c1_dd = Q6_V_vand_QV(bmask, r1_c1_dd);
+ r0_c0_ia = Q6_V_vand_QV(bmask, r0_c0_ia);
+ r0_c1_ia = Q6_V_vand_QV(bmask, r0_c1_ia);
+ r1_c0_ia = Q6_V_vand_QV(bmask, r1_c0_ia);
+ r1_c1_ia = Q6_V_vand_QV(bmask, r1_c1_ia);
+
+ HVX_Vector r0_c0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_c0_ia, r0_c0_dd);
+ HVX_Vector r0_c1_fa = Q6_Vqf32_vmpy_VsfVsf(r0_c1_ia, r0_c1_dd);
+ HVX_Vector r1_c0_fa = Q6_Vqf32_vmpy_VsfVsf(r1_c0_ia, r1_c0_dd);
+ HVX_Vector r1_c1_fa = Q6_Vqf32_vmpy_VsfVsf(r1_c1_ia, r1_c1_dd);
+
+ r0_c0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_c0_fa, r0_c0_sum));
+ r0_c1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_c1_fa, r0_c1_sum));
+ r1_c0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_c0_fa, r1_c0_sum));
+ r1_c1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_c1_fa, r1_c1_sum));
+ }
+
+ // Reduce and store results
+ HVX_Vector r0_r1_c0_sum = hvx_vec_reduce_sum_f32x2(r0_c0_sum, r1_c0_sum);
+ HVX_Vector r0_r1_c1_sum = hvx_vec_reduce_sum_f32x2(r0_c1_sum, r1_c1_sum);
+
+ hvx_vec_store_u(&s0[0], 8, r0_r1_c0_sum); // row0,col0 row1,col0
+ hvx_vec_store_u(&s1[0], 8, r0_r1_c1_sum); // row0,col1 row1,col1
}
-static void vec_dot_mxfp4x4x2_q8x4x2(const int n,
- float * restrict s,
- const void * restrict vx,
- const void * restrict vy) {
+static void vec_dot_mxfp4x4x2_q8x4x2_1x1(const int n, float * restrict s0, const void * restrict vx0, const void * restrict vy0) {
assert(n % 32 == 0); // min sub-block size
- assert((unsigned long) vx % 128 == 0);
- assert((unsigned long) vy % 128 == 0);
+ assert((unsigned long) vx0 % 128 == 0);
+ assert((unsigned long) vy0 % 128 == 0);
const uint32_t qk = QK_MXFP4x4x2 * 4;
const uint32_t y_qblk_size = qk; // int8
const uint32_t y_qrow_size = n; // int8 (not padded)
- const uint8_t * restrict r0_x_q = ((const uint8_t *) vx + 0); // quants first
- const uint8_t * restrict r0_x_d = ((const uint8_t *) vx + x_qrow_size); // then scales
+ const uint8_t * restrict r0_x_q = ((const uint8_t *) vx0 + 0); // quants first
+ const uint8_t * restrict r0_x_d = ((const uint8_t *) vx0 + x_qrow_size); // then scales
- const uint8_t * restrict y_q = ((const uint8_t *) vy + 0); // quants first
- const uint8_t * restrict y_d = ((const uint8_t *) vy + y_qrow_size); // then scales
+ const uint8_t * restrict y_q = ((const uint8_t *) vy0 + 0); // quants first
+ const uint8_t * restrict y_d = ((const uint8_t *) vy0 + y_qrow_size); // then scales
// Row sum (sf)
HVX_Vector r0_sum = Q6_V_vsplat_R(0);
r0_sum = hvx_vec_reduce_sum_f32(r0_sum);
- hvx_vec_store_u(&s[0], 4, r0_sum);
+ hvx_vec_store_u(s0, 4, r0_sum);
}
-static void vec_dot_mxfp4x4x2_q8x4x2_rx2(const int n,
- float * restrict s,
- const void * restrict vx,
- uint32_t vx_row_size,
- const void * restrict vy) {
+static void vec_dot_mxfp4x4x2_q8x4x2_2x1(const int n, float * restrict s0,
+ const void * restrict vx0, const void * restrict vx1,
+ const void * restrict vy0) {
assert(n % 32 == 0); // min sub-block size
- assert((unsigned long) vx % 128 == 0);
- assert((unsigned long) vy % 128 == 0);
+ assert((unsigned long) vx0 % 128 == 0);
+ assert((unsigned long) vx1 % 128 == 0);
+ assert((unsigned long) vy0 % 128 == 0);
const uint32_t qk = QK_MXFP4x4x2 * 4;
- const uint32_t x_dblk_size = 8 * 4 * 1; // 32x e8m0
- const uint32_t x_qblk_size = qk / 2; // fp4
- const uint32_t x_qrow_size = n / 2; // fp4 (not padded)
+ const uint32_t x_dblk_size = 8 * 4 * 1; // 32x e8m0
+ const uint32_t x_qblk_size = qk / 2; // fp4
+ const uint32_t x_qrow_size = n / 2; // fp4 (not padded)
- const uint32_t y_dblk_size = 8 * 4 * 2; // 32x __fp16
- const uint32_t y_qblk_size = qk; // int8
- const uint32_t y_qrow_size = n; // int8 (not padded)
+ const uint32_t y_dblk_size = 8 * 4 * 2; // 32x __fp16
+ const uint32_t y_qblk_size = qk; // int8
+ const uint32_t y_qrow_size = n; // int8 (not padded)
- const uint8_t * restrict r0_x_q = ((const uint8_t *) (vx + (0 * vx_row_size)) + 0); // quants first
- const uint8_t * restrict r0_x_d = ((const uint8_t *) (vx + (0 * vx_row_size)) + x_qrow_size); // then scales
+ const uint8_t * restrict r0_x_q = ((const uint8_t *) vx0) + 0; // quants first
+ const uint8_t * restrict r0_x_d = ((const uint8_t *) vx0) + x_qrow_size; // then scales
+ const uint8_t * restrict r1_x_q = ((const uint8_t *) vx1) + 0; // quants first
+ const uint8_t * restrict r1_x_d = ((const uint8_t *) vx1) + x_qrow_size; // then scales
- const uint8_t * restrict r1_x_q = ((const uint8_t *) (vx + (1 * vx_row_size)) + 0); // quants first
- const uint8_t * restrict r1_x_d = ((const uint8_t *) (vx + (1 * vx_row_size)) + x_qrow_size); // then scales
-
- const uint8_t * restrict y_q = ((const uint8_t *) vy + 0); // quants first
- const uint8_t * restrict y_d = ((const uint8_t *) vy + y_qrow_size); // then scales
+ const uint8_t * restrict y_q = ((const uint8_t *) vy0) + 0; // quants first
+ const uint8_t * restrict y_d = ((const uint8_t *) vy0) + y_qrow_size; // then scales
// Row sum (sf)
HVX_Vector r0_sum = Q6_V_vsplat_R(0);
}
HVX_Vector rsum = hvx_vec_reduce_sum_f32x2(r0_sum, r1_sum);
- hvx_vec_store_u(&s[0], 8, rsum);
+ hvx_vec_store_u(s0, 8, rsum);
}
-static void vec_dot_f16_f16_aa(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) {
+static void vec_dot_mxfp4x4x2_q8x4x2_2x2(const int n, float * restrict s0, float * restrict s1,
+ const void * restrict vx0, const void * restrict vx1,
+ const void * restrict vy0, const void * restrict vy1) {
+ assert(n % 32 == 0);
+ assert((unsigned long) vx0 % 128 == 0);
+ assert((unsigned long) vx1 % 128 == 0);
+ assert((unsigned long) vy0 % 128 == 0);
+ assert((unsigned long) vy1 % 128 == 0);
+
+ const uint32_t qk = QK_MXFP4x4x2 * 4;
+
+ const uint32_t x_dblk_size = 8 * 4 * 1; // 32x e8m0
+ const uint32_t x_qblk_size = qk / 2; // fp4
+ const uint32_t x_qrow_size = n / 2; // fp4 (not padded)
+
+ const uint32_t y_dblk_size = 8 * 4 * 2; // 32x __fp16
+ const uint32_t y_qblk_size = qk; // int8
+ const uint32_t y_qrow_size = n; // int8 (not padded)
+
+ const uint8_t * restrict r0_x_q = ((const uint8_t *) vx0) + 0; // quants first
+ const uint8_t * restrict r0_x_d = ((const uint8_t *) vx0) + x_qrow_size; // then scales
+ const uint8_t * restrict r1_x_q = ((const uint8_t *) vx1) + 0; // quants first
+ const uint8_t * restrict r1_x_d = ((const uint8_t *) vx1) + x_qrow_size; // then scales
+
+ const uint8_t * restrict y0_q = ((const uint8_t *) vy0) + 0; // quants first
+ const uint8_t * restrict y0_d = ((const uint8_t *) vy0) + y_qrow_size; // then scales
+ const uint8_t * restrict y1_q = ((const uint8_t *) vy1) + 0; // quants first
+ const uint8_t * restrict y1_d = ((const uint8_t *) vy1) + y_qrow_size; // then scales
+
+ // Row sums (sf) - 4 accumulators for 2×2 tile
+ HVX_Vector r0_c0_sum = Q6_V_vsplat_R(0);
+ HVX_Vector r0_c1_sum = Q6_V_vsplat_R(0);
+ HVX_Vector r1_c0_sum = Q6_V_vsplat_R(0);
+ HVX_Vector r1_c1_sum = Q6_V_vsplat_R(0);
+
+ const uint32_t nb = n / qk; // num full blocks
+ const uint32_t nloe = n % qk; // num leftover elements
+
+ uint32_t i = 0;
+ for (; i < nb; i++) {
+ // Load src1 columns (reused across both src0 rows)
+ HVX_Vector_x8 vy0_q = hvx_vec_load_q8x4x8(y0_q + i * y_qblk_size);
+ HVX_Vector_x8 vy1_q = hvx_vec_load_q8x4x8(y1_q + i * y_qblk_size);
+
+ // Load src0 rows (reused across both src1 columns)
+ HVX_Vector_x8 r0_q = hvx_vec_load_mxfp4x4x8(r0_x_q + i * x_qblk_size);
+ HVX_Vector_x8 r1_q = hvx_vec_load_mxfp4x4x8(r1_x_q + i * x_qblk_size);
+
+ // Compute 4 dot products: r0×c0, r0×c1, r1×c0, r1×c1
+ HVX_Vector r0_c0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r0_q, vy0_q));
+ HVX_Vector r0_c1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r0_q, vy1_q));
+ HVX_Vector r1_c0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r1_q, vy0_q));
+ HVX_Vector r1_c1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r1_q, vy1_q));
+
+ // Load scales
+ HVX_Vector vy0_d = *(const HVX_UVector *) (y0_d + i * y_dblk_size);
+ HVX_Vector vy1_d = *(const HVX_UVector *) (y1_d + i * y_dblk_size);
+ HVX_Vector r0_d = *(const HVX_UVector *) (r0_x_d + i * x_dblk_size);
+ HVX_Vector r1_d = *(const HVX_UVector *) (r1_x_d + i * x_dblk_size);
+
+ // Convert vy_d from fp16 to fp32 while applying 0.5 scaling which is used for e8m0 halving
+ HVX_Vector half = Q6_Vh_vsplat_R(0x3800); // 0.5 in fp16
+ vy0_d = Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(Q6_Vh_vshuff_Vh(vy0_d), half));
+ vy0_d = Q6_Vsf_equals_Vqf32(vy0_d);
+ vy1_d = Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(Q6_Vh_vshuff_Vh(vy1_d), half));
+ vy1_d = Q6_Vsf_equals_Vqf32(vy1_d);
+
+ // Convert rX_d scales from e8m0 to fp32
+ // Expand and zero-pad 32x uint8 e8m0 values to uint32s : 0 0 0 0, 0 0 0 1, 0 0 0 2, ...
+ // Left shift with zero fill to create FP32
+ // FIXME: might need to handle zero as a special case (see ggml-cpu code)
+ HVX_Vector expand = *(const HVX_Vector *) expand_x32_e8m0;
+ HVX_Vector e8m0_mask = Q6_V_vsplat_R(0x000000ff);
+ r0_d = Q6_V_vdelta_VV(r0_d, expand);
+ r0_d = Q6_V_vand_VV(r0_d, e8m0_mask);
+ r0_d = Q6_Vw_vasl_VwR(r0_d, 23);
+ r1_d = Q6_V_vdelta_VV(r1_d, expand);
+ r1_d = Q6_V_vand_VV(r1_d, e8m0_mask);
+ r1_d = Q6_Vw_vasl_VwR(r1_d, 23);
+
+ // Compute combined scales
+ HVX_Vector r0_c0_dd = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(r0_d, vy0_d));
+ HVX_Vector r0_c1_dd = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(r0_d, vy1_d));
+ HVX_Vector r1_c0_dd = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(r1_d, vy0_d));
+ HVX_Vector r1_c1_dd = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(r1_d, vy1_d));
+
+ // Apply scales and accumulate
+ HVX_Vector r0_c0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_c0_ia, r0_c0_dd);
+ HVX_Vector r0_c1_fa = Q6_Vqf32_vmpy_VsfVsf(r0_c1_ia, r0_c1_dd);
+ HVX_Vector r1_c0_fa = Q6_Vqf32_vmpy_VsfVsf(r1_c0_ia, r1_c0_dd);
+ HVX_Vector r1_c1_fa = Q6_Vqf32_vmpy_VsfVsf(r1_c1_ia, r1_c1_dd);
+
+ r0_c0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_c0_fa, r0_c0_sum));
+ r0_c1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_c1_fa, r0_c1_sum));
+ r1_c0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_c0_fa, r1_c0_sum));
+ r1_c1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_c1_fa, r1_c1_sum));
+ }
+
+ // Process leftovers
+ if (nloe) {
+ HVX_Vector_x8 vy0_q = hvx_vec_load_q8x4x8(y0_q + i * y_qblk_size);
+ HVX_Vector_x8 vy1_q = hvx_vec_load_q8x4x8(y1_q + i * y_qblk_size);
+ HVX_Vector_x8 r0_q = hvx_vec_load_mxfp4x4x8(r0_x_q + i * x_qblk_size);
+ HVX_Vector_x8 r1_q = hvx_vec_load_mxfp4x4x8(r1_x_q + i * x_qblk_size);
+
+ HVX_Vector r0_c0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_nloe(r0_q, vy0_q, nloe));
+ HVX_Vector r0_c1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_nloe(r0_q, vy1_q, nloe));
+ HVX_Vector r1_c0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_nloe(r1_q, vy0_q, nloe));
+ HVX_Vector r1_c1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_nloe(r1_q, vy1_q, nloe));
+
+ HVX_Vector vy0_d = *(const HVX_UVector *) (y0_d + i * y_dblk_size);
+ HVX_Vector vy1_d = *(const HVX_UVector *) (y1_d + i * y_dblk_size);
+ HVX_Vector r0_d = *(const HVX_UVector *) (r0_x_d + i * x_dblk_size);
+ HVX_Vector r1_d = *(const HVX_UVector *) (r1_x_d + i * x_dblk_size);
+
+ // Convert vy_d from fp16 to fp32 while applying 0.5 scaling which is used for e8m0 halving
+ HVX_Vector half = Q6_Vh_vsplat_R(0x3800); // 0.5 in fp16
+ vy0_d = Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(Q6_Vh_vshuff_Vh(vy0_d), half));
+ vy0_d = Q6_Vsf_equals_Vqf32(vy0_d);
+ vy1_d = Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(Q6_Vh_vshuff_Vh(vy1_d), half));
+ vy1_d = Q6_Vsf_equals_Vqf32(vy1_d);
+
+ // Convert rX_d scales from e8m0 to fp32
+ // Expand and zero-pad 32x uint8 e8m0 values to uint32s : 0 0 0 0, 0 0 0 1, 0 0 0 2, ...
+ // Left shift with zero fill to create FP32
+ // FIXME: might need to handle zero as a special case (see ggml-cpu code)
+ HVX_Vector expand = *(const HVX_Vector *) expand_x32_e8m0;
+ HVX_Vector e8m0_mask = Q6_V_vsplat_R(0x000000ff);
+ r0_d = Q6_V_vdelta_VV(r0_d, expand);
+ r0_d = Q6_V_vand_VV(r0_d, e8m0_mask);
+ r0_d = Q6_Vw_vasl_VwR(r0_d, 23);
+ r1_d = Q6_V_vdelta_VV(r1_d, expand);
+ r1_d = Q6_V_vand_VV(r1_d, e8m0_mask);
+ r1_d = Q6_Vw_vasl_VwR(r1_d, 23);
+
+ HVX_Vector r0_c0_dd = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(r0_d, vy0_d));
+ HVX_Vector r0_c1_dd = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(r0_d, vy1_d));
+ HVX_Vector r1_c0_dd = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(r1_d, vy0_d));
+ HVX_Vector r1_c1_dd = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(r1_d, vy1_d));
+
+ // Zero out unused scales
+ HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe / 8);
+ r0_c0_dd = Q6_V_vand_QV(bmask, r0_c0_dd);
+ r0_c1_dd = Q6_V_vand_QV(bmask, r0_c1_dd);
+ r1_c0_dd = Q6_V_vand_QV(bmask, r1_c0_dd);
+ r1_c1_dd = Q6_V_vand_QV(bmask, r1_c1_dd);
+ r0_c0_ia = Q6_V_vand_QV(bmask, r0_c0_ia);
+ r0_c1_ia = Q6_V_vand_QV(bmask, r0_c1_ia);
+ r1_c0_ia = Q6_V_vand_QV(bmask, r1_c0_ia);
+ r1_c1_ia = Q6_V_vand_QV(bmask, r1_c1_ia);
+
+ HVX_Vector r0_c0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_c0_ia, r0_c0_dd);
+ HVX_Vector r0_c1_fa = Q6_Vqf32_vmpy_VsfVsf(r0_c1_ia, r0_c1_dd);
+ HVX_Vector r1_c0_fa = Q6_Vqf32_vmpy_VsfVsf(r1_c0_ia, r1_c0_dd);
+ HVX_Vector r1_c1_fa = Q6_Vqf32_vmpy_VsfVsf(r1_c1_ia, r1_c1_dd);
+
+ r0_c0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_c0_fa, r0_c0_sum));
+ r0_c1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_c1_fa, r0_c1_sum));
+ r1_c0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_c0_fa, r1_c0_sum));
+ r1_c1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_c1_fa, r1_c1_sum));
+ }
+
+ // Reduce and store results
+ HVX_Vector r0_r1_c0_sum = hvx_vec_reduce_sum_f32x2(r0_c0_sum, r1_c0_sum);
+ HVX_Vector r0_r1_c1_sum = hvx_vec_reduce_sum_f32x2(r0_c1_sum, r1_c1_sum);
+
+ hvx_vec_store_u(&s0[0], 8, r0_r1_c0_sum); // row0,col0 row1,col0
+ hvx_vec_store_u(&s1[0], 8, r0_r1_c1_sum); // row0,col1 row1,col1
+}
+
+static void vec_dot_f16_f16_aa_1x1(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) {
const HVX_Vector * restrict x = (const HVX_Vector *) vx;
const HVX_Vector * restrict y = (const HVX_Vector *) vy;
hvx_vec_store_u(&s[0], 4, rsum);
}
-static void vec_dot_f16_f16_aa_rx2(const int n,
- float * restrict s,
- const void * restrict vx,
- uint32_t vx_row_size,
- const void * restrict vy) {
- const HVX_Vector * restrict x0 = (const HVX_Vector *) vx;
- const HVX_Vector * restrict x1 = (const HVX_Vector *) ((const uint8_t *) vx + vx_row_size);
- const HVX_Vector * restrict y = (const HVX_Vector *) vy;
+static void vec_dot_f16_f16_aa_2x1(const int n, float * restrict s0,
+ const void * restrict vx0, const void * restrict vx1,
+ const void * restrict vy0) {
+ const HVX_Vector * restrict x0 = (const HVX_Vector *) vx0;
+ const HVX_Vector * restrict x1 = (const HVX_Vector *) vx1;
+ const HVX_Vector * restrict y = (const HVX_Vector *) vy0;
uint32_t nvec = n / VLEN_FP16;
uint32_t nloe = n % VLEN_FP16;
}
HVX_Vector rsum = hvx_vec_reduce_sum_f32x2(Q6_Vsf_equals_Vqf32(rsum0), Q6_Vsf_equals_Vqf32(rsum1));
- hvx_vec_store_u(&s[0], 8, rsum);
+ hvx_vec_store_u(s0, 8, rsum);
+}
+
+static void vec_dot_f16_f16_aa_2x2(const int n, float * restrict s0, float * restrict s1,
+ const void * restrict vx0, const void * restrict vx1,
+ const void * restrict vy0, const void * restrict vy1) {
+ const HVX_Vector * restrict x0 = (const HVX_Vector *) vx0;
+ const HVX_Vector * restrict x1 = (const HVX_Vector *) vx1;
+ const HVX_Vector * restrict y0 = (const HVX_Vector *) vy0;
+ const HVX_Vector * restrict y1 = (const HVX_Vector *) vy1;
+
+ uint32_t nvec = n / VLEN_FP16;
+ uint32_t nloe = n % VLEN_FP16;
+
+ // Row sums (sf) - 4 accumulators for 2×2 tile
+ HVX_Vector r0_c0_sum = Q6_V_vsplat_R(0);
+ HVX_Vector r0_c1_sum = Q6_V_vsplat_R(0);
+ HVX_Vector r1_c0_sum = Q6_V_vsplat_R(0);
+ HVX_Vector r1_c1_sum = Q6_V_vsplat_R(0);
+
+ uint32_t i = 0;
+
+ #pragma unroll(2)
+ for (i = 0; i < nvec; i++) {
+ HVX_Vector r0_hf = x0[i];
+ HVX_Vector r1_hf = x1[i];
+ HVX_Vector c0_hf = y0[i];
+ HVX_Vector c1_hf = y1[i];
+
+ // Compute 4 dot products: r0×c0, r0×c1, r1×c0, r1×c1
+ HVX_VectorPair r0_c0_qf_p = Q6_Wqf32_vmpy_VhfVhf(r0_hf, c0_hf);
+ HVX_VectorPair r0_c1_qf_p = Q6_Wqf32_vmpy_VhfVhf(r0_hf, c1_hf);
+ HVX_VectorPair r1_c0_qf_p = Q6_Wqf32_vmpy_VhfVhf(r1_hf, c0_hf);
+ HVX_VectorPair r1_c1_qf_p = Q6_Wqf32_vmpy_VhfVhf(r1_hf, c1_hf);
+
+ HVX_Vector r0_c0_qf = Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(r0_c0_qf_p), Q6_V_hi_W(r0_c0_qf_p));
+ HVX_Vector r0_c1_qf = Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(r0_c1_qf_p), Q6_V_hi_W(r0_c1_qf_p));
+ HVX_Vector r1_c0_qf = Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(r1_c0_qf_p), Q6_V_hi_W(r1_c0_qf_p));
+ HVX_Vector r1_c1_qf = Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(r1_c1_qf_p), Q6_V_hi_W(r1_c1_qf_p));
+
+ r0_c0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_c0_qf, r0_c0_sum));
+ r0_c1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_c1_qf, r0_c1_sum));
+ r1_c0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_c0_qf, r1_c0_sum));
+ r1_c1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_c1_qf, r1_c1_sum));
+ }
+
+ if (nloe) {
+ HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe * 2);
+
+ HVX_Vector r0_hf = Q6_V_vand_QV(bmask, x0[i]);
+ HVX_Vector r1_hf = Q6_V_vand_QV(bmask, x1[i]);
+ HVX_Vector c0_hf = Q6_V_vand_QV(bmask, y0[i]);
+ HVX_Vector c1_hf = Q6_V_vand_QV(bmask, y1[i]);
+
+ HVX_VectorPair r0_c0_qf_p = Q6_Wqf32_vmpy_VhfVhf(r0_hf, c0_hf);
+ HVX_VectorPair r0_c1_qf_p = Q6_Wqf32_vmpy_VhfVhf(r0_hf, c1_hf);
+ HVX_VectorPair r1_c0_qf_p = Q6_Wqf32_vmpy_VhfVhf(r1_hf, c0_hf);
+ HVX_VectorPair r1_c1_qf_p = Q6_Wqf32_vmpy_VhfVhf(r1_hf, c1_hf);
+
+ HVX_Vector r0_c0_qf = Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(r0_c0_qf_p), Q6_V_hi_W(r0_c0_qf_p));
+ HVX_Vector r0_c1_qf = Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(r0_c1_qf_p), Q6_V_hi_W(r0_c1_qf_p));
+ HVX_Vector r1_c0_qf = Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(r1_c0_qf_p), Q6_V_hi_W(r1_c0_qf_p));
+ HVX_Vector r1_c1_qf = Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(r1_c1_qf_p), Q6_V_hi_W(r1_c1_qf_p));
+
+ r0_c0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_c0_qf, r0_c0_sum));
+ r0_c1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_c1_qf, r0_c1_sum));
+ r1_c0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_c0_qf, r1_c0_sum));
+ r1_c1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_c1_qf, r1_c1_sum));
+
+ }
+
+ // Reduce and store results
+ HVX_Vector r0_r1_c0_sum = hvx_vec_reduce_sum_f32x2(r0_c0_sum, r1_c0_sum);
+ HVX_Vector r0_r1_c1_sum = hvx_vec_reduce_sum_f32x2(r0_c1_sum, r1_c1_sum);
+
+ hvx_vec_store_u(&s0[0], 8, r0_r1_c0_sum); // row0,col0 row1,col0
+ hvx_vec_store_u(&s1[0], 8, r0_r1_c1_sum); // row0,col1 row1,col1
}
-static void vec_dot_f16_f16_uu(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) {
+static void vec_dot_f16_f16_uu_1x1(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) {
const HVX_UVector * restrict x = (const HVX_UVector *) vx;
const HVX_UVector * restrict y = (const HVX_UVector *) vy;
hvx_vec_store_u(&s[0], 4, rsum);
}
-static void vec_dot_f16_f32_uu(const int n, float * restrict s, const void * restrict x, const void * restrict y) {
+static void vec_dot_f16_f32_uu_1x1(const int n, float * restrict s, const void * restrict x, const void * restrict y) {
const HVX_UVector * restrict vx = (const HVX_UVector * restrict) x;
const HVX_UVector * restrict vy = (const HVX_UVector * restrict) y;
const uint32_t nb2 = dst->nb[2]; \
const uint32_t nb3 = dst->nb[3];
-#define htp_matmul_preamble \
- htp_matmul_tensors_preamble; \
- dma_queue *dma_queue = octx->ctx->dma[ith]; \
- uint32_t src0_nrows_per_thread = octx->src0_nrows_per_thread;
+#define htp_matmul_preamble \
+ struct htp_matmul_context * mmctx = data; \
+ struct htp_ops_context * octx = mmctx->octx; \
+ htp_matmul_tensors_preamble; \
+ dma_queue *dma_queue = octx->ctx->dma[ith]; \
+ uint32_t src0_nrows_per_thread = mmctx->src0_nrows_per_thread;
// *** matmul with support for 4d tensors and full broadcasting
-static void matmul_4d(struct htp_matmul_type * mt, struct htp_ops_context * octx, uint32_t nth, uint32_t ith) {
+static void matmul_4d(unsigned int nth, unsigned int ith, void * data) {
htp_matmul_preamble;
uint64_t t1, t2;
for (uint32_t iir1 = ir1_start; iir1 < ir1_end; iir1 += blck_1) {
for (uint32_t iir0 = ir0_start; iir0 < ir0_end; iir0 += blck_0) {
for (uint32_t ir1 = iir1; ir1 < MIN(iir1 + blck_1, ir1_end); ir1++) {
- const uint32_t i13 = fastdiv(ir1, &octx->mm_div_ne12_ne1);
- const uint32_t i12 = fastdiv(ir1 - i13 * ne12 * ne1, &octx->mm_div_ne1);
+ const uint32_t i13 = fastdiv(ir1, &mmctx->mm_div_ne12_ne1);
+ const uint32_t i12 = fastdiv(ir1 - i13 * ne12 * ne1, &mmctx->mm_div_ne1);
const uint32_t i11 = (ir1 - i13 * ne12 * ne1 - i12 * ne1);
// broadcast src0 into src1
- const uint32_t i03 = fastdiv(i13, &octx->mm_div_r3);
- const uint32_t i02 = fastdiv(i12, &octx->mm_div_r2);
+ const uint32_t i03 = fastdiv(i13, &mmctx->mm_div_r3);
+ const uint32_t i02 = fastdiv(i12, &mmctx->mm_div_r2);
const uint32_t i1 = i11;
const uint32_t i2 = i12;
const uint32_t ir0_block_end = MIN(iir0 + blck_0, ir0_end);
for (uint32_t ir0 = iir0; ir0 < ir0_block_end; ir0++) {
const uint8_t * restrict src0_row = src0_base + ir0 * nb01;
- mt->vec_dot(ne00, &dst_col[ir0], src0_row, src1_col);
+ mmctx->vec_dot_1x1(ne00, &dst_col[ir0], src0_row, src1_col);
}
}
}
}
// src1 tensor is already in VTCM spad
-static void matmul_2d(struct htp_matmul_type * mt, struct htp_ops_context * octx, uint32_t nth, uint32_t ith) {
+static void matmul_2d(unsigned int nth, unsigned int ith, void * data) {
htp_matmul_preamble;
const uint32_t src0_nrows = ne01 * ne02 * ne03; // src0 rows
// Per-thread VTCM scratchpads for all tensors
// Note that the entire src1 tensor is already in VTCM
// For other tensors we allocate N rows per thread, padded to HVX vector size
- uint8_t * restrict spad_dst = dst_spad->data + dst_spad->size_per_thread * ith;
+ uint8_t * restrict spad_dst = dst_spad->data + dst_spad->size_per_thread * ith;
uint8_t * restrict spad_src0 = src0_spad->data + src0_spad->size_per_thread * ith;
uint8_t * restrict src1_data = src1_spad->data;
for (uint32_t ir0 = src0_start_row; ir0 < src0_end_row_x2; ir0 += 2) {
const uint8_t * ss0 = dma_queue_pop(dma_queue).dst;
- #pragma unroll(2)
- for (uint32_t ir1 = 0; ir1 < src1_nrows; ++ir1) {
+ // Process src1 columns in pairs (2×2 tiling)
+ uint32_t ir1 = 0;
+ for (; ir1 + 1 < src1_nrows; ir1 += 2) {
+ const uint8_t * restrict src1_col0 = (const uint8_t *) (src1_data + (ir1+0) * src1_stride);
+ const uint8_t * restrict src1_col1 = (const uint8_t *) (src1_data + (ir1+1) * src1_stride);
+ float * restrict dst_row0 = (float *) (dst->data + ((ir1+0) * dst_row_size));
+ float * restrict dst_row1 = (float *) (dst->data + ((ir1+1) * dst_row_size));
+ mmctx->vec_dot_2x2(ne00, &dst_row0[ir0], &dst_row1[ir0], ss0, ss0 + src0_stride, src1_col0, src1_col1);
+ }
+
+ // Handle remaining src1 rows (fallback to 2×1)
+ for (; ir1 < src1_nrows; ++ir1) {
const uint8_t * restrict src1_col = (const uint8_t *) (src1_data + ir1 * src1_stride);
float * restrict dst_row = (float *) (dst->data + (ir1 * dst_row_size));
- mt->vec_dot_rx2(ne00, &dst_row[ir0], ss0, src0_stride, src1_col);
+ mmctx->vec_dot_2x1(ne00, &dst_row[ir0], ss0, ss0 + src0_stride, src1_col);
}
// Prefetch next (n + spad_nrows) row
for (uint32_t ir1 = 0; ir1 < src1_nrows; ++ir1) {
const uint8_t * restrict src1_col = (const uint8_t *) (src1_data + ir1 * src1_stride);
float * restrict dst_row = (float *) (dst->data + (ir1 * dst_row_size));
- mt->vec_dot(ne00, &dst_row[ir0], ss0, src1_col);
+ mmctx->vec_dot_1x1(ne00, &dst_row[ir0], ss0, src1_col);
}
}
t2 = HAP_perf_get_qtimer_count();
- FARF(HIGH, "matmul-%s %d/%d: %ux%ux%ux%u (%u:%u) * %ux%ux%ux%u -> %ux%ux%ux%u usec %u\n", mt->type, ith, nth,
+ FARF(HIGH, "matmul-%s %d/%d: %ux%ux%ux%u (%u:%u) * %ux%ux%ux%u -> %ux%ux%ux%u usec %u\n", mmctx->type, ith, nth,
src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3], src0_start_row, src0_end_row, src1->ne[0], src1->ne[1],
src1->ne[2], src1->ne[3], dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3],
(unsigned) HAP_perf_qtimer_count_to_us(t2 - t1));
}
// q8x4x2 src1 tensor is already in VTCM spad
-static void matvec_2d(struct htp_matmul_type * mt, struct htp_ops_context * octx, uint32_t nth, uint32_t ith) {
+static void matvec_2d(unsigned int nth, unsigned int ith, void * data) {
htp_matmul_preamble;
const uint32_t src0_nrows = ne01;
// Process src0 rows
for (uint32_t ir0 = src0_start_row; ir0 < src0_end_row_x2; ir0 += 2) {
const uint8_t * ss0 = dma_queue_pop(dma_queue).dst;
- mt->vec_dot_rx2(ne00, &tmp[ir0 - src0_start_row], ss0, src0_stride, src1_col);
+ mmctx->vec_dot_2x1(ne00, &tmp[ir0 - src0_start_row], ss0, ss0 + src0_stride, src1_col);
// Prefetch next (n + spad_nrows) row
const uint32_t pr0 = (ir0 + MM_SPAD_SRC0_NROWS);
dma_queue_push_ddr_to_vtcm(dma_queue, dma_make_ptr(spad_src0 + is0 * src0_stride, src0_row + ir0 * src0_row_size),
src0_stride, src0_row_size, 1);
const uint8_t * ss0 = dma_queue_pop(dma_queue).dst;
- mt->vec_dot(ne00, &tmp[ir0 - src0_start_row], ss0, src1_col);
+ mmctx->vec_dot_1x1(ne00, &tmp[ir0 - src0_start_row], ss0, src1_col);
}
hvx_copy_f32_ua((uint8_t *) &dst_col[src0_start_row], (uint8_t *) tmp, src0_end_row - src0_start_row);
t2 = HAP_perf_get_qtimer_count();
- FARF(HIGH, "matvec-%s %u/%u: %ux%ux%ux%u (%u:%u) * %ux%ux%ux%u -> %ux%ux%ux%u usec %u\n", mt->type, ith, nth,
+ FARF(HIGH, "matvec-%s %u/%u: %ux%ux%ux%u (%u:%u) * %ux%ux%ux%u -> %ux%ux%ux%u usec %u\n", mmctx->type, ith, nth,
src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3], src0_start_row, src0_end_row, src1->ne[0], src1->ne[1],
src1->ne[2], src1->ne[3], dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3],
(unsigned) HAP_perf_qtimer_count_to_us(t2 - t1));
};
// src1 tensor is already in VTCM spad
-static void matmul_id(struct htp_matmul_type * mt, struct htp_ops_context * octx, uint32_t nth, uint32_t ith) {
+static void matmul_id(unsigned int nth, unsigned int ith, void * data) {
htp_matmul_preamble;
struct htp_tensor * restrict ids = &octx->src2;
const int rm2 = row_mapping.i2; // token idx
const uint32_t ir1 = src1_nrows == 1 ? 0 : rm1; // src1 row idx
- const uint8_t * restrict src1_col =
- (const uint8_t *) (src1_data + (ir1 + rm2 * ne11 + 0) * src1_row_size);
+ const uint8_t * restrict src1_col = (const uint8_t *) (src1_data + (ir1 + rm2 * ne11 + 0) * src1_row_size);
float * dst_row = (float *) (dst->data + (rm1 * nb1 + rm2 * nb2 + 0));
- mt->vec_dot_rx2(ne00, &dst_row[ir0], ss0, src0_row_size_padded, src1_col);
+ mmctx->vec_dot_2x1(ne00, &dst_row[ir0], ss0, ss0 + src0_row_size_padded, src1_col);
}
// Prefetch next (n + spad_nrows) row
const int rm2 = row_mapping.i2; // token idx
const uint32_t ir1 = src1_nrows == 1 ? 0 : rm1; // src1 row idx
- const uint8_t * restrict src1_col =
- (const uint8_t *) (src1_data + (ir1 + rm2 * ne11 + 0) * src1_row_size);
+ const uint8_t * restrict src1_col = (const uint8_t *) (src1_data + (ir1 + rm2 * ne11 + 0) * src1_row_size);
float * dst_row = (float *) (dst->data + (rm1 * nb1 + rm2 * nb2 + 0));
- mt->vec_dot(ne00, &dst_row[ir0], ss0, src1_col);
+ mmctx->vec_dot_1x1(ne00, &dst_row[ir0], ss0, src1_col);
}
}
}
t2 = HAP_perf_get_qtimer_count();
- FARF(HIGH, "matmul-id-%s %d/%d: %ux%ux%ux%u (%u:%u) * %ux%ux%ux%u (%ux%ux%ux%u) -> %ux%ux%ux%u usec %u\n", mt->type,
+ FARF(HIGH, "matmul-id-%s %d/%d: %ux%ux%ux%u (%u:%u) * %ux%ux%ux%u (%ux%ux%ux%u) -> %ux%ux%ux%u usec %u\n", mmctx->type,
ith, nth, src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3], src0_start_row, src0_end_row, src1->ne[0],
src1->ne[1], src1->ne[2], src1->ne[3], ids->ne[0], ids->ne[1], ids->ne[2], ids->ne[3], dst->ne[0], dst->ne[1],
dst->ne[2], dst->ne[3], (unsigned) HAP_perf_qtimer_count_to_us(t2 - t1));
}
// src1 tensor is already in VTCM spad
-static void matvec_id(struct htp_matmul_type * mt, struct htp_ops_context * octx, uint32_t nth, uint32_t ith) {
+static void matvec_id(unsigned int nth, unsigned int ith, void * data) {
htp_matmul_preamble;
struct htp_tensor * restrict ids = &octx->src2;
// Process src0 rows
for (uint32_t ir0 = src0_start_row; ir0 < src0_end_row_x2; ir0 += 2) {
const uint8_t * ss0 = dma_queue_pop(dma_queue).dst;
- mt->vec_dot_rx2(ne00, &dst_row[ir0], ss0, src0_row_size_padded, src1_col);
+ mmctx->vec_dot_2x1(ne00, &dst_row[ir0], ss0, ss0 + src0_row_size_padded, src1_col);
// Prefetch next (n + spad_nrows) row
const int pr0 = (ir0 + MM_SPAD_SRC0_NROWS);
dma_queue_push_ddr_to_vtcm(dma_queue, dma_make_ptr(spad_src0 + is0 * src0_row_size_padded, src0_row + ir0 * src0_row_size),
src0_row_size_padded, src0_row_size, 1);
const uint8_t * ss0 = dma_queue_pop(dma_queue).dst;
- mt->vec_dot(ne00, &dst_row[ir0], ss0, src1_col);
+ mmctx->vec_dot_1x1(ne00, &dst_row[ir0], ss0, src1_col);
}
}
t2 = HAP_perf_get_qtimer_count();
- FARF(HIGH, "matvec-id-%s %d/%d: %ux%ux%ux%u (%u:%u) * %ux%ux%ux%u (%ux%ux%ux%u) -> %ux%ux%ux%u usec %u\n", mt->type,
+ FARF(HIGH, "matvec-id-%s %d/%d: %ux%ux%ux%u (%u:%u) * %ux%ux%ux%u (%ux%ux%ux%u) -> %ux%ux%ux%u usec %u\n", mmctx->type,
ith, nth, src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3], src0_start_row, src0_end_row, src1->ne[0],
src1->ne[1], src1->ne[2], src1->ne[3], src2->ne[0], src2->ne[1], src2->ne[2], src2->ne[3], dst->ne[0],
dst->ne[1], dst->ne[2], dst->ne[3], (unsigned) HAP_perf_qtimer_count_to_us(t2 - t1));
hvx_copy_f16_ua(y_d, t_d, nb * 8);
}
-static void quantize_f32_q8x4x2(const struct htp_tensor * src,
- uint8_t * restrict dst,
- struct htp_spad * spad,
- uint32_t nth,
- uint32_t ith,
- uint32_t nrows_per_thread) {
+static void quantize_f32_q8x4x2(unsigned int nth, unsigned int ith, void * data) {
+ struct htp_matmul_context * mmctx = data;
+ struct htp_ops_context * octx = mmctx->octx;
+
+ const struct htp_tensor * src = &octx->src1;
+ uint8_t * restrict dst = octx->src1_spad.data;
+ struct htp_spad * spad = &octx->src0_spad;
+ uint32_t nrows_per_thread = mmctx->src1_nrows_per_thread;
uint64_t t1 = HAP_perf_get_qtimer_count();
ir_last, src_row_size, dst_row_size, (unsigned) HAP_perf_qtimer_count_to_us(t2 - t1));
}
-static void quantize_f32_f16(const struct htp_tensor * src, uint8_t * restrict dst, uint32_t nth, uint32_t ith,
- uint32_t nrows_per_thread, uint32_t dst_stride) {
+static void quantize_f32_f16(unsigned int nth, unsigned int ith, void * data) {
+ struct htp_matmul_context * mmctx = data;
+ struct htp_ops_context * octx = mmctx->octx;
+
+ const struct htp_tensor * src = &octx->src1;
+ uint8_t * restrict dst = octx->src1_spad.data;
+ uint32_t nrows_per_thread = mmctx->src1_nrows_per_thread;
+ uint32_t dst_stride = octx->src1_spad.stride;
uint64_t t1 = HAP_perf_get_qtimer_count();
}
// TODO just a plain copy that should be done via the DMA during the Op setup
-static void quantize_f16_f16(const struct htp_tensor * src, uint8_t * restrict dst, uint32_t nth, uint32_t ith,
- uint32_t nrows_per_thread, uint32_t dst_stride) {
+static void quantize_f16_f16(unsigned int nth, unsigned int ith, void * data) {
+ struct htp_matmul_context * mmctx = data;
+ struct htp_ops_context * octx = mmctx->octx;
+
+ const struct htp_tensor * src = &octx->src1;
+ uint8_t * restrict dst = octx->src1_spad.data;
+ uint32_t nrows_per_thread = mmctx->src1_nrows_per_thread;
+ uint32_t dst_stride = octx->src1_spad.stride;
uint64_t t1 = HAP_perf_get_qtimer_count();
ir_last, src_row_size, src_stride, dst_stride, (unsigned) HAP_perf_qtimer_count_to_us(t2 - t1));
}
-static void htp_quantize_f32_q8x4x2(unsigned int n, unsigned int i, void * data) {
- struct htp_ops_context * octx = data;
- quantize_f32_q8x4x2(&octx->src1, octx->src1_spad.data, &octx->src0_spad, n, i, octx->src1_nrows_per_thread);
-}
-
-static void htp_quantize_f32_f16(unsigned int n, unsigned int i, void * data) {
- struct htp_ops_context * octx = data;
- quantize_f32_f16(&octx->src1, octx->src1_spad.data, n, i, octx->src1_nrows_per_thread, octx->src1_spad.stride);
-}
-
-static void htp_quantize_f16_f16(unsigned int n, unsigned int i, void * data) {
- struct htp_ops_context * octx = data;
- quantize_f16_f16(&octx->src1, octx->src1_spad.data, n, i, octx->src1_nrows_per_thread, octx->src1_spad.stride);
-}
-
-// ** matmul/matvec callbacks for worker_pool
-
-static void htp_matvec_2d_q4x4x2_q8x4x2(unsigned int n, unsigned int i, void * data) {
- struct htp_ops_context * octx = data;
-
- struct htp_matmul_type mt;
- mt.type = "q4x4x2-q8x4x2";
- mt.vec_dot = vec_dot_q4x4x2_q8x4x2;
- mt.vec_dot_rx2 = vec_dot_q4x4x2_q8x4x2_rx2;
-
- matvec_2d(&mt, octx, n, i);
-}
-
-static void htp_matmul_2d_q4x4x2_q8x4x2(unsigned int n, unsigned int i, void * data) {
- struct htp_ops_context * octx = data;
-
- struct htp_matmul_type mt;
- mt.type = "q4x4x2-q8x4x2";
- mt.vec_dot = vec_dot_q4x4x2_q8x4x2;
- mt.vec_dot_rx2 = vec_dot_q4x4x2_q8x4x2_rx2;
-
- matmul_2d(&mt, octx, n, i);
-}
-
-static void htp_matvec_2d_q8x4x2_q8x4x2(unsigned int n, unsigned int i, void * data) {
- struct htp_ops_context * octx = data;
-
- struct htp_matmul_type mt;
- mt.type = "q8x4x2-q8x4x2";
- mt.vec_dot = vec_dot_q8x4x2_q8x4x2;
- mt.vec_dot_rx2 = vec_dot_q8x4x2_q8x4x2_rx2;
-
- matvec_2d(&mt, octx, n, i);
-}
-
-static void htp_matmul_2d_q8x4x2_q8x4x2(unsigned int n, unsigned int i, void * data) {
- struct htp_ops_context * octx = data;
-
- struct htp_matmul_type mt;
- mt.type = "q8x4x2-q8x4x2";
- mt.vec_dot = vec_dot_q8x4x2_q8x4x2;
- mt.vec_dot_rx2 = vec_dot_q8x4x2_q8x4x2_rx2;
-
- matmul_2d(&mt, octx, n, i);
-}
-
-static void htp_matvec_2d_mxfp4x4x2_q8x4x2(unsigned int n, unsigned int i, void * data) {
- struct htp_ops_context * octx = data;
-
- struct htp_matmul_type mt;
- mt.type = "mxfp4x4x2-q8x4x2";
- mt.vec_dot = vec_dot_mxfp4x4x2_q8x4x2;
- mt.vec_dot_rx2 = vec_dot_mxfp4x4x2_q8x4x2_rx2;
-
- matvec_2d(&mt, octx, n, i);
-}
-
-static void htp_matmul_2d_mxfp4x4x2_q8x4x2(unsigned int n, unsigned int i, void * data) {
- struct htp_ops_context * octx = data;
-
- struct htp_matmul_type mt;
- mt.type = "mxfp4x4x2-q8x4x2";
- mt.vec_dot = vec_dot_mxfp4x4x2_q8x4x2;
- mt.vec_dot_rx2 = vec_dot_mxfp4x4x2_q8x4x2_rx2;
-
- matmul_2d(&mt, octx, n, i);
-}
-
-static void htp_matvec_2d_f16_f16(unsigned int n, unsigned int i, void * data) {
- struct htp_ops_context * octx = data;
-
- struct htp_matmul_type mt;
- mt.type = "f16-f16";
- mt.vec_dot = vec_dot_f16_f16_aa;
- mt.vec_dot_rx2 = vec_dot_f16_f16_aa_rx2;
-
- matvec_2d(&mt, octx, n, i);
-}
-
-static void htp_matmul_2d_f16_f16(unsigned int n, unsigned int i, void * data) {
- struct htp_ops_context * octx = data;
-
- struct htp_matmul_type mt;
- mt.type = "f16-f16";
- mt.vec_dot = vec_dot_f16_f16_aa;
- mt.vec_dot_rx2 = vec_dot_f16_f16_aa_rx2;
-
- matmul_2d(&mt, octx, n, i);
-}
-
-static void htp_matmul_4d_f16_f32(unsigned int n, unsigned int i, void * data) {
- struct htp_ops_context * octx = data;
-
- struct htp_matmul_type mt;
- mt.type = "f16-f32";
- mt.vec_dot = vec_dot_f16_f32_uu;
-
- matmul_4d(&mt, octx, n, i);
-}
-
-static void htp_matmul_4d_f16_f16(unsigned int n, unsigned int i, void * data) {
- struct htp_ops_context * octx = data;
- struct htp_matmul_type mt;
- mt.type = "f16-f16";
- mt.vec_dot = vec_dot_f16_f16_uu;
-
- matmul_4d(&mt, octx, n, i);
-}
-
-// ** matmul-id callbacks for worker_pool
-
-static void htp_matvec_id_q4x4x2_q8x4x2(unsigned int n, unsigned int i, void * data) {
- struct htp_ops_context * octx = data;
-
- struct htp_matmul_type mt;
- mt.type = "q4x4x2-q8x4x2";
- mt.vec_dot = vec_dot_q4x4x2_q8x4x2;
- mt.vec_dot_rx2 = vec_dot_q4x4x2_q8x4x2_rx2;
-
- matvec_id(&mt, octx, n, i);
-}
-
-static void htp_matmul_id_q4x4x2_q8x4x2(unsigned int n, unsigned int i, void * data) {
- struct htp_ops_context * octx = data;
-
- struct htp_matmul_type mt;
- mt.type = "q4x4x2-q8x4x2";
- mt.vec_dot = vec_dot_q4x4x2_q8x4x2;
- mt.vec_dot_rx2 = vec_dot_q4x4x2_q8x4x2_rx2;
-
- matmul_id(&mt, octx, n, i);
-}
-
-static void htp_matvec_id_q8x4x2_q8x4x2(unsigned int n, unsigned int i, void * data) {
- struct htp_ops_context * octx = data;
-
- struct htp_matmul_type mt;
- mt.type = "q8x4x2-q8x4x2";
- mt.vec_dot = vec_dot_q8x4x2_q8x4x2;
- mt.vec_dot_rx2 = vec_dot_q8x4x2_q8x4x2_rx2;
-
- matvec_id(&mt, octx, n, i);
-}
-
-static void htp_matmul_id_q8x4x2_q8x4x2(unsigned int n, unsigned int i, void * data) {
- struct htp_ops_context * octx = data;
-
- struct htp_matmul_type mt;
- mt.type = "q8x4x2-q8x4x2";
- mt.vec_dot = vec_dot_q8x4x2_q8x4x2;
- mt.vec_dot_rx2 = vec_dot_q8x4x2_q8x4x2_rx2;
-
- matmul_id(&mt, octx, n, i);
+static inline bool htp_is_permuted(const struct htp_tensor * t) {
+ return t->nb[0] > t->nb[1] || t->nb[1] > t->nb[2] || t->nb[2] > t->nb[3];
}
-static void htp_matvec_id_mxfp4x4x2_q8x4x2(unsigned int n, unsigned int i, void * data) {
- struct htp_ops_context * octx = data;
-
- struct htp_matmul_type mt;
- mt.type = "mxfp4x4x2-q8x4x2";
- mt.vec_dot = vec_dot_mxfp4x4x2_q8x4x2;
- mt.vec_dot_rx2 = vec_dot_mxfp4x4x2_q8x4x2_rx2;
-
- matvec_id(&mt, octx, n, i);
+static int htp_mminit_vec_dot(struct htp_matmul_context * mmctx, enum htp_data_type type) {
+ switch (type) {
+ case HTP_TYPE_Q4_0:
+ mmctx->type = "q4x4x2-f32";
+ mmctx->vec_dot_1x1 = vec_dot_q4x4x2_q8x4x2_1x1;
+ mmctx->vec_dot_2x1 = vec_dot_q4x4x2_q8x4x2_2x1;
+ mmctx->vec_dot_2x2 = vec_dot_q4x4x2_q8x4x2_2x2;
+ return 0;
+ case HTP_TYPE_Q8_0:
+ mmctx->type = "q8x4x2-f32";
+ mmctx->vec_dot_1x1 = vec_dot_q8x4x2_q8x4x2_1x1;
+ mmctx->vec_dot_2x1 = vec_dot_q8x4x2_q8x4x2_2x1;
+ mmctx->vec_dot_2x2 = vec_dot_q8x4x2_q8x4x2_2x2;
+ return 0;
+ case HTP_TYPE_MXFP4:
+ mmctx->type = "mxfp4x4x2-f32";
+ mmctx->vec_dot_1x1 = vec_dot_mxfp4x4x2_q8x4x2_1x1;
+ mmctx->vec_dot_2x1 = vec_dot_mxfp4x4x2_q8x4x2_2x1;
+ mmctx->vec_dot_2x2 = vec_dot_mxfp4x4x2_q8x4x2_2x2;
+ return 0;
+ default:
+ return -1;
+ }
}
-static void htp_matmul_id_mxfp4x4x2_q8x4x2(unsigned int n, unsigned int i, void * data) {
- struct htp_ops_context * octx = data;
-
- struct htp_matmul_type mt;
- mt.type = "mxfp4x4x2-q8x4x2";
- mt.vec_dot = vec_dot_mxfp4x4x2_q8x4x2;
- mt.vec_dot_rx2 = vec_dot_mxfp4x4x2_q8x4x2_rx2;
-
- matmul_id(&mt, octx, n, i);
-}
+static void htp_mminit_spad(struct htp_ops_context * octx,
+ size_t dst_row_size,
+ size_t src0_row_size_padded,
+ size_t src1_row_size,
+ uint32_t src1_nrows,
+ size_t src2_spad_size_per_thread) {
+ octx->dst_spad.size_per_thread = hex_round_up(MM_SPAD_DST_NROWS * dst_row_size, 256);
+ octx->src0_spad.size_per_thread = hex_round_up(MM_SPAD_SRC0_NROWS * src0_row_size_padded, 256);
+ octx->src1_spad.size_per_thread = hex_round_up(src1_row_size * src1_nrows, 256);
+
+ if (src2_spad_size_per_thread > 0) {
+ octx->src2_spad.size_per_thread = src2_spad_size_per_thread;
+ octx->src2_spad.size = octx->src2_spad.size_per_thread;
+ }
-// ** main matmul entry point
+ // src0 spad is also used in dynamic quantizer to store padded src1 rows
+ size_t src1_row_size_padded = hex_round_up(src1_row_size, QK_Q8_0x4x2 * sizeof(float));
+ if (octx->src0_spad.size_per_thread < src1_row_size_padded) {
+ octx->src0_spad.size_per_thread = src1_row_size_padded;
+ }
-static inline bool htp_is_permuted(const struct htp_tensor * t) {
- return t->nb[0] > t->nb[1] || t->nb[1] > t->nb[2] || t->nb[2] > t->nb[3];
+ octx->src1_spad.size = octx->src1_spad.size_per_thread;
+ octx->src0_spad.size = octx->src0_spad.size_per_thread * octx->n_threads;
+ octx->dst_spad.size = octx->dst_spad.size_per_thread * octx->n_threads;
}
int op_matmul(struct htp_ops_context * octx) {
htp_matmul_tensors_preamble;
- const char * op_type;
+ struct htp_matmul_context mmctx_struct = {0};
+ struct htp_matmul_context * mmctx = &mmctx_struct;
+ mmctx->octx = octx;
const uint32_t src0_nrows = ne01 * ne02 * ne03;
const uint32_t src1_nrows = ne11 * ne12 * ne13;
+ // Compute src0_nrows_per_thread
+ mmctx->src0_nrows_per_thread = (src0_nrows + octx->n_threads - 1) / octx->n_threads;
+ mmctx->src0_nrows_per_thread += (mmctx->src0_nrows_per_thread & 1); // round up to even
+
const size_t src0_row_size = nb01;
const size_t dst_row_size = nb1;
size_t src1_row_size = nb11;
size_t src1_row_size_padded;
worker_callback_t quant_job_func;
- worker_callback_t matmul_job_func;
+ worker_callback_t matmul_job_func = src1_nrows > 1 ? matmul_2d : matvec_2d;
bool need_quant = !(octx->flags & HTP_OPFLAGS_SKIP_QUANTIZE);
- switch (src0->type) {
- case HTP_TYPE_Q4_0:
- op_type = "q4x4x2-f32";
- quant_job_func = htp_quantize_f32_q8x4x2;
- if (src1_nrows > 1) {
- matmul_job_func = htp_matmul_2d_q4x4x2_q8x4x2;
- } else {
- matmul_job_func = htp_matvec_2d_q4x4x2_q8x4x2;
- }
-
- src1_row_size = q8x4x2_row_size(ne10); // row size post quantization
-
- // Entire src1 tensor is placed into the VTCM
- // For other tensors we allocate N rows per thread, padded to HVX vector size
-
- octx->dst_spad.size_per_thread = hex_round_up(MM_SPAD_DST_NROWS * dst_row_size, 256);
- octx->src0_spad.size_per_thread = hex_round_up(MM_SPAD_SRC0_NROWS * src0_row_size_padded, 256);
- octx->src1_spad.size_per_thread = hex_round_up(src1_row_size * src1_nrows, 256);
-
- // src0 spad is also used in dynamic quantizer to store padded src1 rows
- src1_row_size_padded = hex_round_up(src1_row_size, QK_Q8_0x4x2 * sizeof(float));
- if (octx->src0_spad.size_per_thread < src1_row_size_padded) {
- octx->src0_spad.size_per_thread = src1_row_size_padded;
- }
+ if (src0->type == HTP_TYPE_F16) {
+ // Try optimized f16-f16 path first (src1 in VTCM)
+ const size_t f16_src1_row_size = hex_round_up(ne10 * 2, 128);
+ const size_t f16_src1_spad_size = hex_round_up(f16_src1_row_size * src1_nrows, 256);
+ const size_t f16_src0_spad_size = hex_round_up(MM_SPAD_SRC0_NROWS * src0_row_size_padded, 256) * octx->n_threads;
+ const size_t f16_dst_spad_size = hex_round_up(MM_SPAD_DST_NROWS * dst_row_size, 256) * octx->n_threads;
- octx->src1_spad.size = octx->src1_spad.size_per_thread;
- octx->src0_spad.size = octx->src0_spad.size_per_thread * octx->n_threads;
- octx->dst_spad.size = octx->dst_spad.size_per_thread * octx->n_threads;
- break;
+ const size_t f16_total_size = f16_src1_spad_size + f16_src0_spad_size + f16_dst_spad_size;
- case HTP_TYPE_Q8_0:
- op_type = "q8x4x2-f32";
- quant_job_func = htp_quantize_f32_q8x4x2;
- if (src1_nrows > 1) {
- matmul_job_func = htp_matmul_2d_q8x4x2_q8x4x2;
- } else {
- matmul_job_func = htp_matvec_2d_q8x4x2_q8x4x2;
- }
+ // Default matmul implementation does not support multi-batch src0 (N-vs-N broadcasting).
+ // It only supports 1-vs-N broadcasting (src0 is 2D) or standard 2D matmul.
+ const bool is_batched = (ne02 > 1) || (ne03 > 1);
+ const bool is_permuted = htp_is_permuted(&octx->src0) || htp_is_permuted(&octx->src1);
- src1_row_size = q8x4x2_row_size(ne10); // row size post quantization
+ if (!is_batched && !is_permuted && f16_total_size <= octx->ctx->vtcm_size) {
+ // Optimized path
+ quant_job_func = (src1->type == HTP_TYPE_F32) ? quantize_f32_f16 : quantize_f16_f16;
+ mmctx->type = "f16-f16";
+ mmctx->vec_dot_1x1 = vec_dot_f16_f16_aa_1x1;
+ mmctx->vec_dot_2x1 = vec_dot_f16_f16_aa_2x1;
+ mmctx->vec_dot_2x2 = vec_dot_f16_f16_aa_2x2;
- // Entire src1 tensor is placed into the VTCM
- // For other tensors we allocate N rows per thread, padded to HVX vector size
+ src1_row_size = f16_src1_row_size; // row size post quantization
octx->dst_spad.size_per_thread = hex_round_up(MM_SPAD_DST_NROWS * dst_row_size, 256);
octx->src0_spad.size_per_thread = hex_round_up(MM_SPAD_SRC0_NROWS * src0_row_size_padded, 256);
octx->src1_spad.size_per_thread = hex_round_up(src1_row_size * src1_nrows, 256);
- // src0 spad is also used in dynamic quantizer to store padded src1 rows
- src1_row_size_padded = hex_round_up(src1_row_size, QK_Q8_0x4x2 * sizeof(float));
- if (octx->src0_spad.size_per_thread < src1_row_size_padded) {
- octx->src0_spad.size_per_thread = src1_row_size_padded;
- }
-
octx->src1_spad.size = octx->src1_spad.size_per_thread;
octx->src0_spad.size = octx->src0_spad.size_per_thread * octx->n_threads;
octx->dst_spad.size = octx->dst_spad.size_per_thread * octx->n_threads;
- break;
-
- case HTP_TYPE_MXFP4:
- op_type = "mxfp4x4x2-f32";
- quant_job_func = htp_quantize_f32_q8x4x2;
- if (src1_nrows > 1) {
- matmul_job_func = htp_matmul_2d_mxfp4x4x2_q8x4x2;
+ } else {
+ // Fallback to f16/f32 (DDR) if src1 doesn't fit in VTCM or broadcasting is required
+ quant_job_func = NULL;
+ if (src1->type == HTP_TYPE_F32) {
+ mmctx->type = "f16-f32";
+ mmctx->vec_dot_1x1 = vec_dot_f16_f32_uu_1x1;
+ matmul_job_func = matmul_4d;
} else {
- matmul_job_func = htp_matvec_2d_mxfp4x4x2_q8x4x2;
+ mmctx->type = "f16-f16";
+ mmctx->vec_dot_1x1 = vec_dot_f16_f16_uu_1x1;
+ matmul_job_func = matmul_4d;
}
- src1_row_size = q8x4x2_row_size(ne10); // row size post quantization
-
- // Entire src1 tensor is placed into the VTCM
- // For other tensors we allocate N rows per thread, padded to HVX vector size
+ src1_row_size = nb11; // original row size in DDR
octx->dst_spad.size_per_thread = hex_round_up(MM_SPAD_DST_NROWS * dst_row_size, 256);
- octx->src0_spad.size_per_thread = hex_round_up(MM_SPAD_SRC0_NROWS * src0_row_size_padded, 256);
- octx->src1_spad.size_per_thread = hex_round_up(src1_row_size * src1_nrows, 256);
+ octx->src0_spad.size_per_thread = hex_round_up(MM_SPAD_SRC0_NROWS * src0_row_size, 256);
+ octx->src1_spad.size_per_thread = hex_round_up(MM_SPAD_SRC1_NROWS * src1_row_size, 256);
- // src0 spad is also used in dynamic quantizer to store padded src1 rows
- src1_row_size_padded = hex_round_up(src1_row_size, QK_Q8_0x4x2 * sizeof(float));
- if (octx->src0_spad.size_per_thread < src1_row_size_padded) {
- octx->src0_spad.size_per_thread = src1_row_size_padded;
- }
-
- octx->src1_spad.size = octx->src1_spad.size_per_thread;
octx->src0_spad.size = octx->src0_spad.size_per_thread * octx->n_threads;
+ octx->src1_spad.size = octx->src1_spad.size_per_thread * octx->n_threads;
octx->dst_spad.size = octx->dst_spad.size_per_thread * octx->n_threads;
- break;
- case HTP_TYPE_F16:
- {
- // Try optimized f16-f16 path first (src1 in VTCM)
- const size_t f16_src1_row_size = hex_round_up(ne10 * 2, 128);
- const size_t f16_src1_spad_size = hex_round_up(f16_src1_row_size * src1_nrows, 256);
- const size_t f16_src0_spad_size = hex_round_up(MM_SPAD_SRC0_NROWS * src0_row_size_padded, 256) * octx->n_threads;
- const size_t f16_dst_spad_size = hex_round_up(MM_SPAD_DST_NROWS * dst_row_size, 256) * octx->n_threads;
-
- const size_t f16_total_size = f16_src1_spad_size + f16_src0_spad_size + f16_dst_spad_size;
-
- // Default matmul implementation does not support multi-batch src0 (N-vs-N broadcasting).
- // It only supports 1-vs-N broadcasting (src0 is 2D) or standard 2D matmul.
- const bool is_batched = (ne02 > 1) || (ne03 > 1);
- const bool is_permuted = htp_is_permuted(&octx->src0) || htp_is_permuted(&octx->src1);
-
- if (!is_batched && !is_permuted && f16_total_size <= octx->ctx->vtcm_size) {
- // Optimized path
- op_type = "f16-f16";
- quant_job_func = (src1->type == HTP_TYPE_F32) ? htp_quantize_f32_f16 : htp_quantize_f16_f16;
- if (src1_nrows > 1) {
- matmul_job_func = htp_matmul_2d_f16_f16;
- } else {
- matmul_job_func = htp_matvec_2d_f16_f16;
- }
-
- src1_row_size = f16_src1_row_size; // row size post quantization
-
- octx->dst_spad.size_per_thread = hex_round_up(MM_SPAD_DST_NROWS * dst_row_size, 256);
- octx->src0_spad.size_per_thread = hex_round_up(MM_SPAD_SRC0_NROWS * src0_row_size_padded, 256);
- octx->src1_spad.size_per_thread = hex_round_up(src1_row_size * src1_nrows, 256);
-
- octx->src1_spad.size = octx->src1_spad.size_per_thread;
- octx->src0_spad.size = octx->src0_spad.size_per_thread * octx->n_threads;
- octx->dst_spad.size = octx->dst_spad.size_per_thread * octx->n_threads;
- } else {
- // Fallback to f16/f32 (DDR) if src1 doesn't fit in VTCM or broadcasting is required
- quant_job_func = NULL;
- if (src1->type == HTP_TYPE_F32) {
- op_type = "f16-f32";
- matmul_job_func = htp_matmul_4d_f16_f32;
- } else {
- op_type = "f16-f16";
- matmul_job_func = htp_matmul_4d_f16_f16;
- }
-
- src1_row_size = nb11; // original row size in DDR
-
- octx->dst_spad.size_per_thread = hex_round_up(MM_SPAD_DST_NROWS * dst_row_size, 256);
- octx->src0_spad.size_per_thread = hex_round_up(MM_SPAD_SRC0_NROWS * src0_row_size, 256);
- octx->src1_spad.size_per_thread = hex_round_up(MM_SPAD_SRC1_NROWS * src1_row_size, 256);
-
- octx->src0_spad.size = octx->src0_spad.size_per_thread * octx->n_threads;
- octx->src1_spad.size = octx->src1_spad.size_per_thread * octx->n_threads;
- octx->dst_spad.size = octx->dst_spad.size_per_thread * octx->n_threads;
-
- // Init fastdiv for matmul_4d (supports broadcasting)
- octx->mm_div_ne12_ne1 = init_fastdiv_values(src1->ne[2] * dst->ne[1]);
- octx->mm_div_ne1 = init_fastdiv_values(dst->ne[1]);
- octx->mm_div_r2 = init_fastdiv_values(src1->ne[2] / src0->ne[2]);
- octx->mm_div_r3 = init_fastdiv_values(src1->ne[3] / src0->ne[3]);
-
- need_quant = false;
- }
- }
- break;
+ // Init fastdiv for matmul_4d (supports broadcasting)
+ mmctx->mm_div_ne12_ne1 = init_fastdiv_values(src1->ne[2] * dst->ne[1]);
+ mmctx->mm_div_ne1 = init_fastdiv_values(dst->ne[1]);
+ mmctx->mm_div_r2 = init_fastdiv_values(src1->ne[2] / src0->ne[2]);
+ mmctx->mm_div_r3 = init_fastdiv_values(src1->ne[3] / src0->ne[3]);
- default:
+ need_quant = false;
+ }
+ } else {
+ if (htp_mminit_vec_dot(mmctx, src0->type) != 0) {
return HTP_STATUS_NO_SUPPORT;
+ }
+
+ quant_job_func = quantize_f32_q8x4x2;
+ src1_row_size = q8x4x2_row_size(ne10);
+ htp_mminit_spad(octx, dst_row_size, src0_row_size_padded, src1_row_size, src1_nrows, 0);
}
// VTCM scratchpads for all tensors
size_t spad_size = octx->src1_spad.size + octx->src0_spad.size + octx->dst_spad.size;
- FARF(HIGH, "matmul-%s : src0-spad-size %u src1-spad-size %u dst-spad-size %u (%zu)\n", op_type,
+ FARF(HIGH, "matmul-%s : src0-spad-size %u src1-spad-size %u dst-spad-size %u (%zu)\n", mmctx->type,
octx->src0_spad.size, octx->src1_spad.size, octx->dst_spad.size, spad_size);
- FARF(HIGH, "matmul-%s : %ux%ux%ux%u * %ux%ux%ux%u-> %ux%ux%ux%u (0x%p, 0x%p, 0x%p)\n", op_type, src0->ne[0],
+ FARF(HIGH, "matmul-%s : %ux%ux%ux%u * %ux%ux%ux%u-> %ux%ux%ux%u (0x%p, 0x%p, 0x%p)\n", mmctx->type, src0->ne[0],
src0->ne[1], src0->ne[2], src0->ne[3], src1->ne[0], src1->ne[1], src1->ne[2], src1->ne[3], dst->ne[0],
dst->ne[1], dst->ne[2], dst->ne[3], src0->data, src1->data, dst->data);
// Make sure the reserved vtcm size is sufficient
if (octx->ctx->vtcm_size < spad_size) {
- FARF(ERROR, "matmul-%s : current VTCM reservation %zu is too small, needed %zu\n", op_type,
+ FARF(ERROR, "matmul-%s : current VTCM reservation %zu is too small, needed %zu\n", mmctx->type,
octx->ctx->vtcm_size, spad_size);
return HTP_STATUS_VTCM_TOO_SMALL;
}
octx->src1_spad.data = octx->src0_spad.data + octx->src0_spad.size;
octx->dst_spad.data = octx->src1_spad.data + octx->src1_spad.size;
- octx->src0_nrows_per_thread = (src0_nrows + octx->n_threads - 1) / octx->n_threads;
- octx->src0_nrows_per_thread += (octx->src0_nrows_per_thread & 1); // round up to even
-
octx->src0_spad.stride = src0_row_size_padded;
octx->src1_spad.stride = src1_row_size;
if (need_quant) {
- // Run quant jobs
- const uint32_t n_quant_jobs = MIN(src1_nrows, octx->n_threads);
- octx->src1_nrows_per_thread = (src1_nrows + n_quant_jobs - 1) / n_quant_jobs;
- worker_pool_run_func(octx->ctx->worker_pool, quant_job_func, octx, n_quant_jobs);
+ const uint32_t n_quant_jobs = MIN(src1_nrows, octx->n_threads);
+ mmctx->src1_nrows_per_thread = (src1_nrows + n_quant_jobs - 1) / n_quant_jobs;
+ worker_pool_run_func(octx->ctx->worker_pool, quant_job_func, mmctx, n_quant_jobs);
}
if (!(octx->flags & HTP_OPFLAGS_SKIP_COMPUTE)) {
- // Run matmul jobs
const uint32_t n_matmul_jobs = octx->n_threads;
- worker_pool_run_func(octx->ctx->worker_pool, matmul_job_func, octx, n_matmul_jobs);
+ worker_pool_run_func(octx->ctx->worker_pool, matmul_job_func, mmctx, n_matmul_jobs);
}
return HTP_STATUS_OK;
}
-// ** main matmul-id entry point
-
int op_matmul_id(struct htp_ops_context * octx) {
htp_matmul_tensors_preamble;
- struct htp_tensor * restrict ids = &octx->src2;
-
- const char * op_type;
+ struct htp_matmul_context mmctx_struct = {0};
+ struct htp_matmul_context * mmctx = &mmctx_struct;
+ mmctx->octx = octx;
- worker_callback_t quant_job_func;
- worker_callback_t matmul_id_job_func;
+ struct htp_tensor * restrict ids = &octx->src2;
const size_t src0_row_size = nb01;
const size_t dst_row_size = nb1;
const uint32_t src0_nrows = ne01; // per expert
const uint32_t src1_nrows = ne11 * ne12 * ne13;
+ worker_callback_t quant_job_func;
+ worker_callback_t matmul_id_job_func = src1_nrows > 1 ? matmul_id : matvec_id;
+
+ // Compute src0_nrows_per_thread
+ mmctx->src0_nrows_per_thread = (src0_nrows + octx->n_threads - 1) / octx->n_threads;
+ mmctx->src0_nrows_per_thread += (mmctx->src0_nrows_per_thread & 1); // round up to even
+
size_t src1_row_size;
size_t src1_row_size_padded;
size_t matrix_row_counts_size = n_as * sizeof(uint32_t);
size_t matrix_row_map_size = n_as * ids->ne[0] * ids->ne[1] * sizeof(struct mmid_row_mapping);
- switch (src0->type) {
- case HTP_TYPE_Q4_0:
- op_type = "q4x2x2-f32";
- quant_job_func = htp_quantize_f32_q8x4x2;
- src1_row_size = q8x4x2_row_size(ne10); // row size post quantization
- if (src1_nrows > 1) {
- matmul_id_job_func = htp_matmul_id_q4x4x2_q8x4x2;
- } else {
- matmul_id_job_func = htp_matvec_id_q4x4x2_q8x4x2;
- }
-
- // Entire src1 tensor is placed into the VTCM
- // For other tensors we allocate N rows per thread, padded to HVX vector size
- octx->dst_spad.size_per_thread = hex_round_up(MM_SPAD_DST_NROWS * dst_row_size, 256);
- octx->src0_spad.size_per_thread = hex_round_up(MM_SPAD_SRC0_NROWS * src0_row_size_padded, 256);
- octx->src1_spad.size_per_thread = hex_round_up(src1_row_size * src1_nrows, 256);
- octx->src2_spad.size_per_thread = hex_round_up(matrix_row_counts_size + matrix_row_map_size, 256);
-
- // src0 spad is also used in dynamic quantizer to store padded src1 rows
- src1_row_size_padded = hex_round_up(src1_row_size, QK_Q8_0x4x2 * sizeof(float));
- if (octx->src0_spad.size_per_thread < src1_row_size_padded) {
- octx->src0_spad.size_per_thread = src1_row_size_padded;
- }
-
- octx->src2_spad.size = octx->src2_spad.size_per_thread;
- octx->src1_spad.size = octx->src1_spad.size_per_thread;
- octx->src0_spad.size = octx->src0_spad.size_per_thread * octx->n_threads;
- octx->dst_spad.size = octx->dst_spad.size_per_thread * octx->n_threads;
- break;
-
- case HTP_TYPE_Q8_0:
- op_type = "q8x2x2-f32";
- quant_job_func = htp_quantize_f32_q8x4x2;
- src1_row_size = q8x4x2_row_size(ne10); // row size post quantization
- if (src1_nrows > 1) {
- matmul_id_job_func = htp_matmul_id_q8x4x2_q8x4x2;
- } else {
- matmul_id_job_func = htp_matvec_id_q8x4x2_q8x4x2;
- }
-
- // Entire src1 tensor is placed into the VTCM
- // For other tensors we allocate N rows per thread, padded to HVX vector size
- octx->dst_spad.size_per_thread = hex_round_up(MM_SPAD_DST_NROWS * dst_row_size, 256);
- octx->src0_spad.size_per_thread = hex_round_up(MM_SPAD_SRC0_NROWS * src0_row_size_padded, 256);
- octx->src1_spad.size_per_thread = hex_round_up(src1_row_size * src1_nrows, 256);
- octx->src2_spad.size_per_thread = hex_round_up(matrix_row_counts_size + matrix_row_map_size, 256);
-
- // src0 spad is also used in dynamic quantizer to store padded src1 rows
- src1_row_size_padded = hex_round_up(src1_row_size, QK_Q8_0x4x2 * sizeof(float));
- if (octx->src0_spad.size_per_thread < src1_row_size_padded) {
- octx->src0_spad.size_per_thread = src1_row_size_padded;
- }
-
- octx->src2_spad.size = octx->src2_spad.size_per_thread;
- octx->src1_spad.size = octx->src1_spad.size_per_thread;
- octx->src0_spad.size = octx->src0_spad.size_per_thread * octx->n_threads;
- octx->dst_spad.size = octx->dst_spad.size_per_thread * octx->n_threads;
- break;
-
- case HTP_TYPE_MXFP4:
- op_type = "mxfp4x2x2-f32";
- quant_job_func = htp_quantize_f32_q8x4x2;
- src1_row_size = q8x4x2_row_size(ne10); // row size post quantization
- if (src1_nrows > 1) {
- matmul_id_job_func = htp_matmul_id_mxfp4x4x2_q8x4x2;
- } else {
- matmul_id_job_func = htp_matvec_id_mxfp4x4x2_q8x4x2;
- }
-
- // Entire src1 tensor is placed into the VTCM
- // For other tensors we allocate N rows per thread, padded to HVX vector size
- octx->dst_spad.size_per_thread = hex_round_up(MM_SPAD_DST_NROWS * dst_row_size, 256);
- octx->src0_spad.size_per_thread = hex_round_up(MM_SPAD_SRC0_NROWS * src0_row_size_padded, 256);
- octx->src1_spad.size_per_thread = hex_round_up(src1_row_size * src1_nrows, 256);
- octx->src2_spad.size_per_thread = hex_round_up(matrix_row_counts_size + matrix_row_map_size, 256);
-
- // src0 spad is also used in dynamic quantizer to store padded src1 rows
- src1_row_size_padded = hex_round_up(src1_row_size, QK_Q8_0x4x2 * sizeof(float));
- if (octx->src0_spad.size_per_thread < src1_row_size_padded) {
- octx->src0_spad.size_per_thread = src1_row_size_padded;
- }
+ if (htp_mminit_vec_dot(mmctx, src0->type) != 0) {
+ return HTP_STATUS_NO_SUPPORT;
+ }
- octx->src2_spad.size = octx->src2_spad.size_per_thread;
- octx->src1_spad.size = octx->src1_spad.size_per_thread;
- octx->src0_spad.size = octx->src0_spad.size_per_thread * octx->n_threads;
- octx->dst_spad.size = octx->dst_spad.size_per_thread * octx->n_threads;
- break;
+ quant_job_func = quantize_f32_q8x4x2;
+ src1_row_size = q8x4x2_row_size(ne10);
- default:
- return HTP_STATUS_NO_SUPPORT;
- }
+ const size_t src2_spad_size_per_thread = hex_round_up(matrix_row_counts_size + matrix_row_map_size, 256);
+ htp_mminit_spad(octx, dst_row_size, src0_row_size_padded, src1_row_size, src1_nrows, src2_spad_size_per_thread);
size_t spad_size = octx->src2_spad.size + octx->src1_spad.size + octx->src0_spad.size + octx->dst_spad.size;
- FARF(HIGH, "matmul-id-%s : src0-spad-size %u src1-spad-size %u src2-spad-size %u dst-spad-size %u (%zu)\n", op_type,
+ FARF(HIGH, "matmul-id-%s : src0-spad-size %u src1-spad-size %u src2-spad-size %u dst-spad-size %u (%zu)\n", mmctx->type,
octx->src0_spad.size, octx->src1_spad.size, octx->src2_spad.size, octx->dst_spad.size, spad_size);
- FARF(HIGH, "matmul-id-%s : %ux%ux%ux%u * %ux%ux%ux%u (%ux%ux%ux%u) -> %ux%ux%ux%u (0x%p, 0x%p, 0x%p)\n", op_type,
+ FARF(HIGH, "matmul-id-%s : %ux%ux%ux%u * %ux%ux%ux%u (%ux%ux%ux%u) -> %ux%ux%ux%u (0x%p, 0x%p, 0x%p)\n", mmctx->type,
src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3], src1->ne[0], src1->ne[1], src1->ne[2], src1->ne[3],
ids->ne[0], ids->ne[1], ids->ne[2], ids->ne[3], dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3], src0->data,
src1->data, dst->data);
// Make sure the reserved vtcm size is sufficient
if (octx->ctx->vtcm_size < spad_size) {
- FARF(ERROR, "matmul-id-%s : current VTCM reservation %zu is too small, needed %zu\n", op_type,
- octx->ctx->vtcm_size, spad_size);
+ FARF(ERROR, "matmul-id-%s : current VTCM reservation %zu is too small, needed %zu\n", mmctx->type, octx->ctx->vtcm_size, spad_size);
return HTP_STATUS_VTCM_TOO_SMALL;
}
octx->src2_spad.data = octx->src1_spad.data + octx->src1_spad.size;
octx->dst_spad.data = octx->src2_spad.data + octx->src2_spad.size;
- octx->src0_nrows_per_thread = (src0_nrows + octx->n_threads - 1) / octx->n_threads;
- octx->src0_nrows_per_thread += (octx->src0_nrows_per_thread & 1); // round up to even
+ octx->src0_spad.stride = src0_row_size_padded;
+ octx->src1_spad.stride = src1_row_size;
if (src1_nrows > 1) {
// initialize matrix_row_counts and map
// group rows by src0 matrix
for (uint32_t iid1 = 0; iid1 < ids->ne[1]; ++iid1) { // token idx
for (uint32_t id = 0; id < n_ids; ++id) { // expert idx
- const uint32_t i02 =
- *(const uint32_t *) ((const uint8_t *) ids->data + iid1 * ids->nb[1] + id * ids->nb[0]);
+ const uint32_t i02 = *(const uint32_t *) ((const uint8_t *) ids->data + iid1 * ids->nb[1] + id * ids->nb[0]);
assert(i02 >= 0 && i02 < n_as);
// Setup worker pool callbacks
if (!(octx->flags & HTP_OPFLAGS_SKIP_QUANTIZE)) {
- // Run quant jobs
const uint32_t n_quant_jobs = MIN(src1_nrows, octx->n_threads);
- octx->src1_nrows_per_thread = (src1_nrows + n_quant_jobs - 1) / n_quant_jobs;
- worker_pool_run_func(octx->ctx->worker_pool, quant_job_func, octx, n_quant_jobs);
+ mmctx->src1_nrows_per_thread = (src1_nrows + n_quant_jobs - 1) / n_quant_jobs;
+ worker_pool_run_func(octx->ctx->worker_pool, quant_job_func, mmctx, n_quant_jobs);
}
if (!(octx->flags & HTP_OPFLAGS_SKIP_COMPUTE)) {
- // Run matmul-id jobs
const uint32_t n_matmul_jobs = octx->n_threads;
- worker_pool_run_func(octx->ctx->worker_pool, matmul_id_job_func, octx, n_matmul_jobs);
+ worker_pool_run_func(octx->ctx->worker_pool, matmul_id_job_func, mmctx, n_matmul_jobs);
}
return HTP_STATUS_OK;