]> git.djapps.eu Git - pkg/ggml/sources/ggml/commitdiff
hexagon: further optimization and tuning of matmul and dot kernels (llama/19407)
authorMax Krasnyansky <redacted>
Thu, 12 Feb 2026 07:04:27 +0000 (23:04 -0800)
committerGeorgi Gerganov <redacted>
Sat, 14 Feb 2026 22:20:18 +0000 (00:20 +0200)
* ggml-hexagon: implement 2x2 matmul kernel

* hexmm: implement vec_dot_rx2x2 for Q8_0 and MXFP4

* hexagon: fix editor config failures

* hexagon: refactor matmul ops to use context struct and remove wrappers

Also implement vec_dot_f16 2x2

* hexagon: refactor dyn quantizers to use mmctx

* hexagon: remove mm fastdiv from op_ctx

* hexagon: refactor matmul entry point to reduce code duplication

---------

Co-authored-by: Trivikram Reddy <redacted>
src/ggml-hexagon/htp/htp-ops.h
src/ggml-hexagon/htp/matmul-ops.c

index c0d72587ce5d3831261f2ceb64397a40227b571a..f1ad24dbfaa677f4846775f4a1ba2514283ab835 100644 (file)
@@ -64,25 +64,12 @@ struct htp_ops_context {
     struct fastdiv_values broadcast_rv2;
     struct fastdiv_values broadcast_rv3;
 
-    struct fastdiv_values mm_div_ne12_ne1; // fastdiv values for ne12 * ne1
-    struct fastdiv_values mm_div_ne1;      // fastdiv values for ne1
-    struct fastdiv_values mm_div_r2;       // fastdiv values for ne12 / ne02
-    struct fastdiv_values mm_div_r3;       // fastdiv values for ne13 / ne03
-
     struct fastdiv_values set_rows_div_ne12; // fastdiv values for ne12
     struct fastdiv_values set_rows_div_ne11; // fastdiv values for ne11
 
     struct fastdiv_values get_rows_div_ne10;      // fastdiv values for ne10
     struct fastdiv_values get_rows_div_ne10_ne11; // fastdiv values for ne10 * ne11
 
-    struct fastdiv_values cpy_div_ne01; // fastdiv values for ne01
-    struct fastdiv_values cpy_div_ne02; // fastdiv values for ne02
-    struct fastdiv_values cpy_div_ne03; // fastdiv values for ne03
-
-    struct fastdiv_values cpy_rshp_div_n0;       // fastdiv values for ne00
-    struct fastdiv_values cpy_rshp_div_n1n0;     // fastdiv values for ne00*ne01
-    struct fastdiv_values cpy_rshp_div_n2n1n0;   // fastdiv values for ne00*ne01*ne02
-
     uint32_t flags;
 };
 
index d251eeed33a97af93d5299935f0b2b87b4681398..c360abe8dae2d308e4c4d85d0692e9574ecb6dba 100644 (file)
 #define MM_SPAD_SRC1_NROWS 16
 #define MM_SPAD_DST_NROWS  2
 
-struct htp_matmul_type {
+struct htp_matmul_context {
     const char * type;
-    void (*vec_dot)(const int n, float * restrict s, const void * restrict vx, const void * restrict vy);
-    void (*vec_dot_rx2)(const int n, float * restrict s, const void * restrict vx, uint32_t vx_row_size, const void * restrict vy);
+    struct htp_ops_context * octx;
+
+    void (*vec_dot_1x1)(const int n, float * restrict s0,
+         const void * restrict vx0,
+         const void * restrict vy0);
+
+    void (*vec_dot_2x1)(const int n, float * restrict s0,
+         const void * restrict vx0, const void * restrict vx1,
+         const void * restrict vy0);
+
+    void (*vec_dot_2x2)(const int n, float * restrict s0, float * restrict s1,
+         const void * restrict vx0, const void * restrict vx1,
+         const void * restrict vy0, const void * restrict vy1);
+
+    // Precomputed values
+    uint32_t src0_nrows_per_thread;
+    uint32_t src1_nrows_per_thread;
+
+    struct fastdiv_values mm_div_ne12_ne1;
+    struct fastdiv_values mm_div_ne1;
+    struct fastdiv_values mm_div_r2;
+    struct fastdiv_values mm_div_r3;
 };
 
 // vdelta control to replicate first 4x fp32 values across lanes
@@ -122,6 +142,7 @@ static inline HVX_Vector_x8 hvx_vec_load_q4x4x8(const uint8_t * restrict ptr) {
     HVX_Vector v6_7 = vptr[3];  // ...
 
     const HVX_Vector mask_h4 = Q6_Vb_vsplat_R(0x0F);
+    const HVX_Vector i8 = Q6_Vb_vsplat_R(8);
 
     HVX_Vector v0 = Q6_V_vand_VV(v0_1, mask_h4);  // & 0x0F
     HVX_Vector v1 = Q6_Vub_vlsr_VubR(v0_1, 4);    // >> 4
@@ -133,15 +154,14 @@ static inline HVX_Vector_x8 hvx_vec_load_q4x4x8(const uint8_t * restrict ptr) {
     HVX_Vector v7 = Q6_Vub_vlsr_VubR(v6_7, 4);    // >> 4
 
     // Convert uint4 to int4 (i.e. x - 8)
-    const HVX_Vector i8 = Q6_Vb_vsplat_R(8);
-    v0                  = Q6_Vb_vsub_VbVb(v0, i8);
-    v1                  = Q6_Vb_vsub_VbVb(v1, i8);
-    v2                  = Q6_Vb_vsub_VbVb(v2, i8);
-    v3                  = Q6_Vb_vsub_VbVb(v3, i8);
-    v4                  = Q6_Vb_vsub_VbVb(v4, i8);
-    v5                  = Q6_Vb_vsub_VbVb(v5, i8);
-    v6                  = Q6_Vb_vsub_VbVb(v6, i8);
-    v7                  = Q6_Vb_vsub_VbVb(v7, i8);
+    v0 = Q6_Vb_vsub_VbVb(v0, i8);
+    v1 = Q6_Vb_vsub_VbVb(v1, i8);
+    v2 = Q6_Vb_vsub_VbVb(v2, i8);
+    v3 = Q6_Vb_vsub_VbVb(v3, i8);
+    v4 = Q6_Vb_vsub_VbVb(v4, i8);
+    v5 = Q6_Vb_vsub_VbVb(v5, i8);
+    v6 = Q6_Vb_vsub_VbVb(v6, i8);
+    v7 = Q6_Vb_vsub_VbVb(v7, i8);
 
     HVX_Vector_x8 r = { v0, v1, v2, v3, v4, v5, v6, v7 };
     return r;
@@ -156,6 +176,7 @@ static inline HVX_Vector_x8 hvx_vec_load_mxfp4x4x8(const uint8_t * restrict ptr)
     HVX_Vector v6_7 = vptr[3];  // ...
 
     const HVX_Vector mask_h4 = Q6_Vb_vsplat_R(0x0F);
+    const HVX_Vector lut = *(const HVX_Vector *) kvalues_mxfp4_lut;
 
     HVX_Vector v0 = Q6_V_vand_VV(v0_1, mask_h4);  // & 0x0F
     HVX_Vector v1 = Q6_Vub_vlsr_VubR(v0_1, 4);    // >> 4
@@ -166,15 +187,14 @@ static inline HVX_Vector_x8 hvx_vec_load_mxfp4x4x8(const uint8_t * restrict ptr)
     HVX_Vector v6 = Q6_V_vand_VV(v6_7, mask_h4);  // & 0x0F
     HVX_Vector v7 = Q6_Vub_vlsr_VubR(v6_7, 4);    // >> 4
 
-    HVX_Vector lut = *(const HVX_Vector *) kvalues_mxfp4_lut;
-    v0             = Q6_Vb_vlut32_VbVbI(v0, lut, 0);
-    v1             = Q6_Vb_vlut32_VbVbI(v1, lut, 0);
-    v2             = Q6_Vb_vlut32_VbVbI(v2, lut, 0);
-    v3             = Q6_Vb_vlut32_VbVbI(v3, lut, 0);
-    v4             = Q6_Vb_vlut32_VbVbI(v4, lut, 0);
-    v5             = Q6_Vb_vlut32_VbVbI(v5, lut, 0);
-    v6             = Q6_Vb_vlut32_VbVbI(v6, lut, 0);
-    v7             = Q6_Vb_vlut32_VbVbI(v7, lut, 0);
+    v0 = Q6_Vb_vlut32_VbVbI(v0, lut, 0);
+    v1 = Q6_Vb_vlut32_VbVbI(v1, lut, 0);
+    v2 = Q6_Vb_vlut32_VbVbI(v2, lut, 0);
+    v3 = Q6_Vb_vlut32_VbVbI(v3, lut, 0);
+    v4 = Q6_Vb_vlut32_VbVbI(v4, lut, 0);
+    v5 = Q6_Vb_vlut32_VbVbI(v5, lut, 0);
+    v6 = Q6_Vb_vlut32_VbVbI(v6, lut, 0);
+    v7 = Q6_Vb_vlut32_VbVbI(v7, lut, 0);
 
     HVX_Vector_x8 r = { v0, v1, v2, v3, v4, v5, v6, v7 };
     return r;
@@ -196,46 +216,6 @@ static inline HVX_Vector_x8 hvx_vec_load_q8x4x8(const uint8_t * restrict ptr) {
     return r;
 }
 
-static inline HVX_Vector_x4 hvx_vec_load_x4_f16(const uint8_t * restrict ptr) {
-    const HVX_Vector * restrict vptr = (const HVX_Vector *) ptr;
-
-    HVX_Vector v0 = vptr[0];  // first  64 vals
-    HVX_Vector v1 = vptr[1];  // second 64 vals
-    HVX_Vector v2 = vptr[2];  // third  64 vals
-    HVX_Vector v3 = vptr[3];  // forth  64 vals
-
-    HVX_Vector_x4 r = { v0, v1, v2, v3 };
-    return r;
-}
-
-static inline HVX_Vector_x4 hvx_vec_load_x4_f32_as_f16(const uint8_t * restrict ptr) {
-    const HVX_VectorPair * restrict vptr = (const HVX_VectorPair *) ptr;
-
-    HVX_VectorPair v0 = vptr[0];  // first  64 vals
-    HVX_VectorPair v1 = vptr[1];  // second 64 vals
-    HVX_VectorPair v2 = vptr[2];  // third  64 vals
-    HVX_VectorPair v3 = vptr[3];  // forth  64 vals
-
-    HVX_Vector vq0_lo = Q6_Vqf32_vsub_VsfVsf(Q6_V_lo_W(v0), Q6_V_vzero());
-    HVX_Vector vq0_hi = Q6_Vqf32_vsub_VsfVsf(Q6_V_hi_W(v0), Q6_V_vzero());
-    HVX_Vector vq1_lo = Q6_Vqf32_vsub_VsfVsf(Q6_V_lo_W(v1), Q6_V_vzero());
-    HVX_Vector vq1_hi = Q6_Vqf32_vsub_VsfVsf(Q6_V_hi_W(v1), Q6_V_vzero());
-    HVX_Vector vq2_lo = Q6_Vqf32_vsub_VsfVsf(Q6_V_lo_W(v2), Q6_V_vzero());
-    HVX_Vector vq2_hi = Q6_Vqf32_vsub_VsfVsf(Q6_V_hi_W(v2), Q6_V_vzero());
-    HVX_Vector vq3_lo = Q6_Vqf32_vsub_VsfVsf(Q6_V_lo_W(v3), Q6_V_vzero());
-    HVX_Vector vq3_hi = Q6_Vqf32_vsub_VsfVsf(Q6_V_hi_W(v3), Q6_V_vzero());
-
-    HVX_Vector vh0 = Q6_Vhf_equals_Wqf32(Q6_W_vcombine_VV(vq0_hi, vq0_lo));
-    HVX_Vector vh1 = Q6_Vhf_equals_Wqf32(Q6_W_vcombine_VV(vq1_hi, vq1_lo));
-    HVX_Vector vh2 = Q6_Vhf_equals_Wqf32(Q6_W_vcombine_VV(vq2_hi, vq2_lo));
-    HVX_Vector vh3 = Q6_Vhf_equals_Wqf32(Q6_W_vcombine_VV(vq3_hi, vq3_lo));
-
-    // vcombine does a shuffle, use vdeal to undo
-
-    HVX_Vector_x4 r = { Q6_Vh_vdeal_Vh(vh0), Q6_Vh_vdeal_Vh(vh1), Q6_Vh_vdeal_Vh(vh2), Q6_Vh_vdeal_Vh(vh3) };
-    return r;
-}
-
 // Reduce multiply 1024 x 1024 int8 elements (32x q4/8 blocks in 8x HVX vectors).
 // Accumulate each block into a single int32 value.
 // Return a single HVX vector with 32x int32 accumulators.
@@ -300,26 +280,26 @@ static inline HVX_Vector hvx_vec_rmpy_x8_nloe(HVX_Vector_x8 x, HVX_Vector_x8 y,
     return hvx_vec_rmpy_x8_n(x, y, 1024);
 }
 
-static void vec_dot_q4x4x2_q8x4x2(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) {
+static void vec_dot_q4x4x2_q8x4x2_1x1(const int n, float * restrict s0, const void * restrict vx0, const void * restrict vy0) {
     assert(n % 32 == 0);  // min sub-block size
-    assert((unsigned long) vx % 128 == 0);
-    assert((unsigned long) vy % 128 == 0);
+    assert((unsigned long) vx0 % 128 == 0);
+    assert((unsigned long) vy0 % 128 == 0);
 
     const uint32_t qk = QK_Q4_0x4x2 * 4;
 
-    const uint32_t x_dblk_size = 8 * 4 * 2;                                  // 32x __fp16
-    const uint32_t x_qblk_size = qk / 2;                                     // int4
-    const uint32_t x_qrow_size = n / 2;                                      // int4 (not padded)
+    const uint32_t x_dblk_size = 8 * 4 * 2;                                   // 32x __fp16
+    const uint32_t x_qblk_size = qk / 2;                                      // int4
+    const uint32_t x_qrow_size = n / 2;                                       // int4 (not padded)
 
-    const uint32_t y_dblk_size = 8 * 4 * 2;                                  // 32x __fp16
-    const uint32_t y_qblk_size = qk;                                         // int8
-    const uint32_t y_qrow_size = n;                                          // int8 (not padded)
+    const uint32_t y_dblk_size = 8 * 4 * 2;                                   // 32x __fp16
+    const uint32_t y_qblk_size = qk;                                          // int8
+    const uint32_t y_qrow_size = n;                                           // int8 (not padded)
 
-    const uint8_t * restrict r0_x_q = ((const uint8_t *) vx + 0);            // quants first
-    const uint8_t * restrict r0_x_d = ((const uint8_t *) vx + x_qrow_size);  // then scales
+    const uint8_t * restrict r0_x_q = ((const uint8_t *) vx0 + 0);            // quants first
+    const uint8_t * restrict r0_x_d = ((const uint8_t *) vx0 + x_qrow_size);  // then scales
 
-    const uint8_t * restrict y_q = ((const uint8_t *) vy + 0);               // quants first
-    const uint8_t * restrict y_d = ((const uint8_t *) vy + y_qrow_size);     // then scales
+    const uint8_t * restrict y_q = ((const uint8_t *) vy0 + 0);               // quants first
+    const uint8_t * restrict y_d = ((const uint8_t *) vy0 + y_qrow_size);     // then scales
 
     // Row sum (sf)
     HVX_Vector r0_sum = Q6_V_vsplat_R(0);
@@ -372,36 +352,34 @@ static void vec_dot_q4x4x2_q8x4x2(const int n, float * restrict s, const void *
 
     r0_sum = hvx_vec_reduce_sum_f32(r0_sum);
 
-    hvx_vec_store_u(&s[0], 4, r0_sum);
+    hvx_vec_store_u(s0, 4, r0_sum);
 }
 
-static void vec_dot_q4x4x2_q8x4x2_rx2(const int n,
-                                      float * restrict s,
-                                      const void * restrict vx,
-                                      uint32_t vx_row_size,
-                                      const void * restrict vy) {
+static void vec_dot_q4x4x2_q8x4x2_2x1(const int n, float * restrict s0,
+                                      const void * restrict vx0, const void * restrict vx1,
+                                      const void * restrict vy0) {
     assert(n % 32 == 0);  // min sub-block size
-    assert((unsigned long) vx % 128 == 0);
-    assert((unsigned long) vy % 128 == 0);
+    assert((unsigned long) vx0 % 128 == 0);
+    assert((unsigned long) vx1 % 128 == 0);
+    assert((unsigned long) vy0 % 128 == 0);
 
     const uint32_t qk = QK_Q4_0x4x2 * 4;
 
-    const uint32_t x_dblk_size = 8 * 4 * 2;                                                        // 32x __fp16
-    const uint32_t x_qblk_size = qk / 2;                                                           // int4
-    const uint32_t x_qrow_size = n / 2;                                                            // int4 (not padded)
-
-    const uint32_t y_dblk_size = 8 * 4 * 2;                                                        // 32x __fp16
-    const uint32_t y_qblk_size = qk;                                                               // int8
-    const uint32_t y_qrow_size = n;                                                                // int8 (not padded)
+    const uint32_t x_dblk_size = 8 * 4 * 2;                                   // 32x __fp16
+    const uint32_t x_qblk_size = qk / 2;                                      // int4
+    const uint32_t x_qrow_size = n / 2;                                       // int4 (not padded)
 
-    const uint8_t * restrict r0_x_q = ((const uint8_t *) (vx + (0 * vx_row_size)) + 0);            // quants first
-    const uint8_t * restrict r0_x_d = ((const uint8_t *) (vx + (0 * vx_row_size)) + x_qrow_size);  // then scales
+    const uint32_t y_dblk_size = 8 * 4 * 2;                                   // 32x __fp16
+    const uint32_t y_qblk_size = qk;                                          // int8
+    const uint32_t y_qrow_size = n;                                           // int8 (not padded)
 
-    const uint8_t * restrict r1_x_q = ((const uint8_t *) (vx + (1 * vx_row_size)) + 0);            // quants first
-    const uint8_t * restrict r1_x_d = ((const uint8_t *) (vx + (1 * vx_row_size)) + x_qrow_size);  // then scales
+    const uint8_t * restrict r0_x_q = ((const uint8_t *) vx0) + 0;            // quants first
+    const uint8_t * restrict r0_x_d = ((const uint8_t *) vx0) + x_qrow_size;  // then scales
+    const uint8_t * restrict r1_x_q = ((const uint8_t *) vx1) + 0;            // quants first
+    const uint8_t * restrict r1_x_d = ((const uint8_t *) vx1) + x_qrow_size;  // then scales
 
-    const uint8_t * restrict y_q = ((const uint8_t *) vy + 0);                                     // quants first
-    const uint8_t * restrict y_d = ((const uint8_t *) vy + y_qrow_size);                           // then scales
+    const uint8_t * restrict y_q = ((const uint8_t *) vy0 + 0);               // quants first
+    const uint8_t * restrict y_d = ((const uint8_t *) vy0 + y_qrow_size);     // then scales
 
     // Row sum (sf)
     HVX_Vector r0_sum = Q6_V_vsplat_R(0);
@@ -468,13 +446,143 @@ static void vec_dot_q4x4x2_q8x4x2_rx2(const int n,
     }
 
     HVX_Vector rsum = hvx_vec_reduce_sum_f32x2(r0_sum, r1_sum);
-    hvx_vec_store_u(&s[0], 8, rsum);
+    hvx_vec_store_u(s0, 8, rsum);
 }
 
-static void vec_dot_q8x4x2_q8x4x2(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) {
+static void vec_dot_q4x4x2_q8x4x2_2x2(const int n, float * restrict s0, float * restrict s1,
+                                        const void * restrict vx0, const void * restrict vx1,
+                                        const void * restrict vy0, const void * restrict vy1) {
+    assert(n % 32 == 0);
+    assert((unsigned long) vx0 % 128 == 0);
+    assert((unsigned long) vx1 % 128 == 0);
+    assert((unsigned long) vy0 % 128 == 0);
+    assert((unsigned long) vy1 % 128 == 0);
+
+    const uint32_t qk = QK_Q4_0x4x2 * 4;
+
+    const uint32_t x_dblk_size = 8 * 4 * 2;                                   // 32x __fp16
+    const uint32_t x_qblk_size = qk / 2;                                      // int4
+    const uint32_t x_qrow_size = n / 2;                                       // int4 (not padded)
+
+    const uint32_t y_dblk_size = 8 * 4 * 2;                                   // 32x __fp16
+    const uint32_t y_qblk_size = qk;                                          // int8
+    const uint32_t y_qrow_size = n;                                           // int8 (not padded)
+
+    const uint8_t * restrict r0_x_q = ((const uint8_t *) vx0) + 0;            // quants first
+    const uint8_t * restrict r0_x_d = ((const uint8_t *) vx0) + x_qrow_size;  // then scales
+    const uint8_t * restrict r1_x_q = ((const uint8_t *) vx1) + 0;            // quants first
+    const uint8_t * restrict r1_x_d = ((const uint8_t *) vx1) + x_qrow_size;  // then scales
+
+    const uint8_t * restrict y0_q = ((const uint8_t *) vy0) + 0;              // quants first
+    const uint8_t * restrict y0_d = ((const uint8_t *) vy0) + y_qrow_size;    // then scales
+    const uint8_t * restrict y1_q = ((const uint8_t *) vy1) + 0;              // quants first
+    const uint8_t * restrict y1_d = ((const uint8_t *) vy1) + y_qrow_size;    // then scales
+
+    // Row sums (sf) - 4 accumulators for 2×2 tile
+    HVX_Vector r0_c0_sum = Q6_V_vsplat_R(0);
+    HVX_Vector r0_c1_sum = Q6_V_vsplat_R(0);
+    HVX_Vector r1_c0_sum = Q6_V_vsplat_R(0);
+    HVX_Vector r1_c1_sum = Q6_V_vsplat_R(0);
+
+    const uint32_t nb   = n / qk;  // num full blocks
+    const uint32_t nloe = n % qk;  // num leftover elements
+
+    uint32_t i = 0;
+    for (; i < nb; i++) {
+        // Load src1 columns (reused across both src0 rows)
+        HVX_Vector_x8 vy0_q = hvx_vec_load_q8x4x8(y0_q + i * y_qblk_size);
+        HVX_Vector_x8 vy1_q = hvx_vec_load_q8x4x8(y1_q + i * y_qblk_size);
+
+        // Load src0 rows (reused across both src1 columns)
+        HVX_Vector_x8 r0_q = hvx_vec_load_q4x4x8(r0_x_q + i * x_qblk_size);
+        HVX_Vector_x8 r1_q = hvx_vec_load_q4x4x8(r1_x_q + i * x_qblk_size);
+
+        // Compute 4 dot products: r0×c0, r0×c1, r1×c0, r1×c1
+        HVX_Vector r0_c0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r0_q, vy0_q));
+        HVX_Vector r0_c1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r0_q, vy1_q));
+        HVX_Vector r1_c0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r1_q, vy0_q));
+        HVX_Vector r1_c1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r1_q, vy1_q));
+
+        // Load scales
+        HVX_Vector vy0_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y0_d + i * y_dblk_size));
+        HVX_Vector vy1_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y1_d + i * y_dblk_size));
+        HVX_Vector r0_d  = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r0_x_d + i * x_dblk_size));
+        HVX_Vector r1_d  = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r1_x_d + i * x_dblk_size));
+
+        // Compute combined scales
+        HVX_Vector r0_c0_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_d, vy0_d)));
+        HVX_Vector r0_c1_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_d, vy1_d)));
+        HVX_Vector r1_c0_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r1_d, vy0_d)));
+        HVX_Vector r1_c1_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r1_d, vy1_d)));
+
+        // Apply scales and accumulate
+        HVX_Vector r0_c0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_c0_ia, r0_c0_dd);
+        HVX_Vector r0_c1_fa = Q6_Vqf32_vmpy_VsfVsf(r0_c1_ia, r0_c1_dd);
+        HVX_Vector r1_c0_fa = Q6_Vqf32_vmpy_VsfVsf(r1_c0_ia, r1_c0_dd);
+        HVX_Vector r1_c1_fa = Q6_Vqf32_vmpy_VsfVsf(r1_c1_ia, r1_c1_dd);
+
+        r0_c0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_c0_fa, r0_c0_sum));
+        r0_c1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_c1_fa, r0_c1_sum));
+        r1_c0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_c0_fa, r1_c0_sum));
+        r1_c1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_c1_fa, r1_c1_sum));
+    }
+
+    // Process leftovers
+    if (nloe) {
+        HVX_Vector_x8 vy0_q = hvx_vec_load_q8x4x8(y0_q + i * y_qblk_size);
+        HVX_Vector_x8 vy1_q = hvx_vec_load_q8x4x8(y1_q + i * y_qblk_size);
+        HVX_Vector_x8 r0_q = hvx_vec_load_q4x4x8(r0_x_q + i * x_qblk_size);
+        HVX_Vector_x8 r1_q = hvx_vec_load_q4x4x8(r1_x_q + i * x_qblk_size);
+
+        HVX_Vector r0_c0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_nloe(r0_q, vy0_q, nloe));
+        HVX_Vector r0_c1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_nloe(r0_q, vy1_q, nloe));
+        HVX_Vector r1_c0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_nloe(r1_q, vy0_q, nloe));
+        HVX_Vector r1_c1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_nloe(r1_q, vy1_q, nloe));
+
+        HVX_Vector vy0_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y0_d + i * y_dblk_size));
+        HVX_Vector vy1_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y1_d + i * y_dblk_size));
+        HVX_Vector r0_d  = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r0_x_d + i * x_dblk_size));
+        HVX_Vector r1_d  = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r1_x_d + i * x_dblk_size));
+
+        HVX_Vector r0_c0_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_d, vy0_d)));
+        HVX_Vector r0_c1_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_d, vy1_d)));
+        HVX_Vector r1_c0_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r1_d, vy0_d)));
+        HVX_Vector r1_c1_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r1_d, vy1_d)));
+
+        // Zero out unused scales
+        HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe / 8);
+        r0_c0_dd = Q6_V_vand_QV(bmask, r0_c0_dd);
+        r0_c1_dd = Q6_V_vand_QV(bmask, r0_c1_dd);
+        r1_c0_dd = Q6_V_vand_QV(bmask, r1_c0_dd);
+        r1_c1_dd = Q6_V_vand_QV(bmask, r1_c1_dd);
+        r0_c0_ia = Q6_V_vand_QV(bmask, r0_c0_ia);
+        r0_c1_ia = Q6_V_vand_QV(bmask, r0_c1_ia);
+        r1_c0_ia = Q6_V_vand_QV(bmask, r1_c0_ia);
+        r1_c1_ia = Q6_V_vand_QV(bmask, r1_c1_ia);
+
+        HVX_Vector r0_c0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_c0_ia, r0_c0_dd);
+        HVX_Vector r0_c1_fa = Q6_Vqf32_vmpy_VsfVsf(r0_c1_ia, r0_c1_dd);
+        HVX_Vector r1_c0_fa = Q6_Vqf32_vmpy_VsfVsf(r1_c0_ia, r1_c0_dd);
+        HVX_Vector r1_c1_fa = Q6_Vqf32_vmpy_VsfVsf(r1_c1_ia, r1_c1_dd);
+
+        r0_c0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_c0_fa, r0_c0_sum));
+        r0_c1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_c1_fa, r0_c1_sum));
+        r1_c0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_c0_fa, r1_c0_sum));
+        r1_c1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_c1_fa, r1_c1_sum));
+    }
+
+    // Reduce and store results
+    HVX_Vector r0_r1_c0_sum = hvx_vec_reduce_sum_f32x2(r0_c0_sum, r1_c0_sum);
+    HVX_Vector r0_r1_c1_sum = hvx_vec_reduce_sum_f32x2(r0_c1_sum, r1_c1_sum);
+
+    hvx_vec_store_u(s0, 8, r0_r1_c0_sum);  // row0,col0 row1,col0
+    hvx_vec_store_u(s1, 8, r0_r1_c1_sum);  // row0,col1 row1,col1
+}
+
+static void vec_dot_q8x4x2_q8x4x2_1x1(const int n, float * restrict s0, const void * restrict vx0, const void * restrict vy0) {
     assert(n % 32 == 0);  // min sub-block size
-    assert((unsigned long) vx % 128 == 0);
-    assert((unsigned long) vy % 128 == 0);
+    assert((unsigned long) vx0 % 128 == 0);
+    assert((unsigned long) vy0 % 128 == 0);
 
     const uint32_t qk = QK_Q4_0x4x2 * 4;
 
@@ -486,11 +594,11 @@ static void vec_dot_q8x4x2_q8x4x2(const int n, float * restrict s, const void *
     const uint32_t y_qblk_size = qk;                                         // int8
     const uint32_t y_qrow_size = n;                                          // int8 (not padded)
 
-    const uint8_t * restrict r0_x_q = ((const uint8_t *) vx + 0);            // quants first
-    const uint8_t * restrict r0_x_d = ((const uint8_t *) vx + x_qrow_size);  // then scales
+    const uint8_t * restrict r0_x_q = ((const uint8_t *) vx0 + 0);           // quants first
+    const uint8_t * restrict r0_x_d = ((const uint8_t *) vx0 + x_qrow_size); // then scales
 
-    const uint8_t * restrict y_q = ((const uint8_t *) vy + 0);               // quants first
-    const uint8_t * restrict y_d = ((const uint8_t *) vy + y_qrow_size);     // then scales
+    const uint8_t * restrict y_q = ((const uint8_t *) vy0 + 0);              // quants first
+    const uint8_t * restrict y_d = ((const uint8_t *) vy0 + y_qrow_size);    // then scales
 
     // Row sum (sf)
     HVX_Vector r0_sum = Q6_V_vsplat_R(0);
@@ -543,36 +651,34 @@ static void vec_dot_q8x4x2_q8x4x2(const int n, float * restrict s, const void *
 
     r0_sum = hvx_vec_reduce_sum_f32(r0_sum);
 
-    hvx_vec_store_u(&s[0], 4, r0_sum);
+    hvx_vec_store_u(s0, 4, r0_sum);
 }
 
-static void vec_dot_q8x4x2_q8x4x2_rx2(const int n,
-                                      float * restrict s,
-                                      const void * restrict vx,
-                                      uint32_t vx_row_size,
-                                      const void * restrict vy) {
+static void vec_dot_q8x4x2_q8x4x2_2x1(const int n, float * restrict s0,
+                                      const void * restrict vx0, const void * restrict vx1,
+                                      const void * restrict vy0) {
     assert(n % 32 == 0);  // min sub-block size
-    assert((unsigned long) vx % 128 == 0);
-    assert((unsigned long) vy % 128 == 0);
+    assert((unsigned long) vx0 % 128 == 0);
+    assert((unsigned long) vx1 % 128 == 0);
+    assert((unsigned long) vy0 % 128 == 0);
 
     const uint32_t qk = QK_Q4_0x4x2 * 4;
 
-    const uint32_t x_dblk_size = 8 * 4 * 2;                                                        // 32x __fp16
-    const uint32_t x_qblk_size = qk;                                                               // int8
-    const uint32_t x_qrow_size = n;                                                                // int8 (not padded)
+    const uint32_t x_dblk_size = 8 * 4 * 2;                                   // 32x __fp16
+    const uint32_t x_qblk_size = qk;                                          // int8
+    const uint32_t x_qrow_size = n;                                           // int8 (not padded)
 
-    const uint32_t y_dblk_size = 8 * 4 * 2;                                                        // 32x __fp16
-    const uint32_t y_qblk_size = qk;                                                               // int8
-    const uint32_t y_qrow_size = n;                                                                // int8 (not padded)
+    const uint32_t y_dblk_size = 8 * 4 * 2;                                   // 32x __fp16
+    const uint32_t y_qblk_size = qk;                                          // int8
+    const uint32_t y_qrow_size = n;                                           // int8 (not padded)
 
-    const uint8_t * restrict r0_x_q = ((const uint8_t *) (vx + (0 * vx_row_size)) + 0);            // quants first
-    const uint8_t * restrict r0_x_d = ((const uint8_t *) (vx + (0 * vx_row_size)) + x_qrow_size);  // then scales
+    const uint8_t * restrict r0_x_q = ((const uint8_t *) vx0) + 0;            // quants first
+    const uint8_t * restrict r0_x_d = ((const uint8_t *) vx0) + x_qrow_size;  // then scales
+    const uint8_t * restrict r1_x_q = ((const uint8_t *) vx1) + 0;            // quants first
+    const uint8_t * restrict r1_x_d = ((const uint8_t *) vx1) + x_qrow_size;  // then scales
 
-    const uint8_t * restrict r1_x_q = ((const uint8_t *) (vx + (1 * vx_row_size)) + 0);            // quants first
-    const uint8_t * restrict r1_x_d = ((const uint8_t *) (vx + (1 * vx_row_size)) + x_qrow_size);  // then scales
-
-    const uint8_t * restrict y_q = ((const uint8_t *) vy + 0);                                     // quants first
-    const uint8_t * restrict y_d = ((const uint8_t *) vy + y_qrow_size);                           // then scales
+    const uint8_t * restrict y_q = ((const uint8_t *) vy0 + 0);               // quants first
+    const uint8_t * restrict y_d = ((const uint8_t *) vy0 + y_qrow_size);     // then scales
 
     // Row sum (qf32)
     HVX_Vector r0_sum = Q6_V_vsplat_R(0);
@@ -639,16 +745,143 @@ static void vec_dot_q8x4x2_q8x4x2_rx2(const int n,
     }
 
     HVX_Vector rsum = hvx_vec_reduce_sum_f32x2(r0_sum, r1_sum);
-    hvx_vec_store_u(&s[0], 8, rsum);
+    hvx_vec_store_u(s0, 8, rsum);
+}
+
+static void vec_dot_q8x4x2_q8x4x2_2x2(const int n, float * restrict s0, float * restrict s1,
+                                        const void * restrict vx0, const void * restrict vx1,
+                                        const void * restrict vy0, const void * restrict vy1) {
+    assert(n % 32 == 0);
+    assert((unsigned long) vx0 % 128 == 0);
+    assert((unsigned long) vx1 % 128 == 0);
+    assert((unsigned long) vy0 % 128 == 0);
+    assert((unsigned long) vy1 % 128 == 0);
+
+    const uint32_t qk = QK_Q8_0x4x2 * 4;
+
+    const uint32_t x_dblk_size = 8 * 4 * 2;                                   // 32x __fp16
+    const uint32_t x_qblk_size = qk;                                          // int8
+    const uint32_t x_qrow_size = n;                                           // int8 (not padded)
+
+    const uint32_t y_dblk_size = 8 * 4 * 2;                                   // 32x __fp16
+    const uint32_t y_qblk_size = qk;                                          // int8
+    const uint32_t y_qrow_size = n;                                           // int8 (not padded)
+
+    const uint8_t * restrict r0_x_q = ((const uint8_t *) vx0) + 0;            // quants first
+    const uint8_t * restrict r0_x_d = ((const uint8_t *) vx0) + x_qrow_size;  // then scales
+    const uint8_t * restrict r1_x_q = ((const uint8_t *) vx1) + 0;            // quants first
+    const uint8_t * restrict r1_x_d = ((const uint8_t *) vx1) + x_qrow_size;  // then scales
+
+    const uint8_t * restrict y0_q = ((const uint8_t *) vy0) + 0;              // quants first
+    const uint8_t * restrict y0_d = ((const uint8_t *) vy0) + y_qrow_size;    // then scales
+    const uint8_t * restrict y1_q = ((const uint8_t *) vy1) + 0;              // quants first
+    const uint8_t * restrict y1_d = ((const uint8_t *) vy1) + y_qrow_size;    // then scales
+
+    // Row sums (sf) - 4 accumulators for 2×2 tile
+    HVX_Vector r0_c0_sum = Q6_V_vsplat_R(0);
+    HVX_Vector r0_c1_sum = Q6_V_vsplat_R(0);
+    HVX_Vector r1_c0_sum = Q6_V_vsplat_R(0);
+    HVX_Vector r1_c1_sum = Q6_V_vsplat_R(0);
+
+    const uint32_t nb   = n / qk;  // num full blocks
+    const uint32_t nloe = n % qk;  // num leftover elements
+
+    uint32_t i = 0;
+    for (; i < nb; i++) {
+        // Load src1 columns (reused across both src0 rows)
+        HVX_Vector_x8 vy0_q = hvx_vec_load_q8x4x8(y0_q + i * y_qblk_size);
+        HVX_Vector_x8 vy1_q = hvx_vec_load_q8x4x8(y1_q + i * y_qblk_size);
+
+        // Load src0 rows (reused across both src1 columns)
+        HVX_Vector_x8 r0_q = hvx_vec_load_q8x4x8(r0_x_q + i * x_qblk_size);
+        HVX_Vector_x8 r1_q = hvx_vec_load_q8x4x8(r1_x_q + i * x_qblk_size);
+
+        // Compute 4 dot products: r0×c0, r0×c1, r1×c0, r1×c1
+        HVX_Vector r0_c0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r0_q, vy0_q));
+        HVX_Vector r0_c1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r0_q, vy1_q));
+        HVX_Vector r1_c0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r1_q, vy0_q));
+        HVX_Vector r1_c1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r1_q, vy1_q));
+
+        // Load scales
+        HVX_Vector vy0_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y0_d + i * y_dblk_size));
+        HVX_Vector vy1_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y1_d + i * y_dblk_size));
+        HVX_Vector r0_d  = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r0_x_d + i * x_dblk_size));
+        HVX_Vector r1_d  = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r1_x_d + i * x_dblk_size));
+
+        // Compute combined scales
+        HVX_Vector r0_c0_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_d, vy0_d)));
+        HVX_Vector r0_c1_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_d, vy1_d)));
+        HVX_Vector r1_c0_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r1_d, vy0_d)));
+        HVX_Vector r1_c1_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r1_d, vy1_d)));
+
+        // Apply scales and accumulate
+        HVX_Vector r0_c0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_c0_ia, r0_c0_dd);
+        HVX_Vector r0_c1_fa = Q6_Vqf32_vmpy_VsfVsf(r0_c1_ia, r0_c1_dd);
+        HVX_Vector r1_c0_fa = Q6_Vqf32_vmpy_VsfVsf(r1_c0_ia, r1_c0_dd);
+        HVX_Vector r1_c1_fa = Q6_Vqf32_vmpy_VsfVsf(r1_c1_ia, r1_c1_dd);
+
+        r0_c0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_c0_fa, r0_c0_sum));
+        r0_c1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_c1_fa, r0_c1_sum));
+        r1_c0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_c0_fa, r1_c0_sum));
+        r1_c1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_c1_fa, r1_c1_sum));
+    }
+
+    // Process leftovers
+    if (nloe) {
+        HVX_Vector_x8 vy0_q = hvx_vec_load_q8x4x8(y0_q + i * y_qblk_size);
+        HVX_Vector_x8 vy1_q = hvx_vec_load_q8x4x8(y1_q + i * y_qblk_size);
+        HVX_Vector_x8 r0_q  = hvx_vec_load_q8x4x8(r0_x_q + i * x_qblk_size);
+        HVX_Vector_x8 r1_q  = hvx_vec_load_q8x4x8(r1_x_q + i * x_qblk_size);
+
+        HVX_Vector r0_c0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_nloe(r0_q, vy0_q, nloe));
+        HVX_Vector r0_c1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_nloe(r0_q, vy1_q, nloe));
+        HVX_Vector r1_c0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_nloe(r1_q, vy0_q, nloe));
+        HVX_Vector r1_c1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_nloe(r1_q, vy1_q, nloe));
+
+        HVX_Vector vy0_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y0_d + i * y_dblk_size));
+        HVX_Vector vy1_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y1_d + i * y_dblk_size));
+        HVX_Vector r0_d  = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r0_x_d + i * x_dblk_size));
+        HVX_Vector r1_d  = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r1_x_d + i * x_dblk_size));
+
+        HVX_Vector r0_c0_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_d, vy0_d)));
+        HVX_Vector r0_c1_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_d, vy1_d)));
+        HVX_Vector r1_c0_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r1_d, vy0_d)));
+        HVX_Vector r1_c1_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r1_d, vy1_d)));
+
+        // Zero out unused scales
+        HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe / 8);
+        r0_c0_dd = Q6_V_vand_QV(bmask, r0_c0_dd);
+        r0_c1_dd = Q6_V_vand_QV(bmask, r0_c1_dd);
+        r1_c0_dd = Q6_V_vand_QV(bmask, r1_c0_dd);
+        r1_c1_dd = Q6_V_vand_QV(bmask, r1_c1_dd);
+        r0_c0_ia = Q6_V_vand_QV(bmask, r0_c0_ia);
+        r0_c1_ia = Q6_V_vand_QV(bmask, r0_c1_ia);
+        r1_c0_ia = Q6_V_vand_QV(bmask, r1_c0_ia);
+        r1_c1_ia = Q6_V_vand_QV(bmask, r1_c1_ia);
+
+        HVX_Vector r0_c0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_c0_ia, r0_c0_dd);
+        HVX_Vector r0_c1_fa = Q6_Vqf32_vmpy_VsfVsf(r0_c1_ia, r0_c1_dd);
+        HVX_Vector r1_c0_fa = Q6_Vqf32_vmpy_VsfVsf(r1_c0_ia, r1_c0_dd);
+        HVX_Vector r1_c1_fa = Q6_Vqf32_vmpy_VsfVsf(r1_c1_ia, r1_c1_dd);
+
+        r0_c0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_c0_fa, r0_c0_sum));
+        r0_c1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_c1_fa, r0_c1_sum));
+        r1_c0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_c0_fa, r1_c0_sum));
+        r1_c1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_c1_fa, r1_c1_sum));
+    }
+
+    // Reduce and store results
+    HVX_Vector r0_r1_c0_sum = hvx_vec_reduce_sum_f32x2(r0_c0_sum, r1_c0_sum);
+    HVX_Vector r0_r1_c1_sum = hvx_vec_reduce_sum_f32x2(r0_c1_sum, r1_c1_sum);
+
+    hvx_vec_store_u(&s0[0], 8, r0_r1_c0_sum);  // row0,col0 row1,col0
+    hvx_vec_store_u(&s1[0], 8, r0_r1_c1_sum);  // row0,col1 row1,col1
 }
 
-static void vec_dot_mxfp4x4x2_q8x4x2(const int n,
-                                     float * restrict s,
-                                     const void * restrict vx,
-                                     const void * restrict vy) {
+static void vec_dot_mxfp4x4x2_q8x4x2_1x1(const int n, float * restrict s0, const void * restrict vx0, const void * restrict vy0) {
     assert(n % 32 == 0);  // min sub-block size
-    assert((unsigned long) vx % 128 == 0);
-    assert((unsigned long) vy % 128 == 0);
+    assert((unsigned long) vx0 % 128 == 0);
+    assert((unsigned long) vy0 % 128 == 0);
 
     const uint32_t qk = QK_MXFP4x4x2 * 4;
 
@@ -660,11 +893,11 @@ static void vec_dot_mxfp4x4x2_q8x4x2(const int n,
     const uint32_t y_qblk_size = qk;                                         // int8
     const uint32_t y_qrow_size = n;                                          // int8 (not padded)
 
-    const uint8_t * restrict r0_x_q = ((const uint8_t *) vx + 0);            // quants first
-    const uint8_t * restrict r0_x_d = ((const uint8_t *) vx + x_qrow_size);  // then scales
+    const uint8_t * restrict r0_x_q = ((const uint8_t *) vx0 + 0);           // quants first
+    const uint8_t * restrict r0_x_d = ((const uint8_t *) vx0 + x_qrow_size); // then scales
 
-    const uint8_t * restrict y_q = ((const uint8_t *) vy + 0);               // quants first
-    const uint8_t * restrict y_d = ((const uint8_t *) vy + y_qrow_size);     // then scales
+    const uint8_t * restrict y_q = ((const uint8_t *) vy0 + 0);              // quants first
+    const uint8_t * restrict y_d = ((const uint8_t *) vy0 + y_qrow_size);    // then scales
 
     // Row sum (sf)
     HVX_Vector r0_sum = Q6_V_vsplat_R(0);
@@ -747,36 +980,34 @@ static void vec_dot_mxfp4x4x2_q8x4x2(const int n,
 
     r0_sum = hvx_vec_reduce_sum_f32(r0_sum);
 
-    hvx_vec_store_u(&s[0], 4, r0_sum);
+    hvx_vec_store_u(s0, 4, r0_sum);
 }
 
-static void vec_dot_mxfp4x4x2_q8x4x2_rx2(const int n,
-                                         float * restrict s,
-                                         const void * restrict vx,
-                                         uint32_t vx_row_size,
-                                         const void * restrict vy) {
+static void vec_dot_mxfp4x4x2_q8x4x2_2x1(const int n, float * restrict s0,
+                                      const void * restrict vx0, const void * restrict vx1,
+                                      const void * restrict vy0) {
     assert(n % 32 == 0);  // min sub-block size
-    assert((unsigned long) vx % 128 == 0);
-    assert((unsigned long) vy % 128 == 0);
+    assert((unsigned long) vx0 % 128 == 0);
+    assert((unsigned long) vx1 % 128 == 0);
+    assert((unsigned long) vy0 % 128 == 0);
 
     const uint32_t qk = QK_MXFP4x4x2 * 4;
 
-    const uint32_t x_dblk_size = 8 * 4 * 1;                                                        // 32x e8m0
-    const uint32_t x_qblk_size = qk / 2;                                                           // fp4
-    const uint32_t x_qrow_size = n / 2;                                                            // fp4 (not padded)
+    const uint32_t x_dblk_size = 8 * 4 * 1;                                   // 32x e8m0
+    const uint32_t x_qblk_size = qk / 2;                                      // fp4
+    const uint32_t x_qrow_size = n / 2;                                       // fp4 (not padded)
 
-    const uint32_t y_dblk_size = 8 * 4 * 2;                                                        // 32x __fp16
-    const uint32_t y_qblk_size = qk;                                                               // int8
-    const uint32_t y_qrow_size = n;                                                                // int8 (not padded)
+    const uint32_t y_dblk_size = 8 * 4 * 2;                                   // 32x __fp16
+    const uint32_t y_qblk_size = qk;                                          // int8
+    const uint32_t y_qrow_size = n;                                           // int8 (not padded)
 
-    const uint8_t * restrict r0_x_q = ((const uint8_t *) (vx + (0 * vx_row_size)) + 0);            // quants first
-    const uint8_t * restrict r0_x_d = ((const uint8_t *) (vx + (0 * vx_row_size)) + x_qrow_size);  // then scales
+    const uint8_t * restrict r0_x_q = ((const uint8_t *) vx0) + 0;            // quants first
+    const uint8_t * restrict r0_x_d = ((const uint8_t *) vx0) + x_qrow_size;  // then scales
+    const uint8_t * restrict r1_x_q = ((const uint8_t *) vx1) + 0;            // quants first
+    const uint8_t * restrict r1_x_d = ((const uint8_t *) vx1) + x_qrow_size;  // then scales
 
-    const uint8_t * restrict r1_x_q = ((const uint8_t *) (vx + (1 * vx_row_size)) + 0);            // quants first
-    const uint8_t * restrict r1_x_d = ((const uint8_t *) (vx + (1 * vx_row_size)) + x_qrow_size);  // then scales
-
-    const uint8_t * restrict y_q = ((const uint8_t *) vy + 0);                                     // quants first
-    const uint8_t * restrict y_d = ((const uint8_t *) vy + y_qrow_size);                           // then scales
+    const uint8_t * restrict y_q = ((const uint8_t *) vy0) + 0;               // quants first
+    const uint8_t * restrict y_d = ((const uint8_t *) vy0) + y_qrow_size;     // then scales
 
     // Row sum (sf)
     HVX_Vector r0_sum = Q6_V_vsplat_R(0);
@@ -879,10 +1110,180 @@ static void vec_dot_mxfp4x4x2_q8x4x2_rx2(const int n,
     }
 
     HVX_Vector rsum = hvx_vec_reduce_sum_f32x2(r0_sum, r1_sum);
-    hvx_vec_store_u(&s[0], 8, rsum);
+    hvx_vec_store_u(s0, 8, rsum);
 }
 
-static void vec_dot_f16_f16_aa(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) {
+static void vec_dot_mxfp4x4x2_q8x4x2_2x2(const int n, float * restrict s0, float * restrict s1,
+                                        const void * restrict vx0, const void * restrict vx1,
+                                        const void * restrict vy0, const void * restrict vy1) {
+    assert(n % 32 == 0);
+    assert((unsigned long) vx0 % 128 == 0);
+    assert((unsigned long) vx1 % 128 == 0);
+    assert((unsigned long) vy0 % 128 == 0);
+    assert((unsigned long) vy1 % 128 == 0);
+
+    const uint32_t qk = QK_MXFP4x4x2 * 4;
+
+    const uint32_t x_dblk_size = 8 * 4 * 1;                                   // 32x e8m0
+    const uint32_t x_qblk_size = qk / 2;                                      // fp4
+    const uint32_t x_qrow_size = n / 2;                                       // fp4 (not padded)
+
+    const uint32_t y_dblk_size = 8 * 4 * 2;                                   // 32x __fp16
+    const uint32_t y_qblk_size = qk;                                          // int8
+    const uint32_t y_qrow_size = n;                                           // int8 (not padded)
+
+    const uint8_t * restrict r0_x_q = ((const uint8_t *) vx0) + 0;            // quants first
+    const uint8_t * restrict r0_x_d = ((const uint8_t *) vx0) + x_qrow_size;  // then scales
+    const uint8_t * restrict r1_x_q = ((const uint8_t *) vx1) + 0;            // quants first
+    const uint8_t * restrict r1_x_d = ((const uint8_t *) vx1) + x_qrow_size;  // then scales
+
+    const uint8_t * restrict y0_q = ((const uint8_t *) vy0) + 0;              // quants first
+    const uint8_t * restrict y0_d = ((const uint8_t *) vy0) + y_qrow_size;    // then scales
+    const uint8_t * restrict y1_q = ((const uint8_t *) vy1) + 0;              // quants first
+    const uint8_t * restrict y1_d = ((const uint8_t *) vy1) + y_qrow_size;    // then scales
+
+    // Row sums (sf) - 4 accumulators for 2×2 tile
+    HVX_Vector r0_c0_sum = Q6_V_vsplat_R(0);
+    HVX_Vector r0_c1_sum = Q6_V_vsplat_R(0);
+    HVX_Vector r1_c0_sum = Q6_V_vsplat_R(0);
+    HVX_Vector r1_c1_sum = Q6_V_vsplat_R(0);
+
+    const uint32_t nb   = n / qk;  // num full blocks
+    const uint32_t nloe = n % qk;  // num leftover elements
+
+    uint32_t i = 0;
+    for (; i < nb; i++) {
+        // Load src1 columns (reused across both src0 rows)
+        HVX_Vector_x8 vy0_q = hvx_vec_load_q8x4x8(y0_q + i * y_qblk_size);
+        HVX_Vector_x8 vy1_q = hvx_vec_load_q8x4x8(y1_q + i * y_qblk_size);
+
+        // Load src0 rows (reused across both src1 columns)
+        HVX_Vector_x8 r0_q = hvx_vec_load_mxfp4x4x8(r0_x_q + i * x_qblk_size);
+        HVX_Vector_x8 r1_q = hvx_vec_load_mxfp4x4x8(r1_x_q + i * x_qblk_size);
+
+        // Compute 4 dot products: r0×c0, r0×c1, r1×c0, r1×c1
+        HVX_Vector r0_c0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r0_q, vy0_q));
+        HVX_Vector r0_c1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r0_q, vy1_q));
+        HVX_Vector r1_c0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r1_q, vy0_q));
+        HVX_Vector r1_c1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r1_q, vy1_q));
+
+        // Load scales
+        HVX_Vector vy0_d = *(const HVX_UVector *) (y0_d   + i * y_dblk_size);
+        HVX_Vector vy1_d = *(const HVX_UVector *) (y1_d   + i * y_dblk_size);
+        HVX_Vector r0_d  = *(const HVX_UVector *) (r0_x_d + i * x_dblk_size);
+        HVX_Vector r1_d  = *(const HVX_UVector *) (r1_x_d + i * x_dblk_size);
+
+        // Convert vy_d from fp16 to fp32 while applying 0.5 scaling which is used for e8m0 halving
+        HVX_Vector half = Q6_Vh_vsplat_R(0x3800);  // 0.5 in fp16
+        vy0_d           = Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(Q6_Vh_vshuff_Vh(vy0_d), half));
+        vy0_d           = Q6_Vsf_equals_Vqf32(vy0_d);
+        vy1_d           = Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(Q6_Vh_vshuff_Vh(vy1_d), half));
+        vy1_d           = Q6_Vsf_equals_Vqf32(vy1_d);
+
+        // Convert rX_d scales from e8m0 to fp32
+        // Expand and zero-pad 32x uint8 e8m0 values to uint32s : 0 0 0 0, 0 0 0 1, 0 0 0 2, ...
+        // Left shift with zero fill to create FP32
+        // FIXME: might need to handle zero as a special case (see ggml-cpu code)
+        HVX_Vector expand    = *(const HVX_Vector *) expand_x32_e8m0;
+        HVX_Vector e8m0_mask = Q6_V_vsplat_R(0x000000ff);
+        r0_d                 = Q6_V_vdelta_VV(r0_d, expand);
+        r0_d                 = Q6_V_vand_VV(r0_d, e8m0_mask);
+        r0_d                 = Q6_Vw_vasl_VwR(r0_d, 23);
+        r1_d                 = Q6_V_vdelta_VV(r1_d, expand);
+        r1_d                 = Q6_V_vand_VV(r1_d, e8m0_mask);
+        r1_d                 = Q6_Vw_vasl_VwR(r1_d, 23);
+
+        // Compute combined scales
+        HVX_Vector r0_c0_dd = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(r0_d, vy0_d));
+        HVX_Vector r0_c1_dd = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(r0_d, vy1_d));
+        HVX_Vector r1_c0_dd = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(r1_d, vy0_d));
+        HVX_Vector r1_c1_dd = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(r1_d, vy1_d));
+
+        // Apply scales and accumulate
+        HVX_Vector r0_c0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_c0_ia, r0_c0_dd);
+        HVX_Vector r0_c1_fa = Q6_Vqf32_vmpy_VsfVsf(r0_c1_ia, r0_c1_dd);
+        HVX_Vector r1_c0_fa = Q6_Vqf32_vmpy_VsfVsf(r1_c0_ia, r1_c0_dd);
+        HVX_Vector r1_c1_fa = Q6_Vqf32_vmpy_VsfVsf(r1_c1_ia, r1_c1_dd);
+
+        r0_c0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_c0_fa, r0_c0_sum));
+        r0_c1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_c1_fa, r0_c1_sum));
+        r1_c0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_c0_fa, r1_c0_sum));
+        r1_c1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_c1_fa, r1_c1_sum));
+    }
+
+    // Process leftovers
+    if (nloe) {
+        HVX_Vector_x8 vy0_q = hvx_vec_load_q8x4x8(y0_q + i * y_qblk_size);
+        HVX_Vector_x8 vy1_q = hvx_vec_load_q8x4x8(y1_q + i * y_qblk_size);
+        HVX_Vector_x8 r0_q  = hvx_vec_load_mxfp4x4x8(r0_x_q + i * x_qblk_size);
+        HVX_Vector_x8 r1_q  = hvx_vec_load_mxfp4x4x8(r1_x_q + i * x_qblk_size);
+
+        HVX_Vector r0_c0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_nloe(r0_q, vy0_q, nloe));
+        HVX_Vector r0_c1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_nloe(r0_q, vy1_q, nloe));
+        HVX_Vector r1_c0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_nloe(r1_q, vy0_q, nloe));
+        HVX_Vector r1_c1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_nloe(r1_q, vy1_q, nloe));
+
+        HVX_Vector vy0_d = *(const HVX_UVector *) (y0_d   + i * y_dblk_size);
+        HVX_Vector vy1_d = *(const HVX_UVector *) (y1_d   + i * y_dblk_size);
+        HVX_Vector r0_d  = *(const HVX_UVector *) (r0_x_d + i * x_dblk_size);
+        HVX_Vector r1_d  = *(const HVX_UVector *) (r1_x_d + i * x_dblk_size);
+
+        // Convert vy_d from fp16 to fp32 while applying 0.5 scaling which is used for e8m0 halving
+        HVX_Vector half = Q6_Vh_vsplat_R(0x3800);  // 0.5 in fp16
+        vy0_d           = Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(Q6_Vh_vshuff_Vh(vy0_d), half));
+        vy0_d           = Q6_Vsf_equals_Vqf32(vy0_d);
+        vy1_d           = Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(Q6_Vh_vshuff_Vh(vy1_d), half));
+        vy1_d           = Q6_Vsf_equals_Vqf32(vy1_d);
+
+        // Convert rX_d scales from e8m0 to fp32
+        // Expand and zero-pad 32x uint8 e8m0 values to uint32s : 0 0 0 0, 0 0 0 1, 0 0 0 2, ...
+        // Left shift with zero fill to create FP32
+        // FIXME: might need to handle zero as a special case (see ggml-cpu code)
+        HVX_Vector expand    = *(const HVX_Vector *) expand_x32_e8m0;
+        HVX_Vector e8m0_mask = Q6_V_vsplat_R(0x000000ff);
+        r0_d                 = Q6_V_vdelta_VV(r0_d, expand);
+        r0_d                 = Q6_V_vand_VV(r0_d, e8m0_mask);
+        r0_d                 = Q6_Vw_vasl_VwR(r0_d, 23);
+        r1_d                 = Q6_V_vdelta_VV(r1_d, expand);
+        r1_d                 = Q6_V_vand_VV(r1_d, e8m0_mask);
+        r1_d                 = Q6_Vw_vasl_VwR(r1_d, 23);
+
+        HVX_Vector r0_c0_dd = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(r0_d, vy0_d));
+        HVX_Vector r0_c1_dd = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(r0_d, vy1_d));
+        HVX_Vector r1_c0_dd = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(r1_d, vy0_d));
+        HVX_Vector r1_c1_dd = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(r1_d, vy1_d));
+
+        // Zero out unused scales
+        HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe / 8);
+        r0_c0_dd = Q6_V_vand_QV(bmask, r0_c0_dd);
+        r0_c1_dd = Q6_V_vand_QV(bmask, r0_c1_dd);
+        r1_c0_dd = Q6_V_vand_QV(bmask, r1_c0_dd);
+        r1_c1_dd = Q6_V_vand_QV(bmask, r1_c1_dd);
+        r0_c0_ia = Q6_V_vand_QV(bmask, r0_c0_ia);
+        r0_c1_ia = Q6_V_vand_QV(bmask, r0_c1_ia);
+        r1_c0_ia = Q6_V_vand_QV(bmask, r1_c0_ia);
+        r1_c1_ia = Q6_V_vand_QV(bmask, r1_c1_ia);
+
+        HVX_Vector r0_c0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_c0_ia, r0_c0_dd);
+        HVX_Vector r0_c1_fa = Q6_Vqf32_vmpy_VsfVsf(r0_c1_ia, r0_c1_dd);
+        HVX_Vector r1_c0_fa = Q6_Vqf32_vmpy_VsfVsf(r1_c0_ia, r1_c0_dd);
+        HVX_Vector r1_c1_fa = Q6_Vqf32_vmpy_VsfVsf(r1_c1_ia, r1_c1_dd);
+
+        r0_c0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_c0_fa, r0_c0_sum));
+        r0_c1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_c1_fa, r0_c1_sum));
+        r1_c0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_c0_fa, r1_c0_sum));
+        r1_c1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_c1_fa, r1_c1_sum));
+    }
+
+    // Reduce and store results
+    HVX_Vector r0_r1_c0_sum = hvx_vec_reduce_sum_f32x2(r0_c0_sum, r1_c0_sum);
+    HVX_Vector r0_r1_c1_sum = hvx_vec_reduce_sum_f32x2(r0_c1_sum, r1_c1_sum);
+
+    hvx_vec_store_u(&s0[0], 8, r0_r1_c0_sum);  // row0,col0 row1,col0
+    hvx_vec_store_u(&s1[0], 8, r0_r1_c1_sum);  // row0,col1 row1,col1
+}
+
+static void vec_dot_f16_f16_aa_1x1(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) {
     const HVX_Vector * restrict x = (const HVX_Vector *) vx;
     const HVX_Vector * restrict y = (const HVX_Vector *) vy;
 
@@ -912,14 +1313,12 @@ static void vec_dot_f16_f16_aa(const int n, float * restrict s, const void * res
     hvx_vec_store_u(&s[0], 4, rsum);
 }
 
-static void vec_dot_f16_f16_aa_rx2(const int n,
-                                float * restrict s,
-                                const void * restrict vx,
-                                uint32_t vx_row_size,
-                                const void * restrict vy) {
-    const HVX_Vector * restrict x0 = (const HVX_Vector *) vx;
-    const HVX_Vector * restrict x1 = (const HVX_Vector *) ((const uint8_t *) vx + vx_row_size);
-    const HVX_Vector * restrict y  = (const HVX_Vector *) vy;
+static void vec_dot_f16_f16_aa_2x1(const int n, float * restrict s0,
+                                const void * restrict vx0, const void * restrict vx1,
+                                const void * restrict vy0) {
+    const HVX_Vector * restrict x0 = (const HVX_Vector *) vx0;
+    const HVX_Vector * restrict x1 = (const HVX_Vector *) vx1;
+    const HVX_Vector * restrict y  = (const HVX_Vector *) vy0;
 
     uint32_t nvec = n / VLEN_FP16;
     uint32_t nloe = n % VLEN_FP16;
@@ -953,10 +1352,86 @@ static void vec_dot_f16_f16_aa_rx2(const int n,
     }
 
     HVX_Vector rsum = hvx_vec_reduce_sum_f32x2(Q6_Vsf_equals_Vqf32(rsum0), Q6_Vsf_equals_Vqf32(rsum1));
-    hvx_vec_store_u(&s[0], 8, rsum);
+    hvx_vec_store_u(s0, 8, rsum);
+}
+
+static void vec_dot_f16_f16_aa_2x2(const int n, float * restrict s0, float * restrict s1,
+                                const void * restrict vx0, const void * restrict vx1,
+                                const void * restrict vy0, const void * restrict vy1) {
+    const HVX_Vector * restrict x0 = (const HVX_Vector *) vx0;
+    const HVX_Vector * restrict x1 = (const HVX_Vector *) vx1;
+    const HVX_Vector * restrict y0 = (const HVX_Vector *) vy0;
+    const HVX_Vector * restrict y1 = (const HVX_Vector *) vy1;
+
+    uint32_t nvec = n / VLEN_FP16;
+    uint32_t nloe = n % VLEN_FP16;
+
+    // Row sums (sf) - 4 accumulators for 2×2 tile
+    HVX_Vector r0_c0_sum = Q6_V_vsplat_R(0);
+    HVX_Vector r0_c1_sum = Q6_V_vsplat_R(0);
+    HVX_Vector r1_c0_sum = Q6_V_vsplat_R(0);
+    HVX_Vector r1_c1_sum = Q6_V_vsplat_R(0);
+
+    uint32_t i = 0;
+
+    #pragma unroll(2)
+    for (i = 0; i < nvec; i++) {
+        HVX_Vector r0_hf = x0[i];
+        HVX_Vector r1_hf = x1[i];
+        HVX_Vector c0_hf = y0[i];
+        HVX_Vector c1_hf = y1[i];
+
+        // Compute 4 dot products: r0×c0, r0×c1, r1×c0, r1×c1
+        HVX_VectorPair r0_c0_qf_p = Q6_Wqf32_vmpy_VhfVhf(r0_hf, c0_hf);
+        HVX_VectorPair r0_c1_qf_p = Q6_Wqf32_vmpy_VhfVhf(r0_hf, c1_hf);
+        HVX_VectorPair r1_c0_qf_p = Q6_Wqf32_vmpy_VhfVhf(r1_hf, c0_hf);
+        HVX_VectorPair r1_c1_qf_p = Q6_Wqf32_vmpy_VhfVhf(r1_hf, c1_hf);
+
+        HVX_Vector r0_c0_qf = Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(r0_c0_qf_p), Q6_V_hi_W(r0_c0_qf_p));
+        HVX_Vector r0_c1_qf = Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(r0_c1_qf_p), Q6_V_hi_W(r0_c1_qf_p));
+        HVX_Vector r1_c0_qf = Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(r1_c0_qf_p), Q6_V_hi_W(r1_c0_qf_p));
+        HVX_Vector r1_c1_qf = Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(r1_c1_qf_p), Q6_V_hi_W(r1_c1_qf_p));
+
+        r0_c0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_c0_qf, r0_c0_sum));
+        r0_c1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_c1_qf, r0_c1_sum));
+        r1_c0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_c0_qf, r1_c0_sum));
+        r1_c1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_c1_qf, r1_c1_sum));
+    }
+
+    if (nloe) {
+        HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe * 2);
+
+        HVX_Vector r0_hf = Q6_V_vand_QV(bmask, x0[i]);
+        HVX_Vector r1_hf = Q6_V_vand_QV(bmask, x1[i]);
+        HVX_Vector c0_hf = Q6_V_vand_QV(bmask, y0[i]);
+        HVX_Vector c1_hf = Q6_V_vand_QV(bmask, y1[i]);
+
+        HVX_VectorPair r0_c0_qf_p = Q6_Wqf32_vmpy_VhfVhf(r0_hf, c0_hf);
+        HVX_VectorPair r0_c1_qf_p = Q6_Wqf32_vmpy_VhfVhf(r0_hf, c1_hf);
+        HVX_VectorPair r1_c0_qf_p = Q6_Wqf32_vmpy_VhfVhf(r1_hf, c0_hf);
+        HVX_VectorPair r1_c1_qf_p = Q6_Wqf32_vmpy_VhfVhf(r1_hf, c1_hf);
+
+        HVX_Vector r0_c0_qf = Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(r0_c0_qf_p), Q6_V_hi_W(r0_c0_qf_p));
+        HVX_Vector r0_c1_qf = Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(r0_c1_qf_p), Q6_V_hi_W(r0_c1_qf_p));
+        HVX_Vector r1_c0_qf = Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(r1_c0_qf_p), Q6_V_hi_W(r1_c0_qf_p));
+        HVX_Vector r1_c1_qf = Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(r1_c1_qf_p), Q6_V_hi_W(r1_c1_qf_p));
+
+        r0_c0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_c0_qf, r0_c0_sum));
+        r0_c1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_c1_qf, r0_c1_sum));
+        r1_c0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_c0_qf, r1_c0_sum));
+        r1_c1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_c1_qf, r1_c1_sum));
+
+    }
+
+    // Reduce and store results
+    HVX_Vector r0_r1_c0_sum = hvx_vec_reduce_sum_f32x2(r0_c0_sum, r1_c0_sum);
+    HVX_Vector r0_r1_c1_sum = hvx_vec_reduce_sum_f32x2(r0_c1_sum, r1_c1_sum);
+
+    hvx_vec_store_u(&s0[0], 8, r0_r1_c0_sum);  // row0,col0 row1,col0
+    hvx_vec_store_u(&s1[0], 8, r0_r1_c1_sum);  // row0,col1 row1,col1
 }
 
-static void vec_dot_f16_f16_uu(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) {
+static void vec_dot_f16_f16_uu_1x1(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) {
     const HVX_UVector * restrict x = (const HVX_UVector *) vx;
     const HVX_UVector * restrict y = (const HVX_UVector *) vy;
 
@@ -986,7 +1461,7 @@ static void vec_dot_f16_f16_uu(const int n, float * restrict s, const void * res
     hvx_vec_store_u(&s[0], 4, rsum);
 }
 
-static void vec_dot_f16_f32_uu(const int n, float * restrict s, const void * restrict x, const void * restrict y) {
+static void vec_dot_f16_f32_uu_1x1(const int n, float * restrict s, const void * restrict x, const void * restrict y) {
     const HVX_UVector * restrict vx = (const HVX_UVector * restrict) x;
     const HVX_UVector * restrict vy = (const HVX_UVector * restrict) y;
 
@@ -1083,14 +1558,16 @@ static void vec_dot_f16_f32_uu(const int n, float * restrict s, const void * res
     const uint32_t nb2 = dst->nb[2];   \
     const uint32_t nb3 = dst->nb[3];
 
-#define htp_matmul_preamble            \
-    htp_matmul_tensors_preamble;       \
-    dma_queue *dma_queue           = octx->ctx->dma[ith];         \
-    uint32_t src0_nrows_per_thread = octx->src0_nrows_per_thread;
+#define htp_matmul_preamble                                     \
+    struct htp_matmul_context * mmctx = data;                   \
+    struct htp_ops_context * octx  = mmctx->octx;               \
+    htp_matmul_tensors_preamble;                                \
+    dma_queue *dma_queue           = octx->ctx->dma[ith];       \
+    uint32_t src0_nrows_per_thread = mmctx->src0_nrows_per_thread;
 
 // *** matmul with support for 4d tensors and full broadcasting
 
-static void matmul_4d(struct htp_matmul_type * mt, struct htp_ops_context * octx, uint32_t nth, uint32_t ith) {
+static void matmul_4d(unsigned int nth, unsigned int ith, void * data) {
     htp_matmul_preamble;
 
     uint64_t t1, t2;
@@ -1136,13 +1613,13 @@ static void matmul_4d(struct htp_matmul_type * mt, struct htp_ops_context * octx
     for (uint32_t iir1 = ir1_start; iir1 < ir1_end; iir1 += blck_1) {
         for (uint32_t iir0 = ir0_start; iir0 < ir0_end; iir0 += blck_0) {
             for (uint32_t ir1 = iir1; ir1 < MIN(iir1 + blck_1, ir1_end); ir1++) {
-                const uint32_t i13 = fastdiv(ir1, &octx->mm_div_ne12_ne1);
-                const uint32_t i12 = fastdiv(ir1 - i13 * ne12 * ne1, &octx->mm_div_ne1);
+                const uint32_t i13 = fastdiv(ir1, &mmctx->mm_div_ne12_ne1);
+                const uint32_t i12 = fastdiv(ir1 - i13 * ne12 * ne1, &mmctx->mm_div_ne1);
                 const uint32_t i11 = (ir1 - i13 * ne12 * ne1 - i12 * ne1);
 
                 // broadcast src0 into src1
-                const uint32_t i03 = fastdiv(i13, &octx->mm_div_r3);
-                const uint32_t i02 = fastdiv(i12, &octx->mm_div_r2);
+                const uint32_t i03 = fastdiv(i13, &mmctx->mm_div_r3);
+                const uint32_t i02 = fastdiv(i12, &mmctx->mm_div_r2);
 
                 const uint32_t i1 = i11;
                 const uint32_t i2 = i12;
@@ -1155,7 +1632,7 @@ static void matmul_4d(struct htp_matmul_type * mt, struct htp_ops_context * octx
                 const uint32_t ir0_block_end = MIN(iir0 + blck_0, ir0_end);
                 for (uint32_t ir0 = iir0; ir0 < ir0_block_end; ir0++) {
                     const uint8_t * restrict src0_row = src0_base + ir0 * nb01;
-                    mt->vec_dot(ne00, &dst_col[ir0], src0_row, src1_col);
+                    mmctx->vec_dot_1x1(ne00, &dst_col[ir0], src0_row, src1_col);
                 }
             }
         }
@@ -1170,7 +1647,7 @@ static void matmul_4d(struct htp_matmul_type * mt, struct htp_ops_context * octx
 }
 
 // src1 tensor is already in VTCM spad
-static void matmul_2d(struct htp_matmul_type * mt, struct htp_ops_context * octx, uint32_t nth, uint32_t ith) {
+static void matmul_2d(unsigned int nth, unsigned int ith, void * data) {
     htp_matmul_preamble;
 
     const uint32_t src0_nrows = ne01 * ne02 * ne03;  // src0 rows
@@ -1195,7 +1672,7 @@ static void matmul_2d(struct htp_matmul_type * mt, struct htp_ops_context * octx
     // Per-thread VTCM scratchpads for all tensors
     // Note that the entire src1 tensor is already in VTCM
     // For other tensors we allocate N rows per thread, padded to HVX vector size
-    uint8_t * restrict spad_dst  = dst_spad->data + dst_spad->size_per_thread * ith;
+    uint8_t * restrict spad_dst  = dst_spad->data  + dst_spad->size_per_thread  * ith;
     uint8_t * restrict spad_src0 = src0_spad->data + src0_spad->size_per_thread * ith;
     uint8_t * restrict src1_data = src1_spad->data;
 
@@ -1219,11 +1696,21 @@ static void matmul_2d(struct htp_matmul_type * mt, struct htp_ops_context * octx
     for (uint32_t ir0 = src0_start_row; ir0 < src0_end_row_x2; ir0 += 2) {
         const uint8_t * ss0 = dma_queue_pop(dma_queue).dst;
 
-        #pragma unroll(2)
-        for (uint32_t ir1 = 0; ir1 < src1_nrows; ++ir1) {
+        // Process src1 columns in pairs (2×2 tiling)
+        uint32_t ir1 = 0;
+        for (; ir1 + 1 < src1_nrows; ir1 += 2) {
+            const uint8_t * restrict src1_col0 = (const uint8_t *) (src1_data + (ir1+0) * src1_stride);
+            const uint8_t * restrict src1_col1 = (const uint8_t *) (src1_data + (ir1+1) * src1_stride);
+            float * restrict dst_row0 = (float *) (dst->data + ((ir1+0) * dst_row_size));
+            float * restrict dst_row1 = (float *) (dst->data + ((ir1+1) * dst_row_size));
+            mmctx->vec_dot_2x2(ne00, &dst_row0[ir0], &dst_row1[ir0], ss0, ss0 + src0_stride, src1_col0, src1_col1);
+        }
+
+        // Handle remaining src1 rows (fallback to 2×1)
+        for (; ir1 < src1_nrows; ++ir1) {
             const uint8_t * restrict src1_col = (const uint8_t *) (src1_data + ir1 * src1_stride);
             float * restrict dst_row          = (float *) (dst->data + (ir1 * dst_row_size));
-            mt->vec_dot_rx2(ne00, &dst_row[ir0], ss0, src0_stride, src1_col);
+            mmctx->vec_dot_2x1(ne00, &dst_row[ir0], ss0, ss0 + src0_stride, src1_col);
         }
 
         // Prefetch next (n + spad_nrows) row
@@ -1247,20 +1734,20 @@ static void matmul_2d(struct htp_matmul_type * mt, struct htp_ops_context * octx
         for (uint32_t ir1 = 0; ir1 < src1_nrows; ++ir1) {
             const uint8_t * restrict src1_col = (const uint8_t *) (src1_data + ir1 * src1_stride);
             float * restrict dst_row          = (float *) (dst->data + (ir1 * dst_row_size));
-            mt->vec_dot(ne00, &dst_row[ir0], ss0, src1_col);
+            mmctx->vec_dot_1x1(ne00, &dst_row[ir0], ss0, src1_col);
         }
     }
 
     t2 = HAP_perf_get_qtimer_count();
 
-    FARF(HIGH, "matmul-%s %d/%d: %ux%ux%ux%u (%u:%u) * %ux%ux%ux%u -> %ux%ux%ux%u usec %u\n", mt->type, ith, nth,
+    FARF(HIGH, "matmul-%s %d/%d: %ux%ux%ux%u (%u:%u) * %ux%ux%ux%u -> %ux%ux%ux%u usec %u\n", mmctx->type, ith, nth,
          src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3], src0_start_row, src0_end_row, src1->ne[0], src1->ne[1],
          src1->ne[2], src1->ne[3], dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3],
          (unsigned) HAP_perf_qtimer_count_to_us(t2 - t1));
 }
 
 // q8x4x2 src1 tensor is already in VTCM spad
-static void matvec_2d(struct htp_matmul_type * mt, struct htp_ops_context * octx, uint32_t nth, uint32_t ith) {
+static void matvec_2d(unsigned int nth, unsigned int ith, void * data) {
     htp_matmul_preamble;
 
     const uint32_t src0_nrows = ne01;
@@ -1311,7 +1798,7 @@ static void matvec_2d(struct htp_matmul_type * mt, struct htp_ops_context * octx
     // Process src0 rows
     for (uint32_t ir0 = src0_start_row; ir0 < src0_end_row_x2; ir0 += 2) {
         const uint8_t * ss0 = dma_queue_pop(dma_queue).dst;
-        mt->vec_dot_rx2(ne00, &tmp[ir0 - src0_start_row], ss0, src0_stride, src1_col);
+        mmctx->vec_dot_2x1(ne00, &tmp[ir0 - src0_start_row], ss0, ss0 + src0_stride, src1_col);
 
         // Prefetch next (n + spad_nrows) row
         const uint32_t pr0 = (ir0 + MM_SPAD_SRC0_NROWS);
@@ -1329,14 +1816,14 @@ static void matvec_2d(struct htp_matmul_type * mt, struct htp_ops_context * octx
         dma_queue_push_ddr_to_vtcm(dma_queue, dma_make_ptr(spad_src0 + is0 * src0_stride, src0_row + ir0 * src0_row_size),
                        src0_stride, src0_row_size, 1);
         const uint8_t * ss0 = dma_queue_pop(dma_queue).dst;
-        mt->vec_dot(ne00, &tmp[ir0 - src0_start_row], ss0, src1_col);
+        mmctx->vec_dot_1x1(ne00, &tmp[ir0 - src0_start_row], ss0, src1_col);
     }
 
     hvx_copy_f32_ua((uint8_t *) &dst_col[src0_start_row], (uint8_t *) tmp, src0_end_row - src0_start_row);
 
     t2 = HAP_perf_get_qtimer_count();
 
-    FARF(HIGH, "matvec-%s %u/%u: %ux%ux%ux%u (%u:%u) * %ux%ux%ux%u -> %ux%ux%ux%u usec %u\n", mt->type, ith, nth,
+    FARF(HIGH, "matvec-%s %u/%u: %ux%ux%ux%u (%u:%u) * %ux%ux%ux%u -> %ux%ux%ux%u usec %u\n", mmctx->type, ith, nth,
          src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3], src0_start_row, src0_end_row, src1->ne[0], src1->ne[1],
          src1->ne[2], src1->ne[3], dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3],
          (unsigned) HAP_perf_qtimer_count_to_us(t2 - t1));
@@ -1350,7 +1837,7 @@ struct mmid_row_mapping {
 };
 
 // src1 tensor is already in VTCM spad
-static void matmul_id(struct htp_matmul_type * mt, struct htp_ops_context * octx, uint32_t nth, uint32_t ith) {
+static void matmul_id(unsigned int nth, unsigned int ith, void * data) {
     htp_matmul_preamble;
 
     struct htp_tensor * restrict     ids = &octx->src2;
@@ -1423,11 +1910,10 @@ static void matmul_id(struct htp_matmul_type * mt, struct htp_ops_context * octx
                 const int               rm2         = row_mapping.i2;  // token idx
 
                 const uint32_t ir1 = src1_nrows == 1 ? 0 : rm1;        // src1 row idx
-                const uint8_t * restrict src1_col =
-                    (const uint8_t *) (src1_data + (ir1 + rm2 * ne11 + 0) * src1_row_size);
+                const uint8_t * restrict src1_col = (const uint8_t *) (src1_data + (ir1 + rm2 * ne11 + 0) * src1_row_size);
                 float * dst_row = (float *) (dst->data + (rm1 * nb1 + rm2 * nb2 + 0));
 
-                mt->vec_dot_rx2(ne00, &dst_row[ir0], ss0, src0_row_size_padded, src1_col);
+                mmctx->vec_dot_2x1(ne00, &dst_row[ir0], ss0, ss0 + src0_row_size_padded, src1_col);
             }
 
             // Prefetch next (n + spad_nrows) row
@@ -1453,25 +1939,24 @@ static void matmul_id(struct htp_matmul_type * mt, struct htp_ops_context * octx
                 const int               rm2         = row_mapping.i2;  // token idx
 
                 const uint32_t ir1 = src1_nrows == 1 ? 0 : rm1;        // src1 row idx
-                const uint8_t * restrict src1_col =
-                    (const uint8_t *) (src1_data + (ir1 + rm2 * ne11 + 0) * src1_row_size);
+                const uint8_t * restrict src1_col = (const uint8_t *) (src1_data + (ir1 + rm2 * ne11 + 0) * src1_row_size);
                 float * dst_row = (float *) (dst->data + (rm1 * nb1 + rm2 * nb2 + 0));
 
-                mt->vec_dot(ne00, &dst_row[ir0], ss0, src1_col);
+                mmctx->vec_dot_1x1(ne00, &dst_row[ir0], ss0, src1_col);
             }
         }
     }
 
     t2 = HAP_perf_get_qtimer_count();
 
-    FARF(HIGH, "matmul-id-%s %d/%d: %ux%ux%ux%u (%u:%u) * %ux%ux%ux%u (%ux%ux%ux%u) -> %ux%ux%ux%u usec %u\n", mt->type,
+    FARF(HIGH, "matmul-id-%s %d/%d: %ux%ux%ux%u (%u:%u) * %ux%ux%ux%u (%ux%ux%ux%u) -> %ux%ux%ux%u usec %u\n", mmctx->type,
          ith, nth, src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3], src0_start_row, src0_end_row, src1->ne[0],
          src1->ne[1], src1->ne[2], src1->ne[3], ids->ne[0], ids->ne[1], ids->ne[2], ids->ne[3], dst->ne[0], dst->ne[1],
          dst->ne[2], dst->ne[3], (unsigned) HAP_perf_qtimer_count_to_us(t2 - t1));
 }
 
 // src1 tensor is already in VTCM spad
-static void matvec_id(struct htp_matmul_type * mt, struct htp_ops_context * octx, uint32_t nth, uint32_t ith) {
+static void matvec_id(unsigned int nth, unsigned int ith, void * data) {
     htp_matmul_preamble;
 
     struct htp_tensor * restrict     ids = &octx->src2;
@@ -1531,7 +2016,7 @@ static void matvec_id(struct htp_matmul_type * mt, struct htp_ops_context * octx
         // Process src0 rows
         for (uint32_t ir0 = src0_start_row; ir0 < src0_end_row_x2; ir0 += 2) {
             const uint8_t * ss0 = dma_queue_pop(dma_queue).dst;
-            mt->vec_dot_rx2(ne00, &dst_row[ir0], ss0, src0_row_size_padded, src1_col);
+            mmctx->vec_dot_2x1(ne00, &dst_row[ir0], ss0, ss0 + src0_row_size_padded, src1_col);
 
             // Prefetch next (n + spad_nrows) row
             const int pr0 = (ir0 + MM_SPAD_SRC0_NROWS);
@@ -1549,13 +2034,13 @@ static void matvec_id(struct htp_matmul_type * mt, struct htp_ops_context * octx
             dma_queue_push_ddr_to_vtcm(dma_queue, dma_make_ptr(spad_src0 + is0 * src0_row_size_padded, src0_row + ir0 * src0_row_size),
                            src0_row_size_padded, src0_row_size, 1);
             const uint8_t * ss0 = dma_queue_pop(dma_queue).dst;
-            mt->vec_dot(ne00, &dst_row[ir0], ss0, src1_col);
+            mmctx->vec_dot_1x1(ne00, &dst_row[ir0], ss0, src1_col);
         }
     }
 
     t2 = HAP_perf_get_qtimer_count();
 
-    FARF(HIGH, "matvec-id-%s %d/%d: %ux%ux%ux%u (%u:%u) * %ux%ux%ux%u (%ux%ux%ux%u) -> %ux%ux%ux%u usec %u\n", mt->type,
+    FARF(HIGH, "matvec-id-%s %d/%d: %ux%ux%ux%u (%u:%u) * %ux%ux%ux%u (%ux%ux%ux%u) -> %ux%ux%ux%u usec %u\n", mmctx->type,
          ith, nth, src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3], src0_start_row, src0_end_row, src1->ne[0],
          src1->ne[1], src1->ne[2], src1->ne[3], src2->ne[0], src2->ne[1], src2->ne[2], src2->ne[3], dst->ne[0],
          dst->ne[1], dst->ne[2], dst->ne[3], (unsigned) HAP_perf_qtimer_count_to_us(t2 - t1));
@@ -1754,12 +2239,14 @@ static void quantize_row_f32_q8x4x2(float * restrict x, uint8_t * restrict y, ui
     hvx_copy_f16_ua(y_d, t_d, nb * 8);
 }
 
-static void quantize_f32_q8x4x2(const struct htp_tensor * src,
-                                 uint8_t * restrict dst,
-                                 struct htp_spad * spad,
-                                 uint32_t          nth,
-                                 uint32_t          ith,
-                                 uint32_t          nrows_per_thread) {
+static void quantize_f32_q8x4x2(unsigned int nth, unsigned int ith, void * data) {
+    struct htp_matmul_context * mmctx = data;
+    struct htp_ops_context * octx = mmctx->octx;
+
+    const struct htp_tensor * src = &octx->src1;
+    uint8_t * restrict dst = octx->src1_spad.data;
+    struct htp_spad * spad = &octx->src0_spad;
+    uint32_t nrows_per_thread = mmctx->src1_nrows_per_thread;
 
     uint64_t t1 = HAP_perf_get_qtimer_count();
 
@@ -1799,8 +2286,14 @@ static void quantize_f32_q8x4x2(const struct htp_tensor * src,
          ir_last, src_row_size, dst_row_size, (unsigned) HAP_perf_qtimer_count_to_us(t2 - t1));
 }
 
-static void quantize_f32_f16(const struct htp_tensor * src, uint8_t * restrict dst, uint32_t nth, uint32_t ith,
-                              uint32_t nrows_per_thread, uint32_t dst_stride) {
+static void quantize_f32_f16(unsigned int nth, unsigned int ith, void * data) {
+    struct htp_matmul_context * mmctx = data;
+    struct htp_ops_context * octx = mmctx->octx;
+
+    const struct htp_tensor * src = &octx->src1;
+    uint8_t * restrict dst = octx->src1_spad.data;
+    uint32_t nrows_per_thread = mmctx->src1_nrows_per_thread;
+    uint32_t dst_stride = octx->src1_spad.stride;
 
     uint64_t t1 = HAP_perf_get_qtimer_count();
 
@@ -1835,8 +2328,14 @@ static void quantize_f32_f16(const struct htp_tensor * src, uint8_t * restrict d
 }
 
 // TODO just a plain copy that should be done via the DMA during the Op setup
-static void quantize_f16_f16(const struct htp_tensor * src, uint8_t * restrict dst, uint32_t nth, uint32_t ith,
-                              uint32_t nrows_per_thread, uint32_t dst_stride) {
+static void quantize_f16_f16(unsigned int nth, unsigned int ith, void * data) {
+    struct htp_matmul_context * mmctx = data;
+    struct htp_ops_context * octx = mmctx->octx;
+
+    const struct htp_tensor * src = &octx->src1;
+    uint8_t * restrict dst = octx->src1_spad.data;
+    uint32_t nrows_per_thread = mmctx->src1_nrows_per_thread;
+    uint32_t dst_stride = octx->src1_spad.stride;
 
     uint64_t t1 = HAP_perf_get_qtimer_count();
 
@@ -1870,213 +2369,76 @@ static void quantize_f16_f16(const struct htp_tensor * src, uint8_t * restrict d
         ir_last, src_row_size, src_stride, dst_stride, (unsigned) HAP_perf_qtimer_count_to_us(t2 - t1));
 }
 
-static void htp_quantize_f32_q8x4x2(unsigned int n, unsigned int i, void * data) {
-    struct htp_ops_context * octx = data;
-    quantize_f32_q8x4x2(&octx->src1, octx->src1_spad.data, &octx->src0_spad, n, i, octx->src1_nrows_per_thread);
-}
-
-static void htp_quantize_f32_f16(unsigned int n, unsigned int i, void * data) {
-    struct htp_ops_context * octx = data;
-    quantize_f32_f16(&octx->src1, octx->src1_spad.data, n, i, octx->src1_nrows_per_thread, octx->src1_spad.stride);
-}
-
-static void htp_quantize_f16_f16(unsigned int n, unsigned int i, void * data) {
-    struct htp_ops_context * octx = data;
-    quantize_f16_f16(&octx->src1, octx->src1_spad.data, n, i, octx->src1_nrows_per_thread, octx->src1_spad.stride);
-}
-
-// ** matmul/matvec callbacks for worker_pool
-
-static void htp_matvec_2d_q4x4x2_q8x4x2(unsigned int n, unsigned int i, void * data) {
-    struct htp_ops_context * octx = data;
-
-    struct htp_matmul_type mt;
-    mt.type        = "q4x4x2-q8x4x2";
-    mt.vec_dot     = vec_dot_q4x4x2_q8x4x2;
-    mt.vec_dot_rx2 = vec_dot_q4x4x2_q8x4x2_rx2;
-
-    matvec_2d(&mt, octx, n, i);
-}
-
-static void htp_matmul_2d_q4x4x2_q8x4x2(unsigned int n, unsigned int i, void * data) {
-    struct htp_ops_context * octx = data;
-
-    struct htp_matmul_type mt;
-    mt.type        = "q4x4x2-q8x4x2";
-    mt.vec_dot     = vec_dot_q4x4x2_q8x4x2;
-    mt.vec_dot_rx2 = vec_dot_q4x4x2_q8x4x2_rx2;
-
-    matmul_2d(&mt, octx, n, i);
-}
-
-static void htp_matvec_2d_q8x4x2_q8x4x2(unsigned int n, unsigned int i, void * data) {
-    struct htp_ops_context * octx = data;
-
-    struct htp_matmul_type mt;
-    mt.type        = "q8x4x2-q8x4x2";
-    mt.vec_dot     = vec_dot_q8x4x2_q8x4x2;
-    mt.vec_dot_rx2 = vec_dot_q8x4x2_q8x4x2_rx2;
-
-    matvec_2d(&mt, octx, n, i);
-}
-
-static void htp_matmul_2d_q8x4x2_q8x4x2(unsigned int n, unsigned int i, void * data) {
-    struct htp_ops_context * octx = data;
-
-    struct htp_matmul_type mt;
-    mt.type        = "q8x4x2-q8x4x2";
-    mt.vec_dot     = vec_dot_q8x4x2_q8x4x2;
-    mt.vec_dot_rx2 = vec_dot_q8x4x2_q8x4x2_rx2;
-
-    matmul_2d(&mt, octx, n, i);
-}
-
-static void htp_matvec_2d_mxfp4x4x2_q8x4x2(unsigned int n, unsigned int i, void * data) {
-    struct htp_ops_context * octx = data;
-
-    struct htp_matmul_type mt;
-    mt.type        = "mxfp4x4x2-q8x4x2";
-    mt.vec_dot     = vec_dot_mxfp4x4x2_q8x4x2;
-    mt.vec_dot_rx2 = vec_dot_mxfp4x4x2_q8x4x2_rx2;
-
-    matvec_2d(&mt, octx, n, i);
-}
-
-static void htp_matmul_2d_mxfp4x4x2_q8x4x2(unsigned int n, unsigned int i, void * data) {
-    struct htp_ops_context * octx = data;
-
-    struct htp_matmul_type mt;
-    mt.type        = "mxfp4x4x2-q8x4x2";
-    mt.vec_dot     = vec_dot_mxfp4x4x2_q8x4x2;
-    mt.vec_dot_rx2 = vec_dot_mxfp4x4x2_q8x4x2_rx2;
-
-    matmul_2d(&mt, octx, n, i);
-}
-
-static void htp_matvec_2d_f16_f16(unsigned int n, unsigned int i, void * data) {
-    struct htp_ops_context * octx = data;
-
-    struct htp_matmul_type mt;
-    mt.type        = "f16-f16";
-    mt.vec_dot     = vec_dot_f16_f16_aa;
-    mt.vec_dot_rx2 = vec_dot_f16_f16_aa_rx2;
-
-    matvec_2d(&mt, octx, n, i);
-}
-
-static void htp_matmul_2d_f16_f16(unsigned int n, unsigned int i, void * data) {
-    struct htp_ops_context * octx = data;
-
-    struct htp_matmul_type mt;
-    mt.type        = "f16-f16";
-    mt.vec_dot     = vec_dot_f16_f16_aa;
-    mt.vec_dot_rx2 = vec_dot_f16_f16_aa_rx2;
-
-    matmul_2d(&mt, octx, n, i);
-}
-
-static void htp_matmul_4d_f16_f32(unsigned int n, unsigned int i, void * data) {
-    struct htp_ops_context * octx = data;
-
-    struct htp_matmul_type mt;
-    mt.type        = "f16-f32";
-    mt.vec_dot     = vec_dot_f16_f32_uu;
-
-    matmul_4d(&mt, octx, n, i);
-}
-
-static void htp_matmul_4d_f16_f16(unsigned int n, unsigned int i, void * data) {
-    struct htp_ops_context * octx = data;
 
-    struct htp_matmul_type mt;
-    mt.type        = "f16-f16";
-    mt.vec_dot     = vec_dot_f16_f16_uu;
-
-    matmul_4d(&mt, octx, n, i);
-}
-
-// ** matmul-id callbacks for worker_pool
-
-static void htp_matvec_id_q4x4x2_q8x4x2(unsigned int n, unsigned int i, void * data) {
-    struct htp_ops_context * octx = data;
-
-    struct htp_matmul_type mt;
-    mt.type        = "q4x4x2-q8x4x2";
-    mt.vec_dot     = vec_dot_q4x4x2_q8x4x2;
-    mt.vec_dot_rx2 = vec_dot_q4x4x2_q8x4x2_rx2;
-
-    matvec_id(&mt, octx, n, i);
-}
-
-static void htp_matmul_id_q4x4x2_q8x4x2(unsigned int n, unsigned int i, void * data) {
-    struct htp_ops_context * octx = data;
-
-    struct htp_matmul_type mt;
-    mt.type        = "q4x4x2-q8x4x2";
-    mt.vec_dot     = vec_dot_q4x4x2_q8x4x2;
-    mt.vec_dot_rx2 = vec_dot_q4x4x2_q8x4x2_rx2;
-
-    matmul_id(&mt, octx, n, i);
-}
-
-static void htp_matvec_id_q8x4x2_q8x4x2(unsigned int n, unsigned int i, void * data) {
-    struct htp_ops_context * octx = data;
-
-    struct htp_matmul_type mt;
-    mt.type        = "q8x4x2-q8x4x2";
-    mt.vec_dot     = vec_dot_q8x4x2_q8x4x2;
-    mt.vec_dot_rx2 = vec_dot_q8x4x2_q8x4x2_rx2;
-
-    matvec_id(&mt, octx, n, i);
-}
-
-static void htp_matmul_id_q8x4x2_q8x4x2(unsigned int n, unsigned int i, void * data) {
-    struct htp_ops_context * octx = data;
-
-    struct htp_matmul_type mt;
-    mt.type        = "q8x4x2-q8x4x2";
-    mt.vec_dot     = vec_dot_q8x4x2_q8x4x2;
-    mt.vec_dot_rx2 = vec_dot_q8x4x2_q8x4x2_rx2;
-
-    matmul_id(&mt, octx, n, i);
+static inline bool htp_is_permuted(const struct htp_tensor * t) {
+    return t->nb[0] > t->nb[1] || t->nb[1] > t->nb[2] || t->nb[2] > t->nb[3];
 }
 
-static void htp_matvec_id_mxfp4x4x2_q8x4x2(unsigned int n, unsigned int i, void * data) {
-    struct htp_ops_context * octx = data;
-
-    struct htp_matmul_type mt;
-    mt.type        = "mxfp4x4x2-q8x4x2";
-    mt.vec_dot     = vec_dot_mxfp4x4x2_q8x4x2;
-    mt.vec_dot_rx2 = vec_dot_mxfp4x4x2_q8x4x2_rx2;
-
-    matvec_id(&mt, octx, n, i);
+static int htp_mminit_vec_dot(struct htp_matmul_context * mmctx, enum htp_data_type type) {
+    switch (type) {
+        case HTP_TYPE_Q4_0:
+            mmctx->type        = "q4x4x2-f32";
+            mmctx->vec_dot_1x1 = vec_dot_q4x4x2_q8x4x2_1x1;
+            mmctx->vec_dot_2x1 = vec_dot_q4x4x2_q8x4x2_2x1;
+            mmctx->vec_dot_2x2 = vec_dot_q4x4x2_q8x4x2_2x2;
+            return 0;
+        case HTP_TYPE_Q8_0:
+            mmctx->type        = "q8x4x2-f32";
+            mmctx->vec_dot_1x1 = vec_dot_q8x4x2_q8x4x2_1x1;
+            mmctx->vec_dot_2x1 = vec_dot_q8x4x2_q8x4x2_2x1;
+            mmctx->vec_dot_2x2 = vec_dot_q8x4x2_q8x4x2_2x2;
+            return 0;
+        case HTP_TYPE_MXFP4:
+            mmctx->type        = "mxfp4x4x2-f32";
+            mmctx->vec_dot_1x1 = vec_dot_mxfp4x4x2_q8x4x2_1x1;
+            mmctx->vec_dot_2x1 = vec_dot_mxfp4x4x2_q8x4x2_2x1;
+            mmctx->vec_dot_2x2 = vec_dot_mxfp4x4x2_q8x4x2_2x2;
+            return 0;
+        default:
+            return -1;
+    }
 }
 
-static void htp_matmul_id_mxfp4x4x2_q8x4x2(unsigned int n, unsigned int i, void * data) {
-    struct htp_ops_context * octx = data;
-
-    struct htp_matmul_type mt;
-    mt.type        = "mxfp4x4x2-q8x4x2";
-    mt.vec_dot     = vec_dot_mxfp4x4x2_q8x4x2;
-    mt.vec_dot_rx2 = vec_dot_mxfp4x4x2_q8x4x2_rx2;
-
-    matmul_id(&mt, octx, n, i);
-}
+static void htp_mminit_spad(struct htp_ops_context * octx,
+                                 size_t dst_row_size,
+                                 size_t src0_row_size_padded,
+                                 size_t src1_row_size,
+                                 uint32_t src1_nrows,
+                                 size_t src2_spad_size_per_thread) {
+    octx->dst_spad.size_per_thread  = hex_round_up(MM_SPAD_DST_NROWS * dst_row_size, 256);
+    octx->src0_spad.size_per_thread = hex_round_up(MM_SPAD_SRC0_NROWS * src0_row_size_padded, 256);
+    octx->src1_spad.size_per_thread = hex_round_up(src1_row_size * src1_nrows, 256);
+
+    if (src2_spad_size_per_thread > 0) {
+        octx->src2_spad.size_per_thread = src2_spad_size_per_thread;
+        octx->src2_spad.size            = octx->src2_spad.size_per_thread;
+    }
 
-// ** main matmul entry point
+    // src0 spad is also used in dynamic quantizer to store padded src1 rows
+    size_t src1_row_size_padded = hex_round_up(src1_row_size, QK_Q8_0x4x2 * sizeof(float));
+    if (octx->src0_spad.size_per_thread < src1_row_size_padded) {
+        octx->src0_spad.size_per_thread = src1_row_size_padded;
+    }
 
-static inline bool htp_is_permuted(const struct htp_tensor * t) {
-    return t->nb[0] > t->nb[1] || t->nb[1] > t->nb[2] || t->nb[2] > t->nb[3];
+    octx->src1_spad.size = octx->src1_spad.size_per_thread;
+    octx->src0_spad.size = octx->src0_spad.size_per_thread * octx->n_threads;
+    octx->dst_spad.size  = octx->dst_spad.size_per_thread * octx->n_threads;
 }
 
 int op_matmul(struct htp_ops_context * octx) {
     htp_matmul_tensors_preamble;
 
-    const char * op_type;
+    struct htp_matmul_context mmctx_struct = {0};
+    struct htp_matmul_context * mmctx = &mmctx_struct;
+    mmctx->octx = octx;
 
     const uint32_t src0_nrows = ne01 * ne02 * ne03;
     const uint32_t src1_nrows = ne11 * ne12 * ne13;
 
+    // Compute src0_nrows_per_thread
+    mmctx->src0_nrows_per_thread  = (src0_nrows + octx->n_threads - 1) / octx->n_threads;
+    mmctx->src0_nrows_per_thread += (mmctx->src0_nrows_per_thread & 1); // round up to even
+
     const size_t src0_row_size = nb01;
     const size_t dst_row_size  = nb1;
     size_t       src1_row_size = nb11;
@@ -2085,181 +2447,95 @@ int op_matmul(struct htp_ops_context * octx) {
     size_t       src1_row_size_padded;
 
     worker_callback_t quant_job_func;
-    worker_callback_t matmul_job_func;
+    worker_callback_t matmul_job_func = src1_nrows > 1 ? matmul_2d : matvec_2d;
 
     bool need_quant = !(octx->flags & HTP_OPFLAGS_SKIP_QUANTIZE);
 
-    switch (src0->type) {
-        case HTP_TYPE_Q4_0:
-            op_type        = "q4x4x2-f32";
-            quant_job_func = htp_quantize_f32_q8x4x2;
-            if (src1_nrows > 1) {
-                matmul_job_func = htp_matmul_2d_q4x4x2_q8x4x2;
-            } else {
-                matmul_job_func = htp_matvec_2d_q4x4x2_q8x4x2;
-            }
-
-            src1_row_size = q8x4x2_row_size(ne10);  // row size post quantization
-
-            // Entire src1 tensor is placed into the VTCM
-            // For other tensors we allocate N rows per thread, padded to HVX vector size
-
-            octx->dst_spad.size_per_thread  = hex_round_up(MM_SPAD_DST_NROWS * dst_row_size, 256);
-            octx->src0_spad.size_per_thread = hex_round_up(MM_SPAD_SRC0_NROWS * src0_row_size_padded, 256);
-            octx->src1_spad.size_per_thread = hex_round_up(src1_row_size * src1_nrows, 256);
-
-            // src0 spad is also used in dynamic quantizer to store padded src1 rows
-            src1_row_size_padded = hex_round_up(src1_row_size, QK_Q8_0x4x2 * sizeof(float));
-            if (octx->src0_spad.size_per_thread < src1_row_size_padded) {
-                octx->src0_spad.size_per_thread = src1_row_size_padded;
-            }
+    if (src0->type == HTP_TYPE_F16) {
+        // Try optimized f16-f16 path first (src1 in VTCM)
+        const size_t f16_src1_row_size  = hex_round_up(ne10 * 2, 128);
+        const size_t f16_src1_spad_size = hex_round_up(f16_src1_row_size * src1_nrows, 256);
+        const size_t f16_src0_spad_size = hex_round_up(MM_SPAD_SRC0_NROWS * src0_row_size_padded, 256) * octx->n_threads;
+        const size_t f16_dst_spad_size  = hex_round_up(MM_SPAD_DST_NROWS * dst_row_size, 256) * octx->n_threads;
 
-            octx->src1_spad.size = octx->src1_spad.size_per_thread;
-            octx->src0_spad.size = octx->src0_spad.size_per_thread * octx->n_threads;
-            octx->dst_spad.size  = octx->dst_spad.size_per_thread * octx->n_threads;
-            break;
+        const size_t f16_total_size = f16_src1_spad_size + f16_src0_spad_size + f16_dst_spad_size;
 
-        case HTP_TYPE_Q8_0:
-            op_type        = "q8x4x2-f32";
-            quant_job_func = htp_quantize_f32_q8x4x2;
-            if (src1_nrows > 1) {
-                matmul_job_func = htp_matmul_2d_q8x4x2_q8x4x2;
-            } else {
-                matmul_job_func = htp_matvec_2d_q8x4x2_q8x4x2;
-            }
+        // Default matmul implementation does not support multi-batch src0 (N-vs-N broadcasting).
+        // It only supports 1-vs-N broadcasting (src0 is 2D) or standard 2D matmul.
+        const bool is_batched  = (ne02 > 1) || (ne03 > 1);
+        const bool is_permuted = htp_is_permuted(&octx->src0) || htp_is_permuted(&octx->src1);
 
-            src1_row_size = q8x4x2_row_size(ne10);  // row size post quantization
+        if (!is_batched && !is_permuted && f16_total_size <= octx->ctx->vtcm_size) {
+            // Optimized path
+            quant_job_func     = (src1->type == HTP_TYPE_F32) ? quantize_f32_f16 : quantize_f16_f16;
+            mmctx->type        = "f16-f16";
+            mmctx->vec_dot_1x1 = vec_dot_f16_f16_aa_1x1;
+            mmctx->vec_dot_2x1 = vec_dot_f16_f16_aa_2x1;
+            mmctx->vec_dot_2x2 = vec_dot_f16_f16_aa_2x2;
 
-            // Entire src1 tensor is placed into the VTCM
-            // For other tensors we allocate N rows per thread, padded to HVX vector size
+            src1_row_size = f16_src1_row_size;  // row size post quantization
 
             octx->dst_spad.size_per_thread  = hex_round_up(MM_SPAD_DST_NROWS * dst_row_size, 256);
             octx->src0_spad.size_per_thread = hex_round_up(MM_SPAD_SRC0_NROWS * src0_row_size_padded, 256);
             octx->src1_spad.size_per_thread = hex_round_up(src1_row_size * src1_nrows, 256);
 
-            // src0 spad is also used in dynamic quantizer to store padded src1 rows
-            src1_row_size_padded = hex_round_up(src1_row_size, QK_Q8_0x4x2 * sizeof(float));
-            if (octx->src0_spad.size_per_thread < src1_row_size_padded) {
-                octx->src0_spad.size_per_thread = src1_row_size_padded;
-            }
-
             octx->src1_spad.size = octx->src1_spad.size_per_thread;
             octx->src0_spad.size = octx->src0_spad.size_per_thread * octx->n_threads;
             octx->dst_spad.size  = octx->dst_spad.size_per_thread * octx->n_threads;
-            break;
-
-        case HTP_TYPE_MXFP4:
-            op_type        = "mxfp4x4x2-f32";
-            quant_job_func = htp_quantize_f32_q8x4x2;
-            if (src1_nrows > 1) {
-                matmul_job_func = htp_matmul_2d_mxfp4x4x2_q8x4x2;
+        } else {
+            // Fallback to f16/f32 (DDR) if src1 doesn't fit in VTCM or broadcasting is required
+            quant_job_func = NULL;
+            if (src1->type == HTP_TYPE_F32) {
+                mmctx->type        = "f16-f32";
+                mmctx->vec_dot_1x1 = vec_dot_f16_f32_uu_1x1;
+                matmul_job_func    = matmul_4d;
             } else {
-                matmul_job_func = htp_matvec_2d_mxfp4x4x2_q8x4x2;
+                mmctx->type        = "f16-f16";
+                mmctx->vec_dot_1x1 = vec_dot_f16_f16_uu_1x1;
+                matmul_job_func    = matmul_4d;
             }
 
-            src1_row_size = q8x4x2_row_size(ne10);  // row size post quantization
-
-            // Entire src1 tensor is placed into the VTCM
-            // For other tensors we allocate N rows per thread, padded to HVX vector size
+            src1_row_size = nb11;  // original row size in DDR
 
             octx->dst_spad.size_per_thread  = hex_round_up(MM_SPAD_DST_NROWS * dst_row_size, 256);
-            octx->src0_spad.size_per_thread = hex_round_up(MM_SPAD_SRC0_NROWS * src0_row_size_padded, 256);
-            octx->src1_spad.size_per_thread = hex_round_up(src1_row_size * src1_nrows, 256);
+            octx->src0_spad.size_per_thread = hex_round_up(MM_SPAD_SRC0_NROWS * src0_row_size, 256);
+            octx->src1_spad.size_per_thread = hex_round_up(MM_SPAD_SRC1_NROWS * src1_row_size, 256);
 
-            // src0 spad is also used in dynamic quantizer to store padded src1 rows
-            src1_row_size_padded = hex_round_up(src1_row_size, QK_Q8_0x4x2 * sizeof(float));
-            if (octx->src0_spad.size_per_thread < src1_row_size_padded) {
-                octx->src0_spad.size_per_thread = src1_row_size_padded;
-            }
-
-            octx->src1_spad.size = octx->src1_spad.size_per_thread;
             octx->src0_spad.size = octx->src0_spad.size_per_thread * octx->n_threads;
+            octx->src1_spad.size = octx->src1_spad.size_per_thread * octx->n_threads;
             octx->dst_spad.size  = octx->dst_spad.size_per_thread * octx->n_threads;
-            break;
 
-        case HTP_TYPE_F16:
-            {
-                // Try optimized f16-f16 path first (src1 in VTCM)
-                const size_t f16_src1_row_size  = hex_round_up(ne10 * 2, 128);
-                const size_t f16_src1_spad_size = hex_round_up(f16_src1_row_size * src1_nrows, 256);
-                const size_t f16_src0_spad_size = hex_round_up(MM_SPAD_SRC0_NROWS * src0_row_size_padded, 256) * octx->n_threads;
-                const size_t f16_dst_spad_size  = hex_round_up(MM_SPAD_DST_NROWS  * dst_row_size, 256) * octx->n_threads;
-
-                const size_t f16_total_size = f16_src1_spad_size + f16_src0_spad_size + f16_dst_spad_size;
-
-                // Default matmul implementation does not support multi-batch src0 (N-vs-N broadcasting).
-                // It only supports 1-vs-N broadcasting (src0 is 2D) or standard 2D matmul.
-                const bool is_batched  = (ne02 > 1) || (ne03 > 1);
-                const bool is_permuted = htp_is_permuted(&octx->src0) || htp_is_permuted(&octx->src1);
-
-                if (!is_batched && !is_permuted && f16_total_size <= octx->ctx->vtcm_size) {
-                    // Optimized path
-                    op_type        = "f16-f16";
-                    quant_job_func = (src1->type == HTP_TYPE_F32) ? htp_quantize_f32_f16 : htp_quantize_f16_f16;
-                    if (src1_nrows > 1) {
-                        matmul_job_func = htp_matmul_2d_f16_f16;
-                    } else {
-                        matmul_job_func = htp_matvec_2d_f16_f16;
-                    }
-
-                    src1_row_size = f16_src1_row_size; // row size post quantization
-
-                    octx->dst_spad.size_per_thread  = hex_round_up(MM_SPAD_DST_NROWS * dst_row_size, 256);
-                    octx->src0_spad.size_per_thread = hex_round_up(MM_SPAD_SRC0_NROWS * src0_row_size_padded, 256);
-                    octx->src1_spad.size_per_thread = hex_round_up(src1_row_size * src1_nrows, 256);
-
-                    octx->src1_spad.size = octx->src1_spad.size_per_thread;
-                    octx->src0_spad.size = octx->src0_spad.size_per_thread * octx->n_threads;
-                    octx->dst_spad.size  = octx->dst_spad.size_per_thread * octx->n_threads;
-                } else {
-                    // Fallback to f16/f32 (DDR) if src1 doesn't fit in VTCM or broadcasting is required
-                    quant_job_func  = NULL;
-                    if (src1->type == HTP_TYPE_F32) {
-                        op_type         = "f16-f32";
-                        matmul_job_func = htp_matmul_4d_f16_f32;
-                    } else {
-                        op_type         = "f16-f16";
-                        matmul_job_func = htp_matmul_4d_f16_f16;
-                    }
-
-                    src1_row_size = nb11; // original row size in DDR
-
-                    octx->dst_spad.size_per_thread  = hex_round_up(MM_SPAD_DST_NROWS * dst_row_size, 256);
-                    octx->src0_spad.size_per_thread = hex_round_up(MM_SPAD_SRC0_NROWS * src0_row_size, 256);
-                    octx->src1_spad.size_per_thread = hex_round_up(MM_SPAD_SRC1_NROWS * src1_row_size, 256);
-
-                    octx->src0_spad.size = octx->src0_spad.size_per_thread * octx->n_threads;
-                    octx->src1_spad.size = octx->src1_spad.size_per_thread * octx->n_threads;
-                    octx->dst_spad.size  = octx->dst_spad.size_per_thread * octx->n_threads;
-
-                    // Init fastdiv for matmul_4d (supports broadcasting)
-                    octx->mm_div_ne12_ne1 = init_fastdiv_values(src1->ne[2] * dst->ne[1]);
-                    octx->mm_div_ne1      = init_fastdiv_values(dst->ne[1]);
-                    octx->mm_div_r2       = init_fastdiv_values(src1->ne[2] / src0->ne[2]);
-                    octx->mm_div_r3       = init_fastdiv_values(src1->ne[3] / src0->ne[3]);
-
-                    need_quant = false;
-                }
-            }
-            break;
+            // Init fastdiv for matmul_4d (supports broadcasting)
+            mmctx->mm_div_ne12_ne1 = init_fastdiv_values(src1->ne[2] * dst->ne[1]);
+            mmctx->mm_div_ne1      = init_fastdiv_values(dst->ne[1]);
+            mmctx->mm_div_r2       = init_fastdiv_values(src1->ne[2] / src0->ne[2]);
+            mmctx->mm_div_r3       = init_fastdiv_values(src1->ne[3] / src0->ne[3]);
 
-        default:
+            need_quant = false;
+        }
+    } else {
+        if (htp_mminit_vec_dot(mmctx, src0->type) != 0) {
             return HTP_STATUS_NO_SUPPORT;
+        }
+
+        quant_job_func = quantize_f32_q8x4x2;
+        src1_row_size  = q8x4x2_row_size(ne10);
+        htp_mminit_spad(octx, dst_row_size, src0_row_size_padded, src1_row_size, src1_nrows, 0);
     }
 
     // VTCM scratchpads for all tensors
     size_t spad_size = octx->src1_spad.size + octx->src0_spad.size + octx->dst_spad.size;
 
-    FARF(HIGH, "matmul-%s : src0-spad-size %u src1-spad-size %u dst-spad-size %u (%zu)\n", op_type,
+    FARF(HIGH, "matmul-%s : src0-spad-size %u src1-spad-size %u dst-spad-size %u (%zu)\n", mmctx->type,
          octx->src0_spad.size, octx->src1_spad.size, octx->dst_spad.size, spad_size);
 
-    FARF(HIGH, "matmul-%s : %ux%ux%ux%u * %ux%ux%ux%u-> %ux%ux%ux%u (0x%p, 0x%p, 0x%p)\n", op_type, src0->ne[0],
+    FARF(HIGH, "matmul-%s : %ux%ux%ux%u * %ux%ux%ux%u-> %ux%ux%ux%u (0x%p, 0x%p, 0x%p)\n", mmctx->type, src0->ne[0],
          src0->ne[1], src0->ne[2], src0->ne[3], src1->ne[0], src1->ne[1], src1->ne[2], src1->ne[3], dst->ne[0],
          dst->ne[1], dst->ne[2], dst->ne[3], src0->data, src1->data, dst->data);
 
     // Make sure the reserved vtcm size is sufficient
     if (octx->ctx->vtcm_size < spad_size) {
-        FARF(ERROR, "matmul-%s : current VTCM reservation %zu is too small, needed %zu\n", op_type,
+        FARF(ERROR, "matmul-%s : current VTCM reservation %zu is too small, needed %zu\n", mmctx->type,
              octx->ctx->vtcm_size, spad_size);
         return HTP_STATUS_VTCM_TOO_SMALL;
     }
@@ -2268,39 +2544,31 @@ int op_matmul(struct htp_ops_context * octx) {
     octx->src1_spad.data = octx->src0_spad.data + octx->src0_spad.size;
     octx->dst_spad.data  = octx->src1_spad.data + octx->src1_spad.size;
 
-    octx->src0_nrows_per_thread = (src0_nrows + octx->n_threads - 1) / octx->n_threads;
-    octx->src0_nrows_per_thread += (octx->src0_nrows_per_thread & 1);  // round up to even
-
     octx->src0_spad.stride = src0_row_size_padded;
     octx->src1_spad.stride = src1_row_size;
 
     if (need_quant) {
-        // Run quant jobs
-        const uint32_t n_quant_jobs = MIN(src1_nrows, octx->n_threads);
-        octx->src1_nrows_per_thread = (src1_nrows + n_quant_jobs - 1) / n_quant_jobs;
-        worker_pool_run_func(octx->ctx->worker_pool, quant_job_func, octx, n_quant_jobs);
+        const uint32_t n_quant_jobs  = MIN(src1_nrows, octx->n_threads);
+        mmctx->src1_nrows_per_thread = (src1_nrows + n_quant_jobs - 1) / n_quant_jobs;
+        worker_pool_run_func(octx->ctx->worker_pool, quant_job_func, mmctx, n_quant_jobs);
     }
 
     if (!(octx->flags & HTP_OPFLAGS_SKIP_COMPUTE)) {
-        // Run matmul jobs
         const uint32_t n_matmul_jobs = octx->n_threads;
-        worker_pool_run_func(octx->ctx->worker_pool, matmul_job_func, octx, n_matmul_jobs);
+        worker_pool_run_func(octx->ctx->worker_pool, matmul_job_func, mmctx, n_matmul_jobs);
     }
 
     return HTP_STATUS_OK;
 }
 
-// ** main matmul-id entry point
-
 int op_matmul_id(struct htp_ops_context * octx) {
     htp_matmul_tensors_preamble;
 
-    struct htp_tensor * restrict ids = &octx->src2;
-
-    const char * op_type;
+    struct htp_matmul_context mmctx_struct = {0};
+    struct htp_matmul_context * mmctx = &mmctx_struct;
+    mmctx->octx = octx;
 
-    worker_callback_t quant_job_func;
-    worker_callback_t matmul_id_job_func;
+    struct htp_tensor * restrict ids = &octx->src2;
 
     const size_t src0_row_size = nb01;
     const size_t dst_row_size  = nb1;
@@ -2310,6 +2578,13 @@ int op_matmul_id(struct htp_ops_context * octx) {
     const uint32_t src0_nrows = ne01;  // per expert
     const uint32_t src1_nrows = ne11 * ne12 * ne13;
 
+    worker_callback_t quant_job_func;
+    worker_callback_t matmul_id_job_func = src1_nrows > 1 ? matmul_id : matvec_id;
+
+    // Compute src0_nrows_per_thread
+    mmctx->src0_nrows_per_thread  = (src0_nrows + octx->n_threads - 1) / octx->n_threads;
+    mmctx->src0_nrows_per_thread += (mmctx->src0_nrows_per_thread & 1); // round up to even
+
     size_t src1_row_size;
     size_t src1_row_size_padded;
 
@@ -2320,112 +2595,29 @@ int op_matmul_id(struct htp_ops_context * octx) {
     size_t matrix_row_counts_size = n_as * sizeof(uint32_t);
     size_t matrix_row_map_size    = n_as * ids->ne[0] * ids->ne[1] * sizeof(struct mmid_row_mapping);
 
-    switch (src0->type) {
-        case HTP_TYPE_Q4_0:
-            op_type        = "q4x2x2-f32";
-            quant_job_func = htp_quantize_f32_q8x4x2;
-            src1_row_size  = q8x4x2_row_size(ne10);  // row size post quantization
-            if (src1_nrows > 1) {
-                matmul_id_job_func = htp_matmul_id_q4x4x2_q8x4x2;
-            } else {
-                matmul_id_job_func = htp_matvec_id_q4x4x2_q8x4x2;
-            }
-
-            // Entire src1 tensor is placed into the VTCM
-            // For other tensors we allocate N rows per thread, padded to HVX vector size
-            octx->dst_spad.size_per_thread  = hex_round_up(MM_SPAD_DST_NROWS * dst_row_size, 256);
-            octx->src0_spad.size_per_thread = hex_round_up(MM_SPAD_SRC0_NROWS * src0_row_size_padded, 256);
-            octx->src1_spad.size_per_thread = hex_round_up(src1_row_size * src1_nrows, 256);
-            octx->src2_spad.size_per_thread = hex_round_up(matrix_row_counts_size + matrix_row_map_size, 256);
-
-            // src0 spad is also used in dynamic quantizer to store padded src1 rows
-            src1_row_size_padded = hex_round_up(src1_row_size, QK_Q8_0x4x2 * sizeof(float));
-            if (octx->src0_spad.size_per_thread < src1_row_size_padded) {
-                octx->src0_spad.size_per_thread = src1_row_size_padded;
-            }
-
-            octx->src2_spad.size = octx->src2_spad.size_per_thread;
-            octx->src1_spad.size = octx->src1_spad.size_per_thread;
-            octx->src0_spad.size = octx->src0_spad.size_per_thread * octx->n_threads;
-            octx->dst_spad.size  = octx->dst_spad.size_per_thread * octx->n_threads;
-            break;
-
-        case HTP_TYPE_Q8_0:
-            op_type        = "q8x2x2-f32";
-            quant_job_func = htp_quantize_f32_q8x4x2;
-            src1_row_size  = q8x4x2_row_size(ne10);  // row size post quantization
-            if (src1_nrows > 1) {
-                matmul_id_job_func = htp_matmul_id_q8x4x2_q8x4x2;
-            } else {
-                matmul_id_job_func = htp_matvec_id_q8x4x2_q8x4x2;
-            }
-
-            // Entire src1 tensor is placed into the VTCM
-            // For other tensors we allocate N rows per thread, padded to HVX vector size
-            octx->dst_spad.size_per_thread  = hex_round_up(MM_SPAD_DST_NROWS * dst_row_size, 256);
-            octx->src0_spad.size_per_thread = hex_round_up(MM_SPAD_SRC0_NROWS * src0_row_size_padded, 256);
-            octx->src1_spad.size_per_thread = hex_round_up(src1_row_size * src1_nrows, 256);
-            octx->src2_spad.size_per_thread = hex_round_up(matrix_row_counts_size + matrix_row_map_size, 256);
-
-            // src0 spad is also used in dynamic quantizer to store padded src1 rows
-            src1_row_size_padded = hex_round_up(src1_row_size, QK_Q8_0x4x2 * sizeof(float));
-            if (octx->src0_spad.size_per_thread < src1_row_size_padded) {
-                octx->src0_spad.size_per_thread = src1_row_size_padded;
-            }
-
-            octx->src2_spad.size = octx->src2_spad.size_per_thread;
-            octx->src1_spad.size = octx->src1_spad.size_per_thread;
-            octx->src0_spad.size = octx->src0_spad.size_per_thread * octx->n_threads;
-            octx->dst_spad.size  = octx->dst_spad.size_per_thread * octx->n_threads;
-            break;
-
-        case HTP_TYPE_MXFP4:
-            op_type        = "mxfp4x2x2-f32";
-            quant_job_func = htp_quantize_f32_q8x4x2;
-            src1_row_size  = q8x4x2_row_size(ne10);  // row size post quantization
-            if (src1_nrows > 1) {
-                matmul_id_job_func = htp_matmul_id_mxfp4x4x2_q8x4x2;
-            } else {
-                matmul_id_job_func = htp_matvec_id_mxfp4x4x2_q8x4x2;
-            }
-
-            // Entire src1 tensor is placed into the VTCM
-            // For other tensors we allocate N rows per thread, padded to HVX vector size
-            octx->dst_spad.size_per_thread  = hex_round_up(MM_SPAD_DST_NROWS * dst_row_size, 256);
-            octx->src0_spad.size_per_thread = hex_round_up(MM_SPAD_SRC0_NROWS * src0_row_size_padded, 256);
-            octx->src1_spad.size_per_thread = hex_round_up(src1_row_size * src1_nrows, 256);
-            octx->src2_spad.size_per_thread = hex_round_up(matrix_row_counts_size + matrix_row_map_size, 256);
-
-            // src0 spad is also used in dynamic quantizer to store padded src1 rows
-            src1_row_size_padded = hex_round_up(src1_row_size, QK_Q8_0x4x2 * sizeof(float));
-            if (octx->src0_spad.size_per_thread < src1_row_size_padded) {
-                octx->src0_spad.size_per_thread = src1_row_size_padded;
-            }
+    if (htp_mminit_vec_dot(mmctx, src0->type) != 0) {
+        return HTP_STATUS_NO_SUPPORT;
+    }
 
-            octx->src2_spad.size = octx->src2_spad.size_per_thread;
-            octx->src1_spad.size = octx->src1_spad.size_per_thread;
-            octx->src0_spad.size = octx->src0_spad.size_per_thread * octx->n_threads;
-            octx->dst_spad.size  = octx->dst_spad.size_per_thread * octx->n_threads;
-            break;
+    quant_job_func = quantize_f32_q8x4x2;
+    src1_row_size  = q8x4x2_row_size(ne10);
 
-        default:
-            return HTP_STATUS_NO_SUPPORT;
-    }
+    const size_t src2_spad_size_per_thread = hex_round_up(matrix_row_counts_size + matrix_row_map_size, 256);
+    htp_mminit_spad(octx, dst_row_size, src0_row_size_padded, src1_row_size, src1_nrows, src2_spad_size_per_thread);
 
     size_t spad_size = octx->src2_spad.size + octx->src1_spad.size + octx->src0_spad.size + octx->dst_spad.size;
 
-    FARF(HIGH, "matmul-id-%s : src0-spad-size %u src1-spad-size %u src2-spad-size %u dst-spad-size %u (%zu)\n", op_type,
+    FARF(HIGH, "matmul-id-%s : src0-spad-size %u src1-spad-size %u src2-spad-size %u dst-spad-size %u (%zu)\n", mmctx->type,
          octx->src0_spad.size, octx->src1_spad.size, octx->src2_spad.size, octx->dst_spad.size, spad_size);
 
-    FARF(HIGH, "matmul-id-%s : %ux%ux%ux%u * %ux%ux%ux%u (%ux%ux%ux%u) -> %ux%ux%ux%u (0x%p, 0x%p, 0x%p)\n", op_type,
+    FARF(HIGH, "matmul-id-%s : %ux%ux%ux%u * %ux%ux%ux%u (%ux%ux%ux%u) -> %ux%ux%ux%u (0x%p, 0x%p, 0x%p)\n", mmctx->type,
          src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3], src1->ne[0], src1->ne[1], src1->ne[2], src1->ne[3],
          ids->ne[0], ids->ne[1], ids->ne[2], ids->ne[3], dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3], src0->data,
          src1->data, dst->data);
 
     // Make sure the reserved vtcm size is sufficient
     if (octx->ctx->vtcm_size < spad_size) {
-        FARF(ERROR, "matmul-id-%s : current VTCM reservation %zu is too small, needed %zu\n", op_type,
-             octx->ctx->vtcm_size, spad_size);
+        FARF(ERROR, "matmul-id-%s : current VTCM reservation %zu is too small, needed %zu\n", mmctx->type, octx->ctx->vtcm_size, spad_size);
         return HTP_STATUS_VTCM_TOO_SMALL;
     }
 
@@ -2434,8 +2626,8 @@ int op_matmul_id(struct htp_ops_context * octx) {
     octx->src2_spad.data = octx->src1_spad.data + octx->src1_spad.size;
     octx->dst_spad.data  = octx->src2_spad.data + octx->src2_spad.size;
 
-    octx->src0_nrows_per_thread = (src0_nrows + octx->n_threads - 1) / octx->n_threads;
-    octx->src0_nrows_per_thread += (octx->src0_nrows_per_thread & 1);  // round up to even
+    octx->src0_spad.stride = src0_row_size_padded;
+    octx->src1_spad.stride = src1_row_size;
 
     if (src1_nrows > 1) {
         // initialize matrix_row_counts and map
@@ -2447,8 +2639,7 @@ int op_matmul_id(struct htp_ops_context * octx) {
         // group rows by src0 matrix
         for (uint32_t iid1 = 0; iid1 < ids->ne[1]; ++iid1) {  // token idx
             for (uint32_t id = 0; id < n_ids; ++id) {         // expert idx
-                const uint32_t i02 =
-                    *(const uint32_t *) ((const uint8_t *) ids->data + iid1 * ids->nb[1] + id * ids->nb[0]);
+                const uint32_t i02 = *(const uint32_t *) ((const uint8_t *) ids->data + iid1 * ids->nb[1] + id * ids->nb[0]);
 
                 assert(i02 >= 0 && i02 < n_as);
 
@@ -2460,16 +2651,14 @@ int op_matmul_id(struct htp_ops_context * octx) {
 
     // Setup worker pool callbacks
     if (!(octx->flags & HTP_OPFLAGS_SKIP_QUANTIZE)) {
-        // Run quant jobs
         const uint32_t n_quant_jobs = MIN(src1_nrows, octx->n_threads);
-        octx->src1_nrows_per_thread = (src1_nrows + n_quant_jobs - 1) / n_quant_jobs;
-        worker_pool_run_func(octx->ctx->worker_pool, quant_job_func, octx, n_quant_jobs);
+        mmctx->src1_nrows_per_thread = (src1_nrows + n_quant_jobs - 1) / n_quant_jobs;
+        worker_pool_run_func(octx->ctx->worker_pool, quant_job_func, mmctx, n_quant_jobs);
     }
 
     if (!(octx->flags & HTP_OPFLAGS_SKIP_COMPUTE)) {
-        // Run matmul-id jobs
         const uint32_t n_matmul_jobs = octx->n_threads;
-        worker_pool_run_func(octx->ctx->worker_pool, matmul_id_job_func, octx, n_matmul_jobs);
+        worker_pool_run_func(octx->ctx->worker_pool, matmul_id_job_func, mmctx, n_matmul_jobs);
     }
 
     return HTP_STATUS_OK;